colossalai.zero.init_ctx.init_context

class colossalai.zero.init_ctx.init_context.ZeroContextConfig(target_device, replicated=True, shard_param=False)[source]

The configuration used to control zero context initialization.

Parameters:
  • target_device (torch.device) – The device where param data are after exiting the context.

  • replicated (bool, optional) – Whether the param is replicated across data parallel group. Some parameters are not replicated, e.g. parameters in MOE experts.

  • shard_param (bool, optional) – Is param sharded after exiting the context. Defaults to False.

class colossalai.zero.init_ctx.init_context.ZeroInitContext(target_device, shard_strategy, seed=1023, shard_param=False, default_dtype=None, model_numel_tensor=tensor([0]))[source]

A context to initialize model.

  1. Convert the model to fp16.

  2. The paramaters of the module are adapted to type ShardedParameter.

  3. Shard the param and grad according to flags.

Parameters:
  • target_device (torch.device) – The device where param data are after exiting the context.

  • shard_strategy (BaseShardStrategy) – Shard strategy instance.

  • seed (int, optional) – Random seed for weight initialization

  • shard_param (bool, optional) – Is param sharded after exiting the context. Defaults to False.

  • default_dtype (torch.dtype, optional) – If it’s not None, parameters will be initialized as default_dtype then converted to fp16.

  • model_numel_tensor (torch.Tensor, optional) – A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).

static calc_fanin_fanout(tensor)[source]

We use this function to substitute fan-in and fan-out calculation in torch.nn.init. This can help us get correct fan-in and fan-out for sharded tensor.