1- from typing import Optional , List , Dict , Tuple , Union
1+ from typing import Optional , List , Dict , Union
22import math
33import torch
44import numpy as np
55
6+
67__all__ = [
78 "SplittingStrategy" ,
89 "RandomSplit" ,
@@ -280,7 +281,6 @@ def split(self, dataset, *split_sizes) -> List[List[int]]:
280281 Returns:
281282 List of index lists into dataset.plan, one per split.
282283 """
283- # Import here to avoid circular import (MergedDataset imports splitting)
284284 from schnetpack .datasets .merge_db import MergedDataset
285285
286286 if not isinstance (dataset , MergedDataset ):
@@ -302,9 +302,11 @@ def split(self, dataset, *split_sizes) -> List[List[int]]:
302302 # Resolve fractional sizes to absolute counts
303303 abs_sizes = absolute_split_sizes (len (dataset ), list (split_sizes ))
304304
305- # Per-dataset counts for each split via largest-remainder method
305+ # Per-dataset counts for each split
306306 counts_per_split = [
307- self ._counts_from_proportions (size , norm , dataset_names )
307+ self ._counts_from_proportions (
308+ size , norm , dataset_names
309+ ) ## largest-remainder method for safety
308310 for size in abs_sizes
309311 ]
310312
@@ -360,3 +362,52 @@ def _counts_from_proportions(
360362 base [order [i % len (order )]] += 1
361363
362364 return base
365+
366+
367+ """
368+ - MD17: 200,000 samples
369+ - rMD17: 80,000 samples
370+ - Total merged: 300,000
371+ - `num_train=0.8` → 240,000 train samples 120
372+ - `num_val=0.1` → 30,000 val samples
373+ - `num_test=0.1` → 30,000 test samples
374+ - Proportions: `{"md17": 0.7, "rmd17": 0.3}`
375+
376+ Step 1 — Normalise proportions**
377+
378+ md17: 0.7 / (0.7+0.3) = 0.7
379+ rmd17: 0.3 / (0.7+0.3) = 0.3
380+
381+ Step 2 — Figure out how many samples per dataset per split
382+
383+ For train (240,000 total):
384+ md17: 0.7 x 240,000 = 168,000
385+ rmd17: 0.3 x 240,000 = 72,000
386+
387+ For val (30,000 total):
388+ md17: 0.7 x 30,000 = 21,000
389+ rmd17: 0.3 x 30,000 = 9,000
390+
391+ For test (30,000 total):
392+ md17: 0.7 x 30,000 = 21,000
393+ rmd17: 0.3 x 30,000 = 9,000
394+
395+ Step 3 — Check availability
396+ md17 needs: 168,000 + 21,000 + 21,000 = 210,000 — have 200,000 → raises error
397+ rmd17 needs: 72,000 + 9,000 + 9,000 = 90,000 — have 100,000
398+
399+ Step 4 — Build per-name index pools
400+ plan_indices_by_name = {
401+ "md17": [0, 1, 2, ..., 199999], # positions in the plan
402+ "rmd17": [200000, 200001, ..., 299999],
403+ }
404+
405+ Step 5 — Sample without replacement
406+ md17 chosen = rng.choice(200000 indices, size=210000) # error, not enough
407+ rmd17 chosen = rng.choice(100000 indices, size=90000, replace=False)
408+ → first 72000 go to train
409+ → next 9000 go to val
410+ → last 9000 go to test
411+
412+ Step 6 — Shuffle each split
413+ """
0 commit comments