core.objective.BBBObjective
1from torch import Tensor 2 3from core.objective import AbstractObjective 4 5 6class BBBObjective(AbstractObjective): 7 """ 8 The Bayes By Backprop (BBB) objective from Blundell et al. (2015). 9 10 This objective typically adds a KL penalty weighted by a user-defined factor. 11 """ 12 13 def __init__(self, kl_penalty: float): 14 """ 15 Args: 16 kl_penalty (float): The coefficient for scaling KL divergence in the objective. 17 """ 18 self._kl_penalty = kl_penalty 19 20 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 21 """ 22 Combine the loss with scaled KL divergence. 23 24 Args: 25 loss (Tensor): Empirical loss (e.g., NLL). 26 kl (Tensor): KL divergence between posterior and prior. 27 num_samples (float): The number of training samples or an equivalent factor. 28 29 Returns: 30 Tensor: A scalar objective = loss + (kl_penalty * KL / num_samples). 31 """ 32 return loss + self._kl_penalty * (kl / num_samples)
7class BBBObjective(AbstractObjective): 8 """ 9 The Bayes By Backprop (BBB) objective from Blundell et al. (2015). 10 11 This objective typically adds a KL penalty weighted by a user-defined factor. 12 """ 13 14 def __init__(self, kl_penalty: float): 15 """ 16 Args: 17 kl_penalty (float): The coefficient for scaling KL divergence in the objective. 18 """ 19 self._kl_penalty = kl_penalty 20 21 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 22 """ 23 Combine the loss with scaled KL divergence. 24 25 Args: 26 loss (Tensor): Empirical loss (e.g., NLL). 27 kl (Tensor): KL divergence between posterior and prior. 28 num_samples (float): The number of training samples or an equivalent factor. 29 30 Returns: 31 Tensor: A scalar objective = loss + (kl_penalty * KL / num_samples). 32 """ 33 return loss + self._kl_penalty * (kl / num_samples)
The Bayes By Backprop (BBB) objective from Blundell et al. (2015).
This objective typically adds a KL penalty weighted by a user-defined factor.
BBBObjective(kl_penalty: float)
14 def __init__(self, kl_penalty: float): 15 """ 16 Args: 17 kl_penalty (float): The coefficient for scaling KL divergence in the objective. 18 """ 19 self._kl_penalty = kl_penalty
Arguments:
- kl_penalty (float): The coefficient for scaling KL divergence in the objective.
def
calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
21 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 22 """ 23 Combine the loss with scaled KL divergence. 24 25 Args: 26 loss (Tensor): Empirical loss (e.g., NLL). 27 kl (Tensor): KL divergence between posterior and prior. 28 num_samples (float): The number of training samples or an equivalent factor. 29 30 Returns: 31 Tensor: A scalar objective = loss + (kl_penalty * KL / num_samples). 32 """ 33 return loss + self._kl_penalty * (kl / num_samples)
Combine the loss with scaled KL divergence.
Arguments:
- loss (Tensor): Empirical loss (e.g., NLL).
- kl (Tensor): KL divergence between posterior and prior.
- num_samples (float): The number of training samples or an equivalent factor.
Returns:
Tensor: A scalar objective = loss + (kl_penalty * KL / num_samples).