You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I rewrote the jax version of the SAC algorithm to the torch version, and in the reparameterization section, when calculating the loss function by sampling a uniform distribution, I found that the loss function is the same when the inputs and network weights are the same, but the gradients are different, why is this?
System info (python version, jaxlib version, accelerator, etc.)
jax==0.4.28, flax==0.8.0, torch==2.3.1
The text was updated successfully, but these errors were encountered:
Description
I rewrote the jax version of the SAC algorithm to the torch version, and in the reparameterization section, when calculating the loss function by sampling a uniform distribution, I found that the loss function is the same when the inputs and network weights are the same, but the gradients are different, why is this?
System info (python version, jaxlib version, accelerator, etc.)
jax==0.4.28, flax==0.8.0, torch==2.3.1
The text was updated successfully, but these errors were encountered: