How-to use JAX with qiskit-dynamics#

JAX enables just-in-time compilation, automatic differentation, and GPU execution. JAX is integrated into qiskit-dynamics via the Array class, which allows most parts of the package to be executed with either numpy or jax.numpy.

This guide addresses the following topics:

  1. How do I configure dynamics to run with JAX?

  2. How do I write code using dispatch that can be executed with either numpy or JAX?

  3. How do I write JAX-transformable functions using the objects and functions in qiskit-dynamics?

  4. Gotchas when using JAX with dynamics.

1. How do I configure dynamics to run with JAX?#

The Array class provides a means of controlling whether array operations are performed using numpy or jax.numpy. In many cases, the “default backend” is used to determine which of the two options is used.

# configure jax to use 64 bit mode
import jax
jax.config.update("jax_enable_x64", True)

# tell JAX we are using CPU
jax.config.update('jax_platform_name', 'cpu')

# import Array and set default backend
from qiskit_dynamics.array import Array
Array.set_default_backend('jax')

The default backend can be observed via:

Array.default_backend()
'jax'

2. How do I write code using Array that can be executed with either numpy or JAX?#

The Array class wraps both numpy and jax.numpy arrays. The particular type is indicated by the backend property, and numpy functions called on an Array will automatically be dispatched to numpy or jax.numpy based on the Array’s backend. See the API documentation for qiskit_dynamics.array for details.

3. How do I write JAX-transformable functions using the objects and functions in qiskit-dynamics?#

JAX-transformable functions must be:

  • JAX-executable.

  • Take JAX arrays as input and output (see the JAX documentation for more details on accepted input and output types).

  • Pure, in the sense that they have no side-effects.

The previous section shows how to handle the first two points using Array. The last point further restricts the type of code that can be safely transformed. Qiskit Dynamics uses various objects which can be updated by setting properties (models, solvers). If a function to be transformed requires updating an already-constructed object of this form, it is necessary to first make a copy.

We demonstrate this process for both just-in-time compilation and automatic differentiation in the context of an anticipated common use-case: parameterized simulation of a model of a quantum system.

3.1 Just-in-time compiling a parameterized simulation#

“Just-in-time compiling” a function means to compile it at run time. Just-in-time compilation incurs an initial cost associated with the construction of the compiled function, but subsequent calls to the function will generally be faster than the uncompiled version. In JAX, just-in-time compilation is performed using the jax.jit function, which transforms a JAX-compatible function into optimized code using XLA. We demonstrate here how, using the JAX backend, functions built using Qiskit Dynamics can be just-in-time compiled, resulting in faster simulation times.

For convenience, the wrap function can be used to transform jax.jit to also work on functions that have Array objects as inputs and outputs.

from qiskit_dynamics.array import wrap

jit = wrap(jax.jit, decorator=True)

Construct a Solver instance with a model that will be used to solve.

import numpy as np
from qiskit.quantum_info import Operator
from qiskit_dynamics import Solver, Signal
from qiskit_dynamics.array import Array

r = 0.5
w = 1.
X = Operator.from_label('X')
Z = Operator.from_label('Z')

static_hamiltonian = 2 * np.pi * w * Z/2
hamiltonian_operators = [2 * np.pi * r * X/2]

solver = Solver(
    static_hamiltonian=static_hamiltonian,
    hamiltonian_operators=hamiltonian_operators,
    rotating_frame=static_hamiltonian
)

Next, define the function to be compiled:

  • The input is the amplitude of a constant-envelope signal on resonance, driven over time \([0, 3]\).

  • The output is the state of the system, starting in the ground state, at 100 points over the total evolution time.

Note, as described at the beginning of this section, we need to make a copy of solver before setting the signals, to ensure the simulation function remains pure.

def sim_function(amp):

    # define a constant signal
    amp = Array(amp)
    signals = [Signal(amp, carrier_freq=w)]

    # simulate and return results
    results = solver.solve(
        t_span=[0, 3.],
        y0=np.array([0., 1.], dtype=complex),
        signals=signals,
        t_eval=np.linspace(0, 3., 100),
        method='jax_odeint'
    )

    return results.y

Compile the function.

fast_sim = jit(sim_function)

The first time the function is called, JAX will compile an XLA version of the function, which is then executed. Hence, the time taken on the first call includes compilation time.

%time ys = fast_sim(1.).block_until_ready()
CPU times: user 769 ms, sys: 14.2 ms, total: 783 ms
Wall time: 771 ms

On subsequent calls the compiled function is directly executed, demonstrating the true speed of the compiled function.

%timeit fast_sim(1.).block_until_ready()
127 µs ± 279 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

We use this function to plot the \(Z\) expectation value over a range of input amplitudes.

import matplotlib.pyplot as plt

for amp in np.linspace(0, 1, 10):
    ys = fast_sim(amp)
    plt.plot(np.linspace(0, 3., 100), np.real(np.abs(ys[:, 0])**2-np.abs(ys[:, 1])**2))
../_images/how_to_use_jax_8_0.png

3.2 Automatically differentiating a parameterized simulation#

Next, we use jax.grad to automatically differentiate a parameterized simulation. In this case, jax.grad requires that the output be a real number, so we specifically compute the population in the excited state at the end of the previous simulation

def excited_state_pop(amp):
    yf = sim_function(amp)[-1]
    return np.abs(Array(yf[0]))**2

Wrap jax.grad in the same way, then differentiate and compile excited_state_pop.

grad = wrap(jax.grad, decorator=True)

excited_pop_grad = jit(grad(excited_state_pop))

As before, the first execution includes compilation time.

%time excited_pop_grad(1.).block_until_ready()
CPU times: user 1.86 s, sys: 20 ms, total: 1.88 s
Wall time: 1.85 s
Array(-2.33674306)

Subsequent runs of the function reveal the execution time once compiled.

%timeit excited_pop_grad(1.).block_until_ready()
721 µs ± 2.46 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

4. Pitfalls when using JAX with Dynamics#

4.1 JAX must be set as the default backend before building any objects in Qiskit Dynamics#

To get dynamics to run with JAX, it is necessary to configure dynamics to run with JAX before building any objects or running any functions. The internal behaviour of some objects is modified by what the default backend is at the time of instantiation. For example, at instantiation the operators in a model or Solver instance will be wrapped in an Array whose backend is the current default backend, and changing the default backend after building the object won’t change this.

4.2 Running Dynamics with JAX on CPU vs GPU#

Certain JAX-based features in Dynamics are primarily recommended for use only with CPU or only with GPU. In such cases, a warning is raised if non-recommended hardware is used, however users are not prevented from configuring Dynamics and JAX in whatever way they choose.

Instances of such features are:
  • Setting evaluation_mode='sparse' for solvers and models is only recommended for use on CPU.

  • Parallel fixed step solver options in solve_lmde are recommended only for use on GPU.