colossalai.zero.init_ctx
- class colossalai.zero.init_ctx.ZeroInitContext(target_device, shard_strategy, seed=1023, shard_param=False, default_dtype=None, model_numel_tensor=tensor([0]))[source]
A context to initialize model.
Convert the model to fp16.
The paramaters of the module are adapted to type ShardedParameter.
Shard the param and grad according to flags.
- Parameters:
target_device (torch.device) – The device where param data are after exiting the context.
shard_strategy (BaseShardStrategy) – Shard strategy instance.
seed (int, optional) – Random seed for weight initialization
shard_param (bool, optional) – Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional) – If it’s not None, parameters will be initialized as
default_dtype
then converted to fp16.model_numel_tensor (torch.Tensor, optional) – A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).