BaseFinetuning¶
- class lightning.pytorch.callbacks.BaseFinetuning[source]¶
- Bases: - Callback- This class implements the base logic for writing your own Finetuning Callback. - Warning - This is an experimental feature. - Override - freeze_before_trainingand- finetune_functionmethods with your own logic.- freeze_before_training: This method is called before- configure_optimizers
- and should be used to freeze any modules parameters. 
- finetune_function: This method is called on every train epoch start and should be used to
- unfreezeany parameters. Those parameters need to be added in a new- param_groupwithin the optimizer.
 - Note - Make sure to filter the parameters based on - requires_grad.- Example: - >>> from torch.optim import Adam >>> class MyModel(pl.LightningModule): ... def configure_optimizer(self): ... # Make sure to filter the parameters based on `requires_grad` ... return Adam(filter(lambda p: p.requires_grad, self.parameters())) ... >>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning): ... def __init__(self, unfreeze_at_epoch=10): ... super().__init__() ... self._unfreeze_at_epoch = unfreeze_at_epoch ... ... def freeze_before_training(self, pl_module): ... # freeze any module you want ... # Here, we are freezing `feature_extractor` ... self.freeze(pl_module.feature_extractor) ... ... def finetune_function(self, pl_module, current_epoch, optimizer): ... # When `current_epoch` is 10, feature_extractor will start training. ... if current_epoch == self._unfreeze_at_epoch: ... self.unfreeze_and_add_param_group( ... modules=pl_module.feature_extractor, ... optimizer=optimizer, ... train_bn=True, ... ) - static filter_on_optimizer(optimizer, params)[source]¶
- This function is used to exclude any parameter which already exists in this optimizer. 
 - static filter_params(modules, train_bn=True, requires_grad=True)[source]¶
- Yields the requires_grad parameters of a given module or list of modules. - Parameters:
- Return type:
- Returns:
- Generator 
 
 - finetune_function(pl_module, epoch, optimizer)[source]¶
- Override to add your unfreeze logic. - Return type:
 
 - static flatten_modules(modules)[source]¶
- This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves. 
 - load_state_dict(state_dict)[source]¶
- Called when loading a checkpoint, implement to reload callback state given callback’s - state_dict.
 - setup(trainer, pl_module, stage)[source]¶
- Called when fit, validate, test, predict, or tune begins. - Return type:
 
 - static unfreeze_and_add_param_group(modules, optimizer, lr=None, initial_denom_lr=10.0, train_bn=True)[source]¶
- Unfreezes a module and adds its parameters to an optimizer. - Parameters:
- modules¶ ( - Union[- Module,- Iterable[- Union[- Module,- Iterable]]]) – A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group.
- optimizer¶ ( - Optimizer) – The provided optimizer will receive new parameters and will add them to add_param_group
- lr¶ ( - Optional[- float]) – Learning rate for the new param group.
- initial_denom_lr¶ ( - float) – If no lr is provided, the learning from the first param group will be used and divided by initial_denom_lr.
- train_bn¶ ( - bool) – Whether to train the BatchNormalization layers.
 
- Return type: