core.objective.FClassicObjective
1import numpy as np 2import torch 3from torch import Tensor 4 5from core.objective import AbstractObjective 6 7 8class FClassicObjective(AbstractObjective): 9 """ 10 The "f-classic" objective from Perez-Ortiz et al. (2021), which 11 combines empirical loss with a square-root bounding term involving KL and delta. 12 13 Typically used to ensure a PAC-Bayes bound is minimized during training. 14 """ 15 16 def __init__(self, kl_penalty: float, delta: float): 17 """ 18 Args: 19 kl_penalty (float): Coefficient for scaling KL divergence. 20 delta (float): Confidence parameter for the PAC-Bayes bound. 21 """ 22 self._kl_penalty = kl_penalty 23 self._delta = delta # confidence value for the training objective 24 25 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 26 """ 27 Compute the f-classic objective. 28 29 Args: 30 loss (Tensor): Empirical risk (e.g., average loss on a mini-batch). 31 kl (Tensor): KL divergence between posterior and prior. 32 num_samples (float): Number of samples or an equivalent scaling factor. 33 34 Returns: 35 Tensor: A scalar objective = loss + sqrt( (KL * kl_penalty + ln(2 sqrt(n)/delta)) / (2n) ). 36 """ 37 kl = kl * self._kl_penalty 38 kl_ratio = torch.div( 39 kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples 40 ) 41 return loss + torch.sqrt(kl_ratio)
9class FClassicObjective(AbstractObjective): 10 """ 11 The "f-classic" objective from Perez-Ortiz et al. (2021), which 12 combines empirical loss with a square-root bounding term involving KL and delta. 13 14 Typically used to ensure a PAC-Bayes bound is minimized during training. 15 """ 16 17 def __init__(self, kl_penalty: float, delta: float): 18 """ 19 Args: 20 kl_penalty (float): Coefficient for scaling KL divergence. 21 delta (float): Confidence parameter for the PAC-Bayes bound. 22 """ 23 self._kl_penalty = kl_penalty 24 self._delta = delta # confidence value for the training objective 25 26 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 27 """ 28 Compute the f-classic objective. 29 30 Args: 31 loss (Tensor): Empirical risk (e.g., average loss on a mini-batch). 32 kl (Tensor): KL divergence between posterior and prior. 33 num_samples (float): Number of samples or an equivalent scaling factor. 34 35 Returns: 36 Tensor: A scalar objective = loss + sqrt( (KL * kl_penalty + ln(2 sqrt(n)/delta)) / (2n) ). 37 """ 38 kl = kl * self._kl_penalty 39 kl_ratio = torch.div( 40 kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples 41 ) 42 return loss + torch.sqrt(kl_ratio)
The "f-classic" objective from Perez-Ortiz et al. (2021), which combines empirical loss with a square-root bounding term involving KL and delta.
Typically used to ensure a PAC-Bayes bound is minimized during training.
FClassicObjective(kl_penalty: float, delta: float)
17 def __init__(self, kl_penalty: float, delta: float): 18 """ 19 Args: 20 kl_penalty (float): Coefficient for scaling KL divergence. 21 delta (float): Confidence parameter for the PAC-Bayes bound. 22 """ 23 self._kl_penalty = kl_penalty 24 self._delta = delta # confidence value for the training objective
Arguments:
- kl_penalty (float): Coefficient for scaling KL divergence.
- delta (float): Confidence parameter for the PAC-Bayes bound.
def
calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
26 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 27 """ 28 Compute the f-classic objective. 29 30 Args: 31 loss (Tensor): Empirical risk (e.g., average loss on a mini-batch). 32 kl (Tensor): KL divergence between posterior and prior. 33 num_samples (float): Number of samples or an equivalent scaling factor. 34 35 Returns: 36 Tensor: A scalar objective = loss + sqrt( (KL * kl_penalty + ln(2 sqrt(n)/delta)) / (2n) ). 37 """ 38 kl = kl * self._kl_penalty 39 kl_ratio = torch.div( 40 kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples 41 ) 42 return loss + torch.sqrt(kl_ratio)
Compute the f-classic objective.
Arguments:
- loss (Tensor): Empirical risk (e.g., average loss on a mini-batch).
- kl (Tensor): KL divergence between posterior and prior.
- num_samples (float): Number of samples or an equivalent scaling factor.
Returns:
Tensor: A scalar objective = loss + sqrt( (KL * kl_penalty + ln(2 sqrt(n)/delta)) / (2n) ).