core.layer.ProbBatchNorm2d

 1import torch.nn.functional as f
 2from torch import Tensor, nn
 3
 4from core.layer import AbstractProbLayer
 5
 6
 7class ProbBatchNorm2d(nn.BatchNorm2d, AbstractProbLayer):
 8    """
 9    A probabilistic 2D batch normalization layer.
10
11    Extends PyTorch's `nn.BatchNorm2d` to sample weight and bias from learned
12    distributions for use in a probabilistic neural network.
13    """
14
15    def forward(self, input: Tensor) -> Tensor:
16        """
17        Forward pass for probabilistic 2D batch normalization.
18
19        During training:
20          - Uses mini-batch statistics to normalize.
21          - Samples weight and bias if `probabilistic_mode` is True.
22
23        Args:
24            input (Tensor): Input tensor of shape (N, C, H, W).
25
26        Returns:
27            Tensor: Batch-normalized output of the same shape as `input`.
28        """
29        self._check_input_dim(input)
30
31        if self.momentum is None:
32            exponential_average_factor = 0.0
33        else:
34            exponential_average_factor = self.momentum
35
36        if self.training and self.track_running_stats:
37            if self.num_batches_tracked is not None:
38                self.num_batches_tracked.add_(1)
39                if self.momentum is None:
40                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
41                else:
42                    exponential_average_factor = self.momentum
43
44        if self.training:
45            bn_training = True
46        else:
47            bn_training = (self.running_mean is None) and (self.running_var is None)
48
49        sampled_weight, sampled_bias = self.sample_from_distribution()
50
51        return f.batch_norm(
52            input,
53            self.running_mean
54            if not self.training or self.track_running_stats
55            else None,
56            self.running_var if not self.training or self.track_running_stats else None,
57            sampled_weight,
58            sampled_bias,
59            bn_training,
60            exponential_average_factor,
61            self.eps,
62        )
class ProbBatchNorm2d(torch.nn.modules.batchnorm.BatchNorm2d, core.layer.AbstractProbLayer.AbstractProbLayer):
 8class ProbBatchNorm2d(nn.BatchNorm2d, AbstractProbLayer):
 9    """
10    A probabilistic 2D batch normalization layer.
11
12    Extends PyTorch's `nn.BatchNorm2d` to sample weight and bias from learned
13    distributions for use in a probabilistic neural network.
14    """
15
16    def forward(self, input: Tensor) -> Tensor:
17        """
18        Forward pass for probabilistic 2D batch normalization.
19
20        During training:
21          - Uses mini-batch statistics to normalize.
22          - Samples weight and bias if `probabilistic_mode` is True.
23
24        Args:
25            input (Tensor): Input tensor of shape (N, C, H, W).
26
27        Returns:
28            Tensor: Batch-normalized output of the same shape as `input`.
29        """
30        self._check_input_dim(input)
31
32        if self.momentum is None:
33            exponential_average_factor = 0.0
34        else:
35            exponential_average_factor = self.momentum
36
37        if self.training and self.track_running_stats:
38            if self.num_batches_tracked is not None:
39                self.num_batches_tracked.add_(1)
40                if self.momentum is None:
41                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
42                else:
43                    exponential_average_factor = self.momentum
44
45        if self.training:
46            bn_training = True
47        else:
48            bn_training = (self.running_mean is None) and (self.running_var is None)
49
50        sampled_weight, sampled_bias = self.sample_from_distribution()
51
52        return f.batch_norm(
53            input,
54            self.running_mean
55            if not self.training or self.track_running_stats
56            else None,
57            self.running_var if not self.training or self.track_running_stats else None,
58            sampled_weight,
59            sampled_bias,
60            bn_training,
61            exponential_average_factor,
62            self.eps,
63        )

A probabilistic 2D batch normalization layer.

Extends PyTorch's nn.BatchNorm2d to sample weight and bias from learned distributions for use in a probabilistic neural network.

def forward(self, input: torch.Tensor) -> torch.Tensor:
16    def forward(self, input: Tensor) -> Tensor:
17        """
18        Forward pass for probabilistic 2D batch normalization.
19
20        During training:
21          - Uses mini-batch statistics to normalize.
22          - Samples weight and bias if `probabilistic_mode` is True.
23
24        Args:
25            input (Tensor): Input tensor of shape (N, C, H, W).
26
27        Returns:
28            Tensor: Batch-normalized output of the same shape as `input`.
29        """
30        self._check_input_dim(input)
31
32        if self.momentum is None:
33            exponential_average_factor = 0.0
34        else:
35            exponential_average_factor = self.momentum
36
37        if self.training and self.track_running_stats:
38            if self.num_batches_tracked is not None:
39                self.num_batches_tracked.add_(1)
40                if self.momentum is None:
41                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
42                else:
43                    exponential_average_factor = self.momentum
44
45        if self.training:
46            bn_training = True
47        else:
48            bn_training = (self.running_mean is None) and (self.running_var is None)
49
50        sampled_weight, sampled_bias = self.sample_from_distribution()
51
52        return f.batch_norm(
53            input,
54            self.running_mean
55            if not self.training or self.track_running_stats
56            else None,
57            self.running_var if not self.training or self.track_running_stats else None,
58            sampled_weight,
59            sampled_bias,
60            bn_training,
61            exponential_average_factor,
62            self.eps,
63        )

Forward pass for probabilistic 2D batch normalization.

During training:
  • Uses mini-batch statistics to normalize.
  • Samples weight and bias if probabilistic_mode is True.
Arguments:
  • input (Tensor): Input tensor of shape (N, C, H, W).
Returns:

Tensor: Batch-normalized output of the same shape as input.