Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to export Phi-3-vision model to PyTorch exported program #31622

Open
2 of 4 tasks
zewenli98 opened this issue Jun 25, 2024 · 1 comment
Open
2 of 4 tasks

Unable to export Phi-3-vision model to PyTorch exported program #31622

zewenli98 opened this issue Jun 25, 2024 · 1 comment

Comments

@zewenli98
Copy link

System Info

  • transformers version: 4.41.2
  • Platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.30.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0.dev20240412+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@amyeroberts

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm trying to export the Phi-3-vision model to PyTorch exported program.

Repro:

import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from ml_dtypes import bfloat16
import numpy as np


model_id = "microsoft/Phi-3-vision-128k-instruct"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id, trust_remote_code=True, torch_dtype="auto"
).to('cuda')

user_prompt = "<|user|>\n"
assistant_prompt = "<|assistant|>\n"
prompt_suffix = "<|end|>\n"

# single-image prompt
prompt = f"{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
print(f">>> Prompt\n{prompt}")

image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")

# Initialize 
# inputs.keys: dict_keys(['input_ids', 'pixel_values', 'image_sizes'])
input_ids = inputs["input_ids"]
start_point = input_ids.shape[1]
pixel_values = inputs["pixel_values"]
image_sizes = inputs["image_sizes"]
inputs.pop("attention_mask")

with torch.no_grad():
    # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
    seq_len = torch.export.Dim("seq_len", min=1, max=4096)
    kwargs = {"input_ids": input_ids, "pixel_values": pixel_values, "image_sizes": image_sizes}
    ep = torch.export.export(
        model, 
        args=tuple(),
        kwargs=kwargs,
        dynamic_shapes=({1: seq_len}, {}, {}), 
        strict=False,
    )

Error message:

Traceback (most recent call last):
  File "/home/zewenl/Documents/pytorch/TensorRT/examples/dynamo/phi3.py", line 39, in <module>
    ep = torch.export.export(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export
    return _export(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/export/_trace.py", line 900, in wrapper
    raise e
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/export/_trace.py", line 883, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/export/exported_program.py", line 85, in wrapper
    return fn(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/export/_trace.py", line 1062, in _export
    ep_non_strict = _export_non_strict(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/export/_trace.py", line 583, in _export_non_strict
    gm, graph_signature = transform(aot_export_module)(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/export/_trace.py", line 1023, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1059, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1249, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 265, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 549, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 150, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 174, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 709, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/export/_trace.py", line 1010, in forward
    tree_out = self._export_root(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zewenl/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-vision-128k-instruct/7b92b8c62807f5a98a9fa47cdfd4144f11fbd112/modeling_phi3_v.py", line 1301, in forward
    outputs = self.model(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zewenl/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-vision-128k-instruct/7b92b8c62807f5a98a9fa47cdfd4144f11fbd112/modeling_phi3_v.py", line 1129, in forward
    inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zewenl/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-vision-128k-instruct/7b92b8c62807f5a98a9fa47cdfd4144f11fbd112/image_embedding_phi3_v.py", line 170, in forward
    if len(positions.tolist()) > 0:
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 219, in tolist
    return [elem.tolist() for elem in self.elem]
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_tensor.py", line 1066, in __iter__
    return iter(self.unbind(0))
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 421, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 842, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1187, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 920, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _dispatch_impl
    return decomposition_table[func](*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_refs/__init__.py", line 3906, in unbind
    if t.shape[dim] == 0:
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/__init__.py", line 377, in __bool__
    return self.node.bool_()
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 439, in bool_
    return self.guard_bool("", 0)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py", line 377, in guard_bool
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 265, in wrapper
    return event.run(self)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/fx/experimental/recording.py", line 160, in run
    return self.f(*args, **kwargs)
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4269, in evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 0) (unhinted: Eq(u0, 0)).  (Size-like symbols: u0)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
  File "/home/zewenl/anaconda3/envs/trt-10-py310/lib/python3.10/site-packages/torch/_refs/__init__.py", line 3906, in unbind
    if t.shape[dim] == 0:

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

It seems the error is due to:

File "/home/zewenl/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-vision-128k-instruct/7b92b8c62807f5a98a9fa47cdfd4144f11fbd112/image_embedding_phi3_v.py", line 170, in forward
    if len(positions.tolist()) > 0:

Expected behavior

The code should be able to run correctly.
I'm not sure the issue is from huggingface or pytorch. I submitted a issue to pytorch as well here for your reference.

@amyeroberts
Copy link
Collaborator

Hi @zewenli98, thanks for opening this issue!

I don't think this is a bug in either PyTorch or Hugging Face. When one uses torch.export, it's trying to trace the model/function/callable to produce a traced graph i.e. something which can effectively be compiled or serialized. Not all code is tracing compatible. In particular, things like variable shapes, unknown or changing input and output types and certain logic controls cannot be traced. In the error, we can see the tracing breaks at this line (implies variable tensor sizes and a logic condition).

As the modeling code is on the hub, it's the repo's authors who can update it. If you wish for the model to be exportable, I'd suggest opening a discussion on the checkpoint's community tab, requesting this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants