core.layer.utils
1from collections.abc import Callable, Iterator 2 3import torch.nn as nn 4 5from core.layer import supported_layers 6 7LayerNameT = tuple[str, ...] 8 9 10def get_layers( 11 model: nn.Module, is_layer_func: Callable[[nn.Module], bool], names: list[str] 12) -> Iterator[tuple[LayerNameT, nn.Module]]: 13 """ 14 Recursively traverse a PyTorch model to find layers matching a given criterion. 15 16 This function performs a depth-first search over children of `model`. 17 If `is_layer_func(layer)` is True, yield the path of layer names (as a tuple) 18 and the layer object. 19 20 Args: 21 model (nn.Module): The PyTorch model or submodule to traverse. 22 is_layer_func (Callable): A predicate function that returns True if a layer 23 matches the criterion (e.g., belongs to a certain set of layer types). 24 names (List[str]): Accumulates the hierarchical names as we recurse. 25 26 Yields: 27 Iterator[Tuple[LayerNameT, nn.Module]]: Tuples of (layer_name_path, layer_object). 28 """ 29 for name, layer in model.named_children(): 30 if layer is not None: 31 yield from get_layers(layer, is_layer_func, names + [name]) 32 if is_layer_func(layer): 33 yield tuple(names + [name]), layer 34 35 36def is_torch_layer(layer: nn.Module) -> bool: 37 """ 38 Check if the given layer is a supported PyTorch layer in the framework. 39 40 Args: 41 layer (nn.Module): A PyTorch layer or module. 42 43 Returns: 44 bool: True if the layer's type is one of the framework's supported mappings. 45 """ 46 return any( 47 isinstance(layer, torch_layer) 48 for torch_layer in supported_layers.LAYER_MAPPING 49 ) 50 51 52def get_torch_layers(model: nn.Module) -> Iterator[tuple[LayerNameT, nn.Module]]: 53 """ 54 Yield all supported PyTorch layers in the model. 55 56 Args: 57 model (nn.Module): The PyTorch model to traverse. 58 59 Returns: 60 Iterator[Tuple[LayerNameT, nn.Module]]: Each tuple is (path_of_names, layer). 61 """ 62 return get_layers(model, is_torch_layer, names=[]) 63 64 65def is_bayesian_torch_layer(layer: nn.Module) -> bool: 66 """ 67 Check if the layer belongs to a BayesianTorch-style layer, 68 identified by having 'mu_weight', 'rho_weight', 'mu_bias', 'rho_bias' 69 or the 'kernel' equivalent attributes. 70 71 Args: 72 layer (nn.Module): A PyTorch module. 73 74 Returns: 75 bool: True if the layer has BayesianTorch parameter attributes. 76 """ 77 return ( 78 hasattr(layer, "mu_weight") 79 and hasattr(layer, "rho_weight") 80 and hasattr(layer, "mu_bias") 81 and hasattr(layer, "rho_bias") 82 ) or ( 83 hasattr(layer, "mu_kernel") 84 and hasattr(layer, "rho_kernel") 85 and hasattr(layer, "mu_bias") 86 and hasattr(layer, "rho_bias") 87 ) 88 89 90def get_bayesian_torch_layers( 91 model: nn.Module, 92) -> Iterator[tuple[LayerNameT, nn.Module]]: 93 """ 94 Yield all layers in the model recognized as BayesianTorch layers, 95 i.e., those containing 'mu_weight', 'rho_weight', etc. 96 97 Args: 98 model (nn.Module): The PyTorch model to traverse. 99 100 Returns: 101 Iterator[Tuple[LayerNameT, nn.Module]]: (layer_name_path, layer_object) for each Bayesian layer. 102 """ 103 return get_layers(model, is_bayesian_torch_layer, names=[])
11def get_layers( 12 model: nn.Module, is_layer_func: Callable[[nn.Module], bool], names: list[str] 13) -> Iterator[tuple[LayerNameT, nn.Module]]: 14 """ 15 Recursively traverse a PyTorch model to find layers matching a given criterion. 16 17 This function performs a depth-first search over children of `model`. 18 If `is_layer_func(layer)` is True, yield the path of layer names (as a tuple) 19 and the layer object. 20 21 Args: 22 model (nn.Module): The PyTorch model or submodule to traverse. 23 is_layer_func (Callable): A predicate function that returns True if a layer 24 matches the criterion (e.g., belongs to a certain set of layer types). 25 names (List[str]): Accumulates the hierarchical names as we recurse. 26 27 Yields: 28 Iterator[Tuple[LayerNameT, nn.Module]]: Tuples of (layer_name_path, layer_object). 29 """ 30 for name, layer in model.named_children(): 31 if layer is not None: 32 yield from get_layers(layer, is_layer_func, names + [name]) 33 if is_layer_func(layer): 34 yield tuple(names + [name]), layer
Recursively traverse a PyTorch model to find layers matching a given criterion.
This function performs a depth-first search over children of model.
If is_layer_func(layer) is True, yield the path of layer names (as a tuple)
and the layer object.
Arguments:
- model (nn.Module): The PyTorch model or submodule to traverse.
- is_layer_func (Callable): A predicate function that returns True if a layer matches the criterion (e.g., belongs to a certain set of layer types).
- names (List[str]): Accumulates the hierarchical names as we recurse.
Yields:
Iterator[Tuple[LayerNameT, nn.Module]]: Tuples of (layer_name_path, layer_object).
37def is_torch_layer(layer: nn.Module) -> bool: 38 """ 39 Check if the given layer is a supported PyTorch layer in the framework. 40 41 Args: 42 layer (nn.Module): A PyTorch layer or module. 43 44 Returns: 45 bool: True if the layer's type is one of the framework's supported mappings. 46 """ 47 return any( 48 isinstance(layer, torch_layer) 49 for torch_layer in supported_layers.LAYER_MAPPING 50 )
Check if the given layer is a supported PyTorch layer in the framework.
Arguments:
- layer (nn.Module): A PyTorch layer or module.
Returns:
bool: True if the layer's type is one of the framework's supported mappings.
53def get_torch_layers(model: nn.Module) -> Iterator[tuple[LayerNameT, nn.Module]]: 54 """ 55 Yield all supported PyTorch layers in the model. 56 57 Args: 58 model (nn.Module): The PyTorch model to traverse. 59 60 Returns: 61 Iterator[Tuple[LayerNameT, nn.Module]]: Each tuple is (path_of_names, layer). 62 """ 63 return get_layers(model, is_torch_layer, names=[])
Yield all supported PyTorch layers in the model.
Arguments:
- model (nn.Module): The PyTorch model to traverse.
Returns:
Iterator[Tuple[LayerNameT, nn.Module]]: Each tuple is (path_of_names, layer).
66def is_bayesian_torch_layer(layer: nn.Module) -> bool: 67 """ 68 Check if the layer belongs to a BayesianTorch-style layer, 69 identified by having 'mu_weight', 'rho_weight', 'mu_bias', 'rho_bias' 70 or the 'kernel' equivalent attributes. 71 72 Args: 73 layer (nn.Module): A PyTorch module. 74 75 Returns: 76 bool: True if the layer has BayesianTorch parameter attributes. 77 """ 78 return ( 79 hasattr(layer, "mu_weight") 80 and hasattr(layer, "rho_weight") 81 and hasattr(layer, "mu_bias") 82 and hasattr(layer, "rho_bias") 83 ) or ( 84 hasattr(layer, "mu_kernel") 85 and hasattr(layer, "rho_kernel") 86 and hasattr(layer, "mu_bias") 87 and hasattr(layer, "rho_bias") 88 )
Check if the layer belongs to a BayesianTorch-style layer, identified by having 'mu_weight', 'rho_weight', 'mu_bias', 'rho_bias' or the 'kernel' equivalent attributes.
Arguments:
- layer (nn.Module): A PyTorch module.
Returns:
bool: True if the layer has BayesianTorch parameter attributes.
91def get_bayesian_torch_layers( 92 model: nn.Module, 93) -> Iterator[tuple[LayerNameT, nn.Module]]: 94 """ 95 Yield all layers in the model recognized as BayesianTorch layers, 96 i.e., those containing 'mu_weight', 'rho_weight', etc. 97 98 Args: 99 model (nn.Module): The PyTorch model to traverse. 100 101 Returns: 102 Iterator[Tuple[LayerNameT, nn.Module]]: (layer_name_path, layer_object) for each Bayesian layer. 103 """ 104 return get_layers(model, is_bayesian_torch_layer, names=[])
Yield all layers in the model recognized as BayesianTorch layers, i.e., those containing 'mu_weight', 'rho_weight', etc.
Arguments:
- model (nn.Module): The PyTorch model to traverse.
Returns:
Iterator[Tuple[LayerNameT, nn.Module]]: (layer_name_path, layer_object) for each Bayesian layer.