core.training

  1import logging
  2from typing import Any
  3
  4import torch
  5from torch import nn
  6from tqdm import tqdm
  7
  8import wandb
  9from core.distribution.utils import DistributionT, compute_kl
 10from core.model import bounded_call
 11from core.objective import AbstractObjective
 12
 13
 14def __raise_exception_on_invalid_value(value: torch.Tensor):
 15    """
 16    Raise a ValueError if the given tensor is None or contains NaN values.
 17
 18    This helper function is used to catch numerical issues that might arise
 19    during training (e.g., invalid gradients, extreme KL values).
 20
 21    Args:
 22        value (torch.Tensor): A scalar tensor to check for validity.
 23
 24    Raises:
 25        ValueError: If `value` is None or if any entry is NaN.
 26    """
 27    if value is None or torch.isnan(value).any():
 28        raise ValueError(f"Invalid value {value}")
 29
 30
 31def train(
 32    model: nn.Module,
 33    posterior: DistributionT,
 34    prior: DistributionT,
 35    objective: AbstractObjective,
 36    train_loader: torch.utils.data.dataloader.DataLoader,
 37    val_loader: torch.utils.data.dataloader.DataLoader,
 38    parameters: dict[str, Any],
 39    device: torch.device,
 40    wandb_params: dict = None,
 41):
 42    """
 43    Train a probabilistic neural network by optimizing a PAC-Bayes-inspired objective.
 44
 45    At each iteration:
 46      1) Optionally clamp model outputs using `bounded_call` if `pmin` is provided in `parameters`.
 47      2) Compute KL divergence between posterior and prior.
 48      3) Compute the empirical loss (NLL by default).
 49      4) Combine loss and KL via the given `objective`.
 50      5) Backpropagate and update model parameters.
 51
 52    Logs intermediate results (objective, loss, KL) to Python's logger and optionally to wandb.
 53
 54    Args:
 55        model (nn.Module): The probabilistic neural network to train.
 56        posterior (DistributionT): The current (learnable) posterior distribution.
 57        prior (DistributionT): The (fixed or partially learnable) prior distribution.
 58        objective (AbstractObjective): An object that merges empirical loss and KL
 59            into a single differentiable objective.
 60        train_loader (DataLoader): Dataloader for the training dataset.
 61        val_loader (DataLoader): Dataloader for the validation dataset (currently unused here).
 62        parameters (Dict[str, Any]): A dictionary of training hyperparameters, which can include:
 63            - 'lr': Learning rate.
 64            - 'momentum': Momentum term for SGD.
 65            - 'epochs': Number of epochs.
 66            - 'num_samples': Usually the size of the training set (or mini-batch size times steps).
 67            - 'seed': Random seed (optional).
 68            - 'pmin': Minimum probability for bounding (optional).
 69        device (torch.device): The device (CPU or GPU) for training.
 70        wandb_params (Dict, optional): Configuration for logging to Weights & Biases. Expects keys:
 71            - "log_wandb": bool, whether to log or not
 72            - "name_wandb": str, run name / prefix for logging
 73
 74    Returns:
 75        None: The model (and its posterior) are updated in-place over the specified epochs.
 76    """
 77    criterion = torch.nn.NLLLoss()
 78    optimizer = torch.optim.SGD(
 79        model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"]
 80    )
 81
 82    if "seed" in parameters:
 83        torch.manual_seed(parameters["seed"])
 84    for epoch in range(parameters["epochs"]):
 85        for _i, (data, target) in tqdm(enumerate(train_loader)):
 86            data, target = data.to(device), target.to(device)
 87            optimizer.zero_grad()
 88            if "pmin" in parameters:
 89                output = bounded_call(model, data, parameters["pmin"])
 90            else:
 91                output = model(data)
 92            kl = compute_kl(posterior, prior)
 93            loss = criterion(output, target)
 94            objective_value = objective.calculate(loss, kl, parameters["num_samples"])
 95            __raise_exception_on_invalid_value(objective_value)
 96            objective_value.backward()
 97            optimizer.step()
 98        logging.info(
 99            f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl / parameters['num_samples']}"
100        )
101        if wandb_params is not None and wandb_params["log_wandb"]:
102            wandb.log(
103                {
104                    wandb_params["name_wandb"] + "/Epoch": epoch,
105                    wandb_params["name_wandb"] + "/Objective": objective_value,
106                    wandb_params["name_wandb"] + "/Loss": loss,
107                    wandb_params["name_wandb"] + "/KL-n": kl
108                    / parameters["num_samples"],
109                }
110            )
def train( model: torch.nn.modules.module.Module, posterior: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], prior: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], objective: core.objective.AbstractObjective.AbstractObjective, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, parameters: dict[str, typing.Any], device: torch.device, wandb_params: dict = None):
 32def train(
 33    model: nn.Module,
 34    posterior: DistributionT,
 35    prior: DistributionT,
 36    objective: AbstractObjective,
 37    train_loader: torch.utils.data.dataloader.DataLoader,
 38    val_loader: torch.utils.data.dataloader.DataLoader,
 39    parameters: dict[str, Any],
 40    device: torch.device,
 41    wandb_params: dict = None,
 42):
 43    """
 44    Train a probabilistic neural network by optimizing a PAC-Bayes-inspired objective.
 45
 46    At each iteration:
 47      1) Optionally clamp model outputs using `bounded_call` if `pmin` is provided in `parameters`.
 48      2) Compute KL divergence between posterior and prior.
 49      3) Compute the empirical loss (NLL by default).
 50      4) Combine loss and KL via the given `objective`.
 51      5) Backpropagate and update model parameters.
 52
 53    Logs intermediate results (objective, loss, KL) to Python's logger and optionally to wandb.
 54
 55    Args:
 56        model (nn.Module): The probabilistic neural network to train.
 57        posterior (DistributionT): The current (learnable) posterior distribution.
 58        prior (DistributionT): The (fixed or partially learnable) prior distribution.
 59        objective (AbstractObjective): An object that merges empirical loss and KL
 60            into a single differentiable objective.
 61        train_loader (DataLoader): Dataloader for the training dataset.
 62        val_loader (DataLoader): Dataloader for the validation dataset (currently unused here).
 63        parameters (Dict[str, Any]): A dictionary of training hyperparameters, which can include:
 64            - 'lr': Learning rate.
 65            - 'momentum': Momentum term for SGD.
 66            - 'epochs': Number of epochs.
 67            - 'num_samples': Usually the size of the training set (or mini-batch size times steps).
 68            - 'seed': Random seed (optional).
 69            - 'pmin': Minimum probability for bounding (optional).
 70        device (torch.device): The device (CPU or GPU) for training.
 71        wandb_params (Dict, optional): Configuration for logging to Weights & Biases. Expects keys:
 72            - "log_wandb": bool, whether to log or not
 73            - "name_wandb": str, run name / prefix for logging
 74
 75    Returns:
 76        None: The model (and its posterior) are updated in-place over the specified epochs.
 77    """
 78    criterion = torch.nn.NLLLoss()
 79    optimizer = torch.optim.SGD(
 80        model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"]
 81    )
 82
 83    if "seed" in parameters:
 84        torch.manual_seed(parameters["seed"])
 85    for epoch in range(parameters["epochs"]):
 86        for _i, (data, target) in tqdm(enumerate(train_loader)):
 87            data, target = data.to(device), target.to(device)
 88            optimizer.zero_grad()
 89            if "pmin" in parameters:
 90                output = bounded_call(model, data, parameters["pmin"])
 91            else:
 92                output = model(data)
 93            kl = compute_kl(posterior, prior)
 94            loss = criterion(output, target)
 95            objective_value = objective.calculate(loss, kl, parameters["num_samples"])
 96            __raise_exception_on_invalid_value(objective_value)
 97            objective_value.backward()
 98            optimizer.step()
 99        logging.info(
100            f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl / parameters['num_samples']}"
101        )
102        if wandb_params is not None and wandb_params["log_wandb"]:
103            wandb.log(
104                {
105                    wandb_params["name_wandb"] + "/Epoch": epoch,
106                    wandb_params["name_wandb"] + "/Objective": objective_value,
107                    wandb_params["name_wandb"] + "/Loss": loss,
108                    wandb_params["name_wandb"] + "/KL-n": kl
109                    / parameters["num_samples"],
110                }
111            )

Train a probabilistic neural network by optimizing a PAC-Bayes-inspired objective.

At each iteration:

1) Optionally clamp model outputs using bounded_call if pmin is provided in parameters. 2) Compute KL divergence between posterior and prior. 3) Compute the empirical loss (NLL by default). 4) Combine loss and KL via the given objective. 5) Backpropagate and update model parameters.

Logs intermediate results (objective, loss, KL) to Python's logger and optionally to wandb.

Arguments:
  • model (nn.Module): The probabilistic neural network to train.
  • posterior (DistributionT): The current (learnable) posterior distribution.
  • prior (DistributionT): The (fixed or partially learnable) prior distribution.
  • objective (AbstractObjective): An object that merges empirical loss and KL into a single differentiable objective.
  • train_loader (DataLoader): Dataloader for the training dataset.
  • val_loader (DataLoader): Dataloader for the validation dataset (currently unused here).
  • parameters (Dict[str, Any]): A dictionary of training hyperparameters, which can include:
    • 'lr': Learning rate.
    • 'momentum': Momentum term for SGD.
    • 'epochs': Number of epochs.
    • 'num_samples': Usually the size of the training set (or mini-batch size times steps).
    • 'seed': Random seed (optional).
    • 'pmin': Minimum probability for bounding (optional).
  • device (torch.device): The device (CPU or GPU) for training.
  • wandb_params (Dict, optional): Configuration for logging to Weights & Biases. Expects keys:
    • "log_wandb": bool, whether to log or not
    • "name_wandb": str, run name / prefix for logging
Returns:

None: The model (and its posterior) are updated in-place over the specified epochs.