core.layer.ProbBatchNorm1d

 1import torch.nn.functional as f
 2from torch import Tensor, nn
 3
 4from core.layer import AbstractProbLayer
 5
 6
 7class ProbBatchNorm1d(nn.BatchNorm1d, AbstractProbLayer):
 8    """
 9    A probabilistic 1D batch normalization layer.
10
11    This layer extends PyTorch's `nn.BatchNorm1d` to sample weight and bias
12    from learned distributions. The forward pass behavior is the same as standard
13    batch norm, except the parameters come from `sample_from_distribution`.
14    """
15
16    def forward(self, input: Tensor) -> Tensor:
17        """
18        Forward pass for probabilistic batch normalization.
19
20        During training:
21          - Maintains running statistics if `track_running_stats` is True.
22          - Samples weight/bias if `probabilistic_mode` is True.
23
24        Args:
25            input (Tensor): Input tensor of shape (N, C, L) or (N, C).
26
27        Returns:
28            Tensor: Batch-normalized output of the same shape as `input`.
29        """
30        self._check_input_dim(input)
31
32        if self.momentum is None:
33            exponential_average_factor = 0.0
34        else:
35            exponential_average_factor = self.momentum
36
37        if self.training and self.track_running_stats:
38            if self.num_batches_tracked is not None:
39                self.num_batches_tracked.add_(1)
40                if self.momentum is None:
41                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
42                else:
43                    exponential_average_factor = self.momentum
44
45        if self.training:
46            bn_training = True
47        else:
48            bn_training = (self.running_mean is None) and (self.running_var is None)
49
50        sampled_weight, sampled_bias = self.sample_from_distribution()
51
52        return f.batch_norm(
53            input,
54            self.running_mean
55            if not self.training or self.track_running_stats
56            else None,
57            self.running_var if not self.training or self.track_running_stats else None,
58            sampled_weight,
59            sampled_bias,
60            bn_training,
61            exponential_average_factor,
62            self.eps,
63        )
class ProbBatchNorm1d(torch.nn.modules.batchnorm.BatchNorm1d, core.layer.AbstractProbLayer.AbstractProbLayer):
 8class ProbBatchNorm1d(nn.BatchNorm1d, AbstractProbLayer):
 9    """
10    A probabilistic 1D batch normalization layer.
11
12    This layer extends PyTorch's `nn.BatchNorm1d` to sample weight and bias
13    from learned distributions. The forward pass behavior is the same as standard
14    batch norm, except the parameters come from `sample_from_distribution`.
15    """
16
17    def forward(self, input: Tensor) -> Tensor:
18        """
19        Forward pass for probabilistic batch normalization.
20
21        During training:
22          - Maintains running statistics if `track_running_stats` is True.
23          - Samples weight/bias if `probabilistic_mode` is True.
24
25        Args:
26            input (Tensor): Input tensor of shape (N, C, L) or (N, C).
27
28        Returns:
29            Tensor: Batch-normalized output of the same shape as `input`.
30        """
31        self._check_input_dim(input)
32
33        if self.momentum is None:
34            exponential_average_factor = 0.0
35        else:
36            exponential_average_factor = self.momentum
37
38        if self.training and self.track_running_stats:
39            if self.num_batches_tracked is not None:
40                self.num_batches_tracked.add_(1)
41                if self.momentum is None:
42                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
43                else:
44                    exponential_average_factor = self.momentum
45
46        if self.training:
47            bn_training = True
48        else:
49            bn_training = (self.running_mean is None) and (self.running_var is None)
50
51        sampled_weight, sampled_bias = self.sample_from_distribution()
52
53        return f.batch_norm(
54            input,
55            self.running_mean
56            if not self.training or self.track_running_stats
57            else None,
58            self.running_var if not self.training or self.track_running_stats else None,
59            sampled_weight,
60            sampled_bias,
61            bn_training,
62            exponential_average_factor,
63            self.eps,
64        )

A probabilistic 1D batch normalization layer.

This layer extends PyTorch's nn.BatchNorm1d to sample weight and bias from learned distributions. The forward pass behavior is the same as standard batch norm, except the parameters come from sample_from_distribution.

def forward(self, input: torch.Tensor) -> torch.Tensor:
17    def forward(self, input: Tensor) -> Tensor:
18        """
19        Forward pass for probabilistic batch normalization.
20
21        During training:
22          - Maintains running statistics if `track_running_stats` is True.
23          - Samples weight/bias if `probabilistic_mode` is True.
24
25        Args:
26            input (Tensor): Input tensor of shape (N, C, L) or (N, C).
27
28        Returns:
29            Tensor: Batch-normalized output of the same shape as `input`.
30        """
31        self._check_input_dim(input)
32
33        if self.momentum is None:
34            exponential_average_factor = 0.0
35        else:
36            exponential_average_factor = self.momentum
37
38        if self.training and self.track_running_stats:
39            if self.num_batches_tracked is not None:
40                self.num_batches_tracked.add_(1)
41                if self.momentum is None:
42                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
43                else:
44                    exponential_average_factor = self.momentum
45
46        if self.training:
47            bn_training = True
48        else:
49            bn_training = (self.running_mean is None) and (self.running_var is None)
50
51        sampled_weight, sampled_bias = self.sample_from_distribution()
52
53        return f.batch_norm(
54            input,
55            self.running_mean
56            if not self.training or self.track_running_stats
57            else None,
58            self.running_var if not self.training or self.track_running_stats else None,
59            sampled_weight,
60            sampled_bias,
61            bn_training,
62            exponential_average_factor,
63            self.eps,
64        )

Forward pass for probabilistic batch normalization.

During training:
Arguments:
  • input (Tensor): Input tensor of shape (N, C, L) or (N, C).
Returns:

Tensor: Batch-normalized output of the same shape as input.