core.distribution.GaussianVariable
1import torch 2from torch import Tensor 3 4from core.distribution import AbstractVariable 5 6 7class GaussianVariable(AbstractVariable): 8 """ 9 Represents a Gaussian random variable with mean mu and rho. 10 """ 11 12 def __init__( 13 self, 14 mu: Tensor, 15 rho: Tensor, 16 mu_requires_grad: bool = False, 17 rho_requires_grad: bool = False, 18 ): 19 """ 20 Initialize the GaussianVariable. 21 22 Args: 23 mu (Tensor): The mean of the Gaussian distribution. 24 rho (Tensor): rho = log(exp(sigma)-1) where sigma is a standard deviation of the Gaussian distribution. 25 mu_requires_grad (bool): Flag indicating whether mu is fixed. 26 rho_requires_grad (bool): Flag indicating whether rho is fixed. 27 """ 28 super().__init__(mu, rho, mu_requires_grad, rho_requires_grad) 29 30 def sample(self) -> Tensor: 31 """ 32 Sample from the Gaussian distribution. 33 34 Returns: 35 Tensor: Sampled values from the Gaussian distribution. 36 """ 37 epsilon = torch.randn(self.sigma.size()) 38 epsilon = epsilon.to(self.mu.device) 39 return self.mu + self.sigma * epsilon 40 41 def compute_kl(self, other: "GaussianVariable") -> Tensor: 42 """ 43 Compute the KL divergence between two Gaussian distributions. 44 45 Args: 46 other (GaussianVariable): The other Gaussian distribution. 47 48 Returns: 49 Tensor: The KL divergence between the two distributions. 50 """ 51 b1 = torch.pow(self.sigma, 2) 52 b0 = torch.pow(other.sigma, 2) 53 54 term1 = torch.log(torch.div(b0, b1)) 55 term2 = torch.div(torch.pow(self.mu - other.mu, 2), b0) 56 term3 = torch.div(b1, b0) 57 kl_div = (torch.mul(term1 + term2 + term3 - 1, 0.5)).sum() 58 return kl_div
8class GaussianVariable(AbstractVariable): 9 """ 10 Represents a Gaussian random variable with mean mu and rho. 11 """ 12 13 def __init__( 14 self, 15 mu: Tensor, 16 rho: Tensor, 17 mu_requires_grad: bool = False, 18 rho_requires_grad: bool = False, 19 ): 20 """ 21 Initialize the GaussianVariable. 22 23 Args: 24 mu (Tensor): The mean of the Gaussian distribution. 25 rho (Tensor): rho = log(exp(sigma)-1) where sigma is a standard deviation of the Gaussian distribution. 26 mu_requires_grad (bool): Flag indicating whether mu is fixed. 27 rho_requires_grad (bool): Flag indicating whether rho is fixed. 28 """ 29 super().__init__(mu, rho, mu_requires_grad, rho_requires_grad) 30 31 def sample(self) -> Tensor: 32 """ 33 Sample from the Gaussian distribution. 34 35 Returns: 36 Tensor: Sampled values from the Gaussian distribution. 37 """ 38 epsilon = torch.randn(self.sigma.size()) 39 epsilon = epsilon.to(self.mu.device) 40 return self.mu + self.sigma * epsilon 41 42 def compute_kl(self, other: "GaussianVariable") -> Tensor: 43 """ 44 Compute the KL divergence between two Gaussian distributions. 45 46 Args: 47 other (GaussianVariable): The other Gaussian distribution. 48 49 Returns: 50 Tensor: The KL divergence between the two distributions. 51 """ 52 b1 = torch.pow(self.sigma, 2) 53 b0 = torch.pow(other.sigma, 2) 54 55 term1 = torch.log(torch.div(b0, b1)) 56 term2 = torch.div(torch.pow(self.mu - other.mu, 2), b0) 57 term3 = torch.div(b1, b0) 58 kl_div = (torch.mul(term1 + term2 + term3 - 1, 0.5)).sum() 59 return kl_div
Represents a Gaussian random variable with mean mu and rho.
GaussianVariable( mu: torch.Tensor, rho: torch.Tensor, mu_requires_grad: bool = False, rho_requires_grad: bool = False)
13 def __init__( 14 self, 15 mu: Tensor, 16 rho: Tensor, 17 mu_requires_grad: bool = False, 18 rho_requires_grad: bool = False, 19 ): 20 """ 21 Initialize the GaussianVariable. 22 23 Args: 24 mu (Tensor): The mean of the Gaussian distribution. 25 rho (Tensor): rho = log(exp(sigma)-1) where sigma is a standard deviation of the Gaussian distribution. 26 mu_requires_grad (bool): Flag indicating whether mu is fixed. 27 rho_requires_grad (bool): Flag indicating whether rho is fixed. 28 """ 29 super().__init__(mu, rho, mu_requires_grad, rho_requires_grad)
Initialize the GaussianVariable.
Arguments:
- mu (Tensor): The mean of the Gaussian distribution.
- rho (Tensor): rho = log(exp(sigma)-1) where sigma is a standard deviation of the Gaussian distribution.
- mu_requires_grad (bool): Flag indicating whether mu is fixed.
- rho_requires_grad (bool): Flag indicating whether rho is fixed.
def
sample(self) -> torch.Tensor:
31 def sample(self) -> Tensor: 32 """ 33 Sample from the Gaussian distribution. 34 35 Returns: 36 Tensor: Sampled values from the Gaussian distribution. 37 """ 38 epsilon = torch.randn(self.sigma.size()) 39 epsilon = epsilon.to(self.mu.device) 40 return self.mu + self.sigma * epsilon
Sample from the Gaussian distribution.
Returns:
Tensor: Sampled values from the Gaussian distribution.
42 def compute_kl(self, other: "GaussianVariable") -> Tensor: 43 """ 44 Compute the KL divergence between two Gaussian distributions. 45 46 Args: 47 other (GaussianVariable): The other Gaussian distribution. 48 49 Returns: 50 Tensor: The KL divergence between the two distributions. 51 """ 52 b1 = torch.pow(self.sigma, 2) 53 b0 = torch.pow(other.sigma, 2) 54 55 term1 = torch.log(torch.div(b0, b1)) 56 term2 = torch.div(torch.pow(self.mu - other.mu, 2), b0) 57 term3 = torch.div(b1, b0) 58 kl_div = (torch.mul(term1 + term2 + term3 - 1, 0.5)).sum() 59 return kl_div
Compute the KL divergence between two Gaussian distributions.
Arguments:
- other (GaussianVariable): The other Gaussian distribution.
Returns:
Tensor: The KL divergence between the two distributions.