Training a bilinear layer on modular arithmetic, with visualization of interaction matrices and top eigenvector components.
This project uses modern Python tooling for a streamlined development experience:
- uv - Fast Python package manager
- ruff - Lightning-fast linting and formatting
- ty - Type checking
- pytest - Testing framework
- pre-commit - Git hooks for code quality
- Python 3.13
- uv installed
- Clone the repository:
git clone https://github.com/d0rbu/bilinear-modular-arthimetic.git
cd bilinear-modular-arthimetic- Install dependencies:
uv sync --dev- Install pre-commit hooks:
uv run pre-commit installuv run pytestCheck for linting issues:
uv run ruff check .Auto-fix linting issues:
uv run ruff check --fix .Format code:
uv run ruff format .uvx ty checkPre-commit hooks are automatically run before each commit. They will:
- Run ruff linting with auto-fixes
- Run ruff formatting
- Run ty type checking
To manually run all pre-commit hooks:
uv run pre-commit run --all-filesGitHub Actions automatically runs the following checks on every push and pull request, split into separate jobs:
- Linting and Formatting: Ruff linting and formatting verification
- Type Checking: ty type checking
- Tests: pytest test suite
All jobs run on Python 3.13.
Dataset generation and model training code is located in the core/ directory (work in progress).
Once you have a trained model checkpoint, you can visualize the interaction matrices:
# Visualize with default settings (output indices 0, 1, 112 for mod 113)
uv run python -m bilinear_modular.viz.interaction_matrices visualize checkpoints/model_epoch_2000.pt
# Visualize specific output classes
uv run python -m bilinear_modular.viz.interaction_matrices visualize checkpoints/model_epoch_2000.pt --output-indices 0 5 10 50 112
# Change number of eigenvectors to plot
uv run python -m bilinear_modular.viz.interaction_matrices visualize checkpoints/model_epoch_2000.pt --num-eigenvectors 10
# Save to a different directory
uv run python -m bilinear_modular.viz.interaction_matrices visualize checkpoints/model_epoch_2000.pt --output-dir figures/experiment_1See src/bilinear_modular/viz/README.md for detailed documentation on the visualization module.
Generate a modular arithmetic dataset for a given modulus:
from bilinear_modular import generate_dataset, ModularArithmeticDataset
# Generate dataset for mod 113 (creates all a+b combinations)
dataset = generate_dataset(mod_basis=113)
# Dataset info
print(f"Total samples: {len(dataset)}") # 113 * 113 = 12769
print(f"Training samples: {dataset.train_size}") # 80% = 10215
print(f"Validation samples: {dataset.val_size}") # 20% = 2554
# Get training batches (returns torch tensors)
inputs, targets = dataset.get_train_batch(batch_size=128)
# inputs: (128, 226) - one-hot encoded [a, b]
# targets: (128, 113) - one-hot encoded c where c = (a + b) % 113
# Get all training data
all_train_inputs, all_train_targets = dataset.get_all_train()
# Use as iterator for training loops
dataset.batch_size = 128
dataset.train() # Set to training mode
for inputs, targets in dataset:
# Your training code here
pass
# Load existing dataset
dataset = ModularArithmeticDataset(mod_basis=113)For a complete example, see examples/generate_dataset_example.py.
- Automatic caching: Datasets are saved to
data/{mod_basis}/for reuse as .pt files - Pure PyTorch: All data stored and returned as PyTorch tensors (no numpy)
- One-hot encoding: Optional one-hot encoding of inputs and outputs
- Efficient batching: Simple API for getting training/validation batches
- Iterator protocol: Supports
__iter__and__next__for easy training loops - Reproducible splits: Consistent 80/20 train/val split with fixed seed
.
├── src/
│ └── bilinear_modular/ # Main package
│ ├── __init__.py
│ └── core/
│ │ ├── __init__.py
│ │ └── dataset.py # Dataset generation and loading
│ └── viz/ # Visualization tools
│ ├── __init__.py
│ ├── interaction_matrices.py
│ └── README.md
├── fig/ # Output directory for figures
├── tests/ # Test files
│ ├── test_placeholder.py
│ └── test_visualization.py
├── examples/ # Example scripts
│ └── generate_dataset_example.py
├── data/ # Generated datasets (gitignored)
├── .github/
│ └── workflows/
│ └── ci.yml # GitHub Actions CI
├── .pre-commit-config.yaml # Pre-commit hooks config
├── pyproject.toml # Project configuration
└── README.md
- Make sure all tests pass:
uv run pytest - Ensure code is properly formatted:
uv run ruff format . - Check for linting issues:
uv run ruff check . - Verify type checking passes:
uvx ty check
Pre-commit hooks will automatically run these checks before each commit.