You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In this RFC I will talk about the roadmap to enable eager mode as the default computation mode for PyTorch/XLA users and how to enable graph compilation in this mode.
Background
PyTorch/XLA has been using tracing mode as the default mode since the project started. All of the torch operation users issued will be accumulated in the background and sent to the XLA for compilation and execution upon a mark_step call.
The upside of this approach is that users don’t need to change their model code too much. As long as the user adds a mark_step at the right place everything should just work. However from the user feedback in the last couple years this approach creates too much confusion and frustration for the user. Both PyTorch and JAX took the approach of using eager mode as default and asking users to specify the function that they want to compile. PyTorch/XLA should take the same approach.
Design
Eager mode
There is no real eager mode in TPU. However we can fake the eager mode by compiling and executing each torch operation. Such mode already exist as a debug only mode today, it was contributed by @aws-rhsoln 2 year ago in #3306. The work here is to do a better API level wrapping and make sure this mode work with other features(debug output, SPMD, multiprocess etc). This approach was way too slow a couple years ago due to XRT not being able to execute small executions very efficiently but with PJRT the performance is much better.
The whole eager mode still builds on top of the existing Lazy tensor framework, but becomes invisible to the user. A couple things we need to do to accommodate the eager mode are
Increase the compilation cache from 1024 to 2048 since each torch op will also reside in the compilation cache. We also need to recompile every torch op for different input shapes.
Increase the max execution we can queue in the PJRT level since now we will execute a lot more small computations.
Compile
For the compile part we currently have 2 options, lazy tensor and torch dynamo(torch.compile).
For lazy tensor based compile I will add a new API_
torch_xla.experimental.compile(fn) -> compiled_fn
Which under the hood just enables the tracing mode upon running the function and executes the traced graph before returning. Here is the implementation. For torch.compile we can just use the existing API.
Note that two changes user need to make is to enable the eager mode by torch_xla.experimental.eager_mode(True) and then compile the step function with torch_xla.experimental.compile or torch.compile.
Users can also choose to run the whole model in eager mode.
Why
IMO using tracing mode as the default has a couple very significant drawback
Users are often confused about when the framework is tracing and when the framework is executing.
Users don’t know where to add the mark_step.
Random python code(data preprocessing for example) often generates some small pending execution that gets leaked into the main graph(step function) and causes recompilation. The recompilation of the whole graph is usually very expensive.
It is hard to debug when/why recompilation happens.
Both JAX and PyTorch took the approach of asking users to explicitly mark the region/function for compilation. This methodology seems well received for users that want compilation mode. I think this proposal will make a much better usability story by
Allow users to use eager mode to do the initial model development and use compile mode to scale up. This also significantly lowers the bar for a normal pytorch user to onboard PyTorch/XLA.
Reduce the number of recompilation generated by non-core model codes, since those will get executed eagerly.
Make graph recompilation easier to debug since only the compiled_fn should generate graphs.
Benchmark
I am running a 2 layer decoder only model training(it is pretty much just a llama2) with fake data on a single chip of v4-8 for 300 steps. This is not a very scientific benchmark so take it with a grain of salt.
token/s
Tracing mode(base line)
147
Eager mode
65
Eager + torch_xla compile
147
Eager mode can achieve ~45% performance of the fully compiled model for the decoder only model. The trainer I used to test can be found here and here.
Is it possible to support the mixed use of PyTorch eager (for CUDA tensors) and XLA eager (for XLA tensors), for instance, in cases where some operations are not supported by XLA or operations that introduce dynamic shapes need to be executed by PyTorch?
@baoleai You can do that. Now @vanbasten23 added the dlpack support you can just use that api to do a zero-copy convert between XLA:GPU tensor to cuda tensor then operations will happens on eager cuda. IMO the downside is that Pytorch/XLA execution(even for eager) is async, you will need to add a xm.wait_device_ops to wait for the XLA:GPU buffer to be ready before use dlpack to convert it to a cuda tensor.
That being said. @baoleai I don't think there is much point of using PyTorch/XLA:GPU eager, you can just use pytorch eager which is likely faster and use torch.compile or torch_xla.experimental.compile to wrap the funciton you want to compile. Pytorch/XLA's eager mode is more for those bakends that only support XLA and doesn't have other eager mode.
Context
Objective
In this RFC I will talk about the roadmap to enable eager mode as the default computation mode for PyTorch/XLA users and how to enable graph compilation in this mode.
Background
PyTorch/XLA has been using tracing mode as the default mode since the project started. All of the torch operation users issued will be accumulated in the background and sent to the XLA for compilation and execution upon a
mark_step
call.The upside of this approach is that users don’t need to change their model code too much. As long as the user adds a
mark_step
at the right place everything should just work. However from the user feedback in the last couple years this approach creates too much confusion and frustration for the user. Both PyTorch and JAX took the approach of using eager mode as default and asking users to specify the function that they want to compile. PyTorch/XLA should take the same approach.Design
Eager mode
There is no real eager mode in TPU. However we can fake the eager mode by compiling and executing each torch operation. Such mode already exist as a debug only mode today, it was contributed by @aws-rhsoln 2 year ago in #3306. The work here is to do a better API level wrapping and make sure this mode work with other features(debug output, SPMD, multiprocess etc). This approach was way too slow a couple years ago due to XRT not being able to execute small executions very efficiently but with PJRT the performance is much better.
The whole eager mode still builds on top of the existing Lazy tensor framework, but becomes invisible to the user. A couple things we need to do to accommodate the eager mode are
Compile
For the compile part we currently have 2 options, lazy tensor and torch dynamo(torch.compile).
For lazy tensor based compile I will add a new API_
Which under the hood just enables the tracing mode upon running the function and executes the traced graph before returning. Here is the implementation. For
torch.compile
we can just use the existing API.Example UX
Note that two changes user need to make is to enable the eager mode by
torch_xla.experimental.eager_mode(True)
and then compile the step function withtorch_xla.experimental.compile
ortorch.compile
.Users can also choose to run the whole model in eager mode.
Why
IMO using tracing mode as the default has a couple very significant drawback
mark_step
.Both JAX and PyTorch took the approach of asking users to explicitly mark the region/function for compilation. This methodology seems well received for users that want compilation mode. I think this proposal will make a much better usability story by
compiled_fn
should generate graphs.Benchmark
I am running a 2 layer decoder only model training(it is pretty much just a llama2) with fake data on a single chip of v4-8 for 300 steps. This is not a very scientific benchmark so take it with a grain of salt.
Eager mode can achieve ~45% performance of the fully compiled model for the decoder only model. The trainer I used to test can be found here and here.
Work Breakdown
torch_xla.experimental.compile
(done)torch.compile
(pr)Timeline
2.4 release -> experimental
2.5 release -> beta
2.6 release -> enable by default
cc @ezyang @bdhirsh @wconstab @baoleai @amithrm @jeffhataws @albanD @gkroiz @Liyang90
The text was updated successfully, but these errors were encountered: