Its really cool! If you don't mind me asking, does it support variable size inputs? I am bit confused about JAX in that regards. I am trying for long to run JAX stablehlo models in C++ for inference. However dynamic shapes were still an issue. If I understand correctly, it recompiles the kernels for all different shapes at runtime, so if inputs vary too much in shape, it will spend considerable time recompiling kerenels. In C++ inference it becomes impossible .However I could be wrong (I did not fully understand the issue, the developer of gopjrt tried to explain it to me!). Do you have any thoughts on this?
We wanted to use ONNX runtime for a "model driver" for MD simulations, where any ML model can be used for molecular dynamics simulations. Problem was it was way too immature. Like ceiling function will only work with single precision in ONNX. But the biggest issue was that we could not take derivatives in ONNX runtime, so any complicated model that uses derivatives inside was a nogo, is that limitation still exist? Do you know if it can take derivatives in training mode now?
One option to your case is OpenVino. It's written in C++ and has Python Bindings. Also, it can be used to train new nets. You can use ONNX files with OpenVino too.
e.g.:
https://github.com/openxla/xla/issues/33092 https://github.com/openxla/xla/issues/35556
Explanation from the gopjrt dev:
https://github.com/gomlx/gopjrt/issues/59