Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/notebook-validation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Validate notebooks

on:
pull_request:
branches: [master]
paths:
- "templates/**/*.ipynb"
- "solutions/**/*.ipynb"
- "scripts/validate_notebooks.py"
- ".github/workflows/notebook-validation.yml"
push:
branches: [master]
paths:
- "templates/**/*.ipynb"
- "solutions/**/*.ipynb"
- "scripts/validate_notebooks.py"
- ".github/workflows/notebook-validation.yml"
workflow_dispatch:

jobs:
validate-notebooks:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install nbformat
run: pip install nbformat

- name: Validate notebooks
run: python scripts/validate_notebooks.py
13 changes: 6 additions & 7 deletions scripts/add_colab_badges.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
"""Add 'Open in Colab' badges to all template and solution notebooks."""
"""Add 'Open in Colab' links to all template and solution notebooks."""

import json
from pathlib import Path
Expand All @@ -9,7 +9,6 @@
ROOT = Path(__file__).resolve().parent.parent
TEMPLATES_DIR = ROOT / "templates"
SOLUTIONS_DIR = ROOT / "solutions"
BADGE_IMG = "https://colab.research.google.com/assets/colab-badge.svg"


def colab_url(filename: str, folder: str) -> str:
Expand All @@ -19,8 +18,8 @@ def colab_url(filename: str, folder: str) -> str:
)


def badge_markdown(filename: str, folder: str) -> str:
return f"[![Open In Colab]({BADGE_IMG})]({colab_url(filename, folder)})"
def colab_markdown(filename: str, folder: str) -> str:
return f"[Open in Colab]({colab_url(filename, folder)})"


def process_notebook(path: Path, folder: str) -> bool:
Expand All @@ -33,11 +32,11 @@ def process_notebook(path: Path, folder: str) -> bool:

source_lines = cells[0]["source"]
flat = "".join(source_lines) if isinstance(source_lines, list) else source_lines
if "colab-badge.svg" in flat:
if "colab.research.google.com/github/" in flat:
return False

badge = badge_markdown(path.name, folder)
cells[0]["source"] = [badge + "\n\n"] + (
link = colab_markdown(path.name, folder)
cells[0]["source"] = [link + "\n\n"] + (
source_lines if isinstance(source_lines, list) else [source_lines]
)

Expand Down
184 changes: 184 additions & 0 deletions scripts/validate_notebooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""Validate and optionally repair notebook cell schemas."""

from __future__ import annotations

import argparse
import hashlib
import json
from pathlib import Path
import re
from typing import Any


ROOT = Path(__file__).resolve().parent.parent
NOTEBOOK_GLOBS = ("templates/*.ipynb", "solutions/*.ipynb")
CODE_ONLY_FIELDS = ("outputs", "execution_count")
CELL_ID_RE = re.compile(r"^[A-Za-z0-9-_]+$")


def notebook_paths() -> list[Path]:
paths: list[Path] = []
for pattern in NOTEBOOK_GLOBS:
paths.extend(ROOT.glob(pattern))
return sorted(paths)


def load_notebook(path: Path) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as f:
return json.load(f)


def write_notebook(path: Path, notebook: dict[str, Any]) -> None:
with path.open("w", encoding="utf-8") as f:
json.dump(notebook, f, ensure_ascii=False, indent=1)
f.write("\n")


def source_text(cell: dict[str, Any]) -> str:
source = cell.get("source", "")
if isinstance(source, list):
return "".join(str(line) for line in source)
return str(source)


def stable_cell_id(path: Path, index: int, cell: dict[str, Any], used: set[str]) -> str:
seed = (
f"{path.relative_to(ROOT)}:{index}:"
f"{cell.get('cell_type', '')}:{source_text(cell)}"
)
base = f"cell-{hashlib.sha1(seed.encode('utf-8')).hexdigest()[:12]}"
cell_id = base
suffix = 1
while cell_id in used:
cell_id = f"{base}-{suffix}"
suffix += 1
return cell_id


def sanitize_notebook(path: Path, fix: bool) -> list[str]:
notebook = load_notebook(path)
errors: list[str] = []
changed = False
used_ids: set[str] = set()
needs_cell_ids = False
nbformat_minor = int(notebook.get("nbformat_minor", 0))
version_error_added = False

for index, cell in enumerate(notebook.get("cells", [])):
cell_id = cell.get("id")
if not isinstance(cell_id, str) or not cell_id:
errors.append(f"{path.relative_to(ROOT)} cell {index}: missing cell id")
needs_cell_ids = True
if fix:
cell["id"] = stable_cell_id(path, index, cell, used_ids)
cell_id = cell["id"]
changed = True
elif not CELL_ID_RE.match(cell_id):
errors.append(f"{path.relative_to(ROOT)} cell {index}: invalid cell id")
needs_cell_ids = True
if fix:
cell["id"] = stable_cell_id(path, index, cell, used_ids)
cell_id = cell["id"]
changed = True

if isinstance(cell_id, str):
if cell_id in used_ids:
errors.append(f"{path.relative_to(ROOT)} cell {index}: duplicate cell id")
if fix:
cell["id"] = stable_cell_id(path, index, cell, used_ids)
cell_id = cell["id"]
changed = True
used_ids.add(cell_id)

if (
isinstance(cell_id, str)
and cell_id
and nbformat_minor < 5
and not version_error_added
):
errors.append(
f"{path.relative_to(ROOT)}: cell ids require nbformat_minor >= 5"
)
version_error_added = True
if fix:
needs_cell_ids = True
changed = True

if cell.get("cell_type") == "code":
continue

for field in CODE_ONLY_FIELDS:
if field in cell:
errors.append(
f"{path.relative_to(ROOT)} cell {index}: "
f"non-code cell contains '{field}'"
)
if fix:
del cell[field]
changed = True

if changed:
if needs_cell_ids:
notebook["nbformat"] = 4
notebook["nbformat_minor"] = max(int(notebook.get("nbformat_minor", 0)), 5)
write_notebook(path, notebook)

return errors


def validate_with_nbformat(path: Path) -> str | None:
try:
import nbformat
except ImportError:
return None

try:
notebook = nbformat.read(path, as_version=4)
nbformat.validate(notebook)
except Exception as exc: # pragma: no cover - message is for CLI output
return f"{path.relative_to(ROOT)}: {exc}"

return None


def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
"--fix",
action="store_true",
help="repair missing/invalid ids and remove code-only fields from non-code cells",
)
args = parser.parse_args()

schema_errors: list[str] = []
for path in notebook_paths():
schema_errors.extend(sanitize_notebook(path, args.fix))

if schema_errors and not args.fix:
print("Notebook schema errors:")
print("\n".join(schema_errors))
print("\nRun scripts/validate_notebooks.py --fix to repair them.")
return 1

nbformat_errors: list[str] = []
for path in notebook_paths():
error = validate_with_nbformat(path)
if error:
nbformat_errors.append(error)

if nbformat_errors:
print("nbformat validation errors:")
print("\n".join(nbformat_errors))
return 1

if schema_errors:
print(f"Fixed {len(schema_errors)} notebook schema issue(s).")
else:
print("All notebooks passed validation.")

return 0


if __name__ == "__main__":
raise SystemExit(main())
17 changes: 11 additions & 6 deletions solutions/01_relu_solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "0556419b",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/01_relu_solution.ipynb)\n",
"[Open in Colab](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/01_relu_solution.ipynb)\n",
"\n",
"# 🟢 Solution: Implement ReLU\n",
"\n",
Expand All @@ -26,7 +26,8 @@
" get_ipython().run_line_magic('pip', 'install -q torch-judge')\n",
"except ImportError:\n",
" pass\n"
]
],
"id": "cell-8b01cc47c4eb"
},
{
"cell_type": "code",
Expand All @@ -35,7 +36,8 @@
"outputs": [],
"source": [
"import torch"
]
],
"id": "cell-42aacb4e5964"
},
{
"cell_type": "code",
Expand All @@ -47,7 +49,8 @@
"\n",
"def relu(x: torch.Tensor) -> torch.Tensor:\n",
" return x * (x > 0).float()"
]
],
"id": "cell-77d0ad2d5301"
},
{
"cell_type": "code",
Expand All @@ -59,7 +62,8 @@
"x = torch.tensor([-2., -1., 0., 1., 2.])\n",
"print(\"Input: \", x)\n",
"print(\"Output:\", relu(x))"
]
],
"id": "cell-ee0dd6b7c97c"
},
{
"cell_type": "code",
Expand All @@ -70,7 +74,8 @@
"# Run judge\n",
"from torch_judge import check\n",
"check(\"relu\")"
]
],
"id": "cell-a93501c6f94f"
}
],
"metadata": {
Expand Down
19 changes: 12 additions & 7 deletions solutions/02_softmax_solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/02_softmax_solution.ipynb)\n\n",
"[Open in Colab](https://colab.research.google.com/github/duoan/TorchCode/blob/master/solutions/02_softmax_solution.ipynb)\n\n",
"# 🟢 Solution: Implement Softmax\n",
"\n",
"Reference solution for the numerically-stable Softmax function.\n",
"\n",
"$$\\text{softmax}(x_i) = \\frac{e^{x_i - \\max(x)}}{\\sum_j e^{x_j - \\max(x)}}$$"
],
"outputs": []
"id": "cell-1b2118d19858"
},
{
"cell_type": "code",
Expand All @@ -38,7 +38,8 @@
" pass\n"
],
"outputs": [],
"execution_count": null
"execution_count": null,
"id": "cell-23f1a6c527b4"
},
{
"cell_type": "code",
Expand All @@ -47,7 +48,8 @@
"import torch"
],
"outputs": [],
"execution_count": null
"execution_count": null,
"id": "cell-01a8eba71de9"
},
{
"cell_type": "code",
Expand All @@ -61,7 +63,8 @@
" return e_x / e_x.sum(dim=dim, keepdim=True)"
],
"outputs": [],
"execution_count": null
"execution_count": null,
"id": "cell-455503eacc0f"
},
{
"cell_type": "code",
Expand All @@ -74,7 +77,8 @@
"print(\"Ref: \", torch.softmax(x, dim=-1))"
],
"outputs": [],
"execution_count": null
"execution_count": null,
"id": "cell-10c0561ea637"
},
{
"cell_type": "code",
Expand All @@ -85,7 +89,8 @@
"check(\"softmax\")"
],
"outputs": [],
"execution_count": null
"execution_count": null,
"id": "cell-ff8e53cdd120"
}
]
}
Loading
Loading