core.distribution.AbstractVariable

 1from abc import ABC, abstractmethod
 2
 3import torch
 4import torch.nn as nn
 5
 6from core.utils import KLDivergenceInterface
 7
 8
 9class AbstractVariable(nn.Module, KLDivergenceInterface, ABC):
10    """
11    An abstract class representing a single random variable for a probabilistic
12    neural network parameter (e.g., weight or bias).
13
14    Each variable holds:
15      - A mean parameter (`mu`)
16      - A `rho` parameter that is used to derive the standard deviation (`sigma`)
17      - A method to sample from the underlying distribution
18      - A method to compute KL divergence with another variable of the same type
19
20    This class inherits from `nn.Module` for parameter registration in PyTorch
21    and from `KLDivergenceInterface` for consistent KL divergence handling.
22    """
23
24    def __init__(
25        self,
26        mu: torch.Tensor,
27        rho: torch.Tensor,
28        mu_requires_grad: bool = False,
29        rho_requires_grad: bool = False,
30    ):
31        """
32        Initialize an AbstractVariable with given `mu` and `rho` tensors.
33
34        Args:
35            mu (torch.Tensor): The mean parameter for the distribution.
36            rho (torch.Tensor): The parameter from which we derive sigma = log(1 + exp(rho)).
37            mu_requires_grad (bool, optional): If True, allow gradients on `mu`.
38            rho_requires_grad (bool, optional): If True, allow gradients on `rho`.
39        """
40        super().__init__()
41        self.mu = nn.Parameter(mu.detach().clone(), requires_grad=mu_requires_grad)
42        self.rho = nn.Parameter(rho.detach().clone(), requires_grad=rho_requires_grad)
43        self.kl_div = None
44
45    @property
46    def sigma(self) -> torch.Tensor:
47        """
48        The standard deviation of the distribution, computed as:
49            sigma = log(1 + exp(rho)).
50
51        Returns:
52            torch.Tensor: A tensor representing the current standard deviation.
53        """
54        return torch.log(1 + torch.exp(self.rho))
55
56    @abstractmethod
57    def sample(self) -> torch.Tensor:
58        """
59        Draw a sample from this variable's underlying distribution.
60
61        Returns:
62            torch.Tensor: A sampled value of the same shape as `mu`.
63        """
64        pass
65
66    @abstractmethod
67    def compute_kl(self, other: "AbstractVariable") -> torch.Tensor:
68        """
69        Compute the KL divergence between this variable and another variable
70        of the same distribution type.
71
72        Args:
73            other (AbstractVariable): Another AbstractVariable instance
74                with comparable parameters (e.g., mu, rho).
75
76        Returns:
77            torch.Tensor: A scalar tensor with the KL divergence value.
78        """
79        pass
class AbstractVariable(torch.nn.modules.module.Module, core.utils.kl.KLDivergenceInterface, abc.ABC):
10class AbstractVariable(nn.Module, KLDivergenceInterface, ABC):
11    """
12    An abstract class representing a single random variable for a probabilistic
13    neural network parameter (e.g., weight or bias).
14
15    Each variable holds:
16      - A mean parameter (`mu`)
17      - A `rho` parameter that is used to derive the standard deviation (`sigma`)
18      - A method to sample from the underlying distribution
19      - A method to compute KL divergence with another variable of the same type
20
21    This class inherits from `nn.Module` for parameter registration in PyTorch
22    and from `KLDivergenceInterface` for consistent KL divergence handling.
23    """
24
25    def __init__(
26        self,
27        mu: torch.Tensor,
28        rho: torch.Tensor,
29        mu_requires_grad: bool = False,
30        rho_requires_grad: bool = False,
31    ):
32        """
33        Initialize an AbstractVariable with given `mu` and `rho` tensors.
34
35        Args:
36            mu (torch.Tensor): The mean parameter for the distribution.
37            rho (torch.Tensor): The parameter from which we derive sigma = log(1 + exp(rho)).
38            mu_requires_grad (bool, optional): If True, allow gradients on `mu`.
39            rho_requires_grad (bool, optional): If True, allow gradients on `rho`.
40        """
41        super().__init__()
42        self.mu = nn.Parameter(mu.detach().clone(), requires_grad=mu_requires_grad)
43        self.rho = nn.Parameter(rho.detach().clone(), requires_grad=rho_requires_grad)
44        self.kl_div = None
45
46    @property
47    def sigma(self) -> torch.Tensor:
48        """
49        The standard deviation of the distribution, computed as:
50            sigma = log(1 + exp(rho)).
51
52        Returns:
53            torch.Tensor: A tensor representing the current standard deviation.
54        """
55        return torch.log(1 + torch.exp(self.rho))
56
57    @abstractmethod
58    def sample(self) -> torch.Tensor:
59        """
60        Draw a sample from this variable's underlying distribution.
61
62        Returns:
63            torch.Tensor: A sampled value of the same shape as `mu`.
64        """
65        pass
66
67    @abstractmethod
68    def compute_kl(self, other: "AbstractVariable") -> torch.Tensor:
69        """
70        Compute the KL divergence between this variable and another variable
71        of the same distribution type.
72
73        Args:
74            other (AbstractVariable): Another AbstractVariable instance
75                with comparable parameters (e.g., mu, rho).
76
77        Returns:
78            torch.Tensor: A scalar tensor with the KL divergence value.
79        """
80        pass

An abstract class representing a single random variable for a probabilistic neural network parameter (e.g., weight or bias).

Each variable holds:
  • A mean parameter (mu)
  • A rho parameter that is used to derive the standard deviation (sigma)
  • A method to sample from the underlying distribution
  • A method to compute KL divergence with another variable of the same type

This class inherits from nn.Module for parameter registration in PyTorch and from KLDivergenceInterface for consistent KL divergence handling.

AbstractVariable( mu: torch.Tensor, rho: torch.Tensor, mu_requires_grad: bool = False, rho_requires_grad: bool = False)
25    def __init__(
26        self,
27        mu: torch.Tensor,
28        rho: torch.Tensor,
29        mu_requires_grad: bool = False,
30        rho_requires_grad: bool = False,
31    ):
32        """
33        Initialize an AbstractVariable with given `mu` and `rho` tensors.
34
35        Args:
36            mu (torch.Tensor): The mean parameter for the distribution.
37            rho (torch.Tensor): The parameter from which we derive sigma = log(1 + exp(rho)).
38            mu_requires_grad (bool, optional): If True, allow gradients on `mu`.
39            rho_requires_grad (bool, optional): If True, allow gradients on `rho`.
40        """
41        super().__init__()
42        self.mu = nn.Parameter(mu.detach().clone(), requires_grad=mu_requires_grad)
43        self.rho = nn.Parameter(rho.detach().clone(), requires_grad=rho_requires_grad)
44        self.kl_div = None

Initialize an AbstractVariable with given mu and rho tensors.

Arguments:
  • mu (torch.Tensor): The mean parameter for the distribution.
  • rho (torch.Tensor): The parameter from which we derive sigma = log(1 + exp(rho)).
  • mu_requires_grad (bool, optional): If True, allow gradients on mu.
  • rho_requires_grad (bool, optional): If True, allow gradients on rho.
mu
rho
kl_div
sigma: torch.Tensor
46    @property
47    def sigma(self) -> torch.Tensor:
48        """
49        The standard deviation of the distribution, computed as:
50            sigma = log(1 + exp(rho)).
51
52        Returns:
53            torch.Tensor: A tensor representing the current standard deviation.
54        """
55        return torch.log(1 + torch.exp(self.rho))

The standard deviation of the distribution, computed as: sigma = log(1 + exp(rho)).

Returns:

torch.Tensor: A tensor representing the current standard deviation.

@abstractmethod
def sample(self) -> torch.Tensor:
57    @abstractmethod
58    def sample(self) -> torch.Tensor:
59        """
60        Draw a sample from this variable's underlying distribution.
61
62        Returns:
63            torch.Tensor: A sampled value of the same shape as `mu`.
64        """
65        pass

Draw a sample from this variable's underlying distribution.

Returns:

torch.Tensor: A sampled value of the same shape as mu.

@abstractmethod
def compute_kl( self, other: AbstractVariable) -> torch.Tensor:
67    @abstractmethod
68    def compute_kl(self, other: "AbstractVariable") -> torch.Tensor:
69        """
70        Compute the KL divergence between this variable and another variable
71        of the same distribution type.
72
73        Args:
74            other (AbstractVariable): Another AbstractVariable instance
75                with comparable parameters (e.g., mu, rho).
76
77        Returns:
78            torch.Tensor: A scalar tensor with the KL divergence value.
79        """
80        pass

Compute the KL divergence between this variable and another variable of the same distribution type.

Arguments:
  • other (AbstractVariable): Another AbstractVariable instance with comparable parameters (e.g., mu, rho).
Returns:

torch.Tensor: A scalar tensor with the KL divergence value.