diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 27509ca..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..c7bbad6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,26 @@ +name: CI + +on: + push: + branches: [master] + pull_request: + +jobs: + smoke-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install numpy mido pytest + pip install torch --index-url https://download.pytorch.org/whl/cpu + + - name: Run smoke tests + run: pytest tests/ -v diff --git a/.gitignore b/.gitignore index 7bbc71c..a3cc02a 100755 --- a/.gitignore +++ b/.gitignore @@ -99,3 +99,12 @@ ENV/ # mypy .mypy_cache/ + +# macOS +.DS_Store + +# Dembow artifacts +*.pt +generated/ +# ...but the bundled pretrained model ships with the package +!dembow/assets/*.pt diff --git a/README.md b/README.md index 2773f93..01ee2ca 100755 --- a/README.md +++ b/README.md @@ -1,63 +1,177 @@ # Dembow -## The first A.I that generates reggaeton hits. +## The first A.I. that generates reggaeton hits. πŸ”₯ ![Denbow.jpg](denbow.jpg) -## Machine Learning Techniques -Using TensorFlow to generate short sequences of music with a [Restricted Boltzmann Machine](http://deeplearning4j.org/restrictedboltzmannmachine.html). -Do you want to go deep?, see the original technical idea: [How to build an RBM neural network in tensorflow](http://danshiebler.com/2016-08-10-musical-tensorflow-part-one-the-rbm/). +Dembow learns from a corpus of reggaeton MIDI and writes new tracks of its own. +It began life in 2016 as a Restricted Boltzmann Machine over a binary piano roll; +it is now a **decoder-only Transformer over an event-based music language** β€” the +same recipe behind modern symbolic-music models. +## How it works +Dembow treats music the way a language model treats text. Every song is +tokenized into a stream of musical **events** (REMI-style): -## Getting Started +``` +BOS BAR POS_0 INST_drums DRUM_kick DUR_1 VEL_5 + POS_0 INST_bass PITCH_36 DUR_4 VEL_6 + POS_4 INST_drums DRUM_snare DUR_1 VEL_5 ... + BAR ... EOS +``` + +Each note carries its **instrument group** (drums / bass / mid / high), **pitch**, +**duration**, and **velocity** β€” so the model can write expressive, +multi-instrument arrangements, not a flat on/off grid. A small Transformer then +learns to predict the next event from everything before it, using masked +self-attention to capture phrasing, repetition, and the way the drums and bass +lock into the dembow groove. + +Generation is autoregressive with **temperature + nucleus (top-p) sampling**, the +standard modern decoding strategy. + +## What changed from the original -1. Install [Tensorflow](https://www.tensorflow.org/). If you have trouble running Tensorflow installation it may help: +| Then (2016) | Now | +| --- | --- | +| Python 2, TensorFlow 1.x | Python 3, **PyTorch** | +| Restricted Boltzmann Machine | **Decoder-only Transformer** | +| Binary piano roll (on/off only) | **Event tokens**: pitch + duration + velocity | +| All tracks flattened into one roll | **Multi-instrument** (drums / bass / mid / high) | +| No sense of time | **Self-attention** over the whole sequence | +| Trained on ~76 raw files | **Pitch-augmented** corpus (Γ—7) for generalization | +| `python-midi` (Py2, dead) | [`mido`](https://mido.readthedocs.io) | +| Threw the weights away | Saves & loads checkpoints | +| One-shot script | A real CLI + installable package + CI | + +## Getting started ```sh -sudo easy_install pip -sudo pip install --upgrade virtualenv - export PIP_REQUIRE_VIRTUALENV=false +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt # numpy, mido, torch +# or install the package + `dembow` command: +pip install -e . +``` + +## Make magic happen +A **pretrained model ships with the package**, so you can generate immediately β€” +no training required: + +```sh +dembow generate # uses the bundled model -> generated/dembow_*.mid +dembow generate --render # also write .wav so you can actually hear it ``` -2. Install [Anaconda and dependencies](https://www.continuum.io/downloads) +Train your own (better, especially with more data): -3. Create virtualenv ```sh -virtualenv venv +dembow train # -> dembow.pt +dembow generate # picks up your dembow.pt automatically +``` +Without installing: + +```sh +python -m dembow.cli train +python -m dembow.cli generate ``` -4. Activate venv +Or light the fire (train + generate in one go): + ```sh -source venv/bin/activate +python fire.py ``` -5. Install python_midi module in normal procedure +## Hardware presets + +Transformers are slow to train on CPU, so `train` picks a preset automatically +(small model + early stopping on CPU, a bigger one on GPU). Override it, or any +individual flag: + ```sh -git clone git@github.com:vishnubob/python-midi.git -cd python-midi -python setup.py install +dembow train --preset cpu # small + fast (auto-selected when no GPU) +dembow train --preset gpu # bigger model, more epochs, more augmentation +dembow train --preset gpu --d-model 320 --n-layers 6 # flags override the preset ``` -6. Install remaining dependencies with pip. -- matplotlib -- numpy -- pandas -- msgpack -- glob -- tqdm +## Training quality: validation + early stopping + +Because the corpus is tiny, overfitting is the main risk. Training holds out a +fraction of **songs** (not windows β€” so pitch-augmented copies can't leak), +reports validation loss each epoch, **saves the checkpoint with the best val +loss**, and stops early when it plateaus: ```sh -pip install [dependencies] +dembow train --val-frac 0.1 --patience 8 +# epoch 12/40 train 1.49 val 1.71 *best (saved) +# ... +# Early stopping at epoch 23 (no val improvement for 8 epochs) ``` -7. Make magic happen. First train your model with custom parameters and then wait the output. +## Generation + ```sh -python fire.py +dembow generate \ + --num-samples 8 \ + --max-new-tokens 1200 \ # longer songs + --temperature 0.9 \ # <1 tighter & more repetitive, >1 wilder + --top-p 0.92 \ # nucleus sampling threshold + --repetition-penalty 1.15 \ # discourage degenerate loops (1.0 = off) + --no-repeat-ngram 0 \ # hard-ban repeated token n-grams (0 = off) + --prime-bars 2 \ # real bars used to kick off each song + --seed-dir none # cold start instead of priming from a real song +``` + +`--repetition-penalty` gently down-weights recently used tokens so the model +doesn't get stuck looping β€” while still allowing the musical repetition that +makes a groove a groove. + +### Hearing it (audio) + +```sh +dembow generate --render # writes .wav next to each .mid +dembow generate --render --soundfont my.sf2 # better quality via FluidSynth +``` + +`--render` turns each generated song into audio. If [FluidSynth](https://www.fluidsynth.org/) +and a SoundFont are installed it uses them for realistic instruments; otherwise +it falls back to a small built-in synth so rendering works with no extra setup. + +## Hear it without training + +A few example outputs from a small demo model live in [`examples/`](examples/) +so you can listen before training your own. + +**Honest note on quality.** The corpus is only ~76 short MIDI files, so even a +Transformer is data-limited β€” it captures the *feel* (groove, instrumentation, +key) more than polished, hook-worthy songwriting. The single biggest lever is +**more clean MIDI** in `reggaeton_samples/` (see +[`reggaeton_samples/SOURCES.md`](reggaeton_samples/SOURCES.md) for where to find +it). Pitch augmentation and priming from real songs help it stay in the pocket +meanwhile. + +## Project layout + +``` +dembow/ + tokenizer.py MIDI <-> event tokens (the REMI-style music language) + model.py the decoder-only Transformer + data.py corpus loading, pitch augmentation, song-level train/val split + train.py training loop, validation, early stopping, best-checkpoint + generate.py sample new songs (temperature / top-p / repetition control) + render.py MIDI -> audio (FluidSynth, or a builtin dependency-free synth) + cli.py the `dembow` command (with cpu/gpu presets) + assets/ a bundled pretrained model so generation works out of the box +fire.py one-shot entry point +reggaeton_samples/ the MIDI corpus (+ SOURCES.md: where to find more) +examples/ a few generated outputs from a small demo model +tests/ a fast end-to-end smoke test ``` -Depends of the technical capabilities of your computer, it can take from 5 to 10 minutes. ## Contribute -We need your help feeding and training our current model. If you have reggeaton samples feel free to contribute. +We still need your help feeding the model. If you have reggaeton MIDI, drop it in +`reggaeton_samples/` and open a pull request β€” more data is the single best way +to make Dembow sound like a hit. diff --git a/dembow/__init__.py b/dembow/__init__.py new file mode 100644 index 0000000..ee9bbcd --- /dev/null +++ b/dembow/__init__.py @@ -0,0 +1,25 @@ +"""Dembow -- a Transformer that generates reggaeton. + +The first A.I. that generates reggaeton hits, rebuilt around a modern, +event-based music language model. + +The 2016 original trained a Restricted Boltzmann Machine on a binary piano roll. +This version replaces that entirely: songs are tokenized into a REMI-style stream +of musical events (bar, position, instrument, pitch, duration, velocity) and a +decoder-only Transformer learns to generate them one token at a time -- the same +recipe used by modern symbolic-music models. +""" + +from .tokenizer import VOCAB, Vocab, encode, decode +from .model import MusicTransformer, ModelConfig + +__version__ = "2.2.0" + +__all__ = [ + "VOCAB", + "Vocab", + "encode", + "decode", + "MusicTransformer", + "ModelConfig", +] diff --git a/dembow/assets/dembow-pretrained.pt b/dembow/assets/dembow-pretrained.pt new file mode 100644 index 0000000..ad0c89c Binary files /dev/null and b/dembow/assets/dembow-pretrained.pt differ diff --git a/dembow/cli.py b/dembow/cli.py new file mode 100644 index 0000000..db63f46 --- /dev/null +++ b/dembow/cli.py @@ -0,0 +1,130 @@ +"""Command line interface for Dembow. + + dembow train # train the Transformer on the reggaeton corpus + dembow generate # sample new reggaeton out of a trained model +""" + +from __future__ import annotations + +import argparse + +from . import __version__ + +# Hardware presets. CPU training of a Transformer is slow, so the CPU preset is +# deliberately small and leans on early stopping; the GPU preset is more +# ambitious. Any value passed explicitly on the command line overrides these. +PRESETS = { + "cpu": dict(d_model=128, n_layers=2, n_heads=4, seq_len=192, num_epochs=40, batch_size=24, augment=1), + "gpu": dict(d_model=256, n_layers=4, n_heads=4, seq_len=384, num_epochs=120, batch_size=32, augment=3), +} + + +def _resolve_preset(name: str) -> dict: + if name == "auto": + try: + import torch + + name = "gpu" if torch.cuda.is_available() else "cpu" + except Exception: + name = "cpu" + return PRESETS[name] + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="dembow", + description="The first A.I. that generates reggaeton hits.", + ) + parser.add_argument("--version", action="version", version=f"dembow {__version__}") + sub = parser.add_subparsers(dest="command", required=True) + + t = sub.add_parser("train", help="train the Transformer on a folder of MIDI files") + t.add_argument("--preset", choices=["auto", "cpu", "gpu"], default="auto", + help="hardware preset for model size / epochs (default auto); explicit flags override it") + t.add_argument("--data-dir", default="reggaeton_samples") + t.add_argument("--checkpoint", default="dembow.pt") + # These default to None so the preset can fill them in unless set explicitly. + t.add_argument("--seq-len", type=int, default=None) + t.add_argument("--d-model", type=int, default=None) + t.add_argument("--n-layers", type=int, default=None) + t.add_argument("--n-heads", type=int, default=None) + t.add_argument("--num-epochs", type=int, default=None) + t.add_argument("--batch-size", type=int, default=None) + t.add_argument("--augment", type=int, default=None, help="pitch-shift augmentation range in semitones (0 disables)") + t.add_argument("--lr", type=float, default=3e-4) + t.add_argument("--val-frac", type=float, default=0.1, help="fraction of songs held out for validation") + t.add_argument("--patience", type=int, default=8, help="early-stop after N epochs without val improvement") + t.add_argument("--seed", type=int, default=0) + t.add_argument("--device", default=None) + + g = sub.add_parser("generate", help="sample new reggaeton from a trained model") + g.add_argument("--checkpoint", default="dembow.pt") + g.add_argument("--output-dir", default="generated") + g.add_argument("--num-samples", type=int, default=5) + g.add_argument("--max-new-tokens", type=int, default=800) + g.add_argument("--temperature", type=float, default=1.0, help="<1 tighter, >1 wilder") + g.add_argument("--top-p", type=float, default=0.92, help="nucleus sampling threshold") + g.add_argument("--top-k", type=int, default=None) + g.add_argument("--repetition-penalty", type=float, default=1.15, help="discourage looping (1.0 = off)") + g.add_argument("--no-repeat-ngram", type=int, default=0, help="hard-ban repeated token n-grams of this size (0 = off)") + g.add_argument("--prime-bars", type=int, default=2, help="real bars used to prime each song ('--seed-dir none' for cold start)") + g.add_argument("--seed-dir", default="reggaeton_samples") + g.add_argument("--tempo-bpm", type=float, default=95.0) + g.add_argument("--render", action="store_true", help="also render each song to .wav audio") + g.add_argument("--soundfont", default=None, help="SoundFont (.sf2) for FluidSynth rendering; omit to use the builtin synth") + g.add_argument("--random-seed", type=int, default=0) + g.add_argument("--device", default=None) + + return parser + + +def main(argv=None) -> None: + args = build_parser().parse_args(argv) + + if args.command == "train": + from .train import train + + preset = _resolve_preset(args.preset) + pick = lambda name: getattr(args, name) if getattr(args, name) is not None else preset[name] + train( + data_dir=args.data_dir, + checkpoint=args.checkpoint, + seq_len=pick("seq_len"), + d_model=pick("d_model"), + n_layers=pick("n_layers"), + n_heads=pick("n_heads"), + num_epochs=pick("num_epochs"), + batch_size=pick("batch_size"), + augment=pick("augment"), + lr=args.lr, + val_frac=args.val_frac, + patience=args.patience, + seed=args.seed, + device=args.device, + ) + elif args.command == "generate": + from .generate import generate + + seed_dir = None if str(args.seed_dir).lower() == "none" else args.seed_dir + generate( + checkpoint=args.checkpoint, + output_dir=args.output_dir, + num_samples=args.num_samples, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + no_repeat_ngram_size=args.no_repeat_ngram, + prime_bars=args.prime_bars, + seed_dir=seed_dir, + tempo_bpm=args.tempo_bpm, + render=args.render, + soundfont=args.soundfont, + random_seed=args.random_seed, + device=args.device, + ) + + +if __name__ == "__main__": + main() diff --git a/dembow/data.py b/dembow/data.py new file mode 100644 index 0000000..da6a711 --- /dev/null +++ b/dembow/data.py @@ -0,0 +1,78 @@ +"""Load the reggaeton corpus and turn it into token windows for training.""" + +from __future__ import annotations + +import glob +import os +from typing import List + +import numpy as np + +from .tokenizer import BOS, EOS, PAD, encode + + +def find_midi_files(directory: str) -> List[str]: + """Every MIDI file in ``directory`` (case-insensitive extension).""" + files = glob.glob(os.path.join(directory, "*")) + return sorted(f for f in files if f.lower().endswith((".mid", ".midi"))) + + +def split_files(paths: List[str], val_frac: float, seed: int = 0) -> tuple[List[str], List[str]]: + """Split a file list into (train, val) at the *song* level. + + Splitting by song -- not by window -- matters because pitch augmentation + creates transposed copies of each song; a window-level split would leak those + copies across the train/val boundary and make the validation loss a lie. + """ + rng = np.random.RandomState(seed) + order = list(paths) + rng.shuffle(order) + n_val = int(len(order) * val_frac) + return order[n_val:], order[:n_val] + + +def encode_files(paths: List[str], transpositions: List[int] | None = None) -> List[List[int]]: + """Encode a specific list of MIDI files (optionally pitch-augmented).""" + if transpositions is None: + transpositions = [0] + songs = [] + for path in paths: + for shift in transpositions: + try: + ids = encode(path, transpose=shift) + except Exception as exc: + print(f" skipping {os.path.basename(path)} (shift {shift}): {type(exc).__name__}: {exc}") + continue + if len(ids) > 8: + songs.append(ids) + return songs + + +def encode_corpus(directory: str, transpositions: List[int] | None = None) -> List[List[int]]: + """Encode every song, optionally augmenting with transposed copies. + + Pitch augmentation multiplies our tiny ~76-song corpus: each transpose is a + musically valid variation, which helps the model generalize instead of just + memorizing a handful of files. + """ + return encode_files(find_midi_files(directory), transpositions) + + +def build_windows(songs: List[List[int]], seq_len: int) -> np.ndarray: + """Slice token streams into fixed-length training windows. + + Each window is one training example; the model learns to predict every token + from the ones before it. Short songs are padded; long songs yield several + overlapping windows. + """ + examples = [] + stride = seq_len // 2 + for ids in songs: + if len(ids) <= seq_len: + examples.append(ids + [PAD] * (seq_len - len(ids))) + else: + for start in range(0, len(ids) - seq_len + 1, stride): + examples.append(ids[start : start + seq_len]) + if not examples: + return np.zeros((0, seq_len), dtype=np.int64) + return np.array(examples, dtype=np.int64) diff --git a/dembow/generate.py b/dembow/generate.py new file mode 100644 index 0000000..54fe136 --- /dev/null +++ b/dembow/generate.py @@ -0,0 +1,95 @@ +"""Generate new reggaeton from a trained Dembow Transformer.""" + +from __future__ import annotations + +import os +from typing import List, Optional + +import numpy as np +import torch + +from .data import find_midi_files +from .model import MusicTransformer +from .tokenizer import BAR, BOS, EOS, decode, encode + + +BUNDLED_CHECKPOINT = os.path.join(os.path.dirname(__file__), "assets", "dembow-pretrained.pt") + + +def _resolve_checkpoint(checkpoint: str) -> str: + """Use the given checkpoint, or fall back to the bundled pretrained model.""" + if os.path.exists(checkpoint): + return checkpoint + if os.path.exists(BUNDLED_CHECKPOINT): + print(f"'{checkpoint}' not found -- using the bundled pretrained model.") + return BUNDLED_CHECKPOINT + raise SystemExit(f"No checkpoint at '{checkpoint}' and no bundled model. Train one with `dembow train`.") + + +def _prime_tokens(seed_dir: Optional[str], prime_bars: int) -> List[int]: + """Build a priming prompt from the opening bars of a real reggaeton song.""" + if not seed_dir or not os.path.isdir(seed_dir): + return [BOS] + for path in find_midi_files(seed_dir): + ids = encode(path) + # Cut after `prime_bars` BAR tokens so we hand the model a few real bars. + bars = 0 + for i, t in enumerate(ids): + if t == BAR: + bars += 1 + if bars > prime_bars: + return ids[:i] + if len(ids) > 4: + return ids + return [BOS] + + +def generate( + checkpoint: str = "dembow.pt", + output_dir: str = "generated", + num_samples: int = 5, + max_new_tokens: int = 800, + temperature: float = 1.0, + top_p: float = 0.92, + top_k: Optional[int] = None, + repetition_penalty: float = 1.15, + no_repeat_ngram_size: int = 0, + prime_bars: int = 2, + seed_dir: Optional[str] = "reggaeton_samples", + tempo_bpm: float = 95.0, + render: bool = False, + soundfont: Optional[str] = None, + random_seed: int = 0, + device: str | None = None, +) -> List[str]: + """Sample songs from the model and write them out as MIDI files.""" + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + torch.manual_seed(random_seed) + np.random.seed(random_seed) + os.makedirs(output_dir, exist_ok=True) + + model = MusicTransformer.load(_resolve_checkpoint(checkpoint), device=device) + prompt_ids = _prime_tokens(seed_dir, prime_bars) + + written = [] + for i in range(num_samples): + prompt = torch.tensor(prompt_ids, dtype=torch.long) + ids = model.generate( + prompt, max_new_tokens=max_new_tokens, temperature=temperature, + top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, eos_id=EOS, device=device, + ) + midi_file = decode(ids, tempo_bpm=tempo_bpm) + # A valid song needs at least a couple of note tracks (beyond the meta track). + if len(midi_file.tracks) < 2: + continue + out_path = os.path.join(output_dir, f"dembow_{i}.mid") + midi_file.save(out_path) + written.append(out_path) + if render: + from .render import render_to_wav + + render_to_wav(out_path, out_path[:-4] + ".wav", soundfont=soundfont) + + print(f"Wrote {len(written)} MIDI file(s) to '{output_dir}/'" + (" (+ .wav audio)" if render else "")) + return written diff --git a/dembow/model.py b/dembow/model.py new file mode 100644 index 0000000..3cb242b --- /dev/null +++ b/dembow/model.py @@ -0,0 +1,138 @@ +"""A small decoder-only Transformer language model over music tokens. + +This is the new engine: instead of an RBM's static "bag of notes" or an LSTM's +single recurrent state, it uses masked self-attention to condition every token on +the entire sequence so far. That long-range view is what lets it learn musical +structure -- phrases, repetition, how the drums and bass lock together -- one +event at a time, exactly as a GPT-style model generates text. +""" + +from __future__ import annotations + +from dataclasses import dataclass, asdict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _banned_ngram_tokens(ids: list[int], n: int) -> list[int]: + """Tokens that would complete a previously seen n-gram given the current tail.""" + if len(ids) < n: + return [] + prefix = tuple(ids[-(n - 1):]) if n > 1 else () + banned = [] + for i in range(len(ids) - n + 1): + if tuple(ids[i:i + n - 1]) == prefix: + banned.append(ids[i + n - 1]) + return banned + + +@dataclass +class ModelConfig: + vocab_size: int + d_model: int = 256 + n_heads: int = 4 + n_layers: int = 4 + d_ff: int = 512 + max_len: int = 512 + dropout: float = 0.1 + + +class MusicTransformer(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.token_emb = nn.Embedding(config.vocab_size, config.d_model) + self.pos_emb = nn.Embedding(config.max_len, config.d_model) + self.drop = nn.Dropout(config.dropout) + layer = nn.TransformerEncoderLayer( + d_model=config.d_model, + nhead=config.n_heads, + dim_feedforward=config.d_ff, + dropout=config.dropout, + batch_first=True, + activation="gelu", + norm_first=True, + ) + self.blocks = nn.TransformerEncoder(layer, num_layers=config.n_layers, enable_nested_tensor=False) + self.norm = nn.LayerNorm(config.d_model) + self.head = nn.Linear(config.d_model, config.vocab_size, bias=False) + # Weight tying: input embedding and output projection share weights. + self.head.weight = self.token_emb.weight + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, t = x.shape + pos = torch.arange(t, device=x.device).unsqueeze(0) + h = self.drop(self.token_emb(x) + self.pos_emb(pos)) + mask = torch.triu(torch.ones(t, t, device=x.device, dtype=torch.bool), diagonal=1) + h = self.blocks(h, mask=mask) + return self.head(self.norm(h)) + + @torch.no_grad() + def generate( + self, + prompt: torch.Tensor, + max_new_tokens: int = 600, + temperature: float = 1.0, + top_k: int | None = None, + top_p: float | None = 0.92, + repetition_penalty: float = 1.15, + recent_window: int = 96, + no_repeat_ngram_size: int = 0, + eos_id: int | None = None, + device: str | torch.device = "cpu", + ) -> list[int]: + """Autoregressively sample a continuation. + + Uses temperature + nucleus (top-p) sampling, plus repetition control to + keep the model from collapsing into a degenerate loop -- a real risk when + a Transformer is trained on so little data. ``repetition_penalty`` gently + down-weights tokens seen in the last ``recent_window`` steps (musical + repetition is still allowed); ``no_repeat_ngram_size`` hard-bans exact + n-gram repeats when > 0. + """ + self.eval() + ids = prompt.to(device).tolist() + for _ in range(max_new_tokens): + context = torch.tensor(ids[-self.config.max_len:], device=device).unsqueeze(0) + logits = self.forward(context)[0, -1] + + if repetition_penalty and repetition_penalty != 1.0: + for t in set(ids[-recent_window:]): + logits[t] = logits[t] / repetition_penalty if logits[t] > 0 else logits[t] * repetition_penalty + if no_repeat_ngram_size > 0: + for t in _banned_ngram_tokens(ids, no_repeat_ngram_size): + logits[t] = -float("inf") + + logits = logits / max(temperature, 1e-6) + + if top_k is not None: + kth = torch.topk(logits, top_k).values[-1] + logits[logits < kth] = -float("inf") + if top_p is not None: + ordered, idx = torch.sort(logits, descending=True) + probs = torch.softmax(ordered, dim=-1) + cutoff = torch.cumsum(probs, dim=-1) > top_p + cutoff[1:] = cutoff[:-1].clone() + cutoff[0] = False + ordered[cutoff] = -float("inf") + logits = torch.full_like(logits, -float("inf")).scatter(0, idx, ordered) + + probs = torch.softmax(logits, dim=-1) + nxt = int(torch.multinomial(probs, 1)) + ids.append(nxt) + if eos_id is not None and nxt == eos_id: + break + return ids + + def save(self, path: str) -> None: + torch.save({"model_type": "transformer", "config": asdict(self.config), "state_dict": self.state_dict()}, path) + + @classmethod + def load(cls, path: str, device: str | torch.device = "cpu") -> "MusicTransformer": + ckpt = torch.load(path, map_location="cpu", weights_only=False) + model = cls(ModelConfig(**ckpt["config"])) + model.load_state_dict(ckpt["state_dict"]) + model.to(device) + return model diff --git a/dembow/render.py b/dembow/render.py new file mode 100644 index 0000000..6727512 --- /dev/null +++ b/dembow/render.py @@ -0,0 +1,146 @@ +"""Render generated MIDI to audio so you can actually hear it. + +Two backends, tried in order: + +1. **FluidSynth** -- if the ``fluidsynth`` binary and a SoundFont are available, + use them for the best, most realistic sound. +2. **Built-in NumPy synth** -- a tiny, dependency-free oscillator/noise synth so + rendering works everywhere even without FluidSynth installed. It won't sound + like a studio, but you'll hear the groove. +""" + +from __future__ import annotations + +import os +import shutil +import struct +import subprocess +import wave +from typing import List, Optional + +import mido +import numpy as np + +_SOUNDFONT_PATHS = [ + "/usr/share/sounds/sf2/FluidR3_GM.sf2", + "/usr/share/sounds/sf2/default-GM.sf2", + "/usr/share/soundfonts/FluidR3_GM.sf2", + "/usr/share/soundfonts/default.sf2", +] + + +def _find_soundfont(explicit: Optional[str]) -> Optional[str]: + if explicit and os.path.exists(explicit): + return explicit + return next((p for p in _SOUNDFONT_PATHS if os.path.exists(p)), None) + + +def _render_fluidsynth(midi_path: str, wav_path: str, soundfont: str, sample_rate: int) -> bool: + try: + subprocess.run( + ["fluidsynth", "-ni", "-F", wav_path, "-r", str(sample_rate), soundfont, midi_path], + check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, + ) + return os.path.exists(wav_path) + except Exception: + return False + + +def _adsr(n: int, sample_rate: int, release: float = 0.05) -> np.ndarray: + """A simple attack/release envelope to avoid clicks.""" + env = np.ones(n) + attack = min(n, int(0.005 * sample_rate)) + rel = min(n, int(release * sample_rate)) + if attack: + env[:attack] = np.linspace(0, 1, attack) + if rel: + env[-rel:] *= np.linspace(1, 0, rel) + return env + + +def _tone(freq: float, dur: float, sample_rate: int, harmonics=(1.0, 0.4, 0.2)) -> np.ndarray: + n = max(1, int(dur * sample_rate)) + t = np.arange(n) / sample_rate + wave_out = sum(amp * np.sin(2 * np.pi * freq * h * t) for h, amp in enumerate(harmonics, start=1)) + return wave_out * _adsr(n, sample_rate) + + +def _drum(note: int, sample_rate: int) -> np.ndarray: + """Percussion: pitched sine for the kick, shaped noise for snare/hats.""" + if note <= 37: # kick + n = int(0.18 * sample_rate) + t = np.arange(n) / sample_rate + freq = 110 * np.exp(-30 * t) + 45 # pitch drop + return np.sin(2 * np.pi * freq * t) * np.exp(-12 * t) + if note in (38, 39, 40): # snare / clap + n = int(0.12 * sample_rate) + return np.random.randn(n) * np.exp(-22 * np.arange(n) / sample_rate) + # hats / cymbals: short bright noise + n = int(0.05 * sample_rate) + return np.random.randn(n) * np.exp(-60 * np.arange(n) / sample_rate) + + +def _builtin_synth(midi_file: mido.MidiFile, sample_rate: int) -> np.ndarray: + """Synthesize a MIDI file to a mono float waveform with no external deps.""" + # Collect notes as (start_sec, dur_sec, channel, note, velocity). + notes = [] + open_notes = {} + abs_time = 0.0 + for msg in midi_file: # iterating a MidiFile yields delta time in seconds + abs_time += msg.time + if msg.type == "note_on" and msg.velocity > 0: + if msg.channel == 9: + notes.append((abs_time, 0.2, 9, msg.note, msg.velocity)) + else: + open_notes[(msg.channel, msg.note)] = (abs_time, msg.velocity) + elif msg.type in ("note_off", "note_on"): + key = (msg.channel, msg.note) + if key in open_notes: + start, vel = open_notes.pop(key) + notes.append((start, max(0.05, abs_time - start), msg.channel, msg.note, vel)) + + if not notes: + return np.zeros(sample_rate, dtype=np.float32) + + total = max(s + d for s, d, *_ in notes) + 0.3 + buf = np.zeros(int(total * sample_rate) + sample_rate, dtype=np.float32) + for start, dur, channel, note, vel in notes: + idx = int(start * sample_rate) + amp = (vel / 127.0) * 0.25 + if channel == 9: + seg = _drum(note, sample_rate) * amp + else: + freq = 440.0 * 2 ** ((note - 69) / 12.0) + gain = 1.4 if channel == 1 else 1.0 # bass a touch louder + seg = _tone(freq, dur, sample_rate) * amp * gain + end = min(len(buf), idx + len(seg)) + buf[idx:end] += seg[: end - idx] + + peak = np.max(np.abs(buf)) + if peak > 0: + buf = buf / peak * 0.95 + return buf[: int(total * sample_rate)] + + +def _write_wav(samples: np.ndarray, path: str, sample_rate: int) -> None: + pcm = (np.clip(samples, -1.0, 1.0) * 32767).astype(np.int16) + with wave.open(path, "w") as w: + w.setnchannels(1) + w.setsampwidth(2) + w.setframerate(sample_rate) + w.writeframes(struct.pack(f"<{len(pcm)}h", *pcm.tolist())) + + +def render_to_wav( + midi_path: str, + wav_path: str, + soundfont: Optional[str] = None, + sample_rate: int = 22050, +) -> str: + """Render a MIDI file to WAV. Uses FluidSynth if available, else a builtin synth.""" + sf = _find_soundfont(soundfont) + if shutil.which("fluidsynth") and sf and _render_fluidsynth(midi_path, wav_path, sf, sample_rate): + return wav_path + midi_file = mido.MidiFile(midi_path) + _write_wav(_builtin_synth(midi_file, sample_rate), wav_path, sample_rate) + return wav_path diff --git a/dembow/tokenizer.py b/dembow/tokenizer.py new file mode 100644 index 0000000..227008c --- /dev/null +++ b/dembow/tokenizer.py @@ -0,0 +1,224 @@ +"""Event-based (REMI-style) tokenization of MIDI. + +The old project represented music as a binary piano roll -- a grid of on/off +bits with no duration, velocity, or instrument identity. This replaces it with a +sequence of musical *events*, the representation modern symbolic-music models use +(Music Transformer, REMI, MMM). A song becomes a flat sequence of tokens:: + + BOS BAR POS_0 INST_drums PITCH_? DUR_1 VEL_5 INST_bass PITCH_36 DUR_4 VEL_6 ... + BAR POS_4 ... EOS + +Each note carries its instrument group, pitch, duration, and velocity, so the +Transformer can learn expressive, multi-instrument arrangements -- not just a +flat grid. Timing is quantized to a 16th-note grid (16 positions per 4/4 bar). +""" + +from __future__ import annotations + +from typing import Dict, List, Tuple + +import mido + +# -- musical vocabulary ---------------------------------------------------- + +STEPS_PER_BAR = 16 # 16th-note grid, 4/4 +MEL_LOW, MEL_HIGH = 24, 96 # melodic pitch range (C1..C7) +N_MEL = MEL_HIGH - MEL_LOW + 1 +N_VEL_BINS = 8 +DUR_BUCKETS = [1, 2, 3, 4, 6, 8, 12, 16, 24, 32] +INST_GROUPS = ["drums", "bass", "mid", "high"] +DRUM_CHANNEL = 9 + +# General MIDI percussion -> a small set of drum classes (the dembow toolkit). +DRUM_CLASSES = ["kick", "snare", "clap", "closed_hat", "open_hat", "tom", "crash", "ride", "perc"] +_DRUM_MAP = { + 35: 0, 36: 0, 38: 1, 40: 1, 37: 2, 39: 2, 42: 3, 44: 3, 46: 4, + 41: 5, 43: 5, 45: 5, 47: 5, 48: 5, 50: 5, 49: 6, 57: 6, 55: 6, 51: 7, 59: 7, 53: 7, +} +_DRUM_REPR = [36, 38, 39, 42, 46, 45, 49, 51, 54] # play each class back as this GM note +# Instrument group -> GM program for playback (drums use channel 9). +_INST_PROGRAM = {"bass": 38, "mid": 0, "high": 81} + + +def _drum_class(note: int) -> int: + return _DRUM_MAP.get(note, 8) + + +def _vel_bin(vel: int) -> int: + return min(N_VEL_BINS - 1, max(0, vel // (128 // N_VEL_BINS))) + + +def _vel_value(bin_idx: int) -> int: + return min(127, bin_idx * (128 // N_VEL_BINS) + 24) + + +def _dur_bucket(steps: int) -> int: + best = 0 + for i, b in enumerate(DUR_BUCKETS): + if abs(b - steps) < abs(DUR_BUCKETS[best] - steps): + best = i + return best + + +def _pitch_group(pitch: int) -> str: + if pitch < 48: + return "bass" + if pitch < 72: + return "mid" + return "high" + + +class Vocab: + """Builds and holds the token <-> id mapping.""" + + def __init__(self): + tokens: List[str] = ["PAD", "BOS", "EOS", "BAR"] + tokens += [f"POS_{p}" for p in range(STEPS_PER_BAR)] + tokens += [f"INST_{g}" for g in INST_GROUPS] + tokens += [f"PITCH_{n}" for n in range(MEL_LOW, MEL_HIGH + 1)] + tokens += [f"DRUM_{c}" for c in range(len(DRUM_CLASSES))] + tokens += [f"DUR_{i}" for i in range(len(DUR_BUCKETS))] + tokens += [f"VEL_{v}" for v in range(N_VEL_BINS)] + self.itos: List[str] = tokens + self.stoi: Dict[str, int] = {t: i for i, t in enumerate(tokens)} + + def __len__(self) -> int: + return len(self.itos) + + def __getitem__(self, token: str) -> int: + return self.stoi[token] + + +VOCAB = Vocab() +PAD, BOS, EOS, BAR = VOCAB["PAD"], VOCAB["BOS"], VOCAB["EOS"], VOCAB["BAR"] + + +# -- parsing --------------------------------------------------------------- + +def _read_notes(path: str) -> List[Tuple[int, str, int, int, int]]: + """Return notes as ``(start_step, group, sound_id, dur_steps, vel)``.""" + midi_file = mido.MidiFile(path) + ticks_per_step = midi_file.ticks_per_beat / 4.0 + notes = [] + open_notes: Dict[Tuple[int, int], Tuple[int, int]] = {} + + abs_tick = 0 + for msg in mido.merge_tracks(midi_file.tracks): + abs_tick += msg.time + step = int(round(abs_tick / ticks_per_step)) + if msg.type == "note_on" and msg.velocity > 0: + if msg.channel == DRUM_CHANNEL: + # Percussion: instantaneous hit, no sustain. + notes.append((step, "drums", _drum_class(msg.note), 1, msg.velocity)) + else: + open_notes[(msg.channel, msg.note)] = (step, msg.velocity) + elif msg.type == "note_off" or (msg.type == "note_on" and msg.velocity == 0): + key = (msg.channel, msg.note) + if key in open_notes and msg.channel != DRUM_CHANNEL: + start, vel = open_notes.pop(key) + dur = max(1, step - start) + notes.append((start, _pitch_group(msg.note), msg.note, dur, vel)) + return notes + + +def encode(path: str, transpose: int = 0) -> List[int]: + """Encode a MIDI file into a list of token ids (optionally transposed).""" + notes = _read_notes(path) + # Apply transposition to pitched notes only. + shifted = [] + for start, group, sound, dur, vel in notes: + if group != "drums": + sound += transpose + if sound < MEL_LOW or sound > MEL_HIGH: + continue + shifted.append((start, group, sound, dur, vel)) + shifted.sort(key=lambda n: (n[0], INST_GROUPS.index(n[1]), n[2])) + + ids = [BOS] + cur_bar, cur_pos = -1, -1 + for start, group, sound, dur, vel in shifted: + bar, pos = divmod(start, STEPS_PER_BAR) + while cur_bar < bar: + ids.append(BAR) + cur_bar += 1 + cur_pos = -1 + if pos != cur_pos: + ids.append(VOCAB[f"POS_{pos}"]) + cur_pos = pos + ids.append(VOCAB[f"INST_{group}"]) + if group == "drums": + ids.append(VOCAB[f"DRUM_{sound}"]) + else: + ids.append(VOCAB[f"PITCH_{sound}"]) + ids.append(VOCAB[f"DUR_{_dur_bucket(dur)}"]) + ids.append(VOCAB[f"VEL_{_vel_bin(vel)}"]) + ids.append(EOS) + return ids + + +# -- decoding -------------------------------------------------------------- + +def decode(ids: List[int], tempo_bpm: float = 95.0, ticks_per_beat: int = 480) -> mido.MidiFile: + """Turn a token-id sequence back into a multi-track MIDI file.""" + ticks_per_step = ticks_per_beat // 4 + midi_file = mido.MidiFile(ticks_per_beat=ticks_per_beat) + + meta = mido.MidiTrack() + meta.append(mido.MetaMessage("set_tempo", tempo=mido.bpm2tempo(tempo_bpm), time=0)) + midi_file.tracks.append(meta) + + # Collect note events per instrument group as (tick, kind, note, vel). + events: Dict[str, List[Tuple[int, str, int, int]]] = {g: [] for g in INST_GROUPS} + + bar, pos = -1, 0 + group = sound = dur = vel = None + for tid in ids: + tok = VOCAB.itos[tid] + if tok in ("BOS", "PAD"): + continue + if tok == "EOS": + break + if tok == "BAR": + bar += 1 + pos = 0 + elif tok.startswith("POS_"): + pos = int(tok[4:]) + elif tok.startswith("INST_"): + group, sound, dur, vel = tok[5:], None, None, None + elif tok.startswith("PITCH_"): + sound = int(tok[6:]) + elif tok.startswith("DRUM_"): + sound = int(tok[5:]) + elif tok.startswith("DUR_"): + dur = DUR_BUCKETS[int(tok[4:])] + elif tok.startswith("VEL_"): + vel = _vel_value(int(tok[4:])) + if group is not None and sound is not None: + start_step = max(0, bar) * STEPS_PER_BAR + pos + tick = start_step * ticks_per_step + if group == "drums": + note = _DRUM_REPR[sound % len(_DRUM_REPR)] + events["drums"].append((tick, "note_on", note, vel)) + events["drums"].append((tick + ticks_per_step // 2, "note_off", note, 0)) + else: + length = (dur or 2) * ticks_per_step + events[group].append((tick, "note_on", sound, vel)) + events[group].append((tick + length, "note_off", sound, 0)) + group = sound = dur = vel = None + + for g in INST_GROUPS: + evs = events[g] + if not evs: + continue + track = mido.MidiTrack() + channel = DRUM_CHANNEL if g == "drums" else INST_GROUPS.index(g) + if g != "drums": + track.append(mido.Message("program_change", program=_INST_PROGRAM[g], channel=channel, time=0)) + evs.sort(key=lambda e: e[0]) + last = 0 + for tick, kind, note, v in evs: + track.append(mido.Message(kind, note=note, velocity=v, channel=channel, time=tick - last)) + last = tick + midi_file.tracks.append(track) + + return midi_file diff --git a/dembow/train.py b/dembow/train.py new file mode 100644 index 0000000..80aff1c --- /dev/null +++ b/dembow/train.py @@ -0,0 +1,119 @@ +"""Train the Dembow music Transformer on a folder of reggaeton MIDI.""" + +from __future__ import annotations + +import numpy as np +import torch +import torch.nn.functional as F + +from .data import build_windows, encode_files, find_midi_files, split_files +from .model import ModelConfig, MusicTransformer +from .tokenizer import PAD, VOCAB + + +@torch.no_grad() +def _eval_loss(model: MusicTransformer, data: torch.Tensor, batch_size: int) -> float: + """Average next-token loss over a held-out set.""" + model.eval() + losses = [] + for start in range(0, data.shape[0], batch_size): + batch = data[start : start + batch_size] + logits = model(batch[:, :-1]) + loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)), batch[:, 1:].reshape(-1), ignore_index=PAD + ) + losses.append(float(loss)) + model.train() + return float(np.mean(losses)) if losses else float("nan") + + +def train( + data_dir: str = "reggaeton_samples", + checkpoint: str = "dembow.pt", + seq_len: int = 384, + d_model: int = 256, + n_layers: int = 4, + n_heads: int = 4, + num_epochs: int = 80, + batch_size: int = 16, + lr: float = 3e-4, + augment: int = 3, + val_frac: float = 0.1, + patience: int = 8, + seed: int = 0, + device: str | None = None, +) -> MusicTransformer: + """Train the Transformer and save the best checkpoint (by validation loss).""" + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + torch.manual_seed(seed) + np.random.seed(seed) + + transpositions = list(range(-augment, augment + 1)) if augment else [0] + files = find_midi_files(data_dir) + train_files, val_files = split_files(files, val_frac, seed=seed) + print(f"Loading songs from '{data_dir}' (pitch augmentation: {transpositions})") + print(f" {len(train_files)} train songs, {len(val_files)} val songs") + + # Train data is augmented; validation is held-out songs at original pitch only. + train_windows = build_windows(encode_files(train_files, transpositions), seq_len) + val_windows = build_windows(encode_files(val_files, [0]), seq_len) + if train_windows.shape[0] == 0: + raise SystemExit("No usable training windows -- nothing to train on.") + print(f" {train_windows.shape[0]} train windows, {val_windows.shape[0]} val windows (vocab {len(VOCAB)})") + + config = ModelConfig( + vocab_size=len(VOCAB), d_model=d_model, n_layers=n_layers, + n_heads=n_heads, max_len=seq_len, + ) + model = MusicTransformer(config).to(device) + n_params = sum(p.numel() for p in model.parameters()) + print(f"Model: {n_params/1e6:.2f}M parameters on {device}") + + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) + train_data = torch.from_numpy(train_windows).to(device) + val_data = torch.from_numpy(val_windows).to(device) if val_windows.shape[0] else None + n = train_data.shape[0] + + print(f"Training for up to {num_epochs} epochs (early-stop patience {patience}) ...") + best_val = float("inf") + best_epoch = 0 + since_improved = 0 + model.train() + for epoch in range(1, num_epochs + 1): + perm = torch.randperm(n, device=device) + losses = [] + for start in range(0, n, batch_size): + batch = train_data[perm[start : start + batch_size]] + logits = model(batch[:, :-1]) + loss = F.cross_entropy( + logits.reshape(-1, logits.size(-1)), batch[:, 1:].reshape(-1), ignore_index=PAD + ) + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + losses.append(loss.item()) + + train_loss = float(np.mean(losses)) + if val_data is not None: + val_loss = _eval_loss(model, val_data, batch_size) + improved = val_loss < best_val - 1e-4 + if improved: + best_val, best_epoch, since_improved = val_loss, epoch, 0 + model.save(checkpoint) # keep the best model, not the last + else: + since_improved += 1 + print(f" epoch {epoch:4d}/{num_epochs} train {train_loss:.4f} val {val_loss:.4f}" + f"{' *best (saved)' if improved else ''}") + if since_improved >= patience: + print(f"Early stopping at epoch {epoch} (no val improvement for {patience} epochs)") + break + else: + print(f" epoch {epoch:4d}/{num_epochs} train {train_loss:.4f}") + model.save(checkpoint) + + if val_data is not None: + print(f"Best model: epoch {best_epoch}, val loss {best_val:.4f} -> '{checkpoint}'") + else: + print(f"Saved trained model to '{checkpoint}'") + return model diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..abcd813 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,18 @@ +# Example outputs + +These `.mid` files were generated by Dembow so you can hear it without training +your own model first. Open them in any DAW or MIDI player. + +They come from a deliberately **small demo model** (`--d-model 96 --n-layers 2`, +no pitch augmentation, ~16 epochs, validation loss β‰ˆ 1.44) trained on the bundled +`reggaeton_samples/` corpus, then generated with: + +```sh +dembow generate --num-samples 3 --max-new-tokens 900 \ + --temperature 0.95 --repetition-penalty 1.2 +``` + +Each file is a multi-track arrangement β€” drums (channel 10) plus bass / mid / +high melodic parts. A model trained longer, with the `gpu` preset, and on **more +data** (see [`../reggaeton_samples/SOURCES.md`](../reggaeton_samples/SOURCES.md)) +will sound noticeably better. diff --git a/examples/dembow_0.mid b/examples/dembow_0.mid new file mode 100644 index 0000000..956caa9 Binary files /dev/null and b/examples/dembow_0.mid differ diff --git a/examples/dembow_1.mid b/examples/dembow_1.mid new file mode 100644 index 0000000..4e0ce39 Binary files /dev/null and b/examples/dembow_1.mid differ diff --git a/examples/dembow_2.mid b/examples/dembow_2.mid new file mode 100644 index 0000000..a1e2d48 Binary files /dev/null and b/examples/dembow_2.mid differ diff --git a/fire.py b/fire.py index cbdbeba..3db665e 100755 --- a/fire.py +++ b/fire.py @@ -1,129 +1,26 @@ -# Based on Daniel Johnson's code in https://github.com/hexahedria/biaxial-rnn-music-composition +#!/usr/bin/env python3 +"""Light the fire. πŸ”₯ -import numpy as np -import pandas as pd -import msgpack -import glob -import tensorflow as tf -from tensorflow.python.ops import control_flow_ops -from tqdm import tqdm +Trains the Dembow Transformer on the reggaeton corpus and immediately generates +a few songs. For more control use the CLI: + dembow train + dembow generate -import midi_manipulation +or, without installing: -def get_songs(path): - files = glob.glob('{}/*.mid*'.format(path)) - songs = [] - for f in tqdm(files): - try: - song = np.array(midi_manipulation.midiToNoteStateMatrix(f)) - if np.array(song).shape[0] > 50: - songs.append(song) - except Exception as e: - raise e - return songs + python -m dembow.cli train + python -m dembow.cli generate +""" -songs = get_songs('reggaeton_samples') -print "{} songs processed".format(len(songs)) +from dembow.train import train +from dembow.generate import generate -### HyperParameters -# First, let's take a look at the hyperparameters of our model: +def main(): + train(data_dir="reggaeton_samples", checkpoint="dembow.pt") + generate(checkpoint="dembow.pt", output_dir="generated") -lowest_note = midi_manipulation.lowerBound #the index of the lowest note on the piano roll -highest_note = midi_manipulation.upperBound #the index of the highest note on the piano roll -note_range = highest_note-lowest_note #the note range -num_timesteps = 15 #This is the number of timesteps that we will create at a time -n_visible = 2*note_range*num_timesteps #This is the size of the visible layer. -n_hidden = 50 #This is the size of the hidden layer - -num_epochs = 200 #The number of training epochs that we are going to run. For each epoch we go through the entire data set. -batch_size = 100 #The number of training examples that we are going to send through the RBM at a time. -lr = tf.constant(0.005, tf.float32) #The learning rate of our model - -### Variables: -# Next, let's look at the variables we're going to use: - -x = tf.placeholder(tf.float32, [None, n_visible], name="x") #The placeholder variable that holds our data -W = tf.Variable(tf.random_normal([n_visible, n_hidden], 0.01), name="W") #The weight matrix that stores the edge weights -bh = tf.Variable(tf.zeros([1, n_hidden], tf.float32, name="bh")) #The bias vector for the hidden layer -bv = tf.Variable(tf.zeros([1, n_visible], tf.float32, name="bv")) #The bias vector for the visible layer - - -#### Helper functions. - -#This function lets us easily sample from a vector of probabilities -def sample(probs): - #Takes in a vector of probabilities, and returns a random vector of 0s and 1s sampled from the input vector - return tf.floor(probs + tf.random_uniform(tf.shape(probs), 0, 1)) - -#This function runs the gibbs chain. We will call this function in two places: -# - When we define the training update step -# - When we sample our music segments from the trained RBM -def gibbs_sample(k): - #Runs a k-step gibbs chain to sample from the probability distribution of the RBM defined by W, bh, bv - def gibbs_step(count, k, xk): - #Runs a single gibbs step. The visible values are initialized to xk - hk = sample(tf.sigmoid(tf.matmul(xk, W) + bh)) #Propagate the visible values to sample the hidden values - xk = sample(tf.sigmoid(tf.matmul(hk, tf.transpose(W)) + bv)) #Propagate the hidden values to sample the visible values - return count+1, k, xk - - #Run gibbs steps for k iterations - ct = tf.constant(0) #counter - [_, _, x_sample] = control_flow_ops.while_loop(lambda count, num_iter, *args: count < num_iter, - gibbs_step, [ct, tf.constant(k), x]) - #This is not strictly necessary in this implementation, but if you want to adapt this code to use one of TensorFlow's - #optimizers, you need this in order to stop tensorflow from propagating gradients back through the gibbs step - x_sample = tf.stop_gradient(x_sample) - return x_sample - -### Training Update Code -# Now we implement the contrastive divergence algorithm. First, we get the samples of x and h from the probability distribution -#The sample of x -x_sample = gibbs_sample(1) -#The sample of the hidden nodes, starting from the visible state of x -h = sample(tf.sigmoid(tf.matmul(x, W) + bh)) -#The sample of the hidden nodes, starting from the visible state of x_sample -h_sample = sample(tf.sigmoid(tf.matmul(x_sample, W) + bh)) - -#Next, we update the values of W, bh, and bv, based on the difference between the samples that we drew and the original values -size_bt = tf.cast(tf.shape(x)[0], tf.float32) -W_adder = tf.multiply(lr/size_bt, tf.subtract(tf.matmul(tf.transpose(x), h), tf.matmul(tf.transpose(x_sample), h_sample))) -bv_adder = tf.multiply(lr/size_bt, tf.reduce_sum(tf.subtract(x, x_sample), 0, True)) -bh_adder = tf.multiply(lr/size_bt, tf.reduce_sum(tf.subtract(h, h_sample), 0, True)) -#When we do sess.run(updt), TensorFlow will run all 3 update steps -updt = [W.assign_add(W_adder), bv.assign_add(bv_adder), bh.assign_add(bh_adder)] - - -### Run the graph! -# Now it's time to start a session and run the graph! - -with tf.Session() as sess: - #First, we train the model - #initialize the variables of the model - init = tf.global_variables_initializer() - sess.run(init) - #Run through all of the training data num_epochs times - for epoch in tqdm(range(num_epochs)): - for song in songs: - #The songs are stored in a time x notes format. The size of each song is timesteps_in_song x 2*note_range - #Here we reshape the songs so that each training example is a vector with num_timesteps x 2*note_range elements - song = np.array(song) - song = song[:int(np.floor(song.shape[0]/num_timesteps)*num_timesteps)] - song = np.reshape(song, [song.shape[0]/num_timesteps, song.shape[1]*num_timesteps]) - #Train the RBM on batch_size examples at a time - for i in range(1, len(song), batch_size): - tr_x = song[i:i+batch_size] - sess.run(updt, feed_dict={x: tr_x}) - - #Now the model is fully trained, so let's make some music! - #Run a gibbs chain where the visible nodes are initialized to 0 - sample = gibbs_sample(1).eval(session=sess, feed_dict={x: np.zeros((10, n_visible))}) - for i in range(sample.shape[0]): - if not any(sample[i,:]): - continue - #Here we reshape the vector to be time x notes, and then save the vector as a midi file - S = np.reshape(sample[i,:], (num_timesteps, 2*note_range)) - midi_manipulation.noteStateMatrixToMidi(S, "generated_chord_{}".format(i)) - \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/midi_manipulation.py b/midi_manipulation.py deleted file mode 100755 index 7453f94..0000000 --- a/midi_manipulation.py +++ /dev/null @@ -1,112 +0,0 @@ -import midi -import numpy as np - - -lowerBound = 24 -upperBound = 102 -span = upperBound-lowerBound - - -def midiToNoteStateMatrix(midifile, squash=True, span=span): - pattern = midi.read_midifile(midifile) - - timeleft = [track[0].tick for track in pattern] - - posns = [0 for track in pattern] - - statematrix = [] - time = 0 - - state = [[0,0] for x in range(span)] - statematrix.append(state) - condition = True - while condition: - if time % (pattern.resolution / 4) == (pattern.resolution / 8): - # Crossed a note boundary. Create a new state, defaulting to holding notes - oldstate = state - state = [[oldstate[x][0],0] for x in range(span)] - statematrix.append(state) - for i in range(len(timeleft)): #For each track - if not condition: - break - while timeleft[i] == 0: - track = pattern[i] - pos = posns[i] - - evt = track[pos] - if isinstance(evt, midi.NoteEvent): - if (evt.pitch < lowerBound) or (evt.pitch >= upperBound): - pass - # print "Note {} at time {} out of bounds (ignoring)".format(evt.pitch, time) - else: - if isinstance(evt, midi.NoteOffEvent) or evt.velocity == 0: - state[evt.pitch-lowerBound] = [0, 0] - else: - state[evt.pitch-lowerBound] = [1, 1] - elif isinstance(evt, midi.TimeSignatureEvent): - if evt.numerator not in (2, 4): - # We don't want to worry about non-4 time signatures. Bail early! - # print "Found time signature event {}. Bailing!".format(evt) - out = statematrix - condition = False - break - try: - timeleft[i] = track[pos + 1].tick - posns[i] += 1 - except IndexError: - timeleft[i] = None - - if timeleft[i] is not None: - timeleft[i] -= 1 - - if all(t is None for t in timeleft): - break - - time += 1 - - S = np.array(statematrix) - statematrix = np.hstack((S[:, :, 0], S[:, :, 1])) - statematrix = np.asarray(statematrix).tolist() - return statematrix - -def noteStateMatrixToMidi(statematrix, name="example", span=span): - statematrix = np.array(statematrix) - if not len(statematrix.shape) == 3: - statematrix = np.dstack((statematrix[:, :span], statematrix[:, span:])) - statematrix = np.asarray(statematrix) - pattern = midi.Pattern() - track = midi.Track() - pattern.append(track) - - span = upperBound-lowerBound - tickscale = 55 - - lastcmdtime = 0 - prevstate = [[0,0] for x in range(span)] - for time, state in enumerate(statematrix + [prevstate[:]]): - offNotes = [] - onNotes = [] - for i in range(span): - n = state[i] - p = prevstate[i] - if p[0] == 1: - if n[0] == 0: - offNotes.append(i) - elif n[1] == 1: - offNotes.append(i) - onNotes.append(i) - elif n[0] == 1: - onNotes.append(i) - for note in offNotes: - track.append(midi.NoteOffEvent(tick=(time-lastcmdtime)*tickscale, pitch=note+lowerBound)) - lastcmdtime = time - for note in onNotes: - track.append(midi.NoteOnEvent(tick=(time-lastcmdtime)*tickscale, velocity=40, pitch=note+lowerBound)) - lastcmdtime = time - - prevstate = state - - eot = midi.EndOfTrackEvent(tick=1) - track.append(eot) - - midi.write_midifile("{}.mid".format(name), pattern) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c9c6693 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" + +[project] +name = "dembow" +version = "2.2.0" +description = "The first A.I. that generates reggaeton hits -- a Transformer for dembow." +readme = "README.md" +requires-python = ">=3.9" +license = { text = "MIT" } +dependencies = [ + "numpy>=1.24", + "mido>=1.3", + "torch>=2.0", +] + +[project.scripts] +dembow = "dembow.cli:main" + +[tool.setuptools] +packages = ["dembow"] + +[tool.setuptools.package-data] +dembow = ["assets/*.pt"] + +[tool.pytest.ini_options] +# Put the repo root on sys.path so `pytest tests/` finds the package without an +# install step (used both locally and in CI). +pythonpath = ["."] +testpaths = ["tests"] diff --git a/reggaeton_samples/SOURCES.md b/reggaeton_samples/SOURCES.md new file mode 100644 index 0000000..6783aa1 --- /dev/null +++ b/reggaeton_samples/SOURCES.md @@ -0,0 +1,41 @@ +# Where to find more reggaeton MIDI + +Dembow's quality is limited mostly by data β€” ~76 short files is small for a +Transformer. More clean reggaeton/dembow MIDI is the single biggest improvement +you can make. Drop new `.mid` / `.midi` files into this folder and retrain. + +## What helps most +- **Reggaeton / dembow / Latin-trap MIDI** with a real drum track on channel 10 + (the dembow kick/snare is the genre's signature). +- **Multi-track arrangements** (drums + bass + melody) β€” Dembow groups parts into + drums / bass / mid / high, so layered files teach it more than single melodies. +- **4/4, straightforward tempo.** Quantized files quantize cleanly to the 16th grid. + +## Free / open MIDI sources +- **FreeMIDI.org**, **BitMidi.com**, **MidiWorld** β€” large general libraries; search + artist names (Daddy Yankee, Don Omar, Wisin & Yandel, Tego CalderΓ³n, Aventura). +- **The Lakh MIDI Dataset (LMD)** β€” ~176k MIDI files; filter to Latin/reggaeton by + matching titles/artists. Great for bulk augmentation. +- **MetaMIDI Dataset** β€” large, with genre metadata you can filter on. +- **Groove MIDI Dataset (Magenta)** β€” expressive *drum* performances; excellent for + teaching the groove even though it isn't reggaeton-specific. +- **Hooktheory / TheoryTab** β€” chord+melody data (export to MIDI) for harmony. + +## Make your own +- **Transcribe audio to MIDI** with [Spotify Basic Pitch](https://basicpitch.spotify.com/) + or [Magenta MT3](https://github.com/magenta/mt3) from reggaeton stems/acapellas. +- **Export from a DAW** (Ableton, FL Studio, Logic) β€” reggaeton MIDI packs are widely + sold; bounce the MIDI clips out. +- **Program dembow patterns** by hand in any DAW and export. + +## Cleaning tips +- Keep drums on channel 10 (MIDI channel index 9) so the tokenizer routes them as drums. +- Remove long silent intros/outros; Dembow trains on bars of actual music. +- One song per file; quantize to 16ths if the timing is loose. + +## ⚠️ Licensing +MIDI transcriptions of copyrighted songs carry the rights of the underlying +composition. Use them for **personal experimentation / research**. Don't +redistribute copyrighted MIDI or publish generated tracks commercially without +clearing rights. Prefer openly licensed datasets (LMD, Groove MIDI, MetaMIDI) for +anything you intend to share. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0226c0f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +numpy>=1.24 +mido>=1.3 +torch>=2.0 diff --git a/tests/test_smoke.py b/tests/test_smoke.py new file mode 100644 index 0000000..ce46f21 --- /dev/null +++ b/tests/test_smoke.py @@ -0,0 +1,154 @@ +"""Fast end-to-end smoke test: tokenizer round-trip, a tiny train, generation. + +Run with: pytest tests/ (or python tests/test_smoke.py) +""" + +import os +import tempfile + +import numpy as np + +from dembow.data import build_windows, encode_corpus, find_midi_files, split_files +from dembow.model import ModelConfig, MusicTransformer +from dembow.tokenizer import BAR, BOS, EOS, VOCAB, decode, encode + +SAMPLES = os.path.join(os.path.dirname(__file__), "..", "reggaeton_samples") + + +def test_finds_uppercase_midi(): + files = find_midi_files(SAMPLES) + assert files + # The corpus contains uppercase .MID files the old glob missed. + assert any(f.endswith(".MID") for f in files) + + +def test_tokenizer_roundtrip(): + files = find_midi_files(SAMPLES) + ids = encode(files[0]) + assert ids[0] == BOS and ids[-1] == EOS + assert all(0 <= t < len(VOCAB) for t in ids) + assert BAR in ids + + midi_file = decode(ids) + assert len(midi_file.tracks) >= 2 # meta + at least one instrument + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "rt.mid") + midi_file.save(path) + # The decoded file must itself be re-encodable. + assert len(encode(path)) > 4 + + +def test_transpose_augmentation_shifts_pitch(): + files = find_midi_files(SAMPLES) + base = encode(files[0], transpose=0) + up = encode(files[0], transpose=2) + # Augmentation should change the token stream (different pitches). + assert base != up + + +def test_build_windows(): + songs = encode_corpus(SAMPLES, transpositions=[0]) + windows = build_windows(songs, seq_len=128) + assert windows.ndim == 2 and windows.shape[1] == 128 + assert windows.shape[0] > 0 + + +def test_train_and_generate_tiny(): + import torch + + songs = encode_corpus(SAMPLES, transpositions=[0]) + windows = build_windows(songs, seq_len=128) + config = ModelConfig(vocab_size=len(VOCAB), d_model=64, n_layers=2, n_heads=2, max_len=128) + model = MusicTransformer(config) + opt = torch.optim.AdamW(model.parameters(), lr=1e-3) + + data = torch.from_numpy(windows[:16]) + first = last = None + for _ in range(8): + logits = model(data[:, :-1]) + loss = torch.nn.functional.cross_entropy( + logits.reshape(-1, logits.size(-1)), data[:, 1:].reshape(-1) + ) + opt.zero_grad() + loss.backward() + opt.step() + val = loss.item() + first = first if first is not None else val + last = val + assert last < first # the model is learning + + ids = model.generate(torch.tensor([BOS]), max_new_tokens=120, eos_id=EOS) + assert len(ids) > 1 and all(0 <= t < len(VOCAB) for t in ids) + midi_file = decode(ids) # generated tokens must decode without error + assert len(midi_file.tracks) >= 1 + + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "m.pt") + model.save(path) + loaded = MusicTransformer.load(path) + assert loaded.config.d_model == 64 + + +def test_split_files_is_disjoint(): + files = find_midi_files(SAMPLES) + train, val = split_files(files, val_frac=0.2, seed=0) + assert len(val) == int(len(files) * 0.2) + assert len(train) + len(val) == len(files) + assert set(train).isdisjoint(set(val)) # no song leaks across the split + + +def test_generate_repetition_controls(): + import torch + + songs = encode_corpus(SAMPLES, transpositions=[0]) + windows = build_windows(songs, seq_len=128) + model = MusicTransformer(ModelConfig(vocab_size=len(VOCAB), d_model=64, n_layers=2, n_heads=2, max_len=128)) + data = torch.from_numpy(windows[:16]) + opt = torch.optim.AdamW(model.parameters(), lr=1e-3) + for _ in range(3): + logits = model(data[:, :-1]) + loss = torch.nn.functional.cross_entropy(logits.reshape(-1, logits.size(-1)), data[:, 1:].reshape(-1)) + opt.zero_grad(); loss.backward(); opt.step() + + # Both repetition controls should run and yield valid tokens. + ids = model.generate( + torch.tensor([BOS]), max_new_tokens=80, repetition_penalty=1.3, + no_repeat_ngram_size=4, eos_id=EOS, + ) + assert len(ids) > 1 and all(0 <= t < len(VOCAB) for t in ids) + decode(ids) # must still decode + + +def test_bundled_pretrained_model_loads(): + from dembow.generate import BUNDLED_CHECKPOINT + from dembow.model import MusicTransformer + + assert os.path.exists(BUNDLED_CHECKPOINT), "a pretrained model should ship with the package" + model = MusicTransformer.load(BUNDLED_CHECKPOINT) + assert model.config.vocab_size == len(VOCAB) + + +def test_render_to_wav_builtin(): + import wave + + from dembow.render import render_to_wav + + files = find_midi_files(SAMPLES) + with tempfile.TemporaryDirectory() as d: + wav = render_to_wav(files[0], os.path.join(d, "out.wav")) + assert os.path.exists(wav) + with wave.open(wav) as w: + assert w.getnframes() > 0 # produced audible audio + + +if __name__ == "__main__": + test_finds_uppercase_midi() + test_tokenizer_roundtrip() + test_transpose_augmentation_shifts_pitch() + test_build_windows() + test_train_and_generate_tiny() + test_split_files_is_disjoint() + test_generate_repetition_controls() + test_bundled_pretrained_model_loads() + test_render_to_wav_builtin() + print("All smoke tests passed.")