core.objective.FClassicObjective

 1import numpy as np
 2import torch
 3from torch import Tensor
 4
 5from core.objective import AbstractObjective
 6
 7
 8class FClassicObjective(AbstractObjective):
 9    """
10    The "f-classic" objective from Perez-Ortiz et al. (2021), which
11    combines empirical loss with a square-root bounding term involving KL and delta.
12
13    Typically used to ensure a PAC-Bayes bound is minimized during training.
14    """
15
16    def __init__(self, kl_penalty: float, delta: float):
17        """
18        Args:
19            kl_penalty (float): Coefficient for scaling KL divergence.
20            delta (float): Confidence parameter for the PAC-Bayes bound.
21        """
22        self._kl_penalty = kl_penalty
23        self._delta = delta  # confidence value for the training objective
24
25    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
26        """
27        Compute the f-classic objective.
28
29        Args:
30            loss (Tensor): Empirical risk (e.g., average loss on a mini-batch).
31            kl (Tensor): KL divergence between posterior and prior.
32            num_samples (float): Number of samples or an equivalent scaling factor.
33
34        Returns:
35            Tensor: A scalar objective = loss + sqrt( (KL * kl_penalty + ln(2 sqrt(n)/delta)) / (2n) ).
36        """
37        kl = kl * self._kl_penalty
38        kl_ratio = torch.div(
39            kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples
40        )
41        return loss + torch.sqrt(kl_ratio)
class FClassicObjective(core.objective.AbstractObjective.AbstractObjective):
 9class FClassicObjective(AbstractObjective):
10    """
11    The "f-classic" objective from Perez-Ortiz et al. (2021), which
12    combines empirical loss with a square-root bounding term involving KL and delta.
13
14    Typically used to ensure a PAC-Bayes bound is minimized during training.
15    """
16
17    def __init__(self, kl_penalty: float, delta: float):
18        """
19        Args:
20            kl_penalty (float): Coefficient for scaling KL divergence.
21            delta (float): Confidence parameter for the PAC-Bayes bound.
22        """
23        self._kl_penalty = kl_penalty
24        self._delta = delta  # confidence value for the training objective
25
26    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
27        """
28        Compute the f-classic objective.
29
30        Args:
31            loss (Tensor): Empirical risk (e.g., average loss on a mini-batch).
32            kl (Tensor): KL divergence between posterior and prior.
33            num_samples (float): Number of samples or an equivalent scaling factor.
34
35        Returns:
36            Tensor: A scalar objective = loss + sqrt( (KL * kl_penalty + ln(2 sqrt(n)/delta)) / (2n) ).
37        """
38        kl = kl * self._kl_penalty
39        kl_ratio = torch.div(
40            kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples
41        )
42        return loss + torch.sqrt(kl_ratio)

The "f-classic" objective from Perez-Ortiz et al. (2021), which combines empirical loss with a square-root bounding term involving KL and delta.

Typically used to ensure a PAC-Bayes bound is minimized during training.

FClassicObjective(kl_penalty: float, delta: float)
17    def __init__(self, kl_penalty: float, delta: float):
18        """
19        Args:
20            kl_penalty (float): Coefficient for scaling KL divergence.
21            delta (float): Confidence parameter for the PAC-Bayes bound.
22        """
23        self._kl_penalty = kl_penalty
24        self._delta = delta  # confidence value for the training objective
Arguments:
  • kl_penalty (float): Coefficient for scaling KL divergence.
  • delta (float): Confidence parameter for the PAC-Bayes bound.
def calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
26    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
27        """
28        Compute the f-classic objective.
29
30        Args:
31            loss (Tensor): Empirical risk (e.g., average loss on a mini-batch).
32            kl (Tensor): KL divergence between posterior and prior.
33            num_samples (float): Number of samples or an equivalent scaling factor.
34
35        Returns:
36            Tensor: A scalar objective = loss + sqrt( (KL * kl_penalty + ln(2 sqrt(n)/delta)) / (2n) ).
37        """
38        kl = kl * self._kl_penalty
39        kl_ratio = torch.div(
40            kl + np.log((2 * np.sqrt(num_samples)) / self._delta), 2 * num_samples
41        )
42        return loss + torch.sqrt(kl_ratio)

Compute the f-classic objective.

Arguments:
  • loss (Tensor): Empirical risk (e.g., average loss on a mini-batch).
  • kl (Tensor): KL divergence between posterior and prior.
  • num_samples (float): Number of samples or an equivalent scaling factor.
Returns:

Tensor: A scalar objective = loss + sqrt( (KL * kl_penalty + ln(2 sqrt(n)/delta)) / (2n) ).