core.metric
1import logging 2from collections.abc import Callable 3 4import torch 5from torch import Tensor, nn 6 7import wandb 8from core.loss import compute_losses 9 10 11def evaluate_metrics( 12 model: nn.Module, 13 metrics: dict[str, Callable], 14 test_loader: torch.utils.data.dataloader.DataLoader, 15 num_samples_metric: int, 16 device: torch.device, 17 pmin: float = 1e-5, 18 wandb_params: dict = None, 19) -> dict[str, Tensor]: 20 """ 21 Evaluate a set of metric functions on a test set with multiple Monte Carlo samples. 22 23 This function uses `compute_losses` under the hood to compute each metric 24 (e.g., NLL, 0-1 error) over `num_samples_metric` samples from the posterior. 25 Optionally logs the results to Weights & Biases (wandb). 26 27 Args: 28 model (nn.Module): A probabilistic neural network model. 29 metrics (Dict[str, Callable]): A dictionary mapping metric names 30 to metric functions (e.g., {"zero_one": zero_one_loss}). 31 test_loader (DataLoader): DataLoader for the test/validation dataset. 32 num_samples_metric (int): Number of Monte Carlo samples to draw 33 when evaluating each metric on the test set. 34 device (torch.device): The device (CPU/GPU) to run computations on. 35 pmin (float, optional): A lower bound for probabilities. If specified, 36 `bounded_call` is applied to model outputs. 37 wandb_params (Dict, optional): Configuration for logging to wandb. 38 Expects keys: 39 - "log_wandb": bool, whether to log or not 40 - "name_wandb": str, prefix for logging metrics 41 42 Returns: 43 Dict[str, Tensor]: A dictionary mapping each metric name to its average value 44 across the entire test dataset and all Monte Carlo samples. 45 """ 46 avg_metrics = compute_losses( 47 model=model, 48 bound_loader=test_loader, 49 mc_samples=num_samples_metric, 50 loss_func_list=list(metrics.values()), 51 pmin=pmin, 52 device=device, 53 ) 54 avg_metrics = dict(zip(metrics.keys(), avg_metrics, strict=False)) 55 logging.info("Average metrics:") 56 logging.info(avg_metrics) 57 if wandb_params is not None and wandb_params["log_wandb"]: 58 for name, metric in avg_metrics.items(): 59 wandb.log({f"{wandb_params['name_wandb']}/{name}": metric.item()}) 60 return avg_metrics
def
evaluate_metrics( model: torch.nn.modules.module.Module, metrics: dict[str, Callable], test_loader: torch.utils.data.dataloader.DataLoader, num_samples_metric: int, device: torch.device, pmin: float = 1e-05, wandb_params: dict = None) -> dict[str, torch.Tensor]:
12def evaluate_metrics( 13 model: nn.Module, 14 metrics: dict[str, Callable], 15 test_loader: torch.utils.data.dataloader.DataLoader, 16 num_samples_metric: int, 17 device: torch.device, 18 pmin: float = 1e-5, 19 wandb_params: dict = None, 20) -> dict[str, Tensor]: 21 """ 22 Evaluate a set of metric functions on a test set with multiple Monte Carlo samples. 23 24 This function uses `compute_losses` under the hood to compute each metric 25 (e.g., NLL, 0-1 error) over `num_samples_metric` samples from the posterior. 26 Optionally logs the results to Weights & Biases (wandb). 27 28 Args: 29 model (nn.Module): A probabilistic neural network model. 30 metrics (Dict[str, Callable]): A dictionary mapping metric names 31 to metric functions (e.g., {"zero_one": zero_one_loss}). 32 test_loader (DataLoader): DataLoader for the test/validation dataset. 33 num_samples_metric (int): Number of Monte Carlo samples to draw 34 when evaluating each metric on the test set. 35 device (torch.device): The device (CPU/GPU) to run computations on. 36 pmin (float, optional): A lower bound for probabilities. If specified, 37 `bounded_call` is applied to model outputs. 38 wandb_params (Dict, optional): Configuration for logging to wandb. 39 Expects keys: 40 - "log_wandb": bool, whether to log or not 41 - "name_wandb": str, prefix for logging metrics 42 43 Returns: 44 Dict[str, Tensor]: A dictionary mapping each metric name to its average value 45 across the entire test dataset and all Monte Carlo samples. 46 """ 47 avg_metrics = compute_losses( 48 model=model, 49 bound_loader=test_loader, 50 mc_samples=num_samples_metric, 51 loss_func_list=list(metrics.values()), 52 pmin=pmin, 53 device=device, 54 ) 55 avg_metrics = dict(zip(metrics.keys(), avg_metrics, strict=False)) 56 logging.info("Average metrics:") 57 logging.info(avg_metrics) 58 if wandb_params is not None and wandb_params["log_wandb"]: 59 for name, metric in avg_metrics.items(): 60 wandb.log({f"{wandb_params['name_wandb']}/{name}": metric.item()}) 61 return avg_metrics
Evaluate a set of metric functions on a test set with multiple Monte Carlo samples.
This function uses compute_losses under the hood to compute each metric
(e.g., NLL, 0-1 error) over num_samples_metric samples from the posterior.
Optionally logs the results to Weights & Biases (wandb).
Arguments:
- model (nn.Module): A probabilistic neural network model.
- metrics (Dict[str, Callable]): A dictionary mapping metric names to metric functions (e.g., {"zero_one": zero_one_loss}).
- test_loader (DataLoader): DataLoader for the test/validation dataset.
- num_samples_metric (int): Number of Monte Carlo samples to draw when evaluating each metric on the test set.
- device (torch.device): The device (CPU/GPU) to run computations on.
- pmin (float, optional): A lower bound for probabilities. If specified,
bounded_callis applied to model outputs. - wandb_params (Dict, optional): Configuration for logging to wandb.
Expects keys:
- "log_wandb": bool, whether to log or not
- "name_wandb": str, prefix for logging metrics
Returns:
Dict[str, Tensor]: A dictionary mapping each metric name to its average value across the entire test dataset and all Monte Carlo samples.