core.risk
1import logging 2from collections.abc import Callable 3 4import torch 5from torch import Tensor, nn 6 7import wandb 8from core.bound import AbstractBound 9from core.distribution.utils import DistributionT, compute_kl 10from core.loss import compute_losses 11 12 13def certify_risk( 14 model: nn.Module, 15 bounds: dict[str, AbstractBound], 16 losses: dict[str, Callable], 17 posterior: DistributionT, 18 prior: DistributionT, 19 bound_loader: torch.utils.data.dataloader.DataLoader, 20 num_samples_loss: int, 21 device: torch.device, 22 pmin: float = 1e-5, 23 wandb_params: dict = None, 24) -> dict[str, dict[str, dict[str, Tensor]]]: 25 """ 26 Certify (evaluate) the generalization risk of a probabilistic neural network 27 using one or more PAC-Bayes bounds on a given dataset. 28 29 Steps: 30 1) Compute average losses (e.g., NLL, 0-1 error) via multiple Monte Carlo samples 31 from the posterior (`compute_losses`). 32 2) Calculate the KL divergence between the posterior and prior distributions. 33 3) For each bound in `bounds`, calculate a PAC-Bayes risk bound for each loss in `losses`. 34 4) Optionally log intermediate results (loss, risk) to Weights & Biases (wandb). 35 36 Args: 37 model (nn.Module): The probabilistic neural network used for risk evaluation. 38 bounds (Dict[str, AbstractBound]): A mapping from bound names to bound objects 39 that implement a PAC-Bayes bound (`AbstractBound`). 40 losses (Dict[str, Callable]): A mapping from loss names to loss functions 41 (e.g., {"nll": nll_loss, "01": zero_one_loss}). 42 posterior (DistributionT): Posterior distribution of the model parameters. 43 prior (DistributionT): Prior distribution of the model parameters. 44 bound_loader (DataLoader): DataLoader for the dataset on which bounds and losses are computed. 45 num_samples_loss (int): Number of Monte Carlo samples to draw from the posterior 46 for estimating the average losses. 47 device (torch.device): The device (CPU or GPU) to perform computations on. 48 pmin (float, optional): A minimum probability bound for clamping model outputs in log space. 49 Defaults to 1e-5. 50 wandb_params (Dict, optional): Configuration for Weights & Biases logging. Expects keys: 51 - "log_wandb": bool, whether to log 52 - "name_wandb": str, prefix for metric names 53 54 Returns: 55 Dict[str, Dict[str, Dict[str, Tensor]]]: A nested dictionary of the form: 56 { 57 bound_name: { 58 loss_name: { 59 'risk': risk_value, 60 'loss': avg_loss_value 61 } 62 } 63 } 64 where `risk_value` is the computed bound on the risk, and `avg_loss_value` is the 65 empirical loss estimate for that loss and bound. 66 """ 67 avg_losses = compute_losses( 68 model=model, 69 bound_loader=bound_loader, 70 mc_samples=num_samples_loss, 71 loss_func_list=list(losses.values()), 72 pmin=pmin, 73 device=device, 74 ) 75 avg_losses = dict(zip(losses.keys(), avg_losses, strict=False)) 76 logging.info("Average losses:") 77 logging.info(avg_losses) 78 79 # Evaluate bound 80 kl = compute_kl(dist1=posterior, dist2=prior) 81 num_samples_bound = len(bound_loader.sampler) 82 83 result = {} 84 for bound_name, bound in bounds.items(): 85 logging.info(f"Bound name: {bound_name}") 86 result[bound_name] = {} 87 for loss_name, avg_loss in avg_losses.items(): 88 risk, loss = bound.calculate( 89 avg_loss=avg_loss, 90 kl=kl, 91 num_samples_bound=num_samples_bound, 92 num_samples_loss=num_samples_loss, 93 ) 94 result[bound_name][loss_name] = {"risk": risk, "loss": loss} 95 logging.info( 96 f"Loss name: {loss_name}, " 97 f"Risk: {risk.item():.5f}, " 98 f"Loss: {loss.item():.5f}, " 99 f"KL per sample bound: {kl / num_samples_bound:.5f}" 100 ) 101 if wandb_params is not None and wandb_params["log_wandb"]: 102 wandb.log( 103 { 104 f"{wandb_params['name_wandb']}/{bound_name}/{loss_name}_loss": loss.item(), 105 f"{wandb_params['name_wandb']}/{bound_name}/{loss_name}_risk": risk.item(), 106 } 107 ) 108 if wandb_params is not None and wandb_params["log_wandb"]: 109 wandb.log({f"{wandb_params['name_wandb']}/KL-n/": kl / num_samples_bound}) 110 111 return result
14def certify_risk( 15 model: nn.Module, 16 bounds: dict[str, AbstractBound], 17 losses: dict[str, Callable], 18 posterior: DistributionT, 19 prior: DistributionT, 20 bound_loader: torch.utils.data.dataloader.DataLoader, 21 num_samples_loss: int, 22 device: torch.device, 23 pmin: float = 1e-5, 24 wandb_params: dict = None, 25) -> dict[str, dict[str, dict[str, Tensor]]]: 26 """ 27 Certify (evaluate) the generalization risk of a probabilistic neural network 28 using one or more PAC-Bayes bounds on a given dataset. 29 30 Steps: 31 1) Compute average losses (e.g., NLL, 0-1 error) via multiple Monte Carlo samples 32 from the posterior (`compute_losses`). 33 2) Calculate the KL divergence between the posterior and prior distributions. 34 3) For each bound in `bounds`, calculate a PAC-Bayes risk bound for each loss in `losses`. 35 4) Optionally log intermediate results (loss, risk) to Weights & Biases (wandb). 36 37 Args: 38 model (nn.Module): The probabilistic neural network used for risk evaluation. 39 bounds (Dict[str, AbstractBound]): A mapping from bound names to bound objects 40 that implement a PAC-Bayes bound (`AbstractBound`). 41 losses (Dict[str, Callable]): A mapping from loss names to loss functions 42 (e.g., {"nll": nll_loss, "01": zero_one_loss}). 43 posterior (DistributionT): Posterior distribution of the model parameters. 44 prior (DistributionT): Prior distribution of the model parameters. 45 bound_loader (DataLoader): DataLoader for the dataset on which bounds and losses are computed. 46 num_samples_loss (int): Number of Monte Carlo samples to draw from the posterior 47 for estimating the average losses. 48 device (torch.device): The device (CPU or GPU) to perform computations on. 49 pmin (float, optional): A minimum probability bound for clamping model outputs in log space. 50 Defaults to 1e-5. 51 wandb_params (Dict, optional): Configuration for Weights & Biases logging. Expects keys: 52 - "log_wandb": bool, whether to log 53 - "name_wandb": str, prefix for metric names 54 55 Returns: 56 Dict[str, Dict[str, Dict[str, Tensor]]]: A nested dictionary of the form: 57 { 58 bound_name: { 59 loss_name: { 60 'risk': risk_value, 61 'loss': avg_loss_value 62 } 63 } 64 } 65 where `risk_value` is the computed bound on the risk, and `avg_loss_value` is the 66 empirical loss estimate for that loss and bound. 67 """ 68 avg_losses = compute_losses( 69 model=model, 70 bound_loader=bound_loader, 71 mc_samples=num_samples_loss, 72 loss_func_list=list(losses.values()), 73 pmin=pmin, 74 device=device, 75 ) 76 avg_losses = dict(zip(losses.keys(), avg_losses, strict=False)) 77 logging.info("Average losses:") 78 logging.info(avg_losses) 79 80 # Evaluate bound 81 kl = compute_kl(dist1=posterior, dist2=prior) 82 num_samples_bound = len(bound_loader.sampler) 83 84 result = {} 85 for bound_name, bound in bounds.items(): 86 logging.info(f"Bound name: {bound_name}") 87 result[bound_name] = {} 88 for loss_name, avg_loss in avg_losses.items(): 89 risk, loss = bound.calculate( 90 avg_loss=avg_loss, 91 kl=kl, 92 num_samples_bound=num_samples_bound, 93 num_samples_loss=num_samples_loss, 94 ) 95 result[bound_name][loss_name] = {"risk": risk, "loss": loss} 96 logging.info( 97 f"Loss name: {loss_name}, " 98 f"Risk: {risk.item():.5f}, " 99 f"Loss: {loss.item():.5f}, " 100 f"KL per sample bound: {kl / num_samples_bound:.5f}" 101 ) 102 if wandb_params is not None and wandb_params["log_wandb"]: 103 wandb.log( 104 { 105 f"{wandb_params['name_wandb']}/{bound_name}/{loss_name}_loss": loss.item(), 106 f"{wandb_params['name_wandb']}/{bound_name}/{loss_name}_risk": risk.item(), 107 } 108 ) 109 if wandb_params is not None and wandb_params["log_wandb"]: 110 wandb.log({f"{wandb_params['name_wandb']}/KL-n/": kl / num_samples_bound}) 111 112 return result
Certify (evaluate) the generalization risk of a probabilistic neural network using one or more PAC-Bayes bounds on a given dataset.
Steps:
1) Compute average losses (e.g., NLL, 0-1 error) via multiple Monte Carlo samples from the posterior (
compute_losses). 2) Calculate the KL divergence between the posterior and prior distributions. 3) For each bound inbounds, calculate a PAC-Bayes risk bound for each loss inlosses. 4) Optionally log intermediate results (loss, risk) to Weights & Biases (wandb).
Arguments:
- model (nn.Module): The probabilistic neural network used for risk evaluation.
- bounds (Dict[str, AbstractBound]): A mapping from bound names to bound objects
that implement a PAC-Bayes bound (
AbstractBound). - losses (Dict[str, Callable]): A mapping from loss names to loss functions (e.g., {"nll": nll_loss, "01": zero_one_loss}).
- posterior (DistributionT): Posterior distribution of the model parameters.
- prior (DistributionT): Prior distribution of the model parameters.
- bound_loader (DataLoader): DataLoader for the dataset on which bounds and losses are computed.
- num_samples_loss (int): Number of Monte Carlo samples to draw from the posterior for estimating the average losses.
- device (torch.device): The device (CPU or GPU) to perform computations on.
- pmin (float, optional): A minimum probability bound for clamping model outputs in log space. Defaults to 1e-5.
- wandb_params (Dict, optional): Configuration for Weights & Biases logging. Expects keys:
- "log_wandb": bool, whether to log
- "name_wandb": str, prefix for metric names
Returns:
Dict[str, Dict[str, Dict[str, Tensor]]]: A nested dictionary of the form: { bound_name: { loss_name: { 'risk': risk_value, 'loss': avg_loss_value } } } where
risk_valueis the computed bound on the risk, andavg_loss_valueis the empirical loss estimate for that loss and bound.