core.distribution.utils
1import math 2from collections.abc import Callable, Iterator 3 4import torch 5from torch import Tensor, nn 6 7from core.distribution import AbstractVariable 8from core.layer.utils import LayerNameT, get_bayesian_torch_layers, get_torch_layers 9 10DistributionT = dict[LayerNameT, dict[str, AbstractVariable]] 11 12 13def from_ivon( 14 model: nn.Module, 15 optimizer: "ivon.IVON", 16 distribution: type[AbstractVariable], 17 requires_grad: bool = True, 18 get_layers_func: Callable[ 19 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 20 ] = get_torch_layers, 21) -> DistributionT: 22 """ 23 Construct a distribution from an IVON optimizer's parameters and Hessian approximations. 24 25 This function extracts weight and bias information (as well as Hessians) from the 26 IVON optimizer, then creates instances of the specified `distribution` for each 27 parameter. The newly created distributions can be used as a posterior or prior 28 in a PAC-Bayes setting. 29 30 Args: 31 model (nn.Module): The (deterministic) model whose parameters correspond to the IVON's weights. 32 optimizer (ivon.IVON): An instance of IVON optimizer containing the Hessian approximations. 33 distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` to instantiate. 34 requires_grad (bool, optional): If True, gradients will be computed on the newly created parameters. 35 get_layers_func (Callable, optional): A function to retrieve model layers. Defaults to `get_torch_layers`. 36 37 Returns: 38 DistributionT: A dictionary mapping layer names to dicts of {'weight': ..., 'bias': ...}, 39 each containing an instance of `AbstractVariable`. 40 """ 41 distributions = {} 42 i = 0 43 shift = 0 44 weights = optimizer.param_groups[0]["params"] 45 hessians = optimizer.param_groups[0]["hess"] 46 weight_decay = optimizer.param_groups[0]["weight_decay"] 47 ess = optimizer.param_groups[0]["ess"] 48 sigma = 1 / (ess * (hessians + weight_decay)).sqrt() 49 rho = torch.log(torch.exp(sigma) - 1) 50 51 for name, layer in get_layers_func(model): 52 if layer.weight is not None: 53 weight_cutoff = shift + math.prod(layer.weight.shape) 54 weight_distribution = distribution( 55 mu=weights[i], 56 rho=rho[shift:weight_cutoff].reshape(*layer.weight.shape), 57 mu_requires_grad=requires_grad, 58 rho_requires_grad=requires_grad, 59 ) 60 shift = weight_cutoff 61 i += 1 62 else: 63 weight_distribution = None 64 if layer.bias is not None: 65 bias_cutoff = shift + math.prod(layer.bias.shape) 66 bias_distribution = distribution( 67 mu=weights[i], 68 rho=rho[shift:bias_cutoff].reshape(*layer.bias.shape), 69 mu_requires_grad=requires_grad, 70 rho_requires_grad=requires_grad, 71 ) 72 shift = bias_cutoff 73 i += 1 74 else: 75 bias_distribution = None 76 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 77 return distributions 78 79 80def from_flat_rho( 81 model: nn.Module, 82 rho: Tensor | list[float], 83 distribution: type[AbstractVariable], 84 requires_grad: bool = True, 85 get_layers_func: Callable[ 86 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 87 ] = get_torch_layers, 88) -> DistributionT: 89 """ 90 Create distributions for each layer using a shared or flat `rho` array. 91 92 This function takes a model and a `rho` tensor/list containing values for all weight/bias 93 elements in consecutive order. Each layer's `mu` is initialized from the current layer weights, 94 and `rho` is reshaped accordingly for the layer's shape. 95 96 Args: 97 model (nn.Module): The PyTorch model whose layers are converted into distributions. 98 rho (Union[Tensor, List[float]]): A 1D tensor or list of floating values used to set `rho`. 99 distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` to instantiate. 100 requires_grad (bool, optional): If True, gradients will be computed for the distribution parameters. 101 get_layers_func (Callable, optional): Function for iterating over the model layers. Defaults to `get_torch_layers`. 102 103 Returns: 104 DistributionT: A dictionary of layer distributions keyed by layer names, each 105 containing 'weight' and 'bias' distributions if they exist. 106 """ 107 distributions = {} 108 shift = 0 109 for name, layer in get_layers_func(model): 110 if layer.weight is not None: 111 # weight_cutoff = shift + layer.out_features * layer.in_features 112 weight_cutoff = shift + math.prod(layer.weight.shape) 113 weight_distribution = distribution( 114 mu=layer.weight, 115 rho=rho[shift:weight_cutoff].reshape(*layer.weight.shape), 116 mu_requires_grad=requires_grad, 117 rho_requires_grad=requires_grad, 118 ) 119 else: 120 weight_distribution = None 121 if layer.bias is not None: 122 # bias_cutoff = weight_cutoff + layer.out_features 123 bias_cutoff = weight_cutoff + math.prod(layer.bias.shape) 124 bias_distribution = distribution( 125 mu=layer.bias, 126 rho=rho[weight_cutoff:bias_cutoff].reshape(*layer.bias.shape), 127 mu_requires_grad=requires_grad, 128 rho_requires_grad=requires_grad, 129 ) 130 shift = bias_cutoff 131 else: 132 bias_distribution = None 133 shift = weight_cutoff 134 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 135 return distributions 136 137 138def _from_any( 139 model: nn.Module, 140 distribution: type[AbstractVariable], 141 requires_grad: bool, 142 get_layers_func: Callable[[nn.Module], Iterator[tuple[LayerNameT, nn.Module]]], 143 weight_mu_fill_func: Callable[[nn.Module], Tensor], 144 weight_rho_fill_func: Callable[[nn.Module], Tensor], 145 bias_mu_fill_func: Callable[[nn.Module], Tensor], 146 bias_rho_fill_func: Callable[[nn.Module], Tensor], 147 weight_exists: Callable[[nn.Module], bool] = lambda layer: hasattr(layer, "weight"), 148 bias_exists: Callable[[nn.Module], bool] = lambda layer: hasattr(layer, "bias"), 149) -> DistributionT: 150 """ 151 A helper function to create a distribution for each layer in a model using 152 user-provided fill functions for `mu` and `rho`. 153 154 Args: 155 model (nn.Module): The model to convert into distributions. 156 distribution (Type[AbstractVariable]): The type of variable distribution to instantiate. 157 requires_grad (bool): If True, gradients will be computed on the distribution parameters. 158 get_layers_func (Callable): A function to iterate over model layers and yield (name, layer) pairs. 159 weight_mu_fill_func (Callable): A function that returns a tensor for initializing the weight `mu`. 160 weight_rho_fill_func (Callable): A function that returns a tensor for initializing the weight `rho`. 161 bias_mu_fill_func (Callable): A function that returns a tensor for initializing the bias `mu`. 162 bias_rho_fill_func (Callable): A function that returns a tensor for initializing the bias `rho`. 163 weight_exists (Callable, optional): A predicate to check if a layer contains a weight attribute. 164 bias_exists (Callable, optional): A predicate to check if a layer contains a bias attribute. 165 166 Returns: 167 DistributionT: A dictionary mapping each layer to weight/bias distributions. 168 """ 169 distributions = {} 170 for name, layer in get_layers_func(model): 171 if weight_exists(layer) and layer.weight is not None: 172 weight_distribution = distribution( 173 mu=weight_mu_fill_func(layer), 174 rho=weight_rho_fill_func(layer), 175 mu_requires_grad=requires_grad, 176 rho_requires_grad=requires_grad, 177 ) 178 else: 179 weight_distribution = None 180 if bias_exists(layer) and layer.bias is not None: 181 bias_distribution = distribution( 182 mu=bias_mu_fill_func(layer), 183 rho=bias_rho_fill_func(layer), 184 mu_requires_grad=requires_grad, 185 rho_requires_grad=requires_grad, 186 ) 187 else: 188 bias_distribution = None 189 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 190 return distributions 191 192 193def from_random( 194 model: nn.Module, 195 rho: Tensor, 196 distribution: type[AbstractVariable], 197 requires_grad: bool = True, 198 get_layers_func: Callable[ 199 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 200 ] = get_torch_layers, 201) -> DistributionT: 202 """ 203 Create a distribution for each layer with randomly initialized mean (using truncated normal) 204 and a constant `rho` value. 205 206 Args: 207 model (nn.Module): The target PyTorch model. 208 rho (Tensor): A single scalar tensor defining the initial `rho` for all weights/biases. 209 distribution (Type[AbstractVariable]): The class for creating each weight/bias distribution. 210 requires_grad (bool, optional): If True, allows gradient-based updates of `mu` and `rho`. 211 get_layers_func (Callable, optional): Function to iterate over model layers. Defaults to `get_torch_layers`. 212 213 Returns: 214 DistributionT: A dictionary containing layer-wise distributions for weights and biases. 215 """ 216 217 def get_truncated_normal_fill_tensor(layer: nn.Module) -> Tensor: 218 t = torch.Tensor(*layer.weight.shape) 219 if hasattr(layer, "weight") and layer.weight is not None: 220 in_features = math.prod(layer.weight.shape[1:]) 221 else: 222 raise ValueError(f"Unsupported layer of type: {type(layer)}") 223 w = 1 / math.sqrt(in_features) 224 return truncated_normal_fill_tensor(t, 0, w, -2 * w, 2 * w) 225 226 return _from_any( 227 model, 228 distribution, 229 requires_grad, 230 get_layers_func, 231 weight_mu_fill_func=get_truncated_normal_fill_tensor, 232 weight_rho_fill_func=lambda layer: torch.ones(*layer.weight.shape) * rho, 233 bias_mu_fill_func=lambda layer: torch.zeros(*layer.bias.shape), 234 bias_rho_fill_func=lambda layer: torch.ones(*layer.bias.shape) * rho, 235 ) 236 237 238def from_zeros( 239 model: nn.Module, 240 rho: Tensor, 241 distribution: type[AbstractVariable], 242 requires_grad: bool = True, 243 get_layers_func: Callable[ 244 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 245 ] = get_torch_layers, 246) -> DistributionT: 247 """ 248 Create distributions for each layer by setting `mu` to zero and `rho` to a constant value. 249 250 Args: 251 model (nn.Module): The PyTorch model. 252 rho (Tensor): A scalar defining the initial `rho` for all weights/biases. 253 distribution (Type[AbstractVariable]): Distribution class to instantiate. 254 requires_grad (bool, optional): Whether to track gradients for `mu` and `rho`. 255 get_layers_func (Callable, optional): Layer iteration function. Defaults to `get_torch_layers`. 256 257 Returns: 258 DistributionT: A dictionary mapping layer names to weight/bias distributions. 259 """ 260 return _from_any( 261 model, 262 distribution, 263 requires_grad, 264 get_layers_func, 265 weight_mu_fill_func=lambda layer: torch.zeros(*layer.weight.shape), 266 weight_rho_fill_func=lambda layer: torch.ones(*layer.weight.shape) * rho, 267 bias_mu_fill_func=lambda layer: torch.zeros(*layer.bias.shape), 268 bias_rho_fill_func=lambda layer: torch.ones(*layer.bias.shape) * rho, 269 ) 270 271 272def from_layered( 273 model: torch.nn.Module, 274 attribute_mapping: dict[str, str], 275 distribution: type[AbstractVariable], 276 requires_grad: bool = True, 277 get_layers_func: Callable[ 278 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 279 ] = get_torch_layers, 280) -> DistributionT: 281 """ 282 Create distributions by extracting `mu` and `rho` from specified attributes in the model layers. 283 284 This function looks up layer attributes for weight and bias (e.g., "weight_mu", "weight_rho") 285 using `attribute_mapping`, then initializes each distribution accordingly. 286 287 Args: 288 model (nn.Module): The model whose layers contain the specified attributes. 289 attribute_mapping (dict[str, str]): A mapping of attribute names, for example: 290 { 291 "weight_mu": "mu_weight", 292 "weight_rho": "rho_weight", 293 "bias_mu": "mu_bias", 294 "bias_rho": "rho_bias" 295 } 296 distribution (Type[AbstractVariable]): The class used to create weight/bias distributions. 297 requires_grad (bool, optional): If True, gradients will be computed on `mu` and `rho`. 298 get_layers_func (Callable, optional): Layer iteration function. 299 300 Returns: 301 DistributionT: A dictionary of distributions keyed by layer names. 302 """ 303 return _from_any( 304 model, 305 distribution, 306 requires_grad, 307 get_layers_func, 308 weight_exists=lambda layer: hasattr(layer, attribute_mapping["weight_mu"]) 309 and hasattr(layer, attribute_mapping["weight_rho"]), 310 bias_exists=lambda layer: hasattr(layer, attribute_mapping["weight_mu"]) 311 and hasattr(layer, attribute_mapping["weight_rho"]), 312 weight_mu_fill_func=lambda layer: layer.__getattr__( 313 attribute_mapping["weight_mu"] 314 ) 315 .detach() 316 .clone(), 317 weight_rho_fill_func=lambda layer: layer.__getattr__( 318 attribute_mapping["weight_rho"] 319 ) 320 .detach() 321 .clone(), 322 bias_mu_fill_func=lambda layer: layer.__getattr__(attribute_mapping["bias_mu"]) 323 .detach() 324 .clone(), 325 bias_rho_fill_func=lambda layer: layer.__getattr__( 326 attribute_mapping["bias_rho"] 327 ) 328 .detach() 329 .clone(), 330 ) 331 332 333def from_bnn( 334 model: nn.Module, 335 distribution: type[AbstractVariable], 336 requires_grad: bool = True, 337 get_layers_func: Callable[ 338 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 339 ] = get_bayesian_torch_layers, 340) -> DistributionT: 341 """ 342 Construct distributions by reading the attributes (e.g., mu_weight, rho_weight, mu_bias, rho_bias) 343 from layers typically found in BayesianTorch modules. 344 345 Args: 346 model (nn.Module): The Bayesian Torch model containing layer attributes such as mu_weight, rho_weight, etc. 347 distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` for each parameter. 348 requires_grad (bool, optional): If True, allows gradient-based optimization of `mu` and `rho`. 349 get_layers_func (Callable, optional): A function that retrieves BayesianTorch layers. Defaults to `get_bayesian_torch_layers`. 350 351 Returns: 352 DistributionT: A dictionary mapping layer names to weight/bias distributions. 353 """ 354 distributions = {} 355 for name, layer in get_layers_func(model): 356 if hasattr(layer, "mu_weight") and hasattr(layer, "rho_weight"): 357 weight_distribution = distribution( 358 mu=layer.__getattr__("mu_weight").detach().clone(), 359 rho=layer.__getattr__("rho_weight").detach().clone(), 360 mu_requires_grad=requires_grad, 361 rho_requires_grad=requires_grad, 362 ) 363 elif hasattr(layer, "mu_kernel") and hasattr(layer, "rho_kernel"): 364 weight_distribution = distribution( 365 mu=layer.__getattr__("mu_kernel").detach().clone(), 366 rho=layer.__getattr__("rho_kernel").detach().clone(), 367 mu_requires_grad=requires_grad, 368 rho_requires_grad=requires_grad, 369 ) 370 else: 371 weight_distribution = None 372 if hasattr(layer, "mu_bias") and hasattr(layer, "rho_bias"): 373 bias_distribution = distribution( 374 mu=layer.__getattr__("mu_bias").detach().clone(), 375 rho=layer.__getattr__("rho_bias").detach().clone(), 376 mu_requires_grad=requires_grad, 377 rho_requires_grad=requires_grad, 378 ) 379 else: 380 bias_distribution = None 381 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 382 return distributions 383 384 385def from_copy( 386 dist: DistributionT, 387 distribution: type[AbstractVariable], 388 requires_grad: bool = True, 389) -> DistributionT: 390 """ 391 Create a new distribution by copying `mu` and `rho` from an existing distribution. 392 393 Args: 394 dist (DistributionT): A distribution dictionary to copy from. 395 distribution (Type[AbstractVariable]): The class to instantiate for each weight/bias. 396 requires_grad (bool, optional): If True, the new distribution parameters can be updated via gradients. 397 398 Returns: 399 DistributionT: A new distribution dictionary with the same layer structure, 400 but new `mu` and `rho` parameters cloned from `dist`. 401 """ 402 distributions = {} 403 for name, layer in dist.items(): 404 weight_distribution = distribution( 405 mu=layer["weight"].mu.detach().clone(), 406 rho=layer["weight"].rho.detach().clone(), 407 mu_requires_grad=requires_grad, 408 rho_requires_grad=requires_grad, 409 ) 410 if layer["bias"] is not None: 411 bias_distribution = distribution( 412 mu=layer["bias"].mu.detach().clone(), 413 rho=layer["bias"].rho.detach().clone(), 414 mu_requires_grad=requires_grad, 415 rho_requires_grad=requires_grad, 416 ) 417 else: 418 bias_distribution = None 419 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 420 return distributions 421 422 423def compute_kl(dist1: DistributionT, dist2: DistributionT) -> Tensor: 424 """ 425 Compute the total KL divergence between two distributions of the same structure. 426 427 Each corresponding layer's weight/bias KL is summed to produce a single scalar. 428 429 Args: 430 dist1 (DistributionT): The first distribution dictionary. 431 dist2 (DistributionT): The second distribution dictionary. 432 433 Returns: 434 Tensor: A scalar tensor representing the total KL divergence across all layers. 435 """ 436 kl_list = [] 437 for idx in dist1: 438 for key in dist1[idx]: 439 if dist1[idx][key] is not None and dist2[idx][key] is not None: 440 kl = dist1[idx][key].compute_kl(dist2[idx][key]) 441 kl_list.append(kl) 442 return torch.stack(kl_list).sum() 443 444 445def compute_standard_normal_cdf(x: float) -> float: 446 """ 447 Compute the cumulative distribution function (CDF) of a standard normal at point x. 448 449 Args: 450 x (float): The input value at which to evaluate the standard normal CDF. 451 452 Returns: 453 float: The CDF value of the standard normal distribution at x. 454 """ 455 # TODO: replace with numpy 456 return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 457 458 459def truncated_normal_fill_tensor( 460 tensor: torch.Tensor, 461 mean: float = 0.0, 462 std: float = 1.0, 463 a: float = -2.0, 464 b: float = 2.0, 465) -> torch.Tensor: 466 """ 467 Fill a tensor in-place with values drawn from a truncated normal distribution. 468 469 The resulting values lie in the interval [a, b], centered around `mean` 470 with approximate std `std`. 471 472 Args: 473 tensor (torch.Tensor): The tensor to fill. 474 mean (float, optional): Mean of the desired distribution. Defaults to 0.0. 475 std (float, optional): Standard deviation of the desired distribution. Defaults to 1.0. 476 a (float, optional): Lower bound of truncation. Defaults to -2.0. 477 b (float, optional): Upper bound of truncation. Defaults to 2.0. 478 479 Returns: 480 torch.Tensor: The same tensor, filled in-place with truncated normal values. 481 """ 482 with torch.no_grad(): 483 # Get upper and lower cdf values 484 l_ = compute_standard_normal_cdf((a - mean) / std) 485 u_ = compute_standard_normal_cdf((b - mean) / std) 486 487 # Fill tensor with uniform values from [l_, u_] 488 tensor.uniform_(l_, u_) 489 490 # Use inverse cdf transform from normal distribution 491 tensor.mul_(2) 492 tensor.sub_(1) 493 494 # Ensure that the values are strictly between -1 and 1 for erfinv 495 eps = torch.finfo(tensor.dtype).eps 496 tensor.clamp_(min=-(1.0 - eps), max=(1.0 - eps)) 497 tensor.erfinv_() 498 499 # Transform to proper mean, std 500 tensor.mul_(std * math.sqrt(2.0)) 501 tensor.add_(mean) 502 503 # Clamp one last time to ensure it's still in the proper range 504 tensor.clamp_(min=a, max=b) 505 return tensor
14def from_ivon( 15 model: nn.Module, 16 optimizer: "ivon.IVON", 17 distribution: type[AbstractVariable], 18 requires_grad: bool = True, 19 get_layers_func: Callable[ 20 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 21 ] = get_torch_layers, 22) -> DistributionT: 23 """ 24 Construct a distribution from an IVON optimizer's parameters and Hessian approximations. 25 26 This function extracts weight and bias information (as well as Hessians) from the 27 IVON optimizer, then creates instances of the specified `distribution` for each 28 parameter. The newly created distributions can be used as a posterior or prior 29 in a PAC-Bayes setting. 30 31 Args: 32 model (nn.Module): The (deterministic) model whose parameters correspond to the IVON's weights. 33 optimizer (ivon.IVON): An instance of IVON optimizer containing the Hessian approximations. 34 distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` to instantiate. 35 requires_grad (bool, optional): If True, gradients will be computed on the newly created parameters. 36 get_layers_func (Callable, optional): A function to retrieve model layers. Defaults to `get_torch_layers`. 37 38 Returns: 39 DistributionT: A dictionary mapping layer names to dicts of {'weight': ..., 'bias': ...}, 40 each containing an instance of `AbstractVariable`. 41 """ 42 distributions = {} 43 i = 0 44 shift = 0 45 weights = optimizer.param_groups[0]["params"] 46 hessians = optimizer.param_groups[0]["hess"] 47 weight_decay = optimizer.param_groups[0]["weight_decay"] 48 ess = optimizer.param_groups[0]["ess"] 49 sigma = 1 / (ess * (hessians + weight_decay)).sqrt() 50 rho = torch.log(torch.exp(sigma) - 1) 51 52 for name, layer in get_layers_func(model): 53 if layer.weight is not None: 54 weight_cutoff = shift + math.prod(layer.weight.shape) 55 weight_distribution = distribution( 56 mu=weights[i], 57 rho=rho[shift:weight_cutoff].reshape(*layer.weight.shape), 58 mu_requires_grad=requires_grad, 59 rho_requires_grad=requires_grad, 60 ) 61 shift = weight_cutoff 62 i += 1 63 else: 64 weight_distribution = None 65 if layer.bias is not None: 66 bias_cutoff = shift + math.prod(layer.bias.shape) 67 bias_distribution = distribution( 68 mu=weights[i], 69 rho=rho[shift:bias_cutoff].reshape(*layer.bias.shape), 70 mu_requires_grad=requires_grad, 71 rho_requires_grad=requires_grad, 72 ) 73 shift = bias_cutoff 74 i += 1 75 else: 76 bias_distribution = None 77 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 78 return distributions
Construct a distribution from an IVON optimizer's parameters and Hessian approximations.
This function extracts weight and bias information (as well as Hessians) from the
IVON optimizer, then creates instances of the specified distribution for each
parameter. The newly created distributions can be used as a posterior or prior
in a PAC-Bayes setting.
Arguments:
- model (nn.Module): The (deterministic) model whose parameters correspond to the IVON's weights.
- optimizer (ivon.IVON): An instance of IVON optimizer containing the Hessian approximations.
- distribution (Type[AbstractVariable]): The subclass of
AbstractVariableto instantiate. - requires_grad (bool, optional): If True, gradients will be computed on the newly created parameters.
- get_layers_func (Callable, optional): A function to retrieve model layers. Defaults to
get_torch_layers.
Returns:
DistributionT: A dictionary mapping layer names to dicts of {'weight': ..., 'bias': ...}, each containing an instance of
AbstractVariable.
81def from_flat_rho( 82 model: nn.Module, 83 rho: Tensor | list[float], 84 distribution: type[AbstractVariable], 85 requires_grad: bool = True, 86 get_layers_func: Callable[ 87 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 88 ] = get_torch_layers, 89) -> DistributionT: 90 """ 91 Create distributions for each layer using a shared or flat `rho` array. 92 93 This function takes a model and a `rho` tensor/list containing values for all weight/bias 94 elements in consecutive order. Each layer's `mu` is initialized from the current layer weights, 95 and `rho` is reshaped accordingly for the layer's shape. 96 97 Args: 98 model (nn.Module): The PyTorch model whose layers are converted into distributions. 99 rho (Union[Tensor, List[float]]): A 1D tensor or list of floating values used to set `rho`. 100 distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` to instantiate. 101 requires_grad (bool, optional): If True, gradients will be computed for the distribution parameters. 102 get_layers_func (Callable, optional): Function for iterating over the model layers. Defaults to `get_torch_layers`. 103 104 Returns: 105 DistributionT: A dictionary of layer distributions keyed by layer names, each 106 containing 'weight' and 'bias' distributions if they exist. 107 """ 108 distributions = {} 109 shift = 0 110 for name, layer in get_layers_func(model): 111 if layer.weight is not None: 112 # weight_cutoff = shift + layer.out_features * layer.in_features 113 weight_cutoff = shift + math.prod(layer.weight.shape) 114 weight_distribution = distribution( 115 mu=layer.weight, 116 rho=rho[shift:weight_cutoff].reshape(*layer.weight.shape), 117 mu_requires_grad=requires_grad, 118 rho_requires_grad=requires_grad, 119 ) 120 else: 121 weight_distribution = None 122 if layer.bias is not None: 123 # bias_cutoff = weight_cutoff + layer.out_features 124 bias_cutoff = weight_cutoff + math.prod(layer.bias.shape) 125 bias_distribution = distribution( 126 mu=layer.bias, 127 rho=rho[weight_cutoff:bias_cutoff].reshape(*layer.bias.shape), 128 mu_requires_grad=requires_grad, 129 rho_requires_grad=requires_grad, 130 ) 131 shift = bias_cutoff 132 else: 133 bias_distribution = None 134 shift = weight_cutoff 135 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 136 return distributions
Create distributions for each layer using a shared or flat rho array.
This function takes a model and a rho tensor/list containing values for all weight/bias
elements in consecutive order. Each layer's mu is initialized from the current layer weights,
and rho is reshaped accordingly for the layer's shape.
Arguments:
- model (nn.Module): The PyTorch model whose layers are converted into distributions.
- rho (Union[Tensor, List[float]]): A 1D tensor or list of floating values used to set
rho. - distribution (Type[AbstractVariable]): The subclass of
AbstractVariableto instantiate. - requires_grad (bool, optional): If True, gradients will be computed for the distribution parameters.
- get_layers_func (Callable, optional): Function for iterating over the model layers. Defaults to
get_torch_layers.
Returns:
DistributionT: A dictionary of layer distributions keyed by layer names, each containing 'weight' and 'bias' distributions if they exist.
194def from_random( 195 model: nn.Module, 196 rho: Tensor, 197 distribution: type[AbstractVariable], 198 requires_grad: bool = True, 199 get_layers_func: Callable[ 200 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 201 ] = get_torch_layers, 202) -> DistributionT: 203 """ 204 Create a distribution for each layer with randomly initialized mean (using truncated normal) 205 and a constant `rho` value. 206 207 Args: 208 model (nn.Module): The target PyTorch model. 209 rho (Tensor): A single scalar tensor defining the initial `rho` for all weights/biases. 210 distribution (Type[AbstractVariable]): The class for creating each weight/bias distribution. 211 requires_grad (bool, optional): If True, allows gradient-based updates of `mu` and `rho`. 212 get_layers_func (Callable, optional): Function to iterate over model layers. Defaults to `get_torch_layers`. 213 214 Returns: 215 DistributionT: A dictionary containing layer-wise distributions for weights and biases. 216 """ 217 218 def get_truncated_normal_fill_tensor(layer: nn.Module) -> Tensor: 219 t = torch.Tensor(*layer.weight.shape) 220 if hasattr(layer, "weight") and layer.weight is not None: 221 in_features = math.prod(layer.weight.shape[1:]) 222 else: 223 raise ValueError(f"Unsupported layer of type: {type(layer)}") 224 w = 1 / math.sqrt(in_features) 225 return truncated_normal_fill_tensor(t, 0, w, -2 * w, 2 * w) 226 227 return _from_any( 228 model, 229 distribution, 230 requires_grad, 231 get_layers_func, 232 weight_mu_fill_func=get_truncated_normal_fill_tensor, 233 weight_rho_fill_func=lambda layer: torch.ones(*layer.weight.shape) * rho, 234 bias_mu_fill_func=lambda layer: torch.zeros(*layer.bias.shape), 235 bias_rho_fill_func=lambda layer: torch.ones(*layer.bias.shape) * rho, 236 )
Create a distribution for each layer with randomly initialized mean (using truncated normal)
and a constant rho value.
Arguments:
- model (nn.Module): The target PyTorch model.
- rho (Tensor): A single scalar tensor defining the initial
rhofor all weights/biases. - distribution (Type[AbstractVariable]): The class for creating each weight/bias distribution.
- requires_grad (bool, optional): If True, allows gradient-based updates of
muandrho. - get_layers_func (Callable, optional): Function to iterate over model layers. Defaults to
get_torch_layers.
Returns:
DistributionT: A dictionary containing layer-wise distributions for weights and biases.
239def from_zeros( 240 model: nn.Module, 241 rho: Tensor, 242 distribution: type[AbstractVariable], 243 requires_grad: bool = True, 244 get_layers_func: Callable[ 245 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 246 ] = get_torch_layers, 247) -> DistributionT: 248 """ 249 Create distributions for each layer by setting `mu` to zero and `rho` to a constant value. 250 251 Args: 252 model (nn.Module): The PyTorch model. 253 rho (Tensor): A scalar defining the initial `rho` for all weights/biases. 254 distribution (Type[AbstractVariable]): Distribution class to instantiate. 255 requires_grad (bool, optional): Whether to track gradients for `mu` and `rho`. 256 get_layers_func (Callable, optional): Layer iteration function. Defaults to `get_torch_layers`. 257 258 Returns: 259 DistributionT: A dictionary mapping layer names to weight/bias distributions. 260 """ 261 return _from_any( 262 model, 263 distribution, 264 requires_grad, 265 get_layers_func, 266 weight_mu_fill_func=lambda layer: torch.zeros(*layer.weight.shape), 267 weight_rho_fill_func=lambda layer: torch.ones(*layer.weight.shape) * rho, 268 bias_mu_fill_func=lambda layer: torch.zeros(*layer.bias.shape), 269 bias_rho_fill_func=lambda layer: torch.ones(*layer.bias.shape) * rho, 270 )
Create distributions for each layer by setting mu to zero and rho to a constant value.
Arguments:
- model (nn.Module): The PyTorch model.
- rho (Tensor): A scalar defining the initial
rhofor all weights/biases. - distribution (Type[AbstractVariable]): Distribution class to instantiate.
- requires_grad (bool, optional): Whether to track gradients for
muandrho. - get_layers_func (Callable, optional): Layer iteration function. Defaults to
get_torch_layers.
Returns:
DistributionT: A dictionary mapping layer names to weight/bias distributions.
273def from_layered( 274 model: torch.nn.Module, 275 attribute_mapping: dict[str, str], 276 distribution: type[AbstractVariable], 277 requires_grad: bool = True, 278 get_layers_func: Callable[ 279 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 280 ] = get_torch_layers, 281) -> DistributionT: 282 """ 283 Create distributions by extracting `mu` and `rho` from specified attributes in the model layers. 284 285 This function looks up layer attributes for weight and bias (e.g., "weight_mu", "weight_rho") 286 using `attribute_mapping`, then initializes each distribution accordingly. 287 288 Args: 289 model (nn.Module): The model whose layers contain the specified attributes. 290 attribute_mapping (dict[str, str]): A mapping of attribute names, for example: 291 { 292 "weight_mu": "mu_weight", 293 "weight_rho": "rho_weight", 294 "bias_mu": "mu_bias", 295 "bias_rho": "rho_bias" 296 } 297 distribution (Type[AbstractVariable]): The class used to create weight/bias distributions. 298 requires_grad (bool, optional): If True, gradients will be computed on `mu` and `rho`. 299 get_layers_func (Callable, optional): Layer iteration function. 300 301 Returns: 302 DistributionT: A dictionary of distributions keyed by layer names. 303 """ 304 return _from_any( 305 model, 306 distribution, 307 requires_grad, 308 get_layers_func, 309 weight_exists=lambda layer: hasattr(layer, attribute_mapping["weight_mu"]) 310 and hasattr(layer, attribute_mapping["weight_rho"]), 311 bias_exists=lambda layer: hasattr(layer, attribute_mapping["weight_mu"]) 312 and hasattr(layer, attribute_mapping["weight_rho"]), 313 weight_mu_fill_func=lambda layer: layer.__getattr__( 314 attribute_mapping["weight_mu"] 315 ) 316 .detach() 317 .clone(), 318 weight_rho_fill_func=lambda layer: layer.__getattr__( 319 attribute_mapping["weight_rho"] 320 ) 321 .detach() 322 .clone(), 323 bias_mu_fill_func=lambda layer: layer.__getattr__(attribute_mapping["bias_mu"]) 324 .detach() 325 .clone(), 326 bias_rho_fill_func=lambda layer: layer.__getattr__( 327 attribute_mapping["bias_rho"] 328 ) 329 .detach() 330 .clone(), 331 )
Create distributions by extracting mu and rho from specified attributes in the model layers.
This function looks up layer attributes for weight and bias (e.g., "weight_mu", "weight_rho")
using attribute_mapping, then initializes each distribution accordingly.
Arguments:
- model (nn.Module): The model whose layers contain the specified attributes.
- attribute_mapping (dict[str, str]): A mapping of attribute names, for example: { "weight_mu": "mu_weight", "weight_rho": "rho_weight", "bias_mu": "mu_bias", "bias_rho": "rho_bias" }
- distribution (Type[AbstractVariable]): The class used to create weight/bias distributions.
- requires_grad (bool, optional): If True, gradients will be computed on
muandrho. - get_layers_func (Callable, optional): Layer iteration function.
Returns:
DistributionT: A dictionary of distributions keyed by layer names.
334def from_bnn( 335 model: nn.Module, 336 distribution: type[AbstractVariable], 337 requires_grad: bool = True, 338 get_layers_func: Callable[ 339 [nn.Module], Iterator[tuple[LayerNameT, nn.Module]] 340 ] = get_bayesian_torch_layers, 341) -> DistributionT: 342 """ 343 Construct distributions by reading the attributes (e.g., mu_weight, rho_weight, mu_bias, rho_bias) 344 from layers typically found in BayesianTorch modules. 345 346 Args: 347 model (nn.Module): The Bayesian Torch model containing layer attributes such as mu_weight, rho_weight, etc. 348 distribution (Type[AbstractVariable]): The subclass of `AbstractVariable` for each parameter. 349 requires_grad (bool, optional): If True, allows gradient-based optimization of `mu` and `rho`. 350 get_layers_func (Callable, optional): A function that retrieves BayesianTorch layers. Defaults to `get_bayesian_torch_layers`. 351 352 Returns: 353 DistributionT: A dictionary mapping layer names to weight/bias distributions. 354 """ 355 distributions = {} 356 for name, layer in get_layers_func(model): 357 if hasattr(layer, "mu_weight") and hasattr(layer, "rho_weight"): 358 weight_distribution = distribution( 359 mu=layer.__getattr__("mu_weight").detach().clone(), 360 rho=layer.__getattr__("rho_weight").detach().clone(), 361 mu_requires_grad=requires_grad, 362 rho_requires_grad=requires_grad, 363 ) 364 elif hasattr(layer, "mu_kernel") and hasattr(layer, "rho_kernel"): 365 weight_distribution = distribution( 366 mu=layer.__getattr__("mu_kernel").detach().clone(), 367 rho=layer.__getattr__("rho_kernel").detach().clone(), 368 mu_requires_grad=requires_grad, 369 rho_requires_grad=requires_grad, 370 ) 371 else: 372 weight_distribution = None 373 if hasattr(layer, "mu_bias") and hasattr(layer, "rho_bias"): 374 bias_distribution = distribution( 375 mu=layer.__getattr__("mu_bias").detach().clone(), 376 rho=layer.__getattr__("rho_bias").detach().clone(), 377 mu_requires_grad=requires_grad, 378 rho_requires_grad=requires_grad, 379 ) 380 else: 381 bias_distribution = None 382 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 383 return distributions
Construct distributions by reading the attributes (e.g., mu_weight, rho_weight, mu_bias, rho_bias) from layers typically found in BayesianTorch modules.
Arguments:
- model (nn.Module): The Bayesian Torch model containing layer attributes such as mu_weight, rho_weight, etc.
- distribution (Type[AbstractVariable]): The subclass of
AbstractVariablefor each parameter. - requires_grad (bool, optional): If True, allows gradient-based optimization of
muandrho. - get_layers_func (Callable, optional): A function that retrieves BayesianTorch layers. Defaults to
get_bayesian_torch_layers.
Returns:
DistributionT: A dictionary mapping layer names to weight/bias distributions.
386def from_copy( 387 dist: DistributionT, 388 distribution: type[AbstractVariable], 389 requires_grad: bool = True, 390) -> DistributionT: 391 """ 392 Create a new distribution by copying `mu` and `rho` from an existing distribution. 393 394 Args: 395 dist (DistributionT): A distribution dictionary to copy from. 396 distribution (Type[AbstractVariable]): The class to instantiate for each weight/bias. 397 requires_grad (bool, optional): If True, the new distribution parameters can be updated via gradients. 398 399 Returns: 400 DistributionT: A new distribution dictionary with the same layer structure, 401 but new `mu` and `rho` parameters cloned from `dist`. 402 """ 403 distributions = {} 404 for name, layer in dist.items(): 405 weight_distribution = distribution( 406 mu=layer["weight"].mu.detach().clone(), 407 rho=layer["weight"].rho.detach().clone(), 408 mu_requires_grad=requires_grad, 409 rho_requires_grad=requires_grad, 410 ) 411 if layer["bias"] is not None: 412 bias_distribution = distribution( 413 mu=layer["bias"].mu.detach().clone(), 414 rho=layer["bias"].rho.detach().clone(), 415 mu_requires_grad=requires_grad, 416 rho_requires_grad=requires_grad, 417 ) 418 else: 419 bias_distribution = None 420 distributions[name] = {"weight": weight_distribution, "bias": bias_distribution} 421 return distributions
Create a new distribution by copying mu and rho from an existing distribution.
Arguments:
- dist (DistributionT): A distribution dictionary to copy from.
- distribution (Type[AbstractVariable]): The class to instantiate for each weight/bias.
- requires_grad (bool, optional): If True, the new distribution parameters can be updated via gradients.
Returns:
DistributionT: A new distribution dictionary with the same layer structure, but new
muandrhoparameters cloned fromdist.
424def compute_kl(dist1: DistributionT, dist2: DistributionT) -> Tensor: 425 """ 426 Compute the total KL divergence between two distributions of the same structure. 427 428 Each corresponding layer's weight/bias KL is summed to produce a single scalar. 429 430 Args: 431 dist1 (DistributionT): The first distribution dictionary. 432 dist2 (DistributionT): The second distribution dictionary. 433 434 Returns: 435 Tensor: A scalar tensor representing the total KL divergence across all layers. 436 """ 437 kl_list = [] 438 for idx in dist1: 439 for key in dist1[idx]: 440 if dist1[idx][key] is not None and dist2[idx][key] is not None: 441 kl = dist1[idx][key].compute_kl(dist2[idx][key]) 442 kl_list.append(kl) 443 return torch.stack(kl_list).sum()
Compute the total KL divergence between two distributions of the same structure.
Each corresponding layer's weight/bias KL is summed to produce a single scalar.
Arguments:
- dist1 (DistributionT): The first distribution dictionary.
- dist2 (DistributionT): The second distribution dictionary.
Returns:
Tensor: A scalar tensor representing the total KL divergence across all layers.
446def compute_standard_normal_cdf(x: float) -> float: 447 """ 448 Compute the cumulative distribution function (CDF) of a standard normal at point x. 449 450 Args: 451 x (float): The input value at which to evaluate the standard normal CDF. 452 453 Returns: 454 float: The CDF value of the standard normal distribution at x. 455 """ 456 # TODO: replace with numpy 457 return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
Compute the cumulative distribution function (CDF) of a standard normal at point x.
Arguments:
- x (float): The input value at which to evaluate the standard normal CDF.
Returns:
float: The CDF value of the standard normal distribution at x.
460def truncated_normal_fill_tensor( 461 tensor: torch.Tensor, 462 mean: float = 0.0, 463 std: float = 1.0, 464 a: float = -2.0, 465 b: float = 2.0, 466) -> torch.Tensor: 467 """ 468 Fill a tensor in-place with values drawn from a truncated normal distribution. 469 470 The resulting values lie in the interval [a, b], centered around `mean` 471 with approximate std `std`. 472 473 Args: 474 tensor (torch.Tensor): The tensor to fill. 475 mean (float, optional): Mean of the desired distribution. Defaults to 0.0. 476 std (float, optional): Standard deviation of the desired distribution. Defaults to 1.0. 477 a (float, optional): Lower bound of truncation. Defaults to -2.0. 478 b (float, optional): Upper bound of truncation. Defaults to 2.0. 479 480 Returns: 481 torch.Tensor: The same tensor, filled in-place with truncated normal values. 482 """ 483 with torch.no_grad(): 484 # Get upper and lower cdf values 485 l_ = compute_standard_normal_cdf((a - mean) / std) 486 u_ = compute_standard_normal_cdf((b - mean) / std) 487 488 # Fill tensor with uniform values from [l_, u_] 489 tensor.uniform_(l_, u_) 490 491 # Use inverse cdf transform from normal distribution 492 tensor.mul_(2) 493 tensor.sub_(1) 494 495 # Ensure that the values are strictly between -1 and 1 for erfinv 496 eps = torch.finfo(tensor.dtype).eps 497 tensor.clamp_(min=-(1.0 - eps), max=(1.0 - eps)) 498 tensor.erfinv_() 499 500 # Transform to proper mean, std 501 tensor.mul_(std * math.sqrt(2.0)) 502 tensor.add_(mean) 503 504 # Clamp one last time to ensure it's still in the proper range 505 tensor.clamp_(min=a, max=b) 506 return tensor
Fill a tensor in-place with values drawn from a truncated normal distribution.
The resulting values lie in the interval [a, b], centered around mean
with approximate std std.
Arguments:
- tensor (torch.Tensor): The tensor to fill.
- mean (float, optional): Mean of the desired distribution. Defaults to 0.0.
- std (float, optional): Standard deviation of the desired distribution. Defaults to 1.0.
- a (float, optional): Lower bound of truncation. Defaults to -2.0.
- b (float, optional): Upper bound of truncation. Defaults to 2.0.
Returns:
torch.Tensor: The same tensor, filled in-place with truncated normal values.