class colossalai.nn.parallel.ColoDDP(module, process_group, bucket_cap_mb=25, rebuild_bucket=True)[source]

Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.


>>> from colossalai.core import global_context as gpc
>>> from colossalai.context import ParallelMode
>>> model = torch.nn.Linear(20, 1)
>>> pg = ProcessGroup(tp_degree = world_size//2)
>>> model = ColoDDP(model, pg)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
  • module (torch.nn.Module) – Module to apply DDP.

  • process_group (Optional[dist.ProcessGroup], optional) – The process group which DDP uses. If it’s None, the default data parallel group will be used. Defaults to None.

static set_params_to_ignore(params_to_ignore)[source]

Sets parameters to be ignored by DDP. This method must be called before initializing ColoDDP.


>>> params_to_ignore = []
>>> for p in module.parameters():
>>>     if should_ignore(p):
>>>         params_to_ignore.append(p)
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
>>> module = ColoDDP(module)

params_to_ignore (Iterable[torch.Tensor]) – A list of parameters to be ignored.

class colossalai.nn.parallel.ZeroDDP(module, gemini_manager, pin_memory=False, force_outputs_fp32=False, strict_ddp_mode=False)[source]

ZeRO DDP for ColoTensor. Warning: Nested ZeroDDP is not supported now. It is designed to be used with ChunkManager and GeminiManager. For more details, see the API reference of ChunkManager and GeminiManager.

  • module (torch.nn.Module) – Module to apply ZeRO-DP.

  • gemini_manager (GeminiManager) – Manages the chunk manager and heterogeneous momery space. For more details, see the API reference of GeminiManager.

  • pin_memory (bool) – Chunks on CPU Memory use pin-memory.

  • force_outputs_fp32 (bool) – If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.

  • strict_ddp_mode (bool) – If set to True, there is no tensor sharding, each tensor is replicated. Defaults to False. Users can set it to True, when they clearly know that they only need DDP.

state_dict(destination=None, prefix='', keep_vars=False, only_rank_0=True)[source]

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Warning: The non strict state dict would ignore the parameters if the tensors of the parameters

are shared with other parameters which have been included in the dictionary. When you need to load the state dict, you should set the argument strict to False.


a dictionary containing a whole state of the module

Return type:


load_state_dict(state_dict, strict=True)[source]

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True


  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields


If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

colossalai.nn.parallel.zero_model_wrapper(model, zero_stage=1, gemini_config=None)[source]

This wrapper function is used to wrap your training model for ZeRO DDP.


>>> with ColoInitContext():
>>>     my_model = Bert()
>>> my_optim = SGD(my_model.parameters(), lr = 1e-3)
>>> zero_model = zero_model_wrapper(my_model, zero_stage=1)
>>> zero_optim = zero_optim_wrapper(zero_model, my_optim)
  • model (nn.Module) – The model used in ZeRO DDP.

  • zero_stage (int, optional) – The stage of ZeRO DDP. You can find more information in ZeRO’s paper.

  • gemini_config (dict, optional) –

    The configuration dictionary of GeminiDDP. GeminiDDP is enabled when the stage is set to 3. You can set the arguemnts of GeminiDDP in the gemini_config. Here is an example where we set the device of the model, the placement policy of Gemini, and the size of hidden dimension to help Gemini find out a unified chunk size.


    >>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
    >>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)

colossalai.nn.parallel.zero_optim_wrapper(model, optimizer, initial_scale=65536, growth_factor=2, backoff_factor=0.5, growth_interval=1000, hysteresis=2, min_scale=1, max_scale=4294967296, max_norm=0.0, norm_type=2.0, optim_config=None)[source]

This wrapper function is used to wrap your training optimizer for ZeRO DDP.

  • model (nn.Module) – Your model wrapped by zero_model_wrapper

  • optimizer (torch.optim.Optimizer) – Your initialized optimizer

  • initial_scale (float, optional) – initial_scale used by DynamicGradScaler.

  • min_scale (float, optional) – min_scale used by DynamicGradScaler.

  • growth_factor (float, optional) – growth_factor used by DynamicGradScaler.

  • backoff_factor (float, optional) – backoff_factor used by DynamicGradScaler.

  • growth_interval (float, optional) – growth_interval used by DynamicGradScaler.

  • hysteresis (float, optional) – hysteresis used by DynamicGradScaler.

  • max_scale (int, optional) – max_scale used by DynamicGradScaler.

  • max_norm (float, optional) – max_norm used for clip_grad_norm. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.

  • norm_type (float, optional) – norm_type used for clip_grad_norm.

  • optim_config (dict, optinoal) –

    The configuration used for the ZeRO optimizer. .. rubric:: Example

    >>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
    >>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)