core.layer.ProbBatchNorm2d
1import torch.nn.functional as f 2from torch import Tensor, nn 3 4from core.layer import AbstractProbLayer 5 6 7class ProbBatchNorm2d(nn.BatchNorm2d, AbstractProbLayer): 8 """ 9 A probabilistic 2D batch normalization layer. 10 11 Extends PyTorch's `nn.BatchNorm2d` to sample weight and bias from learned 12 distributions for use in a probabilistic neural network. 13 """ 14 15 def forward(self, input: Tensor) -> Tensor: 16 """ 17 Forward pass for probabilistic 2D batch normalization. 18 19 During training: 20 - Uses mini-batch statistics to normalize. 21 - Samples weight and bias if `probabilistic_mode` is True. 22 23 Args: 24 input (Tensor): Input tensor of shape (N, C, H, W). 25 26 Returns: 27 Tensor: Batch-normalized output of the same shape as `input`. 28 """ 29 self._check_input_dim(input) 30 31 if self.momentum is None: 32 exponential_average_factor = 0.0 33 else: 34 exponential_average_factor = self.momentum 35 36 if self.training and self.track_running_stats: 37 if self.num_batches_tracked is not None: 38 self.num_batches_tracked.add_(1) 39 if self.momentum is None: 40 exponential_average_factor = 1.0 / float(self.num_batches_tracked) 41 else: 42 exponential_average_factor = self.momentum 43 44 if self.training: 45 bn_training = True 46 else: 47 bn_training = (self.running_mean is None) and (self.running_var is None) 48 49 sampled_weight, sampled_bias = self.sample_from_distribution() 50 51 return f.batch_norm( 52 input, 53 self.running_mean 54 if not self.training or self.track_running_stats 55 else None, 56 self.running_var if not self.training or self.track_running_stats else None, 57 sampled_weight, 58 sampled_bias, 59 bn_training, 60 exponential_average_factor, 61 self.eps, 62 )
class
ProbBatchNorm2d(torch.nn.modules.batchnorm.BatchNorm2d, core.layer.AbstractProbLayer.AbstractProbLayer):
8class ProbBatchNorm2d(nn.BatchNorm2d, AbstractProbLayer): 9 """ 10 A probabilistic 2D batch normalization layer. 11 12 Extends PyTorch's `nn.BatchNorm2d` to sample weight and bias from learned 13 distributions for use in a probabilistic neural network. 14 """ 15 16 def forward(self, input: Tensor) -> Tensor: 17 """ 18 Forward pass for probabilistic 2D batch normalization. 19 20 During training: 21 - Uses mini-batch statistics to normalize. 22 - Samples weight and bias if `probabilistic_mode` is True. 23 24 Args: 25 input (Tensor): Input tensor of shape (N, C, H, W). 26 27 Returns: 28 Tensor: Batch-normalized output of the same shape as `input`. 29 """ 30 self._check_input_dim(input) 31 32 if self.momentum is None: 33 exponential_average_factor = 0.0 34 else: 35 exponential_average_factor = self.momentum 36 37 if self.training and self.track_running_stats: 38 if self.num_batches_tracked is not None: 39 self.num_batches_tracked.add_(1) 40 if self.momentum is None: 41 exponential_average_factor = 1.0 / float(self.num_batches_tracked) 42 else: 43 exponential_average_factor = self.momentum 44 45 if self.training: 46 bn_training = True 47 else: 48 bn_training = (self.running_mean is None) and (self.running_var is None) 49 50 sampled_weight, sampled_bias = self.sample_from_distribution() 51 52 return f.batch_norm( 53 input, 54 self.running_mean 55 if not self.training or self.track_running_stats 56 else None, 57 self.running_var if not self.training or self.track_running_stats else None, 58 sampled_weight, 59 sampled_bias, 60 bn_training, 61 exponential_average_factor, 62 self.eps, 63 )
A probabilistic 2D batch normalization layer.
Extends PyTorch's nn.BatchNorm2d to sample weight and bias from learned
distributions for use in a probabilistic neural network.
def
forward(self, input: torch.Tensor) -> torch.Tensor:
16 def forward(self, input: Tensor) -> Tensor: 17 """ 18 Forward pass for probabilistic 2D batch normalization. 19 20 During training: 21 - Uses mini-batch statistics to normalize. 22 - Samples weight and bias if `probabilistic_mode` is True. 23 24 Args: 25 input (Tensor): Input tensor of shape (N, C, H, W). 26 27 Returns: 28 Tensor: Batch-normalized output of the same shape as `input`. 29 """ 30 self._check_input_dim(input) 31 32 if self.momentum is None: 33 exponential_average_factor = 0.0 34 else: 35 exponential_average_factor = self.momentum 36 37 if self.training and self.track_running_stats: 38 if self.num_batches_tracked is not None: 39 self.num_batches_tracked.add_(1) 40 if self.momentum is None: 41 exponential_average_factor = 1.0 / float(self.num_batches_tracked) 42 else: 43 exponential_average_factor = self.momentum 44 45 if self.training: 46 bn_training = True 47 else: 48 bn_training = (self.running_mean is None) and (self.running_var is None) 49 50 sampled_weight, sampled_bias = self.sample_from_distribution() 51 52 return f.batch_norm( 53 input, 54 self.running_mean 55 if not self.training or self.track_running_stats 56 else None, 57 self.running_var if not self.training or self.track_running_stats else None, 58 sampled_weight, 59 sampled_bias, 60 bn_training, 61 exponential_average_factor, 62 self.eps, 63 )
Forward pass for probabilistic 2D batch normalization.
During training:
- Uses mini-batch statistics to normalize.
- Samples weight and bias if
probabilistic_modeis True.
Arguments:
- input (Tensor): Input tensor of shape (N, C, H, W).
Returns:
Tensor: Batch-normalized output of the same shape as
input.