core.bound.KLBound
1import math 2 3from torch import Tensor 4 5from core.bound import AbstractBound 6from core.utils.kl import inv_kl 7 8 9class KLBound(AbstractBound): 10 """ 11 Implements a PAC Bayes KL bound. 12 """ 13 14 def __init__(self, bound_delta: float, loss_delta: float): 15 super().__init__(bound_delta, loss_delta) 16 17 def calculate( 18 self, 19 avg_loss: float, 20 kl: Tensor | float, 21 num_samples_bound: int, 22 num_samples_loss: int, 23 ) -> tuple[Tensor | float, Tensor | float]: 24 """ 25 Calculates the PAC Bayes bound. 26 27 Args: 28 avg_loss (float): The loss averaged using Monte Carlo sampling. 29 kl (Union[Tensor, float]): The Kullback-Leibler divergence between prior and posterior distributions. 30 num_samples_bound (int): The number of data samples in the bound dataset. 31 num_samples_loss (int): The number of Monte Carlo samples. 32 33 Returns: 34 Tuple[Union[Tensor, float], Union[Tensor, float]]: 35 A tuple containing the calculated PAC Bayes bound and the upper bound of empirical risk. 36 """ 37 empirical_risk = inv_kl( 38 avg_loss, math.log(2 / self._loss_delta) / num_samples_loss 39 ) 40 risk = inv_kl( 41 empirical_risk, 42 (kl + math.log((2 * math.sqrt(num_samples_bound)) / self._bound_delta)) 43 / num_samples_bound, 44 ) 45 return risk, empirical_risk
10class KLBound(AbstractBound): 11 """ 12 Implements a PAC Bayes KL bound. 13 """ 14 15 def __init__(self, bound_delta: float, loss_delta: float): 16 super().__init__(bound_delta, loss_delta) 17 18 def calculate( 19 self, 20 avg_loss: float, 21 kl: Tensor | float, 22 num_samples_bound: int, 23 num_samples_loss: int, 24 ) -> tuple[Tensor | float, Tensor | float]: 25 """ 26 Calculates the PAC Bayes bound. 27 28 Args: 29 avg_loss (float): The loss averaged using Monte Carlo sampling. 30 kl (Union[Tensor, float]): The Kullback-Leibler divergence between prior and posterior distributions. 31 num_samples_bound (int): The number of data samples in the bound dataset. 32 num_samples_loss (int): The number of Monte Carlo samples. 33 34 Returns: 35 Tuple[Union[Tensor, float], Union[Tensor, float]]: 36 A tuple containing the calculated PAC Bayes bound and the upper bound of empirical risk. 37 """ 38 empirical_risk = inv_kl( 39 avg_loss, math.log(2 / self._loss_delta) / num_samples_loss 40 ) 41 risk = inv_kl( 42 empirical_risk, 43 (kl + math.log((2 * math.sqrt(num_samples_bound)) / self._bound_delta)) 44 / num_samples_bound, 45 ) 46 return risk, empirical_risk
Implements a PAC Bayes KL bound.
def
calculate( self, avg_loss: float, kl: torch.Tensor | float, num_samples_bound: int, num_samples_loss: int) -> tuple[torch.Tensor | float, torch.Tensor | float]:
18 def calculate( 19 self, 20 avg_loss: float, 21 kl: Tensor | float, 22 num_samples_bound: int, 23 num_samples_loss: int, 24 ) -> tuple[Tensor | float, Tensor | float]: 25 """ 26 Calculates the PAC Bayes bound. 27 28 Args: 29 avg_loss (float): The loss averaged using Monte Carlo sampling. 30 kl (Union[Tensor, float]): The Kullback-Leibler divergence between prior and posterior distributions. 31 num_samples_bound (int): The number of data samples in the bound dataset. 32 num_samples_loss (int): The number of Monte Carlo samples. 33 34 Returns: 35 Tuple[Union[Tensor, float], Union[Tensor, float]]: 36 A tuple containing the calculated PAC Bayes bound and the upper bound of empirical risk. 37 """ 38 empirical_risk = inv_kl( 39 avg_loss, math.log(2 / self._loss_delta) / num_samples_loss 40 ) 41 risk = inv_kl( 42 empirical_risk, 43 (kl + math.log((2 * math.sqrt(num_samples_bound)) / self._bound_delta)) 44 / num_samples_bound, 45 ) 46 return risk, empirical_risk
Calculates the PAC Bayes bound.
Arguments:
- avg_loss (float): The loss averaged using Monte Carlo sampling.
- kl (Union[Tensor, float]): The Kullback-Leibler divergence between prior and posterior distributions.
- num_samples_bound (int): The number of data samples in the bound dataset.
- num_samples_loss (int): The number of Monte Carlo samples.
Returns:
Tuple[Union[Tensor, float], Union[Tensor, float]]: A tuple containing the calculated PAC Bayes bound and the upper bound of empirical risk.