core.split_strategy.PBPSplitStrategy
1from dataclasses import dataclass 2from typing import Union 3 4import numpy as np 5import torch 6from torch.utils import data 7from torch.utils.data.sampler import SubsetRandomSampler 8 9from core.split_strategy import AbstractSplitStrategy 10 11 12@dataclass 13class PBPSplitStrategy(AbstractSplitStrategy): 14 """ 15 A split strategy implementing a Prior-Posterior-Bound (PBP) partition of the dataset. 16 17 This strategy supports data splits for: 18 - Posterior training 19 - Prior training 20 - Validation 21 - Test 22 - Bound evaluation (data for calculating PAC-Bayes bounds) 23 24 By changing the internal splitting methods, one can adapt different scenarios such as: 25 - 'not_learnt': The prior is not trained. 26 - 'learnt': The prior is trained on some portion of the data. 27 - 'learnt_with_test': Similar to 'learnt', but includes an explicit test subset. 28 """ 29 30 # Posterior training 31 posterior_loader: data.dataloader.DataLoader = None 32 # Prior training 33 prior_loader: data.dataloader.DataLoader = None 34 # Evaluation 35 val_loader: data.dataloader.DataLoader = None 36 test_loader: data.dataloader.DataLoader = None 37 test_1batch: data.dataloader.DataLoader = None 38 # Bounds evaluation 39 bound_loader: data.dataloader.DataLoader = None 40 bound_loader_1batch: data.dataloader.DataLoader = None 41 42 def __init__( 43 self, 44 prior_type: str, 45 train_percent: float, 46 val_percent: float, 47 prior_percent: float, 48 self_certified: bool, 49 ): 50 """ 51 Initialize the PBPSplitStrategy with user-defined parameters for how to partition the data. 52 53 Args: 54 prior_type (str): Indicates whether the prior is "not_learnt", "learnt", or "learnt_with_test". 55 train_percent (float): Proportion of data used for training. 56 val_percent (float): Proportion of data used for validation. 57 prior_percent (float): Proportion of data used specifically for training the prior. 58 self_certified (bool): If True, indicates self-certified splitting approach. 59 """ 60 self._prior_type = prior_type 61 self._train_percent = train_percent 62 self._val_percent = val_percent 63 self._prior_percent = prior_percent 64 self._self_certified = self_certified 65 66 def _split_not_learnt( 67 self, 68 train_dataset: data.Dataset, 69 test_dataset: data.Dataset, 70 split_config: dict, 71 loader_kwargs: dict, 72 ) -> None: 73 """ 74 Split data for the scenario where the prior is not learned (e.g., a fixed prior). 75 76 Args: 77 train_dataset (Dataset): The dataset for training and possibly validation. 78 test_dataset (Dataset): The dataset for testing. 79 split_config (Dict): Dictionary with keys like 'batch_size', 'seed', etc. 80 loader_kwargs (Dict): Extra keyword arguments for DataLoader initialization. 81 """ 82 batch_size = split_config["batch_size"] 83 training_percent = self._train_percent 84 val_percent = self._val_percent 85 seed = split_config["seed"] 86 87 train_size = len(train_dataset.data) 88 test_size = len(test_dataset.data) 89 train_indices = list(range(train_size)) 90 np.random.seed(seed) 91 np.random.shuffle(train_indices) 92 93 # take fraction of a training dataset 94 training_split = int(np.round(training_percent * train_size)) 95 train_indices = train_indices[:training_split] 96 if val_percent > 0.0: 97 val_split = int(np.round(val_percent * training_split)) 98 train_idx = train_indices[val_split:] 99 val_idx = train_indices[:val_split] 100 else: 101 train_idx = train_indices 102 val_idx = None 103 104 train_sampler = SubsetRandomSampler(train_idx) 105 val_sampler = SubsetRandomSampler(val_idx) 106 107 self.posterior_loader = torch.utils.data.DataLoader( 108 train_dataset, 109 batch_size=batch_size, 110 sampler=train_sampler, 111 **loader_kwargs, 112 ) 113 # self.prior_loader = None 114 if val_idx: 115 self.val_loader = torch.utils.data.DataLoader( 116 train_dataset, 117 batch_size=batch_size, 118 sampler=val_sampler, 119 shuffle=False, 120 ) 121 self.test_loader = torch.utils.data.DataLoader( 122 test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs 123 ) 124 self.test_1batch = torch.utils.data.DataLoader( 125 test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs 126 ) 127 self.bound_loader = torch.utils.data.DataLoader( 128 train_dataset, 129 batch_size=batch_size, 130 sampler=train_sampler, 131 **loader_kwargs, 132 ) 133 self.bound_loader_1batch = torch.utils.data.DataLoader( 134 train_dataset, 135 batch_size=len(train_idx), 136 sampler=train_sampler, 137 **loader_kwargs, 138 ) 139 140 def _split_learnt_self_certified( 141 self, 142 train_dataset: data.Dataset, 143 test_dataset: data.Dataset, 144 split_config: dict, 145 loader_kwargs: dict, 146 ) -> None: 147 """ 148 Split data when the prior is learned and we use a self-certified approach (all data combined). 149 150 Args: 151 train_dataset (Dataset): Training dataset. 152 test_dataset (Dataset): Test dataset. 153 split_config (Dict): Contains config like 'batch_size', 'seed'. 154 loader_kwargs (Dict): Extra arguments for DataLoader. 155 """ 156 batch_size = split_config["batch_size"] 157 training_percent = self._train_percent 158 val_percent = self._val_percent 159 prior_percent = self._prior_percent 160 seed = split_config["seed"] 161 162 train_test_dataset = torch.utils.data.ConcatDataset( 163 [train_dataset, test_dataset] 164 ) 165 train_test_size = len(train_dataset.data) + len(test_dataset.data) 166 train_indices = list(range(train_test_size)) 167 np.random.seed(seed) 168 np.random.shuffle(train_indices) 169 # take fraction of a training dataset 170 training_test_split = int(np.round(training_percent * train_test_size)) 171 train_indices = train_indices[:training_test_split] 172 173 if val_percent > 0.0: 174 prior_split = int(np.round(prior_percent * training_test_split)) 175 bound_idx, prior_val_idx = ( 176 train_indices[prior_split:], 177 train_indices[:prior_split], 178 ) 179 val_split = int(np.round(val_percent * prior_split)) 180 prior_idx, val_idx = ( 181 prior_val_idx[val_split:], 182 prior_val_idx[:val_split], 183 ) 184 else: 185 prior_split = int(np.round(prior_percent * training_test_split)) 186 bound_idx, prior_idx = ( 187 train_indices[prior_split:], 188 train_indices[:prior_split], 189 ) 190 val_idx = None 191 192 train_test_sampler = SubsetRandomSampler(train_indices) 193 bound_sampler = SubsetRandomSampler(bound_idx) 194 prior_sampler = SubsetRandomSampler(prior_idx) 195 val_sampler = SubsetRandomSampler(val_idx) 196 197 self.posterior_loader = torch.utils.data.DataLoader( 198 train_test_dataset, 199 batch_size=batch_size, 200 sampler=train_test_sampler, 201 shuffle=False, 202 **loader_kwargs, 203 ) 204 self.prior_loader = torch.utils.data.DataLoader( 205 train_test_dataset, 206 batch_size=batch_size, 207 sampler=prior_sampler, 208 shuffle=False, 209 **loader_kwargs, 210 ) 211 if val_idx: 212 self.val_loader = torch.utils.data.DataLoader( 213 train_test_dataset, 214 batch_size=batch_size, 215 sampler=val_sampler, 216 shuffle=False, 217 **loader_kwargs, 218 ) 219 # self.test_loader = None 220 # self.test_1batch = None 221 self.bound_loader = torch.utils.data.DataLoader( 222 train_test_dataset, 223 batch_size=batch_size, 224 sampler=bound_sampler, 225 shuffle=False, 226 **loader_kwargs, 227 ) 228 self.bound_loader_1batch = torch.utils.data.DataLoader( 229 train_test_dataset, 230 batch_size=len(bound_idx), 231 sampler=bound_sampler, 232 **loader_kwargs, 233 ) 234 235 def _split_learnt_self_certified_with_test( 236 self, 237 train_dataset: data.Dataset, 238 test_dataset: data.Dataset, 239 split_config: dict, 240 loader_kwargs: dict, 241 ) -> None: 242 """ 243 Similar to `_split_learnt_self_certified`, but explicitly keeps a separate test set. 244 245 Args: 246 train_dataset (Dataset): Training portion of the data. 247 test_dataset (Dataset): Test portion of the data. 248 split_config (Dict): Contains parameters such as 'batch_size', 'seed', etc. 249 loader_kwargs (Dict): Extra DataLoader arguments. 250 """ 251 batch_size = split_config["batch_size"] 252 training_percent = self._train_percent 253 val_percent = self._val_percent 254 prior_percent = self._prior_percent 255 seed = split_config["seed"] 256 257 train_test_dataset = torch.utils.data.ConcatDataset( 258 [train_dataset, test_dataset] 259 ) 260 train_test_size = len(train_dataset.data) + len(test_dataset.data) 261 train_test_indices = list(range(train_test_size)) 262 np.random.seed(seed) 263 np.random.shuffle(train_test_indices) 264 # take fraction of a training dataset 265 training_test_split = int(np.round(training_percent * train_test_size)) 266 train_indices = train_test_indices[:training_test_split] 267 test_indices = train_test_indices[training_test_split:] 268 269 if val_percent > 0.0: 270 prior_split = int(np.round(prior_percent * training_test_split)) 271 bound_idx, prior_val_idx = ( 272 train_indices[prior_split:], 273 train_indices[:prior_split], 274 ) 275 val_split = int(np.round(val_percent * prior_split)) 276 prior_idx, val_idx = ( 277 prior_val_idx[val_split:], 278 prior_val_idx[:val_split], 279 ) 280 else: 281 prior_split = int(np.round(prior_percent * training_test_split)) 282 bound_idx, prior_idx = ( 283 train_indices[prior_split:], 284 train_indices[:prior_split], 285 ) 286 val_idx = None 287 288 train_sampler = SubsetRandomSampler(train_indices) 289 bound_sampler = SubsetRandomSampler(bound_idx) 290 prior_sampler = SubsetRandomSampler(prior_idx) 291 val_sampler = SubsetRandomSampler(val_idx) 292 test_sampler = SubsetRandomSampler(test_indices) 293 294 self.posterior_loader = torch.utils.data.DataLoader( 295 train_test_dataset, 296 batch_size=batch_size, 297 sampler=train_sampler, 298 shuffle=False, 299 **loader_kwargs, 300 ) 301 self.prior_loader = torch.utils.data.DataLoader( 302 train_test_dataset, 303 batch_size=batch_size, 304 sampler=prior_sampler, 305 shuffle=False, 306 **loader_kwargs, 307 ) 308 if val_idx: 309 self.val_loader = torch.utils.data.DataLoader( 310 train_test_dataset, 311 batch_size=batch_size, 312 sampler=val_sampler, 313 shuffle=False, 314 **loader_kwargs, 315 ) 316 if len(test_indices) > 0: 317 self.test_loader = torch.utils.data.DataLoader( 318 train_test_dataset, 319 batch_size=batch_size, 320 sampler=test_sampler, 321 shuffle=False, 322 **loader_kwargs, 323 ) 324 self.test_loader_1batch = torch.utils.data.DataLoader( 325 train_test_dataset, 326 batch_size=len(test_indices), 327 sampler=test_sampler, 328 **loader_kwargs, 329 ) 330 else: 331 self.test_loader = None 332 self.test_loader_1batch = None 333 self.bound_loader = torch.utils.data.DataLoader( 334 train_test_dataset, 335 batch_size=batch_size, 336 sampler=bound_sampler, 337 shuffle=False, 338 **loader_kwargs, 339 ) 340 self.bound_loader_1batch = torch.utils.data.DataLoader( 341 train_test_dataset, 342 batch_size=len(bound_idx), 343 sampler=bound_sampler, 344 **loader_kwargs, 345 ) 346 347 def _split_learnt_not_self_certified( 348 self, 349 train_dataset: data.Dataset, 350 test_dataset: data.Dataset, 351 split_config: dict, 352 loader_kwargs: dict, 353 ) -> None: 354 """ 355 Split data when the prior is learned but not using a self-certified approach. 356 357 Args: 358 train_dataset (Dataset): The training dataset. 359 test_dataset (Dataset): The testing dataset. 360 split_config (Dict): Dictionary with keys (e.g., batch_size, seed, etc.). 361 loader_kwargs (Dict): Additional params for DataLoader creation. 362 """ 363 batch_size = split_config["batch_size"] 364 training_percent = self._train_percent 365 val_percent = self._val_percent 366 prior_percent = self._prior_percent 367 seed = split_config["seed"] 368 369 train_size = len(train_dataset.data) 370 test_size = len(test_dataset.data) 371 train_indices = list(range(train_size)) 372 # TODO: no need to shuffle because of SubsetRandomSampler 373 np.random.seed(seed) 374 np.random.shuffle(train_indices) 375 376 training_split = int(np.round(training_percent * train_size)) 377 train_indices = train_indices[:training_split] 378 379 if val_percent > 0.0: 380 prior_split = int(np.round(prior_percent * training_split)) 381 bound_idx, prior_val_idx = ( 382 train_indices[prior_split:], 383 train_indices[:prior_split], 384 ) 385 val_split = int(np.round(val_percent * prior_split)) 386 prior_idx, val_idx = ( 387 prior_val_idx[val_split:], 388 prior_val_idx[:val_split], 389 ) 390 else: 391 prior_split = int(np.round(prior_percent * training_split)) 392 bound_idx, prior_idx = ( 393 train_indices[prior_split:], 394 train_indices[:prior_split], 395 ) 396 val_idx = None 397 398 train_sampler = SubsetRandomSampler(train_indices) 399 bound_sampler = SubsetRandomSampler(bound_idx) 400 prior_sampler = SubsetRandomSampler(prior_idx) 401 val_sampler = SubsetRandomSampler(val_idx) 402 403 self.posterior_loader = torch.utils.data.DataLoader( 404 train_dataset, 405 batch_size=batch_size, 406 sampler=train_sampler, 407 shuffle=False, 408 **loader_kwargs, 409 ) 410 self.prior_loader = torch.utils.data.DataLoader( 411 train_dataset, 412 batch_size=batch_size, 413 sampler=prior_sampler, 414 shuffle=False, 415 **loader_kwargs, 416 ) 417 if val_idx: 418 self.val_loader = torch.utils.data.DataLoader( 419 train_dataset, 420 batch_size=batch_size, 421 sampler=val_sampler, 422 shuffle=False, 423 **loader_kwargs, 424 ) 425 self.test_loader = torch.utils.data.DataLoader( 426 test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs 427 ) 428 self.test_1batch = torch.utils.data.DataLoader( 429 test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs 430 ) 431 self.bound_loader = torch.utils.data.DataLoader( 432 train_dataset, 433 batch_size=batch_size, 434 sampler=bound_sampler, 435 shuffle=False, 436 ) 437 self.bound_loader_1batch = torch.utils.data.DataLoader( 438 train_dataset, 439 batch_size=len(bound_idx), 440 sampler=bound_sampler, 441 **loader_kwargs, 442 ) 443 444 def split( 445 self, dataset_loader: Union["MNISTLoader", "CIFAR10Loader"], split_config: dict 446 ) -> None: 447 """ 448 Public method to perform the split operation on a dataset loader, 449 setting up DataLoaders for prior, posterior, validation, testing, and bound evaluation. 450 451 Args: 452 dataset_loader (Union[MNISTLoader, CIFAR10Loader]): A dataset loader instance 453 providing `load(dataset_loader_seed)` to retrieve train/test datasets. 454 split_config (Dict): Configuration parameters for splitting (e.g., batch_size, seed). 455 """ 456 dataset_loader_seed = split_config["dataset_loader_seed"] 457 train_dataset, test_dataset = dataset_loader.load(dataset_loader_seed) 458 459 loader_kwargs = ( 460 {"num_workers": 1, "pin_memory": True} if torch.cuda.is_available() else {} 461 ) 462 463 if self._prior_type == "not_learnt": 464 self._split_not_learnt( 465 train_dataset=train_dataset, 466 test_dataset=test_dataset, 467 split_config=split_config, 468 loader_kwargs=loader_kwargs, 469 ) 470 elif self._prior_type == "learnt": 471 if self._self_certified: 472 self._split_learnt_self_certified( 473 train_dataset=train_dataset, 474 test_dataset=test_dataset, 475 split_config=split_config, 476 loader_kwargs=loader_kwargs, 477 ) 478 else: 479 self._split_learnt_not_self_certified( 480 train_dataset=train_dataset, 481 test_dataset=test_dataset, 482 split_config=split_config, 483 loader_kwargs=loader_kwargs, 484 ) 485 elif self._prior_type == "learnt_with_test": 486 if self._self_certified: 487 self._split_learnt_self_certified_with_test( 488 train_dataset=train_dataset, 489 test_dataset=test_dataset, 490 split_config=split_config, 491 loader_kwargs=loader_kwargs, 492 ) 493 else: 494 raise ValueError(f"Invalid prior_type: {self._prior_type}") 495 else: 496 raise ValueError(f"Invalid prior_type: {self._prior_type}")
13@dataclass 14class PBPSplitStrategy(AbstractSplitStrategy): 15 """ 16 A split strategy implementing a Prior-Posterior-Bound (PBP) partition of the dataset. 17 18 This strategy supports data splits for: 19 - Posterior training 20 - Prior training 21 - Validation 22 - Test 23 - Bound evaluation (data for calculating PAC-Bayes bounds) 24 25 By changing the internal splitting methods, one can adapt different scenarios such as: 26 - 'not_learnt': The prior is not trained. 27 - 'learnt': The prior is trained on some portion of the data. 28 - 'learnt_with_test': Similar to 'learnt', but includes an explicit test subset. 29 """ 30 31 # Posterior training 32 posterior_loader: data.dataloader.DataLoader = None 33 # Prior training 34 prior_loader: data.dataloader.DataLoader = None 35 # Evaluation 36 val_loader: data.dataloader.DataLoader = None 37 test_loader: data.dataloader.DataLoader = None 38 test_1batch: data.dataloader.DataLoader = None 39 # Bounds evaluation 40 bound_loader: data.dataloader.DataLoader = None 41 bound_loader_1batch: data.dataloader.DataLoader = None 42 43 def __init__( 44 self, 45 prior_type: str, 46 train_percent: float, 47 val_percent: float, 48 prior_percent: float, 49 self_certified: bool, 50 ): 51 """ 52 Initialize the PBPSplitStrategy with user-defined parameters for how to partition the data. 53 54 Args: 55 prior_type (str): Indicates whether the prior is "not_learnt", "learnt", or "learnt_with_test". 56 train_percent (float): Proportion of data used for training. 57 val_percent (float): Proportion of data used for validation. 58 prior_percent (float): Proportion of data used specifically for training the prior. 59 self_certified (bool): If True, indicates self-certified splitting approach. 60 """ 61 self._prior_type = prior_type 62 self._train_percent = train_percent 63 self._val_percent = val_percent 64 self._prior_percent = prior_percent 65 self._self_certified = self_certified 66 67 def _split_not_learnt( 68 self, 69 train_dataset: data.Dataset, 70 test_dataset: data.Dataset, 71 split_config: dict, 72 loader_kwargs: dict, 73 ) -> None: 74 """ 75 Split data for the scenario where the prior is not learned (e.g., a fixed prior). 76 77 Args: 78 train_dataset (Dataset): The dataset for training and possibly validation. 79 test_dataset (Dataset): The dataset for testing. 80 split_config (Dict): Dictionary with keys like 'batch_size', 'seed', etc. 81 loader_kwargs (Dict): Extra keyword arguments for DataLoader initialization. 82 """ 83 batch_size = split_config["batch_size"] 84 training_percent = self._train_percent 85 val_percent = self._val_percent 86 seed = split_config["seed"] 87 88 train_size = len(train_dataset.data) 89 test_size = len(test_dataset.data) 90 train_indices = list(range(train_size)) 91 np.random.seed(seed) 92 np.random.shuffle(train_indices) 93 94 # take fraction of a training dataset 95 training_split = int(np.round(training_percent * train_size)) 96 train_indices = train_indices[:training_split] 97 if val_percent > 0.0: 98 val_split = int(np.round(val_percent * training_split)) 99 train_idx = train_indices[val_split:] 100 val_idx = train_indices[:val_split] 101 else: 102 train_idx = train_indices 103 val_idx = None 104 105 train_sampler = SubsetRandomSampler(train_idx) 106 val_sampler = SubsetRandomSampler(val_idx) 107 108 self.posterior_loader = torch.utils.data.DataLoader( 109 train_dataset, 110 batch_size=batch_size, 111 sampler=train_sampler, 112 **loader_kwargs, 113 ) 114 # self.prior_loader = None 115 if val_idx: 116 self.val_loader = torch.utils.data.DataLoader( 117 train_dataset, 118 batch_size=batch_size, 119 sampler=val_sampler, 120 shuffle=False, 121 ) 122 self.test_loader = torch.utils.data.DataLoader( 123 test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs 124 ) 125 self.test_1batch = torch.utils.data.DataLoader( 126 test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs 127 ) 128 self.bound_loader = torch.utils.data.DataLoader( 129 train_dataset, 130 batch_size=batch_size, 131 sampler=train_sampler, 132 **loader_kwargs, 133 ) 134 self.bound_loader_1batch = torch.utils.data.DataLoader( 135 train_dataset, 136 batch_size=len(train_idx), 137 sampler=train_sampler, 138 **loader_kwargs, 139 ) 140 141 def _split_learnt_self_certified( 142 self, 143 train_dataset: data.Dataset, 144 test_dataset: data.Dataset, 145 split_config: dict, 146 loader_kwargs: dict, 147 ) -> None: 148 """ 149 Split data when the prior is learned and we use a self-certified approach (all data combined). 150 151 Args: 152 train_dataset (Dataset): Training dataset. 153 test_dataset (Dataset): Test dataset. 154 split_config (Dict): Contains config like 'batch_size', 'seed'. 155 loader_kwargs (Dict): Extra arguments for DataLoader. 156 """ 157 batch_size = split_config["batch_size"] 158 training_percent = self._train_percent 159 val_percent = self._val_percent 160 prior_percent = self._prior_percent 161 seed = split_config["seed"] 162 163 train_test_dataset = torch.utils.data.ConcatDataset( 164 [train_dataset, test_dataset] 165 ) 166 train_test_size = len(train_dataset.data) + len(test_dataset.data) 167 train_indices = list(range(train_test_size)) 168 np.random.seed(seed) 169 np.random.shuffle(train_indices) 170 # take fraction of a training dataset 171 training_test_split = int(np.round(training_percent * train_test_size)) 172 train_indices = train_indices[:training_test_split] 173 174 if val_percent > 0.0: 175 prior_split = int(np.round(prior_percent * training_test_split)) 176 bound_idx, prior_val_idx = ( 177 train_indices[prior_split:], 178 train_indices[:prior_split], 179 ) 180 val_split = int(np.round(val_percent * prior_split)) 181 prior_idx, val_idx = ( 182 prior_val_idx[val_split:], 183 prior_val_idx[:val_split], 184 ) 185 else: 186 prior_split = int(np.round(prior_percent * training_test_split)) 187 bound_idx, prior_idx = ( 188 train_indices[prior_split:], 189 train_indices[:prior_split], 190 ) 191 val_idx = None 192 193 train_test_sampler = SubsetRandomSampler(train_indices) 194 bound_sampler = SubsetRandomSampler(bound_idx) 195 prior_sampler = SubsetRandomSampler(prior_idx) 196 val_sampler = SubsetRandomSampler(val_idx) 197 198 self.posterior_loader = torch.utils.data.DataLoader( 199 train_test_dataset, 200 batch_size=batch_size, 201 sampler=train_test_sampler, 202 shuffle=False, 203 **loader_kwargs, 204 ) 205 self.prior_loader = torch.utils.data.DataLoader( 206 train_test_dataset, 207 batch_size=batch_size, 208 sampler=prior_sampler, 209 shuffle=False, 210 **loader_kwargs, 211 ) 212 if val_idx: 213 self.val_loader = torch.utils.data.DataLoader( 214 train_test_dataset, 215 batch_size=batch_size, 216 sampler=val_sampler, 217 shuffle=False, 218 **loader_kwargs, 219 ) 220 # self.test_loader = None 221 # self.test_1batch = None 222 self.bound_loader = torch.utils.data.DataLoader( 223 train_test_dataset, 224 batch_size=batch_size, 225 sampler=bound_sampler, 226 shuffle=False, 227 **loader_kwargs, 228 ) 229 self.bound_loader_1batch = torch.utils.data.DataLoader( 230 train_test_dataset, 231 batch_size=len(bound_idx), 232 sampler=bound_sampler, 233 **loader_kwargs, 234 ) 235 236 def _split_learnt_self_certified_with_test( 237 self, 238 train_dataset: data.Dataset, 239 test_dataset: data.Dataset, 240 split_config: dict, 241 loader_kwargs: dict, 242 ) -> None: 243 """ 244 Similar to `_split_learnt_self_certified`, but explicitly keeps a separate test set. 245 246 Args: 247 train_dataset (Dataset): Training portion of the data. 248 test_dataset (Dataset): Test portion of the data. 249 split_config (Dict): Contains parameters such as 'batch_size', 'seed', etc. 250 loader_kwargs (Dict): Extra DataLoader arguments. 251 """ 252 batch_size = split_config["batch_size"] 253 training_percent = self._train_percent 254 val_percent = self._val_percent 255 prior_percent = self._prior_percent 256 seed = split_config["seed"] 257 258 train_test_dataset = torch.utils.data.ConcatDataset( 259 [train_dataset, test_dataset] 260 ) 261 train_test_size = len(train_dataset.data) + len(test_dataset.data) 262 train_test_indices = list(range(train_test_size)) 263 np.random.seed(seed) 264 np.random.shuffle(train_test_indices) 265 # take fraction of a training dataset 266 training_test_split = int(np.round(training_percent * train_test_size)) 267 train_indices = train_test_indices[:training_test_split] 268 test_indices = train_test_indices[training_test_split:] 269 270 if val_percent > 0.0: 271 prior_split = int(np.round(prior_percent * training_test_split)) 272 bound_idx, prior_val_idx = ( 273 train_indices[prior_split:], 274 train_indices[:prior_split], 275 ) 276 val_split = int(np.round(val_percent * prior_split)) 277 prior_idx, val_idx = ( 278 prior_val_idx[val_split:], 279 prior_val_idx[:val_split], 280 ) 281 else: 282 prior_split = int(np.round(prior_percent * training_test_split)) 283 bound_idx, prior_idx = ( 284 train_indices[prior_split:], 285 train_indices[:prior_split], 286 ) 287 val_idx = None 288 289 train_sampler = SubsetRandomSampler(train_indices) 290 bound_sampler = SubsetRandomSampler(bound_idx) 291 prior_sampler = SubsetRandomSampler(prior_idx) 292 val_sampler = SubsetRandomSampler(val_idx) 293 test_sampler = SubsetRandomSampler(test_indices) 294 295 self.posterior_loader = torch.utils.data.DataLoader( 296 train_test_dataset, 297 batch_size=batch_size, 298 sampler=train_sampler, 299 shuffle=False, 300 **loader_kwargs, 301 ) 302 self.prior_loader = torch.utils.data.DataLoader( 303 train_test_dataset, 304 batch_size=batch_size, 305 sampler=prior_sampler, 306 shuffle=False, 307 **loader_kwargs, 308 ) 309 if val_idx: 310 self.val_loader = torch.utils.data.DataLoader( 311 train_test_dataset, 312 batch_size=batch_size, 313 sampler=val_sampler, 314 shuffle=False, 315 **loader_kwargs, 316 ) 317 if len(test_indices) > 0: 318 self.test_loader = torch.utils.data.DataLoader( 319 train_test_dataset, 320 batch_size=batch_size, 321 sampler=test_sampler, 322 shuffle=False, 323 **loader_kwargs, 324 ) 325 self.test_loader_1batch = torch.utils.data.DataLoader( 326 train_test_dataset, 327 batch_size=len(test_indices), 328 sampler=test_sampler, 329 **loader_kwargs, 330 ) 331 else: 332 self.test_loader = None 333 self.test_loader_1batch = None 334 self.bound_loader = torch.utils.data.DataLoader( 335 train_test_dataset, 336 batch_size=batch_size, 337 sampler=bound_sampler, 338 shuffle=False, 339 **loader_kwargs, 340 ) 341 self.bound_loader_1batch = torch.utils.data.DataLoader( 342 train_test_dataset, 343 batch_size=len(bound_idx), 344 sampler=bound_sampler, 345 **loader_kwargs, 346 ) 347 348 def _split_learnt_not_self_certified( 349 self, 350 train_dataset: data.Dataset, 351 test_dataset: data.Dataset, 352 split_config: dict, 353 loader_kwargs: dict, 354 ) -> None: 355 """ 356 Split data when the prior is learned but not using a self-certified approach. 357 358 Args: 359 train_dataset (Dataset): The training dataset. 360 test_dataset (Dataset): The testing dataset. 361 split_config (Dict): Dictionary with keys (e.g., batch_size, seed, etc.). 362 loader_kwargs (Dict): Additional params for DataLoader creation. 363 """ 364 batch_size = split_config["batch_size"] 365 training_percent = self._train_percent 366 val_percent = self._val_percent 367 prior_percent = self._prior_percent 368 seed = split_config["seed"] 369 370 train_size = len(train_dataset.data) 371 test_size = len(test_dataset.data) 372 train_indices = list(range(train_size)) 373 # TODO: no need to shuffle because of SubsetRandomSampler 374 np.random.seed(seed) 375 np.random.shuffle(train_indices) 376 377 training_split = int(np.round(training_percent * train_size)) 378 train_indices = train_indices[:training_split] 379 380 if val_percent > 0.0: 381 prior_split = int(np.round(prior_percent * training_split)) 382 bound_idx, prior_val_idx = ( 383 train_indices[prior_split:], 384 train_indices[:prior_split], 385 ) 386 val_split = int(np.round(val_percent * prior_split)) 387 prior_idx, val_idx = ( 388 prior_val_idx[val_split:], 389 prior_val_idx[:val_split], 390 ) 391 else: 392 prior_split = int(np.round(prior_percent * training_split)) 393 bound_idx, prior_idx = ( 394 train_indices[prior_split:], 395 train_indices[:prior_split], 396 ) 397 val_idx = None 398 399 train_sampler = SubsetRandomSampler(train_indices) 400 bound_sampler = SubsetRandomSampler(bound_idx) 401 prior_sampler = SubsetRandomSampler(prior_idx) 402 val_sampler = SubsetRandomSampler(val_idx) 403 404 self.posterior_loader = torch.utils.data.DataLoader( 405 train_dataset, 406 batch_size=batch_size, 407 sampler=train_sampler, 408 shuffle=False, 409 **loader_kwargs, 410 ) 411 self.prior_loader = torch.utils.data.DataLoader( 412 train_dataset, 413 batch_size=batch_size, 414 sampler=prior_sampler, 415 shuffle=False, 416 **loader_kwargs, 417 ) 418 if val_idx: 419 self.val_loader = torch.utils.data.DataLoader( 420 train_dataset, 421 batch_size=batch_size, 422 sampler=val_sampler, 423 shuffle=False, 424 **loader_kwargs, 425 ) 426 self.test_loader = torch.utils.data.DataLoader( 427 test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs 428 ) 429 self.test_1batch = torch.utils.data.DataLoader( 430 test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs 431 ) 432 self.bound_loader = torch.utils.data.DataLoader( 433 train_dataset, 434 batch_size=batch_size, 435 sampler=bound_sampler, 436 shuffle=False, 437 ) 438 self.bound_loader_1batch = torch.utils.data.DataLoader( 439 train_dataset, 440 batch_size=len(bound_idx), 441 sampler=bound_sampler, 442 **loader_kwargs, 443 ) 444 445 def split( 446 self, dataset_loader: Union["MNISTLoader", "CIFAR10Loader"], split_config: dict 447 ) -> None: 448 """ 449 Public method to perform the split operation on a dataset loader, 450 setting up DataLoaders for prior, posterior, validation, testing, and bound evaluation. 451 452 Args: 453 dataset_loader (Union[MNISTLoader, CIFAR10Loader]): A dataset loader instance 454 providing `load(dataset_loader_seed)` to retrieve train/test datasets. 455 split_config (Dict): Configuration parameters for splitting (e.g., batch_size, seed). 456 """ 457 dataset_loader_seed = split_config["dataset_loader_seed"] 458 train_dataset, test_dataset = dataset_loader.load(dataset_loader_seed) 459 460 loader_kwargs = ( 461 {"num_workers": 1, "pin_memory": True} if torch.cuda.is_available() else {} 462 ) 463 464 if self._prior_type == "not_learnt": 465 self._split_not_learnt( 466 train_dataset=train_dataset, 467 test_dataset=test_dataset, 468 split_config=split_config, 469 loader_kwargs=loader_kwargs, 470 ) 471 elif self._prior_type == "learnt": 472 if self._self_certified: 473 self._split_learnt_self_certified( 474 train_dataset=train_dataset, 475 test_dataset=test_dataset, 476 split_config=split_config, 477 loader_kwargs=loader_kwargs, 478 ) 479 else: 480 self._split_learnt_not_self_certified( 481 train_dataset=train_dataset, 482 test_dataset=test_dataset, 483 split_config=split_config, 484 loader_kwargs=loader_kwargs, 485 ) 486 elif self._prior_type == "learnt_with_test": 487 if self._self_certified: 488 self._split_learnt_self_certified_with_test( 489 train_dataset=train_dataset, 490 test_dataset=test_dataset, 491 split_config=split_config, 492 loader_kwargs=loader_kwargs, 493 ) 494 else: 495 raise ValueError(f"Invalid prior_type: {self._prior_type}") 496 else: 497 raise ValueError(f"Invalid prior_type: {self._prior_type}")
A split strategy implementing a Prior-Posterior-Bound (PBP) partition of the dataset.
This strategy supports data splits for:
- Posterior training
- Prior training
- Validation
- Test
- Bound evaluation (data for calculating PAC-Bayes bounds)
By changing the internal splitting methods, one can adapt different scenarios such as:
- 'not_learnt': The prior is not trained.
- 'learnt': The prior is trained on some portion of the data.
- 'learnt_with_test': Similar to 'learnt', but includes an explicit test subset.
PBPSplitStrategy( prior_type: str, train_percent: float, val_percent: float, prior_percent: float, self_certified: bool)
43 def __init__( 44 self, 45 prior_type: str, 46 train_percent: float, 47 val_percent: float, 48 prior_percent: float, 49 self_certified: bool, 50 ): 51 """ 52 Initialize the PBPSplitStrategy with user-defined parameters for how to partition the data. 53 54 Args: 55 prior_type (str): Indicates whether the prior is "not_learnt", "learnt", or "learnt_with_test". 56 train_percent (float): Proportion of data used for training. 57 val_percent (float): Proportion of data used for validation. 58 prior_percent (float): Proportion of data used specifically for training the prior. 59 self_certified (bool): If True, indicates self-certified splitting approach. 60 """ 61 self._prior_type = prior_type 62 self._train_percent = train_percent 63 self._val_percent = val_percent 64 self._prior_percent = prior_percent 65 self._self_certified = self_certified
Initialize the PBPSplitStrategy with user-defined parameters for how to partition the data.
Arguments:
- prior_type (str): Indicates whether the prior is "not_learnt", "learnt", or "learnt_with_test".
- train_percent (float): Proportion of data used for training.
- val_percent (float): Proportion of data used for validation.
- prior_percent (float): Proportion of data used specifically for training the prior.
- self_certified (bool): If True, indicates self-certified splitting approach.
def
split( self, dataset_loader: Union[ForwardRef('MNISTLoader'), ForwardRef('CIFAR10Loader')], split_config: dict) -> None:
445 def split( 446 self, dataset_loader: Union["MNISTLoader", "CIFAR10Loader"], split_config: dict 447 ) -> None: 448 """ 449 Public method to perform the split operation on a dataset loader, 450 setting up DataLoaders for prior, posterior, validation, testing, and bound evaluation. 451 452 Args: 453 dataset_loader (Union[MNISTLoader, CIFAR10Loader]): A dataset loader instance 454 providing `load(dataset_loader_seed)` to retrieve train/test datasets. 455 split_config (Dict): Configuration parameters for splitting (e.g., batch_size, seed). 456 """ 457 dataset_loader_seed = split_config["dataset_loader_seed"] 458 train_dataset, test_dataset = dataset_loader.load(dataset_loader_seed) 459 460 loader_kwargs = ( 461 {"num_workers": 1, "pin_memory": True} if torch.cuda.is_available() else {} 462 ) 463 464 if self._prior_type == "not_learnt": 465 self._split_not_learnt( 466 train_dataset=train_dataset, 467 test_dataset=test_dataset, 468 split_config=split_config, 469 loader_kwargs=loader_kwargs, 470 ) 471 elif self._prior_type == "learnt": 472 if self._self_certified: 473 self._split_learnt_self_certified( 474 train_dataset=train_dataset, 475 test_dataset=test_dataset, 476 split_config=split_config, 477 loader_kwargs=loader_kwargs, 478 ) 479 else: 480 self._split_learnt_not_self_certified( 481 train_dataset=train_dataset, 482 test_dataset=test_dataset, 483 split_config=split_config, 484 loader_kwargs=loader_kwargs, 485 ) 486 elif self._prior_type == "learnt_with_test": 487 if self._self_certified: 488 self._split_learnt_self_certified_with_test( 489 train_dataset=train_dataset, 490 test_dataset=test_dataset, 491 split_config=split_config, 492 loader_kwargs=loader_kwargs, 493 ) 494 else: 495 raise ValueError(f"Invalid prior_type: {self._prior_type}") 496 else: 497 raise ValueError(f"Invalid prior_type: {self._prior_type}")
Public method to perform the split operation on a dataset loader, setting up DataLoaders for prior, posterior, validation, testing, and bound evaluation.
Arguments:
- dataset_loader (Union[MNISTLoader, CIFAR10Loader]): A dataset loader instance
providing
load(dataset_loader_seed)to retrieve train/test datasets. - split_config (Dict): Configuration parameters for splitting (e.g., batch_size, seed).