core.objective.FQuadObjective

 1import numpy as np
 2import torch
 3from torch import Tensor
 4
 5from core.objective import AbstractObjective
 6
 7
 8class FQuadObjective(AbstractObjective):
 9    """
10    A "f-quad" objective from Perez-Ortiz et al. (2021), which involves
11    a quadratic expression derived from the PAC-Bayes bound.
12    """
13
14    def __init__(self, kl_penalty: float, delta: float):
15        """
16        Args:
17            kl_penalty (float): Coefficient to scale the KL term.
18            delta (float): Confidence parameter for the PAC-Bayes bound.
19        """
20        self._kl_penalty = kl_penalty
21        self._delta = delta  # confidence value for the training objective
22
23    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
24        """
25        Compute the f-quad objective.
26
27        This objective calculates:
28            ( sqrt(loss + ratio) + sqrt(ratio) )^2
29        where ratio = (KL + ln(2 sqrt(n)/delta)) / (2n).
30
31        Args:
32            loss (Tensor): Empirical loss.
33            kl (Tensor): KL divergence.
34            num_samples (float): Dataset size or similar factor.
35
36        Returns:
37            Tensor: The scalar objective value.
38        """
39        kl = kl * self._kl_penalty
40        kl_ratio = torch.div(
41            kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples
42        )
43        first_term = torch.sqrt(loss + kl_ratio)
44        second_term = torch.sqrt(kl_ratio)
45        return torch.pow(first_term + second_term, 2)
class FQuadObjective(core.objective.AbstractObjective.AbstractObjective):
 9class FQuadObjective(AbstractObjective):
10    """
11    A "f-quad" objective from Perez-Ortiz et al. (2021), which involves
12    a quadratic expression derived from the PAC-Bayes bound.
13    """
14
15    def __init__(self, kl_penalty: float, delta: float):
16        """
17        Args:
18            kl_penalty (float): Coefficient to scale the KL term.
19            delta (float): Confidence parameter for the PAC-Bayes bound.
20        """
21        self._kl_penalty = kl_penalty
22        self._delta = delta  # confidence value for the training objective
23
24    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
25        """
26        Compute the f-quad objective.
27
28        This objective calculates:
29            ( sqrt(loss + ratio) + sqrt(ratio) )^2
30        where ratio = (KL + ln(2 sqrt(n)/delta)) / (2n).
31
32        Args:
33            loss (Tensor): Empirical loss.
34            kl (Tensor): KL divergence.
35            num_samples (float): Dataset size or similar factor.
36
37        Returns:
38            Tensor: The scalar objective value.
39        """
40        kl = kl * self._kl_penalty
41        kl_ratio = torch.div(
42            kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples
43        )
44        first_term = torch.sqrt(loss + kl_ratio)
45        second_term = torch.sqrt(kl_ratio)
46        return torch.pow(first_term + second_term, 2)

A "f-quad" objective from Perez-Ortiz et al. (2021), which involves a quadratic expression derived from the PAC-Bayes bound.

FQuadObjective(kl_penalty: float, delta: float)
15    def __init__(self, kl_penalty: float, delta: float):
16        """
17        Args:
18            kl_penalty (float): Coefficient to scale the KL term.
19            delta (float): Confidence parameter for the PAC-Bayes bound.
20        """
21        self._kl_penalty = kl_penalty
22        self._delta = delta  # confidence value for the training objective
Arguments:
  • kl_penalty (float): Coefficient to scale the KL term.
  • delta (float): Confidence parameter for the PAC-Bayes bound.
def calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
24    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
25        """
26        Compute the f-quad objective.
27
28        This objective calculates:
29            ( sqrt(loss + ratio) + sqrt(ratio) )^2
30        where ratio = (KL + ln(2 sqrt(n)/delta)) / (2n).
31
32        Args:
33            loss (Tensor): Empirical loss.
34            kl (Tensor): KL divergence.
35            num_samples (float): Dataset size or similar factor.
36
37        Returns:
38            Tensor: The scalar objective value.
39        """
40        kl = kl * self._kl_penalty
41        kl_ratio = torch.div(
42            kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples
43        )
44        first_term = torch.sqrt(loss + kl_ratio)
45        second_term = torch.sqrt(kl_ratio)
46        return torch.pow(first_term + second_term, 2)

Compute the f-quad objective.

This objective calculates:

( sqrt(loss + ratio) + sqrt(ratio) )^2

where ratio = (KL + ln(2 sqrt(n)/delta)) / (2n).

Arguments:
  • loss (Tensor): Empirical loss.
  • kl (Tensor): KL divergence.
  • num_samples (float): Dataset size or similar factor.
Returns:

Tensor: The scalar objective value.