core.loss

  1from collections.abc import Callable
  2
  3import numpy as np
  4import torch
  5import torch.nn as nn
  6import torch.nn.functional as f
  7from torch import Tensor
  8from tqdm import tqdm
  9
 10from core.model import bounded_call
 11
 12
 13def rescale_loss(loss: Tensor, pmin: float) -> Tensor:
 14    """
 15    Rescale a loss value by dividing it by log(1/pmin).
 16
 17    This is often used in PAC-Bayes settings to keep losses within a certain range
 18    (e.g., converting losses to the [0, 1] interval) or to improve numerical stability
 19    when probabilities must be bounded below by `pmin`.
 20
 21    Args:
 22        loss (Tensor): A scalar or batched loss tensor.
 23        pmin (float): A lower bound for probabilities (e.g., 1e-5).
 24
 25    Returns:
 26        Tensor: The loss tensor divided by ln(1/pmin).
 27    """
 28    return loss / np.log(1.0 / pmin)
 29
 30
 31def nll_loss(outputs: Tensor, targets: Tensor, pmin: float = None) -> Tensor:
 32    """
 33    Compute the negative log-likelihood (NLL) loss for classification.
 34
 35    In typical classification settings, `outputs` is the log-probability of each class
 36    (e.g., from a log-softmax layer), and `targets` are the true class indices.
 37
 38    Args:
 39        outputs (Tensor): Log-probabilities of shape (batch_size, num_classes).
 40        targets (Tensor): Ground truth class indices of shape (batch_size,).
 41        pmin (float, optional): Not used directly here; kept for uniform interface with other loss functions.
 42
 43    Returns:
 44        Tensor: A scalar tensor representing the average NLL loss over the batch.
 45    """
 46    return f.nll_loss(outputs, targets)
 47
 48
 49def scaled_nll_loss(outputs: Tensor, targets: Tensor, pmin: float) -> Tensor:
 50    """
 51    Compute the negative log-likelihood (NLL) loss and then rescale it by log(1/pmin).
 52
 53    This is a combination of `nll_loss` and `rescale_loss`, often used to ensure
 54    that the final loss remains within a desired numeric range for PAC-Bayes optimization.
 55
 56    Args:
 57        outputs (Tensor): Log-probabilities of shape (batch_size, num_classes).
 58        targets (Tensor): Ground truth class indices of shape (batch_size,).
 59        pmin (float): A lower bound for probabilities.
 60
 61    Returns:
 62        Tensor: The rescaled NLL loss as a scalar tensor.
 63    """
 64    return rescale_loss(nll_loss(outputs, targets), pmin)
 65
 66
 67def zero_one_loss(outputs: Tensor, targets: Tensor, pmin: float = None) -> Tensor:
 68    """
 69    Compute the 0-1 classification error.
 70
 71    This function returns a loss between 0 and 1, where 0 indicates perfect
 72    classification on the given batch and 1 indicates total misclassification.
 73
 74    Args:
 75        outputs (Tensor): Logits or log-probabilities for each class.
 76        targets (Tensor): Ground truth class indices of shape (batch_size,).
 77        pmin (float, optional): Not used here; kept for consistency with other losses.
 78
 79    Returns:
 80        Tensor: A single-element tensor with the 0-1 error (1 - accuracy).
 81    """
 82    predictions = outputs.max(1, keepdim=True)[1]
 83    correct = predictions.eq(targets.view_as(predictions)).sum().item()
 84    total = targets.size(0)
 85    loss_01 = 1 - (correct / total)
 86    return Tensor([loss_01])
 87
 88
 89def _compute_losses(
 90    model: nn.Module,
 91    inputs: Tensor,
 92    targets: Tensor,
 93    loss_func_list: list[Callable],
 94    pmin: float = None,
 95) -> list[Tensor]:
 96    """
 97    Compute a list of loss values for a single forward pass of the model.
 98
 99    This function optionally applies a bounded call if `pmin` is specified, then
100    evaluates each loss function in `loss_func_list` on the model outputs.
101
102    Args:
103        model (nn.Module): A (probabilistic) neural network model.
104        inputs (Tensor): Input data for one batch, of shape (batch_size, ...).
105        targets (Tensor): Ground truth labels for the batch.
106        loss_func_list (List[Callable]): A list of loss functions to compute
107            (e.g., [nll_loss, zero_one_loss]).
108        pmin (float, optional): A lower bound for probabilities. If given,
109            `bounded_call` is used before computing losses.
110
111    Returns:
112        List[Tensor]: A list of scalar loss tensors, each corresponding to one function
113            in `loss_func_list`.
114    """
115    if pmin:
116        # bound probability to be from [pmin to 1]
117        outputs = bounded_call(model, inputs, pmin)
118    else:
119        outputs = model(inputs)
120    losses = []
121    for loss_func in loss_func_list:
122        loss = (
123            loss_func(outputs, targets, pmin) if pmin else loss_func(outputs, targets)
124        )
125        losses.append(loss)
126    return losses
127
128
129def compute_losses(
130    model: nn.Module,
131    bound_loader: torch.utils.data.DataLoader,
132    mc_samples: int,
133    loss_func_list: list[Callable],
134    device: torch.device,
135    pmin: float = None,
136) -> Tensor:
137    """
138    Compute average losses over multiple Monte Carlo samples for a given dataset.
139
140    This function is typically used to estimate the expected risk under the
141    posterior by sampling the model `mc_samples` times for each batch in the `bound_loader`.
142
143    Args:
144        model (nn.Module): A probabilistic neural network model.
145        bound_loader (DataLoader): A DataLoader for the dataset on which
146            the losses should be computed (e.g. a bound or test set).
147        mc_samples (int): Number of Monte Carlo samples to draw from the posterior
148            for each batch.
149        loss_func_list (List[Callable]): List of loss functions to evaluate
150            (e.g., [nll_loss, zero_one_loss]).
151        device (torch.device): The device (CPU/GPU) on which computations are performed.
152        pmin (float, optional): A lower bound for probabilities. If provided,
153            `bounded_call` will be used to clamp model outputs.
154
155    Returns:
156        Tensor: A tensor of shape (len(loss_func_list),) containing the average losses
157            across the entire dataset for each loss function. The result is typically
158            used to estimate or bound the generalization error in PAC-Bayes experiments.
159    """
160    with torch.no_grad():
161        batch_wise_loss_list = []
162        for data, targets in tqdm(bound_loader):
163            data, targets = data.to(device), targets.to(device)
164            mc_loss_list = []
165            for _i in range(mc_samples):
166                losses = _compute_losses(model, data, targets, loss_func_list, pmin)
167                mc_loss_list.append(Tensor(losses))
168            batch_wise_loss_list.append(torch.stack(mc_loss_list).mean(dim=0))
169    return torch.stack(batch_wise_loss_list).mean(dim=0)
def rescale_loss(loss: torch.Tensor, pmin: float) -> torch.Tensor:
14def rescale_loss(loss: Tensor, pmin: float) -> Tensor:
15    """
16    Rescale a loss value by dividing it by log(1/pmin).
17
18    This is often used in PAC-Bayes settings to keep losses within a certain range
19    (e.g., converting losses to the [0, 1] interval) or to improve numerical stability
20    when probabilities must be bounded below by `pmin`.
21
22    Args:
23        loss (Tensor): A scalar or batched loss tensor.
24        pmin (float): A lower bound for probabilities (e.g., 1e-5).
25
26    Returns:
27        Tensor: The loss tensor divided by ln(1/pmin).
28    """
29    return loss / np.log(1.0 / pmin)

Rescale a loss value by dividing it by log(1/pmin).

This is often used in PAC-Bayes settings to keep losses within a certain range (e.g., converting losses to the [0, 1] interval) or to improve numerical stability when probabilities must be bounded below by pmin.

Arguments:
  • loss (Tensor): A scalar or batched loss tensor.
  • pmin (float): A lower bound for probabilities (e.g., 1e-5).
Returns:

Tensor: The loss tensor divided by ln(1/pmin).

def nll_loss( outputs: torch.Tensor, targets: torch.Tensor, pmin: float = None) -> torch.Tensor:
32def nll_loss(outputs: Tensor, targets: Tensor, pmin: float = None) -> Tensor:
33    """
34    Compute the negative log-likelihood (NLL) loss for classification.
35
36    In typical classification settings, `outputs` is the log-probability of each class
37    (e.g., from a log-softmax layer), and `targets` are the true class indices.
38
39    Args:
40        outputs (Tensor): Log-probabilities of shape (batch_size, num_classes).
41        targets (Tensor): Ground truth class indices of shape (batch_size,).
42        pmin (float, optional): Not used directly here; kept for uniform interface with other loss functions.
43
44    Returns:
45        Tensor: A scalar tensor representing the average NLL loss over the batch.
46    """
47    return f.nll_loss(outputs, targets)

Compute the negative log-likelihood (NLL) loss for classification.

In typical classification settings, outputs is the log-probability of each class (e.g., from a log-softmax layer), and targets are the true class indices.

Arguments:
  • outputs (Tensor): Log-probabilities of shape (batch_size, num_classes).
  • targets (Tensor): Ground truth class indices of shape (batch_size,).
  • pmin (float, optional): Not used directly here; kept for uniform interface with other loss functions.
Returns:

Tensor: A scalar tensor representing the average NLL loss over the batch.

def scaled_nll_loss( outputs: torch.Tensor, targets: torch.Tensor, pmin: float) -> torch.Tensor:
50def scaled_nll_loss(outputs: Tensor, targets: Tensor, pmin: float) -> Tensor:
51    """
52    Compute the negative log-likelihood (NLL) loss and then rescale it by log(1/pmin).
53
54    This is a combination of `nll_loss` and `rescale_loss`, often used to ensure
55    that the final loss remains within a desired numeric range for PAC-Bayes optimization.
56
57    Args:
58        outputs (Tensor): Log-probabilities of shape (batch_size, num_classes).
59        targets (Tensor): Ground truth class indices of shape (batch_size,).
60        pmin (float): A lower bound for probabilities.
61
62    Returns:
63        Tensor: The rescaled NLL loss as a scalar tensor.
64    """
65    return rescale_loss(nll_loss(outputs, targets), pmin)

Compute the negative log-likelihood (NLL) loss and then rescale it by log(1/pmin).

This is a combination of nll_loss and rescale_loss, often used to ensure that the final loss remains within a desired numeric range for PAC-Bayes optimization.

Arguments:
  • outputs (Tensor): Log-probabilities of shape (batch_size, num_classes).
  • targets (Tensor): Ground truth class indices of shape (batch_size,).
  • pmin (float): A lower bound for probabilities.
Returns:

Tensor: The rescaled NLL loss as a scalar tensor.

def zero_one_loss( outputs: torch.Tensor, targets: torch.Tensor, pmin: float = None) -> torch.Tensor:
68def zero_one_loss(outputs: Tensor, targets: Tensor, pmin: float = None) -> Tensor:
69    """
70    Compute the 0-1 classification error.
71
72    This function returns a loss between 0 and 1, where 0 indicates perfect
73    classification on the given batch and 1 indicates total misclassification.
74
75    Args:
76        outputs (Tensor): Logits or log-probabilities for each class.
77        targets (Tensor): Ground truth class indices of shape (batch_size,).
78        pmin (float, optional): Not used here; kept for consistency with other losses.
79
80    Returns:
81        Tensor: A single-element tensor with the 0-1 error (1 - accuracy).
82    """
83    predictions = outputs.max(1, keepdim=True)[1]
84    correct = predictions.eq(targets.view_as(predictions)).sum().item()
85    total = targets.size(0)
86    loss_01 = 1 - (correct / total)
87    return Tensor([loss_01])

Compute the 0-1 classification error.

This function returns a loss between 0 and 1, where 0 indicates perfect classification on the given batch and 1 indicates total misclassification.

Arguments:
  • outputs (Tensor): Logits or log-probabilities for each class.
  • targets (Tensor): Ground truth class indices of shape (batch_size,).
  • pmin (float, optional): Not used here; kept for consistency with other losses.
Returns:

Tensor: A single-element tensor with the 0-1 error (1 - accuracy).

def compute_losses( model: torch.nn.modules.module.Module, bound_loader: torch.utils.data.dataloader.DataLoader, mc_samples: int, loss_func_list: list[Callable], device: torch.device, pmin: float = None) -> torch.Tensor:
130def compute_losses(
131    model: nn.Module,
132    bound_loader: torch.utils.data.DataLoader,
133    mc_samples: int,
134    loss_func_list: list[Callable],
135    device: torch.device,
136    pmin: float = None,
137) -> Tensor:
138    """
139    Compute average losses over multiple Monte Carlo samples for a given dataset.
140
141    This function is typically used to estimate the expected risk under the
142    posterior by sampling the model `mc_samples` times for each batch in the `bound_loader`.
143
144    Args:
145        model (nn.Module): A probabilistic neural network model.
146        bound_loader (DataLoader): A DataLoader for the dataset on which
147            the losses should be computed (e.g. a bound or test set).
148        mc_samples (int): Number of Monte Carlo samples to draw from the posterior
149            for each batch.
150        loss_func_list (List[Callable]): List of loss functions to evaluate
151            (e.g., [nll_loss, zero_one_loss]).
152        device (torch.device): The device (CPU/GPU) on which computations are performed.
153        pmin (float, optional): A lower bound for probabilities. If provided,
154            `bounded_call` will be used to clamp model outputs.
155
156    Returns:
157        Tensor: A tensor of shape (len(loss_func_list),) containing the average losses
158            across the entire dataset for each loss function. The result is typically
159            used to estimate or bound the generalization error in PAC-Bayes experiments.
160    """
161    with torch.no_grad():
162        batch_wise_loss_list = []
163        for data, targets in tqdm(bound_loader):
164            data, targets = data.to(device), targets.to(device)
165            mc_loss_list = []
166            for _i in range(mc_samples):
167                losses = _compute_losses(model, data, targets, loss_func_list, pmin)
168                mc_loss_list.append(Tensor(losses))
169            batch_wise_loss_list.append(torch.stack(mc_loss_list).mean(dim=0))
170    return torch.stack(batch_wise_loss_list).mean(dim=0)

Compute average losses over multiple Monte Carlo samples for a given dataset.

This function is typically used to estimate the expected risk under the posterior by sampling the model mc_samples times for each batch in the bound_loader.

Arguments:
  • model (nn.Module): A probabilistic neural network model.
  • bound_loader (DataLoader): A DataLoader for the dataset on which the losses should be computed (e.g. a bound or test set).
  • mc_samples (int): Number of Monte Carlo samples to draw from the posterior for each batch.
  • loss_func_list (List[Callable]): List of loss functions to evaluate (e.g., [nll_loss, zero_one_loss]).
  • device (torch.device): The device (CPU/GPU) on which computations are performed.
  • pmin (float, optional): A lower bound for probabilities. If provided, bounded_call will be used to clamp model outputs.
Returns:

Tensor: A tensor of shape (len(loss_func_list),) containing the average losses across the entire dataset for each loss function. The result is typically used to estimate or bound the generalization error in PAC-Bayes experiments.