core.objective.BBBObjective

 1from torch import Tensor
 2
 3from core.objective import AbstractObjective
 4
 5
 6class BBBObjective(AbstractObjective):
 7    """
 8    The Bayes By Backprop (BBB) objective from Blundell et al. (2015).
 9
10    This objective typically adds a KL penalty weighted by a user-defined factor.
11    """
12
13    def __init__(self, kl_penalty: float):
14        """
15        Args:
16            kl_penalty (float): The coefficient for scaling KL divergence in the objective.
17        """
18        self._kl_penalty = kl_penalty
19
20    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
21        """
22        Combine the loss with scaled KL divergence.
23
24        Args:
25            loss (Tensor): Empirical loss (e.g., NLL).
26            kl (Tensor): KL divergence between posterior and prior.
27            num_samples (float): The number of training samples or an equivalent factor.
28
29        Returns:
30            Tensor: A scalar objective = loss + (kl_penalty * KL / num_samples).
31        """
32        return loss + self._kl_penalty * (kl / num_samples)
class BBBObjective(core.objective.AbstractObjective.AbstractObjective):
 7class BBBObjective(AbstractObjective):
 8    """
 9    The Bayes By Backprop (BBB) objective from Blundell et al. (2015).
10
11    This objective typically adds a KL penalty weighted by a user-defined factor.
12    """
13
14    def __init__(self, kl_penalty: float):
15        """
16        Args:
17            kl_penalty (float): The coefficient for scaling KL divergence in the objective.
18        """
19        self._kl_penalty = kl_penalty
20
21    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
22        """
23        Combine the loss with scaled KL divergence.
24
25        Args:
26            loss (Tensor): Empirical loss (e.g., NLL).
27            kl (Tensor): KL divergence between posterior and prior.
28            num_samples (float): The number of training samples or an equivalent factor.
29
30        Returns:
31            Tensor: A scalar objective = loss + (kl_penalty * KL / num_samples).
32        """
33        return loss + self._kl_penalty * (kl / num_samples)

The Bayes By Backprop (BBB) objective from Blundell et al. (2015).

This objective typically adds a KL penalty weighted by a user-defined factor.

BBBObjective(kl_penalty: float)
14    def __init__(self, kl_penalty: float):
15        """
16        Args:
17            kl_penalty (float): The coefficient for scaling KL divergence in the objective.
18        """
19        self._kl_penalty = kl_penalty
Arguments:
  • kl_penalty (float): The coefficient for scaling KL divergence in the objective.
def calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
21    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
22        """
23        Combine the loss with scaled KL divergence.
24
25        Args:
26            loss (Tensor): Empirical loss (e.g., NLL).
27            kl (Tensor): KL divergence between posterior and prior.
28            num_samples (float): The number of training samples or an equivalent factor.
29
30        Returns:
31            Tensor: A scalar objective = loss + (kl_penalty * KL / num_samples).
32        """
33        return loss + self._kl_penalty * (kl / num_samples)

Combine the loss with scaled KL divergence.

Arguments:
  • loss (Tensor): Empirical loss (e.g., NLL).
  • kl (Tensor): KL divergence between posterior and prior.
  • num_samples (float): The number of training samples or an equivalent factor.
Returns:

Tensor: A scalar objective = loss + (kl_penalty * KL / num_samples).