-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
shard_map should support static_argnums #22043
Comments
Hey @reinerp ! What if instead of |
Assuming using |
I think that's probably sufficient, thank you! |
We've started using (a type-based derivative of) shard_map as a function decorator. Ideally we'd like to annotate our top-level training function with
@jax.jit @typed_shard_map
decorator on a single top-level function. However,shard_map
doesn't supportstatic_argnums
. We're forced to add a level of nested functions, in order to capture the static arguments in a closure, see this example:https://github.com/MatX-inc/seqax/blob/871bd91efc004e569aa91228ef0a66931ed16986/train.py#L292-L295
Would it be possible to add
static_argnums
support toshard_map
? That would allow us to simplify the above to:The text was updated successfully, but these errors were encountered: