Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap fails to call torch.compiled function #128711

Open
guilhermeleobas opened this issue Jun 14, 2024 · 2 comments
Open

vmap fails to call torch.compiled function #128711

guilhermeleobas opened this issue Jun 14, 2024 · 2 comments
Labels
module: vmap oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@guilhermeleobas
Copy link
Collaborator

guilhermeleobas commented Jun 14, 2024

馃悰 Describe the bug

Reproducer

import torch

@torch.compile(backend="aot_eager", fullgraph=True)
def fn(x):
    return x.sin()

x = torch.randn(3, 4)
y = torch.func.vmap(fn)(x)

# ...
# torch/_dynamo/convert_frame.py:178: in _fn
#     return fn(*args, **kwargs)
# torch/_dynamo/convert_frame.py:564: in transform
#     tracer = InstructionTranslator(
# torch/_dynamo/symbolic_convert.py:2396: in __init__
#     self._throw_if_in_functorch()
# torch/_dynamo/symbolic_convert.py:2452: in _throw_if_in_functorch
#     unimplemented(msg)
# torch/_dynamo/exc.py:221: in unimplemented
#     raise Unsupported(msg)
# E   torch._dynamo.exc.Unsupported: torch.func.vmap(fn) requires the function to be inlined by dynamo
# E
# E   Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
# E
# E
# E   You can suppress this exception and fall back to eager by setting:
# E       import torch._dynamo
# E       torch._dynamo.config.suppress_errors = True

We incorrectly exit when we detect that we're inside a functorch in InstructionTranslator. The

Versions

main

cc @zou3519 @ezyang @anijain2305 @chauhang

@guilhermeleobas
Copy link
Collaborator Author

guilhermeleobas commented Jun 14, 2024

Fixing this bug reveals a limitation in aot_autograd as the example_value is a BatchedTensor:

torch/_subclasses/fake_tensor.py:1153: in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
torch/_subclasses/fake_tensor.py:1539: in _dispatch_impl
    (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(
torch/_subclasses/fake_tensor.py:1832: in validate_and_convert_non_fake_tensors
    validated_args = [validate(a) for a in flat_args]
torch/_subclasses/fake_tensor.py:1832: in <listcomp>
    validated_args = [validate(a) for a in flat_args]
torch/_subclasses/fake_tensor.py:1822: in validate
    raise AssertionError(
E   torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
E   AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.sin.default(BatchedTensor(lvl=1, bdim=0, value=
E       FakeTensor(..., size=(3, 4))
E   ))
E
E   While executing %sin : [num_users=1] = call_method[target=sin](args = (%l_x_,), kwargs = {})
E   Original traceback:
E     File "/home/guilhermeleobas/git/pytorch/test/dynamo/test_higher_order_ops.py", line 5290, in wrapped_fn
E       return x.sin()
E

I guess this would need aot_autograd to support functorch tensors?

@zou3519
Copy link
Contributor

zou3519 commented Jun 14, 2024

I guess this would need aot_autograd to support functorch tensors?

Yes, exactly.

torch._dynamo.exc.Unsupported: torch.func.vmap(fn) requires the function to be inlined by dynamo

The error message is really awkward, is it possible for us to make it better?

@eellison eellison added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: vmap oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants