core.layer.ProbConv2d
1import torch.nn.functional as f 2from torch import Tensor, nn 3 4from core.layer import AbstractProbLayer 5 6 7class ProbConv2d(nn.Conv2d, AbstractProbLayer): 8 """ 9 A probabilistic 2D convolution layer. 10 11 Inherits from `nn.Conv2d` and `AbstractProbLayer`. Weights and bias 12 are sampled from associated distributions during forward passes. 13 """ 14 15 def forward(self, input: Tensor) -> Tensor: 16 """ 17 Perform a 2D convolution using sampled weights and bias. 18 19 Args: 20 input (Tensor): The input tensor of shape (N, C_in, H_in, W_in). 21 22 Returns: 23 Tensor: The output tensor of shape (N, C_out, H_out, W_out). 24 """ 25 sampled_weight, sampled_bias = self.sample_from_distribution() 26 return f.conv2d( 27 input, 28 sampled_weight, 29 sampled_bias, 30 self.stride, 31 self.padding, 32 self.dilation, 33 self.groups, 34 )
8class ProbConv2d(nn.Conv2d, AbstractProbLayer): 9 """ 10 A probabilistic 2D convolution layer. 11 12 Inherits from `nn.Conv2d` and `AbstractProbLayer`. Weights and bias 13 are sampled from associated distributions during forward passes. 14 """ 15 16 def forward(self, input: Tensor) -> Tensor: 17 """ 18 Perform a 2D convolution using sampled weights and bias. 19 20 Args: 21 input (Tensor): The input tensor of shape (N, C_in, H_in, W_in). 22 23 Returns: 24 Tensor: The output tensor of shape (N, C_out, H_out, W_out). 25 """ 26 sampled_weight, sampled_bias = self.sample_from_distribution() 27 return f.conv2d( 28 input, 29 sampled_weight, 30 sampled_bias, 31 self.stride, 32 self.padding, 33 self.dilation, 34 self.groups, 35 )
A probabilistic 2D convolution layer.
Inherits from nn.Conv2d and AbstractProbLayer. Weights and bias
are sampled from associated distributions during forward passes.
def
forward(self, input: torch.Tensor) -> torch.Tensor:
16 def forward(self, input: Tensor) -> Tensor: 17 """ 18 Perform a 2D convolution using sampled weights and bias. 19 20 Args: 21 input (Tensor): The input tensor of shape (N, C_in, H_in, W_in). 22 23 Returns: 24 Tensor: The output tensor of shape (N, C_out, H_out, W_out). 25 """ 26 sampled_weight, sampled_bias = self.sample_from_distribution() 27 return f.conv2d( 28 input, 29 sampled_weight, 30 sampled_bias, 31 self.stride, 32 self.padding, 33 self.dilation, 34 self.groups, 35 )
Perform a 2D convolution using sampled weights and bias.
Arguments:
- input (Tensor): The input tensor of shape (N, C_in, H_in, W_in).
Returns:
Tensor: The output tensor of shape (N, C_out, H_out, W_out).