colossalai.zero.init_ctx.init_context
- class colossalai.zero.init_ctx.init_context.ZeroContextConfig(target_device, replicated=True, shard_param=False)[source]
The configuration used to control zero context initialization.
- Parameters:
target_device (torch.device) – The device where param data are after exiting the context.
replicated (bool, optional) – Whether the param is replicated across data parallel group. Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional) – Is param sharded after exiting the context. Defaults to False.
- class colossalai.zero.init_ctx.init_context.ZeroInitContext(target_device, shard_strategy, seed=1023, shard_param=False, default_dtype=None, model_numel_tensor=tensor([0]))[source]
A context to initialize model.
Convert the model to fp16.
The paramaters of the module are adapted to type ShardedParameter.
Shard the param and grad according to flags.
- Parameters:
target_device (torch.device) – The device where param data are after exiting the context.
shard_strategy (BaseShardStrategy) – Shard strategy instance.
seed (int, optional) – Random seed for weight initialization
shard_param (bool, optional) – Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional) – If it’s not None, parameters will be initialized as
default_dtype
then converted to fp16.model_numel_tensor (torch.Tensor, optional) – A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).