zamba.pytorch.finetuning¶
Classes¶
BackboneFinetuning
¶
Bases: pl.callbacks.finetuning.BackboneFinetuning
Derived from PTL's built-in BackboneFinetuning
, but during the backbone freeze phase,
choose whether to freeze batch norm layers, even if train_bn
is True (i.e., even if we train them
during the backbone unfreeze phase).
Finetune a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate
and should_align
is set to True, it will align with it for the rest of the training.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
unfreeze_backbone_at_epoch |
Epoch at which the backbone will be unfreezed. |
required | |
lambda_func |
Scheduling function for increasing backbone learning rate. |
required | |
backbone_initial_ratio_lr |
Used to scale down the backbone learning rate compared to rest of model |
required | |
backbone_initial_lr |
Optional, Inital learning rate for the backbone. By default, we will use current_learning / backbone_initial_ratio_lr |
required | |
should_align |
Wheter to align with current learning rate when backbone learning reaches it. |
required | |
initial_denom_lr |
When unfreezing the backbone, the intial learning rate will current_learning_rate / initial_denom_lr. |
required | |
train_bn |
Wheter to make Batch Normalization trainable. |
required | |
verbose |
Display current learning rate for model and backbone |
required | |
round |
Precision for displaying learning rate |
required |
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
Source code in zamba/pytorch/finetuning.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
|
Attributes¶
pre_train_bn = pre_train_bn
instance-attribute
¶
Functions¶
__init__(*args, multiplier: Optional[float] = 1, pre_train_bn: bool = False, **kwargs)
¶
Source code in zamba/pytorch/finetuning.py
64 65 66 67 68 69 70 71 |
|
freeze_before_training(pl_module: 'pl.LightningModule')
¶
Source code in zamba/pytorch/finetuning.py
73 74 |
|
Functions¶
multiplier_factory(rate: float)
¶
Returns a function that returns a constant value for use in computing a constant learning rate multiplier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rate |
float
|
Constant multiplier. |
required |
Source code in zamba/pytorch/finetuning.py
5 6 7 8 9 10 11 12 13 14 15 16 |
|