core.objective.AbstractObjective

 1from abc import ABC, abstractmethod
 2
 3from torch import Tensor
 4
 5
 6class AbstractObjective(ABC):
 7    """
 8    Base class for PAC-Bayes training objectives.
 9
10    An objective typically combines:
11      - Empirical loss (e.g., negative log-likelihood)
12      - KL divergence between posterior and prior
13      - Additional terms for confidence or other bounding factors
14    """
15
16    @abstractmethod
17    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
18        """
19        Compute the combined objective scalar to be backpropagated.
20
21        Args:
22            loss (Tensor): Empirical loss, e.g. cross-entropy on a batch.
23            kl (Tensor): KL divergence between the current posterior and prior.
24            num_samples (float): Number of samples used or total dataset size,
25                used for scaling KL or other terms.
26
27        Returns:
28            Tensor: A scalar tensor that includes the loss, KL penalty, and any other terms.
29        """
30        pass
class AbstractObjective(abc.ABC):
 7class AbstractObjective(ABC):
 8    """
 9    Base class for PAC-Bayes training objectives.
10
11    An objective typically combines:
12      - Empirical loss (e.g., negative log-likelihood)
13      - KL divergence between posterior and prior
14      - Additional terms for confidence or other bounding factors
15    """
16
17    @abstractmethod
18    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
19        """
20        Compute the combined objective scalar to be backpropagated.
21
22        Args:
23            loss (Tensor): Empirical loss, e.g. cross-entropy on a batch.
24            kl (Tensor): KL divergence between the current posterior and prior.
25            num_samples (float): Number of samples used or total dataset size,
26                used for scaling KL or other terms.
27
28        Returns:
29            Tensor: A scalar tensor that includes the loss, KL penalty, and any other terms.
30        """
31        pass

Base class for PAC-Bayes training objectives.

An objective typically combines:
  • Empirical loss (e.g., negative log-likelihood)
  • KL divergence between posterior and prior
  • Additional terms for confidence or other bounding factors
@abstractmethod
def calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
17    @abstractmethod
18    def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor:
19        """
20        Compute the combined objective scalar to be backpropagated.
21
22        Args:
23            loss (Tensor): Empirical loss, e.g. cross-entropy on a batch.
24            kl (Tensor): KL divergence between the current posterior and prior.
25            num_samples (float): Number of samples used or total dataset size,
26                used for scaling KL or other terms.
27
28        Returns:
29            Tensor: A scalar tensor that includes the loss, KL penalty, and any other terms.
30        """
31        pass

Compute the combined objective scalar to be backpropagated.

Arguments:
  • loss (Tensor): Empirical loss, e.g. cross-entropy on a batch.
  • kl (Tensor): KL divergence between the current posterior and prior.
  • num_samples (float): Number of samples used or total dataset size, used for scaling KL or other terms.
Returns:

Tensor: A scalar tensor that includes the loss, KL penalty, and any other terms.