core.objective.AbstractObjective
1from abc import ABC, abstractmethod 2 3from torch import Tensor 4 5 6class AbstractObjective(ABC): 7 """ 8 Base class for PAC-Bayes training objectives. 9 10 An objective typically combines: 11 - Empirical loss (e.g., negative log-likelihood) 12 - KL divergence between posterior and prior 13 - Additional terms for confidence or other bounding factors 14 """ 15 16 @abstractmethod 17 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 18 """ 19 Compute the combined objective scalar to be backpropagated. 20 21 Args: 22 loss (Tensor): Empirical loss, e.g. cross-entropy on a batch. 23 kl (Tensor): KL divergence between the current posterior and prior. 24 num_samples (float): Number of samples used or total dataset size, 25 used for scaling KL or other terms. 26 27 Returns: 28 Tensor: A scalar tensor that includes the loss, KL penalty, and any other terms. 29 """ 30 pass
class
AbstractObjective(abc.ABC):
7class AbstractObjective(ABC): 8 """ 9 Base class for PAC-Bayes training objectives. 10 11 An objective typically combines: 12 - Empirical loss (e.g., negative log-likelihood) 13 - KL divergence between posterior and prior 14 - Additional terms for confidence or other bounding factors 15 """ 16 17 @abstractmethod 18 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 19 """ 20 Compute the combined objective scalar to be backpropagated. 21 22 Args: 23 loss (Tensor): Empirical loss, e.g. cross-entropy on a batch. 24 kl (Tensor): KL divergence between the current posterior and prior. 25 num_samples (float): Number of samples used or total dataset size, 26 used for scaling KL or other terms. 27 28 Returns: 29 Tensor: A scalar tensor that includes the loss, KL penalty, and any other terms. 30 """ 31 pass
Base class for PAC-Bayes training objectives.
An objective typically combines:
- Empirical loss (e.g., negative log-likelihood)
- KL divergence between posterior and prior
- Additional terms for confidence or other bounding factors
@abstractmethod
def
calculate( self, loss: torch.Tensor, kl: torch.Tensor, num_samples: float) -> torch.Tensor:
17 @abstractmethod 18 def calculate(self, loss: Tensor, kl: Tensor, num_samples: float) -> Tensor: 19 """ 20 Compute the combined objective scalar to be backpropagated. 21 22 Args: 23 loss (Tensor): Empirical loss, e.g. cross-entropy on a batch. 24 kl (Tensor): KL divergence between the current posterior and prior. 25 num_samples (float): Number of samples used or total dataset size, 26 used for scaling KL or other terms. 27 28 Returns: 29 Tensor: A scalar tensor that includes the loss, KL penalty, and any other terms. 30 """ 31 pass
Compute the combined objective scalar to be backpropagated.
Arguments:
- loss (Tensor): Empirical loss, e.g. cross-entropy on a batch.
- kl (Tensor): KL divergence between the current posterior and prior.
- num_samples (float): Number of samples used or total dataset size, used for scaling KL or other terms.
Returns:
Tensor: A scalar tensor that includes the loss, KL penalty, and any other terms.