Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

autograd with is_grads_batched=True fails on GroupNorm #128703

Open
TheotimeLH opened this issue Jun 14, 2024 · 4 comments
Open

autograd with is_grads_batched=True fails on GroupNorm #128703

TheotimeLH opened this issue Jun 14, 2024 · 4 comments
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: functorch Pertaining to torch.func or pytorch/functorch module: norms and normalization module: vmap triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@TheotimeLH
Copy link

TheotimeLH commented Jun 14, 2024

馃悰 Describe the bug

Hi !

I'm trying to backward my model along multiple directions, so I'm using torch.autograd.grad with is_grads_batched=True. I had no problem using it on a MLP, but when I tried on a more complicated model, it crashed. I look for what function of the model isn't compatible with this autograd feature, and apparently GroupNorm isn't compatible:

You can reproduce using :

import torch
B = 11
P = 3
x = torch.randn(B,32,12).requires_grad_()
net = torch.nn.GroupNorm(1,32)
out = net(x)
proj_directions = torch.randn(P,B,32,12)
autograd_res = torch.autograd.grad(
    outputs=out,
    inputs=[x],
    grad_outputs=proj_directions,
    retain_graph=True,
    create_graph=True,
    is_grads_batched=True
)

My model works with tensors of shape (B, 32, 12).
And I want 3 different projections of the jacobian, => that's the purpose of proj_directions.

It works with linear / relu / LayerNorm / BatchNorm etc, but not GroupNorm.
Am I doing sth wrong ?

I tried a second version where I flatten both the input and the output, so that the produced jacobian as a 3D shapes (Nb of projection directions ; Batch dim ; total input size).

B = 11
P = 3
flat_x = torch.randn(B,32*12).requires_grad_()
x = flat_x.view(B,32,12)
net = torch.nn.GroupNorm(1,32)
out = net(x)
flat_out = out.flatten(1,2)
proj_directions = torch.randn(P,B,32*12)
autograd_res = torch.autograd.grad(
    outputs=flat_out,
    inputs=[flat_x],
    grad_outputs=proj_directions,
    retain_graph=True,
    create_graph=True,
    is_grads_batched=True
)

Thank you for your help,

Versions

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-112-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.1
[pip3] torch==2.1.2.post100
[pip3] torch-tb-profiler==0.4.1
[pip3] torchvision==0.16.1+b88453f
[conda] cpuonly 2.0 0 pytorch
[conda] libopenvino-pytorch-frontend 2023.3.0 h59595ed_3 conda-forge
[conda] libtorch 2.1.2 cpu_mkl_hadc400e_100 conda-forge
[conda] mkl 2023.2.0 h84fe81f_50496 conda-forge
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.8 py310ha3dbc2a_1 conda-forge
[conda] mkl_random 1.2.5 py310hbd113e2_1 conda-forge
[conda] numpy 1.26.4 py310heeff2f4_0
[conda] numpy-base 1.26.4 py310h8a23956_0
[conda] pytorch 2.1.2 cpu_mkl_py310h3ea73d3_100 conda-forge
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torchvision 0.16.1 cpu_py310h684a773_3 conda-forge

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @zou3519 @Chillee @samdow @kshitij12345 @janeyx99

@ezyang ezyang added module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: vmap module: norms and normalization labels Jun 15, 2024
@ezyang
Copy link
Contributor

ezyang commented Jun 15, 2024

Does it work if you use vmap instead, by any chance

@TheotimeLH
Copy link
Author

TheotimeLH commented Jun 15, 2024

I'm sorry but I don't know how to use vmap instead;
I tried to vmap the autograd call, but, as stated by https://pytorch.org/tutorials/intermediate/jacobians_hessians.html << We can鈥檛 directly apply vmap to torch.autograd.grad; instead, PyTorch provides a torch.func.vjp>>

But apparently torch.func.vjp is designed to build the Jacobian but it runs the forward too.
Whereas my function is non deterministic, so I don't want to (re)run the forward to compute this autograd.grad.

Also, I tried both Cpu and Cuda,
BatchNorm and InstanceNorm work,
Only GroupNorm fails

@soulitzer
Copy link
Contributor

Is it possible to avoid running forward the first time? vjp returns both the forward output and the gradients.

@soulitzer soulitzer added the module: functorch Pertaining to torch.func or pytorch/functorch label Jun 17, 2024
@zou3519
Copy link
Contributor

zou3519 commented Jun 17, 2024

But apparently torch.func.vjp is designed to build the Jacobian but it runs the forward too.
Whereas my function is non deterministic, so I don't want to (re)run the forward to compute this autograd.grad.

You can use torch.func.vjp to run the forward once. It gives you a callable that can be used to run the backward. The idea is to vmap over said callable (this avoids re-running the forward).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: functorch Pertaining to torch.func or pytorch/functorch module: norms and normalization module: vmap triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants