ModelParallelStrategy¶
- class lightning.fabric.strategies.ModelParallelStrategy(parallelize_fn, data_parallel_size='auto', tensor_parallel_size='auto', save_distributed_checkpoint=True, process_group_backend=None, timeout=datetime.timedelta(seconds=1800))[source]¶
- Bases: - ParallelStrategy- Enables user-defined parallelism applied to a model. - Warning - This is an experimental feature. - Currently supports up to 2D parallelism. Specifically, it supports the combination of Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still experimental in PyTorch. Requires PyTorch 2.4 or newer. - Parameters:
- parallelize_fn¶ ( - Callable[[- TypeVar(- TModel, bound=- Module),- DeviceMesh],- TypeVar(- TModel, bound=- Module)]) – A function that applies parallelisms to a module. The strategy will provide the model and device mesh as input.
- data_parallel_size¶ ( - Union[- Literal[- 'auto'],- int]) – The number of devices within a data-parallel group. Defaults to- "auto", which sets this size to the number of nodes in the cluster.
- tensor_parallel_size¶ ( - Union[- Literal[- 'auto'],- int]) – The number of devices within a tensor-parallel group. Defaults to- "auto", which sets this size to the number of GPUs in a single node.
- save_distributed_checkpoint¶ ( - bool) – If- True, each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the world size. If- False, the full weights and optimizer states get assembled on rank 0 and saved to a single file.
 
 - all_reduce(tensor, group=None, reduce_op='mean')[source]¶
- Reduces the given tensor (e.g. across GPUs/processes). 
 - barrier(*args, **kwargs)[source]¶
- Synchronizes all processes which blocks processes until the whole group enters this function. 
 - load_checkpoint(path, state=None, strict=True)[source]¶
- Load the contents from a checkpoint and restore the state of the given objects. 
 - module_init_context(empty_init=None)[source]¶
- A context manager wrapping the model instantiation. - Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model. 
 - save_checkpoint(path, state, storage_options=None, filter=None)[source]¶
- Save model, optimizer, and other state to a checkpoint on disk. - If distributed checkpointing is enabled (default), the checkpoint gets saved as a directory containing one file per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file meta.pt with the rest of the user’s state (only saved from rank 0). If distributed checkpointing is disabled ( - save_distributed_checkpoint=False), the checkpoint will be written to a single file containing the weights, optimizer state and other metadata.- Return type:
 
 - setup_environment()[source]¶
- Setup any processes or distributed connections. - This must be called by the framework at the beginning of every process, before any distributed communication takes place. - Return type:
 
 - setup_module(module)[source]¶
- Performs setup for the model, e.g., by wrapping it by another class. - Return type: