core.objective.TolstikhinObjective
1import numpy as np 2import torch 3from torch import Tensor 4 5from core.objective import AbstractObjective 6 7 8class TolstikhinObjective(AbstractObjective): 9 """ 10 Objective related to Tolstikhin et al. (2013), featuring a combination of 11 the empirical loss, a square-root term involving KL, and an additional additive term. 12 """ 13 14 def __init__(self, kl_penalty: float, delta: float): 15 """ 16 Args: 17 kl_penalty (float): The coefficient multiplying the KL term. 18 delta (float): Confidence parameter in the PAC-Bayes or related bound. 19 """ 20 self._kl_penalty = kl_penalty 21 self._delta = delta # confidence value for the training objective 22 23 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 24 """ 25 Compute the Tolstikhin objective. 26 27 The final expression includes: 28 loss + sqrt(2 * loss * ratio) + 2 * ratio 29 where ratio = (kl_penalty*KL + ln(2n) - ln(delta)) / (2n). 30 31 Args: 32 loss (Tensor): Empirical loss. 33 kl (Tensor): KL divergence. 34 num_samples (float): Number of data samples for normalization. 35 36 Returns: 37 Tensor: The scalar objective value. 38 """ 39 kl = kl * self._kl_penalty 40 second_term = ( 41 2 42 * loss 43 * torch.div( 44 kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples 45 ) 46 ) 47 third_term = 2 * torch.div( 48 kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples 49 ) 50 return loss + torch.sqrt(second_term) + third_term
9class TolstikhinObjective(AbstractObjective): 10 """ 11 Objective related to Tolstikhin et al. (2013), featuring a combination of 12 the empirical loss, a square-root term involving KL, and an additional additive term. 13 """ 14 15 def __init__(self, kl_penalty: float, delta: float): 16 """ 17 Args: 18 kl_penalty (float): The coefficient multiplying the KL term. 19 delta (float): Confidence parameter in the PAC-Bayes or related bound. 20 """ 21 self._kl_penalty = kl_penalty 22 self._delta = delta # confidence value for the training objective 23 24 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 25 """ 26 Compute the Tolstikhin objective. 27 28 The final expression includes: 29 loss + sqrt(2 * loss * ratio) + 2 * ratio 30 where ratio = (kl_penalty*KL + ln(2n) - ln(delta)) / (2n). 31 32 Args: 33 loss (Tensor): Empirical loss. 34 kl (Tensor): KL divergence. 35 num_samples (float): Number of data samples for normalization. 36 37 Returns: 38 Tensor: The scalar objective value. 39 """ 40 kl = kl * self._kl_penalty 41 second_term = ( 42 2 43 * loss 44 * torch.div( 45 kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples 46 ) 47 ) 48 third_term = 2 * torch.div( 49 kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples 50 ) 51 return loss + torch.sqrt(second_term) + third_term
Objective related to Tolstikhin et al. (2013), featuring a combination of the empirical loss, a square-root term involving KL, and an additional additive term.
TolstikhinObjective(kl_penalty: float, delta: float)
15 def __init__(self, kl_penalty: float, delta: float): 16 """ 17 Args: 18 kl_penalty (float): The coefficient multiplying the KL term. 19 delta (float): Confidence parameter in the PAC-Bayes or related bound. 20 """ 21 self._kl_penalty = kl_penalty 22 self._delta = delta # confidence value for the training objective
Arguments:
- kl_penalty (float): The coefficient multiplying the KL term.
- delta (float): Confidence parameter in the PAC-Bayes or related bound.
def
calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
24 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 25 """ 26 Compute the Tolstikhin objective. 27 28 The final expression includes: 29 loss + sqrt(2 * loss * ratio) + 2 * ratio 30 where ratio = (kl_penalty*KL + ln(2n) - ln(delta)) / (2n). 31 32 Args: 33 loss (Tensor): Empirical loss. 34 kl (Tensor): KL divergence. 35 num_samples (float): Number of data samples for normalization. 36 37 Returns: 38 Tensor: The scalar objective value. 39 """ 40 kl = kl * self._kl_penalty 41 second_term = ( 42 2 43 * loss 44 * torch.div( 45 kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples 46 ) 47 ) 48 third_term = 2 * torch.div( 49 kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples 50 ) 51 return loss + torch.sqrt(second_term) + third_term
Compute the Tolstikhin objective.
The final expression includes:
loss + sqrt(2 * loss * ratio) + 2 * ratio
where ratio = (kl_penalty*KL + ln(2n) - ln(delta)) / (2n).
Arguments:
- loss (Tensor): Empirical loss.
- kl (Tensor): KL divergence.
- num_samples (float): Number of data samples for normalization.
Returns:
Tensor: The scalar objective value.