core.layer.ProbLinear

 1import torch.nn.functional as f
 2from torch import Tensor, nn
 3
 4from core.layer import AbstractProbLayer
 5
 6
 7class ProbLinear(nn.Linear, AbstractProbLayer):
 8    """
 9    A probabilistic linear (fully connected) layer.
10
11    Extends `nn.Linear` such that weights and bias are sampled from
12    a distribution during each forward pass if `probabilistic_mode` is True.
13    """
14
15    def forward(self, input: Tensor) -> Tensor:
16        """
17        Forward pass for a probabilistic linear layer.
18
19        Args:
20            input (Tensor): Input tensor of shape (N, in_features).
21
22        Returns:
23            Tensor: Output tensor of shape (N, out_features).
24        """
25        sampled_weight, sampled_bias = self.sample_from_distribution()
26        return f.linear(input, sampled_weight, sampled_bias)
class ProbLinear(torch.nn.modules.linear.Linear, core.layer.AbstractProbLayer.AbstractProbLayer):
 8class ProbLinear(nn.Linear, AbstractProbLayer):
 9    """
10    A probabilistic linear (fully connected) layer.
11
12    Extends `nn.Linear` such that weights and bias are sampled from
13    a distribution during each forward pass if `probabilistic_mode` is True.
14    """
15
16    def forward(self, input: Tensor) -> Tensor:
17        """
18        Forward pass for a probabilistic linear layer.
19
20        Args:
21            input (Tensor): Input tensor of shape (N, in_features).
22
23        Returns:
24            Tensor: Output tensor of shape (N, out_features).
25        """
26        sampled_weight, sampled_bias = self.sample_from_distribution()
27        return f.linear(input, sampled_weight, sampled_bias)

A probabilistic linear (fully connected) layer.

Extends nn.Linear such that weights and bias are sampled from a distribution during each forward pass if probabilistic_mode is True.

def forward(self, input: torch.Tensor) -> torch.Tensor:
16    def forward(self, input: Tensor) -> Tensor:
17        """
18        Forward pass for a probabilistic linear layer.
19
20        Args:
21            input (Tensor): Input tensor of shape (N, in_features).
22
23        Returns:
24            Tensor: Output tensor of shape (N, out_features).
25        """
26        sampled_weight, sampled_bias = self.sample_from_distribution()
27        return f.linear(input, sampled_weight, sampled_bias)

Forward pass for a probabilistic linear layer.

Arguments:
  • input (Tensor): Input tensor of shape (N, in_features).
Returns:

Tensor: Output tensor of shape (N, out_features).