Not exactly an answer to your question, but from a pytorch standpoint, there are still many operations that are not supported on MPS[1]. I can't recall a circumstance where an architecture I wanted to train was fully supported on MPS, so some of the training ends up happening in CPU at that point.
1 = https://github.com/pytorch/pytorch/issues/77764