colossalai.tensor
- class colossalai.tensor.ColoTensor(data, spec)[source]
Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
The Colotensor can be initialized with a PyTorch tensor in the following ways.
>>> pg = ProcessGroup() >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec())) >>> # The tensor passed in is a tensor after sharding but not a global tensor. >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), >>> dims=[0], >>> num_partitions=[world_size]) >>> tensor_spec = ColoTensorSpec(pg, shard_spec) >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
- Parameters:
data (torch.Tensor) – a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional) – the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
- set_process_group(pg)[source]
change the pg of the ColoTensor. Note that the valid use cases is limited. It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
- Parameters:
pg (ProcessGroup) – target pg
- set_dist_spec(dist_spec)[source]
set dist spec and change the payloads.
- Parameters:
dist_spec (_DistSpec) – target dist spec.
- redistribute(dist_spec, pg=None)[source]
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the DP process group not changed.
2. If the pg is not not None and not equal to the current process group. First, convert the tensor as replicated among the TP process group. Second, reset the process group to the new pg. Third, conver the tensor (new replicated both among the tp process group) to the new dist_spec.
- Parameters:
dist_spec (_DistSpec) – the new dist spec.
pg (Optional[ProcessGroup], optional) – the new process group . Defaults to None.
- Returns:
a redistributed colotensor
- Return type:
- static from_torch_tensor(tensor, spec=None)[source]
A static method builds a ColoTensor from a PyTorch Tensor.
- Parameters:
tensor (torch.Tensor) – the pytorch tensor, which is a local tensor for this rank not a global tensor.
spec (Optional[ColoTensorSpec], optional) – tensor spec. Defaults to None.
- Returns:
a ColoTensor
- Return type:
- class colossalai.tensor.ComputeSpec(compute_pattern)[source]
The Specification for compuattion pattern
- Parameters:
compute_pattern (ComputePattern) – an Enum instance for compute pattern.
- colossalai.tensor.named_params_with_colotensor(module, prefix='', recurse=True)[source]
Returns an iterator over module parameters (together with the ColoTensor parameters), yielding both the name of the parameter as well as the parameter itself. This is typically passed to a :class:torchshard._shard.sharded_optim.ShardedOptimizer
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields:
(string, Union[Tensor, ColoTensor]) –
- Tuple containing
the name and parameter (or ColoTensor parameter)
Example
>>> model = torch.nn.Linear(*linear_size) >>> delattr(model.weight) >>> setattr(model.weight, ColoTensor(...)) >>> for name, param in named_params_with_colotensor(model): >>> if name in ['weight']: >>> print(param.size())
- class colossalai.tensor.ColoParameter(data=None, requires_grad=True, spec=None)[source]
A kind of ColoTensor to be considered as a module parameter.
- class colossalai.tensor.ColoParamOpHook[source]
Hook which is triggered by each operation when operands contain ColoParameter. To customize it, you must inherit this abstract class, and implement
pre_forward
,post_forward
,pre_backward
andpost_backward
. These four methods apply a list of ColoParameter as input args.
- class colossalai.tensor.ColoParamOpHookManager[source]
Manage your param op hooks. It only has static methods. The only static method you should call is
use_hooks(*hooks)
.
- class colossalai.tensor.ProcessGroup(rank=None, ranks=None, tp_degree=None, dp_degree=None)[source]
Process Group indicates how processes are organized in groups for parallel execution using Tensor Parallelism and Data Parallelism.
NOTE, the ProcessGroup must be used after torch.distributed.initialize()
- Parameters:
rank – the global rank of the current process.
ranks – List[int], a list of rank id belongings to this process group.
backend – str, the backend of the process group.
tp_degree – Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
dp_degree – Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
- property has_cpu_groups
If cpu groups have been initailized.
- Returns:
cpu process groups have been initialized or not.
- Return type:
bool
- rank()[source]
The current rank in the global process group.
- Returns:
the rank number
- Return type:
int
- ranks_in_group()[source]
a list of rank number in in the global process group.
- Returns:
a list of rank number.
- Return type:
List[int]
- world_size()[source]
The world size of the global process group.
- Returns:
world size
- Return type:
int
- tp_rank_list()[source]
the rank list in the TP process group containing the current rank.
- Returns:
the list of rank number.
- Return type:
List[int]
- dp_rank_list()[source]
the rank list in the DP process group containing the current rank.
- Returns:
the list of rank number.
- Return type:
List[int]
- tp_local_rank()[source]
The local rank number in the current TP process group.
- Returns:
tp rank number.
- Return type:
int
- dp_local_rank()[source]
The local rank number in the current DP process group.
- Returns:
dp rank number.
- Return type:
int
- dp_world_size()[source]
The world size of the current DP process group.
- Returns:
dp world size
- Return type:
int
- tp_world_size()[source]
The world size of the current TP process group.
- Returns:
tp world size
- Return type:
int
- dp_process_group()[source]
the pytorch DP process group containing the current rank.
- Returns:
the pytorch DP process group.
- Return type:
torch._C._distributed_c10d.ProcessGroup
- tp_process_group()[source]
the pytorch TP process group containing the current rank.
- Returns:
the pytorch TP process group.
- Return type:
torch._C._distributed_c10d.ProcessGroup
- cpu_dp_process_group()[source]
the pytorch CPU DP process group containing the current rank.
assert failed if cpu process group is not initialized.
- Returns:
the pytorch DP process group.
- Return type:
torch._C._distributed_c10d.ProcessGroup
- cpu_tp_process_group()[source]
the pytorch CPU TP process group containing the current rank.
assert failed if cpu process group is not initialized.
- Returns:
the pytorch TP process group.
- Return type:
torch._C._distributed_c10d.ProcessGroup
- class colossalai.tensor.ColoTensorSpec(pg, dist_attr=DistSpec(placement=DistPlacementPattern.REPLICATE), compute_attr=None)[source]
A data class for specifications of the ColoTensor. It contains attributes of ProcessGroup, _DistSpec, ComputeSpec. The latter two attributes are optional. If not set, they are default value is Replicate() and None.
- colossalai.tensor.ShardSpec(dims, num_partitions)[source]
A distributed specification represents the tensor is sharded among the tensor parallel process group.
Note
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
- Parameters:
dims (List[int]) – a list of dimensions
num_partitions (List[int]) – a list of partition number of each dimensions.
- Returns:
an shard dist spec instance.
- Return type:
_DistSpec
- colossalai.tensor.ReplicaSpec()[source]
A distributed specification represents the tensor is replicated among the tensor parallel process group.
- Returns:
an replicated dist spec instance.
- Return type:
_DistSpec
- class colossalai.tensor.CommSpec(comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None, forward_only=False, mix_gather=False)[source]
Communication spec is used to record the communication action. It has two main functions: 1. Compute the communication cost which will be used in auto parallel solver. 2. Convert the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the communication method, sharding_spec to determine the communication size, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis
- Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
- get_comm_cost()[source]
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is zero.
- covert_spec_to_action(tensor)[source]
Convert CommSpec into runtime action, implement real collection communication to target tensor. The collection communication action is directed by the CommSpec.
- Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
- colossalai.tensor.convert_dim_partition_dict(dim_size, dim_partition_dict)[source]
This method is used to convert the negative dim value to positive.
- colossalai.tensor.merge_same_dim_mesh_list(dim_size, dim_partition_dict)[source]
This method is used to merge the different key value which points to same physical position.
- For example:
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position. In this method, above dim_partition_dict will be converted to {1: [0, 1]}
- colossalai.tensor.colo_parameter
- colossalai.tensor.colo_tensor
- colossalai.tensor.compute_spec
- colossalai.tensor.const
- colossalai.tensor.dist_spec_mgr
- colossalai.tensor.distspec
- colossalai.tensor.op_wrapper
- colossalai.tensor.param_op_hook
- colossalai.tensor.process_group
- colossalai.tensor.tensor_spec
- colossalai.tensor.utils