Skip to content

Commit 44bdbcf

Browse files
committed
refactor: clean up ProportionalSplit implementation and enhance MergedDataset initialization
1 parent be816b2 commit 44bdbcf

2 files changed

Lines changed: 61 additions & 8 deletions

File tree

src/schnetpack/data/splitting.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Optional, List, Dict, Tuple, Union
1+
from typing import Optional, List, Dict, Union
22
import math
33
import torch
44
import numpy as np
55

6+
67
__all__ = [
78
"SplittingStrategy",
89
"RandomSplit",
@@ -280,7 +281,6 @@ def split(self, dataset, *split_sizes) -> List[List[int]]:
280281
Returns:
281282
List of index lists into dataset.plan, one per split.
282283
"""
283-
# Import here to avoid circular import (MergedDataset imports splitting)
284284
from schnetpack.datasets.merge_db import MergedDataset
285285

286286
if not isinstance(dataset, MergedDataset):
@@ -302,9 +302,11 @@ def split(self, dataset, *split_sizes) -> List[List[int]]:
302302
# Resolve fractional sizes to absolute counts
303303
abs_sizes = absolute_split_sizes(len(dataset), list(split_sizes))
304304

305-
# Per-dataset counts for each split via largest-remainder method
305+
# Per-dataset counts for each split
306306
counts_per_split = [
307-
self._counts_from_proportions(size, norm, dataset_names)
307+
self._counts_from_proportions(
308+
size, norm, dataset_names
309+
) ## largest-remainder method for safety
308310
for size in abs_sizes
309311
]
310312

@@ -360,3 +362,52 @@ def _counts_from_proportions(
360362
base[order[i % len(order)]] += 1
361363

362364
return base
365+
366+
367+
"""
368+
- MD17: 200,000 samples
369+
- rMD17: 80,000 samples
370+
- Total merged: 300,000
371+
- `num_train=0.8` → 240,000 train samples 120
372+
- `num_val=0.1` → 30,000 val samples
373+
- `num_test=0.1` → 30,000 test samples
374+
- Proportions: `{"md17": 0.7, "rmd17": 0.3}`
375+
376+
Step 1 — Normalise proportions**
377+
378+
md17: 0.7 / (0.7+0.3) = 0.7
379+
rmd17: 0.3 / (0.7+0.3) = 0.3
380+
381+
Step 2 — Figure out how many samples per dataset per split
382+
383+
For train (240,000 total):
384+
md17: 0.7 x 240,000 = 168,000
385+
rmd17: 0.3 x 240,000 = 72,000
386+
387+
For val (30,000 total):
388+
md17: 0.7 x 30,000 = 21,000
389+
rmd17: 0.3 x 30,000 = 9,000
390+
391+
For test (30,000 total):
392+
md17: 0.7 x 30,000 = 21,000
393+
rmd17: 0.3 x 30,000 = 9,000
394+
395+
Step 3 — Check availability
396+
md17 needs: 168,000 + 21,000 + 21,000 = 210,000 — have 200,000 → raises error
397+
rmd17 needs: 72,000 + 9,000 + 9,000 = 90,000 — have 100,000
398+
399+
Step 4 — Build per-name index pools
400+
plan_indices_by_name = {
401+
"md17": [0, 1, 2, ..., 199999], # positions in the plan
402+
"rmd17": [200000, 200001, ..., 299999],
403+
}
404+
405+
Step 5 — Sample without replacement
406+
md17 chosen = rng.choice(200000 indices, size=210000) # error, not enough
407+
rmd17 chosen = rng.choice(100000 indices, size=90000, replace=False)
408+
→ first 72000 go to train
409+
→ next 9000 go to val
410+
→ last 9000 go to test
411+
412+
Step 6 — Shuffle each split
413+
"""

src/schnetpack/datasets/merge_db.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def __init__(
8989
if not datasets:
9090
raise AtomsDataError("datasets must not be empty.")
9191

92-
self._validate_compatibility(datasets)
93-
self._warn_if_component_transforms(datasets)
92+
self._validate_compatibility(datasets) ## removed
93+
self._warn_if_component_transforms(datasets) ## Not necessary
9494

9595
self.datasets = datasets
9696
self.add_source_index = add_source_index
@@ -99,7 +99,9 @@ def __init__(
9999
self._dataset_ids: Dict[str, int] = {name: i for i, name in enumerate(datasets)}
100100

101101
self.plan: List[Tuple[str, int]] = [
102-
(name, idx) for name, ds in datasets.items() for idx in range(len(ds))
102+
(name, idx)
103+
for name, ds in datasets.items()
104+
for idx in range(len(ds)) # ('rmd17',5)
103105
]
104106

105107
self.transforms: List[Transform] = list(transforms or [])
@@ -264,7 +266,7 @@ def __len__(self) -> int:
264266
return len(self.plan)
265267

266268
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
267-
dataset_name, index = self.plan[i]
269+
dataset_name, index = self.plan[i] # ("rmd17", 5)
268270
component_ds = self.datasets[dataset_name]
269271

270272
saved_transforms = component_ds.transforms

0 commit comments

Comments
 (0)