core.objective.TolstikhinObjective

 1import numpy as np
 2import torch
 3from torch import Tensor
 4
 5from core.objective import AbstractObjective
 6
 7
 8class TolstikhinObjective(AbstractObjective):
 9    """
10    Objective related to Tolstikhin et al. (2013), featuring a combination of
11    the empirical loss, a square-root term involving KL, and an additional additive term.
12    """
13
14    def __init__(self, kl_penalty: float, delta: float):
15        """
16        Args:
17            kl_penalty (float): The coefficient multiplying the KL term.
18            delta (float): Confidence parameter in the PAC-Bayes or related bound.
19        """
20        self._kl_penalty = kl_penalty
21        self._delta = delta  # confidence value for the training objective
22
23    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
24        """
25        Compute the Tolstikhin objective.
26
27        The final expression includes:
28            loss + sqrt(2 * loss * ratio) + 2 * ratio
29        where ratio = (kl_penalty*KL + ln(2n) - ln(delta)) / (2n).
30
31        Args:
32            loss (Tensor): Empirical loss.
33            kl (Tensor): KL divergence.
34            num_samples (float): Number of data samples for normalization.
35
36        Returns:
37            Tensor: The scalar objective value.
38        """
39        kl = kl * self._kl_penalty
40        second_term = (
41            2
42            * loss
43            * torch.div(
44                kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples
45            )
46        )
47        third_term = 2 * torch.div(
48            kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples
49        )
50        return loss + torch.sqrt(second_term) + third_term
class TolstikhinObjective(core.objective.AbstractObjective.AbstractObjective):
 9class TolstikhinObjective(AbstractObjective):
10    """
11    Objective related to Tolstikhin et al. (2013), featuring a combination of
12    the empirical loss, a square-root term involving KL, and an additional additive term.
13    """
14
15    def __init__(self, kl_penalty: float, delta: float):
16        """
17        Args:
18            kl_penalty (float): The coefficient multiplying the KL term.
19            delta (float): Confidence parameter in the PAC-Bayes or related bound.
20        """
21        self._kl_penalty = kl_penalty
22        self._delta = delta  # confidence value for the training objective
23
24    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
25        """
26        Compute the Tolstikhin objective.
27
28        The final expression includes:
29            loss + sqrt(2 * loss * ratio) + 2 * ratio
30        where ratio = (kl_penalty*KL + ln(2n) - ln(delta)) / (2n).
31
32        Args:
33            loss (Tensor): Empirical loss.
34            kl (Tensor): KL divergence.
35            num_samples (float): Number of data samples for normalization.
36
37        Returns:
38            Tensor: The scalar objective value.
39        """
40        kl = kl * self._kl_penalty
41        second_term = (
42            2
43            * loss
44            * torch.div(
45                kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples
46            )
47        )
48        third_term = 2 * torch.div(
49            kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples
50        )
51        return loss + torch.sqrt(second_term) + third_term

Objective related to Tolstikhin et al. (2013), featuring a combination of the empirical loss, a square-root term involving KL, and an additional additive term.

TolstikhinObjective(kl_penalty: float, delta: float)
15    def __init__(self, kl_penalty: float, delta: float):
16        """
17        Args:
18            kl_penalty (float): The coefficient multiplying the KL term.
19            delta (float): Confidence parameter in the PAC-Bayes or related bound.
20        """
21        self._kl_penalty = kl_penalty
22        self._delta = delta  # confidence value for the training objective
Arguments:
  • kl_penalty (float): The coefficient multiplying the KL term.
  • delta (float): Confidence parameter in the PAC-Bayes or related bound.
def calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
24    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
25        """
26        Compute the Tolstikhin objective.
27
28        The final expression includes:
29            loss + sqrt(2 * loss * ratio) + 2 * ratio
30        where ratio = (kl_penalty*KL + ln(2n) - ln(delta)) / (2n).
31
32        Args:
33            loss (Tensor): Empirical loss.
34            kl (Tensor): KL divergence.
35            num_samples (float): Number of data samples for normalization.
36
37        Returns:
38            Tensor: The scalar objective value.
39        """
40        kl = kl * self._kl_penalty
41        second_term = (
42            2
43            * loss
44            * torch.div(
45                kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples
46            )
47        )
48        third_term = 2 * torch.div(
49            kl + np.log(2 * num_samples) - np.log(self._delta), 2 * num_samples
50        )
51        return loss + torch.sqrt(second_term) + third_term

Compute the Tolstikhin objective.

The final expression includes:

loss + sqrt(2 * loss * ratio) + 2 * ratio

where ratio = (kl_penalty*KL + ln(2n) - ln(delta)) / (2n).

Arguments:
  • loss (Tensor): Empirical loss.
  • kl (Tensor): KL divergence.
  • num_samples (float): Number of data samples for normalization.
Returns:

Tensor: The scalar objective value.