Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partitioner doesn't appear to respect SAC region #128730

Open
bdhirsh opened this issue Jun 14, 2024 · 2 comments
Open

partitioner doesn't appear to respect SAC region #128730

bdhirsh opened this issue Jun 14, 2024 · 2 comments
Labels
module: aotdispatch umbrella label for AOTAutograd issues module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 14, 2024

I have a non-distributed repro that uses float8: below is a mini model that is basically float8linear + relu, wrapped in an SAC region that wants to recompute everything except matmuls (requires installing or git cloning the float8_experimental repo).

import torch
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
    get_float8_layers,
    get_float8_linear,
    LinearType,
    swap_linear_with_float8_linear,
    sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from torch.utils.checkpoint import (
    checkpoint,
    CheckpointPolicy,
    create_selective_checkpoint_contexts,
)


class Mod(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = get_float8_linear(LinearType.DYNAMIC, torch.nn.Linear(32, 32, device='cuda', dtype=torch.bfloat16), emulate=False)
    def forward(self, x):
        out = self.lin1(x)
        out = torch.relu(out)
        return x + out

m = Mod()

def _get_custom_policy(no_recompute_list=None):
    def _custom_policy(ctx, func, *args, **kwargs):
        if func in no_recompute_list:
            return CheckpointPolicy.MUST_SAVE
        else:
            return CheckpointPolicy.PREFER_RECOMPUTE

    return _custom_policy

def selective_checkpointing_context_fn():
    no_recompute_list = [
        torch.ops.aten.mm.default,
    ]
    return create_selective_checkpoint_contexts(
        _get_custom_policy(no_recompute_list=no_recompute_list)
    )

@torch.compile(backend="aot_eager_decomp_partition")
def f(x):
    return torch.utils.checkpoint.checkpoint(m, x, use_reentrant=False)

x = torch.randn(32, 32, device='cuda', dtype=torch.bfloat16)
out = f(x)

Printing the generated forward graph with TORCH_LOGS="aot", you get:

 ===== Forward graph 0 =====
 /home/hirsheybar/local/b/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "bf16[32, 32][32, 1]cuda:0", primals_2: "bf16[32][1]cuda:0", primals_3: "bf16[32, 32][32, 1]cuda:0"):
        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_utils.py:97 in tensor_to_amax, code: amax = torch.max(torch.abs(x))
        abs_1: "bf16[32, 32][32, 1]cuda:0" = torch.ops.aten.abs.default(primals_3)
        max_1: "bf16[][]cuda:0" = torch.ops.aten.max.default(abs_1);  abs_1 = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_utils.py:38 in amax_to_scale, code: scale = torch.empty_like(amax, dtype=torch.float32)
        empty: "f32[][]cuda:0" = torch.ops.aten.empty.memory_format([], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
        permute: "f32[][]cuda:0" = torch.ops.aten.permute.default(empty, []);  empty = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_utils.py:40 in amax_to_scale, code: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
        convert_element_type: "f32[][]cuda:0" = torch.ops.prims.convert_element_type.default(max_1, torch.float32);  max_1 = None
        clamp_min: "f32[][]cuda:0" = torch.ops.aten.clamp_min.default(convert_element_type, 1e-12);  convert_element_type = None
        convert_element_type_1: "bf16[][]cuda:0" = torch.ops.prims.convert_element_type.default(clamp_min, torch.bfloat16);  clamp_min = None
        reciprocal: "bf16[][]cuda:0" = torch.ops.aten.reciprocal.default(convert_element_type_1);  convert_element_type_1 = None
        mul: "bf16[][]cuda:0" = torch.ops.aten.mul.Tensor(reciprocal, 448.0);  reciprocal = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_utils.py:49 in amax_to_scale, code: scale.copy_(res)
        copy: "f32[][]cuda:0" = torch.ops.aten.copy.default(permute, mul);  mul = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_tensor.py:267 in to_float8, code: return ToFloat8ConstrFunc.apply(
        mul_1: "bf16[32, 32][32, 1]cuda:0" = torch.ops.aten.mul.Tensor(primals_3, copy)
        convert_element_type_2: "f32[32, 32][32, 1]cuda:0" = torch.ops.prims.convert_element_type.default(mul_1, torch.float32);  mul_1 = None
        clamp_min_1: "f32[32, 32][32, 1]cuda:0" = torch.ops.aten.clamp_min.default(convert_element_type_2, -448.0);  convert_element_type_2 = None
        clamp_max: "f32[32, 32][32, 1]cuda:0" = torch.ops.aten.clamp_max.default(clamp_min_1, 448.0);  clamp_min_1 = None
        convert_element_type_3: "bf16[32, 32][32, 1]cuda:0" = torch.ops.prims.convert_element_type.default(clamp_max, torch.bfloat16);  clamp_max = None
        convert_element_type_4: "f8e4m3fn[32, 32][32, 1]cuda:0" = torch.ops.prims.convert_element_type.default(convert_element_type_3, torch.float8_e4m3fn);  convert_element_type_3 = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_utils.py:97 in tensor_to_amax, code: amax = torch.max(torch.abs(x))
        abs_2: "bf16[32, 32][32, 1]cuda:0" = torch.ops.aten.abs.default(primals_1)
        max_2: "bf16[][]cuda:0" = torch.ops.aten.max.default(abs_2);  abs_2 = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_utils.py:40 in amax_to_scale, code: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
        convert_element_type_5: "f32[][]cuda:0" = torch.ops.prims.convert_element_type.default(max_2, torch.float32);  max_2 = None
        clamp_min_2: "f32[][]cuda:0" = torch.ops.aten.clamp_min.default(convert_element_type_5, 1e-12);  convert_element_type_5 = None
        convert_element_type_6: "bf16[][]cuda:0" = torch.ops.prims.convert_element_type.default(clamp_min_2, torch.bfloat16);  clamp_min_2 = None
        reciprocal_1: "bf16[][]cuda:0" = torch.ops.aten.reciprocal.default(convert_element_type_6);  convert_element_type_6 = None
        mul_2: "bf16[][]cuda:0" = torch.ops.aten.mul.Tensor(reciprocal_1, 448.0);  reciprocal_1 = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_utils.py:49 in amax_to_scale, code: scale.copy_(res)
        copy_1: "f32[][]cuda:0" = torch.ops.aten.copy.default(permute, mul_2);  permute = mul_2 = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_tensor.py:267 in to_float8, code: return ToFloat8ConstrFunc.apply(
        mul_3: "bf16[32, 32][32, 1]cuda:0" = torch.ops.aten.mul.Tensor(primals_1, copy_1);  primals_1 = None
        convert_element_type_7: "f32[32, 32][32, 1]cuda:0" = torch.ops.prims.convert_element_type.default(mul_3, torch.float32);  mul_3 = None
        clamp_min_3: "f32[32, 32][32, 1]cuda:0" = torch.ops.aten.clamp_min.default(convert_element_type_7, -448.0);  convert_element_type_7 = None
        clamp_max_1: "f32[32, 32][32, 1]cuda:0" = torch.ops.aten.clamp_max.default(clamp_min_3, 448.0);  clamp_min_3 = None
        convert_element_type_8: "bf16[32, 32][32, 1]cuda:0" = torch.ops.prims.convert_element_type.default(clamp_max_1, torch.bfloat16);  clamp_max_1 = None
        convert_element_type_9: "f8e4m3fn[32, 32][32, 1]cuda:0" = torch.ops.prims.convert_element_type.default(convert_element_type_8, torch.float8_e4m3fn);  convert_element_type_8 = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_dynamic_linear.py:71 in forward, code: y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
        permute_2: "f8e4m3fn[32, 32][1, 32]cuda:0" = torch.ops.aten.permute.default(convert_element_type_9, [1, 0]);  convert_element_type_9 = None
        reciprocal_2: "f32[][]cuda:0" = torch.ops.aten.reciprocal.default(copy);  copy = None
        reciprocal_3: "f32[][]cuda:0" = torch.ops.aten.reciprocal.default(copy_1);  copy_1 = None
        _scaled_mm = torch.ops.aten._scaled_mm.default(convert_element_type_4, permute_2, bias = primals_2, out_dtype = torch.bfloat16, scale_a = reciprocal_2, scale_b = reciprocal_3, use_fast_accum = True)
        getitem: "bf16[32, 32][32, 1]cuda:0" = _scaled_mm[0];  _scaled_mm = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/float8_experimental/float8_dynamic_linear.py:117 in cast_to_float8_e5m2_bw, code: return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
        view: "bf16[32, 32][32, 1]cuda:0" = torch.ops.aten.view.default(getitem, [32, 32]);  getitem = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/tmp.py:26 in forward, code: out = torch.relu(out)
        relu: "bf16[32, 32][32, 1]cuda:0" = torch.ops.aten.relu.default(view);  view = None

        # File: /home/hirsheybar/local/b/pytorch/float8_experimental/tmp.py:27 in forward, code: return x + out
        add: "bf16[32, 32][32, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, relu);  primals_3 = relu = None
        return [add, primals_2, convert_element_type_4, permute_2, reciprocal_2, reciprocal_3]

primals_2, convert_element_type_4, permute_2, reciprocal_2, reciprocal_3 are all outputs to the forward graph corresponding to activations.

One thing we can see: convert_element_type_4 is not the output of a matmul, so according to eager SAC, we would want to recompute it (it is an input to the scaled_mm, computed by taking the input to nn.linear and converting it to float8)

cc @ezyang @anijain2305 @chauhang @zou3519

@bdhirsh bdhirsh added oncall: pt2 module: aotdispatch umbrella label for AOTAutograd issues module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Jun 14, 2024
@eellison eellison added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 14, 2024
@drisspg
Copy link
Contributor

drisspg commented Jun 20, 2024

Curious what the possible solutions are here?

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Jun 20, 2024

I imagined that the partitioner's goal is to jointly optimize what we save for backward, subject to the restriction that it will obey any recompute/save decisions made from user-specified AC/SAC API's. If that's true, then that would make "partitioner not respecting user-specific SAC" a legitimate bug.

I'm not too familiar with how the partitioner maintains this invariant though. Maybe @Chillee and/or @yf225 have a better idea of where to look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: aotdispatch umbrella label for AOTAutograd issues module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants