Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shard_map should support static_argnums #22043

Closed
reinerp opened this issue Jun 22, 2024 · 4 comments · Fixed by #22049
Closed

shard_map should support static_argnums #22043

reinerp opened this issue Jun 22, 2024 · 4 comments · Fixed by #22049
Assignees
Labels
enhancement New feature or request

Comments

@reinerp
Copy link
Contributor

reinerp commented Jun 22, 2024

We've started using (a type-based derivative of) shard_map as a function decorator. Ideally we'd like to annotate our top-level training function with @jax.jit @typed_shard_map decorator on a single top-level function. However, shard_map doesn't support static_argnums. We're forced to add a level of nested functions, in order to capture the static arguments in a closure, see this example:

https://github.com/MatX-inc/seqax/blob/871bd91efc004e569aa91228ef0a66931ed16986/train.py#L292-L295

Would it be possible to add static_argnums support to shard_map? That would allow us to simplify the above to:

@partial(jax.jit, static_argnums=(2, 3), donate_argnums=(0,))
@typed_shard_map
def training_step(state: State, step: u32[b''], h: Hparams, hparams: TrainingHparams, batch: TokenBatch) -> Tuple[Any, Metrics]:
  ...
@reinerp reinerp added the enhancement New feature or request label Jun 22, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 22, 2024

Assigning @mattjj as this is similar to #17461

@mattjj
Copy link
Member

mattjj commented Jun 22, 2024

Hey @reinerp !

What if instead of static_argnums, we allowed passing in_specs=(P(...), P(...), None, None)? That'd be more like vmap's in_axes=None. (Or do you prefer static_argnums? In that case, what do you pass for the corresponding in_specs? Or do we just want to leave out the corresponding entries of in_specs?)

@mattjj
Copy link
Member

mattjj commented Jun 22, 2024

Assuming using in_specs=(..., None, ...) rather than static_argnums=... is okay, #22049 should fix this!

@reinerp
Copy link
Contributor Author

reinerp commented Jun 23, 2024

I think that's probably sufficient, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants