autograd with is_grads_batched=True
fails on GroupNorm
#128703
Labels
module: autograd
Related to torch.autograd, and the autograd engine in general
module: functorch
Pertaining to torch.func or pytorch/functorch
module: norms and normalization
module: vmap
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
Hi !
I'm trying to backward my model along multiple directions, so I'm using
torch.autograd.grad
withis_grads_batched=True
. I had no problem using it on a MLP, but when I tried on a more complicated model, it crashed. I look for what function of the model isn't compatible with this autograd feature, and apparentlyGroupNorm
isn't compatible:You can reproduce using :
My model works with tensors of shape (B, 32, 12).
And I want 3 different projections of the jacobian, => that's the purpose of
proj_directions
.It works with linear / relu / LayerNorm / BatchNorm etc, but not GroupNorm.
Am I doing sth wrong ?
I tried a second version where I flatten both the input and the output, so that the produced jacobian as a 3D shapes (Nb of projection directions ; Batch dim ; total input size).
Thank you for your help,
Versions
Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-112-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.1
[pip3] torch==2.1.2.post100
[pip3] torch-tb-profiler==0.4.1
[pip3] torchvision==0.16.1+b88453f
[conda] cpuonly 2.0 0 pytorch
[conda] libopenvino-pytorch-frontend 2023.3.0 h59595ed_3 conda-forge
[conda] libtorch 2.1.2 cpu_mkl_hadc400e_100 conda-forge
[conda] mkl 2023.2.0 h84fe81f_50496 conda-forge
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.8 py310ha3dbc2a_1 conda-forge
[conda] mkl_random 1.2.5 py310hbd113e2_1 conda-forge
[conda] numpy 1.26.4 py310heeff2f4_0
[conda] numpy-base 1.26.4 py310h8a23956_0
[conda] pytorch 2.1.2 cpu_mkl_py310h3ea73d3_100 conda-forge
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torchvision 0.16.1 cpu_py310h684a773_3 conda-forge
cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @zou3519 @Chillee @samdow @kshitij12345 @janeyx99
The text was updated successfully, but these errors were encountered: