core.distribution.GaussianVariable

 1import torch
 2from torch import Tensor
 3
 4from core.distribution import AbstractVariable
 5
 6
 7class GaussianVariable(AbstractVariable):
 8    """
 9    Represents a Gaussian random variable with mean mu and rho.
10    """
11
12    def __init__(
13        self,
14        mu: Tensor,
15        rho: Tensor,
16        mu_requires_grad: bool = False,
17        rho_requires_grad: bool = False,
18    ):
19        """
20        Initialize the GaussianVariable.
21
22        Args:
23            mu (Tensor): The mean of the Gaussian distribution.
24            rho (Tensor): rho = log(exp(sigma)-1) where sigma is a standard deviation of the Gaussian distribution.
25            mu_requires_grad (bool): Flag indicating whether mu is fixed.
26            rho_requires_grad (bool): Flag indicating whether rho is fixed.
27        """
28        super().__init__(mu, rho, mu_requires_grad, rho_requires_grad)
29
30    def sample(self) -> Tensor:
31        """
32        Sample from the Gaussian distribution.
33
34        Returns:
35            Tensor: Sampled values from the Gaussian distribution.
36        """
37        epsilon = torch.randn(self.sigma.size())
38        epsilon = epsilon.to(self.mu.device)
39        return self.mu + self.sigma * epsilon
40
41    def compute_kl(self, other: "GaussianVariable") -> Tensor:
42        """
43        Compute the KL divergence between two Gaussian distributions.
44
45        Args:
46            other (GaussianVariable): The other Gaussian distribution.
47
48        Returns:
49            Tensor: The KL divergence between the two distributions.
50        """
51        b1 = torch.pow(self.sigma, 2)
52        b0 = torch.pow(other.sigma, 2)
53
54        term1 = torch.log(torch.div(b0, b1))
55        term2 = torch.div(torch.pow(self.mu - other.mu, 2), b0)
56        term3 = torch.div(b1, b0)
57        kl_div = (torch.mul(term1 + term2 + term3 - 1, 0.5)).sum()
58        return kl_div
class GaussianVariable(core.distribution.AbstractVariable.AbstractVariable):
 8class GaussianVariable(AbstractVariable):
 9    """
10    Represents a Gaussian random variable with mean mu and rho.
11    """
12
13    def __init__(
14        self,
15        mu: Tensor,
16        rho: Tensor,
17        mu_requires_grad: bool = False,
18        rho_requires_grad: bool = False,
19    ):
20        """
21        Initialize the GaussianVariable.
22
23        Args:
24            mu (Tensor): The mean of the Gaussian distribution.
25            rho (Tensor): rho = log(exp(sigma)-1) where sigma is a standard deviation of the Gaussian distribution.
26            mu_requires_grad (bool): Flag indicating whether mu is fixed.
27            rho_requires_grad (bool): Flag indicating whether rho is fixed.
28        """
29        super().__init__(mu, rho, mu_requires_grad, rho_requires_grad)
30
31    def sample(self) -> Tensor:
32        """
33        Sample from the Gaussian distribution.
34
35        Returns:
36            Tensor: Sampled values from the Gaussian distribution.
37        """
38        epsilon = torch.randn(self.sigma.size())
39        epsilon = epsilon.to(self.mu.device)
40        return self.mu + self.sigma * epsilon
41
42    def compute_kl(self, other: "GaussianVariable") -> Tensor:
43        """
44        Compute the KL divergence between two Gaussian distributions.
45
46        Args:
47            other (GaussianVariable): The other Gaussian distribution.
48
49        Returns:
50            Tensor: The KL divergence between the two distributions.
51        """
52        b1 = torch.pow(self.sigma, 2)
53        b0 = torch.pow(other.sigma, 2)
54
55        term1 = torch.log(torch.div(b0, b1))
56        term2 = torch.div(torch.pow(self.mu - other.mu, 2), b0)
57        term3 = torch.div(b1, b0)
58        kl_div = (torch.mul(term1 + term2 + term3 - 1, 0.5)).sum()
59        return kl_div

Represents a Gaussian random variable with mean mu and rho.

GaussianVariable( mu: torch.Tensor, rho: torch.Tensor, mu_requires_grad: bool = False, rho_requires_grad: bool = False)
13    def __init__(
14        self,
15        mu: Tensor,
16        rho: Tensor,
17        mu_requires_grad: bool = False,
18        rho_requires_grad: bool = False,
19    ):
20        """
21        Initialize the GaussianVariable.
22
23        Args:
24            mu (Tensor): The mean of the Gaussian distribution.
25            rho (Tensor): rho = log(exp(sigma)-1) where sigma is a standard deviation of the Gaussian distribution.
26            mu_requires_grad (bool): Flag indicating whether mu is fixed.
27            rho_requires_grad (bool): Flag indicating whether rho is fixed.
28        """
29        super().__init__(mu, rho, mu_requires_grad, rho_requires_grad)

Initialize the GaussianVariable.

Arguments:
  • mu (Tensor): The mean of the Gaussian distribution.
  • rho (Tensor): rho = log(exp(sigma)-1) where sigma is a standard deviation of the Gaussian distribution.
  • mu_requires_grad (bool): Flag indicating whether mu is fixed.
  • rho_requires_grad (bool): Flag indicating whether rho is fixed.
def sample(self) -> torch.Tensor:
31    def sample(self) -> Tensor:
32        """
33        Sample from the Gaussian distribution.
34
35        Returns:
36            Tensor: Sampled values from the Gaussian distribution.
37        """
38        epsilon = torch.randn(self.sigma.size())
39        epsilon = epsilon.to(self.mu.device)
40        return self.mu + self.sigma * epsilon

Sample from the Gaussian distribution.

Returns:

Tensor: Sampled values from the Gaussian distribution.

def compute_kl( self, other: GaussianVariable) -> torch.Tensor:
42    def compute_kl(self, other: "GaussianVariable") -> Tensor:
43        """
44        Compute the KL divergence between two Gaussian distributions.
45
46        Args:
47            other (GaussianVariable): The other Gaussian distribution.
48
49        Returns:
50            Tensor: The KL divergence between the two distributions.
51        """
52        b1 = torch.pow(self.sigma, 2)
53        b0 = torch.pow(other.sigma, 2)
54
55        term1 = torch.log(torch.div(b0, b1))
56        term2 = torch.div(torch.pow(self.mu - other.mu, 2), b0)
57        term3 = torch.div(b1, b0)
58        kl_div = (torch.mul(term1 + term2 + term3 - 1, 0.5)).sum()
59        return kl_div

Compute the KL divergence between two Gaussian distributions.

Arguments:
  • other (GaussianVariable): The other Gaussian distribution.
Returns:

Tensor: The KL divergence between the two distributions.