core.split_strategy.PBPSplitStrategy

  1from dataclasses import dataclass
  2from typing import Union
  3
  4import numpy as np
  5import torch
  6from torch.utils import data
  7from torch.utils.data.sampler import SubsetRandomSampler
  8
  9from core.split_strategy import AbstractSplitStrategy
 10
 11
 12@dataclass
 13class PBPSplitStrategy(AbstractSplitStrategy):
 14    """
 15    A split strategy implementing a Prior-Posterior-Bound (PBP) partition of the dataset.
 16
 17    This strategy supports data splits for:
 18      - Posterior training
 19      - Prior training
 20      - Validation
 21      - Test
 22      - Bound evaluation (data for calculating PAC-Bayes bounds)
 23
 24    By changing the internal splitting methods, one can adapt different scenarios such as:
 25    - 'not_learnt': The prior is not trained.
 26    - 'learnt': The prior is trained on some portion of the data.
 27    - 'learnt_with_test': Similar to 'learnt', but includes an explicit test subset.
 28    """
 29
 30    # Posterior training
 31    posterior_loader: data.dataloader.DataLoader = None
 32    # Prior training
 33    prior_loader: data.dataloader.DataLoader = None
 34    # Evaluation
 35    val_loader: data.dataloader.DataLoader = None
 36    test_loader: data.dataloader.DataLoader = None
 37    test_1batch: data.dataloader.DataLoader = None
 38    # Bounds evaluation
 39    bound_loader: data.dataloader.DataLoader = None
 40    bound_loader_1batch: data.dataloader.DataLoader = None
 41
 42    def __init__(
 43        self,
 44        prior_type: str,
 45        train_percent: float,
 46        val_percent: float,
 47        prior_percent: float,
 48        self_certified: bool,
 49    ):
 50        """
 51        Initialize the PBPSplitStrategy with user-defined parameters for how to partition the data.
 52
 53        Args:
 54            prior_type (str): Indicates whether the prior is "not_learnt", "learnt", or "learnt_with_test".
 55            train_percent (float): Proportion of data used for training.
 56            val_percent (float): Proportion of data used for validation.
 57            prior_percent (float): Proportion of data used specifically for training the prior.
 58            self_certified (bool): If True, indicates self-certified splitting approach.
 59        """
 60        self._prior_type = prior_type
 61        self._train_percent = train_percent
 62        self._val_percent = val_percent
 63        self._prior_percent = prior_percent
 64        self._self_certified = self_certified
 65
 66    def _split_not_learnt(
 67        self,
 68        train_dataset: data.Dataset,
 69        test_dataset: data.Dataset,
 70        split_config: dict,
 71        loader_kwargs: dict,
 72    ) -> None:
 73        """
 74        Split data for the scenario where the prior is not learned (e.g., a fixed prior).
 75
 76        Args:
 77            train_dataset (Dataset): The dataset for training and possibly validation.
 78            test_dataset (Dataset): The dataset for testing.
 79            split_config (Dict): Dictionary with keys like 'batch_size', 'seed', etc.
 80            loader_kwargs (Dict): Extra keyword arguments for DataLoader initialization.
 81        """
 82        batch_size = split_config["batch_size"]
 83        training_percent = self._train_percent
 84        val_percent = self._val_percent
 85        seed = split_config["seed"]
 86
 87        train_size = len(train_dataset.data)
 88        test_size = len(test_dataset.data)
 89        train_indices = list(range(train_size))
 90        np.random.seed(seed)
 91        np.random.shuffle(train_indices)
 92
 93        # take fraction of a training dataset
 94        training_split = int(np.round(training_percent * train_size))
 95        train_indices = train_indices[:training_split]
 96        if val_percent > 0.0:
 97            val_split = int(np.round(val_percent * training_split))
 98            train_idx = train_indices[val_split:]
 99            val_idx = train_indices[:val_split]
100        else:
101            train_idx = train_indices
102            val_idx = None
103
104        train_sampler = SubsetRandomSampler(train_idx)
105        val_sampler = SubsetRandomSampler(val_idx)
106
107        self.posterior_loader = torch.utils.data.DataLoader(
108            train_dataset,
109            batch_size=batch_size,
110            sampler=train_sampler,
111            **loader_kwargs,
112        )
113        # self.prior_loader = None
114        if val_idx:
115            self.val_loader = torch.utils.data.DataLoader(
116                train_dataset,
117                batch_size=batch_size,
118                sampler=val_sampler,
119                shuffle=False,
120            )
121        self.test_loader = torch.utils.data.DataLoader(
122            test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
123        )
124        self.test_1batch = torch.utils.data.DataLoader(
125            test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs
126        )
127        self.bound_loader = torch.utils.data.DataLoader(
128            train_dataset,
129            batch_size=batch_size,
130            sampler=train_sampler,
131            **loader_kwargs,
132        )
133        self.bound_loader_1batch = torch.utils.data.DataLoader(
134            train_dataset,
135            batch_size=len(train_idx),
136            sampler=train_sampler,
137            **loader_kwargs,
138        )
139
140    def _split_learnt_self_certified(
141        self,
142        train_dataset: data.Dataset,
143        test_dataset: data.Dataset,
144        split_config: dict,
145        loader_kwargs: dict,
146    ) -> None:
147        """
148        Split data when the prior is learned and we use a self-certified approach (all data combined).
149
150        Args:
151            train_dataset (Dataset): Training dataset.
152            test_dataset (Dataset): Test dataset.
153            split_config (Dict): Contains config like 'batch_size', 'seed'.
154            loader_kwargs (Dict): Extra arguments for DataLoader.
155        """
156        batch_size = split_config["batch_size"]
157        training_percent = self._train_percent
158        val_percent = self._val_percent
159        prior_percent = self._prior_percent
160        seed = split_config["seed"]
161
162        train_test_dataset = torch.utils.data.ConcatDataset(
163            [train_dataset, test_dataset]
164        )
165        train_test_size = len(train_dataset.data) + len(test_dataset.data)
166        train_indices = list(range(train_test_size))
167        np.random.seed(seed)
168        np.random.shuffle(train_indices)
169        # take fraction of a training dataset
170        training_test_split = int(np.round(training_percent * train_test_size))
171        train_indices = train_indices[:training_test_split]
172
173        if val_percent > 0.0:
174            prior_split = int(np.round(prior_percent * training_test_split))
175            bound_idx, prior_val_idx = (
176                train_indices[prior_split:],
177                train_indices[:prior_split],
178            )
179            val_split = int(np.round(val_percent * prior_split))
180            prior_idx, val_idx = (
181                prior_val_idx[val_split:],
182                prior_val_idx[:val_split],
183            )
184        else:
185            prior_split = int(np.round(prior_percent * training_test_split))
186            bound_idx, prior_idx = (
187                train_indices[prior_split:],
188                train_indices[:prior_split],
189            )
190            val_idx = None
191
192        train_test_sampler = SubsetRandomSampler(train_indices)
193        bound_sampler = SubsetRandomSampler(bound_idx)
194        prior_sampler = SubsetRandomSampler(prior_idx)
195        val_sampler = SubsetRandomSampler(val_idx)
196
197        self.posterior_loader = torch.utils.data.DataLoader(
198            train_test_dataset,
199            batch_size=batch_size,
200            sampler=train_test_sampler,
201            shuffle=False,
202            **loader_kwargs,
203        )
204        self.prior_loader = torch.utils.data.DataLoader(
205            train_test_dataset,
206            batch_size=batch_size,
207            sampler=prior_sampler,
208            shuffle=False,
209            **loader_kwargs,
210        )
211        if val_idx:
212            self.val_loader = torch.utils.data.DataLoader(
213                train_test_dataset,
214                batch_size=batch_size,
215                sampler=val_sampler,
216                shuffle=False,
217                **loader_kwargs,
218            )
219        # self.test_loader = None
220        # self.test_1batch = None
221        self.bound_loader = torch.utils.data.DataLoader(
222            train_test_dataset,
223            batch_size=batch_size,
224            sampler=bound_sampler,
225            shuffle=False,
226            **loader_kwargs,
227        )
228        self.bound_loader_1batch = torch.utils.data.DataLoader(
229            train_test_dataset,
230            batch_size=len(bound_idx),
231            sampler=bound_sampler,
232            **loader_kwargs,
233        )
234
235    def _split_learnt_self_certified_with_test(
236        self,
237        train_dataset: data.Dataset,
238        test_dataset: data.Dataset,
239        split_config: dict,
240        loader_kwargs: dict,
241    ) -> None:
242        """
243        Similar to `_split_learnt_self_certified`, but explicitly keeps a separate test set.
244
245        Args:
246            train_dataset (Dataset): Training portion of the data.
247            test_dataset (Dataset): Test portion of the data.
248            split_config (Dict): Contains parameters such as 'batch_size', 'seed', etc.
249            loader_kwargs (Dict): Extra DataLoader arguments.
250        """
251        batch_size = split_config["batch_size"]
252        training_percent = self._train_percent
253        val_percent = self._val_percent
254        prior_percent = self._prior_percent
255        seed = split_config["seed"]
256
257        train_test_dataset = torch.utils.data.ConcatDataset(
258            [train_dataset, test_dataset]
259        )
260        train_test_size = len(train_dataset.data) + len(test_dataset.data)
261        train_test_indices = list(range(train_test_size))
262        np.random.seed(seed)
263        np.random.shuffle(train_test_indices)
264        # take fraction of a training dataset
265        training_test_split = int(np.round(training_percent * train_test_size))
266        train_indices = train_test_indices[:training_test_split]
267        test_indices = train_test_indices[training_test_split:]
268
269        if val_percent > 0.0:
270            prior_split = int(np.round(prior_percent * training_test_split))
271            bound_idx, prior_val_idx = (
272                train_indices[prior_split:],
273                train_indices[:prior_split],
274            )
275            val_split = int(np.round(val_percent * prior_split))
276            prior_idx, val_idx = (
277                prior_val_idx[val_split:],
278                prior_val_idx[:val_split],
279            )
280        else:
281            prior_split = int(np.round(prior_percent * training_test_split))
282            bound_idx, prior_idx = (
283                train_indices[prior_split:],
284                train_indices[:prior_split],
285            )
286            val_idx = None
287
288        train_sampler = SubsetRandomSampler(train_indices)
289        bound_sampler = SubsetRandomSampler(bound_idx)
290        prior_sampler = SubsetRandomSampler(prior_idx)
291        val_sampler = SubsetRandomSampler(val_idx)
292        test_sampler = SubsetRandomSampler(test_indices)
293
294        self.posterior_loader = torch.utils.data.DataLoader(
295            train_test_dataset,
296            batch_size=batch_size,
297            sampler=train_sampler,
298            shuffle=False,
299            **loader_kwargs,
300        )
301        self.prior_loader = torch.utils.data.DataLoader(
302            train_test_dataset,
303            batch_size=batch_size,
304            sampler=prior_sampler,
305            shuffle=False,
306            **loader_kwargs,
307        )
308        if val_idx:
309            self.val_loader = torch.utils.data.DataLoader(
310                train_test_dataset,
311                batch_size=batch_size,
312                sampler=val_sampler,
313                shuffle=False,
314                **loader_kwargs,
315            )
316        if len(test_indices) > 0:
317            self.test_loader = torch.utils.data.DataLoader(
318                train_test_dataset,
319                batch_size=batch_size,
320                sampler=test_sampler,
321                shuffle=False,
322                **loader_kwargs,
323            )
324            self.test_loader_1batch = torch.utils.data.DataLoader(
325                train_test_dataset,
326                batch_size=len(test_indices),
327                sampler=test_sampler,
328                **loader_kwargs,
329            )
330        else:
331            self.test_loader = None
332            self.test_loader_1batch = None
333        self.bound_loader = torch.utils.data.DataLoader(
334            train_test_dataset,
335            batch_size=batch_size,
336            sampler=bound_sampler,
337            shuffle=False,
338            **loader_kwargs,
339        )
340        self.bound_loader_1batch = torch.utils.data.DataLoader(
341            train_test_dataset,
342            batch_size=len(bound_idx),
343            sampler=bound_sampler,
344            **loader_kwargs,
345        )
346
347    def _split_learnt_not_self_certified(
348        self,
349        train_dataset: data.Dataset,
350        test_dataset: data.Dataset,
351        split_config: dict,
352        loader_kwargs: dict,
353    ) -> None:
354        """
355        Split data when the prior is learned but not using a self-certified approach.
356
357        Args:
358            train_dataset (Dataset): The training dataset.
359            test_dataset (Dataset): The testing dataset.
360            split_config (Dict): Dictionary with keys (e.g., batch_size, seed, etc.).
361            loader_kwargs (Dict): Additional params for DataLoader creation.
362        """
363        batch_size = split_config["batch_size"]
364        training_percent = self._train_percent
365        val_percent = self._val_percent
366        prior_percent = self._prior_percent
367        seed = split_config["seed"]
368
369        train_size = len(train_dataset.data)
370        test_size = len(test_dataset.data)
371        train_indices = list(range(train_size))
372        # TODO: no need to shuffle because of SubsetRandomSampler
373        np.random.seed(seed)
374        np.random.shuffle(train_indices)
375
376        training_split = int(np.round(training_percent * train_size))
377        train_indices = train_indices[:training_split]
378
379        if val_percent > 0.0:
380            prior_split = int(np.round(prior_percent * training_split))
381            bound_idx, prior_val_idx = (
382                train_indices[prior_split:],
383                train_indices[:prior_split],
384            )
385            val_split = int(np.round(val_percent * prior_split))
386            prior_idx, val_idx = (
387                prior_val_idx[val_split:],
388                prior_val_idx[:val_split],
389            )
390        else:
391            prior_split = int(np.round(prior_percent * training_split))
392            bound_idx, prior_idx = (
393                train_indices[prior_split:],
394                train_indices[:prior_split],
395            )
396            val_idx = None
397
398        train_sampler = SubsetRandomSampler(train_indices)
399        bound_sampler = SubsetRandomSampler(bound_idx)
400        prior_sampler = SubsetRandomSampler(prior_idx)
401        val_sampler = SubsetRandomSampler(val_idx)
402
403        self.posterior_loader = torch.utils.data.DataLoader(
404            train_dataset,
405            batch_size=batch_size,
406            sampler=train_sampler,
407            shuffle=False,
408            **loader_kwargs,
409        )
410        self.prior_loader = torch.utils.data.DataLoader(
411            train_dataset,
412            batch_size=batch_size,
413            sampler=prior_sampler,
414            shuffle=False,
415            **loader_kwargs,
416        )
417        if val_idx:
418            self.val_loader = torch.utils.data.DataLoader(
419                train_dataset,
420                batch_size=batch_size,
421                sampler=val_sampler,
422                shuffle=False,
423                **loader_kwargs,
424            )
425        self.test_loader = torch.utils.data.DataLoader(
426            test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
427        )
428        self.test_1batch = torch.utils.data.DataLoader(
429            test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs
430        )
431        self.bound_loader = torch.utils.data.DataLoader(
432            train_dataset,
433            batch_size=batch_size,
434            sampler=bound_sampler,
435            shuffle=False,
436        )
437        self.bound_loader_1batch = torch.utils.data.DataLoader(
438            train_dataset,
439            batch_size=len(bound_idx),
440            sampler=bound_sampler,
441            **loader_kwargs,
442        )
443
444    def split(
445        self, dataset_loader: Union["MNISTLoader", "CIFAR10Loader"], split_config: dict
446    ) -> None:
447        """
448        Public method to perform the split operation on a dataset loader,
449        setting up DataLoaders for prior, posterior, validation, testing, and bound evaluation.
450
451        Args:
452            dataset_loader (Union[MNISTLoader, CIFAR10Loader]): A dataset loader instance
453                providing `load(dataset_loader_seed)` to retrieve train/test datasets.
454            split_config (Dict): Configuration parameters for splitting (e.g., batch_size, seed).
455        """
456        dataset_loader_seed = split_config["dataset_loader_seed"]
457        train_dataset, test_dataset = dataset_loader.load(dataset_loader_seed)
458
459        loader_kwargs = (
460            {"num_workers": 1, "pin_memory": True} if torch.cuda.is_available() else {}
461        )
462
463        if self._prior_type == "not_learnt":
464            self._split_not_learnt(
465                train_dataset=train_dataset,
466                test_dataset=test_dataset,
467                split_config=split_config,
468                loader_kwargs=loader_kwargs,
469            )
470        elif self._prior_type == "learnt":
471            if self._self_certified:
472                self._split_learnt_self_certified(
473                    train_dataset=train_dataset,
474                    test_dataset=test_dataset,
475                    split_config=split_config,
476                    loader_kwargs=loader_kwargs,
477                )
478            else:
479                self._split_learnt_not_self_certified(
480                    train_dataset=train_dataset,
481                    test_dataset=test_dataset,
482                    split_config=split_config,
483                    loader_kwargs=loader_kwargs,
484                )
485        elif self._prior_type == "learnt_with_test":
486            if self._self_certified:
487                self._split_learnt_self_certified_with_test(
488                    train_dataset=train_dataset,
489                    test_dataset=test_dataset,
490                    split_config=split_config,
491                    loader_kwargs=loader_kwargs,
492                )
493            else:
494                raise ValueError(f"Invalid prior_type: {self._prior_type}")
495        else:
496            raise ValueError(f"Invalid prior_type: {self._prior_type}")
@dataclass
class PBPSplitStrategy(core.split_strategy.AbstractSplitStrategy.AbstractSplitStrategy):
 13@dataclass
 14class PBPSplitStrategy(AbstractSplitStrategy):
 15    """
 16    A split strategy implementing a Prior-Posterior-Bound (PBP) partition of the dataset.
 17
 18    This strategy supports data splits for:
 19      - Posterior training
 20      - Prior training
 21      - Validation
 22      - Test
 23      - Bound evaluation (data for calculating PAC-Bayes bounds)
 24
 25    By changing the internal splitting methods, one can adapt different scenarios such as:
 26    - 'not_learnt': The prior is not trained.
 27    - 'learnt': The prior is trained on some portion of the data.
 28    - 'learnt_with_test': Similar to 'learnt', but includes an explicit test subset.
 29    """
 30
 31    # Posterior training
 32    posterior_loader: data.dataloader.DataLoader = None
 33    # Prior training
 34    prior_loader: data.dataloader.DataLoader = None
 35    # Evaluation
 36    val_loader: data.dataloader.DataLoader = None
 37    test_loader: data.dataloader.DataLoader = None
 38    test_1batch: data.dataloader.DataLoader = None
 39    # Bounds evaluation
 40    bound_loader: data.dataloader.DataLoader = None
 41    bound_loader_1batch: data.dataloader.DataLoader = None
 42
 43    def __init__(
 44        self,
 45        prior_type: str,
 46        train_percent: float,
 47        val_percent: float,
 48        prior_percent: float,
 49        self_certified: bool,
 50    ):
 51        """
 52        Initialize the PBPSplitStrategy with user-defined parameters for how to partition the data.
 53
 54        Args:
 55            prior_type (str): Indicates whether the prior is "not_learnt", "learnt", or "learnt_with_test".
 56            train_percent (float): Proportion of data used for training.
 57            val_percent (float): Proportion of data used for validation.
 58            prior_percent (float): Proportion of data used specifically for training the prior.
 59            self_certified (bool): If True, indicates self-certified splitting approach.
 60        """
 61        self._prior_type = prior_type
 62        self._train_percent = train_percent
 63        self._val_percent = val_percent
 64        self._prior_percent = prior_percent
 65        self._self_certified = self_certified
 66
 67    def _split_not_learnt(
 68        self,
 69        train_dataset: data.Dataset,
 70        test_dataset: data.Dataset,
 71        split_config: dict,
 72        loader_kwargs: dict,
 73    ) -> None:
 74        """
 75        Split data for the scenario where the prior is not learned (e.g., a fixed prior).
 76
 77        Args:
 78            train_dataset (Dataset): The dataset for training and possibly validation.
 79            test_dataset (Dataset): The dataset for testing.
 80            split_config (Dict): Dictionary with keys like 'batch_size', 'seed', etc.
 81            loader_kwargs (Dict): Extra keyword arguments for DataLoader initialization.
 82        """
 83        batch_size = split_config["batch_size"]
 84        training_percent = self._train_percent
 85        val_percent = self._val_percent
 86        seed = split_config["seed"]
 87
 88        train_size = len(train_dataset.data)
 89        test_size = len(test_dataset.data)
 90        train_indices = list(range(train_size))
 91        np.random.seed(seed)
 92        np.random.shuffle(train_indices)
 93
 94        # take fraction of a training dataset
 95        training_split = int(np.round(training_percent * train_size))
 96        train_indices = train_indices[:training_split]
 97        if val_percent > 0.0:
 98            val_split = int(np.round(val_percent * training_split))
 99            train_idx = train_indices[val_split:]
100            val_idx = train_indices[:val_split]
101        else:
102            train_idx = train_indices
103            val_idx = None
104
105        train_sampler = SubsetRandomSampler(train_idx)
106        val_sampler = SubsetRandomSampler(val_idx)
107
108        self.posterior_loader = torch.utils.data.DataLoader(
109            train_dataset,
110            batch_size=batch_size,
111            sampler=train_sampler,
112            **loader_kwargs,
113        )
114        # self.prior_loader = None
115        if val_idx:
116            self.val_loader = torch.utils.data.DataLoader(
117                train_dataset,
118                batch_size=batch_size,
119                sampler=val_sampler,
120                shuffle=False,
121            )
122        self.test_loader = torch.utils.data.DataLoader(
123            test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
124        )
125        self.test_1batch = torch.utils.data.DataLoader(
126            test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs
127        )
128        self.bound_loader = torch.utils.data.DataLoader(
129            train_dataset,
130            batch_size=batch_size,
131            sampler=train_sampler,
132            **loader_kwargs,
133        )
134        self.bound_loader_1batch = torch.utils.data.DataLoader(
135            train_dataset,
136            batch_size=len(train_idx),
137            sampler=train_sampler,
138            **loader_kwargs,
139        )
140
141    def _split_learnt_self_certified(
142        self,
143        train_dataset: data.Dataset,
144        test_dataset: data.Dataset,
145        split_config: dict,
146        loader_kwargs: dict,
147    ) -> None:
148        """
149        Split data when the prior is learned and we use a self-certified approach (all data combined).
150
151        Args:
152            train_dataset (Dataset): Training dataset.
153            test_dataset (Dataset): Test dataset.
154            split_config (Dict): Contains config like 'batch_size', 'seed'.
155            loader_kwargs (Dict): Extra arguments for DataLoader.
156        """
157        batch_size = split_config["batch_size"]
158        training_percent = self._train_percent
159        val_percent = self._val_percent
160        prior_percent = self._prior_percent
161        seed = split_config["seed"]
162
163        train_test_dataset = torch.utils.data.ConcatDataset(
164            [train_dataset, test_dataset]
165        )
166        train_test_size = len(train_dataset.data) + len(test_dataset.data)
167        train_indices = list(range(train_test_size))
168        np.random.seed(seed)
169        np.random.shuffle(train_indices)
170        # take fraction of a training dataset
171        training_test_split = int(np.round(training_percent * train_test_size))
172        train_indices = train_indices[:training_test_split]
173
174        if val_percent > 0.0:
175            prior_split = int(np.round(prior_percent * training_test_split))
176            bound_idx, prior_val_idx = (
177                train_indices[prior_split:],
178                train_indices[:prior_split],
179            )
180            val_split = int(np.round(val_percent * prior_split))
181            prior_idx, val_idx = (
182                prior_val_idx[val_split:],
183                prior_val_idx[:val_split],
184            )
185        else:
186            prior_split = int(np.round(prior_percent * training_test_split))
187            bound_idx, prior_idx = (
188                train_indices[prior_split:],
189                train_indices[:prior_split],
190            )
191            val_idx = None
192
193        train_test_sampler = SubsetRandomSampler(train_indices)
194        bound_sampler = SubsetRandomSampler(bound_idx)
195        prior_sampler = SubsetRandomSampler(prior_idx)
196        val_sampler = SubsetRandomSampler(val_idx)
197
198        self.posterior_loader = torch.utils.data.DataLoader(
199            train_test_dataset,
200            batch_size=batch_size,
201            sampler=train_test_sampler,
202            shuffle=False,
203            **loader_kwargs,
204        )
205        self.prior_loader = torch.utils.data.DataLoader(
206            train_test_dataset,
207            batch_size=batch_size,
208            sampler=prior_sampler,
209            shuffle=False,
210            **loader_kwargs,
211        )
212        if val_idx:
213            self.val_loader = torch.utils.data.DataLoader(
214                train_test_dataset,
215                batch_size=batch_size,
216                sampler=val_sampler,
217                shuffle=False,
218                **loader_kwargs,
219            )
220        # self.test_loader = None
221        # self.test_1batch = None
222        self.bound_loader = torch.utils.data.DataLoader(
223            train_test_dataset,
224            batch_size=batch_size,
225            sampler=bound_sampler,
226            shuffle=False,
227            **loader_kwargs,
228        )
229        self.bound_loader_1batch = torch.utils.data.DataLoader(
230            train_test_dataset,
231            batch_size=len(bound_idx),
232            sampler=bound_sampler,
233            **loader_kwargs,
234        )
235
236    def _split_learnt_self_certified_with_test(
237        self,
238        train_dataset: data.Dataset,
239        test_dataset: data.Dataset,
240        split_config: dict,
241        loader_kwargs: dict,
242    ) -> None:
243        """
244        Similar to `_split_learnt_self_certified`, but explicitly keeps a separate test set.
245
246        Args:
247            train_dataset (Dataset): Training portion of the data.
248            test_dataset (Dataset): Test portion of the data.
249            split_config (Dict): Contains parameters such as 'batch_size', 'seed', etc.
250            loader_kwargs (Dict): Extra DataLoader arguments.
251        """
252        batch_size = split_config["batch_size"]
253        training_percent = self._train_percent
254        val_percent = self._val_percent
255        prior_percent = self._prior_percent
256        seed = split_config["seed"]
257
258        train_test_dataset = torch.utils.data.ConcatDataset(
259            [train_dataset, test_dataset]
260        )
261        train_test_size = len(train_dataset.data) + len(test_dataset.data)
262        train_test_indices = list(range(train_test_size))
263        np.random.seed(seed)
264        np.random.shuffle(train_test_indices)
265        # take fraction of a training dataset
266        training_test_split = int(np.round(training_percent * train_test_size))
267        train_indices = train_test_indices[:training_test_split]
268        test_indices = train_test_indices[training_test_split:]
269
270        if val_percent > 0.0:
271            prior_split = int(np.round(prior_percent * training_test_split))
272            bound_idx, prior_val_idx = (
273                train_indices[prior_split:],
274                train_indices[:prior_split],
275            )
276            val_split = int(np.round(val_percent * prior_split))
277            prior_idx, val_idx = (
278                prior_val_idx[val_split:],
279                prior_val_idx[:val_split],
280            )
281        else:
282            prior_split = int(np.round(prior_percent * training_test_split))
283            bound_idx, prior_idx = (
284                train_indices[prior_split:],
285                train_indices[:prior_split],
286            )
287            val_idx = None
288
289        train_sampler = SubsetRandomSampler(train_indices)
290        bound_sampler = SubsetRandomSampler(bound_idx)
291        prior_sampler = SubsetRandomSampler(prior_idx)
292        val_sampler = SubsetRandomSampler(val_idx)
293        test_sampler = SubsetRandomSampler(test_indices)
294
295        self.posterior_loader = torch.utils.data.DataLoader(
296            train_test_dataset,
297            batch_size=batch_size,
298            sampler=train_sampler,
299            shuffle=False,
300            **loader_kwargs,
301        )
302        self.prior_loader = torch.utils.data.DataLoader(
303            train_test_dataset,
304            batch_size=batch_size,
305            sampler=prior_sampler,
306            shuffle=False,
307            **loader_kwargs,
308        )
309        if val_idx:
310            self.val_loader = torch.utils.data.DataLoader(
311                train_test_dataset,
312                batch_size=batch_size,
313                sampler=val_sampler,
314                shuffle=False,
315                **loader_kwargs,
316            )
317        if len(test_indices) > 0:
318            self.test_loader = torch.utils.data.DataLoader(
319                train_test_dataset,
320                batch_size=batch_size,
321                sampler=test_sampler,
322                shuffle=False,
323                **loader_kwargs,
324            )
325            self.test_loader_1batch = torch.utils.data.DataLoader(
326                train_test_dataset,
327                batch_size=len(test_indices),
328                sampler=test_sampler,
329                **loader_kwargs,
330            )
331        else:
332            self.test_loader = None
333            self.test_loader_1batch = None
334        self.bound_loader = torch.utils.data.DataLoader(
335            train_test_dataset,
336            batch_size=batch_size,
337            sampler=bound_sampler,
338            shuffle=False,
339            **loader_kwargs,
340        )
341        self.bound_loader_1batch = torch.utils.data.DataLoader(
342            train_test_dataset,
343            batch_size=len(bound_idx),
344            sampler=bound_sampler,
345            **loader_kwargs,
346        )
347
348    def _split_learnt_not_self_certified(
349        self,
350        train_dataset: data.Dataset,
351        test_dataset: data.Dataset,
352        split_config: dict,
353        loader_kwargs: dict,
354    ) -> None:
355        """
356        Split data when the prior is learned but not using a self-certified approach.
357
358        Args:
359            train_dataset (Dataset): The training dataset.
360            test_dataset (Dataset): The testing dataset.
361            split_config (Dict): Dictionary with keys (e.g., batch_size, seed, etc.).
362            loader_kwargs (Dict): Additional params for DataLoader creation.
363        """
364        batch_size = split_config["batch_size"]
365        training_percent = self._train_percent
366        val_percent = self._val_percent
367        prior_percent = self._prior_percent
368        seed = split_config["seed"]
369
370        train_size = len(train_dataset.data)
371        test_size = len(test_dataset.data)
372        train_indices = list(range(train_size))
373        # TODO: no need to shuffle because of SubsetRandomSampler
374        np.random.seed(seed)
375        np.random.shuffle(train_indices)
376
377        training_split = int(np.round(training_percent * train_size))
378        train_indices = train_indices[:training_split]
379
380        if val_percent > 0.0:
381            prior_split = int(np.round(prior_percent * training_split))
382            bound_idx, prior_val_idx = (
383                train_indices[prior_split:],
384                train_indices[:prior_split],
385            )
386            val_split = int(np.round(val_percent * prior_split))
387            prior_idx, val_idx = (
388                prior_val_idx[val_split:],
389                prior_val_idx[:val_split],
390            )
391        else:
392            prior_split = int(np.round(prior_percent * training_split))
393            bound_idx, prior_idx = (
394                train_indices[prior_split:],
395                train_indices[:prior_split],
396            )
397            val_idx = None
398
399        train_sampler = SubsetRandomSampler(train_indices)
400        bound_sampler = SubsetRandomSampler(bound_idx)
401        prior_sampler = SubsetRandomSampler(prior_idx)
402        val_sampler = SubsetRandomSampler(val_idx)
403
404        self.posterior_loader = torch.utils.data.DataLoader(
405            train_dataset,
406            batch_size=batch_size,
407            sampler=train_sampler,
408            shuffle=False,
409            **loader_kwargs,
410        )
411        self.prior_loader = torch.utils.data.DataLoader(
412            train_dataset,
413            batch_size=batch_size,
414            sampler=prior_sampler,
415            shuffle=False,
416            **loader_kwargs,
417        )
418        if val_idx:
419            self.val_loader = torch.utils.data.DataLoader(
420                train_dataset,
421                batch_size=batch_size,
422                sampler=val_sampler,
423                shuffle=False,
424                **loader_kwargs,
425            )
426        self.test_loader = torch.utils.data.DataLoader(
427            test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
428        )
429        self.test_1batch = torch.utils.data.DataLoader(
430            test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs
431        )
432        self.bound_loader = torch.utils.data.DataLoader(
433            train_dataset,
434            batch_size=batch_size,
435            sampler=bound_sampler,
436            shuffle=False,
437        )
438        self.bound_loader_1batch = torch.utils.data.DataLoader(
439            train_dataset,
440            batch_size=len(bound_idx),
441            sampler=bound_sampler,
442            **loader_kwargs,
443        )
444
445    def split(
446        self, dataset_loader: Union["MNISTLoader", "CIFAR10Loader"], split_config: dict
447    ) -> None:
448        """
449        Public method to perform the split operation on a dataset loader,
450        setting up DataLoaders for prior, posterior, validation, testing, and bound evaluation.
451
452        Args:
453            dataset_loader (Union[MNISTLoader, CIFAR10Loader]): A dataset loader instance
454                providing `load(dataset_loader_seed)` to retrieve train/test datasets.
455            split_config (Dict): Configuration parameters for splitting (e.g., batch_size, seed).
456        """
457        dataset_loader_seed = split_config["dataset_loader_seed"]
458        train_dataset, test_dataset = dataset_loader.load(dataset_loader_seed)
459
460        loader_kwargs = (
461            {"num_workers": 1, "pin_memory": True} if torch.cuda.is_available() else {}
462        )
463
464        if self._prior_type == "not_learnt":
465            self._split_not_learnt(
466                train_dataset=train_dataset,
467                test_dataset=test_dataset,
468                split_config=split_config,
469                loader_kwargs=loader_kwargs,
470            )
471        elif self._prior_type == "learnt":
472            if self._self_certified:
473                self._split_learnt_self_certified(
474                    train_dataset=train_dataset,
475                    test_dataset=test_dataset,
476                    split_config=split_config,
477                    loader_kwargs=loader_kwargs,
478                )
479            else:
480                self._split_learnt_not_self_certified(
481                    train_dataset=train_dataset,
482                    test_dataset=test_dataset,
483                    split_config=split_config,
484                    loader_kwargs=loader_kwargs,
485                )
486        elif self._prior_type == "learnt_with_test":
487            if self._self_certified:
488                self._split_learnt_self_certified_with_test(
489                    train_dataset=train_dataset,
490                    test_dataset=test_dataset,
491                    split_config=split_config,
492                    loader_kwargs=loader_kwargs,
493                )
494            else:
495                raise ValueError(f"Invalid prior_type: {self._prior_type}")
496        else:
497            raise ValueError(f"Invalid prior_type: {self._prior_type}")

A split strategy implementing a Prior-Posterior-Bound (PBP) partition of the dataset.

This strategy supports data splits for:
  • Posterior training
  • Prior training
  • Validation
  • Test
  • Bound evaluation (data for calculating PAC-Bayes bounds)

By changing the internal splitting methods, one can adapt different scenarios such as:

  • 'not_learnt': The prior is not trained.
  • 'learnt': The prior is trained on some portion of the data.
  • 'learnt_with_test': Similar to 'learnt', but includes an explicit test subset.
PBPSplitStrategy( prior_type: str, train_percent: float, val_percent: float, prior_percent: float, self_certified: bool)
43    def __init__(
44        self,
45        prior_type: str,
46        train_percent: float,
47        val_percent: float,
48        prior_percent: float,
49        self_certified: bool,
50    ):
51        """
52        Initialize the PBPSplitStrategy with user-defined parameters for how to partition the data.
53
54        Args:
55            prior_type (str): Indicates whether the prior is "not_learnt", "learnt", or "learnt_with_test".
56            train_percent (float): Proportion of data used for training.
57            val_percent (float): Proportion of data used for validation.
58            prior_percent (float): Proportion of data used specifically for training the prior.
59            self_certified (bool): If True, indicates self-certified splitting approach.
60        """
61        self._prior_type = prior_type
62        self._train_percent = train_percent
63        self._val_percent = val_percent
64        self._prior_percent = prior_percent
65        self._self_certified = self_certified

Initialize the PBPSplitStrategy with user-defined parameters for how to partition the data.

Arguments:
  • prior_type (str): Indicates whether the prior is "not_learnt", "learnt", or "learnt_with_test".
  • train_percent (float): Proportion of data used for training.
  • val_percent (float): Proportion of data used for validation.
  • prior_percent (float): Proportion of data used specifically for training the prior.
  • self_certified (bool): If True, indicates self-certified splitting approach.
posterior_loader: torch.utils.data.dataloader.DataLoader = None
prior_loader: torch.utils.data.dataloader.DataLoader = None
val_loader: torch.utils.data.dataloader.DataLoader = None
test_loader: torch.utils.data.dataloader.DataLoader = None
test_1batch: torch.utils.data.dataloader.DataLoader = None
bound_loader: torch.utils.data.dataloader.DataLoader = None
bound_loader_1batch: torch.utils.data.dataloader.DataLoader = None
def split( self, dataset_loader: Union[ForwardRef('MNISTLoader'), ForwardRef('CIFAR10Loader')], split_config: dict) -> None:
445    def split(
446        self, dataset_loader: Union["MNISTLoader", "CIFAR10Loader"], split_config: dict
447    ) -> None:
448        """
449        Public method to perform the split operation on a dataset loader,
450        setting up DataLoaders for prior, posterior, validation, testing, and bound evaluation.
451
452        Args:
453            dataset_loader (Union[MNISTLoader, CIFAR10Loader]): A dataset loader instance
454                providing `load(dataset_loader_seed)` to retrieve train/test datasets.
455            split_config (Dict): Configuration parameters for splitting (e.g., batch_size, seed).
456        """
457        dataset_loader_seed = split_config["dataset_loader_seed"]
458        train_dataset, test_dataset = dataset_loader.load(dataset_loader_seed)
459
460        loader_kwargs = (
461            {"num_workers": 1, "pin_memory": True} if torch.cuda.is_available() else {}
462        )
463
464        if self._prior_type == "not_learnt":
465            self._split_not_learnt(
466                train_dataset=train_dataset,
467                test_dataset=test_dataset,
468                split_config=split_config,
469                loader_kwargs=loader_kwargs,
470            )
471        elif self._prior_type == "learnt":
472            if self._self_certified:
473                self._split_learnt_self_certified(
474                    train_dataset=train_dataset,
475                    test_dataset=test_dataset,
476                    split_config=split_config,
477                    loader_kwargs=loader_kwargs,
478                )
479            else:
480                self._split_learnt_not_self_certified(
481                    train_dataset=train_dataset,
482                    test_dataset=test_dataset,
483                    split_config=split_config,
484                    loader_kwargs=loader_kwargs,
485                )
486        elif self._prior_type == "learnt_with_test":
487            if self._self_certified:
488                self._split_learnt_self_certified_with_test(
489                    train_dataset=train_dataset,
490                    test_dataset=test_dataset,
491                    split_config=split_config,
492                    loader_kwargs=loader_kwargs,
493                )
494            else:
495                raise ValueError(f"Invalid prior_type: {self._prior_type}")
496        else:
497            raise ValueError(f"Invalid prior_type: {self._prior_type}")

Public method to perform the split operation on a dataset loader, setting up DataLoaders for prior, posterior, validation, testing, and bound evaluation.

Arguments:
  • dataset_loader (Union[MNISTLoader, CIFAR10Loader]): A dataset loader instance providing load(dataset_loader_seed) to retrieve train/test datasets.
  • split_config (Dict): Configuration parameters for splitting (e.g., batch_size, seed).