WeightAveraging¶
- class lightning.pytorch.callbacks.WeightAveraging(device=None, use_buffers=True, **kwargs)[source]¶
- Bases: - Callback- A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) after each training step. - Arguments given to the constructor will be passed to the - AveragedModelconstructor. If no- deviceis specified, the device of the original model will be used. Contrary to- AveragedModel,- use_buffersis set to- Trueby default. That is, by default the callback will compute running averages for both the parameters and the buffers of the model. Setting- use_buffersto- Falsewill cause only the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using- torch.optim.swa_utils.update_bn()).- You can provide a custom averaging function with the - avg_fnor- multi_avg_fnparameter. See the- AveragedModelclass for details. If no averaging function is provided, the default is to compute the equally-weighted average of the weights (SWA).- You can customize when the average model is updated by overriding the - should_update()method. The callback calls it with either- step_idxor- epoch_idxand the method returns a boolean indicating whether to update after the given step or epoch. The default is to update after every step.- During validation and after the training finishes, the current model parameters will be replaced with the averaged values. - See also the documentation on the weight averaging callbacks provided by Lightning. - Note - To ensure that the - AveragedModelwill contain all layers,- setup()will call- configure_model()before instantiating the- AveragedModel. However, that hook is not called in a strategy aware context, sharded models do not work with weight averaging, and a warning will be issued.- Example: - from lightning.pytorch.callbacks import WeightAveraging from torch.optim.swa_utils import get_ema_avg_fn class EMAWeightAveraging(WeightAveraging): def __init__(self): super().__init__(avg_fn=get_ema_avg_fn()) def should_update(self, step_idx=None, epoch_idx=None): # Start after 100 steps. return (step_idx is not None) and (step_idx >= 100) trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10) trainer.fit(model, dataloader) - Parameters:
- device¶ ( - Union[- device,- str,- int,- None]) – By default, the- AveragedModelwill be stored on the same device as the original model. If the- deviceargument is provided, the- AveragedModelwill be stored on this device instead. If you run out of GPU memory, you might want to use- "cpu".
- use_buffers¶ ( - bool) – If- False, the buffers of the model will not be averaged.
- kwargs¶ ( - Any) – Additional keyword arguments to be passed to the- AveragedModelconstructor, such as- avg_fnor- multi_avg_fn.
 
 - load_state_dict(state_dict)[source]¶
- Called when loading a checkpoint. - Reloads the callback state given a - state_dict.
 - on_load_checkpoint(trainer, pl_module, checkpoint)[source]¶
- Called when loading a model checkpoint. - Loads the current model and the - AveragedModelparameters from the checkpoint.
 - on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
- Called when saving a checkpoint. - Moves the current model state to the key - current_model_state, and places the average model state in- state_dictinstead. Any other state variables of the- AveragedModelwill be saved in- averaging_state.
 - on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
- Called when a training batch ends. - Updates the - AveragedModelparameters, if requested by- self.should_update().- Parameters:
- Return type:
 
 - on_train_end(trainer, pl_module)[source]¶
- Called when training ends. - Transfers parameters from the - AveragedModelto the current model.- Parameters:
- pl_module¶ ( - LightningModule) – The current- LightningModuleinstance.
 
- Return type:
 
 - on_train_epoch_end(trainer, pl_module)[source]¶
- Called when a training epoch ends. - Updates the - AveragedModelparameters, if requested by- self.should_update().- Parameters:
- pl_module¶ ( - LightningModule) – The current- LightningModuleinstance.
 
- Return type:
 
 - on_validation_epoch_end(trainer, pl_module)[source]¶
- Called when a validation epoch ends. - Recovers the current model parameters from the - AveragedModel.- Parameters:
- pl_module¶ ( - LightningModule) – The current- LightningModuleinstance.
 
- Return type:
 
 - on_validation_epoch_start(trainer, pl_module)[source]¶
- Called when a validation epoch begins. - Transfers parameter values from the - AveragedModelto the current model.- Parameters:
- pl_module¶ ( - LightningModule) – The current- LightningModuleinstance.
 
- Return type:
 
 - setup(trainer, pl_module, stage)[source]¶
- Called when fit, validate, test, predict, or tune begins. - Creates an - AveragedModelwhen fit begins.- Parameters:
- pl_module¶ ( - LightningModule) – The current- LightningModuleinstance.
 
- Return type:
 
 - should_update(step_idx=None, epoch_idx=None)[source]¶
- Called after every optimizer step and after every training epoch to check whether the average model should be updated. - One of the arguments is set to the zero-based index of the last training step or epoch. The default implementation returns - Truewhen any- step_idxis provided. The user can customize when the average model gets updated by overriding this method.