core.distribution.AbstractVariable
1from abc import ABC, abstractmethod 2 3import torch 4import torch.nn as nn 5 6from core.utils import KLDivergenceInterface 7 8 9class AbstractVariable(nn.Module, KLDivergenceInterface, ABC): 10 """ 11 An abstract class representing a single random variable for a probabilistic 12 neural network parameter (e.g., weight or bias). 13 14 Each variable holds: 15 - A mean parameter (`mu`) 16 - A `rho` parameter that is used to derive the standard deviation (`sigma`) 17 - A method to sample from the underlying distribution 18 - A method to compute KL divergence with another variable of the same type 19 20 This class inherits from `nn.Module` for parameter registration in PyTorch 21 and from `KLDivergenceInterface` for consistent KL divergence handling. 22 """ 23 24 def __init__( 25 self, 26 mu: torch.Tensor, 27 rho: torch.Tensor, 28 mu_requires_grad: bool = False, 29 rho_requires_grad: bool = False, 30 ): 31 """ 32 Initialize an AbstractVariable with given `mu` and `rho` tensors. 33 34 Args: 35 mu (torch.Tensor): The mean parameter for the distribution. 36 rho (torch.Tensor): The parameter from which we derive sigma = log(1 + exp(rho)). 37 mu_requires_grad (bool, optional): If True, allow gradients on `mu`. 38 rho_requires_grad (bool, optional): If True, allow gradients on `rho`. 39 """ 40 super().__init__() 41 self.mu = nn.Parameter(mu.detach().clone(), requires_grad=mu_requires_grad) 42 self.rho = nn.Parameter(rho.detach().clone(), requires_grad=rho_requires_grad) 43 self.kl_div = None 44 45 @property 46 def sigma(self) -> torch.Tensor: 47 """ 48 The standard deviation of the distribution, computed as: 49 sigma = log(1 + exp(rho)). 50 51 Returns: 52 torch.Tensor: A tensor representing the current standard deviation. 53 """ 54 return torch.log(1 + torch.exp(self.rho)) 55 56 @abstractmethod 57 def sample(self) -> torch.Tensor: 58 """ 59 Draw a sample from this variable's underlying distribution. 60 61 Returns: 62 torch.Tensor: A sampled value of the same shape as `mu`. 63 """ 64 pass 65 66 @abstractmethod 67 def compute_kl(self, other: "AbstractVariable") -> torch.Tensor: 68 """ 69 Compute the KL divergence between this variable and another variable 70 of the same distribution type. 71 72 Args: 73 other (AbstractVariable): Another AbstractVariable instance 74 with comparable parameters (e.g., mu, rho). 75 76 Returns: 77 torch.Tensor: A scalar tensor with the KL divergence value. 78 """ 79 pass
class
AbstractVariable(torch.nn.modules.module.Module, core.utils.kl.KLDivergenceInterface, abc.ABC):
10class AbstractVariable(nn.Module, KLDivergenceInterface, ABC): 11 """ 12 An abstract class representing a single random variable for a probabilistic 13 neural network parameter (e.g., weight or bias). 14 15 Each variable holds: 16 - A mean parameter (`mu`) 17 - A `rho` parameter that is used to derive the standard deviation (`sigma`) 18 - A method to sample from the underlying distribution 19 - A method to compute KL divergence with another variable of the same type 20 21 This class inherits from `nn.Module` for parameter registration in PyTorch 22 and from `KLDivergenceInterface` for consistent KL divergence handling. 23 """ 24 25 def __init__( 26 self, 27 mu: torch.Tensor, 28 rho: torch.Tensor, 29 mu_requires_grad: bool = False, 30 rho_requires_grad: bool = False, 31 ): 32 """ 33 Initialize an AbstractVariable with given `mu` and `rho` tensors. 34 35 Args: 36 mu (torch.Tensor): The mean parameter for the distribution. 37 rho (torch.Tensor): The parameter from which we derive sigma = log(1 + exp(rho)). 38 mu_requires_grad (bool, optional): If True, allow gradients on `mu`. 39 rho_requires_grad (bool, optional): If True, allow gradients on `rho`. 40 """ 41 super().__init__() 42 self.mu = nn.Parameter(mu.detach().clone(), requires_grad=mu_requires_grad) 43 self.rho = nn.Parameter(rho.detach().clone(), requires_grad=rho_requires_grad) 44 self.kl_div = None 45 46 @property 47 def sigma(self) -> torch.Tensor: 48 """ 49 The standard deviation of the distribution, computed as: 50 sigma = log(1 + exp(rho)). 51 52 Returns: 53 torch.Tensor: A tensor representing the current standard deviation. 54 """ 55 return torch.log(1 + torch.exp(self.rho)) 56 57 @abstractmethod 58 def sample(self) -> torch.Tensor: 59 """ 60 Draw a sample from this variable's underlying distribution. 61 62 Returns: 63 torch.Tensor: A sampled value of the same shape as `mu`. 64 """ 65 pass 66 67 @abstractmethod 68 def compute_kl(self, other: "AbstractVariable") -> torch.Tensor: 69 """ 70 Compute the KL divergence between this variable and another variable 71 of the same distribution type. 72 73 Args: 74 other (AbstractVariable): Another AbstractVariable instance 75 with comparable parameters (e.g., mu, rho). 76 77 Returns: 78 torch.Tensor: A scalar tensor with the KL divergence value. 79 """ 80 pass
An abstract class representing a single random variable for a probabilistic neural network parameter (e.g., weight or bias).
Each variable holds:
This class inherits from nn.Module for parameter registration in PyTorch
and from KLDivergenceInterface for consistent KL divergence handling.
AbstractVariable( mu: torch.Tensor, rho: torch.Tensor, mu_requires_grad: bool = False, rho_requires_grad: bool = False)
25 def __init__( 26 self, 27 mu: torch.Tensor, 28 rho: torch.Tensor, 29 mu_requires_grad: bool = False, 30 rho_requires_grad: bool = False, 31 ): 32 """ 33 Initialize an AbstractVariable with given `mu` and `rho` tensors. 34 35 Args: 36 mu (torch.Tensor): The mean parameter for the distribution. 37 rho (torch.Tensor): The parameter from which we derive sigma = log(1 + exp(rho)). 38 mu_requires_grad (bool, optional): If True, allow gradients on `mu`. 39 rho_requires_grad (bool, optional): If True, allow gradients on `rho`. 40 """ 41 super().__init__() 42 self.mu = nn.Parameter(mu.detach().clone(), requires_grad=mu_requires_grad) 43 self.rho = nn.Parameter(rho.detach().clone(), requires_grad=rho_requires_grad) 44 self.kl_div = None
sigma: torch.Tensor
46 @property 47 def sigma(self) -> torch.Tensor: 48 """ 49 The standard deviation of the distribution, computed as: 50 sigma = log(1 + exp(rho)). 51 52 Returns: 53 torch.Tensor: A tensor representing the current standard deviation. 54 """ 55 return torch.log(1 + torch.exp(self.rho))
The standard deviation of the distribution, computed as: sigma = log(1 + exp(rho)).
Returns:
torch.Tensor: A tensor representing the current standard deviation.
@abstractmethod
def
sample(self) -> torch.Tensor:
57 @abstractmethod 58 def sample(self) -> torch.Tensor: 59 """ 60 Draw a sample from this variable's underlying distribution. 61 62 Returns: 63 torch.Tensor: A sampled value of the same shape as `mu`. 64 """ 65 pass
Draw a sample from this variable's underlying distribution.
Returns:
torch.Tensor: A sampled value of the same shape as
mu.
67 @abstractmethod 68 def compute_kl(self, other: "AbstractVariable") -> torch.Tensor: 69 """ 70 Compute the KL divergence between this variable and another variable 71 of the same distribution type. 72 73 Args: 74 other (AbstractVariable): Another AbstractVariable instance 75 with comparable parameters (e.g., mu, rho). 76 77 Returns: 78 torch.Tensor: A scalar tensor with the KL divergence value. 79 """ 80 pass
Compute the KL divergence between this variable and another variable of the same distribution type.
Arguments:
- other (AbstractVariable): Another AbstractVariable instance with comparable parameters (e.g., mu, rho).
Returns:
torch.Tensor: A scalar tensor with the KL divergence value.