supergrad.utils.optimize.adam_opt#
- supergrad.utils.optimize.adam_opt(fun, x0, args=(), options={}, jit=True, fwd_ad=False)[source]#
The JAX-backend implementation of the Adam optimization algorithm.
- Parameters:
fun – the cost function f(params, all_params, **kwargs)
x0 – initial guess will be optimized iteratively.
args (tuple, optional) – Extra arguments passed to the objective function and its derivative. (fun, jac and hess functions)
options (dict, optional) – optimizer’s hyper parameters dictionary. The default dict is training_params = { ‘adam_lr’: 0.001, ‘adam_lr_decay_rate’: 1000, ‘steps’: 2000, ‘adam_b1’: 0.9, ‘adam_b2’: 0.999}
jit (bool) – True for just-in-time compile.
fwd_ad (bool) – True for using forward-mode auto diff.