core.training
1import logging 2from typing import Any 3 4import torch 5from torch import nn 6from tqdm import tqdm 7 8import wandb 9from core.distribution.utils import DistributionT, compute_kl 10from core.model import bounded_call 11from core.objective import AbstractObjective 12 13 14def __raise_exception_on_invalid_value(value: torch.Tensor): 15 """ 16 Raise a ValueError if the given tensor is None or contains NaN values. 17 18 This helper function is used to catch numerical issues that might arise 19 during training (e.g., invalid gradients, extreme KL values). 20 21 Args: 22 value (torch.Tensor): A scalar tensor to check for validity. 23 24 Raises: 25 ValueError: If `value` is None or if any entry is NaN. 26 """ 27 if value is None or torch.isnan(value).any(): 28 raise ValueError(f"Invalid value {value}") 29 30 31def train( 32 model: nn.Module, 33 posterior: DistributionT, 34 prior: DistributionT, 35 objective: AbstractObjective, 36 train_loader: torch.utils.data.dataloader.DataLoader, 37 val_loader: torch.utils.data.dataloader.DataLoader, 38 parameters: dict[str, Any], 39 device: torch.device, 40 wandb_params: dict = None, 41): 42 """ 43 Train a probabilistic neural network by optimizing a PAC-Bayes-inspired objective. 44 45 At each iteration: 46 1) Optionally clamp model outputs using `bounded_call` if `pmin` is provided in `parameters`. 47 2) Compute KL divergence between posterior and prior. 48 3) Compute the empirical loss (NLL by default). 49 4) Combine loss and KL via the given `objective`. 50 5) Backpropagate and update model parameters. 51 52 Logs intermediate results (objective, loss, KL) to Python's logger and optionally to wandb. 53 54 Args: 55 model (nn.Module): The probabilistic neural network to train. 56 posterior (DistributionT): The current (learnable) posterior distribution. 57 prior (DistributionT): The (fixed or partially learnable) prior distribution. 58 objective (AbstractObjective): An object that merges empirical loss and KL 59 into a single differentiable objective. 60 train_loader (DataLoader): Dataloader for the training dataset. 61 val_loader (DataLoader): Dataloader for the validation dataset (currently unused here). 62 parameters (Dict[str, Any]): A dictionary of training hyperparameters, which can include: 63 - 'lr': Learning rate. 64 - 'momentum': Momentum term for SGD. 65 - 'epochs': Number of epochs. 66 - 'num_samples': Usually the size of the training set (or mini-batch size times steps). 67 - 'seed': Random seed (optional). 68 - 'pmin': Minimum probability for bounding (optional). 69 device (torch.device): The device (CPU or GPU) for training. 70 wandb_params (Dict, optional): Configuration for logging to Weights & Biases. Expects keys: 71 - "log_wandb": bool, whether to log or not 72 - "name_wandb": str, run name / prefix for logging 73 74 Returns: 75 None: The model (and its posterior) are updated in-place over the specified epochs. 76 """ 77 criterion = torch.nn.NLLLoss() 78 optimizer = torch.optim.SGD( 79 model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"] 80 ) 81 82 if "seed" in parameters: 83 torch.manual_seed(parameters["seed"]) 84 for epoch in range(parameters["epochs"]): 85 for _i, (data, target) in tqdm(enumerate(train_loader)): 86 data, target = data.to(device), target.to(device) 87 optimizer.zero_grad() 88 if "pmin" in parameters: 89 output = bounded_call(model, data, parameters["pmin"]) 90 else: 91 output = model(data) 92 kl = compute_kl(posterior, prior) 93 loss = criterion(output, target) 94 objective_value = objective.calculate(loss, kl, parameters["num_samples"]) 95 __raise_exception_on_invalid_value(objective_value) 96 objective_value.backward() 97 optimizer.step() 98 logging.info( 99 f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl / parameters['num_samples']}" 100 ) 101 if wandb_params is not None and wandb_params["log_wandb"]: 102 wandb.log( 103 { 104 wandb_params["name_wandb"] + "/Epoch": epoch, 105 wandb_params["name_wandb"] + "/Objective": objective_value, 106 wandb_params["name_wandb"] + "/Loss": loss, 107 wandb_params["name_wandb"] + "/KL-n": kl 108 / parameters["num_samples"], 109 } 110 )
def
train( model: torch.nn.modules.module.Module, posterior: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], prior: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], objective: core.objective.AbstractObjective.AbstractObjective, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, parameters: dict[str, typing.Any], device: torch.device, wandb_params: dict = None):
32def train( 33 model: nn.Module, 34 posterior: DistributionT, 35 prior: DistributionT, 36 objective: AbstractObjective, 37 train_loader: torch.utils.data.dataloader.DataLoader, 38 val_loader: torch.utils.data.dataloader.DataLoader, 39 parameters: dict[str, Any], 40 device: torch.device, 41 wandb_params: dict = None, 42): 43 """ 44 Train a probabilistic neural network by optimizing a PAC-Bayes-inspired objective. 45 46 At each iteration: 47 1) Optionally clamp model outputs using `bounded_call` if `pmin` is provided in `parameters`. 48 2) Compute KL divergence between posterior and prior. 49 3) Compute the empirical loss (NLL by default). 50 4) Combine loss and KL via the given `objective`. 51 5) Backpropagate and update model parameters. 52 53 Logs intermediate results (objective, loss, KL) to Python's logger and optionally to wandb. 54 55 Args: 56 model (nn.Module): The probabilistic neural network to train. 57 posterior (DistributionT): The current (learnable) posterior distribution. 58 prior (DistributionT): The (fixed or partially learnable) prior distribution. 59 objective (AbstractObjective): An object that merges empirical loss and KL 60 into a single differentiable objective. 61 train_loader (DataLoader): Dataloader for the training dataset. 62 val_loader (DataLoader): Dataloader for the validation dataset (currently unused here). 63 parameters (Dict[str, Any]): A dictionary of training hyperparameters, which can include: 64 - 'lr': Learning rate. 65 - 'momentum': Momentum term for SGD. 66 - 'epochs': Number of epochs. 67 - 'num_samples': Usually the size of the training set (or mini-batch size times steps). 68 - 'seed': Random seed (optional). 69 - 'pmin': Minimum probability for bounding (optional). 70 device (torch.device): The device (CPU or GPU) for training. 71 wandb_params (Dict, optional): Configuration for logging to Weights & Biases. Expects keys: 72 - "log_wandb": bool, whether to log or not 73 - "name_wandb": str, run name / prefix for logging 74 75 Returns: 76 None: The model (and its posterior) are updated in-place over the specified epochs. 77 """ 78 criterion = torch.nn.NLLLoss() 79 optimizer = torch.optim.SGD( 80 model.parameters(), lr=parameters["lr"], momentum=parameters["momentum"] 81 ) 82 83 if "seed" in parameters: 84 torch.manual_seed(parameters["seed"]) 85 for epoch in range(parameters["epochs"]): 86 for _i, (data, target) in tqdm(enumerate(train_loader)): 87 data, target = data.to(device), target.to(device) 88 optimizer.zero_grad() 89 if "pmin" in parameters: 90 output = bounded_call(model, data, parameters["pmin"]) 91 else: 92 output = model(data) 93 kl = compute_kl(posterior, prior) 94 loss = criterion(output, target) 95 objective_value = objective.calculate(loss, kl, parameters["num_samples"]) 96 __raise_exception_on_invalid_value(objective_value) 97 objective_value.backward() 98 optimizer.step() 99 logging.info( 100 f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl / parameters['num_samples']}" 101 ) 102 if wandb_params is not None and wandb_params["log_wandb"]: 103 wandb.log( 104 { 105 wandb_params["name_wandb"] + "/Epoch": epoch, 106 wandb_params["name_wandb"] + "/Objective": objective_value, 107 wandb_params["name_wandb"] + "/Loss": loss, 108 wandb_params["name_wandb"] + "/KL-n": kl 109 / parameters["num_samples"], 110 } 111 )
Train a probabilistic neural network by optimizing a PAC-Bayes-inspired objective.
At each iteration:
1) Optionally clamp model outputs using
bounded_callifpminis provided inparameters. 2) Compute KL divergence between posterior and prior. 3) Compute the empirical loss (NLL by default). 4) Combine loss and KL via the givenobjective. 5) Backpropagate and update model parameters.
Logs intermediate results (objective, loss, KL) to Python's logger and optionally to wandb.
Arguments:
- model (nn.Module): The probabilistic neural network to train.
- posterior (DistributionT): The current (learnable) posterior distribution.
- prior (DistributionT): The (fixed or partially learnable) prior distribution.
- objective (AbstractObjective): An object that merges empirical loss and KL into a single differentiable objective.
- train_loader (DataLoader): Dataloader for the training dataset.
- val_loader (DataLoader): Dataloader for the validation dataset (currently unused here).
- parameters (Dict[str, Any]): A dictionary of training hyperparameters, which can include:
- 'lr': Learning rate.
- 'momentum': Momentum term for SGD.
- 'epochs': Number of epochs.
- 'num_samples': Usually the size of the training set (or mini-batch size times steps).
- 'seed': Random seed (optional).
- 'pmin': Minimum probability for bounding (optional).
- device (torch.device): The device (CPU or GPU) for training.
- wandb_params (Dict, optional): Configuration for logging to Weights & Biases. Expects keys:
- "log_wandb": bool, whether to log or not
- "name_wandb": str, run name / prefix for logging
Returns:
None: The model (and its posterior) are updated in-place over the specified epochs.