core.split_strategy.FaultySplitStrategy
1from dataclasses import dataclass 2 3import numpy as np 4import torch 5from torch.utils import data 6from torch.utils.data.sampler import SubsetRandomSampler 7 8from core.split_strategy import PBPSplitStrategy 9 10 11@dataclass 12class FaultySplitStrategy(PBPSplitStrategy): 13 """ 14 A specialized (and potentially buggy) subclass of PBPSplitStrategy that demonstrates 15 alternative splitting logic or partial overlaps between dataset subsets. 16 17 Fields: 18 posterior_loader (DataLoader): DataLoader for posterior training. 19 prior_loader (DataLoader): DataLoader for prior training. 20 val_loader (DataLoader): DataLoader for validation set. 21 test_loader (DataLoader): DataLoader for test set. 22 test_1batch (DataLoader): DataLoader for test set (one big batch). 23 bound_loader (DataLoader): DataLoader for bound evaluation set. 24 bound_loader_1batch (DataLoader): DataLoader for bound evaluation set (one big batch). 25 """ 26 27 # Posterior training 28 posterior_loader: data.dataloader.DataLoader = None 29 # Prior training 30 prior_loader: data.dataloader.DataLoader = None 31 # Evaluation 32 val_loader: data.dataloader.DataLoader = None 33 test_loader: data.dataloader.DataLoader = None 34 test_1batch: data.dataloader.DataLoader = None 35 # Bounds evaluation 36 bound_loader: data.dataloader.DataLoader = None 37 bound_loader_1batch: data.dataloader.DataLoader = None 38 39 def __init__( 40 self, 41 prior_type: str, 42 train_percent: float, 43 val_percent: float, 44 prior_percent: float, 45 self_certified: bool, 46 ): 47 """ 48 Initialize the FaultySplitStrategy with user-defined percentages 49 and flags for how to partition the dataset. 50 51 Args: 52 prior_type (str): A string indicating how the prior is handled (e.g. "not_learnt", "learnt"). 53 train_percent (float): Fraction of dataset to use for training. 54 val_percent (float): Fraction of dataset to use for validation. 55 prior_percent (float): Fraction of dataset to use for prior training. 56 self_certified (bool): If True, indicates self-certified approach to data splitting. 57 """ 58 super().__init__( 59 prior_type, train_percent, val_percent, prior_percent, self_certified 60 ) 61 62 def _split_not_learnt( 63 self, 64 train_dataset: data.Dataset, 65 test_dataset: data.Dataset, 66 split_config: dict, 67 loader_kwargs: dict, 68 ) -> None: 69 """ 70 Split the data for the case when the prior is not learned from data. 71 72 Args: 73 train_dataset (Dataset): The training dataset. 74 test_dataset (Dataset): The test dataset. 75 split_config (Dict): A configuration dictionary containing keys like 'batch_size', 'seed', etc. 76 loader_kwargs (Dict): Additional keyword arguments for DataLoader (e.g., num_workers). 77 """ 78 batch_size = split_config["batch_size"] 79 training_percent = self._train_percent 80 val_percent = self._val_percent 81 82 train_size = len(train_dataset.data) 83 test_size = len(test_dataset.data) 84 85 indices = list(range(train_size)) 86 split = int(np.round((training_percent) * train_size)) 87 np.random.seed(split_config["seed"]) 88 np.random.shuffle(indices) 89 90 if val_percent > 0.0: 91 # compute number of data points 92 indices = list(range(split)) 93 split_val = int(np.round((val_percent) * split)) 94 train_idx, val_idx = indices[split_val:], indices[:split_val] 95 else: 96 train_idx = indices[:split] 97 val_idx = None 98 99 train_sampler = SubsetRandomSampler(train_idx) 100 val_sampler = SubsetRandomSampler(val_idx) 101 102 self.posterior_loader = torch.utils.data.DataLoader( 103 train_dataset, 104 batch_size=batch_size, 105 sampler=train_sampler, 106 **loader_kwargs, 107 ) 108 # self.prior_loader = None 109 if val_idx: 110 self.val_loader = torch.utils.data.DataLoader( 111 train_dataset, 112 batch_size=batch_size, 113 sampler=val_sampler, 114 shuffle=False, 115 ) 116 self.test_loader = torch.utils.data.DataLoader( 117 test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs 118 ) 119 self.test_1batch = torch.utils.data.DataLoader( 120 test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs 121 ) 122 self.bound_loader = torch.utils.data.DataLoader( 123 train_dataset, 124 batch_size=batch_size, 125 sampler=train_sampler, 126 **loader_kwargs, 127 ) 128 self.bound_loader_1batch = torch.utils.data.DataLoader( 129 train_dataset, 130 batch_size=len(train_idx), 131 sampler=train_sampler, 132 **loader_kwargs, 133 ) 134 135 def _split_learnt_self_certified( 136 self, 137 train_dataset: data.Dataset, 138 test_dataset: data.Dataset, 139 split_config: dict, 140 loader_kwargs: dict, 141 ) -> None: 142 """ 143 Split logic when the prior is learned from data and we use a self-certified approach. 144 145 Args: 146 train_dataset (Dataset): The training dataset. 147 test_dataset (Dataset): The test dataset. 148 split_config (Dict): A dictionary of split settings (batch size, seed, etc.). 149 loader_kwargs (Dict): Keyword arguments for DataLoader initialization. 150 """ 151 batch_size = split_config["batch_size"] 152 training_percent = self._train_percent 153 val_percent = self._val_percent 154 prior_percent = self._prior_percent 155 156 n = len(train_dataset.data) + len(test_dataset.data) 157 158 # reduce training data if needed 159 new_num_train = int(np.round((training_percent) * n)) 160 indices = list(range(new_num_train)) 161 split = int(np.round((prior_percent) * new_num_train)) 162 np.random.seed(split_config["seed"]) 163 np.random.shuffle(indices) 164 165 all_train_sampler = SubsetRandomSampler(indices) 166 if val_percent > 0.0: 167 bound_idx = indices[split:] 168 indices_prior = list(range(split)) 169 _all_prior_sampler = SubsetRandomSampler(indices_prior) 170 split_val = int(np.round((val_percent) * split)) 171 prior_idx, val_idx = indices_prior[split_val:], indices_prior[:split_val] 172 else: 173 bound_idx, prior_idx = indices[split:], indices[:split] 174 val_idx = None 175 176 bound_sampler = SubsetRandomSampler(bound_idx) 177 prior_sampler = SubsetRandomSampler(prior_idx) 178 val_sampler = SubsetRandomSampler(val_idx) 179 180 final_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset]) 181 182 self.posterior_loader = torch.utils.data.DataLoader( 183 final_dataset, 184 batch_size=batch_size, 185 sampler=all_train_sampler, 186 shuffle=False, 187 **loader_kwargs, 188 ) 189 self.prior_loader = torch.utils.data.DataLoader( 190 final_dataset, 191 batch_size=batch_size, 192 sampler=prior_sampler, 193 shuffle=False, 194 **loader_kwargs, 195 ) 196 if val_idx: 197 self.val_loader = torch.utils.data.DataLoader( 198 final_dataset, 199 batch_size=batch_size, 200 sampler=val_sampler, 201 shuffle=False, 202 **loader_kwargs, 203 ) 204 # self.test_loader = None 205 # self.test_1batch = None 206 self.bound_loader = torch.utils.data.DataLoader( 207 final_dataset, 208 batch_size=batch_size, 209 sampler=bound_sampler, 210 shuffle=False, 211 **loader_kwargs, 212 ) 213 self.bound_loader_1batch = torch.utils.data.DataLoader( 214 final_dataset, 215 batch_size=len(bound_idx), 216 sampler=bound_sampler, 217 **loader_kwargs, 218 ) 219 220 def _split_learnt_not_self_certified( 221 self, 222 train_dataset: data.Dataset, 223 test_dataset: data.Dataset, 224 split_config: dict, 225 loader_kwargs: dict, 226 ) -> None: 227 """ 228 Split logic for a learned prior without self-certification. 229 230 Args: 231 train_dataset (Dataset): The training dataset. 232 test_dataset (Dataset): The test dataset. 233 split_config (Dict): Dictionary with split hyperparameters (batch size, seed, etc.). 234 loader_kwargs (Dict): Additional arguments for torch.utils.data.DataLoader. 235 """ 236 batch_size = split_config["batch_size"] 237 training_percent = self._train_percent 238 val_percent = self._val_percent 239 prior_percent = self._prior_percent 240 241 train_size = len(train_dataset.data) 242 test_size = len(test_dataset.data) 243 244 new_num_train = int(np.round((training_percent) * train_size)) 245 indices = list(range(new_num_train)) 246 split = int(np.round((prior_percent) * new_num_train)) 247 np.random.seed(split_config["seed"]) 248 np.random.shuffle(indices) 249 250 all_train_sampler = SubsetRandomSampler(indices) 251 # train_idx, valid_idx = indices[split:], indices[:split] 252 if val_percent > 0.0: 253 bound_idx = indices[split:] 254 indices_prior = list(range(split)) 255 _all_prior_sampler = SubsetRandomSampler(indices_prior) 256 split_val = int(np.round((val_percent) * split)) 257 prior_idx, val_idx = indices_prior[split_val:], indices_prior[:split_val] 258 else: 259 bound_idx, prior_idx = indices[split:], indices[:split] 260 val_idx = None 261 262 bound_sampler = SubsetRandomSampler(bound_idx) 263 prior_sampler = SubsetRandomSampler(prior_idx) 264 val_sampler = SubsetRandomSampler(val_idx) 265 266 self.posterior_loader = torch.utils.data.DataLoader( 267 train_dataset, 268 batch_size=batch_size, 269 sampler=all_train_sampler, 270 shuffle=False, 271 **loader_kwargs, 272 ) 273 self.prior_loader = torch.utils.data.DataLoader( 274 train_dataset, 275 batch_size=batch_size, 276 sampler=prior_sampler, 277 shuffle=False, 278 **loader_kwargs, 279 ) 280 if val_idx: 281 self.val_loader = torch.utils.data.DataLoader( 282 train_dataset, 283 batch_size=batch_size, 284 sampler=val_sampler, 285 shuffle=False, 286 **loader_kwargs, 287 ) 288 self.test_loader = torch.utils.data.DataLoader( 289 test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs 290 ) 291 self.test_1batch = torch.utils.data.DataLoader( 292 test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs 293 ) 294 self.bound_loader = torch.utils.data.DataLoader( 295 train_dataset, 296 batch_size=batch_size, 297 sampler=bound_sampler, 298 shuffle=False, 299 ) 300 self.bound_loader_1batch = torch.utils.data.DataLoader( 301 train_dataset, 302 batch_size=len(bound_idx), 303 sampler=bound_sampler, 304 **loader_kwargs, 305 )
12@dataclass 13class FaultySplitStrategy(PBPSplitStrategy): 14 """ 15 A specialized (and potentially buggy) subclass of PBPSplitStrategy that demonstrates 16 alternative splitting logic or partial overlaps between dataset subsets. 17 18 Fields: 19 posterior_loader (DataLoader): DataLoader for posterior training. 20 prior_loader (DataLoader): DataLoader for prior training. 21 val_loader (DataLoader): DataLoader for validation set. 22 test_loader (DataLoader): DataLoader for test set. 23 test_1batch (DataLoader): DataLoader for test set (one big batch). 24 bound_loader (DataLoader): DataLoader for bound evaluation set. 25 bound_loader_1batch (DataLoader): DataLoader for bound evaluation set (one big batch). 26 """ 27 28 # Posterior training 29 posterior_loader: data.dataloader.DataLoader = None 30 # Prior training 31 prior_loader: data.dataloader.DataLoader = None 32 # Evaluation 33 val_loader: data.dataloader.DataLoader = None 34 test_loader: data.dataloader.DataLoader = None 35 test_1batch: data.dataloader.DataLoader = None 36 # Bounds evaluation 37 bound_loader: data.dataloader.DataLoader = None 38 bound_loader_1batch: data.dataloader.DataLoader = None 39 40 def __init__( 41 self, 42 prior_type: str, 43 train_percent: float, 44 val_percent: float, 45 prior_percent: float, 46 self_certified: bool, 47 ): 48 """ 49 Initialize the FaultySplitStrategy with user-defined percentages 50 and flags for how to partition the dataset. 51 52 Args: 53 prior_type (str): A string indicating how the prior is handled (e.g. "not_learnt", "learnt"). 54 train_percent (float): Fraction of dataset to use for training. 55 val_percent (float): Fraction of dataset to use for validation. 56 prior_percent (float): Fraction of dataset to use for prior training. 57 self_certified (bool): If True, indicates self-certified approach to data splitting. 58 """ 59 super().__init__( 60 prior_type, train_percent, val_percent, prior_percent, self_certified 61 ) 62 63 def _split_not_learnt( 64 self, 65 train_dataset: data.Dataset, 66 test_dataset: data.Dataset, 67 split_config: dict, 68 loader_kwargs: dict, 69 ) -> None: 70 """ 71 Split the data for the case when the prior is not learned from data. 72 73 Args: 74 train_dataset (Dataset): The training dataset. 75 test_dataset (Dataset): The test dataset. 76 split_config (Dict): A configuration dictionary containing keys like 'batch_size', 'seed', etc. 77 loader_kwargs (Dict): Additional keyword arguments for DataLoader (e.g., num_workers). 78 """ 79 batch_size = split_config["batch_size"] 80 training_percent = self._train_percent 81 val_percent = self._val_percent 82 83 train_size = len(train_dataset.data) 84 test_size = len(test_dataset.data) 85 86 indices = list(range(train_size)) 87 split = int(np.round((training_percent) * train_size)) 88 np.random.seed(split_config["seed"]) 89 np.random.shuffle(indices) 90 91 if val_percent > 0.0: 92 # compute number of data points 93 indices = list(range(split)) 94 split_val = int(np.round((val_percent) * split)) 95 train_idx, val_idx = indices[split_val:], indices[:split_val] 96 else: 97 train_idx = indices[:split] 98 val_idx = None 99 100 train_sampler = SubsetRandomSampler(train_idx) 101 val_sampler = SubsetRandomSampler(val_idx) 102 103 self.posterior_loader = torch.utils.data.DataLoader( 104 train_dataset, 105 batch_size=batch_size, 106 sampler=train_sampler, 107 **loader_kwargs, 108 ) 109 # self.prior_loader = None 110 if val_idx: 111 self.val_loader = torch.utils.data.DataLoader( 112 train_dataset, 113 batch_size=batch_size, 114 sampler=val_sampler, 115 shuffle=False, 116 ) 117 self.test_loader = torch.utils.data.DataLoader( 118 test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs 119 ) 120 self.test_1batch = torch.utils.data.DataLoader( 121 test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs 122 ) 123 self.bound_loader = torch.utils.data.DataLoader( 124 train_dataset, 125 batch_size=batch_size, 126 sampler=train_sampler, 127 **loader_kwargs, 128 ) 129 self.bound_loader_1batch = torch.utils.data.DataLoader( 130 train_dataset, 131 batch_size=len(train_idx), 132 sampler=train_sampler, 133 **loader_kwargs, 134 ) 135 136 def _split_learnt_self_certified( 137 self, 138 train_dataset: data.Dataset, 139 test_dataset: data.Dataset, 140 split_config: dict, 141 loader_kwargs: dict, 142 ) -> None: 143 """ 144 Split logic when the prior is learned from data and we use a self-certified approach. 145 146 Args: 147 train_dataset (Dataset): The training dataset. 148 test_dataset (Dataset): The test dataset. 149 split_config (Dict): A dictionary of split settings (batch size, seed, etc.). 150 loader_kwargs (Dict): Keyword arguments for DataLoader initialization. 151 """ 152 batch_size = split_config["batch_size"] 153 training_percent = self._train_percent 154 val_percent = self._val_percent 155 prior_percent = self._prior_percent 156 157 n = len(train_dataset.data) + len(test_dataset.data) 158 159 # reduce training data if needed 160 new_num_train = int(np.round((training_percent) * n)) 161 indices = list(range(new_num_train)) 162 split = int(np.round((prior_percent) * new_num_train)) 163 np.random.seed(split_config["seed"]) 164 np.random.shuffle(indices) 165 166 all_train_sampler = SubsetRandomSampler(indices) 167 if val_percent > 0.0: 168 bound_idx = indices[split:] 169 indices_prior = list(range(split)) 170 _all_prior_sampler = SubsetRandomSampler(indices_prior) 171 split_val = int(np.round((val_percent) * split)) 172 prior_idx, val_idx = indices_prior[split_val:], indices_prior[:split_val] 173 else: 174 bound_idx, prior_idx = indices[split:], indices[:split] 175 val_idx = None 176 177 bound_sampler = SubsetRandomSampler(bound_idx) 178 prior_sampler = SubsetRandomSampler(prior_idx) 179 val_sampler = SubsetRandomSampler(val_idx) 180 181 final_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset]) 182 183 self.posterior_loader = torch.utils.data.DataLoader( 184 final_dataset, 185 batch_size=batch_size, 186 sampler=all_train_sampler, 187 shuffle=False, 188 **loader_kwargs, 189 ) 190 self.prior_loader = torch.utils.data.DataLoader( 191 final_dataset, 192 batch_size=batch_size, 193 sampler=prior_sampler, 194 shuffle=False, 195 **loader_kwargs, 196 ) 197 if val_idx: 198 self.val_loader = torch.utils.data.DataLoader( 199 final_dataset, 200 batch_size=batch_size, 201 sampler=val_sampler, 202 shuffle=False, 203 **loader_kwargs, 204 ) 205 # self.test_loader = None 206 # self.test_1batch = None 207 self.bound_loader = torch.utils.data.DataLoader( 208 final_dataset, 209 batch_size=batch_size, 210 sampler=bound_sampler, 211 shuffle=False, 212 **loader_kwargs, 213 ) 214 self.bound_loader_1batch = torch.utils.data.DataLoader( 215 final_dataset, 216 batch_size=len(bound_idx), 217 sampler=bound_sampler, 218 **loader_kwargs, 219 ) 220 221 def _split_learnt_not_self_certified( 222 self, 223 train_dataset: data.Dataset, 224 test_dataset: data.Dataset, 225 split_config: dict, 226 loader_kwargs: dict, 227 ) -> None: 228 """ 229 Split logic for a learned prior without self-certification. 230 231 Args: 232 train_dataset (Dataset): The training dataset. 233 test_dataset (Dataset): The test dataset. 234 split_config (Dict): Dictionary with split hyperparameters (batch size, seed, etc.). 235 loader_kwargs (Dict): Additional arguments for torch.utils.data.DataLoader. 236 """ 237 batch_size = split_config["batch_size"] 238 training_percent = self._train_percent 239 val_percent = self._val_percent 240 prior_percent = self._prior_percent 241 242 train_size = len(train_dataset.data) 243 test_size = len(test_dataset.data) 244 245 new_num_train = int(np.round((training_percent) * train_size)) 246 indices = list(range(new_num_train)) 247 split = int(np.round((prior_percent) * new_num_train)) 248 np.random.seed(split_config["seed"]) 249 np.random.shuffle(indices) 250 251 all_train_sampler = SubsetRandomSampler(indices) 252 # train_idx, valid_idx = indices[split:], indices[:split] 253 if val_percent > 0.0: 254 bound_idx = indices[split:] 255 indices_prior = list(range(split)) 256 _all_prior_sampler = SubsetRandomSampler(indices_prior) 257 split_val = int(np.round((val_percent) * split)) 258 prior_idx, val_idx = indices_prior[split_val:], indices_prior[:split_val] 259 else: 260 bound_idx, prior_idx = indices[split:], indices[:split] 261 val_idx = None 262 263 bound_sampler = SubsetRandomSampler(bound_idx) 264 prior_sampler = SubsetRandomSampler(prior_idx) 265 val_sampler = SubsetRandomSampler(val_idx) 266 267 self.posterior_loader = torch.utils.data.DataLoader( 268 train_dataset, 269 batch_size=batch_size, 270 sampler=all_train_sampler, 271 shuffle=False, 272 **loader_kwargs, 273 ) 274 self.prior_loader = torch.utils.data.DataLoader( 275 train_dataset, 276 batch_size=batch_size, 277 sampler=prior_sampler, 278 shuffle=False, 279 **loader_kwargs, 280 ) 281 if val_idx: 282 self.val_loader = torch.utils.data.DataLoader( 283 train_dataset, 284 batch_size=batch_size, 285 sampler=val_sampler, 286 shuffle=False, 287 **loader_kwargs, 288 ) 289 self.test_loader = torch.utils.data.DataLoader( 290 test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs 291 ) 292 self.test_1batch = torch.utils.data.DataLoader( 293 test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs 294 ) 295 self.bound_loader = torch.utils.data.DataLoader( 296 train_dataset, 297 batch_size=batch_size, 298 sampler=bound_sampler, 299 shuffle=False, 300 ) 301 self.bound_loader_1batch = torch.utils.data.DataLoader( 302 train_dataset, 303 batch_size=len(bound_idx), 304 sampler=bound_sampler, 305 **loader_kwargs, 306 )
A specialized (and potentially buggy) subclass of PBPSplitStrategy that demonstrates alternative splitting logic or partial overlaps between dataset subsets.
Fields:
posterior_loader (DataLoader): DataLoader for posterior training. prior_loader (DataLoader): DataLoader for prior training. val_loader (DataLoader): DataLoader for validation set. test_loader (DataLoader): DataLoader for test set. test_1batch (DataLoader): DataLoader for test set (one big batch). bound_loader (DataLoader): DataLoader for bound evaluation set. bound_loader_1batch (DataLoader): DataLoader for bound evaluation set (one big batch).
FaultySplitStrategy( prior_type: str, train_percent: float, val_percent: float, prior_percent: float, self_certified: bool)
40 def __init__( 41 self, 42 prior_type: str, 43 train_percent: float, 44 val_percent: float, 45 prior_percent: float, 46 self_certified: bool, 47 ): 48 """ 49 Initialize the FaultySplitStrategy with user-defined percentages 50 and flags for how to partition the dataset. 51 52 Args: 53 prior_type (str): A string indicating how the prior is handled (e.g. "not_learnt", "learnt"). 54 train_percent (float): Fraction of dataset to use for training. 55 val_percent (float): Fraction of dataset to use for validation. 56 prior_percent (float): Fraction of dataset to use for prior training. 57 self_certified (bool): If True, indicates self-certified approach to data splitting. 58 """ 59 super().__init__( 60 prior_type, train_percent, val_percent, prior_percent, self_certified 61 )
Initialize the FaultySplitStrategy with user-defined percentages and flags for how to partition the dataset.
Arguments:
- prior_type (str): A string indicating how the prior is handled (e.g. "not_learnt", "learnt").
- train_percent (float): Fraction of dataset to use for training.
- val_percent (float): Fraction of dataset to use for validation.
- prior_percent (float): Fraction of dataset to use for prior training.
- self_certified (bool): If True, indicates self-certified approach to data splitting.