Skip to content

Commit 9e26bf2

Browse files
committed
Expand data workflows and Nemotron serving/training support
1 parent b72cf53 commit 9e26bf2

43 files changed

Lines changed: 3695 additions & 374 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ htmlcov/
3636

3737
# NeMoCode local config
3838
.nemocode.yaml
39+
.nemocode/data/
40+
.nemocode/adapters/
41+
.nemocode/logs/
42+
.nemocode/models/
3943

4044
# Sessions
4145
~/.local/share/nemocode/

scripts/train_lora.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES.
3+
# SPDX-License-Identifier: MIT
4+
5+
"""LoRA fine-tuning for Nemotron 3 Nano 4B on DGX Spark.
6+
7+
Uses Hugging Face transformers + PEFT for native training on GB10.
8+
The BF16 checkpoint (~8GB) fits easily in Spark's 128GB unified memory.
9+
10+
python scripts/train_lora.py
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import json
16+
import sys
17+
from pathlib import Path
18+
19+
# Training config
20+
MODEL_ID = "nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16"
21+
DATASET_PATH = ".nemocode/data/sft_generated.jsonl"
22+
OUTPUT_DIR = ".nemocode/adapters/nemotron-nano-4b-lora"
23+
LORA_RANK = 16
24+
LORA_ALPHA = 32
25+
EPOCHS = 2
26+
BATCH_SIZE = 1
27+
GRADIENT_ACCUMULATION = 8 # effective batch = 8
28+
LEARNING_RATE = 2e-4
29+
MAX_SEQ_LEN = 8192
30+
LOGGING_STEPS = 10
31+
SAVE_STEPS = 50
32+
33+
34+
def load_dataset(path: str) -> list[dict]:
35+
"""Load SFT JSONL dataset."""
36+
records = []
37+
with open(path) as f:
38+
for line in f:
39+
if line.strip():
40+
records.append(json.loads(line))
41+
print(f"Loaded {len(records)} records from {path}")
42+
return records
43+
44+
45+
def main():
46+
import torch
47+
48+
if not torch.cuda.is_available():
49+
print("ERROR: CUDA not available.")
50+
print("Make sure you're using a venv with CUDA-enabled PyTorch.")
51+
sys.exit(1)
52+
53+
print(f"PyTorch: {torch.__version__}")
54+
print(f"CUDA: {torch.cuda.get_device_name(0)}")
55+
mem = torch.cuda.get_device_properties(0).total_memory
56+
print(f"Memory: {mem / 1e9:.1f} GB")
57+
58+
from transformers import AutoTokenizer, AutoModelForCausalLM
59+
from peft import LoraConfig, get_peft_model, TaskType
60+
from trl import SFTTrainer, SFTConfig
61+
62+
# Load dataset
63+
if not Path(DATASET_PATH).exists():
64+
print(f"ERROR: Dataset not found at {DATASET_PATH}")
65+
print("Generate it first:")
66+
print(" nemo data export-seeds")
67+
print(" nemo data generate")
68+
sys.exit(1)
69+
70+
dataset_records = load_dataset(DATASET_PATH)
71+
72+
# Format for SFT: each record has messages array
73+
def format_messages(record):
74+
"""Convert messages to chat template format."""
75+
return {"messages": record["messages"]}
76+
77+
formatted = [format_messages(r) for r in dataset_records]
78+
79+
# Split 90/10 train/eval
80+
split_idx = int(len(formatted) * 0.9)
81+
train_data = formatted[:split_idx]
82+
eval_data = formatted[split_idx:]
83+
print(f"Train: {len(train_data)}, Eval: {len(eval_data)}")
84+
85+
# Load tokenizer
86+
print(f"\nLoading tokenizer for {MODEL_ID}...")
87+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
88+
if tokenizer.pad_token is None:
89+
tokenizer.pad_token = tokenizer.eos_token
90+
91+
# Load Nemotron 3 Nano 4B BF16 (~8GB — fits easily on Spark's 128GB)
92+
print(f"Loading model {MODEL_ID}...")
93+
model = AutoModelForCausalLM.from_pretrained(
94+
MODEL_ID,
95+
dtype=torch.bfloat16,
96+
device_map="auto",
97+
trust_remote_code=True,
98+
attn_implementation="eager",
99+
)
100+
model.config.use_cache = False
101+
102+
# Configure LoRA
103+
print(f"Applying LoRA (rank={LORA_RANK}, alpha={LORA_ALPHA})...")
104+
lora_config = LoraConfig(
105+
r=LORA_RANK,
106+
lora_alpha=LORA_ALPHA,
107+
lora_dropout=0.05,
108+
bias="none",
109+
task_type=TaskType.CAUSAL_LM,
110+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
111+
)
112+
113+
model = get_peft_model(model, lora_config)
114+
model.print_trainable_parameters()
115+
116+
# Training config
117+
output_dir = Path(OUTPUT_DIR)
118+
output_dir.mkdir(parents=True, exist_ok=True)
119+
120+
from datasets import Dataset
121+
train_dataset = Dataset.from_list(train_data)
122+
eval_dataset = Dataset.from_list(eval_data)
123+
124+
training_args = SFTConfig(
125+
output_dir=str(output_dir),
126+
num_train_epochs=EPOCHS,
127+
per_device_train_batch_size=BATCH_SIZE,
128+
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
129+
learning_rate=LEARNING_RATE,
130+
lr_scheduler_type="cosine",
131+
warmup_ratio=0.1,
132+
logging_steps=LOGGING_STEPS,
133+
save_steps=SAVE_STEPS,
134+
save_total_limit=2,
135+
bf16=True,
136+
max_length=MAX_SEQ_LEN,
137+
gradient_checkpointing=True,
138+
gradient_checkpointing_kwargs={"use_reentrant": False},
139+
optim="adamw_torch_fused",
140+
report_to="none",
141+
eval_strategy="steps",
142+
eval_steps=SAVE_STEPS,
143+
dataset_text_field=None, # Using messages format
144+
)
145+
146+
# Create trainer
147+
print("\nStarting training...")
148+
trainer = SFTTrainer(
149+
model=model,
150+
args=training_args,
151+
train_dataset=train_dataset,
152+
eval_dataset=eval_dataset,
153+
processing_class=tokenizer,
154+
)
155+
156+
# Train
157+
trainer.train()
158+
159+
# Save the LoRA adapter
160+
print(f"\nSaving LoRA adapter to {output_dir}...")
161+
trainer.save_model(str(output_dir))
162+
tokenizer.save_pretrained(str(output_dir))
163+
164+
print(f"\nDone! LoRA adapter saved to {output_dir}")
165+
print(f"To serve with vLLM:")
166+
print(f" vllm serve {MODEL_ID} --enable-lora --lora-modules nemocode={output_dir}")
167+
print(f"Or use nemo serve:")
168+
print(f" nemo serve start --model nemotron-3-nano-4b --adapter {output_dir}")
169+
170+
171+
if __name__ == "__main__":
172+
main()

0 commit comments

Comments
 (0)