Jax-metal on Apple M-series GPUs is barely useable in my opinion. It's not possible to invert a matrix, for example, because Apple has not implemented the necessary triangular solve operator. It's also not possible to sample points from a normal distribution, because the Cholesky decomposition operator is not yet implemented. Apple hasn't responsed to both of these issues for the past year. It's difficult to take a numerical computing framework seriously if one cannot invert a matrix.
[1]: https://github.com/google/jax/issues/16321 [2]: https://github.com/google/jax/issues/17490