Skip to content

Commit 4485729

Browse files
committed
refactor: enhance ProportionalSplit implementation
1 parent 076a9e3 commit 4485729

1 file changed

Lines changed: 45 additions & 91 deletions

File tree

src/schnetpack/data/splitting.py

Lines changed: 45 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,10 @@ class ProportionalSplit(SplittingStrategy):
255255
"""
256256
Splitting strategy for MergedDataset that preserves a fixed per-dataset
257257
proportion in every split (train / val / test).
258-
259-
##NOTE: ASK STEFAAN
260-
Sampling is without replacement — no sample appears in more than one split.
261-
262258
Args:
263259
proportions: mapping from dataset name to relative weight.
264-
Values are normalised to sum=1 internally, so
265-
{"md17": 1, "rmd17": 1} == {"md17": 0.5, "rmd17": 0.5}.
260+
Normalised to sum=1 internally, so
261+
{"md17": 1, "rmd17": 9} == {"md17": 0.1, "rmd17": 0.9}.
266262
seed: random seed for reproducible sampling.
267263
"""
268264

@@ -274,16 +270,14 @@ def __init__(self, proportions: Dict[str, float], seed: int = 42) -> None:
274270
def split(self, dataset, *split_sizes) -> List[List[int]]:
275271
"""
276272
Args:
277-
dataset: a MergedDataset instance.
273+
dataset: a MergedDataset instance (duck-typed via plan + datasets).
278274
*split_sizes: sizes for each split (absolute or fractional),
279275
forwarded directly from AtomsDataModuleV2.
280276
281277
Returns:
282278
List of index lists into dataset.plan, one per split.
283279
"""
284-
from schnetpack.datasets.merge_db import MergedDataset
285-
286-
if not isinstance(dataset, MergedDataset):
280+
if not hasattr(dataset, "plan") or not hasattr(dataset, "datasets"):
287281
raise ValueError(
288282
"ProportionalSplit only works with MergedDataset instances."
289283
)
@@ -299,38 +293,31 @@ def split(self, dataset, *split_sizes) -> List[List[int]]:
299293
total = float(sum(self.proportions[n] for n in dataset_names))
300294
norm = {n: self.proportions[n] / total for n in dataset_names}
301295

302-
# Resolve fractional sizes to absolute counts
296+
# Resolve fractional/absolute split sizes to absolute counts
303297
abs_sizes = absolute_split_sizes(len(dataset), list(split_sizes))
304298

305-
# Per-dataset counts for each split
306-
counts_per_split = [
307-
self._counts_from_proportions(
308-
size, norm, dataset_names
309-
) ## largest-remainder method for safety
310-
for size in abs_sizes
311-
]
312-
313299
# Build per-name pools: positions in dataset.plan
314300
plan_indices_by_name: Dict[str, List[int]] = {n: [] for n in dataset_names}
315301
for pos, (dataset_name, _) in enumerate(dataset.plan):
316302
plan_indices_by_name[dataset_name].append(pos)
317303

318-
# Validate we have enough samples per dataset across all splits
319-
for name in dataset_names:
320-
needed = sum(c[name] for c in counts_per_split)
321-
available = len(plan_indices_by_name[name])
322-
if needed > available:
323-
raise ValueError(
324-
f"Not enough samples in '{name}': "
325-
f"need {needed}, have {available}."
326-
)
304+
counts_per_split = self._proportional_counts(
305+
plan_indices_by_name, abs_sizes, norm, dataset_names
306+
)
327307

328308
# Sample without replacement then slice into splits
329309
result: List[List[int]] = [[] for _ in abs_sizes]
330310

331311
for name in dataset_names:
332312
pool = np.array(plan_indices_by_name[name])
333313
total_needed = sum(c[name] for c in counts_per_split)
314+
315+
if total_needed > len(pool):
316+
raise ValueError(
317+
f"Not enough samples in '{name}': "
318+
f"need {total_needed}, have {len(pool)}."
319+
)
320+
334321
chosen = rng.choice(pool, size=total_needed, replace=False)
335322

336323
offset = 0
@@ -339,75 +326,42 @@ def split(self, dataset, *split_sizes) -> List[List[int]]:
339326
result[split_idx].extend(chosen[offset : offset + n].tolist())
340327
offset += n
341328

342-
# Shuffle each split so datasets are interleaved, not blocked by source
329+
# Shuffle so datasets are interleaved within each split
343330
for split_indices in result:
344331
rng.shuffle(split_indices)
345332

346333
return result
347334

348335
@staticmethod
349-
def _counts_from_proportions(
350-
split_size: int,
351-
proportions: Dict[str, float],
352-
names: List[str],
353-
) -> Dict[str, int]:
354-
"""Largest-remainder allocation of split_size across datasets."""
355-
raw = {n: proportions[n] * split_size for n in names}
356-
base = {n: int(np.floor(raw[n])) for n in names}
357-
remainder = split_size - sum(base.values())
358-
359-
if remainder > 0:
360-
order = sorted(names, key=lambda n: raw[n] - base[n], reverse=True)
361-
for i in range(remainder):
362-
base[order[i % len(order)]] += 1
363-
364-
return base
365-
366-
367-
"""
368-
- MD17: 200,000 samples
369-
- rMD17: 80,000 samples
370-
- Total merged: 300,000
371-
- `num_train=0.8` → 240,000 train samples 120
372-
- `num_val=0.1` → 30,000 val samples
373-
- `num_test=0.1` → 30,000 test samples
374-
- Proportions: `{"md17": 0.7, "rmd17": 0.3}`
375-
376-
Step 1 — Normalise proportions**
377-
378-
md17: 0.7 / (0.7+0.3) = 0.7
379-
rmd17: 0.3 / (0.7+0.3) = 0.3
380-
381-
Step 2 — Figure out how many samples per dataset per split
382-
383-
For train (240,000 total):
384-
md17: 0.7 x 240,000 = 168,000
385-
rmd17: 0.3 x 240,000 = 72,000
386-
387-
For val (30,000 total):
388-
md17: 0.7 x 30,000 = 21,000
389-
rmd17: 0.3 x 30,000 = 9,000
390-
391-
For test (30,000 total):
392-
md17: 0.7 x 30,000 = 21,000
393-
rmd17: 0.3 x 30,000 = 9,000
394-
395-
Step 3 — Check availability
396-
md17 needs: 168,000 + 21,000 + 21,000 = 210,000 — have 200,000 → raises error
397-
rmd17 needs: 72,000 + 9,000 + 9,000 = 90,000 — have 100,000
336+
def _proportional_counts(
337+
plan_indices_by_name: Dict[str, List[int]],
338+
abs_sizes: List[int],
339+
norm: Dict[str, float],
340+
dataset_names: List[str],
341+
) -> List[Dict[str, int]]:
342+
"""
343+
Split each dataset's available pool across train/val/test
344+
using the same ratio as abs_sizes, ignoring target proportions.
345+
DatasetBalancedSampler handles the proportion imbalance during training.
346+
"""
347+
total = sum(abs_sizes)
348+
split_ratios = [s / total for s in abs_sizes]
398349

399-
Step 4 — Build per-name index pools
400-
plan_indices_by_name = {
401-
"md17": [0, 1, 2, ..., 199999], # positions in the plan
402-
"rmd17": [200000, 200001, ..., 299999],
403-
}
350+
counts_per_split: List[Dict[str, int]] = [{} for _ in abs_sizes]
404351

405-
Step 5 — Sample without replacement
406-
md17 chosen = rng.choice(200000 indices, size=210000) # error, not enough
407-
rmd17 chosen = rng.choice(100000 indices, size=90000, replace=False)
408-
→ first 72000 go to train
409-
→ next 9000 go to val
410-
→ last 9000 go to test
352+
for name in dataset_names:
353+
available = len(plan_indices_by_name[name])
354+
# split this dataset's pool by the same train/val/test ratio
355+
raw = [r * available for r in split_ratios]
356+
base = [int(np.floor(r)) for r in raw]
357+
remainder = available - sum(base)
358+
# distribute remainder by largest fractional part
359+
order = sorted(
360+
range(len(abs_sizes)), key=lambda i: raw[i] - base[i], reverse=True
361+
)
362+
for i in range(remainder):
363+
base[order[i % len(order)]] += 1
364+
for split_idx, count in enumerate(base):
365+
counts_per_split[split_idx][name] = count
411366

412-
Step 6 — Shuffle each split
413-
"""
367+
return counts_per_split

0 commit comments

Comments
 (0)