core.model

  1from collections.abc import Callable, Iterator
  2
  3import numpy as np
  4import torch
  5from torch import Tensor, nn
  6
  7from core.distribution.utils import DistributionT
  8from core.layer import LAYER_MAPPING, AbstractProbLayer
  9from core.layer.utils import LayerNameT, get_torch_layers
 10
 11
 12def bounded_call(model: nn.Module, data: Tensor, pmin: float) -> Tensor:
 13    """
 14    Forward data through the model and clamp the output to a minimum log-probability.
 15
 16    This is typically used to avoid numerical instability in PAC-Bayes experiments
 17    when dealing with small probabilities. The output is clamped at log(pmin) to
 18    ensure log-probabilities do not fall below this threshold.
 19
 20    Args:
 21        model (nn.Module): The (probabilistic) neural network model.
 22        data (Tensor): Input data of shape (batch_size, ...).
 23        pmin (float): A lower bound for probabilities. Outputs are clamped at log(pmin).
 24
 25    Returns:
 26        Tensor: Model outputs with each element >= log(pmin).
 27    """
 28    return torch.clamp(model(data), min=np.log(pmin))
 29
 30
 31def dnn_to_probnn(
 32    model: nn.Module,
 33    weight_dist: DistributionT,
 34    prior_weight_dist: DistributionT,
 35    get_layers_func: Callable[
 36        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
 37    ] = get_torch_layers,
 38):
 39    """
 40    Convert a deterministic PyTorch model into a probabilistic neural network (ProbNN)
 41    by attaching weight/bias distributions to its layers.
 42
 43    This function iterates through each layer in the model (using `get_layers_func`),
 44    and if the layer type is supported, it:
 45      - Registers prior and posterior distributions for weights and biases.
 46      - Marks the layer as probabilistic (so that it samples weights/biases in forward calls).
 47      - Replaces the layer class with its probabilistic counterpart from `LAYER_MAPPING`.
 48
 49    Args:
 50        model (nn.Module): A deterministic PyTorch model (e.g., a CNN).
 51        weight_dist (DistributionT): A dictionary containing posterior distributions
 52            for weights and biases, keyed by layer name.
 53        prior_weight_dist (DistributionT): A dictionary containing prior distributions
 54            for weights and biases, keyed by layer name.
 55        get_layers_func (Callable): A function that returns an iterator of (layer_name, layer_module)
 56            pairs. Defaults to `get_torch_layers`.
 57
 58    Returns:
 59        None: The function modifies `model` in place, converting certain layers
 60        to their probabilistic equivalents.
 61    """
 62    for name, layer in get_layers_func(model):
 63        layer_type = type(layer)
 64        if layer_type in LAYER_MAPPING:
 65            layer.register_module(
 66                "_prior_weight_dist", prior_weight_dist[name]["weight"]
 67            )
 68            layer.register_module("_prior_bias_dist", prior_weight_dist[name]["bias"])
 69            layer.register_module("_weight_dist", weight_dist[name]["weight"])
 70            layer.register_module("_bias_dist", weight_dist[name]["bias"])
 71            layer.__setattr__("probabilistic_mode", True)
 72            layer.__class__ = LAYER_MAPPING[layer_type]
 73    model.probabilistic = AbstractProbLayer.probabilistic.__get__(model, nn.Module)
 74
 75
 76def update_dist(
 77    model: nn.Module,
 78    weight_dist: DistributionT = None,
 79    prior_weight_dist: DistributionT = None,
 80    get_layers_func: Callable[
 81        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
 82    ] = get_torch_layers,
 83):
 84    """
 85    Update the weight/bias distributions of an already converted probabilistic model.
 86
 87    This is useful when you want to load a different set of posterior or prior
 88    distributions into the same network structure, without re-running the entire
 89    `dnn_to_probnn` procedure.
 90
 91    Args:
 92        model (nn.Module): The probabilistic neural network model (already converted).
 93        weight_dist (DistributionT, optional): New posterior distributions keyed by layer name.
 94            If provided, each layer's '_weight_dist' and '_bias_dist' are updated.
 95        prior_weight_dist (DistributionT, optional): New prior distributions keyed by layer name.
 96            If provided, each layer's '_prior_weight_dist' and '_prior_bias_dist' are updated.
 97        get_layers_func (Callable): Function that returns an iterator of (layer_name, layer_module).
 98            Defaults to `get_torch_layers`.
 99
100    Returns:
101        None: The distributions in the model are updated in place.
102    """
103    if weight_dist is not None:
104        for (_name, distribution), (_, layer) in zip(
105            weight_dist.items(), get_layers_func(model), strict=False
106        ):
107            layer_type = type(layer)
108            if layer_type in LAYER_MAPPING.values():
109                layer.__setattr__("_weight_dist", distribution["weight"])
110                layer.__setattr__("_bias_dist", distribution["bias"])
111
112    if prior_weight_dist is not None:
113        for (_name, distribution), (_, layer) in zip(
114            prior_weight_dist.items(), get_layers_func(model), strict=False
115        ):
116            layer_type = type(layer)
117            if layer_type in LAYER_MAPPING.values():
118                layer.__setattr__("_prior_weight_dist", distribution["weight"])
119                layer.__setattr__("_prior_bias_dist", distribution["bias"])
def bounded_call( model: torch.nn.modules.module.Module, data: torch.Tensor, pmin: float) -> torch.Tensor:
13def bounded_call(model: nn.Module, data: Tensor, pmin: float) -> Tensor:
14    """
15    Forward data through the model and clamp the output to a minimum log-probability.
16
17    This is typically used to avoid numerical instability in PAC-Bayes experiments
18    when dealing with small probabilities. The output is clamped at log(pmin) to
19    ensure log-probabilities do not fall below this threshold.
20
21    Args:
22        model (nn.Module): The (probabilistic) neural network model.
23        data (Tensor): Input data of shape (batch_size, ...).
24        pmin (float): A lower bound for probabilities. Outputs are clamped at log(pmin).
25
26    Returns:
27        Tensor: Model outputs with each element >= log(pmin).
28    """
29    return torch.clamp(model(data), min=np.log(pmin))

Forward data through the model and clamp the output to a minimum log-probability.

This is typically used to avoid numerical instability in PAC-Bayes experiments when dealing with small probabilities. The output is clamped at log(pmin) to ensure log-probabilities do not fall below this threshold.

Arguments:
  • model (nn.Module): The (probabilistic) neural network model.
  • data (Tensor): Input data of shape (batch_size, ...).
  • pmin (float): A lower bound for probabilities. Outputs are clamped at log(pmin).
Returns:

Tensor: Model outputs with each element >= log(pmin).

def dnn_to_probnn( model: torch.nn.modules.module.Module, weight_dist: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], prior_weight_dist: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], get_layers_func: Callable[[torch.nn.modules.module.Module], Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]] = <function get_torch_layers>):
32def dnn_to_probnn(
33    model: nn.Module,
34    weight_dist: DistributionT,
35    prior_weight_dist: DistributionT,
36    get_layers_func: Callable[
37        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
38    ] = get_torch_layers,
39):
40    """
41    Convert a deterministic PyTorch model into a probabilistic neural network (ProbNN)
42    by attaching weight/bias distributions to its layers.
43
44    This function iterates through each layer in the model (using `get_layers_func`),
45    and if the layer type is supported, it:
46      - Registers prior and posterior distributions for weights and biases.
47      - Marks the layer as probabilistic (so that it samples weights/biases in forward calls).
48      - Replaces the layer class with its probabilistic counterpart from `LAYER_MAPPING`.
49
50    Args:
51        model (nn.Module): A deterministic PyTorch model (e.g., a CNN).
52        weight_dist (DistributionT): A dictionary containing posterior distributions
53            for weights and biases, keyed by layer name.
54        prior_weight_dist (DistributionT): A dictionary containing prior distributions
55            for weights and biases, keyed by layer name.
56        get_layers_func (Callable): A function that returns an iterator of (layer_name, layer_module)
57            pairs. Defaults to `get_torch_layers`.
58
59    Returns:
60        None: The function modifies `model` in place, converting certain layers
61        to their probabilistic equivalents.
62    """
63    for name, layer in get_layers_func(model):
64        layer_type = type(layer)
65        if layer_type in LAYER_MAPPING:
66            layer.register_module(
67                "_prior_weight_dist", prior_weight_dist[name]["weight"]
68            )
69            layer.register_module("_prior_bias_dist", prior_weight_dist[name]["bias"])
70            layer.register_module("_weight_dist", weight_dist[name]["weight"])
71            layer.register_module("_bias_dist", weight_dist[name]["bias"])
72            layer.__setattr__("probabilistic_mode", True)
73            layer.__class__ = LAYER_MAPPING[layer_type]
74    model.probabilistic = AbstractProbLayer.probabilistic.__get__(model, nn.Module)

Convert a deterministic PyTorch model into a probabilistic neural network (ProbNN) by attaching weight/bias distributions to its layers.

This function iterates through each layer in the model (using get_layers_func), and if the layer type is supported, it:

  • Registers prior and posterior distributions for weights and biases.
  • Marks the layer as probabilistic (so that it samples weights/biases in forward calls).
  • Replaces the layer class with its probabilistic counterpart from LAYER_MAPPING.
Arguments:
  • model (nn.Module): A deterministic PyTorch model (e.g., a CNN).
  • weight_dist (DistributionT): A dictionary containing posterior distributions for weights and biases, keyed by layer name.
  • prior_weight_dist (DistributionT): A dictionary containing prior distributions for weights and biases, keyed by layer name.
  • get_layers_func (Callable): A function that returns an iterator of (layer_name, layer_module) pairs. Defaults to get_torch_layers.
Returns:

None: The function modifies model in place, converting certain layers to their probabilistic equivalents.

def update_dist( model: torch.nn.modules.module.Module, weight_dist: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]] = None, prior_weight_dist: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]] = None, get_layers_func: Callable[[torch.nn.modules.module.Module], Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]] = <function get_torch_layers>):
 77def update_dist(
 78    model: nn.Module,
 79    weight_dist: DistributionT = None,
 80    prior_weight_dist: DistributionT = None,
 81    get_layers_func: Callable[
 82        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
 83    ] = get_torch_layers,
 84):
 85    """
 86    Update the weight/bias distributions of an already converted probabilistic model.
 87
 88    This is useful when you want to load a different set of posterior or prior
 89    distributions into the same network structure, without re-running the entire
 90    `dnn_to_probnn` procedure.
 91
 92    Args:
 93        model (nn.Module): The probabilistic neural network model (already converted).
 94        weight_dist (DistributionT, optional): New posterior distributions keyed by layer name.
 95            If provided, each layer's '_weight_dist' and '_bias_dist' are updated.
 96        prior_weight_dist (DistributionT, optional): New prior distributions keyed by layer name.
 97            If provided, each layer's '_prior_weight_dist' and '_prior_bias_dist' are updated.
 98        get_layers_func (Callable): Function that returns an iterator of (layer_name, layer_module).
 99            Defaults to `get_torch_layers`.
100
101    Returns:
102        None: The distributions in the model are updated in place.
103    """
104    if weight_dist is not None:
105        for (_name, distribution), (_, layer) in zip(
106            weight_dist.items(), get_layers_func(model), strict=False
107        ):
108            layer_type = type(layer)
109            if layer_type in LAYER_MAPPING.values():
110                layer.__setattr__("_weight_dist", distribution["weight"])
111                layer.__setattr__("_bias_dist", distribution["bias"])
112
113    if prior_weight_dist is not None:
114        for (_name, distribution), (_, layer) in zip(
115            prior_weight_dist.items(), get_layers_func(model), strict=False
116        ):
117            layer_type = type(layer)
118            if layer_type in LAYER_MAPPING.values():
119                layer.__setattr__("_prior_weight_dist", distribution["weight"])
120                layer.__setattr__("_prior_bias_dist", distribution["bias"])

Update the weight/bias distributions of an already converted probabilistic model.

This is useful when you want to load a different set of posterior or prior distributions into the same network structure, without re-running the entire dnn_to_probnn procedure.

Arguments:
  • model (nn.Module): The probabilistic neural network model (already converted).
  • weight_dist (DistributionT, optional): New posterior distributions keyed by layer name. If provided, each layer's '_weight_dist' and '_bias_dist' are updated.
  • prior_weight_dist (DistributionT, optional): New prior distributions keyed by layer name. If provided, each layer's '_prior_weight_dist' and '_prior_bias_dist' are updated.
  • get_layers_func (Callable): Function that returns an iterator of (layer_name, layer_module). Defaults to get_torch_layers.
Returns:

None: The distributions in the model are updated in place.