Helper for bucketing multiple reduce-scatter operations on small tensors into larger reduce-scatter ops to improve communication efficiency.


bucketer = ReduceScatterBucketer()
    small_tensors, callback_fn=lambda result: print("small")
    big_tensors, callback_fn=lambda result: print("big")
    more_small_tensors, callback_fn=lambda result: print("small2")
bucketer.flush()  # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2

bucket_size_mb (int, Optional) – bucket size for communicating. Buckets are sub-divided based on world_size. Values <= 0 disable bucketing.

reduce_scatter_async(input_list, group, callback_fn=None)[source]

Reduce-scatter a list of tensors asynchronously, so smaller reductions can be bucketed together. The given callback (callback_fn) will be called with the reduced result at some later time. Call flush() to force all queued ops and callbacks to be executed.

Note that large inputs will be reduced immediately, and this function may also flush the relevant bucket to make room for input_list.

  • input_list (List[Tensor]) – list of tensors to reduce-scatter. List should contain group.size() tensors and each tensor should have identical shape, dtype and device.

  • group (ProcessGroup) – process group for reduction

  • callback_fn (Callable, Optional) – callback function to call after the reduction executes. Function will be called with a single argument corresponding to the reduced result.


Reduce-scatter any partial buckets.


Free buffers from all buckets.