supergrad.time_evolution.ode_expm

Contents

supergrad.time_evolution.ode_expm#

supergrad.time_evolution.ode_expm(func, y0, ts, *args, astep=100, trotter_order=None, progress_bar=False, custom_vjp=True, pb_fwd_ad=False, compatibility_mode=False)[source]#

ODE solver using the matrix exponentiation for the propagators at each time step.

Parameters:
  • func – function to evaluate the time derivative of the solution y at time t as func(y, t, *args), producing the same shape/structure as y0.

  • y0 – array or pytree of arrays representing the initial value for the state.

  • ts – array of float times for evaluation, like jnp.linspace(0., 10., 101), in which the values must be strictly increasing.

  • *args – tuple of additional arguments for func, which must be arrays scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types.

  • astep – int, absolute number of steps to take for each timepoint (optional).

  • trotter_order (complex int) – the order of suzuki-trotter decomposition. The following arguments are supported, a) None, calculating matrix exponentiation without trotter decomposition b) 1, first order trotter decomposition c) 2, second order trotter decomposition d) 4, 4th order real decomposition e) 4j, 4th order complex decomposition

  • progress_bar (bool) – whether to display a progress bar (optional).

  • custom_vjp (string) – choose custom automatic differentiation VJP rule. Default is None, which means using the JAX framework to derive VJP. The following arguments are supported: a) True or LCAM (Recommended) : using the local continuous adjoint method to compute the inverse time evolution for both state and adjoint state. b) CAM : using the continuous adjoint method.

  • pb_fwd_ad (bool) – whether to config progress bar as forward mode automatic differentiation.

  • compatibility_mode (bool) – whether to use the compatible mode for func. We disable compatible mode to reduce computational cost when the evolution operator is not depend on y, so we could let y equals to 0 in the func.

Returns:

Values of the solution y (i.e. integrated system values) at each time point in t, represented as an array (or pytree of arrays) with the same shape/s