core.distribution.utils

  1import math
  2from collections.abc import Callable, Iterator
  3
  4import torch
  5from torch import Tensor, nn
  6
  7from core.distribution import AbstractVariable
  8from core.layer.utils import LayerNameT, get_bayesian_torch_layers, get_torch_layers
  9
 10DistributionT = dict[LayerNameT, dict[str, AbstractVariable]]
 11
 12
 13def from_ivon(
 14    model: nn.Module,
 15    optimizer: "ivon.IVON",
 16    distribution: type[AbstractVariable],
 17    requires_grad: bool = True,
 18    get_layers_func: Callable[
 19        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
 20    ] = get_torch_layers,
 21) -> DistributionT:
 22    """
 23    Construct a distribution from an IVON optimizer's parameters and Hessian approximations.
 24
 25    This function extracts weight and bias information (as well as Hessians) from the
 26    IVON optimizer, then creates instances of the specified `distribution` for each
 27    parameter. The newly created distributions can be used as a posterior or prior
 28    in a PAC-Bayes setting.
 29
 30    Args:
 31        model (nn.Module): The (deterministic) model whose parameters correspond to the IVON's weights.
 32        optimizer (ivon.IVON): An instance of IVON optimizer containing the Hessian approximations.
 33        distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` to instantiate.
 34        requires_grad (bool, optional): If True, gradients will be computed on the newly created parameters.
 35        get_layers_func (Callable, optional): A function to retrieve model layers. Defaults to `get_torch_layers`.
 36
 37    Returns:
 38        DistributionT: A dictionary mapping layer names to dicts of {'weight': ..., 'bias': ...},
 39        each containing an instance of `AbstractVariable`.
 40    """
 41    distributions = {}
 42    i = 0
 43    shift = 0
 44    weights = optimizer.param_groups[0]["params"]
 45    hessians = optimizer.param_groups[0]["hess"]
 46    weight_decay = optimizer.param_groups[0]["weight_decay"]
 47    ess = optimizer.param_groups[0]["ess"]
 48    sigma = 1 / (ess * (hessians + weight_decay)).sqrt()
 49    rho = torch.log(torch.exp(sigma) - 1)
 50
 51    for name, layer in get_layers_func(model):
 52        if layer.weight is not None:
 53            weight_cutoff = shift + math.prod(layer.weight.shape)
 54            weight_distribution = distribution(
 55                mu=weights[i],
 56                rho=rho[shift:weight_cutoff].reshape(*layer.weight.shape),
 57                mu_requires_grad=requires_grad,
 58                rho_requires_grad=requires_grad,
 59            )
 60            shift = weight_cutoff
 61            i += 1
 62        else:
 63            weight_distribution = None
 64        if layer.bias is not None:
 65            bias_cutoff = shift + math.prod(layer.bias.shape)
 66            bias_distribution = distribution(
 67                mu=weights[i],
 68                rho=rho[shift:bias_cutoff].reshape(*layer.bias.shape),
 69                mu_requires_grad=requires_grad,
 70                rho_requires_grad=requires_grad,
 71            )
 72            shift = bias_cutoff
 73            i += 1
 74        else:
 75            bias_distribution = None
 76        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
 77    return distributions
 78
 79
 80def from_flat_rho(
 81    model: nn.Module,
 82    rho: Tensor | list[float],
 83    distribution: type[AbstractVariable],
 84    requires_grad: bool = True,
 85    get_layers_func: Callable[
 86        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
 87    ] = get_torch_layers,
 88) -> DistributionT:
 89    """
 90    Create distributions for each layer using a shared or flat `rho` array.
 91
 92    This function takes a model and a `rho` tensor/list containing values for all weight/bias
 93    elements in consecutive order. Each layer's `mu` is initialized from the current layer weights,
 94    and `rho` is reshaped accordingly for the layer's shape.
 95
 96    Args:
 97        model (nn.Module): The PyTorch model whose layers are converted into distributions.
 98        rho (Union[Tensor, List[float]]): A 1D tensor or list of floating values used to set `rho`.
 99        distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` to instantiate.
100        requires_grad (bool, optional): If True, gradients will be computed for the distribution parameters.
101        get_layers_func (Callable, optional): Function for iterating over the model layers. Defaults to `get_torch_layers`.
102
103    Returns:
104        DistributionT: A dictionary of layer distributions keyed by layer names, each
105        containing 'weight' and 'bias' distributions if they exist.
106    """
107    distributions = {}
108    shift = 0
109    for name, layer in get_layers_func(model):
110        if layer.weight is not None:
111            # weight_cutoff = shift + layer.out_features * layer.in_features
112            weight_cutoff = shift + math.prod(layer.weight.shape)
113            weight_distribution = distribution(
114                mu=layer.weight,
115                rho=rho[shift:weight_cutoff].reshape(*layer.weight.shape),
116                mu_requires_grad=requires_grad,
117                rho_requires_grad=requires_grad,
118            )
119        else:
120            weight_distribution = None
121        if layer.bias is not None:
122            # bias_cutoff = weight_cutoff + layer.out_features
123            bias_cutoff = weight_cutoff + math.prod(layer.bias.shape)
124            bias_distribution = distribution(
125                mu=layer.bias,
126                rho=rho[weight_cutoff:bias_cutoff].reshape(*layer.bias.shape),
127                mu_requires_grad=requires_grad,
128                rho_requires_grad=requires_grad,
129            )
130            shift = bias_cutoff
131        else:
132            bias_distribution = None
133            shift = weight_cutoff
134        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
135    return distributions
136
137
138def _from_any(
139    model: nn.Module,
140    distribution: type[AbstractVariable],
141    requires_grad: bool,
142    get_layers_func: Callable[[nn.Module], Iterator[tuple[LayerNameT, nn.Module]]],
143    weight_mu_fill_func: Callable[[nn.Module], Tensor],
144    weight_rho_fill_func: Callable[[nn.Module], Tensor],
145    bias_mu_fill_func: Callable[[nn.Module], Tensor],
146    bias_rho_fill_func: Callable[[nn.Module], Tensor],
147    weight_exists: Callable[[nn.Module], bool] = lambda layer: hasattr(layer, "weight"),
148    bias_exists: Callable[[nn.Module], bool] = lambda layer: hasattr(layer, "bias"),
149) -> DistributionT:
150    """
151    A helper function to create a distribution for each layer in a model using
152    user-provided fill functions for `mu` and `rho`.
153
154    Args:
155        model (nn.Module): The model to convert into distributions.
156        distribution (Type[AbstractVariable]): The type of variable distribution to instantiate.
157        requires_grad (bool): If True, gradients will be computed on the distribution parameters.
158        get_layers_func (Callable): A function to iterate over model layers and yield (name, layer) pairs.
159        weight_mu_fill_func (Callable): A function that returns a tensor for initializing the weight `mu`.
160        weight_rho_fill_func (Callable): A function that returns a tensor for initializing the weight `rho`.
161        bias_mu_fill_func (Callable): A function that returns a tensor for initializing the bias `mu`.
162        bias_rho_fill_func (Callable): A function that returns a tensor for initializing the bias `rho`.
163        weight_exists (Callable, optional): A predicate to check if a layer contains a weight attribute.
164        bias_exists (Callable, optional): A predicate to check if a layer contains a bias attribute.
165
166    Returns:
167        DistributionT: A dictionary mapping each layer to weight/bias distributions.
168    """
169    distributions = {}
170    for name, layer in get_layers_func(model):
171        if weight_exists(layer) and layer.weight is not None:
172            weight_distribution = distribution(
173                mu=weight_mu_fill_func(layer),
174                rho=weight_rho_fill_func(layer),
175                mu_requires_grad=requires_grad,
176                rho_requires_grad=requires_grad,
177            )
178        else:
179            weight_distribution = None
180        if bias_exists(layer) and layer.bias is not None:
181            bias_distribution = distribution(
182                mu=bias_mu_fill_func(layer),
183                rho=bias_rho_fill_func(layer),
184                mu_requires_grad=requires_grad,
185                rho_requires_grad=requires_grad,
186            )
187        else:
188            bias_distribution = None
189        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
190    return distributions
191
192
193def from_random(
194    model: nn.Module,
195    rho: Tensor,
196    distribution: type[AbstractVariable],
197    requires_grad: bool = True,
198    get_layers_func: Callable[
199        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
200    ] = get_torch_layers,
201) -> DistributionT:
202    """
203    Create a distribution for each layer with randomly initialized mean (using truncated normal)
204    and a constant `rho` value.
205
206    Args:
207        model (nn.Module): The target PyTorch model.
208        rho (Tensor): A single scalar tensor defining the initial `rho` for all weights/biases.
209        distribution (Type[AbstractVariable]): The class for creating each weight/bias distribution.
210        requires_grad (bool, optional): If True, allows gradient-based updates of `mu` and `rho`.
211        get_layers_func (Callable, optional): Function to iterate over model layers. Defaults to `get_torch_layers`.
212
213    Returns:
214        DistributionT: A dictionary containing layer-wise distributions for weights and biases.
215    """
216
217    def get_truncated_normal_fill_tensor(layer: nn.Module) -> Tensor:
218        t = torch.Tensor(*layer.weight.shape)
219        if hasattr(layer, "weight") and layer.weight is not None:
220            in_features = math.prod(layer.weight.shape[1:])
221        else:
222            raise ValueError(f"Unsupported layer of type: {type(layer)}")
223        w = 1 / math.sqrt(in_features)
224        return truncated_normal_fill_tensor(t, 0, w, -2 * w, 2 * w)
225
226    return _from_any(
227        model,
228        distribution,
229        requires_grad,
230        get_layers_func,
231        weight_mu_fill_func=get_truncated_normal_fill_tensor,
232        weight_rho_fill_func=lambda layer: torch.ones(*layer.weight.shape) * rho,
233        bias_mu_fill_func=lambda layer: torch.zeros(*layer.bias.shape),
234        bias_rho_fill_func=lambda layer: torch.ones(*layer.bias.shape) * rho,
235    )
236
237
238def from_zeros(
239    model: nn.Module,
240    rho: Tensor,
241    distribution: type[AbstractVariable],
242    requires_grad: bool = True,
243    get_layers_func: Callable[
244        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
245    ] = get_torch_layers,
246) -> DistributionT:
247    """
248    Create distributions for each layer by setting `mu` to zero and `rho` to a constant value.
249
250    Args:
251        model (nn.Module): The PyTorch model.
252        rho (Tensor): A scalar defining the initial `rho` for all weights/biases.
253        distribution (Type[AbstractVariable]): Distribution class to instantiate.
254        requires_grad (bool, optional): Whether to track gradients for `mu` and `rho`.
255        get_layers_func (Callable, optional): Layer iteration function. Defaults to `get_torch_layers`.
256
257    Returns:
258        DistributionT: A dictionary mapping layer names to weight/bias distributions.
259    """
260    return _from_any(
261        model,
262        distribution,
263        requires_grad,
264        get_layers_func,
265        weight_mu_fill_func=lambda layer: torch.zeros(*layer.weight.shape),
266        weight_rho_fill_func=lambda layer: torch.ones(*layer.weight.shape) * rho,
267        bias_mu_fill_func=lambda layer: torch.zeros(*layer.bias.shape),
268        bias_rho_fill_func=lambda layer: torch.ones(*layer.bias.shape) * rho,
269    )
270
271
272def from_layered(
273    model: torch.nn.Module,
274    attribute_mapping: dict[str, str],
275    distribution: type[AbstractVariable],
276    requires_grad: bool = True,
277    get_layers_func: Callable[
278        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
279    ] = get_torch_layers,
280) -> DistributionT:
281    """
282    Create distributions by extracting `mu` and `rho` from specified attributes in the model layers.
283
284    This function looks up layer attributes for weight and bias (e.g., "weight_mu", "weight_rho")
285    using `attribute_mapping`, then initializes each distribution accordingly.
286
287    Args:
288        model (nn.Module): The model whose layers contain the specified attributes.
289        attribute_mapping (dict[str, str]): A mapping of attribute names, for example:
290            {
291              "weight_mu": "mu_weight",
292              "weight_rho": "rho_weight",
293              "bias_mu": "mu_bias",
294              "bias_rho": "rho_bias"
295            }
296        distribution (Type[AbstractVariable]): The class used to create weight/bias distributions.
297        requires_grad (bool, optional): If True, gradients will be computed on `mu` and `rho`.
298        get_layers_func (Callable, optional): Layer iteration function.
299
300    Returns:
301        DistributionT: A dictionary of distributions keyed by layer names.
302    """
303    return _from_any(
304        model,
305        distribution,
306        requires_grad,
307        get_layers_func,
308        weight_exists=lambda layer: hasattr(layer, attribute_mapping["weight_mu"])
309        and hasattr(layer, attribute_mapping["weight_rho"]),
310        bias_exists=lambda layer: hasattr(layer, attribute_mapping["weight_mu"])
311        and hasattr(layer, attribute_mapping["weight_rho"]),
312        weight_mu_fill_func=lambda layer: layer.__getattr__(
313            attribute_mapping["weight_mu"]
314        )
315        .detach()
316        .clone(),
317        weight_rho_fill_func=lambda layer: layer.__getattr__(
318            attribute_mapping["weight_rho"]
319        )
320        .detach()
321        .clone(),
322        bias_mu_fill_func=lambda layer: layer.__getattr__(attribute_mapping["bias_mu"])
323        .detach()
324        .clone(),
325        bias_rho_fill_func=lambda layer: layer.__getattr__(
326            attribute_mapping["bias_rho"]
327        )
328        .detach()
329        .clone(),
330    )
331
332
333def from_bnn(
334    model: nn.Module,
335    distribution: type[AbstractVariable],
336    requires_grad: bool = True,
337    get_layers_func: Callable[
338        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
339    ] = get_bayesian_torch_layers,
340) -> DistributionT:
341    """
342    Construct distributions by reading the attributes (e.g., mu_weight, rho_weight, mu_bias, rho_bias)
343    from layers typically found in BayesianTorch modules.
344
345    Args:
346        model (nn.Module): The Bayesian Torch model containing layer attributes such as mu_weight, rho_weight, etc.
347        distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` for each parameter.
348        requires_grad (bool, optional): If True, allows gradient-based optimization of `mu` and `rho`.
349        get_layers_func (Callable, optional): A function that retrieves BayesianTorch layers. Defaults to `get_bayesian_torch_layers`.
350
351    Returns:
352        DistributionT: A dictionary mapping layer names to weight/bias distributions.
353    """
354    distributions = {}
355    for name, layer in get_layers_func(model):
356        if hasattr(layer, "mu_weight") and hasattr(layer, "rho_weight"):
357            weight_distribution = distribution(
358                mu=layer.__getattr__("mu_weight").detach().clone(),
359                rho=layer.__getattr__("rho_weight").detach().clone(),
360                mu_requires_grad=requires_grad,
361                rho_requires_grad=requires_grad,
362            )
363        elif hasattr(layer, "mu_kernel") and hasattr(layer, "rho_kernel"):
364            weight_distribution = distribution(
365                mu=layer.__getattr__("mu_kernel").detach().clone(),
366                rho=layer.__getattr__("rho_kernel").detach().clone(),
367                mu_requires_grad=requires_grad,
368                rho_requires_grad=requires_grad,
369            )
370        else:
371            weight_distribution = None
372        if hasattr(layer, "mu_bias") and hasattr(layer, "rho_bias"):
373            bias_distribution = distribution(
374                mu=layer.__getattr__("mu_bias").detach().clone(),
375                rho=layer.__getattr__("rho_bias").detach().clone(),
376                mu_requires_grad=requires_grad,
377                rho_requires_grad=requires_grad,
378            )
379        else:
380            bias_distribution = None
381        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
382    return distributions
383
384
385def from_copy(
386    dist: DistributionT,
387    distribution: type[AbstractVariable],
388    requires_grad: bool = True,
389) -> DistributionT:
390    """
391    Create a new distribution by copying `mu` and `rho` from an existing distribution.
392
393    Args:
394        dist (DistributionT): A distribution dictionary to copy from.
395        distribution (Type[AbstractVariable]): The class to instantiate for each weight/bias.
396        requires_grad (bool, optional): If True, the new distribution parameters can be updated via gradients.
397
398    Returns:
399        DistributionT: A new distribution dictionary with the same layer structure,
400        but new `mu` and `rho` parameters cloned from `dist`.
401    """
402    distributions = {}
403    for name, layer in dist.items():
404        weight_distribution = distribution(
405            mu=layer["weight"].mu.detach().clone(),
406            rho=layer["weight"].rho.detach().clone(),
407            mu_requires_grad=requires_grad,
408            rho_requires_grad=requires_grad,
409        )
410        if layer["bias"] is not None:
411            bias_distribution = distribution(
412                mu=layer["bias"].mu.detach().clone(),
413                rho=layer["bias"].rho.detach().clone(),
414                mu_requires_grad=requires_grad,
415                rho_requires_grad=requires_grad,
416            )
417        else:
418            bias_distribution = None
419        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
420    return distributions
421
422
423def compute_kl(dist1: DistributionT, dist2: DistributionT) -> Tensor:
424    """
425    Compute the total KL divergence between two distributions of the same structure.
426
427    Each corresponding layer's weight/bias KL is summed to produce a single scalar.
428
429    Args:
430        dist1 (DistributionT): The first distribution dictionary.
431        dist2 (DistributionT): The second distribution dictionary.
432
433    Returns:
434        Tensor: A scalar tensor representing the total KL divergence across all layers.
435    """
436    kl_list = []
437    for idx in dist1:
438        for key in dist1[idx]:
439            if dist1[idx][key] is not None and dist2[idx][key] is not None:
440                kl = dist1[idx][key].compute_kl(dist2[idx][key])
441                kl_list.append(kl)
442    return torch.stack(kl_list).sum()
443
444
445def compute_standard_normal_cdf(x: float) -> float:
446    """
447    Compute the cumulative distribution function (CDF) of a standard normal at point x.
448
449    Args:
450        x (float): The input value at which to evaluate the standard normal CDF.
451
452    Returns:
453        float: The CDF value of the standard normal distribution at x.
454    """
455    # TODO: replace with numpy
456    return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
457
458
459def truncated_normal_fill_tensor(
460    tensor: torch.Tensor,
461    mean: float = 0.0,
462    std: float = 1.0,
463    a: float = -2.0,
464    b: float = 2.0,
465) -> torch.Tensor:
466    """
467    Fill a tensor in-place with values drawn from a truncated normal distribution.
468
469    The resulting values lie in the interval [a, b], centered around `mean`
470    with approximate std `std`.
471
472    Args:
473        tensor (torch.Tensor): The tensor to fill.
474        mean (float, optional): Mean of the desired distribution. Defaults to 0.0.
475        std (float, optional): Standard deviation of the desired distribution. Defaults to 1.0.
476        a (float, optional): Lower bound of truncation. Defaults to -2.0.
477        b (float, optional): Upper bound of truncation. Defaults to 2.0.
478
479    Returns:
480        torch.Tensor: The same tensor, filled in-place with truncated normal values.
481    """
482    with torch.no_grad():
483        # Get upper and lower cdf values
484        l_ = compute_standard_normal_cdf((a - mean) / std)
485        u_ = compute_standard_normal_cdf((b - mean) / std)
486
487        # Fill tensor with uniform values from [l_, u_]
488        tensor.uniform_(l_, u_)
489
490        # Use inverse cdf transform from normal distribution
491        tensor.mul_(2)
492        tensor.sub_(1)
493
494        # Ensure that the values are strictly between -1 and 1 for erfinv
495        eps = torch.finfo(tensor.dtype).eps
496        tensor.clamp_(min=-(1.0 - eps), max=(1.0 - eps))
497        tensor.erfinv_()
498
499        # Transform to proper mean, std
500        tensor.mul_(std * math.sqrt(2.0))
501        tensor.add_(mean)
502
503        # Clamp one last time to ensure it's still in the proper range
504        tensor.clamp_(min=a, max=b)
505        return tensor
DistributionT = dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]
def from_ivon( model: torch.nn.modules.module.Module, optimizer: 'ivon.IVON', distribution: type[core.distribution.AbstractVariable.AbstractVariable], requires_grad: bool = True, get_layers_func: Callable[[torch.nn.modules.module.Module], Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]] = <function get_torch_layers>) -> dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]:
14def from_ivon(
15    model: nn.Module,
16    optimizer: "ivon.IVON",
17    distribution: type[AbstractVariable],
18    requires_grad: bool = True,
19    get_layers_func: Callable[
20        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
21    ] = get_torch_layers,
22) -> DistributionT:
23    """
24    Construct a distribution from an IVON optimizer's parameters and Hessian approximations.
25
26    This function extracts weight and bias information (as well as Hessians) from the
27    IVON optimizer, then creates instances of the specified `distribution` for each
28    parameter. The newly created distributions can be used as a posterior or prior
29    in a PAC-Bayes setting.
30
31    Args:
32        model (nn.Module): The (deterministic) model whose parameters correspond to the IVON's weights.
33        optimizer (ivon.IVON): An instance of IVON optimizer containing the Hessian approximations.
34        distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` to instantiate.
35        requires_grad (bool, optional): If True, gradients will be computed on the newly created parameters.
36        get_layers_func (Callable, optional): A function to retrieve model layers. Defaults to `get_torch_layers`.
37
38    Returns:
39        DistributionT: A dictionary mapping layer names to dicts of {'weight': ..., 'bias': ...},
40        each containing an instance of `AbstractVariable`.
41    """
42    distributions = {}
43    i = 0
44    shift = 0
45    weights = optimizer.param_groups[0]["params"]
46    hessians = optimizer.param_groups[0]["hess"]
47    weight_decay = optimizer.param_groups[0]["weight_decay"]
48    ess = optimizer.param_groups[0]["ess"]
49    sigma = 1 / (ess * (hessians + weight_decay)).sqrt()
50    rho = torch.log(torch.exp(sigma) - 1)
51
52    for name, layer in get_layers_func(model):
53        if layer.weight is not None:
54            weight_cutoff = shift + math.prod(layer.weight.shape)
55            weight_distribution = distribution(
56                mu=weights[i],
57                rho=rho[shift:weight_cutoff].reshape(*layer.weight.shape),
58                mu_requires_grad=requires_grad,
59                rho_requires_grad=requires_grad,
60            )
61            shift = weight_cutoff
62            i += 1
63        else:
64            weight_distribution = None
65        if layer.bias is not None:
66            bias_cutoff = shift + math.prod(layer.bias.shape)
67            bias_distribution = distribution(
68                mu=weights[i],
69                rho=rho[shift:bias_cutoff].reshape(*layer.bias.shape),
70                mu_requires_grad=requires_grad,
71                rho_requires_grad=requires_grad,
72            )
73            shift = bias_cutoff
74            i += 1
75        else:
76            bias_distribution = None
77        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
78    return distributions

Construct a distribution from an IVON optimizer's parameters and Hessian approximations.

This function extracts weight and bias information (as well as Hessians) from the IVON optimizer, then creates instances of the specified distribution for each parameter. The newly created distributions can be used as a posterior or prior in a PAC-Bayes setting.

Arguments:
  • model (nn.Module): The (deterministic) model whose parameters correspond to the IVON's weights.
  • optimizer (ivon.IVON): An instance of IVON optimizer containing the Hessian approximations.
  • distribution (Type[AbstractVariable]): The subclass of AbstractVariable to instantiate.
  • requires_grad (bool, optional): If True, gradients will be computed on the newly created parameters.
  • get_layers_func (Callable, optional): A function to retrieve model layers. Defaults to get_torch_layers.
Returns:

DistributionT: A dictionary mapping layer names to dicts of {'weight': ..., 'bias': ...}, each containing an instance of AbstractVariable.

def from_flat_rho( model: torch.nn.modules.module.Module, rho: torch.Tensor | list[float], distribution: type[core.distribution.AbstractVariable.AbstractVariable], requires_grad: bool = True, get_layers_func: Callable[[torch.nn.modules.module.Module], Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]] = <function get_torch_layers>) -> dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]:
 81def from_flat_rho(
 82    model: nn.Module,
 83    rho: Tensor | list[float],
 84    distribution: type[AbstractVariable],
 85    requires_grad: bool = True,
 86    get_layers_func: Callable[
 87        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
 88    ] = get_torch_layers,
 89) -> DistributionT:
 90    """
 91    Create distributions for each layer using a shared or flat `rho` array.
 92
 93    This function takes a model and a `rho` tensor/list containing values for all weight/bias
 94    elements in consecutive order. Each layer's `mu` is initialized from the current layer weights,
 95    and `rho` is reshaped accordingly for the layer's shape.
 96
 97    Args:
 98        model (nn.Module): The PyTorch model whose layers are converted into distributions.
 99        rho (Union[Tensor, List[float]]): A 1D tensor or list of floating values used to set `rho`.
100        distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` to instantiate.
101        requires_grad (bool, optional): If True, gradients will be computed for the distribution parameters.
102        get_layers_func (Callable, optional): Function for iterating over the model layers. Defaults to `get_torch_layers`.
103
104    Returns:
105        DistributionT: A dictionary of layer distributions keyed by layer names, each
106        containing 'weight' and 'bias' distributions if they exist.
107    """
108    distributions = {}
109    shift = 0
110    for name, layer in get_layers_func(model):
111        if layer.weight is not None:
112            # weight_cutoff = shift + layer.out_features * layer.in_features
113            weight_cutoff = shift + math.prod(layer.weight.shape)
114            weight_distribution = distribution(
115                mu=layer.weight,
116                rho=rho[shift:weight_cutoff].reshape(*layer.weight.shape),
117                mu_requires_grad=requires_grad,
118                rho_requires_grad=requires_grad,
119            )
120        else:
121            weight_distribution = None
122        if layer.bias is not None:
123            # bias_cutoff = weight_cutoff + layer.out_features
124            bias_cutoff = weight_cutoff + math.prod(layer.bias.shape)
125            bias_distribution = distribution(
126                mu=layer.bias,
127                rho=rho[weight_cutoff:bias_cutoff].reshape(*layer.bias.shape),
128                mu_requires_grad=requires_grad,
129                rho_requires_grad=requires_grad,
130            )
131            shift = bias_cutoff
132        else:
133            bias_distribution = None
134            shift = weight_cutoff
135        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
136    return distributions

Create distributions for each layer using a shared or flat rho array.

This function takes a model and a rho tensor/list containing values for all weight/bias elements in consecutive order. Each layer's mu is initialized from the current layer weights, and rho is reshaped accordingly for the layer's shape.

Arguments:
  • model (nn.Module): The PyTorch model whose layers are converted into distributions.
  • rho (Union[Tensor, List[float]]): A 1D tensor or list of floating values used to set rho.
  • distribution (Type[AbstractVariable]): The subclass of AbstractVariable to instantiate.
  • requires_grad (bool, optional): If True, gradients will be computed for the distribution parameters.
  • get_layers_func (Callable, optional): Function for iterating over the model layers. Defaults to get_torch_layers.
Returns:

DistributionT: A dictionary of layer distributions keyed by layer names, each containing 'weight' and 'bias' distributions if they exist.

def from_random( model: torch.nn.modules.module.Module, rho: torch.Tensor, distribution: type[core.distribution.AbstractVariable.AbstractVariable], requires_grad: bool = True, get_layers_func: Callable[[torch.nn.modules.module.Module], Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]] = <function get_torch_layers>) -> dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]:
194def from_random(
195    model: nn.Module,
196    rho: Tensor,
197    distribution: type[AbstractVariable],
198    requires_grad: bool = True,
199    get_layers_func: Callable[
200        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
201    ] = get_torch_layers,
202) -> DistributionT:
203    """
204    Create a distribution for each layer with randomly initialized mean (using truncated normal)
205    and a constant `rho` value.
206
207    Args:
208        model (nn.Module): The target PyTorch model.
209        rho (Tensor): A single scalar tensor defining the initial `rho` for all weights/biases.
210        distribution (Type[AbstractVariable]): The class for creating each weight/bias distribution.
211        requires_grad (bool, optional): If True, allows gradient-based updates of `mu` and `rho`.
212        get_layers_func (Callable, optional): Function to iterate over model layers. Defaults to `get_torch_layers`.
213
214    Returns:
215        DistributionT: A dictionary containing layer-wise distributions for weights and biases.
216    """
217
218    def get_truncated_normal_fill_tensor(layer: nn.Module) -> Tensor:
219        t = torch.Tensor(*layer.weight.shape)
220        if hasattr(layer, "weight") and layer.weight is not None:
221            in_features = math.prod(layer.weight.shape[1:])
222        else:
223            raise ValueError(f"Unsupported layer of type: {type(layer)}")
224        w = 1 / math.sqrt(in_features)
225        return truncated_normal_fill_tensor(t, 0, w, -2 * w, 2 * w)
226
227    return _from_any(
228        model,
229        distribution,
230        requires_grad,
231        get_layers_func,
232        weight_mu_fill_func=get_truncated_normal_fill_tensor,
233        weight_rho_fill_func=lambda layer: torch.ones(*layer.weight.shape) * rho,
234        bias_mu_fill_func=lambda layer: torch.zeros(*layer.bias.shape),
235        bias_rho_fill_func=lambda layer: torch.ones(*layer.bias.shape) * rho,
236    )

Create a distribution for each layer with randomly initialized mean (using truncated normal) and a constant rho value.

Arguments:
  • model (nn.Module): The target PyTorch model.
  • rho (Tensor): A single scalar tensor defining the initial rho for all weights/biases.
  • distribution (Type[AbstractVariable]): The class for creating each weight/bias distribution.
  • requires_grad (bool, optional): If True, allows gradient-based updates of mu and rho.
  • get_layers_func (Callable, optional): Function to iterate over model layers. Defaults to get_torch_layers.
Returns:

DistributionT: A dictionary containing layer-wise distributions for weights and biases.

def from_zeros( model: torch.nn.modules.module.Module, rho: torch.Tensor, distribution: type[core.distribution.AbstractVariable.AbstractVariable], requires_grad: bool = True, get_layers_func: Callable[[torch.nn.modules.module.Module], Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]] = <function get_torch_layers>) -> dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]:
239def from_zeros(
240    model: nn.Module,
241    rho: Tensor,
242    distribution: type[AbstractVariable],
243    requires_grad: bool = True,
244    get_layers_func: Callable[
245        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
246    ] = get_torch_layers,
247) -> DistributionT:
248    """
249    Create distributions for each layer by setting `mu` to zero and `rho` to a constant value.
250
251    Args:
252        model (nn.Module): The PyTorch model.
253        rho (Tensor): A scalar defining the initial `rho` for all weights/biases.
254        distribution (Type[AbstractVariable]): Distribution class to instantiate.
255        requires_grad (bool, optional): Whether to track gradients for `mu` and `rho`.
256        get_layers_func (Callable, optional): Layer iteration function. Defaults to `get_torch_layers`.
257
258    Returns:
259        DistributionT: A dictionary mapping layer names to weight/bias distributions.
260    """
261    return _from_any(
262        model,
263        distribution,
264        requires_grad,
265        get_layers_func,
266        weight_mu_fill_func=lambda layer: torch.zeros(*layer.weight.shape),
267        weight_rho_fill_func=lambda layer: torch.ones(*layer.weight.shape) * rho,
268        bias_mu_fill_func=lambda layer: torch.zeros(*layer.bias.shape),
269        bias_rho_fill_func=lambda layer: torch.ones(*layer.bias.shape) * rho,
270    )

Create distributions for each layer by setting mu to zero and rho to a constant value.

Arguments:
  • model (nn.Module): The PyTorch model.
  • rho (Tensor): A scalar defining the initial rho for all weights/biases.
  • distribution (Type[AbstractVariable]): Distribution class to instantiate.
  • requires_grad (bool, optional): Whether to track gradients for mu and rho.
  • get_layers_func (Callable, optional): Layer iteration function. Defaults to get_torch_layers.
Returns:

DistributionT: A dictionary mapping layer names to weight/bias distributions.

def from_layered( model: torch.nn.modules.module.Module, attribute_mapping: dict[str, str], distribution: type[core.distribution.AbstractVariable.AbstractVariable], requires_grad: bool = True, get_layers_func: Callable[[torch.nn.modules.module.Module], Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]] = <function get_torch_layers>) -> dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]:
273def from_layered(
274    model: torch.nn.Module,
275    attribute_mapping: dict[str, str],
276    distribution: type[AbstractVariable],
277    requires_grad: bool = True,
278    get_layers_func: Callable[
279        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
280    ] = get_torch_layers,
281) -> DistributionT:
282    """
283    Create distributions by extracting `mu` and `rho` from specified attributes in the model layers.
284
285    This function looks up layer attributes for weight and bias (e.g., "weight_mu", "weight_rho")
286    using `attribute_mapping`, then initializes each distribution accordingly.
287
288    Args:
289        model (nn.Module): The model whose layers contain the specified attributes.
290        attribute_mapping (dict[str, str]): A mapping of attribute names, for example:
291            {
292              "weight_mu": "mu_weight",
293              "weight_rho": "rho_weight",
294              "bias_mu": "mu_bias",
295              "bias_rho": "rho_bias"
296            }
297        distribution (Type[AbstractVariable]): The class used to create weight/bias distributions.
298        requires_grad (bool, optional): If True, gradients will be computed on `mu` and `rho`.
299        get_layers_func (Callable, optional): Layer iteration function.
300
301    Returns:
302        DistributionT: A dictionary of distributions keyed by layer names.
303    """
304    return _from_any(
305        model,
306        distribution,
307        requires_grad,
308        get_layers_func,
309        weight_exists=lambda layer: hasattr(layer, attribute_mapping["weight_mu"])
310        and hasattr(layer, attribute_mapping["weight_rho"]),
311        bias_exists=lambda layer: hasattr(layer, attribute_mapping["weight_mu"])
312        and hasattr(layer, attribute_mapping["weight_rho"]),
313        weight_mu_fill_func=lambda layer: layer.__getattr__(
314            attribute_mapping["weight_mu"]
315        )
316        .detach()
317        .clone(),
318        weight_rho_fill_func=lambda layer: layer.__getattr__(
319            attribute_mapping["weight_rho"]
320        )
321        .detach()
322        .clone(),
323        bias_mu_fill_func=lambda layer: layer.__getattr__(attribute_mapping["bias_mu"])
324        .detach()
325        .clone(),
326        bias_rho_fill_func=lambda layer: layer.__getattr__(
327            attribute_mapping["bias_rho"]
328        )
329        .detach()
330        .clone(),
331    )

Create distributions by extracting mu and rho from specified attributes in the model layers.

This function looks up layer attributes for weight and bias (e.g., "weight_mu", "weight_rho") using attribute_mapping, then initializes each distribution accordingly.

Arguments:
  • model (nn.Module): The model whose layers contain the specified attributes.
  • attribute_mapping (dict[str, str]): A mapping of attribute names, for example: { "weight_mu": "mu_weight", "weight_rho": "rho_weight", "bias_mu": "mu_bias", "bias_rho": "rho_bias" }
  • distribution (Type[AbstractVariable]): The class used to create weight/bias distributions.
  • requires_grad (bool, optional): If True, gradients will be computed on mu and rho.
  • get_layers_func (Callable, optional): Layer iteration function.
Returns:

DistributionT: A dictionary of distributions keyed by layer names.

def from_bnn( model: torch.nn.modules.module.Module, distribution: type[core.distribution.AbstractVariable.AbstractVariable], requires_grad: bool = True, get_layers_func: Callable[[torch.nn.modules.module.Module], Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]] = <function get_bayesian_torch_layers>) -> dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]:
334def from_bnn(
335    model: nn.Module,
336    distribution: type[AbstractVariable],
337    requires_grad: bool = True,
338    get_layers_func: Callable[
339        [nn.Module], Iterator[tuple[LayerNameT, nn.Module]]
340    ] = get_bayesian_torch_layers,
341) -> DistributionT:
342    """
343    Construct distributions by reading the attributes (e.g., mu_weight, rho_weight, mu_bias, rho_bias)
344    from layers typically found in BayesianTorch modules.
345
346    Args:
347        model (nn.Module): The Bayesian Torch model containing layer attributes such as mu_weight, rho_weight, etc.
348        distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` for each parameter.
349        requires_grad (bool, optional): If True, allows gradient-based optimization of `mu` and `rho`.
350        get_layers_func (Callable, optional): A function that retrieves BayesianTorch layers. Defaults to `get_bayesian_torch_layers`.
351
352    Returns:
353        DistributionT: A dictionary mapping layer names to weight/bias distributions.
354    """
355    distributions = {}
356    for name, layer in get_layers_func(model):
357        if hasattr(layer, "mu_weight") and hasattr(layer, "rho_weight"):
358            weight_distribution = distribution(
359                mu=layer.__getattr__("mu_weight").detach().clone(),
360                rho=layer.__getattr__("rho_weight").detach().clone(),
361                mu_requires_grad=requires_grad,
362                rho_requires_grad=requires_grad,
363            )
364        elif hasattr(layer, "mu_kernel") and hasattr(layer, "rho_kernel"):
365            weight_distribution = distribution(
366                mu=layer.__getattr__("mu_kernel").detach().clone(),
367                rho=layer.__getattr__("rho_kernel").detach().clone(),
368                mu_requires_grad=requires_grad,
369                rho_requires_grad=requires_grad,
370            )
371        else:
372            weight_distribution = None
373        if hasattr(layer, "mu_bias") and hasattr(layer, "rho_bias"):
374            bias_distribution = distribution(
375                mu=layer.__getattr__("mu_bias").detach().clone(),
376                rho=layer.__getattr__("rho_bias").detach().clone(),
377                mu_requires_grad=requires_grad,
378                rho_requires_grad=requires_grad,
379            )
380        else:
381            bias_distribution = None
382        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
383    return distributions

Construct distributions by reading the attributes (e.g., mu_weight, rho_weight, mu_bias, rho_bias) from layers typically found in BayesianTorch modules.

Arguments:
  • model (nn.Module): The Bayesian Torch model containing layer attributes such as mu_weight, rho_weight, etc.
  • distribution (Type[AbstractVariable]): The subclass of AbstractVariable for each parameter.
  • requires_grad (bool, optional): If True, allows gradient-based optimization of mu and rho.
  • get_layers_func (Callable, optional): A function that retrieves BayesianTorch layers. Defaults to get_bayesian_torch_layers.
Returns:

DistributionT: A dictionary mapping layer names to weight/bias distributions.

def from_copy( dist: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], distribution: type[core.distribution.AbstractVariable.AbstractVariable], requires_grad: bool = True) -> dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]:
386def from_copy(
387    dist: DistributionT,
388    distribution: type[AbstractVariable],
389    requires_grad: bool = True,
390) -> DistributionT:
391    """
392    Create a new distribution by copying `mu` and `rho` from an existing distribution.
393
394    Args:
395        dist (DistributionT): A distribution dictionary to copy from.
396        distribution (Type[AbstractVariable]): The class to instantiate for each weight/bias.
397        requires_grad (bool, optional): If True, the new distribution parameters can be updated via gradients.
398
399    Returns:
400        DistributionT: A new distribution dictionary with the same layer structure,
401        but new `mu` and `rho` parameters cloned from `dist`.
402    """
403    distributions = {}
404    for name, layer in dist.items():
405        weight_distribution = distribution(
406            mu=layer["weight"].mu.detach().clone(),
407            rho=layer["weight"].rho.detach().clone(),
408            mu_requires_grad=requires_grad,
409            rho_requires_grad=requires_grad,
410        )
411        if layer["bias"] is not None:
412            bias_distribution = distribution(
413                mu=layer["bias"].mu.detach().clone(),
414                rho=layer["bias"].rho.detach().clone(),
415                mu_requires_grad=requires_grad,
416                rho_requires_grad=requires_grad,
417            )
418        else:
419            bias_distribution = None
420        distributions[name] = {"weight": weight_distribution, "bias": bias_distribution}
421    return distributions

Create a new distribution by copying mu and rho from an existing distribution.

Arguments:
  • dist (DistributionT): A distribution dictionary to copy from.
  • distribution (Type[AbstractVariable]): The class to instantiate for each weight/bias.
  • requires_grad (bool, optional): If True, the new distribution parameters can be updated via gradients.
Returns:

DistributionT: A new distribution dictionary with the same layer structure, but new mu and rho parameters cloned from dist.

def compute_kl( dist1: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]], dist2: dict[tuple[str, ...], dict[str, core.distribution.AbstractVariable.AbstractVariable]]) -> torch.Tensor:
424def compute_kl(dist1: DistributionT, dist2: DistributionT) -> Tensor:
425    """
426    Compute the total KL divergence between two distributions of the same structure.
427
428    Each corresponding layer's weight/bias KL is summed to produce a single scalar.
429
430    Args:
431        dist1 (DistributionT): The first distribution dictionary.
432        dist2 (DistributionT): The second distribution dictionary.
433
434    Returns:
435        Tensor: A scalar tensor representing the total KL divergence across all layers.
436    """
437    kl_list = []
438    for idx in dist1:
439        for key in dist1[idx]:
440            if dist1[idx][key] is not None and dist2[idx][key] is not None:
441                kl = dist1[idx][key].compute_kl(dist2[idx][key])
442                kl_list.append(kl)
443    return torch.stack(kl_list).sum()

Compute the total KL divergence between two distributions of the same structure.

Each corresponding layer's weight/bias KL is summed to produce a single scalar.

Arguments:
  • dist1 (DistributionT): The first distribution dictionary.
  • dist2 (DistributionT): The second distribution dictionary.
Returns:

Tensor: A scalar tensor representing the total KL divergence across all layers.

def compute_standard_normal_cdf(x: float) -> float:
446def compute_standard_normal_cdf(x: float) -> float:
447    """
448    Compute the cumulative distribution function (CDF) of a standard normal at point x.
449
450    Args:
451        x (float): The input value at which to evaluate the standard normal CDF.
452
453    Returns:
454        float: The CDF value of the standard normal distribution at x.
455    """
456    # TODO: replace with numpy
457    return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

Compute the cumulative distribution function (CDF) of a standard normal at point x.

Arguments:
  • x (float): The input value at which to evaluate the standard normal CDF.
Returns:

float: The CDF value of the standard normal distribution at x.

def truncated_normal_fill_tensor( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0) -> torch.Tensor:
460def truncated_normal_fill_tensor(
461    tensor: torch.Tensor,
462    mean: float = 0.0,
463    std: float = 1.0,
464    a: float = -2.0,
465    b: float = 2.0,
466) -> torch.Tensor:
467    """
468    Fill a tensor in-place with values drawn from a truncated normal distribution.
469
470    The resulting values lie in the interval [a, b], centered around `mean`
471    with approximate std `std`.
472
473    Args:
474        tensor (torch.Tensor): The tensor to fill.
475        mean (float, optional): Mean of the desired distribution. Defaults to 0.0.
476        std (float, optional): Standard deviation of the desired distribution. Defaults to 1.0.
477        a (float, optional): Lower bound of truncation. Defaults to -2.0.
478        b (float, optional): Upper bound of truncation. Defaults to 2.0.
479
480    Returns:
481        torch.Tensor: The same tensor, filled in-place with truncated normal values.
482    """
483    with torch.no_grad():
484        # Get upper and lower cdf values
485        l_ = compute_standard_normal_cdf((a - mean) / std)
486        u_ = compute_standard_normal_cdf((b - mean) / std)
487
488        # Fill tensor with uniform values from [l_, u_]
489        tensor.uniform_(l_, u_)
490
491        # Use inverse cdf transform from normal distribution
492        tensor.mul_(2)
493        tensor.sub_(1)
494
495        # Ensure that the values are strictly between -1 and 1 for erfinv
496        eps = torch.finfo(tensor.dtype).eps
497        tensor.clamp_(min=-(1.0 - eps), max=(1.0 - eps))
498        tensor.erfinv_()
499
500        # Transform to proper mean, std
501        tensor.mul_(std * math.sqrt(2.0))
502        tensor.add_(mean)
503
504        # Clamp one last time to ensure it's still in the proper range
505        tensor.clamp_(min=a, max=b)
506        return tensor

Fill a tensor in-place with values drawn from a truncated normal distribution.

The resulting values lie in the interval [a, b], centered around mean with approximate std std.

Arguments:
  • tensor (torch.Tensor): The tensor to fill.
  • mean (float, optional): Mean of the desired distribution. Defaults to 0.0.
  • std (float, optional): Standard deviation of the desired distribution. Defaults to 1.0.
  • a (float, optional): Lower bound of truncation. Defaults to -2.0.
  • b (float, optional): Upper bound of truncation. Defaults to 2.0.
Returns:

torch.Tensor: The same tensor, filled in-place with truncated normal values.