core.distribution.LaplaceVariable
1import torch 2from torch import Tensor 3 4from core.distribution import AbstractVariable 5 6 7class LaplaceVariable(AbstractVariable): 8 def __init__( 9 self, 10 mu: Tensor, 11 rho: Tensor, 12 mu_requires_grad: bool = False, 13 rho_requires_grad: bool = False, 14 ): 15 """ 16 Initialize the LaplaceVariable. 17 18 Args: 19 mu (Tensor): The mean of the Laplace distribution. 20 rho (Tensor): rho = log(exp(sigma)-1) of the Laplace distribution. 21 mu_requires_grad (bool): Flag indicating whether mu is fixed. 22 rho_requires_grad (bool): Flag indicating whether rho is fixed. 23 """ 24 super().__init__(mu, rho, mu_requires_grad, rho_requires_grad) 25 26 def sample(self): 27 """ 28 Sample from the Laplace distribution. 29 30 Returns: 31 Tensor: Sampled values from the Laplace distribution. 32 """ 33 epsilon = 0.999 * torch.rand(self.sigma.size()) - 0.49999 34 epsilon = epsilon.to(self.mu.device) 35 return self.mu - torch.mul( 36 torch.mul(self.scale, torch.sign(epsilon)), 37 torch.log(1 - 2 * torch.abs(epsilon)), 38 ) 39 40 def compute_kl(self, other: "LaplaceVariable") -> Tensor: 41 """ 42 Compute the KL divergence between two Laplace distributions. 43 44 Args: 45 other (LaplaceVariable): The other Laplace distribution. 46 47 Returns: 48 Tensor: The KL divergence between the two distributions. 49 """ 50 b1 = self.scale 51 b0 = other.scale 52 term1 = torch.log(torch.div(b0, b1)) 53 aux = torch.abs(self.mu - other.mu) 54 term2 = torch.div(aux, b0) 55 term3 = torch.div(b1, b0) * torch.exp(torch.div(-aux, b1)) 56 57 kl_div = (term1 + term2 + term3 - 1).sum() 58 return kl_div
8class LaplaceVariable(AbstractVariable): 9 def __init__( 10 self, 11 mu: Tensor, 12 rho: Tensor, 13 mu_requires_grad: bool = False, 14 rho_requires_grad: bool = False, 15 ): 16 """ 17 Initialize the LaplaceVariable. 18 19 Args: 20 mu (Tensor): The mean of the Laplace distribution. 21 rho (Tensor): rho = log(exp(sigma)-1) of the Laplace distribution. 22 mu_requires_grad (bool): Flag indicating whether mu is fixed. 23 rho_requires_grad (bool): Flag indicating whether rho is fixed. 24 """ 25 super().__init__(mu, rho, mu_requires_grad, rho_requires_grad) 26 27 def sample(self): 28 """ 29 Sample from the Laplace distribution. 30 31 Returns: 32 Tensor: Sampled values from the Laplace distribution. 33 """ 34 epsilon = 0.999 * torch.rand(self.sigma.size()) - 0.49999 35 epsilon = epsilon.to(self.mu.device) 36 return self.mu - torch.mul( 37 torch.mul(self.scale, torch.sign(epsilon)), 38 torch.log(1 - 2 * torch.abs(epsilon)), 39 ) 40 41 def compute_kl(self, other: "LaplaceVariable") -> Tensor: 42 """ 43 Compute the KL divergence between two Laplace distributions. 44 45 Args: 46 other (LaplaceVariable): The other Laplace distribution. 47 48 Returns: 49 Tensor: The KL divergence between the two distributions. 50 """ 51 b1 = self.scale 52 b0 = other.scale 53 term1 = torch.log(torch.div(b0, b1)) 54 aux = torch.abs(self.mu - other.mu) 55 term2 = torch.div(aux, b0) 56 term3 = torch.div(b1, b0) * torch.exp(torch.div(-aux, b1)) 57 58 kl_div = (term1 + term2 + term3 - 1).sum() 59 return kl_div
An abstract class representing a single random variable for a probabilistic neural network parameter (e.g., weight or bias).
Each variable holds:
This class inherits from nn.Module for parameter registration in PyTorch
and from KLDivergenceInterface for consistent KL divergence handling.
LaplaceVariable( mu: torch.Tensor, rho: torch.Tensor, mu_requires_grad: bool = False, rho_requires_grad: bool = False)
9 def __init__( 10 self, 11 mu: Tensor, 12 rho: Tensor, 13 mu_requires_grad: bool = False, 14 rho_requires_grad: bool = False, 15 ): 16 """ 17 Initialize the LaplaceVariable. 18 19 Args: 20 mu (Tensor): The mean of the Laplace distribution. 21 rho (Tensor): rho = log(exp(sigma)-1) of the Laplace distribution. 22 mu_requires_grad (bool): Flag indicating whether mu is fixed. 23 rho_requires_grad (bool): Flag indicating whether rho is fixed. 24 """ 25 super().__init__(mu, rho, mu_requires_grad, rho_requires_grad)
Initialize the LaplaceVariable.
Arguments:
- mu (Tensor): The mean of the Laplace distribution.
- rho (Tensor): rho = log(exp(sigma)-1) of the Laplace distribution.
- mu_requires_grad (bool): Flag indicating whether mu is fixed.
- rho_requires_grad (bool): Flag indicating whether rho is fixed.
def
sample(self):
27 def sample(self): 28 """ 29 Sample from the Laplace distribution. 30 31 Returns: 32 Tensor: Sampled values from the Laplace distribution. 33 """ 34 epsilon = 0.999 * torch.rand(self.sigma.size()) - 0.49999 35 epsilon = epsilon.to(self.mu.device) 36 return self.mu - torch.mul( 37 torch.mul(self.scale, torch.sign(epsilon)), 38 torch.log(1 - 2 * torch.abs(epsilon)), 39 )
Sample from the Laplace distribution.
Returns:
Tensor: Sampled values from the Laplace distribution.
41 def compute_kl(self, other: "LaplaceVariable") -> Tensor: 42 """ 43 Compute the KL divergence between two Laplace distributions. 44 45 Args: 46 other (LaplaceVariable): The other Laplace distribution. 47 48 Returns: 49 Tensor: The KL divergence between the two distributions. 50 """ 51 b1 = self.scale 52 b0 = other.scale 53 term1 = torch.log(torch.div(b0, b1)) 54 aux = torch.abs(self.mu - other.mu) 55 term2 = torch.div(aux, b0) 56 term3 = torch.div(b1, b0) * torch.exp(torch.div(-aux, b1)) 57 58 kl_div = (term1 + term2 + term3 - 1).sum() 59 return kl_div
Compute the KL divergence between two Laplace distributions.
Arguments:
- other (LaplaceVariable): The other Laplace distribution.
Returns:
Tensor: The KL divergence between the two distributions.