colossalai.nn.optimizer.fused_adam

class colossalai.nn.optimizer.fused_adam.FusedAdam(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, adamw_mode=True, weight_decay=0.0, amsgrad=False, set_grad_none=True)[source]

Implements Adam algorithm.

FusedAdam requires CUDA extensions which can be built during installation or runtime.

This version of fused Adam implements 2 fusions.

  • Fusion of the Adam update’s elementwise operations

  • A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.

colossalai.nn.optimizer.FusedAdam may be used as a drop-in replacement for torch.optim.AdamW, or torch.optim.Adam with adamw_mode=False

colossalai.nn.optimizer.FusedAdam may be used with or without Amp.

Adam was been proposed in Adam: A Method for Stochastic Optimization.

Parameters:
  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: False) NOT SUPPORTED in FusedAdam!

  • adamw_mode (boolean, optional) – Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True)

  • set_grad_none (bool, optional) – whether set grad to None when zero_grad() method is called. (default: True)

step(closure=None, grads=None, output_params=None, scale=None, grad_norms=None, div_scale=-1)[source]

Performs a single optimization step.

Parameters:

closure (callable, optional) – A closure that reevaluates the model and returns the loss.

The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.