core.distribution.LaplaceVariable

 1import torch
 2from torch import Tensor
 3
 4from core.distribution import AbstractVariable
 5
 6
 7class LaplaceVariable(AbstractVariable):
 8    def __init__(
 9        self,
10        mu: Tensor,
11        rho: Tensor,
12        mu_requires_grad: bool = False,
13        rho_requires_grad: bool = False,
14    ):
15        """
16        Initialize the LaplaceVariable.
17
18        Args:
19            mu (Tensor): The mean of the Laplace distribution.
20            rho (Tensor): rho = log(exp(sigma)-1) of the Laplace distribution.
21            mu_requires_grad (bool): Flag indicating whether mu is fixed.
22            rho_requires_grad (bool): Flag indicating whether rho is fixed.
23        """
24        super().__init__(mu, rho, mu_requires_grad, rho_requires_grad)
25
26    def sample(self):
27        """
28        Sample from the Laplace distribution.
29
30        Returns:
31            Tensor: Sampled values from the Laplace distribution.
32        """
33        epsilon = 0.999 * torch.rand(self.sigma.size()) - 0.49999
34        epsilon = epsilon.to(self.mu.device)
35        return self.mu - torch.mul(
36            torch.mul(self.scale, torch.sign(epsilon)),
37            torch.log(1 - 2 * torch.abs(epsilon)),
38        )
39
40    def compute_kl(self, other: "LaplaceVariable") -> Tensor:
41        """
42        Compute the KL divergence between two Laplace distributions.
43
44        Args:
45            other (LaplaceVariable): The other Laplace distribution.
46
47        Returns:
48            Tensor: The KL divergence between the two distributions.
49        """
50        b1 = self.scale
51        b0 = other.scale
52        term1 = torch.log(torch.div(b0, b1))
53        aux = torch.abs(self.mu - other.mu)
54        term2 = torch.div(aux, b0)
55        term3 = torch.div(b1, b0) * torch.exp(torch.div(-aux, b1))
56
57        kl_div = (term1 + term2 + term3 - 1).sum()
58        return kl_div
class LaplaceVariable(core.distribution.AbstractVariable.AbstractVariable):
 8class LaplaceVariable(AbstractVariable):
 9    def __init__(
10        self,
11        mu: Tensor,
12        rho: Tensor,
13        mu_requires_grad: bool = False,
14        rho_requires_grad: bool = False,
15    ):
16        """
17        Initialize the LaplaceVariable.
18
19        Args:
20            mu (Tensor): The mean of the Laplace distribution.
21            rho (Tensor): rho = log(exp(sigma)-1) of the Laplace distribution.
22            mu_requires_grad (bool): Flag indicating whether mu is fixed.
23            rho_requires_grad (bool): Flag indicating whether rho is fixed.
24        """
25        super().__init__(mu, rho, mu_requires_grad, rho_requires_grad)
26
27    def sample(self):
28        """
29        Sample from the Laplace distribution.
30
31        Returns:
32            Tensor: Sampled values from the Laplace distribution.
33        """
34        epsilon = 0.999 * torch.rand(self.sigma.size()) - 0.49999
35        epsilon = epsilon.to(self.mu.device)
36        return self.mu - torch.mul(
37            torch.mul(self.scale, torch.sign(epsilon)),
38            torch.log(1 - 2 * torch.abs(epsilon)),
39        )
40
41    def compute_kl(self, other: "LaplaceVariable") -> Tensor:
42        """
43        Compute the KL divergence between two Laplace distributions.
44
45        Args:
46            other (LaplaceVariable): The other Laplace distribution.
47
48        Returns:
49            Tensor: The KL divergence between the two distributions.
50        """
51        b1 = self.scale
52        b0 = other.scale
53        term1 = torch.log(torch.div(b0, b1))
54        aux = torch.abs(self.mu - other.mu)
55        term2 = torch.div(aux, b0)
56        term3 = torch.div(b1, b0) * torch.exp(torch.div(-aux, b1))
57
58        kl_div = (term1 + term2 + term3 - 1).sum()
59        return kl_div

An abstract class representing a single random variable for a probabilistic neural network parameter (e.g., weight or bias).

Each variable holds:
  • A mean parameter (mu)
  • A rho parameter that is used to derive the standard deviation (sigma)
  • A method to sample from the underlying distribution
  • A method to compute KL divergence with another variable of the same type

This class inherits from nn.Module for parameter registration in PyTorch and from KLDivergenceInterface for consistent KL divergence handling.

LaplaceVariable( mu: torch.Tensor, rho: torch.Tensor, mu_requires_grad: bool = False, rho_requires_grad: bool = False)
 9    def __init__(
10        self,
11        mu: Tensor,
12        rho: Tensor,
13        mu_requires_grad: bool = False,
14        rho_requires_grad: bool = False,
15    ):
16        """
17        Initialize the LaplaceVariable.
18
19        Args:
20            mu (Tensor): The mean of the Laplace distribution.
21            rho (Tensor): rho = log(exp(sigma)-1) of the Laplace distribution.
22            mu_requires_grad (bool): Flag indicating whether mu is fixed.
23            rho_requires_grad (bool): Flag indicating whether rho is fixed.
24        """
25        super().__init__(mu, rho, mu_requires_grad, rho_requires_grad)

Initialize the LaplaceVariable.

Arguments:
  • mu (Tensor): The mean of the Laplace distribution.
  • rho (Tensor): rho = log(exp(sigma)-1) of the Laplace distribution.
  • mu_requires_grad (bool): Flag indicating whether mu is fixed.
  • rho_requires_grad (bool): Flag indicating whether rho is fixed.
def sample(self):
27    def sample(self):
28        """
29        Sample from the Laplace distribution.
30
31        Returns:
32            Tensor: Sampled values from the Laplace distribution.
33        """
34        epsilon = 0.999 * torch.rand(self.sigma.size()) - 0.49999
35        epsilon = epsilon.to(self.mu.device)
36        return self.mu - torch.mul(
37            torch.mul(self.scale, torch.sign(epsilon)),
38            torch.log(1 - 2 * torch.abs(epsilon)),
39        )

Sample from the Laplace distribution.

Returns:

Tensor: Sampled values from the Laplace distribution.

def compute_kl( self, other: LaplaceVariable) -> torch.Tensor:
41    def compute_kl(self, other: "LaplaceVariable") -> Tensor:
42        """
43        Compute the KL divergence between two Laplace distributions.
44
45        Args:
46            other (LaplaceVariable): The other Laplace distribution.
47
48        Returns:
49            Tensor: The KL divergence between the two distributions.
50        """
51        b1 = self.scale
52        b0 = other.scale
53        term1 = torch.log(torch.div(b0, b1))
54        aux = torch.abs(self.mu - other.mu)
55        term2 = torch.div(aux, b0)
56        term3 = torch.div(b1, b0) * torch.exp(torch.div(-aux, b1))
57
58        kl_div = (term1 + term2 + term3 - 1).sum()
59        return kl_div

Compute the KL divergence between two Laplace distributions.

Arguments:
  • other (LaplaceVariable): The other Laplace distribution.
Returns:

Tensor: The KL divergence between the two distributions.