core.split_strategy.FaultySplitStrategy

  1from dataclasses import dataclass
  2
  3import numpy as np
  4import torch
  5from torch.utils import data
  6from torch.utils.data.sampler import SubsetRandomSampler
  7
  8from core.split_strategy import PBPSplitStrategy
  9
 10
 11@dataclass
 12class FaultySplitStrategy(PBPSplitStrategy):
 13    """
 14    A specialized (and potentially buggy) subclass of PBPSplitStrategy that demonstrates
 15    alternative splitting logic or partial overlaps between dataset subsets.
 16
 17    Fields:
 18        posterior_loader (DataLoader): DataLoader for posterior training.
 19        prior_loader (DataLoader): DataLoader for prior training.
 20        val_loader (DataLoader): DataLoader for validation set.
 21        test_loader (DataLoader): DataLoader for test set.
 22        test_1batch (DataLoader): DataLoader for test set (one big batch).
 23        bound_loader (DataLoader): DataLoader for bound evaluation set.
 24        bound_loader_1batch (DataLoader): DataLoader for bound evaluation set (one big batch).
 25    """
 26
 27    # Posterior training
 28    posterior_loader: data.dataloader.DataLoader = None
 29    # Prior training
 30    prior_loader: data.dataloader.DataLoader = None
 31    # Evaluation
 32    val_loader: data.dataloader.DataLoader = None
 33    test_loader: data.dataloader.DataLoader = None
 34    test_1batch: data.dataloader.DataLoader = None
 35    # Bounds evaluation
 36    bound_loader: data.dataloader.DataLoader = None
 37    bound_loader_1batch: data.dataloader.DataLoader = None
 38
 39    def __init__(
 40        self,
 41        prior_type: str,
 42        train_percent: float,
 43        val_percent: float,
 44        prior_percent: float,
 45        self_certified: bool,
 46    ):
 47        """
 48        Initialize the FaultySplitStrategy with user-defined percentages
 49        and flags for how to partition the dataset.
 50
 51        Args:
 52            prior_type (str): A string indicating how the prior is handled (e.g. "not_learnt", "learnt").
 53            train_percent (float): Fraction of dataset to use for training.
 54            val_percent (float): Fraction of dataset to use for validation.
 55            prior_percent (float): Fraction of dataset to use for prior training.
 56            self_certified (bool): If True, indicates self-certified approach to data splitting.
 57        """
 58        super().__init__(
 59            prior_type, train_percent, val_percent, prior_percent, self_certified
 60        )
 61
 62    def _split_not_learnt(
 63        self,
 64        train_dataset: data.Dataset,
 65        test_dataset: data.Dataset,
 66        split_config: dict,
 67        loader_kwargs: dict,
 68    ) -> None:
 69        """
 70        Split the data for the case when the prior is not learned from data.
 71
 72        Args:
 73            train_dataset (Dataset): The training dataset.
 74            test_dataset (Dataset): The test dataset.
 75            split_config (Dict): A configuration dictionary containing keys like 'batch_size', 'seed', etc.
 76            loader_kwargs (Dict): Additional keyword arguments for DataLoader (e.g., num_workers).
 77        """
 78        batch_size = split_config["batch_size"]
 79        training_percent = self._train_percent
 80        val_percent = self._val_percent
 81
 82        train_size = len(train_dataset.data)
 83        test_size = len(test_dataset.data)
 84
 85        indices = list(range(train_size))
 86        split = int(np.round((training_percent) * train_size))
 87        np.random.seed(split_config["seed"])
 88        np.random.shuffle(indices)
 89
 90        if val_percent > 0.0:
 91            # compute number of data points
 92            indices = list(range(split))
 93            split_val = int(np.round((val_percent) * split))
 94            train_idx, val_idx = indices[split_val:], indices[:split_val]
 95        else:
 96            train_idx = indices[:split]
 97            val_idx = None
 98
 99        train_sampler = SubsetRandomSampler(train_idx)
100        val_sampler = SubsetRandomSampler(val_idx)
101
102        self.posterior_loader = torch.utils.data.DataLoader(
103            train_dataset,
104            batch_size=batch_size,
105            sampler=train_sampler,
106            **loader_kwargs,
107        )
108        # self.prior_loader = None
109        if val_idx:
110            self.val_loader = torch.utils.data.DataLoader(
111                train_dataset,
112                batch_size=batch_size,
113                sampler=val_sampler,
114                shuffle=False,
115            )
116        self.test_loader = torch.utils.data.DataLoader(
117            test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
118        )
119        self.test_1batch = torch.utils.data.DataLoader(
120            test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs
121        )
122        self.bound_loader = torch.utils.data.DataLoader(
123            train_dataset,
124            batch_size=batch_size,
125            sampler=train_sampler,
126            **loader_kwargs,
127        )
128        self.bound_loader_1batch = torch.utils.data.DataLoader(
129            train_dataset,
130            batch_size=len(train_idx),
131            sampler=train_sampler,
132            **loader_kwargs,
133        )
134
135    def _split_learnt_self_certified(
136        self,
137        train_dataset: data.Dataset,
138        test_dataset: data.Dataset,
139        split_config: dict,
140        loader_kwargs: dict,
141    ) -> None:
142        """
143        Split logic when the prior is learned from data and we use a self-certified approach.
144
145        Args:
146            train_dataset (Dataset): The training dataset.
147            test_dataset (Dataset): The test dataset.
148            split_config (Dict): A dictionary of split settings (batch size, seed, etc.).
149            loader_kwargs (Dict): Keyword arguments for DataLoader initialization.
150        """
151        batch_size = split_config["batch_size"]
152        training_percent = self._train_percent
153        val_percent = self._val_percent
154        prior_percent = self._prior_percent
155
156        n = len(train_dataset.data) + len(test_dataset.data)
157
158        # reduce training data if needed
159        new_num_train = int(np.round((training_percent) * n))
160        indices = list(range(new_num_train))
161        split = int(np.round((prior_percent) * new_num_train))
162        np.random.seed(split_config["seed"])
163        np.random.shuffle(indices)
164
165        all_train_sampler = SubsetRandomSampler(indices)
166        if val_percent > 0.0:
167            bound_idx = indices[split:]
168            indices_prior = list(range(split))
169            _all_prior_sampler = SubsetRandomSampler(indices_prior)
170            split_val = int(np.round((val_percent) * split))
171            prior_idx, val_idx = indices_prior[split_val:], indices_prior[:split_val]
172        else:
173            bound_idx, prior_idx = indices[split:], indices[:split]
174            val_idx = None
175
176        bound_sampler = SubsetRandomSampler(bound_idx)
177        prior_sampler = SubsetRandomSampler(prior_idx)
178        val_sampler = SubsetRandomSampler(val_idx)
179
180        final_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
181
182        self.posterior_loader = torch.utils.data.DataLoader(
183            final_dataset,
184            batch_size=batch_size,
185            sampler=all_train_sampler,
186            shuffle=False,
187            **loader_kwargs,
188        )
189        self.prior_loader = torch.utils.data.DataLoader(
190            final_dataset,
191            batch_size=batch_size,
192            sampler=prior_sampler,
193            shuffle=False,
194            **loader_kwargs,
195        )
196        if val_idx:
197            self.val_loader = torch.utils.data.DataLoader(
198                final_dataset,
199                batch_size=batch_size,
200                sampler=val_sampler,
201                shuffle=False,
202                **loader_kwargs,
203            )
204        # self.test_loader = None
205        # self.test_1batch = None
206        self.bound_loader = torch.utils.data.DataLoader(
207            final_dataset,
208            batch_size=batch_size,
209            sampler=bound_sampler,
210            shuffle=False,
211            **loader_kwargs,
212        )
213        self.bound_loader_1batch = torch.utils.data.DataLoader(
214            final_dataset,
215            batch_size=len(bound_idx),
216            sampler=bound_sampler,
217            **loader_kwargs,
218        )
219
220    def _split_learnt_not_self_certified(
221        self,
222        train_dataset: data.Dataset,
223        test_dataset: data.Dataset,
224        split_config: dict,
225        loader_kwargs: dict,
226    ) -> None:
227        """
228        Split logic for a learned prior without self-certification.
229
230        Args:
231            train_dataset (Dataset): The training dataset.
232            test_dataset (Dataset): The test dataset.
233            split_config (Dict): Dictionary with split hyperparameters (batch size, seed, etc.).
234            loader_kwargs (Dict): Additional arguments for torch.utils.data.DataLoader.
235        """
236        batch_size = split_config["batch_size"]
237        training_percent = self._train_percent
238        val_percent = self._val_percent
239        prior_percent = self._prior_percent
240
241        train_size = len(train_dataset.data)
242        test_size = len(test_dataset.data)
243
244        new_num_train = int(np.round((training_percent) * train_size))
245        indices = list(range(new_num_train))
246        split = int(np.round((prior_percent) * new_num_train))
247        np.random.seed(split_config["seed"])
248        np.random.shuffle(indices)
249
250        all_train_sampler = SubsetRandomSampler(indices)
251        # train_idx, valid_idx = indices[split:], indices[:split]
252        if val_percent > 0.0:
253            bound_idx = indices[split:]
254            indices_prior = list(range(split))
255            _all_prior_sampler = SubsetRandomSampler(indices_prior)
256            split_val = int(np.round((val_percent) * split))
257            prior_idx, val_idx = indices_prior[split_val:], indices_prior[:split_val]
258        else:
259            bound_idx, prior_idx = indices[split:], indices[:split]
260            val_idx = None
261
262        bound_sampler = SubsetRandomSampler(bound_idx)
263        prior_sampler = SubsetRandomSampler(prior_idx)
264        val_sampler = SubsetRandomSampler(val_idx)
265
266        self.posterior_loader = torch.utils.data.DataLoader(
267            train_dataset,
268            batch_size=batch_size,
269            sampler=all_train_sampler,
270            shuffle=False,
271            **loader_kwargs,
272        )
273        self.prior_loader = torch.utils.data.DataLoader(
274            train_dataset,
275            batch_size=batch_size,
276            sampler=prior_sampler,
277            shuffle=False,
278            **loader_kwargs,
279        )
280        if val_idx:
281            self.val_loader = torch.utils.data.DataLoader(
282                train_dataset,
283                batch_size=batch_size,
284                sampler=val_sampler,
285                shuffle=False,
286                **loader_kwargs,
287            )
288        self.test_loader = torch.utils.data.DataLoader(
289            test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
290        )
291        self.test_1batch = torch.utils.data.DataLoader(
292            test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs
293        )
294        self.bound_loader = torch.utils.data.DataLoader(
295            train_dataset,
296            batch_size=batch_size,
297            sampler=bound_sampler,
298            shuffle=False,
299        )
300        self.bound_loader_1batch = torch.utils.data.DataLoader(
301            train_dataset,
302            batch_size=len(bound_idx),
303            sampler=bound_sampler,
304            **loader_kwargs,
305        )
@dataclass
class FaultySplitStrategy(core.split_strategy.PBPSplitStrategy.PBPSplitStrategy):
 12@dataclass
 13class FaultySplitStrategy(PBPSplitStrategy):
 14    """
 15    A specialized (and potentially buggy) subclass of PBPSplitStrategy that demonstrates
 16    alternative splitting logic or partial overlaps between dataset subsets.
 17
 18    Fields:
 19        posterior_loader (DataLoader): DataLoader for posterior training.
 20        prior_loader (DataLoader): DataLoader for prior training.
 21        val_loader (DataLoader): DataLoader for validation set.
 22        test_loader (DataLoader): DataLoader for test set.
 23        test_1batch (DataLoader): DataLoader for test set (one big batch).
 24        bound_loader (DataLoader): DataLoader for bound evaluation set.
 25        bound_loader_1batch (DataLoader): DataLoader for bound evaluation set (one big batch).
 26    """
 27
 28    # Posterior training
 29    posterior_loader: data.dataloader.DataLoader = None
 30    # Prior training
 31    prior_loader: data.dataloader.DataLoader = None
 32    # Evaluation
 33    val_loader: data.dataloader.DataLoader = None
 34    test_loader: data.dataloader.DataLoader = None
 35    test_1batch: data.dataloader.DataLoader = None
 36    # Bounds evaluation
 37    bound_loader: data.dataloader.DataLoader = None
 38    bound_loader_1batch: data.dataloader.DataLoader = None
 39
 40    def __init__(
 41        self,
 42        prior_type: str,
 43        train_percent: float,
 44        val_percent: float,
 45        prior_percent: float,
 46        self_certified: bool,
 47    ):
 48        """
 49        Initialize the FaultySplitStrategy with user-defined percentages
 50        and flags for how to partition the dataset.
 51
 52        Args:
 53            prior_type (str): A string indicating how the prior is handled (e.g. "not_learnt", "learnt").
 54            train_percent (float): Fraction of dataset to use for training.
 55            val_percent (float): Fraction of dataset to use for validation.
 56            prior_percent (float): Fraction of dataset to use for prior training.
 57            self_certified (bool): If True, indicates self-certified approach to data splitting.
 58        """
 59        super().__init__(
 60            prior_type, train_percent, val_percent, prior_percent, self_certified
 61        )
 62
 63    def _split_not_learnt(
 64        self,
 65        train_dataset: data.Dataset,
 66        test_dataset: data.Dataset,
 67        split_config: dict,
 68        loader_kwargs: dict,
 69    ) -> None:
 70        """
 71        Split the data for the case when the prior is not learned from data.
 72
 73        Args:
 74            train_dataset (Dataset): The training dataset.
 75            test_dataset (Dataset): The test dataset.
 76            split_config (Dict): A configuration dictionary containing keys like 'batch_size', 'seed', etc.
 77            loader_kwargs (Dict): Additional keyword arguments for DataLoader (e.g., num_workers).
 78        """
 79        batch_size = split_config["batch_size"]
 80        training_percent = self._train_percent
 81        val_percent = self._val_percent
 82
 83        train_size = len(train_dataset.data)
 84        test_size = len(test_dataset.data)
 85
 86        indices = list(range(train_size))
 87        split = int(np.round((training_percent) * train_size))
 88        np.random.seed(split_config["seed"])
 89        np.random.shuffle(indices)
 90
 91        if val_percent > 0.0:
 92            # compute number of data points
 93            indices = list(range(split))
 94            split_val = int(np.round((val_percent) * split))
 95            train_idx, val_idx = indices[split_val:], indices[:split_val]
 96        else:
 97            train_idx = indices[:split]
 98            val_idx = None
 99
100        train_sampler = SubsetRandomSampler(train_idx)
101        val_sampler = SubsetRandomSampler(val_idx)
102
103        self.posterior_loader = torch.utils.data.DataLoader(
104            train_dataset,
105            batch_size=batch_size,
106            sampler=train_sampler,
107            **loader_kwargs,
108        )
109        # self.prior_loader = None
110        if val_idx:
111            self.val_loader = torch.utils.data.DataLoader(
112                train_dataset,
113                batch_size=batch_size,
114                sampler=val_sampler,
115                shuffle=False,
116            )
117        self.test_loader = torch.utils.data.DataLoader(
118            test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
119        )
120        self.test_1batch = torch.utils.data.DataLoader(
121            test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs
122        )
123        self.bound_loader = torch.utils.data.DataLoader(
124            train_dataset,
125            batch_size=batch_size,
126            sampler=train_sampler,
127            **loader_kwargs,
128        )
129        self.bound_loader_1batch = torch.utils.data.DataLoader(
130            train_dataset,
131            batch_size=len(train_idx),
132            sampler=train_sampler,
133            **loader_kwargs,
134        )
135
136    def _split_learnt_self_certified(
137        self,
138        train_dataset: data.Dataset,
139        test_dataset: data.Dataset,
140        split_config: dict,
141        loader_kwargs: dict,
142    ) -> None:
143        """
144        Split logic when the prior is learned from data and we use a self-certified approach.
145
146        Args:
147            train_dataset (Dataset): The training dataset.
148            test_dataset (Dataset): The test dataset.
149            split_config (Dict): A dictionary of split settings (batch size, seed, etc.).
150            loader_kwargs (Dict): Keyword arguments for DataLoader initialization.
151        """
152        batch_size = split_config["batch_size"]
153        training_percent = self._train_percent
154        val_percent = self._val_percent
155        prior_percent = self._prior_percent
156
157        n = len(train_dataset.data) + len(test_dataset.data)
158
159        # reduce training data if needed
160        new_num_train = int(np.round((training_percent) * n))
161        indices = list(range(new_num_train))
162        split = int(np.round((prior_percent) * new_num_train))
163        np.random.seed(split_config["seed"])
164        np.random.shuffle(indices)
165
166        all_train_sampler = SubsetRandomSampler(indices)
167        if val_percent > 0.0:
168            bound_idx = indices[split:]
169            indices_prior = list(range(split))
170            _all_prior_sampler = SubsetRandomSampler(indices_prior)
171            split_val = int(np.round((val_percent) * split))
172            prior_idx, val_idx = indices_prior[split_val:], indices_prior[:split_val]
173        else:
174            bound_idx, prior_idx = indices[split:], indices[:split]
175            val_idx = None
176
177        bound_sampler = SubsetRandomSampler(bound_idx)
178        prior_sampler = SubsetRandomSampler(prior_idx)
179        val_sampler = SubsetRandomSampler(val_idx)
180
181        final_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])
182
183        self.posterior_loader = torch.utils.data.DataLoader(
184            final_dataset,
185            batch_size=batch_size,
186            sampler=all_train_sampler,
187            shuffle=False,
188            **loader_kwargs,
189        )
190        self.prior_loader = torch.utils.data.DataLoader(
191            final_dataset,
192            batch_size=batch_size,
193            sampler=prior_sampler,
194            shuffle=False,
195            **loader_kwargs,
196        )
197        if val_idx:
198            self.val_loader = torch.utils.data.DataLoader(
199                final_dataset,
200                batch_size=batch_size,
201                sampler=val_sampler,
202                shuffle=False,
203                **loader_kwargs,
204            )
205        # self.test_loader = None
206        # self.test_1batch = None
207        self.bound_loader = torch.utils.data.DataLoader(
208            final_dataset,
209            batch_size=batch_size,
210            sampler=bound_sampler,
211            shuffle=False,
212            **loader_kwargs,
213        )
214        self.bound_loader_1batch = torch.utils.data.DataLoader(
215            final_dataset,
216            batch_size=len(bound_idx),
217            sampler=bound_sampler,
218            **loader_kwargs,
219        )
220
221    def _split_learnt_not_self_certified(
222        self,
223        train_dataset: data.Dataset,
224        test_dataset: data.Dataset,
225        split_config: dict,
226        loader_kwargs: dict,
227    ) -> None:
228        """
229        Split logic for a learned prior without self-certification.
230
231        Args:
232            train_dataset (Dataset): The training dataset.
233            test_dataset (Dataset): The test dataset.
234            split_config (Dict): Dictionary with split hyperparameters (batch size, seed, etc.).
235            loader_kwargs (Dict): Additional arguments for torch.utils.data.DataLoader.
236        """
237        batch_size = split_config["batch_size"]
238        training_percent = self._train_percent
239        val_percent = self._val_percent
240        prior_percent = self._prior_percent
241
242        train_size = len(train_dataset.data)
243        test_size = len(test_dataset.data)
244
245        new_num_train = int(np.round((training_percent) * train_size))
246        indices = list(range(new_num_train))
247        split = int(np.round((prior_percent) * new_num_train))
248        np.random.seed(split_config["seed"])
249        np.random.shuffle(indices)
250
251        all_train_sampler = SubsetRandomSampler(indices)
252        # train_idx, valid_idx = indices[split:], indices[:split]
253        if val_percent > 0.0:
254            bound_idx = indices[split:]
255            indices_prior = list(range(split))
256            _all_prior_sampler = SubsetRandomSampler(indices_prior)
257            split_val = int(np.round((val_percent) * split))
258            prior_idx, val_idx = indices_prior[split_val:], indices_prior[:split_val]
259        else:
260            bound_idx, prior_idx = indices[split:], indices[:split]
261            val_idx = None
262
263        bound_sampler = SubsetRandomSampler(bound_idx)
264        prior_sampler = SubsetRandomSampler(prior_idx)
265        val_sampler = SubsetRandomSampler(val_idx)
266
267        self.posterior_loader = torch.utils.data.DataLoader(
268            train_dataset,
269            batch_size=batch_size,
270            sampler=all_train_sampler,
271            shuffle=False,
272            **loader_kwargs,
273        )
274        self.prior_loader = torch.utils.data.DataLoader(
275            train_dataset,
276            batch_size=batch_size,
277            sampler=prior_sampler,
278            shuffle=False,
279            **loader_kwargs,
280        )
281        if val_idx:
282            self.val_loader = torch.utils.data.DataLoader(
283                train_dataset,
284                batch_size=batch_size,
285                sampler=val_sampler,
286                shuffle=False,
287                **loader_kwargs,
288            )
289        self.test_loader = torch.utils.data.DataLoader(
290            test_dataset, batch_size=batch_size, shuffle=True, **loader_kwargs
291        )
292        self.test_1batch = torch.utils.data.DataLoader(
293            test_dataset, batch_size=test_size, shuffle=True, **loader_kwargs
294        )
295        self.bound_loader = torch.utils.data.DataLoader(
296            train_dataset,
297            batch_size=batch_size,
298            sampler=bound_sampler,
299            shuffle=False,
300        )
301        self.bound_loader_1batch = torch.utils.data.DataLoader(
302            train_dataset,
303            batch_size=len(bound_idx),
304            sampler=bound_sampler,
305            **loader_kwargs,
306        )

A specialized (and potentially buggy) subclass of PBPSplitStrategy that demonstrates alternative splitting logic or partial overlaps between dataset subsets.

Fields:

posterior_loader (DataLoader): DataLoader for posterior training. prior_loader (DataLoader): DataLoader for prior training. val_loader (DataLoader): DataLoader for validation set. test_loader (DataLoader): DataLoader for test set. test_1batch (DataLoader): DataLoader for test set (one big batch). bound_loader (DataLoader): DataLoader for bound evaluation set. bound_loader_1batch (DataLoader): DataLoader for bound evaluation set (one big batch).

FaultySplitStrategy( prior_type: str, train_percent: float, val_percent: float, prior_percent: float, self_certified: bool)
40    def __init__(
41        self,
42        prior_type: str,
43        train_percent: float,
44        val_percent: float,
45        prior_percent: float,
46        self_certified: bool,
47    ):
48        """
49        Initialize the FaultySplitStrategy with user-defined percentages
50        and flags for how to partition the dataset.
51
52        Args:
53            prior_type (str): A string indicating how the prior is handled (e.g. "not_learnt", "learnt").
54            train_percent (float): Fraction of dataset to use for training.
55            val_percent (float): Fraction of dataset to use for validation.
56            prior_percent (float): Fraction of dataset to use for prior training.
57            self_certified (bool): If True, indicates self-certified approach to data splitting.
58        """
59        super().__init__(
60            prior_type, train_percent, val_percent, prior_percent, self_certified
61        )

Initialize the FaultySplitStrategy with user-defined percentages and flags for how to partition the dataset.

Arguments:
  • prior_type (str): A string indicating how the prior is handled (e.g. "not_learnt", "learnt").
  • train_percent (float): Fraction of dataset to use for training.
  • val_percent (float): Fraction of dataset to use for validation.
  • prior_percent (float): Fraction of dataset to use for prior training.
  • self_certified (bool): If True, indicates self-certified approach to data splitting.
posterior_loader: torch.utils.data.dataloader.DataLoader = None
prior_loader: torch.utils.data.dataloader.DataLoader = None
val_loader: torch.utils.data.dataloader.DataLoader = None
test_loader: torch.utils.data.dataloader.DataLoader = None
test_1batch: torch.utils.data.dataloader.DataLoader = None
bound_loader: torch.utils.data.dataloader.DataLoader = None
bound_loader_1batch: torch.utils.data.dataloader.DataLoader = None