class, shard_strategy, seed=1023, shard_param=False, default_dtype=None, model_numel_tensor=tensor([0]))[source]

A context to initialize model.

  1. Convert the model to fp16.

  2. The paramaters of the module are adapted to type ShardedParameter.

  3. Shard the param and grad according to flags.

  • target_device (torch.device) – The device where param data are after exiting the context.

  • shard_strategy (BaseShardStrategy) – Shard strategy instance.

  • seed (int, optional) – Random seed for weight initialization

  • shard_param (bool, optional) – Is param sharded after exiting the context. Defaults to False.

  • default_dtype (torch.dtype, optional) – If it’s not None, parameters will be initialized as default_dtype then converted to fp16.

  • model_numel_tensor (torch.Tensor, optional) – A tensor which will store the number of elements of model. Defaults to torch.zeros(1,

static calc_fanin_fanout(tensor)[source]

We use this function to substitute fan-in and fan-out calculation in torch.nn.init. This can help us get correct fan-in and fan-out for sharded tensor.