core.risk

  1import logging
  2from collections.abc import Callable
  3
  4import torch
  5from torch import Tensor, nn
  6
  7import wandb
  8from core.bound import AbstractBound
  9from core.distribution.utils import DistributionT, compute_kl
 10from core.loss import compute_losses
 11
 12
 13def certify_risk(
 14    model: nn.Module,
 15    bounds: dict[str, AbstractBound],
 16    losses: dict[str, Callable],
 17    posterior: DistributionT,
 18    prior: DistributionT,
 19    bound_loader: torch.utils.data.dataloader.DataLoader,
 20    num_samples_loss: int,
 21    device: torch.device,
 22    pmin: float = 1e-5,
 23    wandb_params: dict = None,
 24) -> dict[str, dict[str, dict[str, Tensor]]]:
 25    """
 26    Certify (evaluate) the generalization risk of a probabilistic neural network
 27    using one or more PAC-Bayes bounds on a given dataset.
 28
 29    Steps:
 30      1) Compute average losses (e.g., NLL, 0-1 error) via multiple Monte Carlo samples
 31         from the posterior (`compute_losses`).
 32      2) Calculate the KL divergence between the posterior and prior distributions.
 33      3) For each bound in `bounds`, calculate a PAC-Bayes risk bound for each loss in `losses`.
 34      4) Optionally log intermediate results (loss, risk) to Weights & Biases (wandb).
 35
 36    Args:
 37        model (nn.Module): The probabilistic neural network used for risk evaluation.
 38        bounds (Dict[str, AbstractBound]): A mapping from bound names to bound objects
 39            that implement a PAC-Bayes bound (`AbstractBound`).
 40        losses (Dict[str, Callable]): A mapping from loss names to loss functions
 41            (e.g., {"nll": nll_loss, "01": zero_one_loss}).
 42        posterior (DistributionT): Posterior distribution of the model parameters.
 43        prior (DistributionT): Prior distribution of the model parameters.
 44        bound_loader (DataLoader): DataLoader for the dataset on which bounds and losses are computed.
 45        num_samples_loss (int): Number of Monte Carlo samples to draw from the posterior
 46            for estimating the average losses.
 47        device (torch.device): The device (CPU or GPU) to perform computations on.
 48        pmin (float, optional): A minimum probability bound for clamping model outputs in log space.
 49            Defaults to 1e-5.
 50        wandb_params (Dict, optional): Configuration for Weights & Biases logging. Expects keys:
 51            - "log_wandb": bool, whether to log
 52            - "name_wandb": str, prefix for metric names
 53
 54    Returns:
 55        Dict[str, Dict[str, Dict[str, Tensor]]]: A nested dictionary of the form:
 56            {
 57              bound_name: {
 58                loss_name: {
 59                  'risk': risk_value,
 60                  'loss': avg_loss_value
 61                }
 62              }
 63            }
 64        where `risk_value` is the computed bound on the risk, and `avg_loss_value` is the
 65        empirical loss estimate for that loss and bound.
 66    """
 67    avg_losses = compute_losses(
 68        model=model,
 69        bound_loader=bound_loader,
 70        mc_samples=num_samples_loss,
 71        loss_func_list=list(losses.values()),
 72        pmin=pmin,
 73        device=device,
 74    )
 75    avg_losses = dict(zip(losses.keys(), avg_losses, strict=False))
 76    logging.info("Average losses:")
 77    logging.info(avg_losses)
 78
 79    # Evaluate bound
 80    kl = compute_kl(dist1=posterior, dist2=prior)
 81    num_samples_bound = len(bound_loader.sampler)
 82
 83    result = {}
 84    for bound_name, bound in bounds.items():
 85        logging.info(f"Bound name: {bound_name}")
 86        result[bound_name] = {}
 87        for loss_name, avg_loss in avg_losses.items():
 88            risk, loss = bound.calculate(
 89                avg_loss=avg_loss,
 90                kl=kl,
 91                num_samples_bound=num_samples_bound,
 92                num_samples_loss=num_samples_loss,
 93            )
 94            result[bound_name][loss_name] = {"risk": risk, "loss": loss}
 95            logging.info(
 96                f"Loss name: {loss_name}, "
 97                f"Risk: {risk.item():.5f}, "
 98                f"Loss: {loss.item():.5f}, "
 99                f"KL per sample bound: {kl / num_samples_bound:.5f}"
100            )
101            if wandb_params is not None and wandb_params["log_wandb"]:
102                wandb.log(
103                    {
104                        f"{wandb_params['name_wandb']}/{bound_name}/{loss_name}_loss": loss.item(),
105                        f"{wandb_params['name_wandb']}/{bound_name}/{loss_name}_risk": risk.item(),
106                    }
107                )
108    if wandb_params is not None and wandb_params["log_wandb"]:
109        wandb.log({f"{wandb_params['name_wandb']}/KL-n/": kl / num_samples_bound})
110
111    return result
def certify_risk( model: torch.nn.modules.module.Module, bounds: dict[str, core.bound.AbstractBound.AbstractBound], losses: dict[str, Callable], posterior: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], prior: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], bound_loader: torch.utils.data.dataloader.DataLoader, num_samples_loss: int, device: torch.device, pmin: float = 1e-05, wandb_params: dict = None) -> dict[str, dict[str, dict[str, torch.Tensor]]]:
 14def certify_risk(
 15    model: nn.Module,
 16    bounds: dict[str, AbstractBound],
 17    losses: dict[str, Callable],
 18    posterior: DistributionT,
 19    prior: DistributionT,
 20    bound_loader: torch.utils.data.dataloader.DataLoader,
 21    num_samples_loss: int,
 22    device: torch.device,
 23    pmin: float = 1e-5,
 24    wandb_params: dict = None,
 25) -> dict[str, dict[str, dict[str, Tensor]]]:
 26    """
 27    Certify (evaluate) the generalization risk of a probabilistic neural network
 28    using one or more PAC-Bayes bounds on a given dataset.
 29
 30    Steps:
 31      1) Compute average losses (e.g., NLL, 0-1 error) via multiple Monte Carlo samples
 32         from the posterior (`compute_losses`).
 33      2) Calculate the KL divergence between the posterior and prior distributions.
 34      3) For each bound in `bounds`, calculate a PAC-Bayes risk bound for each loss in `losses`.
 35      4) Optionally log intermediate results (loss, risk) to Weights & Biases (wandb).
 36
 37    Args:
 38        model (nn.Module): The probabilistic neural network used for risk evaluation.
 39        bounds (Dict[str, AbstractBound]): A mapping from bound names to bound objects
 40            that implement a PAC-Bayes bound (`AbstractBound`).
 41        losses (Dict[str, Callable]): A mapping from loss names to loss functions
 42            (e.g., {"nll": nll_loss, "01": zero_one_loss}).
 43        posterior (DistributionT): Posterior distribution of the model parameters.
 44        prior (DistributionT): Prior distribution of the model parameters.
 45        bound_loader (DataLoader): DataLoader for the dataset on which bounds and losses are computed.
 46        num_samples_loss (int): Number of Monte Carlo samples to draw from the posterior
 47            for estimating the average losses.
 48        device (torch.device): The device (CPU or GPU) to perform computations on.
 49        pmin (float, optional): A minimum probability bound for clamping model outputs in log space.
 50            Defaults to 1e-5.
 51        wandb_params (Dict, optional): Configuration for Weights & Biases logging. Expects keys:
 52            - "log_wandb": bool, whether to log
 53            - "name_wandb": str, prefix for metric names
 54
 55    Returns:
 56        Dict[str, Dict[str, Dict[str, Tensor]]]: A nested dictionary of the form:
 57            {
 58              bound_name: {
 59                loss_name: {
 60                  'risk': risk_value,
 61                  'loss': avg_loss_value
 62                }
 63              }
 64            }
 65        where `risk_value` is the computed bound on the risk, and `avg_loss_value` is the
 66        empirical loss estimate for that loss and bound.
 67    """
 68    avg_losses = compute_losses(
 69        model=model,
 70        bound_loader=bound_loader,
 71        mc_samples=num_samples_loss,
 72        loss_func_list=list(losses.values()),
 73        pmin=pmin,
 74        device=device,
 75    )
 76    avg_losses = dict(zip(losses.keys(), avg_losses, strict=False))
 77    logging.info("Average losses:")
 78    logging.info(avg_losses)
 79
 80    # Evaluate bound
 81    kl = compute_kl(dist1=posterior, dist2=prior)
 82    num_samples_bound = len(bound_loader.sampler)
 83
 84    result = {}
 85    for bound_name, bound in bounds.items():
 86        logging.info(f"Bound name: {bound_name}")
 87        result[bound_name] = {}
 88        for loss_name, avg_loss in avg_losses.items():
 89            risk, loss = bound.calculate(
 90                avg_loss=avg_loss,
 91                kl=kl,
 92                num_samples_bound=num_samples_bound,
 93                num_samples_loss=num_samples_loss,
 94            )
 95            result[bound_name][loss_name] = {"risk": risk, "loss": loss}
 96            logging.info(
 97                f"Loss name: {loss_name}, "
 98                f"Risk: {risk.item():.5f}, "
 99                f"Loss: {loss.item():.5f}, "
100                f"KL per sample bound: {kl / num_samples_bound:.5f}"
101            )
102            if wandb_params is not None and wandb_params["log_wandb"]:
103                wandb.log(
104                    {
105                        f"{wandb_params['name_wandb']}/{bound_name}/{loss_name}_loss": loss.item(),
106                        f"{wandb_params['name_wandb']}/{bound_name}/{loss_name}_risk": risk.item(),
107                    }
108                )
109    if wandb_params is not None and wandb_params["log_wandb"]:
110        wandb.log({f"{wandb_params['name_wandb']}/KL-n/": kl / num_samples_bound})
111
112    return result

Certify (evaluate) the generalization risk of a probabilistic neural network using one or more PAC-Bayes bounds on a given dataset.

Steps:

1) Compute average losses (e.g., NLL, 0-1 error) via multiple Monte Carlo samples from the posterior (compute_losses). 2) Calculate the KL divergence between the posterior and prior distributions. 3) For each bound in bounds, calculate a PAC-Bayes risk bound for each loss in losses. 4) Optionally log intermediate results (loss, risk) to Weights & Biases (wandb).

Arguments:
  • model (nn.Module): The probabilistic neural network used for risk evaluation.
  • bounds (Dict[str, AbstractBound]): A mapping from bound names to bound objects that implement a PAC-Bayes bound (AbstractBound).
  • losses (Dict[str, Callable]): A mapping from loss names to loss functions (e.g., {"nll": nll_loss, "01": zero_one_loss}).
  • posterior (DistributionT): Posterior distribution of the model parameters.
  • prior (DistributionT): Prior distribution of the model parameters.
  • bound_loader (DataLoader): DataLoader for the dataset on which bounds and losses are computed.
  • num_samples_loss (int): Number of Monte Carlo samples to draw from the posterior for estimating the average losses.
  • device (torch.device): The device (CPU or GPU) to perform computations on.
  • pmin (float, optional): A minimum probability bound for clamping model outputs in log space. Defaults to 1e-5.
  • wandb_params (Dict, optional): Configuration for Weights & Biases logging. Expects keys:
    • "log_wandb": bool, whether to log
    • "name_wandb": str, prefix for metric names
Returns:

Dict[str, Dict[str, Dict[str, Tensor]]]: A nested dictionary of the form: { bound_name: { loss_name: { 'risk': risk_value, 'loss': avg_loss_value } } } where risk_value is the computed bound on the risk, and avg_loss_value is the empirical loss estimate for that loss and bound.