core.layer.ProbLinear
1import torch.nn.functional as f 2from torch import Tensor, nn 3 4from core.layer import AbstractProbLayer 5 6 7class ProbLinear(nn.Linear, AbstractProbLayer): 8 """ 9 A probabilistic linear (fully connected) layer. 10 11 Extends `nn.Linear` such that weights and bias are sampled from 12 a distribution during each forward pass if `probabilistic_mode` is True. 13 """ 14 15 def forward(self, input: Tensor) -> Tensor: 16 """ 17 Forward pass for a probabilistic linear layer. 18 19 Args: 20 input (Tensor): Input tensor of shape (N, in_features). 21 22 Returns: 23 Tensor: Output tensor of shape (N, out_features). 24 """ 25 sampled_weight, sampled_bias = self.sample_from_distribution() 26 return f.linear(input, sampled_weight, sampled_bias)
8class ProbLinear(nn.Linear, AbstractProbLayer): 9 """ 10 A probabilistic linear (fully connected) layer. 11 12 Extends `nn.Linear` such that weights and bias are sampled from 13 a distribution during each forward pass if `probabilistic_mode` is True. 14 """ 15 16 def forward(self, input: Tensor) -> Tensor: 17 """ 18 Forward pass for a probabilistic linear layer. 19 20 Args: 21 input (Tensor): Input tensor of shape (N, in_features). 22 23 Returns: 24 Tensor: Output tensor of shape (N, out_features). 25 """ 26 sampled_weight, sampled_bias = self.sample_from_distribution() 27 return f.linear(input, sampled_weight, sampled_bias)
A probabilistic linear (fully connected) layer.
Extends nn.Linear such that weights and bias are sampled from
a distribution during each forward pass if probabilistic_mode is True.
def
forward(self, input: torch.Tensor) -> torch.Tensor:
16 def forward(self, input: Tensor) -> Tensor: 17 """ 18 Forward pass for a probabilistic linear layer. 19 20 Args: 21 input (Tensor): Input tensor of shape (N, in_features). 22 23 Returns: 24 Tensor: Output tensor of shape (N, out_features). 25 """ 26 sampled_weight, sampled_bias = self.sample_from_distribution() 27 return f.linear(input, sampled_weight, sampled_bias)
Forward pass for a probabilistic linear layer.
Arguments:
- input (Tensor): Input tensor of shape (N, in_features).
Returns:
Tensor: Output tensor of shape (N, out_features).