diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..cd4f780 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,15 @@ +name: Tests +on: [pull_request] +jobs: + unit_tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v4 + with: + version: "latest" + - uses: actions/setup-python@v4 + with: + python-version: "3.12" + - run: uv sync --group dev + - run: uv run pytest tests diff --git a/README.md b/README.md index 120826a..005b994 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ repo is educational, so the aim is to keep the code as legible as possible. [x] Switch to uv [x] Make it easy to modify with a config file -[] Extract the loss calculation from the model -[] Rename main to train +[x] Extract the loss calculation from the model +[x] Rename main to train [] Create an easy to use interface [] Create or check tokenizer interface [] Make it into a package diff --git a/pyproject.toml b/pyproject.toml index addec80..bd24064 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ dev = [ ] [tool.pytest.ini_options] -asyncio_mode = "auto" [tool.mypy] python_version = "3.12" @@ -73,6 +72,6 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project.scripts] -train = "scratchgpt.main:main" +train = "scratchgpt.train:main" infer = "scratchgpt.infer:main" tiktoken = "scratchgpt.tokenizer.tiktoken:main" diff --git a/scratchgpt/dataloader.py b/scratchgpt/dataloader.py index 7d1573e..bb4466c 100644 --- a/scratchgpt/dataloader.py +++ b/scratchgpt/dataloader.py @@ -1,10 +1,12 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Literal, override +from typing import override +import numpy as np import torch from torch import Tensor from torch.utils.data import Dataset +from tqdm import tqdm from .tokenizer.base_tokenizer import Tokenizer @@ -40,17 +42,13 @@ def __init__(self, dir_path: Path) -> None: raise ValueError(f"Directory path {dir_path} is not a directory") self._data = "" + file_paths = list(dir_path.rglob("*")) print(f"Loading data from {dir_path}") - total_read: int = 0 - for idx, file_path in enumerate(dir_path.rglob("*")): + for file_path in tqdm(file_paths, desc="Reading data files", unit="file"): if file_path.is_file() and not file_path.name.startswith("."): with open(file_path, encoding="utf-8") as f: self._data += f.read() + "\n" - if idx % 500 == 1: - total_read += 500 - print(f"Read {total_read} files") - print("Data Loaded") @override @@ -64,28 +62,12 @@ def __init__( text_provider: TextProvider, tokenizer: Tokenizer, block_size: int, - split: Literal["train", "validation", "test"], - train_ratio: float = 0.8, - val_ratio: float = 0.1, ) -> None: self.tokenizer = tokenizer self.block_size = block_size self.data = torch.tensor(self.tokenizer.encode(text_provider.get_text()), dtype=torch.long) - total_size = len(self.data) - train_size = int(total_size * train_ratio) - val_size = int(total_size * val_ratio) - - if split == "train": - self.data = self.data[:train_size] - elif split == "validation": - self.data = self.data[train_size : train_size + val_size] - elif split == "test": - self.data = self.data[train_size + val_size :] - else: - raise ValueError(f"Invalid split: {split}. Must be 'train', 'validation', or 'test'.") - def __len__(self) -> int: return len(self.data) - self.block_size @@ -93,3 +75,26 @@ def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]: block = self.data[idx : idx + self.block_size] target = self.data[idx + 1 : idx + self.block_size + 1] return block, target + + +class PretokenizedDataset(Dataset[tuple[Tensor, Tensor]]): + def __init__( + self, + token_file: Path, + block_size: int, + dtype: np.dtype, + ) -> None: + super().__init__() + self.block_size = block_size + + all_tokens = np.memmap(token_file, dtype=dtype, mode="c") + self.data = torch.from_numpy(all_tokens) + + def __len__(self) -> int: + return max(0, len(self.data) - self.block_size) + + def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]: + block = self.data[idx : idx + self.block_size] + target = self.data[idx + 1 : idx + self.block_size + 1] + + return block.long(), target.long() diff --git a/scratchgpt/preprocess.py b/scratchgpt/preprocess.py new file mode 100644 index 0000000..08caf2d --- /dev/null +++ b/scratchgpt/preprocess.py @@ -0,0 +1,137 @@ +import io +from pathlib import Path +from typing import Any, Protocol + +import numpy as np +from numpy.typing import DTypeLike +from tqdm import tqdm + +from .tokenizer.base_tokenizer import Tokenizer + + +class SupportsUpdate(Protocol): + def update(self, n: int) -> Any: ... + + +class Preprocessor(Protocol): + """ + Preprocessor protocol for handling dataset conversion using a specific tokenizer. + """ + + def __call__( + self, + source: io.TextIOBase, + sink: io.BufferedIOBase, + chunk_size: int, + pbar: SupportsUpdate | None = None, + ) -> None: + """ + Process the input text source and write the result to the binary sink. + Optionally updates a tqdm progress bar. + """ + + +class FilePreprocessor(Protocol): + """ + Preprocessor that deals specifically with file system io. + """ + + def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: + """ + Process input and output paths + """ + + +class TokenizerPreprocessor(Preprocessor): + """ + Default pre-processor. Tokenizes a text stream and writes the output + to a binary stream, managing progress updates internally. + """ + + def __init__(self, tokenizer: Tokenizer) -> None: + self.tokenizer = tokenizer + vocab_size = self.tokenizer.vocab_size + if vocab_size < 2**8: + self.dtype: DTypeLike = np.uint8 + elif vocab_size < 2**16: + self.dtype = np.uint16 + elif vocab_size < 2**32: + self.dtype = np.uint32 + else: + self.dtype = np.uint64 + print(f"Preprocessor initialized. Selected {np.dtype(self.dtype).name} for token storage.") + + def __call__( + self, + source: io.TextIOBase, + sink: io.BufferedIOBase, + chunk_size: int = 10 * 1024 * 1024, + pbar: SupportsUpdate | None = None, + ) -> None: + """ + Reads from the source stream, tokenizes content in chunks, writes to the + sink stream, and updates the provided progress bar. + """ + while chunk := source.read(chunk_size): + tokens = self.tokenizer.encode(chunk) + token_array = np.array(tokens, dtype=self.dtype) + sink.write(token_array.tobytes()) + if pbar: + pbar.update(len(chunk.encode("utf-8", errors="ignore"))) + + +class File2FileTokenizerPreprocessor: + """ + Orchestrates preprocessing for a single source file to a single destination file. + """ + + def __init__(self, tokenizer: Tokenizer) -> None: + self._preprocessor = TokenizerPreprocessor(tokenizer) + + def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: + if not input_path.is_file(): + raise ValueError(f"Input path must be a file: {input_path}") + if output_path.exists(): + raise FileExistsError(f"Output path already exists: {output_path}") + + total_size = input_path.stat().st_size + + with ( + open(input_path, encoding="utf-8", errors="ignore") as source, + open(output_path, "wb") as sink, + tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Tokenizing {input_path.name}") as pbar, + ): + self._preprocessor(source, sink, chunk_size, pbar) + + print(f"Successfully preprocessed '{input_path}' to '{output_path}'") + + +class Folder2FileTokenizerPreprocessor: + """ + Orchestrates preprocessing for a directory of source files to a single destination file. + """ + + def __init__(self, tokenizer: Tokenizer) -> None: + self._preprocessor = TokenizerPreprocessor(tokenizer) + + def __call__(self, input_path: Path, output_path: Path, chunk_size: int = 10 * 1024 * 1024) -> None: + if not input_path.is_dir(): + raise ValueError(f"Input path must be a directory: {input_path}") + if output_path.exists(): + raise FileExistsError(f"Output path already exists: {output_path}") + + files_to_process = [p for p in input_path.rglob("*") if p.is_file() and not p.name.startswith(".")] + total_size = sum(p.stat().st_size for p in files_to_process) + + print(f"Found {len(files_to_process)} files to process.") + + with ( + open(output_path, "wb") as sink, + tqdm(total=total_size, unit="B", unit_scale=True, desc=f"Tokenizing Folder '{input_path.name}'") as pbar, + ): + for file_path in files_to_process: + pbar.set_postfix_str(f"Processing: {file_path.name}", refresh=True) + with open(file_path, encoding="utf-8", errors="ignore") as source: + self._preprocessor(source, sink, chunk_size, pbar) + + print(f"\nSuccessfully preprocessed folder '{input_path}' to '{output_path}'") diff --git a/scratchgpt/main.py b/scratchgpt/train.py similarity index 56% rename from scratchgpt/main.py rename to scratchgpt/train.py index e97d29c..05e213e 100644 --- a/scratchgpt/main.py +++ b/scratchgpt/train.py @@ -1,9 +1,11 @@ import argparse +import math import os import sys from pathlib import Path from typing import Literal +import numpy as np import torch from pydantic_yaml import parse_yaml_file_as, to_yaml_file from rich.pretty import pprint as rpprint @@ -11,11 +13,14 @@ from torch.optim.adamw import AdamW from torch.optim.optimizer import Optimizer from torch.types import Tensor -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset, random_split from tqdm import tqdm +from scratchgpt.preprocess import File2FileTokenizerPreprocessor, FilePreprocessor, Folder2FileTokenizerPreprocessor +from scratchgpt.tokenizer.base_tokenizer import Tokenizer + from .config import ScratchGPTConfig -from .dataloader import FileTextProvider, FolderTextProvider, TextDataset, TextProvider +from .dataloader import PretokenizedDataset from .metering import AverageValueMeter from .model.model import TransformerLanguageModel, print_model_complexity from .model_io import ( @@ -29,6 +34,25 @@ DatasetType = tuple[Tensor, Tensor] +def parse_splits(value: str) -> list[float]: + """ + Custom argparse type to validate and parse training splits. + Splits should be provided as a semicolon-separated string of 3 floats + (train, validation, test) that sum to 1.0. + """ + try: + splits = [float(x) for x in value.split(";")] + if len(splits) != 3: + raise ValueError("Exactly three split values for train, validation, and test are required.") + if not math.isclose(sum(splits), 1.0): + raise ValueError(f"Split values must sum to 1.0, but they sum to {sum(splits):.2f}.") + return splits + except (ValueError, TypeError) as e: + raise argparse.ArgumentTypeError( + f"Invalid split format '{value}'. Use 'train;val;test' format (e.g., '0.8;0.1;0.1'). Error: {e}" + ) from e + + def parse_args() -> argparse.Namespace: """ Create CLI args parser and execute it @@ -37,7 +61,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "-t", "--train_source", - help="The file you want to train on", + help="The file or folder you want to train on", required=True, type=Path, ) @@ -55,6 +79,19 @@ def parse_args() -> argparse.Namespace: default="cuda", choices=["cuda", "cpu"], ) + parser.add_argument( + "-s", + "--splits", + type=parse_splits, + default="0.9;0.1;0.0", + help="Train, validation, and test split ratios, semicolon-separated (e.g., '0.9;0.1;0.0')", + ) + parser.add_argument( + "--dtype", + type=str, + default=None, + help="NumPy dtype for pre-tokenized .bin files (e.g., 'uint16'). Required if using a .bin file.", + ) return parser.parse_args() @@ -129,10 +166,63 @@ def run_epoch( return average_loss.value() -def get_text_provider(path: Path) -> TextProvider: - if path.is_dir(): - return FolderTextProvider(path) - return FileTextProvider(path) +def get_dtype_for_vocab_size(vocab_size: int) -> np.dtype: + """Determine the smallest possible uint dtype for a given vocabulary size.""" + if vocab_size < 2**8: + return np.dtype(np.uint8) + if vocab_size < 2**16: + return np.dtype(np.uint16) + if vocab_size < 2**32: + return np.dtype(np.uint32) + return np.dtype(np.uint64) + + +def prepare_dataset( + args: argparse.Namespace, tokenizer: Tokenizer, config: ScratchGPTConfig +) -> Dataset[tuple[Tensor, Tensor]]: + """ + Prepare the dataset for training. + - If the source is a .bin file, it loads it directly. + - If the source is text, it preprocesses and caches it in the experiment folder. + - If a cached version exists, it uses that instead of reprocessing. + """ + cached_data_path = args.experiment / "preprocessed_data.bin" + + if args.train_source.suffix == ".bin": + print(f"Loading pre-tokenized data directly from {args.train_source}") + if not args.dtype: + raise ValueError("--dtype must be specified when using a .bin file.") + return PretokenizedDataset( + token_file=args.train_source, + block_size=config.architecture.block_size, + dtype=np.dtype(args.dtype), + ) + + # For raw text, determine the best dtype based on the tokenizer's vocab size. + dtype = get_dtype_for_vocab_size(tokenizer.vocab_size) + + if cached_data_path.exists(): + print(f"Found cached preprocessed data at {cached_data_path}. Loading it.") + return PretokenizedDataset( + token_file=cached_data_path, + block_size=config.architecture.block_size, + dtype=dtype, + ) + + print(f"No cached data found. Preprocessing '{args.train_source}' now.") + if args.train_source.is_dir(): + preprocessor: FilePreprocessor = Folder2FileTokenizerPreprocessor(tokenizer) + else: + preprocessor = File2FileTokenizerPreprocessor(tokenizer) + + preprocessor(input_path=args.train_source, output_path=cached_data_path) + + print(f"Loading the newly preprocessed data from {cached_data_path}") + return PretokenizedDataset( + token_file=cached_data_path, + block_size=config.architecture.block_size, + dtype=dtype, + ) def main() -> None: @@ -140,22 +230,31 @@ def main() -> None: config = load_or_create_config(args.experiment) + if not os.path.exists(args.experiment): + os.makedirs(args.experiment, exist_ok=True) + torch.manual_seed(config.training.random_seed) print(f"Set random seed to: {config.training.random_seed}") device = torch.device(args.device) print(f"Using the device: {device}") - text_provider = get_text_provider(args.train_source) - tokenizer = get_tokenizer(args.experiment) config.architecture.vocab_size = tokenizer.vocab_size rpprint(config.model_dump(), indent_guides=True, expand_all=True) - train_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "train", 0.9) - val_dataset = TextDataset(text_provider, tokenizer, config.architecture.block_size, "validation", 0.1) + full_dataset = prepare_dataset(args, tokenizer, config) + print(f"Splitting dataset into train/validation/test with ratios: {args.splits}") + train_dataset, val_dataset, test_dataset = random_split( + dataset=full_dataset, + lengths=args.splits, + generator=torch.Generator().manual_seed(config.training.random_seed), + ) + print(f"Train dataset size: {len(train_dataset)}") + print(f"Validation dataset size: {len(val_dataset)}") + print(f"Test dataset size: {len(test_dataset)}") - print("Loading train and validation loaders") + print("Loading train, validation, and test loaders...") cpu_count = os.cpu_count() or 4 train_dataloader = DataLoader( train_dataset, @@ -173,6 +272,16 @@ def main() -> None: shuffle=False, ) + test_dataloader = None + if len(test_dataset) > 0: + test_dataloader = DataLoader( + test_dataset, + config.training.batch_size, + pin_memory=True, + num_workers=int(cpu_count / 2), + shuffle=False, + ) + print("Loaders initialized") best_model_path = get_best_model_weights_path(args.experiment) @@ -189,9 +298,6 @@ def main() -> None: best_val_loss = float("inf") - if not os.path.exists(args.experiment): - os.makedirs(args.experiment, exist_ok=True) - save_tokenizer(args.experiment, tokenizer) model_config = f"{args.experiment}/scratch_gpt.yaml" print(f"Saving this models config to {model_config}") @@ -229,6 +335,21 @@ def main() -> None: torch.save(model.state_dict(), latest_model_path) print("Trying my best here") + if test_dataloader: + print("\n--- Running Final Test Evaluation ---") + print(f"Loading best model weights from {best_model_path}") + model = load_model(best_model_path, model, device) + + test_loss_mean, test_loss_std = run_epoch( + model=model, + dataloader=test_dataloader, + device=device, + stage="test", + ) + print("=" * 40) + print(f"🔬 Final Test Loss: {test_loss_mean:.4f} ± {test_loss_std:.4f}") + print("=" * 40) + prompt = input("Tell me your prompt: ") context = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device) generated = model.generate(context, max_new_tokens=500) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py new file mode 100644 index 0000000..6aea770 --- /dev/null +++ b/tests/test_preprocess.py @@ -0,0 +1,308 @@ +import io +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import torch + +from scratchgpt.dataloader import PretokenizedDataset +from scratchgpt.preprocess import ( + File2FileTokenizerPreprocessor, + Folder2FileTokenizerPreprocessor, + TokenizerPreprocessor, +) +from scratchgpt.tokenizer.base_tokenizer import Tokenizer + + +class MockTokenizer(Tokenizer): + """A controlled tokenizer for predictable testing.""" + + def __init__(self, vocab_size: int = 256): + self._vocab_size = vocab_size + self.mapping = {chr(ord("a") + i): i + 1 for i in range(26)} + self.mapping[" "] = 27 + self.mapping["\n"] = 28 + self.mapping["€"] = 29 + + def encode(self, text: str) -> list[int]: + return [self.mapping.get(char, 0) for char in text] + + def decode(self, encoding: list[int]) -> str: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + return self._vocab_size + + @property + def vocabulary(self) -> list[str]: + raise NotImplementedError + + +class NumberTokenizer(Tokenizer): + """A controlled tokenizer for testing with sequences of numbers.""" + + def __init__(self, vocab_size: int): + self._vocab_size = vocab_size + + def encode(self, text: str) -> list[int]: + """Encodes a space-separated string of numbers into a list of ints.""" + return [int(x) for x in text.split()] + + def decode(self, encoding: list[int]) -> str: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + return self._vocab_size + + @property + def vocabulary(self) -> list[str]: + raise NotImplementedError + + +class TestTokenizerPreprocessor(unittest.TestCase): + def test_happy_case_tokenization(self) -> None: + """Test standard tokenization with a simple string.""" + tokenizer = MockTokenizer() + preprocessor = TokenizerPreprocessor(tokenizer) + source = io.StringIO("ab c") + sink = io.BytesIO() + + preprocessor(source, sink) + + sink.seek(0) + result = np.frombuffer(sink.read(), dtype=preprocessor.dtype) + expected = np.array([1, 2, 27, 3], dtype=preprocessor.dtype) + np.testing.assert_array_equal(result, expected) + + def test_dtype_selection(self) -> None: + """Ensure correct numpy dtype is chosen based on vocab size.""" + # uint8 + preprocessor_small = TokenizerPreprocessor(MockTokenizer(vocab_size=255)) + self.assertEqual(preprocessor_small.dtype, np.uint8) + + # uint16 + preprocessor_medium = TokenizerPreprocessor(MockTokenizer(vocab_size=65535)) + self.assertEqual(preprocessor_medium.dtype, np.uint16) + + # uint32 + preprocessor_large = TokenizerPreprocessor(MockTokenizer(vocab_size=65536)) + self.assertEqual(preprocessor_large.dtype, np.uint32) + + def test_empty_input(self) -> None: + """Test that an empty source results in an empty sink.""" + preprocessor = TokenizerPreprocessor(MockTokenizer()) + source = io.StringIO("") + sink = io.BytesIO() + + preprocessor(source, sink) + + self.assertEqual(sink.getvalue(), b"") + + def test_chunking_and_multibyte_chars(self) -> None: + """Ensure correct processing with small chunks and unicode.""" + preprocessor = TokenizerPreprocessor(MockTokenizer()) + text = "a€b" # '€' is a multi-byte character + source = io.StringIO(text) + sink = io.BytesIO() + + # Chunk size of 1 character + preprocessor(source, sink, chunk_size=1) + + sink.seek(0) + result = np.frombuffer(sink.read(), dtype=preprocessor.dtype) + expected = np.array([1, 29, 2], dtype=preprocessor.dtype) + np.testing.assert_array_equal(result, expected) + + @patch("scratchgpt.preprocess.tqdm") + def test_progress_bar_update(self, mock_tqdm: MagicMock) -> None: + """Verify that the progress bar is updated.""" + mock_pbar = MagicMock() + mock_tqdm.return_value.__enter__.return_value = mock_pbar + + preprocessor = TokenizerPreprocessor(MockTokenizer()) + source = io.StringIO("abc") + sink = io.BytesIO() + + preprocessor(source, sink, pbar=mock_pbar) + + # 'abc' is 3 bytes in utf-8 + mock_pbar.update.assert_called_once_with(3) + + +class TestFileAndFolderPreprocessors(unittest.TestCase): + def setUp(self) -> None: + """Create a temporary directory for test files.""" + self.test_dir = tempfile.TemporaryDirectory() + self.test_path = Path(self.test_dir.name) + + def tearDown(self) -> None: + """Clean up the temporary directory.""" + self.test_dir.cleanup() + + # --- File2FileTokenizerPreprocessor Tests --- + + @patch("scratchgpt.preprocess.tqdm") + def test_file2file_happy_case(self, mock_tqdm: MagicMock) -> None: + """Test successful preprocessing of a single file.""" + tokenizer = MockTokenizer() + preprocessor = File2FileTokenizerPreprocessor(tokenizer) + + input_file = self.test_path / "input.txt" + output_file = self.test_path / "output.bin" + input_file.write_text("a b c", encoding="utf-8") + + preprocessor(input_file, output_file) + + self.assertTrue(output_file.exists()) + result = np.fromfile(output_file, dtype=preprocessor._preprocessor.dtype) + expected = np.array([1, 27, 2, 27, 3], dtype=preprocessor._preprocessor.dtype) + np.testing.assert_array_equal(result, expected) + + def test_file2file_error_input_not_found(self) -> None: + """Ensure error is raised if input file does not exist.""" + preprocessor = File2FileTokenizerPreprocessor(MockTokenizer()) + with self.assertRaises(ValueError): + # The call to `is_file()` inside the preprocessor will fail + preprocessor(self.test_path / "nonexistent.txt", self.test_path / "output.bin") + + def test_file2file_error_output_exists(self) -> None: + """Ensure error is raised if output file already exists.""" + preprocessor = File2FileTokenizerPreprocessor(MockTokenizer()) + input_file = self.test_path / "input.txt" + output_file = self.test_path / "output.bin" + input_file.touch() + output_file.touch() + with self.assertRaises(FileExistsError): + preprocessor(input_file, output_file) + + # --- Folder2FileTokenizerPreprocessor Tests --- + + @patch("scratchgpt.preprocess.tqdm") + def test_folder2file_happy_case(self, mock_tqdm: MagicMock) -> None: + """Test successful preprocessing of a directory.""" + preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) + + # Setup directory structure + (self.test_path / "sub").mkdir() + (self.test_path / "file1.txt").write_text("a b", encoding="utf-8") + (self.test_path / "file2.txt").write_text(" c d", encoding="utf-8") + (self.test_path / "sub" / "file3.txt").write_text(" e", encoding="utf-8") + # This file should be ignored + (self.test_path / ".ignored.txt").touch() + + output_file = self.test_path / "output.bin" + preprocessor(self.test_path, output_file) + + self.assertTrue(output_file.exists()) + result = np.fromfile(output_file, dtype=preprocessor._preprocessor.dtype) + # Order is not guaranteed, so we sort both arrays + result.sort() + expected = np.array([1, 27, 2, 27, 3, 27, 4, 27, 5], dtype=preprocessor._preprocessor.dtype) + expected.sort() + np.testing.assert_array_equal(result, expected) + + def test_folder2file_error_input_is_file(self) -> None: + """Ensure error is raised if input path is a file.""" + preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) + input_file = self.test_path / "input.txt" + input_file.touch() + with self.assertRaises(ValueError): + preprocessor(input_file, self.test_path / "output.bin") + + def test_folder2file_empty_folder(self) -> None: + """Test that an empty folder produces an empty output file.""" + preprocessor = Folder2FileTokenizerPreprocessor(MockTokenizer()) + output_file = self.test_path / "output.bin" + preprocessor(self.test_path, output_file) + self.assertTrue(output_file.exists()) + self.assertEqual(output_file.stat().st_size, 0) + + +class TestDatasetIntegration(unittest.TestCase): + def setUp(self) -> None: + """Create a temporary directory and a predictable tokenizer.""" + self.test_dir = tempfile.TemporaryDirectory() + self.test_path = Path(self.test_dir.name) + self.tokenizer = NumberTokenizer(vocab_size=500) + + # Common setup: create a preprocessed file with 100 tokens (0-99) + self.block_size = 10 + self.num_tokens = 100 + self.token_file = self.test_path / "tokens.bin" + preprocessor = File2FileTokenizerPreprocessor(self.tokenizer) + input_text = " ".join(map(str, range(self.num_tokens))) + input_file = self.test_path / "input.txt" + input_file.write_text(input_text) + preprocessor(input_file, self.token_file) + + self.dtype = np.dtype(np.uint16) + + def tearDown(self) -> None: + """Clean up the temporary directory.""" + self.test_dir.cleanup() + + def test_dataset_len_and_getitem(self) -> None: + """Verify the full dataset's length and item retrieval.""" + dataset = PretokenizedDataset(self.token_file, self.block_size, dtype=self.dtype) + + # Check __len__ + expected_len = self.num_tokens - self.block_size + self.assertEqual(len(dataset), expected_len) + + # Check __getitem__ + block, target = dataset[0] + + # Verify content + expected_block = torch.arange(0, self.block_size, dtype=torch.int64) + self.assertTrue(torch.equal(block, expected_block)) + + # Verify that the dtype is converted to long (int64) + self.assertEqual(block.dtype, torch.long) + self.assertEqual(target.dtype, torch.long) + + def test_integration_with_random_split(self) -> None: + """Verify the dataset works correctly with torch.utils.data.random_split.""" + from torch.utils.data import random_split + + full_dataset = PretokenizedDataset(self.token_file, self.block_size, dtype=self.dtype) + + # Use a generator for a deterministic split + generator = torch.Generator().manual_seed(42) + train_set, val_set, test_set = random_split(full_dataset, [0.8, 0.1, 0.1], generator=generator) + + # Verify subset lengths (Note: random_split provides Subset objects) + self.assertEqual(len(train_set), 72) + self.assertEqual(len(val_set), 9) + self.assertEqual(len(test_set), 9) + + # Check an item from a subset to ensure it proxies correctly + block, target = train_set[0] # Get the first item from the training Subset + + self.assertEqual(block.shape, (self.block_size,)) + self.assertEqual(target.shape, (self.block_size,)) + self.assertEqual(block.dtype, torch.long) + + def test_dataset_len_when_data_smaller_than_block_size(self) -> None: + """Test the edge case where token count is less than block_size.""" + token_file = self.test_path / "small_tokens.bin" + preprocessor = File2FileTokenizerPreprocessor(self.tokenizer) + + # Create a file with only 5 tokens + input_text = " ".join(map(str, range(5))) + input_file = self.test_path / "small_input.txt" + input_file.write_text(input_text) + preprocessor(input_file, token_file) + + # Use a block_size larger than the number of tokens + dataset = PretokenizedDataset(token_file, block_size=10, dtype=np.dtype(np.uint16)) + + # The length should be 0, not a negative number + self.assertEqual(len(dataset), 0) + + +if __name__ == "__main__": + unittest.main()