Source code for egc.utils.model_management

"""Model Management"""
from pathlib import Path
from pathlib import PurePath
from typing import Tuple

import torch


[docs]def get_checkpoint_path(model_filename: str) -> PurePath: model_path: PurePath = Path(f"checkpoints/{model_filename}.pt") if not model_path.parent.exists(): model_path.parent.mkdir(parents=True, exist_ok=True) return model_path
def _get_file_path(model_filename: str) -> PurePath: model_path: PurePath = Path(f"checkpoints/{model_filename}.pt") if not model_path.parent.exists(): model_path.parent.mkdir(parents=True, exist_ok=True) return model_path
[docs]def save_model( model_filename: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer, current_epoch: int, loss: float, ) -> None: """Save model, optimizer, current_epoch, loss to ``checkpoints/${model_filename}.pt``. Args: model_filename (str): filename to save model. model (torch.nn.Module): model. optimizer (torch.optim.Optimizer): optimizer. current_epoch (int): current epoch. loss (float): loss. """ model_path = _get_file_path(model_filename) torch.save( { "epoch": current_epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss": loss, }, model_path, )
[docs]def load_model( model_filename: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer, ) -> Tuple[torch.nn.Module, torch.optim.Optimizer, int, float]: """Load model from ``checkpoints/${model_filename}.pt``. Args: model_filename (str): filename to load model. model (torch.nn.Module): model. optimizer (torch.optim.Optimizer): optimizer. Returns: Tuple[torch.nn.Module, torch.optim.Optimizer, int, float]: [model, optimizer, epoch, loss] """ model_path = _get_file_path(model_filename) checkpoint = torch.load(model_path) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) epoch = checkpoint["epoch"] loss = checkpoint["loss"] return model, optimizer, epoch, loss