@@ -255,14 +255,10 @@ class ProportionalSplit(SplittingStrategy):
255255 """
256256 Splitting strategy for MergedDataset that preserves a fixed per-dataset
257257 proportion in every split (train / val / test).
258-
259- ##NOTE: ASK STEFAAN
260- Sampling is without replacement — no sample appears in more than one split.
261-
262258 Args:
263259 proportions: mapping from dataset name to relative weight.
264- Values are normalised to sum=1 internally, so
265- {"md17": 1, "rmd17": 1 } == {"md17": 0.5 , "rmd17": 0.5 }.
260+ Normalised to sum=1 internally, so
261+ {"md17": 1, "rmd17": 9 } == {"md17": 0.1 , "rmd17": 0.9 }.
266262 seed: random seed for reproducible sampling.
267263 """
268264
@@ -274,16 +270,14 @@ def __init__(self, proportions: Dict[str, float], seed: int = 42) -> None:
274270 def split (self , dataset , * split_sizes ) -> List [List [int ]]:
275271 """
276272 Args:
277- dataset: a MergedDataset instance.
273+ dataset: a MergedDataset instance (duck-typed via plan + datasets) .
278274 *split_sizes: sizes for each split (absolute or fractional),
279275 forwarded directly from AtomsDataModuleV2.
280276
281277 Returns:
282278 List of index lists into dataset.plan, one per split.
283279 """
284- from schnetpack .datasets .merge_db import MergedDataset
285-
286- if not isinstance (dataset , MergedDataset ):
280+ if not hasattr (dataset , "plan" ) or not hasattr (dataset , "datasets" ):
287281 raise ValueError (
288282 "ProportionalSplit only works with MergedDataset instances."
289283 )
@@ -299,38 +293,31 @@ def split(self, dataset, *split_sizes) -> List[List[int]]:
299293 total = float (sum (self .proportions [n ] for n in dataset_names ))
300294 norm = {n : self .proportions [n ] / total for n in dataset_names }
301295
302- # Resolve fractional sizes to absolute counts
296+ # Resolve fractional/absolute split sizes to absolute counts
303297 abs_sizes = absolute_split_sizes (len (dataset ), list (split_sizes ))
304298
305- # Per-dataset counts for each split
306- counts_per_split = [
307- self ._counts_from_proportions (
308- size , norm , dataset_names
309- ) ## largest-remainder method for safety
310- for size in abs_sizes
311- ]
312-
313299 # Build per-name pools: positions in dataset.plan
314300 plan_indices_by_name : Dict [str , List [int ]] = {n : [] for n in dataset_names }
315301 for pos , (dataset_name , _ ) in enumerate (dataset .plan ):
316302 plan_indices_by_name [dataset_name ].append (pos )
317303
318- # Validate we have enough samples per dataset across all splits
319- for name in dataset_names :
320- needed = sum (c [name ] for c in counts_per_split )
321- available = len (plan_indices_by_name [name ])
322- if needed > available :
323- raise ValueError (
324- f"Not enough samples in '{ name } ': "
325- f"need { needed } , have { available } ."
326- )
304+ counts_per_split = self ._proportional_counts (
305+ plan_indices_by_name , abs_sizes , norm , dataset_names
306+ )
327307
328308 # Sample without replacement then slice into splits
329309 result : List [List [int ]] = [[] for _ in abs_sizes ]
330310
331311 for name in dataset_names :
332312 pool = np .array (plan_indices_by_name [name ])
333313 total_needed = sum (c [name ] for c in counts_per_split )
314+
315+ if total_needed > len (pool ):
316+ raise ValueError (
317+ f"Not enough samples in '{ name } ': "
318+ f"need { total_needed } , have { len (pool )} ."
319+ )
320+
334321 chosen = rng .choice (pool , size = total_needed , replace = False )
335322
336323 offset = 0
@@ -339,75 +326,42 @@ def split(self, dataset, *split_sizes) -> List[List[int]]:
339326 result [split_idx ].extend (chosen [offset : offset + n ].tolist ())
340327 offset += n
341328
342- # Shuffle each split so datasets are interleaved, not blocked by source
329+ # Shuffle so datasets are interleaved within each split
343330 for split_indices in result :
344331 rng .shuffle (split_indices )
345332
346333 return result
347334
348335 @staticmethod
349- def _counts_from_proportions (
350- split_size : int ,
351- proportions : Dict [str , float ],
352- names : List [str ],
353- ) -> Dict [str , int ]:
354- """Largest-remainder allocation of split_size across datasets."""
355- raw = {n : proportions [n ] * split_size for n in names }
356- base = {n : int (np .floor (raw [n ])) for n in names }
357- remainder = split_size - sum (base .values ())
358-
359- if remainder > 0 :
360- order = sorted (names , key = lambda n : raw [n ] - base [n ], reverse = True )
361- for i in range (remainder ):
362- base [order [i % len (order )]] += 1
363-
364- return base
365-
366-
367- """
368- - MD17: 200,000 samples
369- - rMD17: 80,000 samples
370- - Total merged: 300,000
371- - `num_train=0.8` → 240,000 train samples 120
372- - `num_val=0.1` → 30,000 val samples
373- - `num_test=0.1` → 30,000 test samples
374- - Proportions: `{"md17": 0.7, "rmd17": 0.3}`
375-
376- Step 1 — Normalise proportions**
377-
378- md17: 0.7 / (0.7+0.3) = 0.7
379- rmd17: 0.3 / (0.7+0.3) = 0.3
380-
381- Step 2 — Figure out how many samples per dataset per split
382-
383- For train (240,000 total):
384- md17: 0.7 x 240,000 = 168,000
385- rmd17: 0.3 x 240,000 = 72,000
386-
387- For val (30,000 total):
388- md17: 0.7 x 30,000 = 21,000
389- rmd17: 0.3 x 30,000 = 9,000
390-
391- For test (30,000 total):
392- md17: 0.7 x 30,000 = 21,000
393- rmd17: 0.3 x 30,000 = 9,000
394-
395- Step 3 — Check availability
396- md17 needs: 168,000 + 21,000 + 21,000 = 210,000 — have 200,000 → raises error
397- rmd17 needs: 72,000 + 9,000 + 9,000 = 90,000 — have 100,000
336+ def _proportional_counts (
337+ plan_indices_by_name : Dict [str , List [int ]],
338+ abs_sizes : List [int ],
339+ norm : Dict [str , float ],
340+ dataset_names : List [str ],
341+ ) -> List [Dict [str , int ]]:
342+ """
343+ Split each dataset's available pool across train/val/test
344+ using the same ratio as abs_sizes, ignoring target proportions.
345+ DatasetBalancedSampler handles the proportion imbalance during training.
346+ """
347+ total = sum (abs_sizes )
348+ split_ratios = [s / total for s in abs_sizes ]
398349
399- Step 4 — Build per-name index pools
400- plan_indices_by_name = {
401- "md17": [0, 1, 2, ..., 199999], # positions in the plan
402- "rmd17": [200000, 200001, ..., 299999],
403- }
350+ counts_per_split : List [Dict [str , int ]] = [{} for _ in abs_sizes ]
404351
405- Step 5 — Sample without replacement
406- md17 chosen = rng.choice(200000 indices, size=210000) # error, not enough
407- rmd17 chosen = rng.choice(100000 indices, size=90000, replace=False)
408- → first 72000 go to train
409- → next 9000 go to val
410- → last 9000 go to test
352+ for name in dataset_names :
353+ available = len (plan_indices_by_name [name ])
354+ # split this dataset's pool by the same train/val/test ratio
355+ raw = [r * available for r in split_ratios ]
356+ base = [int (np .floor (r )) for r in raw ]
357+ remainder = available - sum (base )
358+ # distribute remainder by largest fractional part
359+ order = sorted (
360+ range (len (abs_sizes )), key = lambda i : raw [i ] - base [i ], reverse = True
361+ )
362+ for i in range (remainder ):
363+ base [order [i % len (order )]] += 1
364+ for split_idx , count in enumerate (base ):
365+ counts_per_split [split_idx ][name ] = count
411366
412- Step 6 — Shuffle each split
413- """
367+ return counts_per_split
0 commit comments