Unable to assign nn.Parameter(DTensor)
(created outside of compile region) to an nn.Module param attribute during Dynamo tracing
#128742
Labels
high priority
module: dtensor
distributed tensor tag
module: dynamo
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
Repro:
This throws error at
raise TypeError(f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
due to thevalue
beingtorch.cuda.FloatTensor
type.A few observations:
value
shouldn't be atorch.cuda.FloatTensor
type - it should be aDTensor
type (at least this is the eager behavior). It's weird that under compile it says the actual type is FloatTensor. (I think it's treating the DTensor's_local_tensor
asvalue
here, but why is it using the_local_tensor
instead of reconstructing and using the DTensor?)self.sharded_param
value is annn.Parameter(DTensor)
- if it'snn.Parameter(torch.Tensor)
instead, the error doesn't occur.I suspect this is related to Dynamo handling of tensor subclass, particularly we probably need to find the right place to "unflatten" the tensor subclass before assigning it to the module.
cc. @anijain2305 @bdhirsh
Versions
nightly
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @wanchaol @XilunWu @tianyu-l @d4l3k
The text was updated successfully, but these errors were encountered: