core.utils.kl
1import math 2from abc import ABC, abstractmethod 3 4import torch 5 6 7class KLDivergenceInterface(ABC): 8 """ 9 An abstract base class for computing Kullback-Leibler Divergence (KL Divergence). 10 """ 11 12 @abstractmethod 13 def compute_kl(self, *args, **kwargs) -> torch.Tensor: 14 """ 15 Computes the Kullback-Leibler Divergence (KL Divergence) between two probability distributions. 16 17 Args: 18 *args: Variable length argument list. 19 **kwargs: Arbitrary keyword arguments. 20 21 Returns: 22 torch.Tensor: The computed KL Divergence. 23 """ 24 pass 25 26 27def inv_kl(qs, ks): 28 """ 29 Inversion of the binary KL divergence from (Not) Bounding the True Error by John Langford and Rich Caruana. 30 31 Parameters: 32 qs (float): Empirical risk. 33 ks (float): Second term for the binary KL divergence inversion. 34 35 Returns: 36 float: The computed inversion of the binary KL divergence. 37 """ 38 ikl = 0 39 izq = qs 40 dch = 1 - 1e-10 41 while True: 42 p = (izq + dch) * 0.5 43 if qs == 0: 44 ikl = ks - (0 + (1 - qs) * math.log((1 - qs) / (1 - p))) 45 elif qs == 1: 46 ikl = ks - (qs * math.log(qs / p) + 0) 47 else: 48 ikl = ks - (qs * math.log(qs / p) + (1 - qs) * math.log((1 - qs) / (1 - p))) 49 if ikl < 0: 50 dch = p 51 else: 52 izq = p 53 if (dch - izq) / dch < 1e-5: 54 break 55 return p
class
KLDivergenceInterface(abc.ABC):
8class KLDivergenceInterface(ABC): 9 """ 10 An abstract base class for computing Kullback-Leibler Divergence (KL Divergence). 11 """ 12 13 @abstractmethod 14 def compute_kl(self, *args, **kwargs) -> torch.Tensor: 15 """ 16 Computes the Kullback-Leibler Divergence (KL Divergence) between two probability distributions. 17 18 Args: 19 *args: Variable length argument list. 20 **kwargs: Arbitrary keyword arguments. 21 22 Returns: 23 torch.Tensor: The computed KL Divergence. 24 """ 25 pass
An abstract base class for computing Kullback-Leibler Divergence (KL Divergence).
@abstractmethod
def
compute_kl(self, *args, **kwargs) -> torch.Tensor:
13 @abstractmethod 14 def compute_kl(self, *args, **kwargs) -> torch.Tensor: 15 """ 16 Computes the Kullback-Leibler Divergence (KL Divergence) between two probability distributions. 17 18 Args: 19 *args: Variable length argument list. 20 **kwargs: Arbitrary keyword arguments. 21 22 Returns: 23 torch.Tensor: The computed KL Divergence. 24 """ 25 pass
Computes the Kullback-Leibler Divergence (KL Divergence) between two probability distributions.
Arguments:
- *args: Variable length argument list.
- **kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor: The computed KL Divergence.
def
inv_kl(qs, ks):
28def inv_kl(qs, ks): 29 """ 30 Inversion of the binary KL divergence from (Not) Bounding the True Error by John Langford and Rich Caruana. 31 32 Parameters: 33 qs (float): Empirical risk. 34 ks (float): Second term for the binary KL divergence inversion. 35 36 Returns: 37 float: The computed inversion of the binary KL divergence. 38 """ 39 ikl = 0 40 izq = qs 41 dch = 1 - 1e-10 42 while True: 43 p = (izq + dch) * 0.5 44 if qs == 0: 45 ikl = ks - (0 + (1 - qs) * math.log((1 - qs) / (1 - p))) 46 elif qs == 1: 47 ikl = ks - (qs * math.log(qs / p) + 0) 48 else: 49 ikl = ks - (qs * math.log(qs / p) + (1 - qs) * math.log((1 - qs) / (1 - p))) 50 if ikl < 0: 51 dch = p 52 else: 53 izq = p 54 if (dch - izq) / dch < 1e-5: 55 break 56 return p
Inversion of the binary KL divergence from (Not) Bounding the True Error by John Langford and Rich Caruana.
Arguments:
- qs (float): Empirical risk.
- ks (float): Second term for the binary KL divergence inversion.
Returns:
float: The computed inversion of the binary KL divergence.