core.metric

 1import logging
 2from collections.abc import Callable
 3
 4import torch
 5from torch import Tensor, nn
 6
 7import wandb
 8from core.loss import compute_losses
 9
10
11def evaluate_metrics(
12    model: nn.Module,
13    metrics: dict[str, Callable],
14    test_loader: torch.utils.data.dataloader.DataLoader,
15    num_samples_metric: int,
16    device: torch.device,
17    pmin: float = 1e-5,
18    wandb_params: dict = None,
19) -> dict[str, Tensor]:
20    """
21    Evaluate a set of metric functions on a test set with multiple Monte Carlo samples.
22
23    This function uses `compute_losses` under the hood to compute each metric
24    (e.g., NLL, 0-1 error) over `num_samples_metric` samples from the posterior.
25    Optionally logs the results to Weights & Biases (wandb).
26
27    Args:
28        model (nn.Module): A probabilistic neural network model.
29        metrics (Dict[str, Callable]): A dictionary mapping metric names
30            to metric functions (e.g., {"zero_one": zero_one_loss}).
31        test_loader (DataLoader): DataLoader for the test/validation dataset.
32        num_samples_metric (int): Number of Monte Carlo samples to draw
33            when evaluating each metric on the test set.
34        device (torch.device): The device (CPU/GPU) to run computations on.
35        pmin (float, optional): A lower bound for probabilities. If specified,
36            `bounded_call` is applied to model outputs.
37        wandb_params (Dict, optional): Configuration for logging to wandb.
38            Expects keys:
39            - "log_wandb": bool, whether to log or not
40            - "name_wandb": str, prefix for logging metrics
41
42    Returns:
43        Dict[str, Tensor]: A dictionary mapping each metric name to its average value
44        across the entire test dataset and all Monte Carlo samples.
45    """
46    avg_metrics = compute_losses(
47        model=model,
48        bound_loader=test_loader,
49        mc_samples=num_samples_metric,
50        loss_func_list=list(metrics.values()),
51        pmin=pmin,
52        device=device,
53    )
54    avg_metrics = dict(zip(metrics.keys(), avg_metrics, strict=False))
55    logging.info("Average metrics:")
56    logging.info(avg_metrics)
57    if wandb_params is not None and wandb_params["log_wandb"]:
58        for name, metric in avg_metrics.items():
59            wandb.log({f"{wandb_params['name_wandb']}/{name}": metric.item()})
60    return avg_metrics
def evaluate_metrics( model: torch.nn.modules.module.Module, metrics: dict[str, Callable], test_loader: torch.utils.data.dataloader.DataLoader, num_samples_metric: int, device: torch.device, pmin: float = 1e-05, wandb_params: dict = None) -> dict[str, torch.Tensor]:
12def evaluate_metrics(
13    model: nn.Module,
14    metrics: dict[str, Callable],
15    test_loader: torch.utils.data.dataloader.DataLoader,
16    num_samples_metric: int,
17    device: torch.device,
18    pmin: float = 1e-5,
19    wandb_params: dict = None,
20) -> dict[str, Tensor]:
21    """
22    Evaluate a set of metric functions on a test set with multiple Monte Carlo samples.
23
24    This function uses `compute_losses` under the hood to compute each metric
25    (e.g., NLL, 0-1 error) over `num_samples_metric` samples from the posterior.
26    Optionally logs the results to Weights & Biases (wandb).
27
28    Args:
29        model (nn.Module): A probabilistic neural network model.
30        metrics (Dict[str, Callable]): A dictionary mapping metric names
31            to metric functions (e.g., {"zero_one": zero_one_loss}).
32        test_loader (DataLoader): DataLoader for the test/validation dataset.
33        num_samples_metric (int): Number of Monte Carlo samples to draw
34            when evaluating each metric on the test set.
35        device (torch.device): The device (CPU/GPU) to run computations on.
36        pmin (float, optional): A lower bound for probabilities. If specified,
37            `bounded_call` is applied to model outputs.
38        wandb_params (Dict, optional): Configuration for logging to wandb.
39            Expects keys:
40            - "log_wandb": bool, whether to log or not
41            - "name_wandb": str, prefix for logging metrics
42
43    Returns:
44        Dict[str, Tensor]: A dictionary mapping each metric name to its average value
45        across the entire test dataset and all Monte Carlo samples.
46    """
47    avg_metrics = compute_losses(
48        model=model,
49        bound_loader=test_loader,
50        mc_samples=num_samples_metric,
51        loss_func_list=list(metrics.values()),
52        pmin=pmin,
53        device=device,
54    )
55    avg_metrics = dict(zip(metrics.keys(), avg_metrics, strict=False))
56    logging.info("Average metrics:")
57    logging.info(avg_metrics)
58    if wandb_params is not None and wandb_params["log_wandb"]:
59        for name, metric in avg_metrics.items():
60            wandb.log({f"{wandb_params['name_wandb']}/{name}": metric.item()})
61    return avg_metrics

Evaluate a set of metric functions on a test set with multiple Monte Carlo samples.

This function uses compute_losses under the hood to compute each metric (e.g., NLL, 0-1 error) over num_samples_metric samples from the posterior. Optionally logs the results to Weights & Biases (wandb).

Arguments:
  • model (nn.Module): A probabilistic neural network model.
  • metrics (Dict[str, Callable]): A dictionary mapping metric names to metric functions (e.g., {"zero_one": zero_one_loss}).
  • test_loader (DataLoader): DataLoader for the test/validation dataset.
  • num_samples_metric (int): Number of Monte Carlo samples to draw when evaluating each metric on the test set.
  • device (torch.device): The device (CPU/GPU) to run computations on.
  • pmin (float, optional): A lower bound for probabilities. If specified, bounded_call is applied to model outputs.
  • wandb_params (Dict, optional): Configuration for logging to wandb. Expects keys:
    • "log_wandb": bool, whether to log or not
    • "name_wandb": str, prefix for logging metrics
Returns:

Dict[str, Tensor]: A dictionary mapping each metric name to its average value across the entire test dataset and all Monte Carlo samples.