Hands-on Example: Fluxonium Qubit#

Energy spectrum#

There are different ways to pass Fluxonium parameters (EC, EJ, EL).

  • Directly set parameters when creating a class instance

  • Use Haiku to manage model parameters

Haiku’s model parameters management

  • hamiltonian.Helper.ls_params return a dictionary containing keywords and parameters.

  • pass parameters through the first argument.

[1]:
import numpy as np
import jax
import jax.numpy as jnp

import supergrad
from supergrad.quantum_system import Fluxonium


class ExploreFluxonium(supergrad.Helper):

    def _init_quantum_system(self):
        self.fluxonium = Fluxonium(phiext=0, phi_max=5 * np.pi)

    def energy_spectrum(self, phi):
        self.fluxonium.phiext = phi * 2 * jnp.pi  # modify phiext, default 0
        return self.fluxonium.eigenenergies()


explore = ExploreFluxonium()
explore.ls_params()

[1]:
{'fluxonium': {'ec': Array(1., dtype=float32),
  'ej': Array(1., dtype=float32),
  'el': Array(1., dtype=float32)}}
[2]:
params = {
    'fluxonium': {
        'ec': jnp.array(1.68),
        'ej': jnp.array(3.5),
        'el': jnp.array(0.5)
    }
}
# each parameters should be float
explore.energy_spectrum(params, 0)

[2]:
Array([-0.24092618,  5.06891579,  7.16328215,  8.28020319, 10.81742776,
       13.5851526 , 16.36989903, 19.3794768 , 22.42063552, 25.39180377],      dtype=float64)

For a fluxonium, one could vary the external flux bias phiext and calculate the energy spectrum.

[3]:
explore.energy_spectrum(params, 0.5)

[3]:
Array([ 1.3041627 ,  2.09823523,  6.19190726,  9.14847985, 12.57031594,
       15.37689449, 17.40065635, 19.16545642, 21.37778417, 23.89876961],      dtype=float64)

Below we show how we can use Jax to transform the above function

Auto-vectorization with vmap()#

JAX has one transformation in its API: vmap(), the vectorizing map. It mapping a function along array axes(phiext), but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance.

[4]:
phi_list = np.linspace(0, 1, 20)
vmap_energy_spectrum = jax.vmap(explore.energy_spectrum, in_axes=(None, 0))
%timeit vmap_energy_spectrum(params, phi_list).block_until_ready()

1.13 s ± 129 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[5]:
phi_list = np.linspace(0, 1, 20)
def forloop_energy_spectrum(params, phi_list):
    spectrum_list = []
    for phi in phi_list:
        spectrum_list.append(explore.energy_spectrum(params, phi))
    return jnp.array(spectrum_list)

%timeit forloop_energy_spectrum(params, phi_list)

3.08 s ± 272 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Use jit() to speed up functions#

JAX runs transparently on the CPU or GPU, however, in the above example, JAX is dispatching kernels one operation at a time. If we have a sequence of operators(for example, parameters optimization), we can use jax.jit to compile multiple operations together using XLA. We can speed vmap_energy_spectrum up with jax.jit, which will jit-compile(Just-In-Time) the first time vmap_energy_spectrum is called and will be cached thereafter.

[6]:
jit_energy_spectrum = jax.jit(vmap_energy_spectrum)
spec_out = jit_energy_spectrum(params, phi_list)

[7]:
%timeit jit_energy_spectrum(params, phi_list).block_until_ready()

973 ms ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)