core.objective.FQuadObjective
1import numpy as np 2import torch 3from torch import Tensor 4 5from core.objective import AbstractObjective 6 7 8class FQuadObjective(AbstractObjective): 9 """ 10 A "f-quad" objective from Perez-Ortiz et al. (2021), which involves 11 a quadratic expression derived from the PAC-Bayes bound. 12 """ 13 14 def __init__(self, kl_penalty: float, delta: float): 15 """ 16 Args: 17 kl_penalty (float): Coefficient to scale the KL term. 18 delta (float): Confidence parameter for the PAC-Bayes bound. 19 """ 20 self._kl_penalty = kl_penalty 21 self._delta = delta # confidence value for the training objective 22 23 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 24 """ 25 Compute the f-quad objective. 26 27 This objective calculates: 28 ( sqrt(loss + ratio) + sqrt(ratio) )^2 29 where ratio = (KL + ln(2 sqrt(n)/delta)) / (2n). 30 31 Args: 32 loss (Tensor): Empirical loss. 33 kl (Tensor): KL divergence. 34 num_samples (float): Dataset size or similar factor. 35 36 Returns: 37 Tensor: The scalar objective value. 38 """ 39 kl = kl * self._kl_penalty 40 kl_ratio = torch.div( 41 kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples 42 ) 43 first_term = torch.sqrt(loss + kl_ratio) 44 second_term = torch.sqrt(kl_ratio) 45 return torch.pow(first_term + second_term, 2)
9class FQuadObjective(AbstractObjective): 10 """ 11 A "f-quad" objective from Perez-Ortiz et al. (2021), which involves 12 a quadratic expression derived from the PAC-Bayes bound. 13 """ 14 15 def __init__(self, kl_penalty: float, delta: float): 16 """ 17 Args: 18 kl_penalty (float): Coefficient to scale the KL term. 19 delta (float): Confidence parameter for the PAC-Bayes bound. 20 """ 21 self._kl_penalty = kl_penalty 22 self._delta = delta # confidence value for the training objective 23 24 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 25 """ 26 Compute the f-quad objective. 27 28 This objective calculates: 29 ( sqrt(loss + ratio) + sqrt(ratio) )^2 30 where ratio = (KL + ln(2 sqrt(n)/delta)) / (2n). 31 32 Args: 33 loss (Tensor): Empirical loss. 34 kl (Tensor): KL divergence. 35 num_samples (float): Dataset size or similar factor. 36 37 Returns: 38 Tensor: The scalar objective value. 39 """ 40 kl = kl * self._kl_penalty 41 kl_ratio = torch.div( 42 kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples 43 ) 44 first_term = torch.sqrt(loss + kl_ratio) 45 second_term = torch.sqrt(kl_ratio) 46 return torch.pow(first_term + second_term, 2)
A "f-quad" objective from Perez-Ortiz et al. (2021), which involves a quadratic expression derived from the PAC-Bayes bound.
FQuadObjective(kl_penalty: float, delta: float)
15 def __init__(self, kl_penalty: float, delta: float): 16 """ 17 Args: 18 kl_penalty (float): Coefficient to scale the KL term. 19 delta (float): Confidence parameter for the PAC-Bayes bound. 20 """ 21 self._kl_penalty = kl_penalty 22 self._delta = delta # confidence value for the training objective
Arguments:
- kl_penalty (float): Coefficient to scale the KL term.
- delta (float): Confidence parameter for the PAC-Bayes bound.
def
calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
24 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 25 """ 26 Compute the f-quad objective. 27 28 This objective calculates: 29 ( sqrt(loss + ratio) + sqrt(ratio) )^2 30 where ratio = (KL + ln(2 sqrt(n)/delta)) / (2n). 31 32 Args: 33 loss (Tensor): Empirical loss. 34 kl (Tensor): KL divergence. 35 num_samples (float): Dataset size or similar factor. 36 37 Returns: 38 Tensor: The scalar objective value. 39 """ 40 kl = kl * self._kl_penalty 41 kl_ratio = torch.div( 42 kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples 43 ) 44 first_term = torch.sqrt(loss + kl_ratio) 45 second_term = torch.sqrt(kl_ratio) 46 return torch.pow(first_term + second_term, 2)
Compute the f-quad objective.
This objective calculates:
( sqrt(loss + ratio) + sqrt(ratio) )^2
where ratio = (KL + ln(2 sqrt(n)/delta)) / (2n).
Arguments:
- loss (Tensor): Empirical loss.
- kl (Tensor): KL divergence.
- num_samples (float): Dataset size or similar factor.
Returns:
Tensor: The scalar objective value.