core.layer.ProbBatchNorm1d
1import torch.nn.functional as f 2from torch import Tensor, nn 3 4from core.layer import AbstractProbLayer 5 6 7class ProbBatchNorm1d(nn.BatchNorm1d, AbstractProbLayer): 8 """ 9 A probabilistic 1D batch normalization layer. 10 11 This layer extends PyTorch's `nn.BatchNorm1d` to sample weight and bias 12 from learned distributions. The forward pass behavior is the same as standard 13 batch norm, except the parameters come from `sample_from_distribution`. 14 """ 15 16 def forward(self, input: Tensor) -> Tensor: 17 """ 18 Forward pass for probabilistic batch normalization. 19 20 During training: 21 - Maintains running statistics if `track_running_stats` is True. 22 - Samples weight/bias if `probabilistic_mode` is True. 23 24 Args: 25 input (Tensor): Input tensor of shape (N, C, L) or (N, C). 26 27 Returns: 28 Tensor: Batch-normalized output of the same shape as `input`. 29 """ 30 self._check_input_dim(input) 31 32 if self.momentum is None: 33 exponential_average_factor = 0.0 34 else: 35 exponential_average_factor = self.momentum 36 37 if self.training and self.track_running_stats: 38 if self.num_batches_tracked is not None: 39 self.num_batches_tracked.add_(1) 40 if self.momentum is None: 41 exponential_average_factor = 1.0 / float(self.num_batches_tracked) 42 else: 43 exponential_average_factor = self.momentum 44 45 if self.training: 46 bn_training = True 47 else: 48 bn_training = (self.running_mean is None) and (self.running_var is None) 49 50 sampled_weight, sampled_bias = self.sample_from_distribution() 51 52 return f.batch_norm( 53 input, 54 self.running_mean 55 if not self.training or self.track_running_stats 56 else None, 57 self.running_var if not self.training or self.track_running_stats else None, 58 sampled_weight, 59 sampled_bias, 60 bn_training, 61 exponential_average_factor, 62 self.eps, 63 )
class
ProbBatchNorm1d(torch.nn.modules.batchnorm.BatchNorm1d, core.layer.AbstractProbLayer.AbstractProbLayer):
8class ProbBatchNorm1d(nn.BatchNorm1d, AbstractProbLayer): 9 """ 10 A probabilistic 1D batch normalization layer. 11 12 This layer extends PyTorch's `nn.BatchNorm1d` to sample weight and bias 13 from learned distributions. The forward pass behavior is the same as standard 14 batch norm, except the parameters come from `sample_from_distribution`. 15 """ 16 17 def forward(self, input: Tensor) -> Tensor: 18 """ 19 Forward pass for probabilistic batch normalization. 20 21 During training: 22 - Maintains running statistics if `track_running_stats` is True. 23 - Samples weight/bias if `probabilistic_mode` is True. 24 25 Args: 26 input (Tensor): Input tensor of shape (N, C, L) or (N, C). 27 28 Returns: 29 Tensor: Batch-normalized output of the same shape as `input`. 30 """ 31 self._check_input_dim(input) 32 33 if self.momentum is None: 34 exponential_average_factor = 0.0 35 else: 36 exponential_average_factor = self.momentum 37 38 if self.training and self.track_running_stats: 39 if self.num_batches_tracked is not None: 40 self.num_batches_tracked.add_(1) 41 if self.momentum is None: 42 exponential_average_factor = 1.0 / float(self.num_batches_tracked) 43 else: 44 exponential_average_factor = self.momentum 45 46 if self.training: 47 bn_training = True 48 else: 49 bn_training = (self.running_mean is None) and (self.running_var is None) 50 51 sampled_weight, sampled_bias = self.sample_from_distribution() 52 53 return f.batch_norm( 54 input, 55 self.running_mean 56 if not self.training or self.track_running_stats 57 else None, 58 self.running_var if not self.training or self.track_running_stats else None, 59 sampled_weight, 60 sampled_bias, 61 bn_training, 62 exponential_average_factor, 63 self.eps, 64 )
A probabilistic 1D batch normalization layer.
This layer extends PyTorch's nn.BatchNorm1d to sample weight and bias
from learned distributions. The forward pass behavior is the same as standard
batch norm, except the parameters come from sample_from_distribution.
def
forward(self, input: torch.Tensor) -> torch.Tensor:
17 def forward(self, input: Tensor) -> Tensor: 18 """ 19 Forward pass for probabilistic batch normalization. 20 21 During training: 22 - Maintains running statistics if `track_running_stats` is True. 23 - Samples weight/bias if `probabilistic_mode` is True. 24 25 Args: 26 input (Tensor): Input tensor of shape (N, C, L) or (N, C). 27 28 Returns: 29 Tensor: Batch-normalized output of the same shape as `input`. 30 """ 31 self._check_input_dim(input) 32 33 if self.momentum is None: 34 exponential_average_factor = 0.0 35 else: 36 exponential_average_factor = self.momentum 37 38 if self.training and self.track_running_stats: 39 if self.num_batches_tracked is not None: 40 self.num_batches_tracked.add_(1) 41 if self.momentum is None: 42 exponential_average_factor = 1.0 / float(self.num_batches_tracked) 43 else: 44 exponential_average_factor = self.momentum 45 46 if self.training: 47 bn_training = True 48 else: 49 bn_training = (self.running_mean is None) and (self.running_var is None) 50 51 sampled_weight, sampled_bias = self.sample_from_distribution() 52 53 return f.batch_norm( 54 input, 55 self.running_mean 56 if not self.training or self.track_running_stats 57 else None, 58 self.running_var if not self.training or self.track_running_stats else None, 59 sampled_weight, 60 sampled_bias, 61 bn_training, 62 exponential_average_factor, 63 self.eps, 64 )
Forward pass for probabilistic batch normalization.
During training:
- Maintains running statistics if
track_running_statsis True.- Samples weight/bias if
probabilistic_modeis True.
Arguments:
- input (Tensor): Input tensor of shape (N, C, L) or (N, C).
Returns:
Tensor: Batch-normalized output of the same shape as
input.