core.layer.utils

  1from collections.abc import Callable, Iterator
  2
  3import torch.nn as nn
  4
  5from core.layer import supported_layers
  6
  7LayerNameT = tuple[str, ...]
  8
  9
 10def get_layers(
 11    model: nn.Module, is_layer_func: Callable[[nn.Module], bool], names: list[str]
 12) -> Iterator[tuple[LayerNameT, nn.Module]]:
 13    """
 14    Recursively traverse a PyTorch model to find layers matching a given criterion.
 15
 16    This function performs a depth-first search over children of `model`.
 17    If `is_layer_func(layer)` is True, yield the path of layer names (as a tuple)
 18    and the layer object.
 19
 20    Args:
 21        model (nn.Module): The PyTorch model or submodule to traverse.
 22        is_layer_func (Callable): A predicate function that returns True if a layer
 23            matches the criterion (e.g., belongs to a certain set of layer types).
 24        names (List[str]): Accumulates the hierarchical names as we recurse.
 25
 26    Yields:
 27        Iterator[Tuple[LayerNameT, nn.Module]]: Tuples of (layer_name_path, layer_object).
 28    """
 29    for name, layer in model.named_children():
 30        if layer is not None:
 31            yield from get_layers(layer, is_layer_func, names + [name])
 32        if is_layer_func(layer):
 33            yield tuple(names + [name]), layer
 34
 35
 36def is_torch_layer(layer: nn.Module) -> bool:
 37    """
 38    Check if the given layer is a supported PyTorch layer in the framework.
 39
 40    Args:
 41        layer (nn.Module): A PyTorch layer or module.
 42
 43    Returns:
 44        bool: True if the layer's type is one of the framework's supported mappings.
 45    """
 46    return any(
 47        isinstance(layer, torch_layer)
 48        for torch_layer in supported_layers.LAYER_MAPPING
 49    )
 50
 51
 52def get_torch_layers(model: nn.Module) -> Iterator[tuple[LayerNameT, nn.Module]]:
 53    """
 54    Yield all supported PyTorch layers in the model.
 55
 56    Args:
 57        model (nn.Module): The PyTorch model to traverse.
 58
 59    Returns:
 60        Iterator[Tuple[LayerNameT, nn.Module]]: Each tuple is (path_of_names, layer).
 61    """
 62    return get_layers(model, is_torch_layer, names=[])
 63
 64
 65def is_bayesian_torch_layer(layer: nn.Module) -> bool:
 66    """
 67    Check if the layer belongs to a BayesianTorch-style layer,
 68    identified by having 'mu_weight', 'rho_weight', 'mu_bias', 'rho_bias'
 69    or the 'kernel' equivalent attributes.
 70
 71    Args:
 72        layer (nn.Module): A PyTorch module.
 73
 74    Returns:
 75        bool: True if the layer has BayesianTorch parameter attributes.
 76    """
 77    return (
 78        hasattr(layer, "mu_weight")
 79        and hasattr(layer, "rho_weight")
 80        and hasattr(layer, "mu_bias")
 81        and hasattr(layer, "rho_bias")
 82    ) or (
 83        hasattr(layer, "mu_kernel")
 84        and hasattr(layer, "rho_kernel")
 85        and hasattr(layer, "mu_bias")
 86        and hasattr(layer, "rho_bias")
 87    )
 88
 89
 90def get_bayesian_torch_layers(
 91    model: nn.Module,
 92) -> Iterator[tuple[LayerNameT, nn.Module]]:
 93    """
 94    Yield all layers in the model recognized as BayesianTorch layers,
 95    i.e., those containing 'mu_weight', 'rho_weight', etc.
 96
 97    Args:
 98        model (nn.Module): The PyTorch model to traverse.
 99
100    Returns:
101        Iterator[Tuple[LayerNameT, nn.Module]]: (layer_name_path, layer_object) for each Bayesian layer.
102    """
103    return get_layers(model, is_bayesian_torch_layer, names=[])
LayerNameT = tuple[str, ...]
def get_layers( model: torch.nn.modules.module.Module, is_layer_func: Callable[[torch.nn.modules.module.Module], bool], names: list[str]) -> Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]:
11def get_layers(
12    model: nn.Module, is_layer_func: Callable[[nn.Module], bool], names: list[str]
13) -> Iterator[tuple[LayerNameT, nn.Module]]:
14    """
15    Recursively traverse a PyTorch model to find layers matching a given criterion.
16
17    This function performs a depth-first search over children of `model`.
18    If `is_layer_func(layer)` is True, yield the path of layer names (as a tuple)
19    and the layer object.
20
21    Args:
22        model (nn.Module): The PyTorch model or submodule to traverse.
23        is_layer_func (Callable): A predicate function that returns True if a layer
24            matches the criterion (e.g., belongs to a certain set of layer types).
25        names (List[str]): Accumulates the hierarchical names as we recurse.
26
27    Yields:
28        Iterator[Tuple[LayerNameT, nn.Module]]: Tuples of (layer_name_path, layer_object).
29    """
30    for name, layer in model.named_children():
31        if layer is not None:
32            yield from get_layers(layer, is_layer_func, names + [name])
33        if is_layer_func(layer):
34            yield tuple(names + [name]), layer

Recursively traverse a PyTorch model to find layers matching a given criterion.

This function performs a depth-first search over children of model. If is_layer_func(layer) is True, yield the path of layer names (as a tuple) and the layer object.

Arguments:
  • model (nn.Module): The PyTorch model or submodule to traverse.
  • is_layer_func (Callable): A predicate function that returns True if a layer matches the criterion (e.g., belongs to a certain set of layer types).
  • names (List[str]): Accumulates the hierarchical names as we recurse.
Yields:

Iterator[Tuple[LayerNameT, nn.Module]]: Tuples of (layer_name_path, layer_object).

def is_torch_layer(layer: torch.nn.modules.module.Module) -> bool:
37def is_torch_layer(layer: nn.Module) -> bool:
38    """
39    Check if the given layer is a supported PyTorch layer in the framework.
40
41    Args:
42        layer (nn.Module): A PyTorch layer or module.
43
44    Returns:
45        bool: True if the layer's type is one of the framework's supported mappings.
46    """
47    return any(
48        isinstance(layer, torch_layer)
49        for torch_layer in supported_layers.LAYER_MAPPING
50    )

Check if the given layer is a supported PyTorch layer in the framework.

Arguments:
  • layer (nn.Module): A PyTorch layer or module.
Returns:

bool: True if the layer's type is one of the framework's supported mappings.

def get_torch_layers( model: torch.nn.modules.module.Module) -> Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]:
53def get_torch_layers(model: nn.Module) -> Iterator[tuple[LayerNameT, nn.Module]]:
54    """
55    Yield all supported PyTorch layers in the model.
56
57    Args:
58        model (nn.Module): The PyTorch model to traverse.
59
60    Returns:
61        Iterator[Tuple[LayerNameT, nn.Module]]: Each tuple is (path_of_names, layer).
62    """
63    return get_layers(model, is_torch_layer, names=[])

Yield all supported PyTorch layers in the model.

Arguments:
  • model (nn.Module): The PyTorch model to traverse.
Returns:

Iterator[Tuple[LayerNameT, nn.Module]]: Each tuple is (path_of_names, layer).

def is_bayesian_torch_layer(layer: torch.nn.modules.module.Module) -> bool:
66def is_bayesian_torch_layer(layer: nn.Module) -> bool:
67    """
68    Check if the layer belongs to a BayesianTorch-style layer,
69    identified by having 'mu_weight', 'rho_weight', 'mu_bias', 'rho_bias'
70    or the 'kernel' equivalent attributes.
71
72    Args:
73        layer (nn.Module): A PyTorch module.
74
75    Returns:
76        bool: True if the layer has BayesianTorch parameter attributes.
77    """
78    return (
79        hasattr(layer, "mu_weight")
80        and hasattr(layer, "rho_weight")
81        and hasattr(layer, "mu_bias")
82        and hasattr(layer, "rho_bias")
83    ) or (
84        hasattr(layer, "mu_kernel")
85        and hasattr(layer, "rho_kernel")
86        and hasattr(layer, "mu_bias")
87        and hasattr(layer, "rho_bias")
88    )

Check if the layer belongs to a BayesianTorch-style layer, identified by having 'mu_weight', 'rho_weight', 'mu_bias', 'rho_bias' or the 'kernel' equivalent attributes.

Arguments:
  • layer (nn.Module): A PyTorch module.
Returns:

bool: True if the layer has BayesianTorch parameter attributes.

def get_bayesian_torch_layers( model: torch.nn.modules.module.Module) -> Iterator[tuple[tuple[str, ...], torch.nn.modules.module.Module]]:
 91def get_bayesian_torch_layers(
 92    model: nn.Module,
 93) -> Iterator[tuple[LayerNameT, nn.Module]]:
 94    """
 95    Yield all layers in the model recognized as BayesianTorch layers,
 96    i.e., those containing 'mu_weight', 'rho_weight', etc.
 97
 98    Args:
 99        model (nn.Module): The PyTorch model to traverse.
100
101    Returns:
102        Iterator[Tuple[LayerNameT, nn.Module]]: (layer_name_path, layer_object) for each Bayesian layer.
103    """
104    return get_layers(model, is_bayesian_torch_layer, names=[])

Yield all layers in the model recognized as BayesianTorch layers, i.e., those containing 'mu_weight', 'rho_weight', etc.

Arguments:
  • model (nn.Module): The PyTorch model to traverse.
Returns:

Iterator[Tuple[LayerNameT, nn.Module]]: (layer_name_path, layer_object) for each Bayesian layer.