Hacker News new | past | comments | ask | show | jobs | submit login

> they're dynamically shaped and hit near peak

While this is true for most common GEMM looking ops, if you tread off the beaten path things get slow (odd channel sizes, batch sizes, etc...). Right now in PyTorch, GroupNorm is 2x slower than BatchNorm. There's no fundamental reason, just that the kernels loop over axes in a less than ideal order. Dynamic recompilation allows you to change the loop order too, not just deal with boundary conditions.




> tread off the beaten path things get slow

Yea, makes sense. I think there's something to be said for dynamic compilation solving this problem more elegantly than providing tons of hand-tuned kernels (PyTorch is 890MB lmao https://pypi.org/project/torch/#files), but I don't think it's a strict reason for a performance win.

> change the loop order too

Memory layout as well! I'm 100% for dynamic compilation, but I'm claiming that it really finds its stride when you fuse things.


Agreed. For anything at all common, most of the gains will be from fusion, the rest is just free. PyTorch also uses tons of GPU memory after only initializing, I wonder if it's copying all the kernels in?


Jax preallocates 90% of available GPU memory when first operation is run to minimize allocation overhead. Can PyTorch grab that VRAM for a similar reason?


Yes PyTorch uses what they call a caching memory allocator[0], basically seems like are allocating a very chunk of GPU memory and implementing a heap with it. If needed they expose some knobs and functions to allow you to control it and observe the memory usage.

[0]: https://pytorch.org/docs/stable/notes/cuda.html#memory-manag...


> Right now in PyTorch, GroupNorm is 2x slower than BatchNorm

How did you benchmark this? I think there are like 3 or 4 different GN implementations in PyTorch..


Whole net performance at comma, when we switch from BatchNorm to GroupNorm it adds 70ms to the training step time, and it's -70ms for no norm. We also wrote a custom AllNorm that's like 10% slower than BatchNorm (and I put several hours into trying to optimize it). Obviously not indicative of everyone's experience, but my point is BatchNorm is hyperoptimized and others, which are pretty much the same thing, aren't.


Thanks, that's certainly helpful anecdotal evidence.. yeah it seems like there should be an "AllNorm" implementation that covers all cases and is just fast. I was wondering because I'm currently looking at math_group_norm, which was ported from PyTorch/XLA and it results in a really weird decomposition that I'm astonished works at all. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen...

I'm also wondering if the handcoded backward passes are actually "numerically correct", because e.g. epsilon doesn't appear in it at all. Someone worked out the gradients manually for BN here: https://web.archive.org/web/20180826123459/http://cthorey.gi...

You can clearly see epsilon appearing in the output. And of course there's the whole training vs. eval mode thing with BN which GN doesn't have.

In any case, thanks again.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: