partitioner doesn't appear to respect SAC region #128730
Labels
module: aotdispatch
umbrella label for AOTAutograd issues
module: pt2-dispatcher
PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
I have a non-distributed repro that uses float8: below is a mini model that is basically float8linear + relu, wrapped in an SAC region that wants to recompute everything except matmuls (requires installing or git cloning the float8_experimental repo).
Printing the generated forward graph with
TORCH_LOGS="aot"
, you get:primals_2, convert_element_type_4, permute_2, reciprocal_2, reciprocal_3
are all outputs to the forward graph corresponding to activations.One thing we can see:
convert_element_type_4
is not the output of a matmul, so according to eager SAC, we would want to recompute it (it is an input to the scaled_mm, computed by taking the input tonn.linear
and converting it to float8)cc @ezyang @anijain2305 @chauhang @zou3519
The text was updated successfully, but these errors were encountered: