colossalai.zero.init_ctx

class colossalai.zero.init_ctx.ZeroInitContext(target_device, shard_strategy, seed=1023, shard_param=False, default_dtype=None, model_numel_tensor=tensor([0]))[source]

A context to initialize model.

  1. Convert the model to fp16.

  2. The paramaters of the module are adapted to type ShardedParameter.

  3. Shard the param and grad according to flags.

Parameters:
  • target_device (torch.device) – The device where param data are after exiting the context.

  • shard_strategy (BaseShardStrategy) – Shard strategy instance.

  • seed (int, optional) – Random seed for weight initialization

  • shard_param (bool, optional) – Is param sharded after exiting the context. Defaults to False.

  • default_dtype (torch.dtype, optional) – If it’s not None, parameters will be initialized as default_dtype then converted to fp16.

  • model_numel_tensor (torch.Tensor, optional) – A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).

static calc_fanin_fanout(tensor)[source]

We use this function to substitute fan-in and fan-out calculation in torch.nn.init. This can help us get correct fan-in and fan-out for sharded tensor.