FullyShardedDataParallel¶
-
class
torch.distributed.fsdp.
FullyShardedDataParallel
(module, process_group=None, cpu_offload=None, fsdp_auto_wrap_policy=None, backward_prefetch=None)[source]¶ A wrapper for sharding Module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage 3 from DeepSpeed. FullyShardedDataParallel is commonly shorten to FSDP.
Example:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
Warning
The optimizer must be initialized after the module has been wrapped, since FSDP will shard parameters in-place and this will break any previously initialized optimizers.
- Parameters
module (nn.Module) – module to be wrapped with FSDP.
process_group (Optional[ProcessGroup]) – process group for sharding
cpu_offload (Optional [CPUOffload]) – CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in
cpu_offload=CPUOffload(offload_params=True)
. Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on same device to work with optimizer. This API is subject to change. Default isNone
in which case there will be no offloading.fsdp_auto_wrap_policy –
(Optional [callable]): A callable specifying a policy to recursively wrap layers with FSDP. Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance.
default_auto_wrap_policy
written intorch.distributed.fsdp.wrap
is an example offsdp_auto_wrap_policy
callable, this policy wraps layers with parameter sizes larger than 100M. Users can supply the customizedfsdp_auto_wrap_policy
callable that should accept following arguments:module: nn.Module
,recurse: bool
,unwrapped_params: int
, extra customized arguments could be added to the customizedfsdp_auto_wrap_policy
callable as well.Example:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> unwrapped_params: int, >>> # These are customizable for this policy function. >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return unwrapped_params >= min_num_params
backward_prefetch – (Optional[BackwardPrefetch]): This is an experimental feature that is subject to change in the the near future. It allows users to enable two different backward_prefetch algorithms to help backward communication and computation overlapping. Pros and cons of each algorithm is explained in the class
BackwardPrefetch
.
-
property
module
¶ make model.module accessible, just like DDP.