XLAFSDPStrategy¶
- class lightning.fabric.strategies.XLAFSDPStrategy(accelerator=None, parallel_devices=None, checkpoint_io=None, precision=None, auto_wrap_policy=None, activation_checkpointing_policy=None, state_dict_type='sharded', sequential_save=False, **kwargs)[source]¶
- Bases: - ParallelStrategy,- _Sharded- Strategy for training multiple XLA devices using the - torch_xla.distributed.xla_fully_sharded_data_parallel.XlaFullyShardedDataParallel()method.- Warning - This is an experimental feature. - For more information check out https://github.com/pytorch/xla/blob/v2.5.0/docs/fsdp.md - Parameters:
- auto_wrap_policy¶ ( - Union[- set[- type[- Module]],- Callable[[- Module,- bool,- int],- bool],- None]) – Same as- auto_wrap_policyparameter in- torch_xla.distributed.fsdp.XlaFullyShardedDataParallel. For convenience, this also accepts a set of the layer classes to wrap.
- activation_checkpointing_policy¶ ( - Optional[- set[- type[- Module]]]) – Used when selecting the modules for which you want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. This accepts a set of the layer classes to wrap.
- state_dict_type¶ ( - Literal[- 'full',- 'sharded']) –- The format in which the state of the model and optimizers gets saved into the checkpoint. - "full": The full weights and optimizer states get assembled on rank 0 and saved to a single file.
- "sharded": Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with files for each shard in the host. Note that TPU VM multihost does not have a shared filesystem.
 
- sequential_save¶ ( - bool) – With this enabled, individual ranks consecutively save their state dictionary shards, reducing peak system RAM usage, although it elongates the saving process.
- **kwargs¶ ( - Any) – See available parameters in- torch_xla.distributed.fsdp.XlaFullyShardedDataParallel.
 
 - all_gather(tensor, group=None, sync_grads=False)[source]¶
- Function to gather a tensor from several distributed processes. 
 - all_reduce(output, group=None, reduce_op=None)[source]¶
- Reduces the given tensor (e.g. across GPUs/processes). 
 - barrier(name=None, *args, **kwargs)[source]¶
- Synchronizes all processes which blocks processes until the whole group enters this function. 
 - clip_gradients_norm(module, optimizer, max_norm, norm_type=2.0, error_if_nonfinite=True)[source]¶
- Clip gradients by norm. - Return type:
 
 - load_checkpoint(path, state=None, strict=True)[source]¶
- Given a folder, load the contents from a checkpoint and restore the state of the given objects. - The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file. 
 - module_init_context(empty_init=None)[source]¶
- A context manager wrapping the model instantiation. - Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model. 
 - module_sharded_context()[source]¶
- A context manager that goes over the instantiation of an - torch.nn.Moduleand handles sharding of parameters on creation.- By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time. - Return type:
 
 - optimizer_step(optimizer, **kwargs)[source]¶
- Overrides default tpu optimizer_step since FSDP should not call torch_xla.core.xla_model.optimizer_step. Performs the actual optimizer step. 
 - process_dataloader(dataloader)[source]¶
- Wraps the dataloader if necessary. - Parameters:
- dataloader¶ ( - DataLoader) – iterable. Ideally of type:- torch.utils.data.DataLoader
- Return type:
- MpDeviceLoader
 
 - save_checkpoint(path, state, storage_options=None, filter=None)[source]¶
- Save model, optimizer, and other state in the provided checkpoint directory. - If the user specifies sharded checkpointing, the directory will contain one file per process, with model- and optimizer shards stored per file. If the user specifies full checkpointing, the directory will contain a consolidated checkpoint combining all of the sharded checkpoints. - Return type:
 
 - setup_environment()[source]¶
- Setup any processes or distributed connections. - This must be called by the framework at the beginning of every process, before any distributed communication takes place. - Return type:
 
 - setup_module(module)[source]¶
- Performs setup for the model, e.g., by wrapping it by another class. - Return type:
 
 - setup_module_and_optimizers(module, optimizers, scheduler=None)[source]¶
- Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup. 
 - setup_optimizer(optimizer)[source]¶
- Set up an optimizer for a model wrapped with XLAFSDP. - This setup method doesn’t modify the optimizer or wrap the optimizer. The only thing it currently does is verify that the optimizer was created after the model was wrapped with - setup_module()with a reference to the flattened parameters.- Return type: