core.loss
1from collections.abc import Callable 2 3import numpy as np 4import torch 5import torch.nn as nn 6import torch.nn.functional as f 7from torch import Tensor 8from tqdm import tqdm 9 10from core.model import bounded_call 11 12 13def rescale_loss(loss: Tensor, pmin: float) -> Tensor: 14 """ 15 Rescale a loss value by dividing it by log(1/pmin). 16 17 This is often used in PAC-Bayes settings to keep losses within a certain range 18 (e.g., converting losses to the [0, 1] interval) or to improve numerical stability 19 when probabilities must be bounded below by `pmin`. 20 21 Args: 22 loss (Tensor): A scalar or batched loss tensor. 23 pmin (float): A lower bound for probabilities (e.g., 1e-5). 24 25 Returns: 26 Tensor: The loss tensor divided by ln(1/pmin). 27 """ 28 return loss / np.log(1.0 / pmin) 29 30 31def nll_loss(outputs: Tensor, targets: Tensor, pmin: float = None) -> Tensor: 32 """ 33 Compute the negative log-likelihood (NLL) loss for classification. 34 35 In typical classification settings, `outputs` is the log-probability of each class 36 (e.g., from a log-softmax layer), and `targets` are the true class indices. 37 38 Args: 39 outputs (Tensor): Log-probabilities of shape (batch_size, num_classes). 40 targets (Tensor): Ground truth class indices of shape (batch_size,). 41 pmin (float, optional): Not used directly here; kept for uniform interface with other loss functions. 42 43 Returns: 44 Tensor: A scalar tensor representing the average NLL loss over the batch. 45 """ 46 return f.nll_loss(outputs, targets) 47 48 49def scaled_nll_loss(outputs: Tensor, targets: Tensor, pmin: float) -> Tensor: 50 """ 51 Compute the negative log-likelihood (NLL) loss and then rescale it by log(1/pmin). 52 53 This is a combination of `nll_loss` and `rescale_loss`, often used to ensure 54 that the final loss remains within a desired numeric range for PAC-Bayes optimization. 55 56 Args: 57 outputs (Tensor): Log-probabilities of shape (batch_size, num_classes). 58 targets (Tensor): Ground truth class indices of shape (batch_size,). 59 pmin (float): A lower bound for probabilities. 60 61 Returns: 62 Tensor: The rescaled NLL loss as a scalar tensor. 63 """ 64 return rescale_loss(nll_loss(outputs, targets), pmin) 65 66 67def zero_one_loss(outputs: Tensor, targets: Tensor, pmin: float = None) -> Tensor: 68 """ 69 Compute the 0-1 classification error. 70 71 This function returns a loss between 0 and 1, where 0 indicates perfect 72 classification on the given batch and 1 indicates total misclassification. 73 74 Args: 75 outputs (Tensor): Logits or log-probabilities for each class. 76 targets (Tensor): Ground truth class indices of shape (batch_size,). 77 pmin (float, optional): Not used here; kept for consistency with other losses. 78 79 Returns: 80 Tensor: A single-element tensor with the 0-1 error (1 - accuracy). 81 """ 82 predictions = outputs.max(1, keepdim=True)[1] 83 correct = predictions.eq(targets.view_as(predictions)).sum().item() 84 total = targets.size(0) 85 loss_01 = 1 - (correct / total) 86 return Tensor([loss_01]) 87 88 89def _compute_losses( 90 model: nn.Module, 91 inputs: Tensor, 92 targets: Tensor, 93 loss_func_list: list[Callable], 94 pmin: float = None, 95) -> list[Tensor]: 96 """ 97 Compute a list of loss values for a single forward pass of the model. 98 99 This function optionally applies a bounded call if `pmin` is specified, then 100 evaluates each loss function in `loss_func_list` on the model outputs. 101 102 Args: 103 model (nn.Module): A (probabilistic) neural network model. 104 inputs (Tensor): Input data for one batch, of shape (batch_size, ...). 105 targets (Tensor): Ground truth labels for the batch. 106 loss_func_list (List[Callable]): A list of loss functions to compute 107 (e.g., [nll_loss, zero_one_loss]). 108 pmin (float, optional): A lower bound for probabilities. If given, 109 `bounded_call` is used before computing losses. 110 111 Returns: 112 List[Tensor]: A list of scalar loss tensors, each corresponding to one function 113 in `loss_func_list`. 114 """ 115 if pmin: 116 # bound probability to be from [pmin to 1] 117 outputs = bounded_call(model, inputs, pmin) 118 else: 119 outputs = model(inputs) 120 losses = [] 121 for loss_func in loss_func_list: 122 loss = ( 123 loss_func(outputs, targets, pmin) if pmin else loss_func(outputs, targets) 124 ) 125 losses.append(loss) 126 return losses 127 128 129def compute_losses( 130 model: nn.Module, 131 bound_loader: torch.utils.data.DataLoader, 132 mc_samples: int, 133 loss_func_list: list[Callable], 134 device: torch.device, 135 pmin: float = None, 136) -> Tensor: 137 """ 138 Compute average losses over multiple Monte Carlo samples for a given dataset. 139 140 This function is typically used to estimate the expected risk under the 141 posterior by sampling the model `mc_samples` times for each batch in the `bound_loader`. 142 143 Args: 144 model (nn.Module): A probabilistic neural network model. 145 bound_loader (DataLoader): A DataLoader for the dataset on which 146 the losses should be computed (e.g. a bound or test set). 147 mc_samples (int): Number of Monte Carlo samples to draw from the posterior 148 for each batch. 149 loss_func_list (List[Callable]): List of loss functions to evaluate 150 (e.g., [nll_loss, zero_one_loss]). 151 device (torch.device): The device (CPU/GPU) on which computations are performed. 152 pmin (float, optional): A lower bound for probabilities. If provided, 153 `bounded_call` will be used to clamp model outputs. 154 155 Returns: 156 Tensor: A tensor of shape (len(loss_func_list),) containing the average losses 157 across the entire dataset for each loss function. The result is typically 158 used to estimate or bound the generalization error in PAC-Bayes experiments. 159 """ 160 with torch.no_grad(): 161 batch_wise_loss_list = [] 162 for data, targets in tqdm(bound_loader): 163 data, targets = data.to(device), targets.to(device) 164 mc_loss_list = [] 165 for _i in range(mc_samples): 166 losses = _compute_losses(model, data, targets, loss_func_list, pmin) 167 mc_loss_list.append(Tensor(losses)) 168 batch_wise_loss_list.append(torch.stack(mc_loss_list).mean(dim=0)) 169 return torch.stack(batch_wise_loss_list).mean(dim=0)
14def rescale_loss(loss: Tensor, pmin: float) -> Tensor: 15 """ 16 Rescale a loss value by dividing it by log(1/pmin). 17 18 This is often used in PAC-Bayes settings to keep losses within a certain range 19 (e.g., converting losses to the [0, 1] interval) or to improve numerical stability 20 when probabilities must be bounded below by `pmin`. 21 22 Args: 23 loss (Tensor): A scalar or batched loss tensor. 24 pmin (float): A lower bound for probabilities (e.g., 1e-5). 25 26 Returns: 27 Tensor: The loss tensor divided by ln(1/pmin). 28 """ 29 return loss / np.log(1.0 / pmin)
Rescale a loss value by dividing it by log(1/pmin).
This is often used in PAC-Bayes settings to keep losses within a certain range
(e.g., converting losses to the [0, 1] interval) or to improve numerical stability
when probabilities must be bounded below by pmin.
Arguments:
- loss (Tensor): A scalar or batched loss tensor.
- pmin (float): A lower bound for probabilities (e.g., 1e-5).
Returns:
Tensor: The loss tensor divided by ln(1/pmin).
32def nll_loss(outputs: Tensor, targets: Tensor, pmin: float = None) -> Tensor: 33 """ 34 Compute the negative log-likelihood (NLL) loss for classification. 35 36 In typical classification settings, `outputs` is the log-probability of each class 37 (e.g., from a log-softmax layer), and `targets` are the true class indices. 38 39 Args: 40 outputs (Tensor): Log-probabilities of shape (batch_size, num_classes). 41 targets (Tensor): Ground truth class indices of shape (batch_size,). 42 pmin (float, optional): Not used directly here; kept for uniform interface with other loss functions. 43 44 Returns: 45 Tensor: A scalar tensor representing the average NLL loss over the batch. 46 """ 47 return f.nll_loss(outputs, targets)
Compute the negative log-likelihood (NLL) loss for classification.
In typical classification settings, outputs is the log-probability of each class
(e.g., from a log-softmax layer), and targets are the true class indices.
Arguments:
- outputs (Tensor): Log-probabilities of shape (batch_size, num_classes).
- targets (Tensor): Ground truth class indices of shape (batch_size,).
- pmin (float, optional): Not used directly here; kept for uniform interface with other loss functions.
Returns:
Tensor: A scalar tensor representing the average NLL loss over the batch.
50def scaled_nll_loss(outputs: Tensor, targets: Tensor, pmin: float) -> Tensor: 51 """ 52 Compute the negative log-likelihood (NLL) loss and then rescale it by log(1/pmin). 53 54 This is a combination of `nll_loss` and `rescale_loss`, often used to ensure 55 that the final loss remains within a desired numeric range for PAC-Bayes optimization. 56 57 Args: 58 outputs (Tensor): Log-probabilities of shape (batch_size, num_classes). 59 targets (Tensor): Ground truth class indices of shape (batch_size,). 60 pmin (float): A lower bound for probabilities. 61 62 Returns: 63 Tensor: The rescaled NLL loss as a scalar tensor. 64 """ 65 return rescale_loss(nll_loss(outputs, targets), pmin)
Compute the negative log-likelihood (NLL) loss and then rescale it by log(1/pmin).
This is a combination of nll_loss and rescale_loss, often used to ensure
that the final loss remains within a desired numeric range for PAC-Bayes optimization.
Arguments:
- outputs (Tensor): Log-probabilities of shape (batch_size, num_classes).
- targets (Tensor): Ground truth class indices of shape (batch_size,).
- pmin (float): A lower bound for probabilities.
Returns:
Tensor: The rescaled NLL loss as a scalar tensor.
68def zero_one_loss(outputs: Tensor, targets: Tensor, pmin: float = None) -> Tensor: 69 """ 70 Compute the 0-1 classification error. 71 72 This function returns a loss between 0 and 1, where 0 indicates perfect 73 classification on the given batch and 1 indicates total misclassification. 74 75 Args: 76 outputs (Tensor): Logits or log-probabilities for each class. 77 targets (Tensor): Ground truth class indices of shape (batch_size,). 78 pmin (float, optional): Not used here; kept for consistency with other losses. 79 80 Returns: 81 Tensor: A single-element tensor with the 0-1 error (1 - accuracy). 82 """ 83 predictions = outputs.max(1, keepdim=True)[1] 84 correct = predictions.eq(targets.view_as(predictions)).sum().item() 85 total = targets.size(0) 86 loss_01 = 1 - (correct / total) 87 return Tensor([loss_01])
Compute the 0-1 classification error.
This function returns a loss between 0 and 1, where 0 indicates perfect classification on the given batch and 1 indicates total misclassification.
Arguments:
- outputs (Tensor): Logits or log-probabilities for each class.
- targets (Tensor): Ground truth class indices of shape (batch_size,).
- pmin (float, optional): Not used here; kept for consistency with other losses.
Returns:
Tensor: A single-element tensor with the 0-1 error (1 - accuracy).
130def compute_losses( 131 model: nn.Module, 132 bound_loader: torch.utils.data.DataLoader, 133 mc_samples: int, 134 loss_func_list: list[Callable], 135 device: torch.device, 136 pmin: float = None, 137) -> Tensor: 138 """ 139 Compute average losses over multiple Monte Carlo samples for a given dataset. 140 141 This function is typically used to estimate the expected risk under the 142 posterior by sampling the model `mc_samples` times for each batch in the `bound_loader`. 143 144 Args: 145 model (nn.Module): A probabilistic neural network model. 146 bound_loader (DataLoader): A DataLoader for the dataset on which 147 the losses should be computed (e.g. a bound or test set). 148 mc_samples (int): Number of Monte Carlo samples to draw from the posterior 149 for each batch. 150 loss_func_list (List[Callable]): List of loss functions to evaluate 151 (e.g., [nll_loss, zero_one_loss]). 152 device (torch.device): The device (CPU/GPU) on which computations are performed. 153 pmin (float, optional): A lower bound for probabilities. If provided, 154 `bounded_call` will be used to clamp model outputs. 155 156 Returns: 157 Tensor: A tensor of shape (len(loss_func_list),) containing the average losses 158 across the entire dataset for each loss function. The result is typically 159 used to estimate or bound the generalization error in PAC-Bayes experiments. 160 """ 161 with torch.no_grad(): 162 batch_wise_loss_list = [] 163 for data, targets in tqdm(bound_loader): 164 data, targets = data.to(device), targets.to(device) 165 mc_loss_list = [] 166 for _i in range(mc_samples): 167 losses = _compute_losses(model, data, targets, loss_func_list, pmin) 168 mc_loss_list.append(Tensor(losses)) 169 batch_wise_loss_list.append(torch.stack(mc_loss_list).mean(dim=0)) 170 return torch.stack(batch_wise_loss_list).mean(dim=0)
Compute average losses over multiple Monte Carlo samples for a given dataset.
This function is typically used to estimate the expected risk under the
posterior by sampling the model mc_samples times for each batch in the bound_loader.
Arguments:
- model (nn.Module): A probabilistic neural network model.
- bound_loader (DataLoader): A DataLoader for the dataset on which the losses should be computed (e.g. a bound or test set).
- mc_samples (int): Number of Monte Carlo samples to draw from the posterior for each batch.
- loss_func_list (List[Callable]): List of loss functions to evaluate (e.g., [nll_loss, zero_one_loss]).
- device (torch.device): The device (CPU/GPU) on which computations are performed.
- pmin (float, optional): A lower bound for probabilities. If provided,
bounded_callwill be used to clamp model outputs.
Returns:
Tensor: A tensor of shape (len(loss_func_list),) containing the average losses across the entire dataset for each loss function. The result is typically used to estimate or bound the generalization error in PAC-Bayes experiments.