Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate the gradient to part of an array only #5767

Open
minhtriet opened this issue May 30, 2024 · 2 comments
Open

Propagate the gradient to part of an array only #5767

minhtriet opened this issue May 30, 2024 · 2 comments
Labels
enhancement ✨ New feature or request

Comments

@minhtriet
Copy link
Contributor

Feature details

I want to optimize for the coordinates of different molecules in a reaction A+B->AB. It would make sense if I fixed coordinates of A and optimize for B's only. However, it doesn't work right now.

Suppose I have this JAX traced array

>>> coords
Traced<ConcreteArray([1. 1. 1.], dtype=float32)>with<JVPTrace(level=2/0)>
...

Then jnp.array([0, 0, 0, *coord]) won't work (within a pennylane context). The error is jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[2,3].

It is because at \qchem\openfermion_obs.py, we have this line geometry_dhf = qml.numpy.array(coordinates.reshape(len(symbols), 3)). At the end of the stack trace it would convert jax to a np array

Implementation

Here is my MVP to recreate the error. The issue is at line 17

import pennylane as qml
from pennylane import numpy as np
import jax
import optax

dev = qml.device("default.qubit", 4)

@qml.qnode(dev)
def circuit_expected(H):
    qml.BasisState([1, 1, 0, 0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(0.2, wires=[0, 1, 2, 3])
    return qml.expval(H)


def loss_f(coord):
    symbols = ["H", "H"]
    H, qb = qml.qchem.molecular_hamiltonian(symbols, jax.numpy.array([0, 0, 0, *coord]))
    return circuit_expected(H)

H_1 = jax.numpy.array([1., 1., 1.])
opt = optax.sgd(learning_rate=0.4)
opt_coords_state = opt.init(H_1)

for i in range(10):
    grad_coordinates = jax.grad(loss_f, 0)(H_1)
    updates, opt_coords_state = opt.update(grad_coordinates, opt_coords_state)
    H_1 = optax.apply_updates(H_1, updates)
    print(grad_coordinates)

How important would you say this feature is?

2: Somewhat important. Needed this quarter.

Additional information

No response

@minhtriet minhtriet added the enhancement ✨ New feature or request label May 30, 2024
@CatalinaAlbornoz
Copy link
Contributor

Hi @minhtriet,

As I answered in the forum thread it looks like molecular_hamiltonian unfortunately only works with autograd so you would need to go back to using PennyLane Numpy instead of Jax 😿 .

We’ll look into adding a warning in the documentation about this and hopefully allowing it to work with Jax in future releases.

@soranjh
Copy link
Contributor

soranjh commented Jun 11, 2024

Thanks @minhtriet for opening the issue. As suggested by Catalina here, could you please try using the diff_hamiltonian function which works with different frameworks? You might also look at this demo for more insight on differentiable qchem workflows.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement ✨ New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants