supergrad.utils.optimize.adam_opt

Contents

supergrad.utils.optimize.adam_opt#

supergrad.utils.optimize.adam_opt(fun, x0, args=(), options={}, jit=True, fwd_ad=False)[source]#

The JAX-backend implementation of the Adam optimization algorithm.

Parameters:
  • fun – the cost function f(params, all_params, **kwargs)

  • x0 – initial guess will be optimized iteratively.

  • args (tuple, optional) – Extra arguments passed to the objective function and its derivative. (fun, jac and hess functions)

  • options (dict, optional) – optimizer’s hyper parameters dictionary. The default dict is training_params = { ‘adam_lr’: 0.001, ‘adam_lr_decay_rate’: 1000, ‘steps’: 2000, ‘adam_b1’: 0.9, ‘adam_b2’: 0.999}

  • jit (bool) – True for just-in-time compile.

  • fwd_ad (bool) – True for using forward-mode auto diff.