Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.cond sometimes inserts a nonlinear lax.stop_gradient into its JVP rule. #22011

Open
patrick-kidger opened this issue Jun 20, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Jun 20, 2024

Description

I finally have a MWE for an intermittent issue I've been seeing for months!

First of all the root cause: the lax.stop_gradient on this line:

ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops]

is being applied unconditionally to all operands. However, if we've been linearised first, then some of those operands may be tangents -- for which only linear operations are valid! Note that JAX treats lax.stop_gradient as nonlinear, and does not offer a transpose rule for it.

Thus the following MWE:

import jax
import jax.lax as lax
import jax.numpy as jnp

def cond(pred, x):
    return lax.cond(pred, lambda: x, lambda: x)

def linearize(pred, x):
    _, lin_fn = jax.linearize(lambda y: cond(pred, y), 1.)
    return lin_fn(x)

def vmap(pred, x):
    return jax.vmap(linearize)(pred[None], x[None])[0]

pred = jnp.array(True)
x = jnp.array(1.)
jax.linear_transpose(lambda y: vmap(pred, y), 1.)(x)

crashes at trace-time with:

NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'stop_gradient' not implemented

Huge thanks to @dkweiss31 over in patrick-kidger/diffrax#387 for having enough of a MWE that I was able to isolate it down to this.

Tagging @mattjj as my guess as being the likely person for this kind of composition-of-transforms stuff.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.11.9 (main, Apr 19 2024, 11:43:47) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Air', release='22.5.0', version='Darwin Kernel Version 22.5.0: Mon Apr 24 20:52:43 PDT 2023; root:xnu-8796.121.2~5/RELEASE_ARM64_T8112', machine='arm64')
@patrick-kidger patrick-kidger added the bug Something isn't working label Jun 20, 2024
@mattjj
Copy link
Member

mattjj commented Jun 20, 2024

Wow, thanks for doing all the hard work of figuring this out, @patrick-kidger and @dkweiss31.

@mattjj mattjj self-assigned this Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants