Skip to content

Commit 67a76e1

Browse files
Bootstrap HF Datasets (#20)
* hf datasets appear to be working * Fixed mypy and ruff * data sources fixed, but skipped data source test needs to be fixed * bug fixes * add stop token to generate * The example showed pip still, we use uv * fix mypy * CR feedback * fix lint issues --------- Co-authored-by: dariocazzani <dariocazzani@gmail.com>
1 parent 5953358 commit 67a76e1

17 files changed

Lines changed: 1124 additions & 286 deletions

File tree

.github/workflows/lint.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
name: Lint
2-
on: [push, pull_request]
2+
on:
3+
push:
4+
branches: [main]
5+
pull_request:
36
jobs:
47
lint:
58
runs-on: ubuntu-latest

examples/simple.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
# Import ScratchGPT components
2424
from scratchgpt import (
2525
CharTokenizer,
26-
FileDataSource,
2726
ScratchGPTArchitecture,
2827
ScratchGPTConfig,
2928
ScratchGPTTraining,
3029
Trainer,
3130
TransformerLanguageModel,
3231
)
32+
from scratchgpt.data import create_data_source
3333

3434

3535
def download_darwin_text(data_file: Path) -> None:
@@ -67,24 +67,21 @@ def create_simple_config() -> ScratchGPTConfig:
6767
random_seed=1337,
6868
)
6969

70-
return ScratchGPTConfig(
71-
architecture=architecture,
72-
training=training
73-
)
70+
return ScratchGPTConfig(architecture=architecture, training=training)
7471

7572

7673
def prepare_text_for_tokenizer(data_file: Path) -> str:
7774
"""Read the text file for tokenization."""
7875
print(f"Reading text from: {data_file}")
7976

80-
with open(data_file, encoding='utf-8') as f:
77+
with open(data_file, encoding="utf-8") as f:
8178
text = f.read()
8279

8380
print(f"Text length: {len(text):,} characters")
8481
return text
8582

8683

87-
def main():
84+
def main() -> None:
8885
print("ScratchGPT Simple Training Example")
8986
print("=" * 50)
9087

@@ -104,7 +101,7 @@ def main():
104101
print(f"Vocabulary size: {tokenizer.vocab_size}")
105102

106103
# Alternative: Use a pre-trained tokenizer like GPT-2
107-
# This requires: pip install 'scratchgpt[hf-tokenizers]'
104+
# This requires: uv sync --extra hf-tokenizers
108105
#
109106
# from scratchgpt import HuggingFaceTokenizer
110107
# tokenizer = HuggingFaceTokenizer.from_hub("gpt2")
@@ -118,8 +115,10 @@ def main():
118115
# Step 3: Create configuration
119116
config = create_simple_config()
120117
config.architecture.vocab_size = tokenizer.vocab_size
121-
print(f"Model configuration: {config.architecture.embedding_size}D embeddings, "
122-
f"{config.architecture.num_blocks} blocks, {config.architecture.num_heads} heads")
118+
print(
119+
f"Model configuration: {config.architecture.embedding_size}D embeddings, "
120+
f"{config.architecture.num_blocks} blocks, {config.architecture.num_heads} heads"
121+
)
123122

124123
# Step 4: Setup model and training
125124
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -130,22 +129,22 @@ def main():
130129
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
131130

132131
optimizer = AdamW(model.parameters(), lr=config.training.learning_rate)
133-
data_source = FileDataSource(data_file)
132+
data_source = create_data_source(str(data_file))
134133

135134
# Step 5: Create trainer and start training
136135
trainer = Trainer(
137136
model=model,
138137
config=config.training,
139138
optimizer=optimizer,
140139
experiment_path=experiment_dir,
141-
device=device
140+
device=device,
142141
)
143142

144143
print("\nStarting training...")
145144
print("(Press Ctrl-C to stop training early and see text generation)")
146145

147146
try:
148-
trainer.train(data=data_source, tokenizer=tokenizer)
147+
trainer.train(data_source=data_source, tokenizer=tokenizer)
149148
print("\nTraining completed successfully!")
150149
except KeyboardInterrupt:
151150
print("\n\nTraining interrupted by user. Moving to text generation with current model state...")
@@ -154,11 +153,7 @@ def main():
154153
print("\nTesting text generation:")
155154
model.eval()
156155

157-
test_prompts = [
158-
"Natural selection",
159-
"The origin of species",
160-
"Darwin observed"
161-
]
156+
test_prompts = ["Natural selection", "The origin of species", "Darwin observed"]
162157

163158
for prompt in test_prompts:
164159
print(f"\nPrompt: '{prompt}'")

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ classifiers = [
2424
license = {file = "LICENSE"}
2525

2626
dependencies = [
27+
"datasets>=4.0.0",
2728
"numpy>=2.3.2",
2829
"ptflops>=0.7.5",
2930
"pydantic-settings>=2.10.1",
@@ -69,7 +70,7 @@ strict = true
6970
exclude = [".venv"]
7071

7172
[[tool.mypy.overrides]]
72-
module = ["ptflops", "tokenizers.*", "huggingface_hub.*"]
73+
module = ["ptflops", "tokenizers.*", "huggingface_hub.*", "datasets.*"]
7374
ignore_missing_imports = true
7475

7576
[tool.ruff]

scratchgpt/__init__.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,8 @@
77
ScratchGPTConfig,
88
ScratchGPTTraining,
99
)
10-
from scratchgpt.data.datasource import (
11-
ByteSizableDataSource,
12-
DataSource,
13-
FileDataSource,
14-
FolderDataSource,
15-
LineByLineFileDataSource,
16-
)
10+
from scratchgpt.data.datasource import DataSource
11+
from scratchgpt.data.hf_datasource import HFDataSource
1712
from scratchgpt.model.model import TransformerLanguageModel
1813
from scratchgpt.model_io import (
1914
ModelLoadFailedError,
@@ -32,7 +27,7 @@
3227
)
3328
from scratchgpt.tokenizer.char_tokenizer import CharTokenizer, Utf8Tokenizer
3429
from scratchgpt.tokenizer.hf_tokenizer import HuggingFaceTokenizer
35-
from scratchgpt.training.trainer import Trainer, get_dtype_for_vocab_size
30+
from scratchgpt.training.trainer import Trainer
3631

3732
__all__ = [
3833
# Core Model and Config
@@ -42,10 +37,7 @@
4237
"ScratchGPTTraining",
4338
# Data Sources
4439
"DataSource",
45-
"ByteSizableDataSource",
46-
"FileDataSource",
47-
"FolderDataSource",
48-
"LineByLineFileDataSource",
40+
"HFDataSource",
4941
# Model I/O
5042
"load_model",
5143
"load_tokenizer",
@@ -64,5 +56,4 @@
6456
"HuggingFaceTokenizer",
6557
# Training
6658
"Trainer",
67-
"get_dtype_for_vocab_size",
6859
]

scratchgpt/config.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
2-
from typing import Annotated
2+
from typing import Annotated, Self
33

4-
from pydantic import AfterValidator, Field
4+
from pydantic import AfterValidator, Field, model_validator
55
from pydantic_settings import (
66
BaseSettings,
77
PydanticBaseSettingsSource,
@@ -18,6 +18,10 @@ def ensure_split_is_valid(v: tuple[float, float]) -> tuple[float, float]:
1818
is_valid_split = math.isclose(splits_sum, 1.0)
1919
if not is_valid_split:
2020
raise ValueError("Invalid data 'split'")
21+
22+
val_split = v[1]
23+
if val_split == 0.0:
24+
raise ValueError("You can't have 0 sized validation split.")
2125
return v
2226

2327

@@ -36,6 +40,18 @@ class ScratchGPTArchitecture(BaseSettings):
3640
num_blocks: int = 6
3741
vocab_size: int | None = None
3842

43+
@model_validator(mode="after")
44+
def validate_embedding_and_heads(self) -> Self:
45+
"""
46+
Ensures that the embedding_size is perfectly divisible by the number of attention heads.
47+
"""
48+
if self.embedding_size % self.num_heads != 0:
49+
raise ValueError(
50+
f"Incompatible model architecture: embedding_size ({self.embedding_size}) "
51+
f"must be divisible by num_heads ({self.num_heads})."
52+
)
53+
return self
54+
3955
model_config = SettingsConfigDict(
4056
env_prefix="ARCHITECTURE_",
4157
extra="allow",

scratchgpt/core/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from torch import Tensor
22
from torch.utils.data import DataLoader
33

4-
TensorTupleLoader = DataLoader[tuple[Tensor, Tensor]]
4+
DictTensorLoader = DataLoader[dict[str, Tensor]]

scratchgpt/data/__init__.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Any
2+
3+
from scratchgpt.data.datasource import DataSource
4+
from scratchgpt.data.hf_datasource import HFDataSource
5+
6+
7+
def create_data_source(
8+
path_or_name: str,
9+
split: str = "train",
10+
streaming: bool = False,
11+
text_column: str = "text",
12+
**kwargs: Any,
13+
) -> DataSource:
14+
"""
15+
Create a data source from a path or dataset name.
16+
17+
Examples:
18+
# HuggingFace Hub dataset
19+
>>> ds = create_data_source("wikitext-2-raw-v1")
20+
21+
# Local text file
22+
>>> ds = create_data_source("data.txt")
23+
24+
# Local CSV file
25+
>>> ds = create_data_source("data.csv", text_column="content")
26+
27+
# Folder of text files
28+
>>> ds = create_data_source("./texts/")
29+
30+
# Streaming large dataset
31+
>>> ds = create_data_source("openwebtext", streaming=True)
32+
33+
Args:
34+
path_or_name: HF Hub dataset name or path to local data
35+
split: Dataset split to use
36+
streaming: Whether to use streaming mode
37+
text_column: Column name containing text
38+
**kwargs: Additional arguments for HFDataSource
39+
40+
Returns:
41+
DataSource instance
42+
"""
43+
return HFDataSource(
44+
path_or_name=path_or_name,
45+
split=split,
46+
streaming=streaming,
47+
text_column=text_column,
48+
**kwargs,
49+
)
50+
51+
52+
__all__ = ["DataSource", "HFDataSource", "create_data_source"]

scratchgpt/data/datasource.py

Lines changed: 18 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,26 @@
1-
from collections.abc import Iterator
2-
from pathlib import Path
3-
from typing import Protocol, runtime_checkable
1+
from typing import Protocol
2+
3+
from scratchgpt.core.types import DictTensorLoader
4+
from scratchgpt.tokenizer.base_tokenizer import Tokenizer
45

56

6-
@runtime_checkable
77
class DataSource(Protocol):
88
"""
9-
An interface for providing raw data to the Trainer.
9+
A protocol for classes that can provide training and validation DataLoaders.
1010
11-
A DataSource is an iterable object that yields individual,
12-
untokenized training samples as strings.
11+
This uses structural subtyping. Any class that implements a matching
12+
`get_dataloaders` method will be considered a valid DataSource.
1313
"""
1414

15-
def __iter__(self) -> Iterator[str]:
16-
"""Returns an iterator over the raw text samples."""
15+
def get_dataloaders(
16+
self,
17+
tokenizer: Tokenizer,
18+
block_size: int,
19+
batch_size: int,
20+
splits: tuple[float, float],
21+
random_seed: int,
22+
) -> tuple[DictTensorLoader, DictTensorLoader | None]:
23+
"""
24+
Processes data and returns train and validation DataLoaders.
25+
"""
1726
...
18-
19-
20-
@runtime_checkable
21-
class ByteSizableDataSource(DataSource, Protocol):
22-
"""An optional extension for DataSources that can report their total size in bytes."""
23-
24-
def total_bytes(self) -> int:
25-
"""Returns the total size of the data source in bytes."""
26-
...
27-
28-
29-
class FileDataSource(ByteSizableDataSource):
30-
"""Yields the entire content of a single text file as one sample."""
31-
32-
def __init__(self, file_path: Path):
33-
if not file_path.is_file():
34-
raise FileNotFoundError(f"Source file not found at: {file_path}")
35-
self._file_path = file_path
36-
37-
def __len__(self) -> int:
38-
return 1
39-
40-
def __iter__(self) -> Iterator[str]:
41-
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
42-
yield f.read()
43-
44-
def total_bytes(self) -> int:
45-
return self._file_path.stat().st_size
46-
47-
48-
class FolderDataSource(ByteSizableDataSource):
49-
"""Iterates through a directory and yields the content of each file."""
50-
51-
def __init__(self, folder_path: Path):
52-
if not folder_path.is_dir():
53-
raise NotADirectoryError(f"Source path is not a directory: {folder_path}")
54-
55-
self._file_paths = [p for p in folder_path.rglob("*") if p.is_file() and not p.name.startswith(".")]
56-
print(f"✅ Found {len(self._file_paths)} files to process in {folder_path}.")
57-
58-
def __len__(self) -> int:
59-
return len(self._file_paths)
60-
61-
def __iter__(self) -> Iterator[str]:
62-
for file_path in self._file_paths:
63-
with open(file_path, encoding="utf-8", errors="ignore") as f:
64-
yield from f
65-
66-
def total_bytes(self) -> int:
67-
return sum(p.stat().st_size for p in self._file_paths)
68-
69-
70-
class LineByLineFileDataSource(ByteSizableDataSource):
71-
"""Reads a text file and yields each line as a separate sample."""
72-
73-
def __init__(self, file_path: Path):
74-
if not file_path.is_file():
75-
raise FileNotFoundError(f"Source file not found at: {file_path}")
76-
self._file_path = file_path
77-
78-
print("Pre-counting lines for progress bar...")
79-
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
80-
self._line_count = sum(1 for _ in f)
81-
82-
def __len__(self) -> int:
83-
return self._line_count
84-
85-
def __iter__(self) -> Iterator[str]:
86-
with open(self._file_path, encoding="utf-8", errors="ignore") as f:
87-
yield from f
88-
89-
def total_bytes(self) -> int:
90-
return self._file_path.stat().st_size

0 commit comments

Comments
 (0)