Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to assign nn.Parameter(DTensor) (created outside of compile region) to an nn.Module param attribute during Dynamo tracing #128742

Open
yf225 opened this issue Jun 14, 2024 · 3 comments
Assignees
Labels
high priority module: dtensor distributed tensor tag module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yf225
Copy link
Contributor

yf225 commented Jun 14, 2024

馃悰 Describe the bug

Repro:

"""
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.run --standalone --nproc_per_node=1 test_setattr.py >artifacts/test_output2.txt 2>&1
"""

import functools
import torch
from torch.distributed._tensor import DTensor, Replicate, init_device_mesh

class FSDPParam:
    def __init__(self):
        device_mesh = init_device_mesh("cuda", (1,))
        replica_placement = [Replicate()]
        local_tensor = torch.zeros(3, 3, device="cuda")
        dtensor = DTensor.from_local(local_tensor, device_mesh=device_mesh, placements=replica_placement)
        self.sharded_param = torch.nn.Parameter(dtensor)

class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.foo = torch.nn.Parameter(torch.ones(3, 3, device="cuda"))

    def forward(self, x):
        return x + self.foo

def forward_post_hook(module, input, output, _fsdp_param):
    setattr(module, "foo", _fsdp_param.sharded_param)

# Eager test
# fsdp_param = FSDPParam()
# mod = TestModule()
# mod.register_forward_hook(functools.partial(forward_post_hook, _fsdp_param=fsdp_param))
# inp = torch.zeros(3, 3)
# mod(inp)
# assert torch.allclose(mod.foo.sum(), torch.tensor(0.))

# Compile test
fsdp_param = FSDPParam()
mod = TestModule()
mod.register_forward_hook(functools.partial(forward_post_hook, _fsdp_param=fsdp_param))
compiled_mod = torch.compile(mod, fullgraph=True)
inp = torch.zeros(3, 3, device="cuda")
compiled_mod(inp)
assert torch.allclose(compiled_mod.foo.sum(), torch.tensor(0.))

This throws error at raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' " due to the value being torch.cuda.FloatTensor type.

A few observations:

  1. At the time of attribute assignment, the value shouldn't be a torch.cuda.FloatTensor type - it should be a DTensor type (at least this is the eager behavior). It's weird that under compile it says the actual type is FloatTensor. (I think it's treating the DTensor's _local_tensor as value here, but why is it using the _local_tensor instead of reconstructing and using the DTensor?)
  2. This error only repros if the self.sharded_param value is an nn.Parameter(DTensor) - if it's nn.Parameter(torch.Tensor) instead, the error doesn't occur.

I suspect this is related to Dynamo handling of tensor subclass, particularly we probably need to find the right place to "unflatten" the tensor subclass before assigning it to the module.

cc. @anijain2305 @bdhirsh

Versions

nightly

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @wanchaol @XilunWu @tianyu-l @d4l3k

@yf225
Copy link
Contributor Author

yf225 commented Jun 14, 2024

This is on the critical path for Traceable FSDP2 work, so marking as high-priority.

@yf225
Copy link
Contributor Author

yf225 commented Jun 17, 2024

Another discovery (maybe tangential):

If I have param = torch.nn.Parameter(DTensor) and I print type(param), it doesn't say it's a nn.Parameter.

But if I check isinstance(param, torch.nn.Parameter) it returns True.

Is this intended behavior? cc. @wanchaol

(repro script for this behavior: https://gist.github.com/yf225/c67c4d0ff081be5a7eac72f0ea395abf)

@bdhirsh bdhirsh self-assigned this Jun 18, 2024
@yf225
Copy link
Contributor Author

yf225 commented Jun 18, 2024

Brian's PR: #128981

@masnesral masnesral added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: dtensor distributed tensor tag module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants