core.utils.kl

 1import math
 2from abc import ABC, abstractmethod
 3
 4import torch
 5
 6
 7class KLDivergenceInterface(ABC):
 8    """
 9    An abstract base class for computing Kullback-Leibler Divergence (KL Divergence).
10    """
11
12    @abstractmethod
13    def compute_kl(self, *args, **kwargs) -> torch.Tensor:
14        """
15        Computes the Kullback-Leibler Divergence (KL Divergence) between two probability distributions.
16
17        Args:
18            *args: Variable length argument list.
19            **kwargs: Arbitrary keyword arguments.
20
21        Returns:
22            torch.Tensor: The computed KL Divergence.
23        """
24        pass
25
26
27def inv_kl(qs, ks):
28    """
29    Inversion of the binary KL divergence from (Not) Bounding the True Error by John Langford and Rich Caruana.
30
31    Parameters:
32        qs (float): Empirical risk.
33        ks (float): Second term for the binary KL divergence inversion.
34
35    Returns:
36        float: The computed inversion of the binary KL divergence.
37    """
38    ikl = 0
39    izq = qs
40    dch = 1 - 1e-10
41    while True:
42        p = (izq + dch) * 0.5
43        if qs == 0:
44            ikl = ks - (0 + (1 - qs) * math.log((1 - qs) / (1 - p)))
45        elif qs == 1:
46            ikl = ks - (qs * math.log(qs / p) + 0)
47        else:
48            ikl = ks - (qs * math.log(qs / p) + (1 - qs) * math.log((1 - qs) / (1 - p)))
49        if ikl < 0:
50            dch = p
51        else:
52            izq = p
53        if (dch - izq) / dch < 1e-5:
54            break
55    return p
class KLDivergenceInterface(abc.ABC):
 8class KLDivergenceInterface(ABC):
 9    """
10    An abstract base class for computing Kullback-Leibler Divergence (KL Divergence).
11    """
12
13    @abstractmethod
14    def compute_kl(self, *args, **kwargs) -> torch.Tensor:
15        """
16        Computes the Kullback-Leibler Divergence (KL Divergence) between two probability distributions.
17
18        Args:
19            *args: Variable length argument list.
20            **kwargs: Arbitrary keyword arguments.
21
22        Returns:
23            torch.Tensor: The computed KL Divergence.
24        """
25        pass

An abstract base class for computing Kullback-Leibler Divergence (KL Divergence).

@abstractmethod
def compute_kl(self, *args, **kwargs) -> torch.Tensor:
13    @abstractmethod
14    def compute_kl(self, *args, **kwargs) -> torch.Tensor:
15        """
16        Computes the Kullback-Leibler Divergence (KL Divergence) between two probability distributions.
17
18        Args:
19            *args: Variable length argument list.
20            **kwargs: Arbitrary keyword arguments.
21
22        Returns:
23            torch.Tensor: The computed KL Divergence.
24        """
25        pass

Computes the Kullback-Leibler Divergence (KL Divergence) between two probability distributions.

Arguments:
  • *args: Variable length argument list.
  • **kwargs: Arbitrary keyword arguments.
Returns:

torch.Tensor: The computed KL Divergence.

def inv_kl(qs, ks):
28def inv_kl(qs, ks):
29    """
30    Inversion of the binary KL divergence from (Not) Bounding the True Error by John Langford and Rich Caruana.
31
32    Parameters:
33        qs (float): Empirical risk.
34        ks (float): Second term for the binary KL divergence inversion.
35
36    Returns:
37        float: The computed inversion of the binary KL divergence.
38    """
39    ikl = 0
40    izq = qs
41    dch = 1 - 1e-10
42    while True:
43        p = (izq + dch) * 0.5
44        if qs == 0:
45            ikl = ks - (0 + (1 - qs) * math.log((1 - qs) / (1 - p)))
46        elif qs == 1:
47            ikl = ks - (qs * math.log(qs / p) + 0)
48        else:
49            ikl = ks - (qs * math.log(qs / p) + (1 - qs) * math.log((1 - qs) / (1 - p)))
50        if ikl < 0:
51            dch = p
52        else:
53            izq = p
54        if (dch - izq) / dch < 1e-5:
55            break
56    return p

Inversion of the binary KL divergence from (Not) Bounding the True Error by John Langford and Rich Caruana.

Arguments:
  • qs (float): Empirical risk.
  • ks (float): Second term for the binary KL divergence inversion.
Returns:

float: The computed inversion of the binary KL divergence.