Pytorch
Class: Trainer#
The Trainer
class in the training.pytorch
module is designed to facilitate efficient and effective training of PyTorch models, particularly on single or no GPU systems. It incorporates various functionalities such as early stopping, model saving, metrics evaluation, and optional loss plotting to streamline the training process.
Constructor#
def __init__(
self, model: torch.nn.Module, train_data: DataLoader, test_data: DataLoader,
optimizer: torch.optim.Optimizer, best_model_path: str,
early_stopping: bool = True, patience: int = 10, loss: Callable = None,
metrics: Dict[str, Callable] = None, epochs: int = 10, mlflow_run=None,
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
create_loss_plot: bool = False,
)
Parameters#
- model (
torch.nn.Module
): The PyTorch model to be trained. - train_data (
DataLoader
): DataLoader for the training dataset. - test_data (
DataLoader
): DataLoader for the validation/test dataset. - optimizer (
torch.optim.Optimizer
): Optimizer used for training the model. - best_model_path (
str
): Path where the best model according to validation loss is saved. - early_stopping (
bool
): Indicates whether training should stop early if there's no improvement, with a default setting of True - patience (
int
): Number of epochs to wait for improvement in validation loss before early stopping. Defaults to10
. - loss (
Callable
): Loss function used during training. Must be specified. - metrics (
Dict[str, Callable]
): Dictionary containing metrics to be evaluated during validation. - epochs (
int
): Total number of epochs to train. Defaults to10
. - mlflow_run (optional): Optional MLflow run instance to log training parameters, metrics, and models. Default is
None
. - device (
torch.device
): Device on which to train the model (cuda
orcpu
). Automatically set based on availability. - create_loss_plot (
bool
): IfTrue
, generates a plot for training and validation losses after training. Defaults toFalse
.
train#
Conducts the training process across all epochs, handles early stopping, and loads the best model state at the end.
train(self) -> torch.nn.Module
Parameters#
None
: This method has no input parameters.
Returns#
torch.nn.Module
: The trained model with the best performance on validation data.
Example#
# Assuming model, train_loader, and test_loader are predefined:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
trainer = Trainer(model, train_loader, test_loader, optimizer, "path/to/save/best_model.pth")
trained_model = trainer.train()