Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] PyTorch/XLA eager mode as default #7253

Open
JackCaoG opened this issue Jun 12, 2024 · 3 comments
Open

[RFC] PyTorch/XLA eager mode as default #7253

JackCaoG opened this issue Jun 12, 2024 · 3 comments
Labels
eager usability Bugs/features related to improving the usability of PyTorch/XLA

Comments

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 12, 2024

Context

Objective

In this RFC I will talk about the roadmap to enable eager mode as the default computation mode for PyTorch/XLA users and how to enable graph compilation in this mode.

Background

PyTorch/XLA has been using tracing mode as the default mode since the project started. All of the torch operation users issued will be accumulated in the background and sent to the XLA for compilation and execution upon a mark_step call.

The upside of this approach is that users don’t need to change their model code too much. As long as the user adds a mark_step at the right place everything should just work. However from the user feedback in the last couple years this approach creates too much confusion and frustration for the user. Both PyTorch and JAX took the approach of using eager mode as default and asking users to specify the function that they want to compile. PyTorch/XLA should take the same approach.

Design

Eager mode

There is no real eager mode in TPU. However we can fake the eager mode by compiling and executing each torch operation. Such mode already exist as a debug only mode today, it was contributed by @aws-rhsoln 2 year ago in #3306. The work here is to do a better API level wrapping and make sure this mode work with other features(debug output, SPMD, multiprocess etc). This approach was way too slow a couple years ago due to XRT not being able to execute small executions very efficiently but with PJRT the performance is much better.

The whole eager mode still builds on top of the existing Lazy tensor framework, but becomes invisible to the user. A couple things we need to do to accommodate the eager mode are

  1. Increase the compilation cache from 1024 to 2048 since each torch op will also reside in the compilation cache. We also need to recompile every torch op for different input shapes.
  2. Increase the max execution we can queue in the PJRT level since now we will execute a lot more small computations.

Compile

For the compile part we currently have 2 options, lazy tensor and torch dynamo(torch.compile).

For lazy tensor based compile I will add a new API_

torch_xla.experimental.compile(fn) -> compiled_fn

Which under the hood just enables the tracing mode upon running the function and executes the traced graph before returning. Here is the implementation. For torch.compile we can just use the existing API.

Example UX

import torch_xla
torch_xla.experimental.eager_mode(True)

Class TrainDecoderOnlyBase():
  def __init__():
    train_loader = MyLoader()
    self.model = DecoderOnlyModel(self.config).to(torch_xla.device())
    # if run with dynamo, use
    # self.step_fn = torch.compile(self.step_fn, backend="openxla")
    self.step_fn = torch_xla.experimental.compile(self.step_fn)

  def step_fn(self, data, target):
    self.optimizer.zero_grad()
    logits = self.model(data)
    loss = self.loss_fn(
        logits.view(-1, self.config.vocab_size), target.view(-1))
    loss.backward()
    self.run_optimizer()
    return loss

  def start_training(self):
    for step, (data, target) in enumerate(loader):
      loss = self.step_fn(data, target)

if __name__ == '__main__':
  base = TrainDecoderOnlyBase()
  base.start_training()

Note that two changes user need to make is to enable the eager mode by torch_xla.experimental.eager_mode(True) and then compile the step function with torch_xla.experimental.compile or torch.compile.

Users can also choose to run the whole model in eager mode.

Why

IMO using tracing mode as the default has a couple very significant drawback

  1. Users are often confused about when the framework is tracing and when the framework is executing.
  2. Users don’t know where to add the mark_step.
  3. Random python code(data preprocessing for example) often generates some small pending execution that gets leaked into the main graph(step function) and causes recompilation. The recompilation of the whole graph is usually very expensive.
  4. It is hard to debug when/why recompilation happens.

Both JAX and PyTorch took the approach of asking users to explicitly mark the region/function for compilation. This methodology seems well received for users that want compilation mode. I think this proposal will make a much better usability story by

  1. Allow users to use eager mode to do the initial model development and use compile mode to scale up. This also significantly lowers the bar for a normal pytorch user to onboard PyTorch/XLA.
  2. Reduce the number of recompilation generated by non-core model codes, since those will get executed eagerly.
  3. Make graph recompilation easier to debug since only the compiled_fn should generate graphs.

Benchmark

I am running a 2 layer decoder only model training(it is pretty much just a llama2) with fake data on a single chip of v4-8 for 300 steps. This is not a very scientific benchmark so take it with a grain of salt.

token/s
Tracing mode(base line) 147
Eager mode 65
Eager + torch_xla compile 147

Eager mode can achieve ~45% performance of the fully compiled model for the decoder only model. The trainer I used to test can be found here and here.

Work Breakdown

  1. Enable eager mode (done)
  2. Enable torch_xla.experimental.compile(done)
  3. Support eager mode with torch.compile (pr)
  4. Test eager mode with SPMD (1 day)
  5. Test eager mode with multi-process distributed (2 days)
  6. Test eager mode with palla kernel(1 day)
  7. Test eager mode with rest of the pytorch/xla features (1 week)
  8. Enable more tests with eager mode (1 week)
  9. Enable more tests with eager mode + torch_xla.compile (1 week)
  10. Update examples and README to use eager + troch_xla.compile(1 week)
  11. Integrate eager mode with HF(2 weeks to 1 months)
  12. Integrate eager mode with Torch Lighting(2 weeks to 1 months)

Timeline

2.4 release -> experimental

2.5 release -> beta

2.6 release -> enable by default

cc @ezyang @bdhirsh @wconstab @baoleai @amithrm @jeffhataws @albanD @gkroiz @Liyang90

@JackCaoG JackCaoG added the usability Bugs/features related to improving the usability of PyTorch/XLA label Jun 12, 2024
@baoleai
Copy link
Contributor

baoleai commented Jun 12, 2024

Is it possible to support the mixed use of PyTorch eager (for CUDA tensors) and XLA eager (for XLA tensors), for instance, in cases where some operations are not supported by XLA or operations that introduce dynamic shapes need to be executed by PyTorch?

@JackCaoG
Copy link
Collaborator Author

@baoleai You can do that. Now @vanbasten23 added the dlpack support you can just use that api to do a zero-copy convert between XLA:GPU tensor to cuda tensor then operations will happens on eager cuda. IMO the downside is that Pytorch/XLA execution(even for eager) is async, you will need to add a xm.wait_device_ops to wait for the XLA:GPU buffer to be ready before use dlpack to convert it to a cuda tensor.

@JackCaoG
Copy link
Collaborator Author

That being said. @baoleai I don't think there is much point of using PyTorch/XLA:GPU eager, you can just use pytorch eager which is likely faster and use torch.compile or torch_xla.experimental.compile to wrap the funciton you want to compile. Pytorch/XLA's eager mode is more for those bakends that only support XLA and doesn't have other eager mode.

@JackCaoG JackCaoG pinned this issue Jun 13, 2024
@JackCaoG JackCaoG added the eager label Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
eager usability Bugs/features related to improving the usability of PyTorch/XLA
Projects
None yet
Development

No branches or pull requests

2 participants