Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there difference in computing gradient between jax and torch #22101

Open
ergo-zyh opened this issue Jun 26, 2024 · 1 comment
Open

Is there difference in computing gradient between jax and torch #22101

ergo-zyh opened this issue Jun 26, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@ergo-zyh
Copy link

Description

I rewrote the jax version of the SAC algorithm to the torch version, and in the reparameterization section, when calculating the loss function by sampling a uniform distribution, I found that the loss function is the same when the inputs and network weights are the same, but the gradients are different, why is this?

System info (python version, jaxlib version, accelerator, etc.)

jax==0.4.28, flax==0.8.0, torch==2.3.1

@ergo-zyh ergo-zyh added the bug Something isn't working label Jun 26, 2024
@ayaka14732
Copy link
Member

This sounds strange. Do you have a code example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants