Hands-on Example: Fluxonium Qubit#
Energy spectrum#
There are different ways to pass Fluxonium parameters (EC, EJ, EL).
Directly set parameters when creating a class instance
Use Haiku to manage model parameters
Haiku’s model parameters management
hamiltonian.Helper.ls_paramsreturn a dictionary containing keywords and parameters.pass parameters through the first argument.
[1]:
import numpy as np
import jax
import jax.numpy as jnp
import supergrad
from supergrad.quantum_system import Fluxonium
class ExploreFluxonium(supergrad.Helper):
def _init_quantum_system(self):
self.fluxonium = Fluxonium(phiext=0, phi_max=5 * np.pi)
def energy_spectrum(self, phi):
self.fluxonium.phiext = phi * 2 * jnp.pi # modify phiext, default 0
return self.fluxonium.eigenenergies()
explore = ExploreFluxonium()
explore.ls_params()
[1]:
{'fluxonium': {'ec': Array(1., dtype=float32),
'ej': Array(1., dtype=float32),
'el': Array(1., dtype=float32)}}
[2]:
params = {
'fluxonium': {
'ec': jnp.array(1.68),
'ej': jnp.array(3.5),
'el': jnp.array(0.5)
}
}
# each parameters should be float
explore.energy_spectrum(params, 0)
[2]:
Array([-0.24092618, 5.06891579, 7.16328215, 8.28020319, 10.81742776,
13.5851526 , 16.36989903, 19.3794768 , 22.42063552, 25.39180377], dtype=float64)
For a fluxonium, one could vary the external flux bias phiext and calculate the energy spectrum.
[3]:
explore.energy_spectrum(params, 0.5)
[3]:
Array([ 1.3041627 , 2.09823523, 6.19190726, 9.14847985, 12.57031594,
15.37689449, 17.40065635, 19.16545642, 21.37778417, 23.89876961], dtype=float64)
Below we show how we can use Jax to transform the above function
Auto-vectorization with vmap()#
JAX has one transformation in its API: vmap(), the vectorizing map. It mapping a function along array axes(phiext), but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance.
[4]:
phi_list = np.linspace(0, 1, 20)
vmap_energy_spectrum = jax.vmap(explore.energy_spectrum, in_axes=(None, 0))
%timeit vmap_energy_spectrum(params, phi_list).block_until_ready()
1.13 s ± 129 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[5]:
phi_list = np.linspace(0, 1, 20)
def forloop_energy_spectrum(params, phi_list):
spectrum_list = []
for phi in phi_list:
spectrum_list.append(explore.energy_spectrum(params, phi))
return jnp.array(spectrum_list)
%timeit forloop_energy_spectrum(params, phi_list)
3.08 s ± 272 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Use jit() to speed up functions#
JAX runs transparently on the CPU or GPU, however, in the above example, JAX is dispatching kernels one operation at a time. If we have a sequence of operators(for example, parameters optimization), we can use jax.jit to compile multiple operations together using XLA. We can speed vmap_energy_spectrum up with jax.jit, which will jit-compile(Just-In-Time) the first time vmap_energy_spectrum is called and will be cached thereafter.
[6]:
jit_energy_spectrum = jax.jit(vmap_energy_spectrum)
spec_out = jit_energy_spectrum(params, phi_list)
[7]:
%timeit jit_energy_spectrum(params, phi_list).block_until_ready()
973 ms ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)