diff --git a/HACKATHON-REFERENCE.md b/HACKATHON-REFERENCE.md new file mode 100644 index 000000000..98a355b54 --- /dev/null +++ b/HACKATHON-REFERENCE.md @@ -0,0 +1,29 @@ +# Hackathon Reference (Read-Me-First) + +This branch is a **readability-first snapshot** of work added on top of Scope. It is intended for code review / judging and is **not guaranteed to be runnable as-is**. + +## Where to Look + +- **Realtime control plane (new module)**: [`src/scope/realtime/`](./src/scope/realtime/) + - Event semantics + deterministic chunk-boundary application: [`src/scope/realtime/control_bus.py`](./src/scope/realtime/control_bus.py) + - Prompt sequencing: [`src/scope/realtime/prompt_playlist.py`](./src/scope/realtime/prompt_playlist.py) + - Driver glue: [`src/scope/realtime/generator_driver.py`](./src/scope/realtime/generator_driver.py), [`src/scope/realtime/pipeline_adapter.py`](./src/scope/realtime/pipeline_adapter.py) + +- **CLI tools**: [`src/scope/cli/`](./src/scope/cli/) + - Main CLI entry: [`src/scope/cli/video_cli.py`](./src/scope/cli/video_cli.py) + - Stream Deck integration: [`src/scope/cli/streamdeck_control.py`](./src/scope/cli/streamdeck_control.py) + +- **Server-side recording**: [`src/scope/server/session_recorder.py`](./src/scope/server/session_recorder.py) + +- **Input + control-map generation** (depth/edges/composite conditioning): [`src/scope/server/frame_processor.py`](./src/scope/server/frame_processor.py) + - Vendored depth model used by the control-map pipeline: [`src/scope/vendored/video_depth_anything/`](./src/scope/vendored/video_depth_anything/) + +- **VACE integration + chunk-stability work**: [`src/scope/core/pipelines/wan2_1/vace/`](./src/scope/core/pipelines/wan2_1/vace/) + +- **NDI input support**: [`src/scope/server/ndi/`](./src/scope/server/ndi/) + +## What’s Intentionally Not Included + +This branch is intentionally scoped to **feature work + readability**. Hardware-specific performance codepaths and low-level optimization infrastructure are out of scope for this public snapshot. + +See [`PERF-NOTES.md`](./PERF-NOTES.md) for a high-level description of performance work (without code). diff --git a/PERF-NOTES.md b/PERF-NOTES.md new file mode 100644 index 000000000..80a831562 --- /dev/null +++ b/PERF-NOTES.md @@ -0,0 +1,71 @@ +# Perf Notes (High Level) + +This is a **high-level summary + journey log** of performance work done while building a realtime video pipeline. It is intentionally written without low-level implementation details. + +Code map / entrypoints: [HACKATHON-REFERENCE.md](./HACKATHON-REFERENCE.md) + +## Goals + +- Reduce end-to-end chunk latency and stabilize throughput (avoid periodic stalls). +- Keep output temporally stable across chunk boundaries (cache correctness is as important as raw speed). +- Make performance/debuggability observable (what backend ran, what shapes ran, when caches reset). + +## Starting Point → Current + +- Starting point: ~11 FPS (early end-to-end baseline with stable output). +- Best observed baseline throughput after core optimizations: ~33 FPS (settings-dependent; after warmup). +- Current “performable” mode: ~23 FPS at 448×448 (B200/B300-class GPUs; includes realtime control/conditioning overhead). + +## How We Measured (Practical) + +- Measured the system as three rates: **input FPS** (camera/NDI/WebRTC ingest), **pipeline FPS** (generation), and **output pacing FPS** (what viewers actually see). +- Used chunk boundaries as the primary unit of “state commits” (cache resets, parameter application, replay determinism). +- Avoided benchmarking under GPU contention (server still running, another job holding the device), because it makes results noisy and misleading. + +## Performance Journey (What Moved the Needle) + +### 1) Remove Hidden Caps (Pacing, Contention, Fallbacks) + +- Used the measurement split above (input vs pipeline vs pacing) to quickly detect input-limited and output-limited runs. +- Routinely checked for GPU contention (a background server or another job can cut throughput dramatically). +- Made backend selection observable so “silent fallbacks” don’t masquerade as model regressions. + +### 2) Make The Hot Path GPU-Efficient + +- Integrated a fused attention backend (e.g., FlashAttention 4) where available, with safe fallbacks. +- Focused on the end-to-end critical path: attention + MLP + decode, not just one microkernel. +- Prioritized reducing synchronization points and avoiding accidental host/device round trips. + +### 3) Fix Data Movement Before Micro-Optimizing Kernels + +- Hunted down implicit copies / contiguity fixes / view-to-contiguous transitions in hot paths (especially decode/resize/resample style code). +- Preferred stable shapes and stable layouts across chunks so caches and compiled graphs can actually be reused. + +### 4) Selective Compilation (When It Helps, When It Hurts) + +- Used `torch.compile` selectively on stable subgraphs and avoided compile on paths that are shape-volatile or stateful across invocations. +- Accepted that compilation has warmup cost; measured steady-state after warmup. +- Watched for cudagraph / reuse interactions that can surface as “reused output” failures when state persists between calls. + +### 5) Cache Hygiene + Transition Semantics (Correctness + Perf) + +- Treated chunk boundaries as the primary “state commit” point: cache resets, parameter application, and replay all happen there. +- Made transitions explicit: + - **Hard cut** = intentional cache reset. + - **Soft cut** = controlled transition over multiple chunk boundaries. +- Avoided mixing independent encode/decode streams through a shared temporal cache (a common source of boundary artifacts). + +### 6) Keep Preprocessing Off The Critical Path + +- Depth/control-map generation needs to be fast and predictable, or it becomes the bottleneck (even if generation is fast). +- Prefer asynchronous/pre-buffered preprocessing so occasional slow frames don’t stall the whole pipeline. + +### 7) Precision / Quantization Tradeoffs + +- Explored mixed precision and (where appropriate) FP8-style quantization to reduce memory bandwidth pressure. +- Kept correctness guardrails so visual quality regressions are obvious and attributable. + +## Takeaways + +- Most “FPS regressions” weren’t one kernel getting slower — they were fallbacks, extra copies, contention, or a cache/compile mode mismatch. +- Optimizations only stick if they’re observable (backend reporting) and repeatable (benchmark hygiene). diff --git a/README.md b/README.md index 5c9196a39..7fe40952a 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,14 @@ Scope is a tool for running and customizing real-time, interactive generative AI 🚧 This project is currently in **beta**. 🚧 +## Hackathon Snapshot (`competition-vace`) + +This fork/branch is a **hackathon submission snapshot** of additional work on top of Scope, optimized for readability and review. + +- Start here: [HACKATHON-REFERENCE.md](./HACKATHON-REFERENCE.md) +- High-level performance notes (no code): [PERF-NOTES.md](./PERF-NOTES.md) +- Note: this branch is not guaranteed to be runnable as-is. + ## Table of Contents - [Table of Contents](#table-of-contents) diff --git a/src/scope/cli/streamdeck_control.py b/src/scope/cli/streamdeck_control.py new file mode 100644 index 000000000..2f6179b5f --- /dev/null +++ b/src/scope/cli/streamdeck_control.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +"""Stream Deck controller for Scope - sends style commands to remote server. + +Usage: + VIDEO_API_URL=http://your-gpu-server:8000 uv run python -m scope.cli.streamdeck_control + +Or: + uv run python -m scope.cli.streamdeck_control --url http://your-gpu-server:8000 +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from io import BytesIO + +import httpx +from PIL import Image, ImageDraw, ImageFont + +# Button layout (15-key Stream Deck, 3 rows x 5 cols) +# Key indices go left-to-right, top-to-bottom: +# [0] [1] [2] [3] [4] Row 0: HIDARI YETI TMNT RAT KAIJU +# [5] [6] [7] [8] [9] Row 1: [empty row] +# [10] [11] [12] [13] [14] Row 2: STEP HARD SOFT PLAY [empty] + +STYLES = ["hidari", "yeti", "tmnt", "rat", "kaiju"] + +# Key index mapping (0-14, left-to-right, top-to-bottom) +STYLE_KEYS = {0: "hidari", 1: "yeti", 2: "tmnt", 3: "rat", 4: "kaiju"} +ACTION_KEYS = { + 10: "step", # Bottom row, first + 11: "hard_cut", # Bottom row, second + 12: "soft_cut", # Bottom row, third + 13: "play_pause", # Bottom row, fourth +} + + +def create_button_image( + deck, text: str, bg_color: str = "#1a1a2e", text_color: str = "#ffffff", active: bool = False +) -> bytes: + """Create a button image with text.""" + # Get the button size for this deck + image_format = deck.key_image_format() + size = (image_format["size"][0], image_format["size"][1]) + + # Create image + if active: + bg_color = "#4a9eff" # Highlight active style + img = Image.new("RGB", size, bg_color) + draw = ImageDraw.Draw(img) + + # Try to use a nice font, fall back to default + font_size = size[0] // 5 + try: + font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", font_size) + except OSError: + font = ImageFont.load_default() + + # Center the text + bbox = draw.textbbox((0, 0), text, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + x = (size[0] - text_width) // 2 + y = (size[1] - text_height) // 2 + + draw.text((x, y), text, font=font, fill=text_color) + + # Rotate 180° - Stream Deck Original has flipped orientation + img = img.rotate(180) + + # Convert to the format the deck expects + img_bytes = BytesIO() + img.save(img_bytes, format="JPEG") + return img_bytes.getvalue() + + +class StreamDeckController: + """Controls Scope via Stream Deck button presses.""" + + def __init__(self, api_url: str): + self.api_url = api_url.rstrip("/") + self.client = httpx.Client(timeout=5.0) + self.deck = None + self.current_style: str | None = None + self.is_paused: bool = False + + def connect(self) -> bool: + """Connect to Stream Deck.""" + from StreamDeck.DeviceManager import DeviceManager + + decks = DeviceManager().enumerate() + if not decks: + print("No Stream Deck found!") + return False + + self.deck = decks[0] + self.deck.open() + try: + self.deck.reset() + except Exception as e: + print(f"Warning: Could not reset deck ({e}), continuing anyway...") + print(f"Connected: {self.deck.deck_type()} ({self.deck.key_count()} keys)") + return True + + def update_buttons(self): + """Update all button images.""" + if not self.deck: + return + + # Style buttons (keys 0-3) + for key, style in STYLE_KEYS.items(): + active = style == self.current_style + img = create_button_image(self.deck, style[:6].upper(), active=active) + self.deck.set_key_image(key, img) + + # Action buttons (bottom row: 10, 11, 12, 13) + self.deck.set_key_image(10, create_button_image(self.deck, "STEP", bg_color="#2d3436")) + self.deck.set_key_image(11, create_button_image(self.deck, "HARD", bg_color="#d63031")) + self.deck.set_key_image(12, create_button_image(self.deck, "SOFT", bg_color="#fdcb6e", text_color="#000000")) + self.deck.set_key_image(13, create_button_image(self.deck, "PLAY" if self.is_paused else "PAUSE", bg_color="#2d3436")) + + # Clear unused keys + for key in range(15): + if key not in STYLE_KEYS and key not in [10, 11, 12, 13]: + self.deck.set_key_image(key, create_button_image(self.deck, "", bg_color="#0d0d0d")) + + def fetch_state(self): + """Fetch current state from server.""" + try: + r = self.client.get(f"{self.api_url}/api/v1/realtime/state") + if r.status_code == 200: + state = r.json() + self.current_style = state.get("active_style") + self.is_paused = state.get("paused", False) + return True + except httpx.RequestError as e: + print(f"Failed to fetch state: {e}") + return False + + def set_style(self, style: str): + """Set the active style.""" + try: + r = self.client.put(f"{self.api_url}/api/v1/realtime/style", json={"name": style}) + if r.status_code == 200: + print(f"Style: {style}") + self.current_style = style + self.update_buttons() + else: + print(f"Failed to set style: {r.status_code}") + except httpx.RequestError as e: + print(f"Error: {e}") + + def toggle_pause(self): + """Toggle pause/play.""" + try: + endpoint = "/api/v1/realtime/run" if self.is_paused else "/api/v1/realtime/pause" + r = self.client.post(f"{self.api_url}{endpoint}") + if r.status_code == 200: + self.is_paused = not self.is_paused + print("Paused" if self.is_paused else "Running") + self.update_buttons() + except httpx.RequestError as e: + print(f"Error: {e}") + + def step(self): + """Step one frame.""" + try: + r = self.client.post(f"{self.api_url}/api/v1/realtime/step") + if r.status_code == 200: + print("Stepped") + except httpx.RequestError as e: + print(f"Error: {e}") + + def hard_cut(self): + """Trigger hard cut (reset cache).""" + try: + r = self.client.post(f"{self.api_url}/api/v1/realtime/hard-cut") + if r.status_code == 200: + print("Hard cut!") + except httpx.RequestError as e: + print(f"Error: {e}") + + def soft_cut(self): + """Trigger soft cut.""" + try: + r = self.client.post(f"{self.api_url}/api/v1/realtime/soft-cut") + if r.status_code == 200: + print("Soft cut") + except httpx.RequestError as e: + print(f"Error: {e}") + + def on_key(self, deck, key: int, pressed: bool): + """Handle key press.""" + if not pressed: # Only act on press, not release + return + + if key in STYLE_KEYS: + self.set_style(STYLE_KEYS[key]) + elif key == 10: + self.step() + elif key == 11: + self.hard_cut() + elif key == 12: + self.soft_cut() + elif key == 13: + self.toggle_pause() + + def run(self): + """Main loop.""" + if not self.connect(): + return 1 + + # Fetch initial state + if self.fetch_state(): + print(f"Current style: {self.current_style}, Paused: {self.is_paused}") + else: + print("Warning: Could not fetch initial state (server may be offline)") + + self.update_buttons() + self.deck.set_key_callback(self.on_key) + + print("\nStream Deck ready! Press Ctrl+C to exit.") + print(" Row 1: HIDARI | YETI | TMNT | RAT | KAIJU") + print(" Row 3: STEP | HARD | SOFT | PLAY/PAUSE") + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nShutting down...") + finally: + if self.deck: + try: + self.deck.reset() + except Exception: + pass # Ignore reset errors during cleanup + self.deck.close() + + return 0 + + +def main(): + parser = argparse.ArgumentParser(description="Stream Deck controller for Scope") + parser.add_argument( + "--url", + default=os.environ.get("VIDEO_API_URL", "http://localhost:8000"), + help="Scope server URL (default: VIDEO_API_URL env or http://localhost:8000)", + ) + args = parser.parse_args() + + print(f"Connecting to: {args.url}") + controller = StreamDeckController(args.url) + sys.exit(controller.run()) + + +if __name__ == "__main__": + main() diff --git a/src/scope/cli/video_cli.py b/src/scope/cli/video_cli.py new file mode 100644 index 000000000..14dbec175 --- /dev/null +++ b/src/scope/cli/video_cli.py @@ -0,0 +1,2339 @@ +"""CLI interface for Scope realtime video generation. + +Designed for agent automation. All commands return JSON. +""" + +import json +import sys +import time +from pathlib import Path +from typing import Any + +import click +import httpx + +DEFAULT_URL = "http://localhost:8000" + + +def get_client(ctx) -> httpx.Client: + return httpx.Client(base_url=ctx.obj["url"], timeout=60.0) + + +def output(data: dict, ctx): + if ctx.obj.get("pretty"): + click.echo(json.dumps(data, indent=2)) + else: + click.echo(json.dumps(data)) + + +def handle_error(response: httpx.Response): + if response.status_code >= 400: + try: + error = response.json() + except Exception: + error = {"error": response.text} + click.echo(json.dumps(error), err=True) + sys.exit(1) + + +# Default content directory for playlists +CONTENT_DIR = Path(__file__).parent.parent.parent.parent / "content" / "playlists" + + +def _resolve_playlist_path(file_or_name: str) -> Path | None: + """Resolve a playlist name or file path to an actual file.""" + direct = Path(file_or_name) + if direct.exists() and direct.is_file(): + return direct + + playlist_dir = CONTENT_DIR / file_or_name + if playlist_dir.exists() and playlist_dir.is_dir(): + captioning_dir = playlist_dir / "Captioning" + if captioning_dir.exists(): + for txt_file in captioning_dir.glob("*_captions.txt"): + return txt_file + for pattern in ["captions.txt", f"{file_or_name}.txt"]: + candidate = playlist_dir / pattern + if candidate.exists(): + return candidate + return None + + +@click.group() +@click.option("--url", envvar="VIDEO_API_URL", default=DEFAULT_URL, help="API base URL") +@click.option("--pretty/--no-pretty", default=True, help="Pretty print JSON output") +@click.pass_context +def cli(ctx, url, pretty): + """Scope realtime video CLI - control video generation via REST API.""" + ctx.ensure_object(dict) + ctx.obj["url"] = url + ctx.obj["pretty"] = pretty + + +# --- State --- + + +@cli.command() +@click.pass_context +def state(ctx): + """Get current session state.""" + with get_client(ctx) as client: + r = client.get("/api/v1/realtime/state") + handle_error(r) + output(r.json(), ctx) + + +# --- Generation --- + + +@cli.command() +@click.pass_context +def step(ctx): + """Generate one chunk.""" + with get_client(ctx) as client: + r = client.post("/api/v1/realtime/step") + handle_error(r) + output(r.json(), ctx) + + +@cli.command() +@click.option("--chunks", type=int, default=None, help="Number of chunks to generate") +@click.pass_context +def run(ctx, chunks): + """Start or run generation.""" + with get_client(ctx) as client: + params = {"chunks": chunks} if chunks else {} + r = client.post("/api/v1/realtime/run", params=params) + handle_error(r) + output(r.json(), ctx) + + +@cli.command() +@click.pass_context +def pause(ctx): + """Pause generation.""" + with get_client(ctx) as client: + r = client.post("/api/v1/realtime/pause") + handle_error(r) + output(r.json(), ctx) + + +# --- Prompt --- + + +@cli.command() +@click.argument("text", required=False) +@click.option("--get", "get_only", is_flag=True, help="Only get current prompt") +@click.pass_context +def prompt(ctx, text, get_only): + """Set or get prompt.""" + with get_client(ctx) as client: + if get_only or text is None: + # Get current state which includes prompt + r = client.get("/api/v1/realtime/state") + handle_error(r) + data = r.json() + output({"prompt": data.get("prompt")}, ctx) + else: + r = client.put("/api/v1/realtime/prompt", json={"prompt": text}) + handle_error(r) + output(r.json(), ctx) + + +@cli.command() +@click.option( + "--direction", + "-d", + default=None, + help='Steering direction (e.g., "sadder", "more dynamic")', +) +@click.option( + "--intensity", + "-i", + default=0.3, + type=float, + show_default=True, + help="Variation intensity 0-1", +) +@click.option( + "--count", + "-n", + default=3, + type=int, + show_default=True, + help="Number of variations", +) +@click.option( + "--mode", + "-m", + default="attentional", + type=click.Choice(["attentional", "semantic"]), + show_default=True, + help="Jiggle mode", +) +@click.option( + "--apply", + "apply_index", + type=int, + default=None, + help="Immediately apply variation at index (0-based)", +) +@click.option("--soft-cut", is_flag=True, help="Use soft cut when applying") +@click.option("--soft-cut-bias", type=float, default=None, help="Soft cut temp_bias (server default if omitted)") +@click.option("--soft-cut-chunks", type=int, default=None, help="Soft cut num_chunks (server default if omitted)") +@click.option("--hard-cut", is_flag=True, help="Use hard cut when applying") +@click.option("--prompt", "-p", default=None, help="Prompt to jiggle (default: current active)") +@click.option( + "--print", + "print_human", + is_flag=True, + help="Print a compact list to stderr (JSON still on stdout)", +) +@click.pass_context +def jiggle( + ctx, + direction, + intensity, + count, + mode, + apply_index, + soft_cut, + soft_cut_bias, + soft_cut_chunks, + hard_cut, + prompt, + print_human, +): + """Generate prompt variations.""" + with get_client(ctx) as client: + r = client.post( + "/api/v1/prompt/jiggle", + json={ + "prompt": prompt, + "intensity": intensity, + "count": count, + "direction": direction, + "mode": mode, + }, + ) + handle_error(r) + response = r.json() + + if print_human: + click.echo(f"\nOriginal:\n{response.get('original_prompt', '')}\n", err=True) + click.echo(f"Variations ({mode}):", err=True) + for i, v in enumerate(response.get("variations", [])): + click.echo(f" [{i}] {v}", err=True) + + applied = None + if apply_index is not None: + variations = response.get("variations", []) + if not isinstance(variations, list): + variations = [] + if apply_index < 0 or apply_index >= len(variations): + click.echo( + json.dumps( + { + "error": "apply index out of range", + "apply_index": apply_index, + "variations_count": len(variations), + } + ), + err=True, + ) + sys.exit(1) + + selected = variations[apply_index] + if hard_cut: + r2 = client.post("/api/v1/realtime/hard-cut", json={"prompt": selected}) + elif soft_cut: + payload: dict[str, object] = {"prompt": selected} + if soft_cut_bias is not None: + payload["temp_bias"] = soft_cut_bias + if soft_cut_chunks is not None: + payload["num_chunks"] = soft_cut_chunks + r2 = client.post("/api/v1/realtime/soft-cut", json=payload) + else: + r2 = client.put("/api/v1/realtime/prompt", json={"prompt": selected}) + handle_error(r2) + applied = {"index": apply_index, "prompt": selected, "response": r2.json()} + + output({**response, "applied": applied}, ctx) + + +# --- Frame --- + + +@cli.command() +@click.option("--out", type=click.Path(), help="Output file path") +@click.pass_context +def frame(ctx, out): + """Get current frame.""" + with get_client(ctx) as client: + if out: + r = client.get("/api/v1/realtime/frame/latest") + handle_error(r) + Path(out).write_bytes(r.content) + output({"saved": out, "size_bytes": len(r.content)}, ctx) + else: + # Just report that frame exists + r = client.get("/api/v1/realtime/state") + handle_error(r) + output({"chunk_index": r.json().get("chunk_index")}, ctx) + + +# --- World State --- + + +@cli.command() +@click.argument("json_data", required=False) +@click.option("--get", "get_only", is_flag=True, help="Only get current world state") +@click.pass_context +def world(ctx, json_data, get_only): + """Set or get WorldState. + + Examples: + video-cli world # Get current world state + video-cli world '{"action":"run"}' # Set world state + """ + with get_client(ctx) as client: + if get_only or json_data is None: + r = client.get("/api/v1/realtime/state") + handle_error(r) + data = r.json() + output({"world_state": data.get("world_state")}, ctx) + else: + try: + world_state = json.loads(json_data) + except json.JSONDecodeError as e: + click.echo(json.dumps({"error": f"Invalid JSON: {e}"}), err=True) + sys.exit(1) + r = client.put("/api/v1/realtime/world", json={"world_state": world_state}) + handle_error(r) + output(r.json(), ctx) + + +# --- Parameters --- + + +@cli.command() +@click.argument("json_data", required=False) +@click.pass_context +def params(ctx, json_data): + """Update realtime parameters via the generic parameters endpoint. + + Examples: + video-cli params '{"vace_context_scale": 0.5}' + video-cli params '{"kv_cache_attention_bias": 0.3, "noise_scale": 0.98}' + echo '{"reset_cache": true}' | video-cli params - + """ + if not json_data: + click.echo( + json.dumps( + { + "error": "Missing JSON payload", + "examples": [ + """video-cli params '{"vace_context_scale": 0.5}'""", + """video-cli params '{"kv_cache_attention_bias": 0.3}'""", + """echo '{\"reset_cache\": true}' | video-cli params -""", + ], + } + ), + err=True, + ) + sys.exit(2) + + if json_data.strip() == "-": + json_data = sys.stdin.read() + + try: + payload = json.loads(json_data) + except json.JSONDecodeError as e: + click.echo(json.dumps({"error": f"Invalid JSON: {e}"}), err=True) + sys.exit(1) + + if not isinstance(payload, dict): + click.echo(json.dumps({"error": "Payload must be a JSON object"}), err=True) + sys.exit(1) + + with get_client(ctx) as client: + r = client.post("/api/v1/realtime/parameters", json=payload) + handle_error(r) + output(r.json(), ctx) + + +# --- Style --- + + +@cli.group() +@click.pass_context +def style(ctx): + """Manage active style.""" + pass + + +@style.command("list") +@click.pass_context +def style_list(ctx): + """List available styles.""" + with get_client(ctx) as client: + r = client.get("/api/v1/realtime/style/list") + handle_error(r) + output(r.json(), ctx) + + +@style.command("set") +@click.argument("name") +@click.pass_context +def style_set(ctx, name): + """Set active style by name.""" + with get_client(ctx) as client: + r = client.put("/api/v1/realtime/style", json={"name": name}) + handle_error(r) + output(r.json(), ctx) + + +@style.command("get") +@click.pass_context +def style_get(ctx): + """Get currently active style.""" + with get_client(ctx) as client: + r = client.get("/api/v1/realtime/state") + handle_error(r) + data = r.json() + output( + { + "active_style": data.get("active_style"), + "compiled_prompt": data.get("compiled_prompt"), + }, + ctx, + ) + + +@style.command("blend") +@click.argument("mode", required=False, type=click.Choice(["on", "off"])) +@click.pass_context +def style_blend(ctx, mode): + """Toggle or check style blend mode. + + When blend mode is ON, style switches don't reset the KV cache, + creating interesting visual artifacts during transitions. + + Examples: + video-cli style blend # Get current blend mode + video-cli style blend on # Enable blend mode + video-cli style blend off # Disable blend mode (clean transitions) + """ + with get_client(ctx) as client: + if mode is None: + # Get current blend mode + r = client.get("/api/v1/realtime/style/blend-mode") + handle_error(r) + output(r.json(), ctx) + else: + # Set blend mode + enabled = mode == "on" + r = client.put("/api/v1/realtime/style/blend-mode", json={"enabled": enabled}) + handle_error(r) + output(r.json(), ctx) + + +@style.command("nav") +@click.pass_context +def style_nav(ctx): + """Interactive style navigation mode. + + Controls: + j, ↓ Move selection down + k, ↑ Move selection up + ENTER, SPACE Activate selected style + b Toggle blend mode + r Refresh display + q, ESC Quit + + Run this in a second terminal alongside 'video-cli playlist nav' for + full control over both prompts and styles during live performance. + """ + import os + import select + import termios + import tty + + def get_char_nonblocking(timeout=0.2): + """Read a char with timeout. Returns None if no input.""" + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + if select.select([fd], [], [], timeout)[0]: + ch = os.read(fd, 1).decode("utf-8", errors="ignore") + if ch == "\x1b": + extra = "" + for _ in range(5): + if select.select([fd], [], [], 0.05)[0]: + byte = os.read(fd, 1).decode("utf-8", errors="ignore") + extra += byte + if len(extra) >= 2 and extra[0] == "[" and extra[-1] in "ABCD": + break + else: + break + ch = ch + extra + return ch + return None + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + def display_styles(client, styles, selected_idx, active_style, blend_mode): + """Display styles list with selection and active indicators.""" + import shutil + + term_width = shutil.get_terminal_size().columns + + click.echo("\n" + "=" * term_width) + status = f" Styles: {len(styles)} available" + if active_style: + status += f" [Active: {active_style}]" + if blend_mode: + status += " [🌀 BLEND]" + click.echo(status) + click.echo("=" * term_width) + + for i, style_info in enumerate(styles): + name = style_info.get("name", "unknown") + trigger = style_info.get("trigger_words", []) + trigger_str = trigger[0] if trigger else "" + + # Markers: > for selected, * for active + is_selected = i == selected_idx + is_active = name == active_style + + if is_selected and is_active: + marker = "▶*" + click.echo( + click.style(f"{marker} {name:<15} {trigger_str}", fg="green", bold=True) + ) + elif is_selected: + marker = "▶ " + click.echo( + click.style(f"{marker} {name:<15} {trigger_str}", fg="cyan", bold=True) + ) + elif is_active: + marker = " *" + click.echo( + click.style(f"{marker} {name:<15} {trigger_str}", fg="green") + ) + else: + marker = " " + click.echo(f"{marker} {name:<15} {trigger_str}") + + click.echo("=" * term_width) + click.echo(" j/k nav | ENTER activate | b blend | r refresh | q quit") + click.echo("=" * term_width + "\n") + + click.echo("\nStyle Navigation Mode") + click.echo("Press q or ESC to quit\n") + + selected_idx = 0 + blend_mode = False + + with get_client(ctx) as client: + # Fetch styles + r = client.get("/api/v1/realtime/style/list") + if r.status_code != 200: + click.echo("Failed to fetch styles") + return + styles = r.json().get("styles", []) + if not styles: + click.echo("No styles available") + return + + # Fetch current active style + r = client.get("/api/v1/realtime/state") + active_style = None + if r.status_code == 200: + active_style = r.json().get("active_style") + + # Set initial selection to active style if found + for i, s in enumerate(styles): + if s.get("name") == active_style: + selected_idx = i + break + + # Fetch blend mode + try: + r = client.get("/api/v1/realtime/style/blend-mode") + if r.status_code == 200: + blend_mode = r.json().get("blend_mode", False) + except Exception: + pass + + display_styles(client, styles, selected_idx, active_style, blend_mode) + + while True: + try: + ch = get_char_nonblocking(timeout=0.2) + + if ch is not None: + # Quit + if ch in ("q", "Q", "\x03"): + click.echo("\nExiting style navigation mode.\n") + break + elif ch == "\x1b" and len(ch) == 1: + click.echo("\nExiting style navigation mode.\n") + break + + # Navigation + elif ch in ("j", "\x1b[B"): # Down + selected_idx = (selected_idx + 1) % len(styles) + display_styles(client, styles, selected_idx, active_style, blend_mode) + + elif ch in ("k", "\x1b[A"): # Up + selected_idx = (selected_idx - 1) % len(styles) + display_styles(client, styles, selected_idx, active_style, blend_mode) + + # Activate selected style + elif ch in ("\r", "\n", " "): + style_name = styles[selected_idx].get("name") + r = client.put("/api/v1/realtime/style", json={"name": style_name}) + if r.status_code == 200: + active_style = style_name + click.echo(f" ✓ Activated: {style_name}") + else: + click.echo(f" ✗ Failed to activate: {style_name}") + display_styles(client, styles, selected_idx, active_style, blend_mode) + + # Toggle blend mode + elif ch == "b": + blend_mode = not blend_mode + try: + r = client.put( + "/api/v1/realtime/style/blend-mode", + json={"enabled": blend_mode} + ) + if r.status_code == 200: + status = "ON" if blend_mode else "OFF" + click.echo(f" 🌀 Blend mode: {status}") + else: + blend_mode = not blend_mode + click.echo(" Failed to set blend mode") + except Exception as e: + blend_mode = not blend_mode + click.echo(f" Error: {e}") + display_styles(client, styles, selected_idx, active_style, blend_mode) + + # Refresh + elif ch == "r": + # Re-fetch styles and state + r = client.get("/api/v1/realtime/style/list") + if r.status_code == 200: + styles = r.json().get("styles", []) + r = client.get("/api/v1/realtime/state") + if r.status_code == 200: + active_style = r.json().get("active_style") + display_styles(client, styles, selected_idx, active_style, blend_mode) + + except KeyboardInterrupt: + click.echo("\nExiting style navigation mode.\n") + break + except Exception as e: + click.echo(f"\nError: {e}\n") + break + + +# --- Snapshots (placeholder - uses existing WebRTC API indirectly) --- + + +@cli.command() +@click.pass_context +def snapshot(ctx): + """Create snapshot (requires WebRTC session to handle response).""" + # Note: Snapshot creation goes through update_parameters which + # sends response via WebRTC data channel. For full REST support, + # we'd need to add dedicated snapshot endpoints. + output( + { + "status": "not_implemented", + "message": "Snapshot creation requires dedicated REST endpoint", + }, + ctx, + ) + + +@cli.command() +@click.argument("snapshot_id") +@click.pass_context +def restore(ctx, snapshot_id): + """Restore from snapshot (requires WebRTC session to handle response).""" + output( + { + "status": "not_implemented", + "message": "Snapshot restore requires dedicated REST endpoint", + }, + ctx, + ) + + +# --- VACE --- + + +@cli.group() +@click.pass_context +def vace(ctx): + """Manage VACE (video-to-video editing) settings.""" + pass + + +@vace.command("control-map") +@click.argument("mode", required=False, type=click.Choice(["none", "canny", "pidinet", "depth", "composite", "external"])) +@click.option("--low", type=int, help="Canny low threshold (default: adaptive)") +@click.option("--high", type=int, help="Canny high threshold (default: adaptive)") +@click.option("--safe/--no-safe", default=None, help="PiDiNet safe mode (cleaner edges)") +@click.option("--filter/--no-filter", default=None, help="PiDiNet apply filter") +@click.option("--edge-strength", type=float, help="Composite: edge strength 0-1 (default: 0.6)") +@click.option("--edge-source", type=click.Choice(["canny", "pidinet"]), help="Composite: edge source") +@click.option("--edge-thickness", type=int, help="Composite: edge thickness in pixels (default: 8)") +@click.option("--sharpness", type=float, help="Composite: soft max sharpness (default: 10.0)") +@click.option("--depth-input-size", type=int, help="Depth: VDA input_size (default: 518; lower is faster)") +@click.option("--depth-fp32/--depth-no-fp32", default=None, help="Depth: force FP32 (default: autocast)") +@click.option("--depth-temporal-mode", type=click.Choice(["stream", "stateless"]), help="Depth: temporal mode (stream=stable, stateless=no trails)") +@click.option("--depth-contrast", type=float, help="Depth: contrast/gamma (default: 1.0; >1.0 increases mid-tone contrast for close-ups)") +@click.option("--temporal-ema", type=float, help="Temporal EMA smoothing 0.0-0.95 (0.0=none, 0.5=smooth, 0.9=very smooth)") +@click.option("--worker/--no-worker", default=None, help="Enable/disable control-map worker (preview; generation with --control-buffer)") +@click.option("--worker-allow-heavy/--no-worker-allow-heavy", default=None, help="Allow heavy modes (depth/pidinet/composite) in worker") +@click.option("--worker-max-fps", type=float, help="Cap worker FPS (0 disables cap)") +@click.option("--control-buffer/--no-control-buffer", default=None, help="Enable/disable control buffer (use worker outputs for generation)") +@click.pass_context +def vace_control_map(ctx, mode, low, high, safe, filter, edge_strength, edge_source, edge_thickness, sharpness, depth_input_size, depth_fp32, depth_temporal_mode, depth_contrast, temporal_ema, worker, worker_allow_heavy, worker_max_fps, control_buffer): + """Get or set VACE control map mode. + + Control maps transform webcam/video input before VACE conditioning: + - "none": Use raw video frames (default) + - "canny": Apply Canny edge detection (fast, CPU-based) + - "pidinet": Neural edge detection (higher quality, requires controlnet_aux) + - "depth": Apply VDA depth estimation (depth/layout control) + - "composite": Depth + edges fused with soft max (best composition lock) + - "external": Passthrough mode (input frames are already control maps) + + Examples: + video-cli vace control-map # Get current mode + video-cli vace control-map canny # Enable Canny edge detection + video-cli vace control-map pidinet # Enable PiDiNet neural edges + video-cli vace control-map depth # Enable depth estimation + video-cli vace control-map depth --depth-input-size 320 # Faster depth (lower input_size) + video-cli vace control-map depth --depth-temporal-mode stream # Stable depth (can trail) + video-cli vace control-map depth --depth-temporal-mode stateless # No trails/ghosting + video-cli vace control-map composite # Enable depth+edges composite + video-cli vace control-map external # Passthrough (precomputed control maps) + video-cli vace control-map none # Disable (use raw frames) + video-cli vace control-map canny --low 50 --high 150 # Custom thresholds + video-cli vace control-map pidinet --no-safe # PiDiNet with more detail + video-cli vace control-map composite --edge-strength 0.7 # Stronger edges + video-cli vace control-map depth --temporal-ema 0.5 # Smooth depth maps + video-cli vace control-map depth --depth-contrast 1.5 # More contrast for webcam close-ups + video-cli vace control-map depth --worker # Enable background worker (preview) + video-cli vace control-map depth --worker --control-buffer # Worker-assisted generation + video-cli vace control-map depth --worker --control-buffer --worker-allow-heavy # Allow heavy modes in worker + video-cli vace control-map depth --worker-max-fps 15 # Cap worker FPS + """ + with get_client(ctx) as client: + if mode is None: + # Get current control map mode + r = client.get("/api/v1/realtime/vace/control-map-mode") + handle_error(r) + output(r.json(), ctx) + else: + # Set control map mode + payload = {"mode": mode} + # Canny options + if low is not None: + payload["canny_low_threshold"] = low + if high is not None: + payload["canny_high_threshold"] = high + # PiDiNet options + if safe is not None: + payload["pidinet_safe"] = safe + if filter is not None: + payload["pidinet_filter"] = filter + # Composite options + if edge_strength is not None: + payload["composite_edge_strength"] = edge_strength + if edge_source is not None: + payload["composite_edge_source"] = edge_source + if edge_thickness is not None: + payload["composite_edge_thickness"] = edge_thickness + if sharpness is not None: + payload["composite_sharpness"] = sharpness + # Depth options + if depth_input_size is not None: + payload["depth_input_size"] = depth_input_size + if depth_fp32 is not None: + payload["depth_fp32"] = depth_fp32 + if depth_temporal_mode is not None: + payload["depth_temporal_mode"] = depth_temporal_mode + if depth_contrast is not None: + payload["depth_contrast"] = depth_contrast + # Worker options + if worker is not None: + payload["worker_enabled"] = worker + if worker_allow_heavy is not None: + payload["worker_allow_heavy"] = worker_allow_heavy + if worker_max_fps is not None: + payload["worker_max_fps"] = worker_max_fps + if control_buffer is not None: + payload["control_buffer_enabled"] = control_buffer + # Temporal smoothing + if temporal_ema is not None: + payload["temporal_ema"] = temporal_ema + r = client.put("/api/v1/realtime/vace/control-map-mode", json=payload) + handle_error(r) + output(r.json(), ctx) + + +@vace.command("external-stale-ms") +@click.argument("ms", required=False, type=float) +@click.pass_context +def vace_external_stale_ms(ctx, ms): + """Get or set external control-map staleness threshold (ms). + + When in `vace_control_map_mode=external` + NDI input, Scope stalls generation + if the newest NDI control frame is older than this threshold. + """ + with get_client(ctx) as client: + if ms is None: + r = client.get("/api/v1/realtime/debug/fps") + handle_error(r) + fp = r.json().get("frame_processor") or {} + ndi = fp.get("ndi") if isinstance(fp, dict) else None + if not isinstance(ndi, dict): + ndi = {} + output({"vace_external_stale_ms": ndi.get("external_stale_ms")}, ctx) + return + + ms = max(0.0, float(ms)) + r = client.post( + "/api/v1/realtime/parameters", + json={"vace_external_stale_ms": ms}, + ) + handle_error(r) + output({"status": "ok", "vace_external_stale_ms": ms}, ctx) + + +@vace.command("external-resume-hard-cut") +@click.argument("enabled", required=False, type=click.Choice(["on", "off"])) +@click.pass_context +def vace_external_resume_hard_cut(ctx, enabled): + """Get or set external-control resume hard cut behavior. + + When enabled, Scope forces `reset_cache=True` once when external control + transitions from stale -> fresh. Disable this to preserve temporal coherence + across brief external-control dropouts. + """ + with get_client(ctx) as client: + if enabled is None: + r = client.get("/api/v1/realtime/debug/fps") + handle_error(r) + fp = r.json().get("frame_processor") or {} + ndi = fp.get("ndi") if isinstance(fp, dict) else None + if not isinstance(ndi, dict): + ndi = {} + output( + {"vace_external_resume_hard_cut": ndi.get("external_resume_hard_cut")}, + ctx, + ) + return + + val = enabled == "on" + r = client.post( + "/api/v1/realtime/parameters", + json={"vace_external_resume_hard_cut": val}, + ) + handle_error(r) + output({"status": "ok", "vace_external_resume_hard_cut": val}, ctx) + + +# --- Playlist --- + + +@cli.group() +@click.pass_context +def playlist(ctx): + """Manage prompt playlist from caption files.""" + pass + + +@playlist.command("load") +@click.argument("file_path", type=click.Path(exists=True)) +@click.option("--swap", nargs=2, help="Trigger swap: OLD NEW") +@click.pass_context +def playlist_load(ctx, file_path, swap): + """Load prompts from a caption file. + + Examples: + video-cli playlist load captions.txt + video-cli playlist load captions.txt --swap "1988 Cel Animation" "Rankin/Bass Animagic Stop-Motion" + """ + with get_client(ctx) as client: + payload = {"file_path": str(Path(file_path).absolute())} + if swap: + payload["old_trigger"] = swap[0] + payload["new_trigger"] = swap[1] + r = client.post("/api/v1/realtime/playlist/load", json=payload) + handle_error(r) + output(r.json(), ctx) + + +@playlist.command("status") +@click.pass_context +def playlist_status(ctx): + """Get current playlist state.""" + with get_client(ctx) as client: + r = client.get("/api/v1/realtime/playlist") + handle_error(r) + output(r.json(), ctx) + + +@playlist.command("switch") +@click.option("--nav", is_flag=True, help="Enter nav mode after loading") +@click.option("--context", "-c", type=int, default=5, help="Lines of context for nav mode (default: 5)") +@click.pass_context +def playlist_switch(ctx, nav, context): + """Interactive playlist switcher.""" + import shutil + import termios + import tty + + if not CONTENT_DIR.exists(): + click.echo("Content directory not found") + return + + playlists = [] + for playlist_dir in sorted(CONTENT_DIR.iterdir()): + if playlist_dir.is_dir(): + resolved = _resolve_playlist_path(playlist_dir.name) + if resolved: + try: + prompt_count = sum(1 for line in resolved.read_text().splitlines() if line.strip()) + except Exception: + prompt_count = 0 + playlists.append({"name": playlist_dir.name, "path": resolved, "count": prompt_count}) + + if not playlists: + click.echo("No playlists found") + return + + def display(): + term_width = shutil.get_terminal_size().columns + click.clear() + click.echo("\n" + "=" * term_width) + click.echo(" Playlist Switcher") + click.echo("=" * term_width + "\n") + for i, pl in enumerate(playlists[:9], 1): + click.echo(f" [{i}] {pl['name']:20s} ({pl['count']} prompts)") + click.echo("\n" + "=" * term_width) + click.echo(" Press 1-9 to load, q to quit") + click.echo("=" * term_width + "\n") + + def get_char(): + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + return sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + display() + while True: + ch = get_char() + if ch in ("q", "Q", "\x1b"): + click.echo("\nCancelled.") + return + if ch.isdigit() and 1 <= int(ch) <= len(playlists): + selected = playlists[int(ch) - 1] + click.echo(f"\nLoading {selected['name']}...") + with get_client(ctx) as client: + r = client.post("/api/v1/realtime/playlist/load", json={"file_path": str(selected["path"].absolute())}) + if r.status_code >= 400: + click.echo(f"Error: {r.text}") + return + click.echo(f"Loaded {selected['name']} ({selected['count']} prompts)") + if nav: + click.echo("Entering nav mode...\n") + ctx.invoke(playlist_nav, context=context) + return + + +@playlist.command("source-trigger") +@click.argument("trigger") +@click.pass_context +def playlist_source_trigger(ctx, trigger): + """Set the source trigger phrase for auto-swap. + + The source trigger is what trigger phrase exists in your original prompts. + This enables auto-trigger-swap when switching styles. + + Examples: + # If your prompts contain "Hidari Animation": + video-cli playlist source-trigger "Hidari Animation" + + # Now when you switch styles, the trigger will auto-swap: + video-cli style set yeti # "Hidari Animation" -> "Yeti Animation" + """ + with get_client(ctx) as client: + r = client.put("/api/v1/realtime/playlist/source-trigger", json={"trigger": trigger}) + handle_error(r) + output(r.json(), ctx) + + +@playlist.command("preview") +@click.option("--context", "-c", type=int, default=2, help="Lines of context around current") +@click.pass_context +def playlist_preview(ctx, context): + """Preview prompts around current position.""" + with get_client(ctx) as client: + r = client.get("/api/v1/realtime/playlist/preview", params={"context": context}) + handle_error(r) + output(r.json(), ctx) + + +@playlist.command("next") +@click.option("--apply/--no-apply", default=True, help="Apply prompt after navigating") +@click.pass_context +def playlist_next(ctx, apply): + """Move to next prompt.""" + with get_client(ctx) as client: + r = client.post("/api/v1/realtime/playlist/next", params={"apply": apply}) + handle_error(r) + output(r.json(), ctx) + + +@playlist.command("prev") +@click.option("--apply/--no-apply", default=True, help="Apply prompt after navigating") +@click.pass_context +def playlist_prev(ctx, apply): + """Move to previous prompt.""" + with get_client(ctx) as client: + r = client.post("/api/v1/realtime/playlist/prev", params={"apply": apply}) + handle_error(r) + output(r.json(), ctx) + + +@playlist.command("goto") +@click.argument("index", type=int) +@click.option("--apply/--no-apply", default=True, help="Apply prompt after navigating") +@click.pass_context +def playlist_goto(ctx, index, apply): + """Go to a specific prompt index.""" + with get_client(ctx) as client: + r = client.post("/api/v1/realtime/playlist/goto", json={"index": index}, params={"apply": apply}) + handle_error(r) + output(r.json(), ctx) + + +@playlist.command("apply") +@click.pass_context +def playlist_apply(ctx): + """Apply current prompt to generation.""" + with get_client(ctx) as client: + r = client.post("/api/v1/realtime/playlist/apply") + handle_error(r) + output(r.json(), ctx) + + +@playlist.command("nav") +@click.option("--context", "-c", type=int, default=5, help="Lines of context around current (default: 5 = 11 total)") +@click.pass_context +def playlist_nav(ctx, context): + """Interactive navigation mode with autoplay. + + Controls: + →, n, l, SPACE Next prompt + ←, p Previous prompt (stops autoplay) + m Toggle bookmark on current prompt (★) + N Jump to next bookmarked prompt + P Jump to previous bookmarked prompt + B Toggle bookmark filter (show only bookmarked) + o Toggle autoplay (default 5s interval) + h/H Toggle hard cut mode (reset cache on each transition) + s Toggle soft cut mode (temporary KV-bias override) + t Toggle embedding transition mode (temporal interpolation) + T Toggle transition method (linear ↔ slerp) + x One-shot hard cut (doesn't change mode) + z Randomize seed + Z Soft cut + new seed (same prompt, new variation) + S Set specific seed (prompts for number) + * Bookmark current seed + # Show seed info (current, history, bookmarks) + b Toggle blend mode (style switch without cache reset) + 1-5 Set soft cut bias (when soft cut active) + !@#$% Set soft cut duration in chunks (when soft cut active) + 6-0 Set transition chunks (1-5) (when transition active) + +/- Adjust autoplay speed (1-30s) + g Go to index (prompts for number) + a Apply current prompt + j Jiggle active prompt (generate variations) + r Refresh display + q, ESC Quit + + Jiggle mode: + 1-4 Apply variation (respects current hard/soft/transition mode) + j Regenerate variations (attentional mode) + J Regenerate variations (semantic mode, requires direction) + d Set/clear direction (blank clears) + ESC Cancel and return to playlist + + Changes are auto-applied by default. + Hard cut mode resets the KV cache on each prompt change for clean scene transitions. + Soft cut mode temporarily lowers kv_cache_attention_bias for faster adaptation without a full reset. + Embedding transition mode interpolates embeddings over N chunks for smoother prompt morphs. + Hard cut and embedding transition are mutually exclusive (hard cut resets state). + One-shot hard cut (x) does a single cache reset without changing your current mode. + """ + import builtins + import os + import select + import shutil + import termios + import time + import textwrap + import tty + + def get_char_nonblocking(timeout=0.2): + """Read a char with timeout. Returns None if no input.""" + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + if select.select([fd], [], [], timeout)[0]: + ch = os.read(fd, 1).decode("utf-8", errors="ignore") + if ch == "\x1b": + extra = "" + for _ in range(5): + if select.select([fd], [], [], 0.05)[0]: + byte = os.read(fd, 1).decode("utf-8", errors="ignore") + extra += byte + if len(extra) >= 2 and extra[0] == "[" and extra[-1] in "ABCD": + break + else: + break + ch = ch + extra + return ch + return None + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + def display_preview( + client, + autoplay=False, + interval=5.0, + hard_cut=False, + soft_cut=False, + soft_cut_bias=0.20, + soft_cut_chunks=4, + transition=False, + transition_chunks=4, + transition_method="slerp", + blend_mode=True, + context_lines=8, + bookmark_filter=False, + ): + """Fetch and display preview.""" + import shutil + + term_width = shutil.get_terminal_size().columns + + params = {"context": context_lines} + if bookmark_filter: + params["bookmarks_only"] = True + r = client.get("/api/v1/realtime/playlist/preview", params=params) + if r.status_code != 200: + return None + data = r.json() + if data.get("status") == "no_playlist": + click.echo("\n No playlist loaded. Use: video-cli playlist load \n") + return None + + click.echo("\n" + "=" * term_width) + bookmarked_indices = data.get("bookmarked_indices", []) + bookmark_count = len(bookmarked_indices) + status = f" Playlist: {data.get('total', 0)} prompts" + if bookmark_count > 0: + status += f" [★ {bookmark_count} bookmarked]" + if bookmark_filter: + status += " [FILTERED]" + if autoplay: + status += f" [▶ AUTO {interval}s]" + if blend_mode: + status += " [🌀 BLEND]" + if hard_cut: + status += " [✂ HARD CUT]" + if soft_cut: + status += f" [~ SOFT b={soft_cut_bias:.2f} c={soft_cut_chunks}]" + if transition: + method_label = "SLERP" if transition_method == "slerp" else "LERP" + status += f" [⟷ {method_label} c={transition_chunks}]" + click.echo(status) + click.echo("=" * term_width) + + # Calculate prompt display width + prompt_width = term_width - 12 # Extra space for bookmark marker + + # Get prompts to show + prompts_to_show = data.get("prompts", []) + + # Cap display (only when not in bookmark filter mode - we want to see all bookmarks) + if not bookmark_filter: + max_display = context_lines * 2 + 1 + if len(prompts_to_show) > max_display: + prompts_to_show = prompts_to_show[:max_display] + + for item in prompts_to_show: + is_current = item.get("current") + is_bookmarked = item.get("bookmarked", False) + marker = "▶ " if is_current else " " + bookmark = "★" if is_bookmarked else " " + idx = item.get("index", 0) + prompt = item.get("prompt", "")[:prompt_width] + if is_current: + click.echo( + click.style(f"{marker}{bookmark}[{idx:3d}] {prompt}", fg="green", bold=True) + ) + elif is_bookmarked: + click.echo( + click.style(f"{marker}{bookmark}[{idx:3d}] {prompt}", fg="yellow") + ) + else: + click.echo(f"{marker}{bookmark}[{idx:3d}] {prompt}") + + click.echo("=" * term_width) + click.echo( + " ←/→ nav | N/P jump★ | m mark | B filter★ | g goto | a apply | o auto | +/- speed" + ) + click.echo( + " h hard | s soft | t trans | T slerp | x cut! | b blend | z seed | j jiggle | D dir | q quit" + ) + click.echo("=" * term_width + "\n") + return data + + STASH_CAPTIONS_PATH = CONTENT_DIR / "_stash" / "Captioning" / "stash_captions.txt" + + def _normalize_stash_prompt(text: str) -> str: + """Stash file is strictly one prompt per line.""" + return " ".join(text.replace("\r", "\n").replace("\n", " ").split()).strip() + + def append_to_jiggle_stash(prompt_text: str) -> None: + """Append an applied jiggle variation to the stash playlist.""" + line = _normalize_stash_prompt(prompt_text) + if not line: + return + STASH_CAPTIONS_PATH.parent.mkdir(parents=True, exist_ok=True) + with STASH_CAPTIONS_PATH.open("a") as f: + f.write(line + "\n") + + def fetch_jiggle_candidates( + client: httpx.Client, + prompt_text: str, + *, + count: int = 4, + direction: str | None = None, + mode: str = "attentional", + intensity: float = 0.3, + timeout_s: float = 10.0, + ) -> dict[str, object]: + """Fetch jiggle candidates from API using a single count=N call.""" + body: dict[str, object] = { + "prompt": prompt_text, + "count": count, + "intensity": intensity, + "mode": mode, + } + if direction: + body["direction"] = direction + try: + r = client.post("/api/v1/prompt/jiggle", json=body, timeout=timeout_s) + except Exception as e: + return {"ok": False, "error": str(e)} + + if r.status_code != 200: + return {"ok": False, "error": r.text, "status_code": r.status_code} + + try: + data = r.json() + except Exception: + return {"ok": False, "error": r.text} + + status = data.get("status") + original_prompt = data.get("original_prompt") + variations = data.get("variations") + if not isinstance(original_prompt, str): + original_prompt = prompt_text + if not isinstance(variations, list): + variations = [] + variations = [v for v in variations if isinstance(v, str) and v.strip()] + return { + "ok": True, + "status": status, + "original_prompt": original_prompt, + "variations": variations, + } + + def display_jiggle_view( + original: str, + candidates: list[str], + *, + direction: str | None, + intensity: float, + mode: str, + preview_index: int | None = None, # Currently previewed candidate (0-3) + ) -> None: + """Render a full-screen jiggle candidate view.""" + term_width = shutil.get_terminal_size().columns + click.clear() + + header = f"JIGGLE MODE [{mode}] [intensity={intensity:.1f}]" + if direction: + header += f' [dir="{direction[:20]}"]' + click.echo(click.style(header, fg="cyan", bold=True)) + click.echo("=" * min(term_width, 70)) + + click.echo("Original:") + for line in textwrap.wrap(original, width=max(20, min(term_width - 4, 66)))[:2]: + click.echo(f" {line}") + click.echo("-" * min(term_width, 70)) + + for i, candidate in enumerate(candidates[:4], 1): + is_previewing = (preview_index == i - 1) + marker = " ◀──" if is_previewing else "" + style = {"fg": "green", "bold": True} if is_previewing else {} + click.echo(click.style(f"[{i}]{marker}", **style)) + for line in textwrap.wrap(candidate, width=max(20, min(term_width - 4, 66)))[:2]: + click.echo(click.style(f" {line}", **style)) + + click.echo("=" * min(term_width, 70)) + click.echo("1-4: preview | Enter: confirm | j: regen | d: direction | J: semantic") + click.echo("ESC: cancel | q: quit") + + def _apply_jiggle_prompt( + client: httpx.Client, + prompt_text: str, + *, + hard_cut: bool, + soft_cut: bool, + soft_cut_bias: float, + soft_cut_chunks: int, + transition: bool, + transition_chunks: int, + transition_method: str, + ) -> httpx.Response: + """Apply a prompt respecting current transition mode toggles.""" + if hard_cut: + return client.post("/api/v1/realtime/hard-cut", json={"prompt": prompt_text}) + + if transition: + prompt_payload = [{"text": prompt_text, "weight": 1.0}] + method = transition_method if transition_method in ("linear", "slerp") else "linear" + msg: dict[str, object] = { + "transition": { + "target_prompts": prompt_payload, + "num_steps": transition_chunks, + "temporal_interpolation_method": method, + } + } + if soft_cut: + msg["_rcp_soft_transition"] = { + "temp_bias": soft_cut_bias, + "num_chunks": soft_cut_chunks, + } + return client.post("/api/v1/realtime/parameters", json=msg) + + if soft_cut: + return client.post( + "/api/v1/realtime/soft-cut", + json={ + "prompt": prompt_text, + "temp_bias": soft_cut_bias, + "num_chunks": soft_cut_chunks, + }, + ) + + return client.put("/api/v1/realtime/prompt", json={"prompt": prompt_text}) + + click.echo("\nPlaylist Navigation Mode") + click.echo("Press q or ESC to quit\n") + + # Autoplay state + autoplay = False + autoplay_interval = 5.0 + last_advance = time.time() + + # Hard cut state - when enabled, all transitions reset the KV cache + hard_cut = False + + # Soft cut state - when enabled, transitions temporarily lower KV cache bias + soft_cut = False + soft_cut_bias = 0.20 # Default temp bias + soft_cut_chunks = 4 # Default duration + + # Transition (embedding interpolation) state + transition = False + transition_chunks = 4 # Default number of chunks to interpolate over + transition_method = "slerp" # linear or slerp + + # Blend mode state - when enabled, style switches don't reset cache (artistic artifacts) + blend_mode = True + + # Bookmark filter - when enabled, only show bookmarked prompts in display + bookmark_filter = False + + # Jiggle mode state (prompt variations) + jiggle_mode = False + jiggle_candidates: list[str] = [] + jiggle_original = "" + jiggle_direction: str | None = None # Persists across jiggle sessions + jiggle_intensity = 0.3 + jiggle_view_mode = "attentional" + jiggle_preview_index: int | None = None # Currently previewed candidate (0-3) + was_autoplay = False # Restore autoplay on exit + + with get_client(ctx) as client: + # Fetch current blend mode from server + try: + r = client.get("/api/v1/realtime/style/blend-mode") + if r.status_code == 200: + blend_mode = r.json().get("blend_mode", False) + except Exception: + pass # Use default if server not ready + + if display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode + ) is None: + return + + while True: + try: + ch = get_char_nonblocking(timeout=0.2) + + if ch is not None: + # Quit + if ch in ("q", "Q", "\x03"): + click.echo("\nExiting navigation mode.\n") + break + elif ch == "\x1b" and not jiggle_mode: + click.echo("\nExiting navigation mode.\n") + break + + # === JIGGLE MODE HANDLERS (must run before other mode handlers) === + elif jiggle_mode: + # Cancel jiggle (ESC only; keep q as quit) + # If previewing, restore original prompt before exiting + if ch == "\x1b": + cancel_msg = " Cancelled" + if jiggle_preview_index is not None: + # Restore original prompt + try: + _apply_jiggle_prompt( + client, + jiggle_original, + hard_cut=hard_cut, + soft_cut=soft_cut, + soft_cut_bias=soft_cut_bias, + soft_cut_chunks=soft_cut_chunks, + transition=transition, + transition_chunks=transition_chunks, + transition_method=transition_method, + ) + cancel_msg = " Cancelled (restored original)" + except Exception: + cancel_msg = " Cancelled (failed to restore original)" + jiggle_mode = False + jiggle_candidates = [] + jiggle_preview_index = None + autoplay = was_autoplay + click.clear() + display_preview( + client, + autoplay, + autoplay_interval, + hard_cut, + soft_cut, + soft_cut_bias, + soft_cut_chunks, + transition, + transition_chunks, + transition_method, + blend_mode, + context, + bookmark_filter, + ) + click.echo(cancel_msg) + continue + + # Preview candidate 1-4 (apply but stay in jiggle mode) + if ch in "1234": + idx = int(ch) - 1 + if idx < len(jiggle_candidates): + selected = jiggle_candidates[idx] + try: + r2 = _apply_jiggle_prompt( + client, + selected, + hard_cut=hard_cut, + soft_cut=soft_cut, + soft_cut_bias=soft_cut_bias, + soft_cut_chunks=soft_cut_chunks, + transition=transition, + transition_chunks=transition_chunks, + transition_method=transition_method, + ) + if r2.status_code != 200: + click.echo(f" Failed to preview [{ch}]: {r2.text}") + else: + jiggle_preview_index = idx + display_jiggle_view( + jiggle_original, + jiggle_candidates, + direction=jiggle_direction, + intensity=jiggle_intensity, + mode=jiggle_view_mode, + preview_index=jiggle_preview_index, + ) + click.echo(f" Previewing [{ch}] - press Enter to confirm, or try another") + except Exception as e: + click.echo(f" Failed to preview [{ch}]: {e}") + else: + click.echo(f" No candidate [{ch}]") + continue + + # Confirm current preview (Enter) + if ch in ("\r", "\n"): + if jiggle_preview_index is not None: + selected = jiggle_candidates[jiggle_preview_index] + # Save to stash for later recall + stash_msg = "" + try: + append_to_jiggle_stash(selected) + except Exception as e: + stash_msg = f" (stash save failed: {e})" + else: + stash_msg = " (saved to stash)" + + jiggle_mode = False + jiggle_candidates = [] + jiggle_preview_index = None + autoplay = was_autoplay + click.clear() + display_preview( + client, + autoplay, + autoplay_interval, + hard_cut, + soft_cut, + soft_cut_bias, + soft_cut_chunks, + transition, + transition_chunks, + transition_method, + blend_mode, + context, + bookmark_filter, + ) + click.echo(f" ✓ Confirmed{stash_msg}") + else: + click.echo(" No preview selected - press 1-4 to preview first") + continue + + # Regenerate (attentional) - clears preview + if ch == "j": + jiggle_preview_index = None + jiggle_view_mode = "attentional" + click.echo(" Regenerating...") + result = fetch_jiggle_candidates( + client, + jiggle_original, + count=4, + direction=jiggle_direction, + mode=jiggle_view_mode, + intensity=jiggle_intensity, + ) + if not result.get("ok"): + click.echo(f" Failed to regenerate: {result.get('error', 'unknown error')}") + else: + jiggle_candidates = list(result.get("variations", [])) # type: ignore[list-item] + display_jiggle_view( + str(result.get("original_prompt", jiggle_original)), + jiggle_candidates, + direction=jiggle_direction, + intensity=jiggle_intensity, + mode=jiggle_view_mode, + preview_index=jiggle_preview_index, + ) + continue + + # Regenerate (semantic) - clears preview + if ch == "J": + if not jiggle_direction: + click.echo(" Semantic mode requires direction - press 'd' first") + continue + jiggle_preview_index = None + jiggle_view_mode = "semantic" + click.echo(" Regenerating (semantic)...") + result = fetch_jiggle_candidates( + client, + jiggle_original, + count=4, + direction=jiggle_direction, + mode=jiggle_view_mode, + intensity=jiggle_intensity, + ) + if not result.get("ok"): + click.echo(f" Failed to regenerate: {result.get('error', 'unknown error')}") + else: + jiggle_candidates = list(result.get("variations", [])) # type: ignore[list-item] + display_jiggle_view( + str(result.get("original_prompt", jiggle_original)), + jiggle_candidates, + direction=jiggle_direction, + intensity=jiggle_intensity, + mode=jiggle_view_mode, + preview_index=jiggle_preview_index, + ) + continue + + # Direction input + regen (attentional) - clears preview + if ch == "d": + click.echo("\nDirection (blank to clear): ", nl=False) + try: + direction_input = builtins.input().strip() + except EOFError: + direction_input = "" + jiggle_direction = direction_input or None + jiggle_preview_index = None + jiggle_view_mode = "attentional" + click.echo(" Regenerating...") + result = fetch_jiggle_candidates( + client, + jiggle_original, + count=4, + direction=jiggle_direction, + mode=jiggle_view_mode, + intensity=jiggle_intensity, + ) + if not result.get("ok"): + click.echo(f" Failed to regenerate: {result.get('error', 'unknown error')}") + else: + jiggle_candidates = list(result.get("variations", [])) # type: ignore[list-item] + display_jiggle_view( + str(result.get("original_prompt", jiggle_original)), + jiggle_candidates, + direction=jiggle_direction, + intensity=jiggle_intensity, + mode=jiggle_view_mode, + preview_index=jiggle_preview_index, + ) + continue + + click.echo(" 1-4: preview | Enter: confirm | j: regen | d: direction | J: semantic | ESC: cancel") + continue + + # Enter jiggle mode + elif ch == "j": + # Jiggle the active (currently generating) prompt, not the playlist line. + try: + state_r = client.get("/api/v1/realtime/state") + except Exception as e: + click.echo(f" Failed to fetch active prompt: {e}") + continue + if state_r.status_code != 200: + click.echo(f" Failed to fetch active prompt: {state_r.text}") + continue + try: + state = state_r.json() + except Exception: + click.echo(" Failed to parse active prompt state") + continue + + active_prompt = state.get("compiled_prompt") or state.get("prompt") + if not isinstance(active_prompt, str) or not active_prompt.strip(): + click.echo(" No active prompt to jiggle") + continue + + click.echo(" Generating variations...") + result = fetch_jiggle_candidates( + client, + active_prompt, + count=4, + direction=jiggle_direction, + mode="attentional", + intensity=jiggle_intensity, + ) + if not result.get("ok"): + click.echo(f" Failed to generate variations: {result.get('error', 'unknown error')}") + continue + + if result.get("status") == "unchanged": + click.echo(" [Jiggle unavailable - GEMINI_API_KEY not set]") + continue + + candidates = list(result.get("variations", [])) # type: ignore[list-item] + if not candidates: + click.echo(" No variations generated") + continue + + was_autoplay = autoplay + autoplay = False + jiggle_mode = True + jiggle_original = str(result.get("original_prompt", active_prompt)) + jiggle_candidates = candidates + jiggle_view_mode = "attentional" + jiggle_preview_index = None # No preview yet + display_jiggle_view( + jiggle_original, + jiggle_candidates, + direction=jiggle_direction, + intensity=jiggle_intensity, + mode=jiggle_view_mode, + preview_index=jiggle_preview_index, + ) + continue + + # Set jiggle direction before entering jiggle mode + elif ch == "D": + click.echo("\nJiggle direction (blank to clear): ", nl=False) + try: + direction_input = builtins.input().strip() + except EOFError: + direction_input = "" + jiggle_direction = direction_input or None + if jiggle_direction: + click.echo(f' Direction set: "{jiggle_direction}" - press j to jiggle') + else: + click.echo(" Direction cleared") + continue + + # Toggle autoplay + elif ch == "o": + autoplay = not autoplay + last_advance = time.time() + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Toggle hard cut mode (mutually exclusive with soft cut) + elif ch in ("H", "h"): + hard_cut = not hard_cut + if hard_cut: + soft_cut = False # Mutually exclusive + transition = False # Transition requires continuity + status = "ON - transitions will reset cache" if hard_cut else "OFF" + click.echo(f" ✂ Hard cut: {status}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Toggle soft cut mode (mutually exclusive with hard cut) + elif ch == "s": + soft_cut = not soft_cut + if soft_cut: + hard_cut = False # Mutually exclusive + if soft_cut: + status = f"ON (bias={soft_cut_bias}, chunks={soft_cut_chunks})" + else: + status = "OFF" + click.echo(f" ~ Soft cut: {status}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Bias adjustment (1-5 keys when soft_cut active) + elif soft_cut and ch in "12345": + bias_map = {"1": 0.05, "2": 0.1, "3": 0.15, "4": 0.2, "5": 0.25} + soft_cut_bias = bias_map[ch] + click.echo(f" ~ Soft cut bias: {soft_cut_bias}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Chunk adjustment (Shift+1-5 = !, @, #, $, % when soft_cut active) + elif soft_cut and ch in "!@#$%": + chunk_map = {"!": 1, "@": 2, "#": 3, "$": 4, "%": 5} + soft_cut_chunks = chunk_map[ch] + click.echo(f" ~ Soft cut chunks: {soft_cut_chunks}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Toggle transition (embedding interpolation) mode + elif ch == "t": + transition = not transition + if transition: + hard_cut = False # Transition requires continuity + status = f"ON (chunks={transition_chunks}, method={transition_method})" + else: + status = "OFF" + click.echo(f" ⟷ Transition: {status}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Toggle transition method (linear/slerp) + elif ch == "T": + transition_method = "slerp" if transition_method == "linear" else "linear" + click.echo(f" ⟷ Transition method: {transition_method}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Transition chunks adjustment (6-0 when transition active) + elif transition and ch in "67890": + chunk_map = {"6": 1, "7": 2, "8": 3, "9": 4, "0": 5} + transition_chunks = chunk_map[ch] + click.echo(f" ⟷ Transition chunks: {transition_chunks}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Toggle blend mode (style switch without cache reset) + elif ch == "b": + blend_mode = not blend_mode + # Sync with server + try: + r = client.put( + "/api/v1/realtime/style/blend-mode", + json={"enabled": blend_mode} + ) + if r.status_code == 200: + status = "ON - style switches will create blend artifacts" if blend_mode else "OFF - clean transitions" + click.echo(f" 🌀 Blend mode: {status}") + else: + blend_mode = not blend_mode # Revert on failure + click.echo(" Failed to set blend mode") + except Exception as e: + blend_mode = not blend_mode + click.echo(f" Error setting blend mode: {e}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Adjust speed + elif ch in ("+", "=", "]"): + autoplay_interval = max(1.0, autoplay_interval - 1.0) + click.echo(f" Interval: {autoplay_interval}s") + elif ch in ("-", "_", "["): + autoplay_interval = min(30.0, autoplay_interval + 1.0) + click.echo(f" Interval: {autoplay_interval}s") + + # Next (uses bookmark nav when filter is ON) + elif ch in ("\x1b[C", "n", "l", " "): + params = {"apply": True} + if hard_cut: + params["hard_cut"] = True + if soft_cut: + params["soft_cut"] = True + params["soft_cut_bias"] = soft_cut_bias + params["soft_cut_chunks"] = soft_cut_chunks + if transition: + params["transition"] = True + params["transition_chunks"] = transition_chunks + params["transition_method"] = transition_method + endpoint = "/api/v1/realtime/playlist/bookmark/next" if bookmark_filter else "/api/v1/realtime/playlist/next" + r = client.post(endpoint, params=params) + if r.status_code == 200: + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + elif r.status_code == 400 and "No bookmarks" in r.text: + click.echo(" ⚠ No bookmarks set. Press 'm' to bookmark prompts.") + last_advance = time.time() + + # Previous (stops autoplay, uses bookmark nav when filter is ON) + elif ch in ("\x1b[D", "p"): + params = {"apply": True} + if hard_cut: + params["hard_cut"] = True + if soft_cut: + params["soft_cut"] = True + params["soft_cut_bias"] = soft_cut_bias + params["soft_cut_chunks"] = soft_cut_chunks + if transition: + params["transition"] = True + params["transition_chunks"] = transition_chunks + params["transition_method"] = transition_method + endpoint = "/api/v1/realtime/playlist/bookmark/prev" if bookmark_filter else "/api/v1/realtime/playlist/prev" + r = client.post(endpoint, params=params) + if r.status_code == 200: + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + elif r.status_code == 400 and "No bookmarks" in r.text: + click.echo(" ⚠ No bookmarks set. Press 'm' to bookmark prompts.") + last_advance = time.time() + if autoplay: + autoplay = False + click.echo(" ⏸ Autoplay stopped") + + # Goto + elif ch == "g": + click.echo("\nGoto index: ", nl=False) + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + try: + idx_str = builtins.input() + idx = int(idx_str) + params = {"apply": True} + if hard_cut: + params["hard_cut"] = True + if soft_cut: + params["soft_cut"] = True + params["soft_cut_bias"] = soft_cut_bias + params["soft_cut_chunks"] = soft_cut_chunks + if transition: + params["transition"] = True + params["transition_chunks"] = transition_chunks + params["transition_method"] = transition_method + r = client.post( + "/api/v1/realtime/playlist/goto", + json={"index": idx}, + params=params, + ) + if r.status_code == 200: + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method + ) + last_advance = time.time() + except ValueError: + click.echo("Invalid index") + except EOFError: + pass + + # Apply (with hard/soft cut/transition if enabled) + elif ch == "a": + params = {} + if hard_cut: + params["hard_cut"] = True + if soft_cut: + params["soft_cut"] = True + params["soft_cut_bias"] = soft_cut_bias + params["soft_cut_chunks"] = soft_cut_chunks + if transition: + params["transition"] = True + params["transition_chunks"] = transition_chunks + params["transition_method"] = transition_method + r = client.post("/api/v1/realtime/playlist/apply", params=params) + if r.status_code == 200: + msg = "✓ Prompt applied" + if hard_cut: + msg += " (hard cut)" + if soft_cut: + msg += f" (soft cut b={soft_cut_bias})" + if transition: + msg += f" (transition c={transition_chunks})" + click.echo(f" {msg}") + + # One-shot hard cut (doesn't change mode) + elif ch == "x": + r = client.post("/api/v1/realtime/hard-cut") + if r.status_code == 200: + click.echo(" ✂ One-shot hard cut applied") + + # Randomize seed + elif ch == "z": + import random + new_seed = random.randint(0, 2**32 - 1) + r = client.post("/api/v1/realtime/parameters", json={"base_seed": new_seed}) + if r.status_code == 200: + click.echo(f" 🎲 New seed: {new_seed}") + else: + click.echo(f" Error setting seed: {r.text}") + + # Soft cut with new seed (same prompt, new variation) + elif ch == "Z": + import random + new_seed = random.randint(0, 2**32 - 1) + # Set new seed + r1 = client.post("/api/v1/realtime/parameters", json={"base_seed": new_seed}) + # Apply current prompt with soft cut + params = { + "soft_cut": True, + "soft_cut_bias": soft_cut_bias, + "soft_cut_chunks": soft_cut_chunks, + } + r2 = client.post("/api/v1/realtime/playlist/apply", params=params) + if r1.status_code == 200 and r2.status_code == 200: + click.echo(f" 🎲~ Soft cut with new seed: {new_seed}") + else: + click.echo(f" Error: seed={r1.status_code}, apply={r2.status_code}") + + # Set specific seed + elif ch == "S": + click.echo(" Enter seed: ", nl=False) + # Temporarily restore terminal for input + import termios + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + seed_str = builtins.input() + if seed_str.strip(): + try: + new_seed = int(seed_str.strip()) + r = client.post("/api/v1/realtime/parameters", json={"base_seed": new_seed}) + if r.status_code == 200: + click.echo(f" 🎲 Seed set: {new_seed}") + else: + click.echo(f" Error: {r.text}") + except ValueError: + click.echo(" Invalid seed (must be integer)") + finally: + pass # Will re-enter raw mode on next iteration + + # Bookmark current seed + elif ch == "*": + r = client.post("/api/v1/realtime/seed/bookmark") + if r.status_code == 200: + data = r.json() + click.echo(f" ⭐ Bookmarked seed: {data.get('bookmarked')}") + else: + click.echo(f" Error: {r.text}") + + # Show seed info (current, history, bookmarks) + elif ch == "#": + r = client.get("/api/v1/realtime/seed") + if r.status_code == 200: + data = r.json() + click.echo(f" Current: {data.get('current')}") + click.echo(f" History: {data.get('history', [])}") + click.echo(f" Bookmarks: {data.get('bookmarks', [])}") + else: + click.echo(f" Error: {r.text}") + + # Toggle prompt bookmark + elif ch == "m": + r = client.post("/api/v1/realtime/playlist/bookmark") + if r.status_code == 200: + data = r.json() + if data.get("status") == "bookmarked": + click.echo(f" ★ Bookmarked prompt {data.get('current_index')}") + else: + click.echo(f" ☆ Unbookmarked prompt {data.get('current_index')}") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + else: + click.echo(f" Error: {r.text}") + + # Next bookmarked prompt + elif ch == "N": + params = {"apply": True} + if hard_cut: + params["hard_cut"] = True + if soft_cut: + params["soft_cut"] = True + params["soft_cut_bias"] = soft_cut_bias + params["soft_cut_chunks"] = soft_cut_chunks + if transition: + params["transition"] = True + params["transition_chunks"] = transition_chunks + params["transition_method"] = transition_method + r = client.post("/api/v1/realtime/playlist/bookmark/next", params=params) + if r.status_code == 200: + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + elif r.status_code == 400 and "No bookmarks" in r.text: + click.echo(" ⚠ No bookmarks set. Press 'm' to bookmark prompts.") + else: + click.echo(f" Error: {r.text}") + + # Previous bookmarked prompt + elif ch == "P": + params = {"apply": True} + if hard_cut: + params["hard_cut"] = True + if soft_cut: + params["soft_cut"] = True + params["soft_cut_bias"] = soft_cut_bias + params["soft_cut_chunks"] = soft_cut_chunks + if transition: + params["transition"] = True + params["transition_chunks"] = transition_chunks + params["transition_method"] = transition_method + r = client.post("/api/v1/realtime/playlist/bookmark/prev", params=params) + if r.status_code == 200: + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + elif r.status_code == 400 and "No bookmarks" in r.text: + click.echo(" ⚠ No bookmarks set. Press 'm' to bookmark prompts.") + else: + click.echo(f" Error: {r.text}") + + # Toggle bookmark filter view + elif ch == "B": + bookmark_filter = not bookmark_filter + if bookmark_filter: + click.echo(" ★ Bookmark filter ON - showing only bookmarked prompts") + else: + click.echo(" ★ Bookmark filter OFF - showing all prompts") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Refresh + elif ch == "r": + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + + # Autoplay advance + if autoplay and (time.time() - last_advance) >= autoplay_interval: + params = {"apply": True} + if hard_cut: + params["hard_cut"] = True + if soft_cut: + params["soft_cut"] = True + params["soft_cut_bias"] = soft_cut_bias + params["soft_cut_chunks"] = soft_cut_chunks + if transition: + params["transition"] = True + params["transition_chunks"] = transition_chunks + params["transition_method"] = transition_method + r = client.post("/api/v1/realtime/playlist/next", params=params) + if r.status_code == 200: + data = r.json() + if not data.get("has_next", False): + autoplay = False + click.echo(" ⏹ End of playlist") + display_preview( + client, autoplay, autoplay_interval, hard_cut, + soft_cut, soft_cut_bias, soft_cut_chunks, + transition, transition_chunks, transition_method, blend_mode, context, + bookmark_filter + ) + last_advance = time.time() + + except KeyboardInterrupt: + click.echo("\nExiting navigation mode.\n") + break + + +# --- Input Sources --- + + +@cli.group() +@click.pass_context +def input(ctx): + """Manage video input sources (NDI, Spout, WebRTC).""" + pass + + +@input.group() +@click.pass_context +def ndi(ctx): + """NDI input control.""" + pass + + +@ndi.command("enable") +@click.argument("source", default="") +@click.option("--extra-ips", "-e", multiple=True, help="Extra IPs to probe (e.g., Tailscale IPs)") +@click.pass_context +def ndi_enable(ctx, source, extra_ips): + """Enable NDI receiver. + + SOURCE is a substring to match against available NDI sources. + Use --extra-ips for Tailscale or cross-subnet discovery. + + Examples: + video-cli input ndi enable DepthOutput -e 100.70.189.4 + video-cli input ndi enable "QUIXOTRON" + """ + with get_client(ctx) as client: + payload = { + "ndi_receiver": { + "enabled": True, + "source": source, + } + } + if extra_ips: + payload["ndi_receiver"]["extra_ips"] = list(extra_ips) + r = client.post("/api/v1/realtime/parameters", json=payload) + handle_error(r) + result = {"status": "ndi_enabled", "source": source or "(any)"} + if extra_ips: + result["extra_ips"] = list(extra_ips) + output(result, ctx) + + +@ndi.command("disable") +@click.pass_context +def ndi_disable(ctx): + """Disable NDI receiver.""" + with get_client(ctx) as client: + payload = {"ndi_receiver": {"enabled": False}} + r = client.post("/api/v1/realtime/parameters", json=payload) + handle_error(r) + output({"status": "ndi_disabled"}, ctx) + + +@ndi.command("list") +@click.option("--extra-ips", "-e", multiple=True, help="Extra IPs to probe") +@click.option("--timeout", "-t", default=3000, help="Discovery timeout in ms") +@click.pass_context +def ndi_list(ctx, extra_ips, timeout): + """List available NDI sources.""" + try: + from scope.server.ndi.finder import list_sources + except ImportError: + output({"error": "NDI module not available"}, ctx) + return + + sources = list_sources( + timeout_ms=timeout, + extra_ips=list(extra_ips) if extra_ips else None, + show_local_sources=True, + ) + result = { + "sources": [{"name": s.name, "url": s.url_address} for s in sources], + "count": len(sources), + } + output(result, ctx) + + +@ndi.command("probe") +@click.argument("source", default="") +@click.option("--extra-ips", "-e", multiple=True, help="Extra IPs to probe (e.g., Tailscale IPs)") +@click.option("--discover-timeout", default=3000, help="Discovery timeout in ms") +@click.option("--capture-timeout", default=2000, help="Capture timeout in ms") +@click.pass_context +def ndi_probe(ctx, source, extra_ips, discover_timeout, capture_timeout): + """Connect to an NDI source and attempt to capture a single video frame. + + This is useful for debugging "source shows up in discovery but no frames arrive". + """ + try: + from scope.server.ndi.receiver import NDIReceiver + except ImportError: + output({"error": "NDI module not available"}, ctx) + return + + receiver = NDIReceiver(recv_name="ScopeNDIProbe") + if not receiver.create(): + output({"error": "NDI receiver create() failed"}, ctx) + return + + extra_ips_list = list(extra_ips) if extra_ips else None + try: + src = receiver.connect_discovered( + source_substring=source, + extra_ips=extra_ips_list, + timeout_ms=int(discover_timeout), + ) + + start = time.monotonic() + deadline = start + max(0.0, float(capture_timeout)) / 1000.0 + frame = None + while time.monotonic() < deadline: + remaining_ms = max(0, int((deadline - time.monotonic()) * 1000)) + frame = receiver.receive_latest_rgb24(timeout_ms=min(250, remaining_ms)) + if frame is not None: + break + + elapsed_ms = int((time.monotonic() - start) * 1000) + result: dict[str, Any] = { + "source": {"name": src.name, "url": src.url_address}, + "connections": receiver.get_no_connections(), + "elapsed_ms": elapsed_ms, + "frame": None, + } + if frame is not None: + result["frame"] = {"shape": list(frame.shape), "dtype": str(frame.dtype)} + + output(result, ctx) + except Exception as e: + output({"error": str(e)}, ctx) + finally: + try: + receiver.release() + except Exception: + pass + + +@input.command("status") +@click.pass_context +def input_status(ctx): + """Get current input source status.""" + with get_client(ctx) as client: + # Try to get state (requires WebRTC session) + r = client.get("/api/v1/realtime/state") + if r.status_code == 200: + data = r.json() + result = { + "active_source": data.get("active_input_source", "unknown"), + "ndi_frames_received": data.get("ndi_frames_received", 0), + "ndi_frames_dropped": data.get("ndi_frames_dropped", 0), + } + output(result, ctx) + else: + # Fallback: just report that we can't get status without session + output({"status": "no_session", "message": "Connect via WebRTC to see input status"}, ctx) + + +def main(): + cli() + + +if __name__ == "__main__": + main() diff --git a/src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py b/src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py index 4e0cd2325..11c236213 100644 --- a/src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py +++ b/src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py @@ -15,6 +15,7 @@ - Standard path: vace_encode_frames -> vace_encode_masks -> vace_latent """ +import os import logging from typing import Any @@ -30,12 +31,7 @@ OutputParam, ) -from ..utils.encoding import ( - load_and_prepare_reference_images, - vace_encode_frames, - vace_encode_masks, - vace_latent, -) +from ..utils.encoding import load_and_prepare_reference_images logger = logging.getLogger(__name__) @@ -62,22 +58,6 @@ class VaceEncodingBlock(ModularPipelineBlocks): VACE has one operation with multiple potential triggers, better handled internally. """ - def __init__(self): - super().__init__() - # Explicit encoder caches for TAE VACE dual-encode (inactive + reactive) - # These persist across chunks to maintain temporal continuity within each stream - # while preventing MemBlock memory pollution between streams. - # Lazily initialized on first use since we need the VAE to create caches. - self._inactive_cache = None - self._reactive_cache = None - self._caches_initialized = False - - def clear_encoder_caches(self): - """Clear encoder caches for a new video sequence.""" - self._inactive_cache = None - self._reactive_cache = None - self._caches_initialized = False - @property def expected_components(self) -> list[ComponentSpec]: return [ @@ -112,20 +92,16 @@ def inputs(self) -> list[InputParam]: description="VACE conditioning input frames [B, C, F, H, W]: source RGB video frames OR conditioning maps (depth, flow, pose, scribble, etc.). 12 frames per chunk, can be combined with vace_ref_images", ), InputParam( - "vace_input_masks", - default=None, - type_hint=torch.Tensor | None, - description="Spatial control masks [B, 1, F, H, W]: defines WHERE to apply conditioning (white=generate, black=preserve). Defaults to ones (all white) when None. Works with any vace_input_frames type.", - ), - InputParam( - "first_frame_image", + "video", default=None, - description="Path to first frame reference image for extension mode. When provided alone, enables 'firstframe' mode (ref at start, generate after). When provided with last_frame_image, enables 'firstlastframe' mode (refs at both ends).", + type_hint=list[torch.Tensor] | torch.Tensor | None, + description="Optional input video frames [B, C, F, H, W] used for latent-init V2V. When provided alongside vace_input_frames (hybrid mode), VACE encoding must be stateless to avoid clobbering the VAE streaming cache.", ), InputParam( - "last_frame_image", + "vace_input_masks", default=None, - description="Path to last frame reference image for extension mode. When provided alone, enables 'lastframe' mode (generate before, ref at end). When provided with first_frame_image, enables 'firstlastframe' mode (refs at both ends).", + type_hint=torch.Tensor | None, + description="Spatial control masks [B, 1, F, H, W]: defines WHERE to apply conditioning (white=generate, black=preserve). Defaults to ones (all white) when None. Works with any vace_input_frames type.", ), InputParam( "height", @@ -166,40 +142,22 @@ def __call__(self, components, state: PipelineState) -> tuple[Any, PipelineState vace_ref_images = block_state.vace_ref_images vace_input_frames = block_state.vace_input_frames - first_frame_image: str | None = block_state.first_frame_image - last_frame_image: str | None = block_state.last_frame_image current_start = block_state.current_start_frame - # If no inputs provided, skip VACE conditioning - has_ref_images = vace_ref_images is not None and len(vace_ref_images) > 0 - has_input_frames = vace_input_frames is not None - has_first_frame = first_frame_image is not None - has_last_frame = last_frame_image is not None - has_extension = has_first_frame or has_last_frame - - if not has_ref_images and not has_input_frames and not has_extension: + # If neither input is provided, skip VACE conditioning + if ( + vace_ref_images is None or len(vace_ref_images) == 0 + ) and vace_input_frames is None: block_state.vace_context = None block_state.vace_ref_images = None self.set_block_state(state, block_state) return components, state # Determine encoding path based on what's provided (implicit mode detection) - if has_extension: - # Extension mode: Generate frames before/after reference frame(s) - # Mode is inferred from which frame images are provided - if has_first_frame and has_last_frame: - extension_mode = "firstlastframe" - elif has_first_frame: - extension_mode = "firstframe" - else: - extension_mode = "lastframe" + has_ref_images = vace_ref_images is not None and len(vace_ref_images) > 0 + has_input_frames = vace_input_frames is not None - block_state.vace_context, block_state.vace_ref_images = ( - self._encode_extension_mode( - components, block_state, current_start, extension_mode - ) - ) - elif has_input_frames: + if has_input_frames: # Standard VACE path: conditioning input (depth, flow, pose, etc.) # with optional reference images block_state.vace_context, block_state.vace_ref_images = ( @@ -286,204 +244,6 @@ def _encode_reference_only(self, components, block_state, current_start): # Return original paths, not tensors, so they can be reused in subsequent chunks return vace_context, ref_image_paths - def _encode_extension_mode( - self, components, block_state, current_start, extension_mode: str - ): - """ - Encode VACE context with reference frames and dummy frames for temporal extension. - - Loads reference image based on extension_mode (inferred from provided images), - replicates it across a temporal group, fills remaining frames with zeros (dummy frames), - and encodes with masks indicating which frames to inpaint (1=dummy, 0=reference). - - Args: - extension_mode: Inferred mode ('firstframe', 'lastframe', or 'firstlastframe') - """ - first_frame_image = block_state.first_frame_image - last_frame_image = block_state.last_frame_image - - # Load reference images based on mode - if extension_mode == "firstframe": - images_to_load = [first_frame_image] - elif extension_mode == "lastframe": - images_to_load = [last_frame_image] - elif extension_mode == "firstlastframe": - # Load BOTH images for firstlastframe mode - images_to_load = [first_frame_image, last_frame_image] - - prepared_refs = load_and_prepare_reference_images( - images_to_load, - block_state.height, - block_state.width, - components.config.device, - ) - - vae = components.vae - vae_dtype = next(vae.parameters()).dtype - - num_frames = ( - components.config.num_frame_per_block - * components.config.vae_temporal_downsample_factor - ) - - # Determine ref placement - ref_at_start = extension_mode in ("firstframe", "firstlastframe") - ref_at_end = extension_mode in ("lastframe", "firstlastframe") - - frames, masks = self._build_extension_frames_and_masks( - prepared_refs=prepared_refs, - num_frames=num_frames, - temporal_group_size=components.config.vae_temporal_downsample_factor, - ref_at_start=ref_at_start, - ref_at_end=ref_at_end, - device=components.config.device, - dtype=vae_dtype, - height=block_state.height, - width=block_state.width, - ) - - frames_to_encode = [frames] - masks_to_encode = [masks] - - z0 = vace_encode_frames( - vae=vae, - frames=frames_to_encode, - ref_images=[None], - masks=masks_to_encode, - pad_to_96=False, - use_cache=False, - ) - - vae_stride = ( - components.config.vae_temporal_downsample_factor, - components.config.vae_spatial_downsample_factor, - components.config.vae_spatial_downsample_factor, - ) - m0 = vace_encode_masks( - masks=masks_to_encode, - ref_images=[None], - vae_stride=vae_stride, - ) - - vace_context = vace_latent(z0, m0) - - logger.info( - f"_encode_extension_mode: mode={extension_mode}, current_start={current_start}, " - f"num_frames={num_frames}, vace_context_shape={vace_context[0].shape}" - ) - - return vace_context, prepared_refs - - def _build_extension_frames_and_masks( - self, - prepared_refs: list[torch.Tensor], - num_frames: int, - temporal_group_size: int, - ref_at_start: bool, - ref_at_end: bool, - device: torch.device, - dtype: torch.dtype, - height: int, - width: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Build frames and masks for extension mode with reference frame replication. - - Args: - prepared_refs: List of prepared reference images [C, 1, H, W]. - For firstframe/lastframe: single image. - For firstlastframe: [first_image, last_image]. - num_frames: Total number of frames to generate - temporal_group_size: Number of frames in a temporal VAE group - ref_at_start: True to place reference at start (firstframe, firstlastframe) - ref_at_end: True to place reference at end (lastframe, firstlastframe) - device: Target device - dtype: Target dtype for frames - height: Frame height - width: Frame width - - Returns: - Tuple of (frames, masks) where: - - frames: [C, F, H, W] tensor with reference frames and dummy frames - - masks: [1, F, H, W] tensor with 0s for reference frames, 1s for dummy frames - """ - - # Helper to create ref and mask tensors - def make_ref_block( - ref_tensor: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Replicate ref across temporal group, return (frames, masks).""" - ref_replicated = ref_tensor.repeat(1, temporal_group_size, 1, 1) - ref_mask = torch.zeros( - (1, temporal_group_size, height, width), - device=device, - dtype=torch.float32, - ) - return ref_replicated, ref_mask - - if ref_at_start and ref_at_end: - # firstlastframe: [ref_first, ref_first, ..., dummy, dummy, ..., ref_last, ref_last, ...] - # Two temporal groups for refs, rest for dummy - num_dummy_frames = num_frames - 2 * temporal_group_size - if num_dummy_frames < 0: - raise ValueError( - f"Not enough frames for firstlastframe mode: need at least {2 * temporal_group_size} frames, got {num_frames}" - ) - - first_ref_frames, first_ref_mask = make_ref_block(prepared_refs[0]) - last_ref_frames, last_ref_mask = make_ref_block(prepared_refs[1]) - - dummy_frames = torch.zeros( - (3, num_dummy_frames, height, width), device=device, dtype=dtype - ) - dummy_mask = torch.ones( - (1, num_dummy_frames, height, width), - device=device, - dtype=torch.float32, - ) - - frames = torch.cat([first_ref_frames, dummy_frames, last_ref_frames], dim=1) - masks = torch.cat([first_ref_mask, dummy_mask, last_ref_mask], dim=1) - - elif ref_at_start: - # firstframe: [ref, ref, ref, zeros, zeros, ...] - num_dummy_frames = num_frames - temporal_group_size - ref_frames, ref_mask = make_ref_block(prepared_refs[0]) - - dummy_frames = torch.zeros( - (3, num_dummy_frames, height, width), device=device, dtype=dtype - ) - dummy_mask = torch.ones( - (1, num_dummy_frames, height, width), - device=device, - dtype=torch.float32, - ) - - frames = torch.cat([ref_frames, dummy_frames], dim=1) - masks = torch.cat([ref_mask, dummy_mask], dim=1) - - elif ref_at_end: - # lastframe: [zeros, zeros, ..., ref, ref, ref] - num_dummy_frames = num_frames - temporal_group_size - ref_frames, ref_mask = make_ref_block(prepared_refs[0]) - - dummy_frames = torch.zeros( - (3, num_dummy_frames, height, width), device=device, dtype=dtype - ) - dummy_mask = torch.ones( - (1, num_dummy_frames, height, width), - device=device, - dtype=torch.float32, - ) - - frames = torch.cat([dummy_frames, ref_frames], dim=1) - masks = torch.cat([dummy_mask, ref_mask], dim=1) - - else: - raise ValueError("At least one of ref_at_start or ref_at_end must be True") - - return frames, masks - def _encode_with_conditioning(self, components, block_state, current_start): """ Encode VACE input using the standard VACE path, with optional reference images. @@ -521,8 +281,8 @@ def _encode_with_conditioning(self, components, block_state, current_start): # Validate resolution if height != block_state.height or width != block_state.width: raise ValueError( - f"VaceEncodingBlock._encode_with_conditioning: Input resolution {width}x{height} " - f"does not match target resolution {block_state.width}x{block_state.height}" + f"VaceEncodingBlock._encode_with_conditioning: Input resolution {height}x{width} " + f"does not match target resolution {block_state.height}x{block_state.width}" ) # Check if we have reference images too (for combined guidance) @@ -537,6 +297,10 @@ def _encode_with_conditioning(self, components, block_state, current_start): # Import vace_utils for standard encoding path from ..utils.encoding import vace_encode_frames, vace_encode_masks, vace_latent + # If we're also encoding `video` for latent-init V2V, avoid clobbering the VAE's + # streaming encoder cache by doing VACE encoding statelessly. + use_cache = block_state.video is None + # Ensure 3-channel input for VAE (conditioning maps should already be 3-channel RGB) if channels == 1: input_frames_data = input_frames_data.repeat(1, 3, 1, 1, 1) @@ -548,21 +312,26 @@ def _encode_with_conditioning(self, components, block_state, current_start): vae_dtype = next(vae.parameters()).dtype input_frames_data = input_frames_data.to(dtype=vae_dtype) - # Convert to list of [C, F, H, W] for vace_encode_frames - input_frames_list = [input_frames_data[b] for b in range(batch_size)] - # Get vace_input_masks from block_state or default to ones (all white) input_masks_data = block_state.vace_input_masks + full_mask = input_masks_data is None + use_full_mask_fastpath = full_mask and os.getenv("SCOPE_VACE_FULL_MASK_FASTPATH", "0") == "1" if input_masks_data is None: - # Default to ones (all white) - apply conditioning everywhere - input_masks_list = [ - torch.ones( - (1, num_frames, height, width), - dtype=vae_dtype, - device=input_frames_data.device, - ) - for _ in range(batch_size) - ] + if use_full_mask_fastpath: + # Avoid materializing full-resolution ones masks when we know the mask is + # implicitly "all white" (conditioning everywhere). The downstream encoding + # utilities can synthesize the latent-resolution mask directly. + input_masks_list = None + else: + # Default to ones (all white) - apply conditioning everywhere + input_masks_list = [ + torch.ones( + (1, num_frames, height, width), + dtype=vae_dtype, + device=input_frames_data.device, + ) + for _ in range(batch_size) + ] else: # Validate vace_input_masks shape if input_masks_data.dim() != 5: @@ -606,31 +375,31 @@ def _encode_with_conditioning(self, components, block_state, current_start): # Wrap in list for batch dimension ref_images = [prepared_refs] - # Lazily initialize encoder caches on first use (need VAE to create them) - # These caches persist across chunks for temporal continuity - # WanVAE.create_encoder_cache() returns None (no MemBlock issue) - # TAEWrapper.create_encoder_cache() returns TAEEncoderCache - if not self._caches_initialized: - self._inactive_cache = vae.create_encoder_cache() - self._reactive_cache = vae.create_encoder_cache() - self._caches_initialized = True - # Standard VACE encoding path (matching wan_vace.py lines 339-341) # z0 = vace_encode_frames(vace_input_frames, vace_ref_images, masks=vace_input_masks) # When masks are provided, set pad_to_96=False because mask encoding (64 channels) will be added later - # Pass explicit caches to prevent TAE MemBlock memory pollution between inactive/reactive streams z0 = vace_encode_frames( vae, - input_frames_list, + input_frames_data, ref_images, masks=input_masks_list, pad_to_96=False, - inactive_cache=self._inactive_cache, - reactive_cache=self._reactive_cache, + use_cache=use_cache, + full_mask=full_mask, ) # m0 = vace_encode_masks(input_masks, ref_images) - m0 = vace_encode_masks(input_masks_list, ref_images) + m0 = vace_encode_masks( + input_masks_list, + ref_images, + full_mask=full_mask, + batch_size=batch_size, + num_frames=num_frames, + height=height, + width=width, + device=input_frames_data.device, + dtype=vae_dtype, + ) # z = vace_latent(z0, m0) z = vace_latent(z0, m0) diff --git a/src/scope/core/pipelines/wan2_1/vace/mixin.py b/src/scope/core/pipelines/wan2_1/vace/mixin.py index a99a2e96c..604113e48 100644 --- a/src/scope/core/pipelines/wan2_1/vace/mixin.py +++ b/src/scope/core/pipelines/wan2_1/vace/mixin.py @@ -79,11 +79,14 @@ def _init_vace( vace_path = _get_config_value(config, "vace_path") vace_in_dim = _get_config_value(config, "vace_in_dim", 96) + vace_layers = _get_config_value(config, "vace_layers", None) # Get vace_in_dim from base_model_kwargs if present base_model_kwargs = _get_config_value(config, "base_model_kwargs") if base_model_kwargs and "vace_in_dim" in base_model_kwargs: vace_in_dim = base_model_kwargs["vace_in_dim"] + if base_model_kwargs and "vace_layers" in base_model_kwargs: + vace_layers = base_model_kwargs["vace_layers"] self.vace_path = vace_path self.vace_in_dim = vace_in_dim @@ -99,7 +102,9 @@ def _init_vace( # Wrap model with VACE start = time.time() - vace_wrapped_model = CausalVaceWanModel(model, vace_in_dim=vace_in_dim) + vace_wrapped_model = CausalVaceWanModel( + model, vace_in_dim=vace_in_dim, vace_layers=vace_layers + ) logger.info( f"_init_vace: Wrapped model with VACE in {time.time() - start:.3f}s" ) diff --git a/src/scope/core/pipelines/wan2_1/vace/models/attention_blocks.py b/src/scope/core/pipelines/wan2_1/vace/models/attention_blocks.py index 0976f200b..d17973086 100644 --- a/src/scope/core/pipelines/wan2_1/vace/models/attention_blocks.py +++ b/src/scope/core/pipelines/wan2_1/vace/models/attention_blocks.py @@ -21,12 +21,17 @@ def create_vace_attention_block_class(base_attention_block_class): A VaceWanAttentionBlock class that inherits from the given base """ + base_module = getattr(base_attention_block_class, "__module__", "") + needs_scratch_kv_cache = "krea_realtime_video" in base_module + class VaceWanAttentionBlock(base_attention_block_class): """VACE attention block with zero-initialized projection layers for hint injection.""" def __init__(self, *args, block_id=0, **kwargs): super().__init__(*args, **kwargs) self.block_id = block_id + self._vace_kv_cache = None + self._vace_kv_cache_key = None # Initialize projection layers for hint accumulation # Duck typing: assume self.dim exists from base class @@ -39,6 +44,50 @@ def __init__(self, *args, block_id=0, **kwargs): nn.init.zeros_(self.after_proj.weight) nn.init.zeros_(self.after_proj.bias) + def _get_or_create_vace_kv_cache(self, x: torch.Tensor) -> dict: + """Return a scratch KV cache for running the block with kv_cache enabled. + + Krea's self-attn implementation expects kv_cache to be present even in + non-streaming contexts. VACE blocks use kv_cache only as a scratch buffer. + """ + if not hasattr(self, "self_attn"): + raise RuntimeError( + "VaceWanAttentionBlock: expected base block to define `self_attn`" + ) + + batch = int(x.shape[0]) + seq_len = int(x.shape[1]) + num_heads = int(getattr(self.self_attn, "num_heads", getattr(self, "num_heads"))) + head_dim = int(getattr(self.self_attn, "head_dim", self.dim // num_heads)) + + key = (batch, seq_len, num_heads, head_dim, str(x.dtype), str(x.device)) + if self._vace_kv_cache is None or self._vace_kv_cache_key != key: + self._vace_kv_cache_key = key + self._vace_kv_cache = { + "k": torch.empty( + (batch, seq_len, num_heads, head_dim), + dtype=x.dtype, + device=x.device, + ).contiguous(), + "v": torch.empty( + (batch, seq_len, num_heads, head_dim), + dtype=x.dtype, + device=x.device, + ).contiguous(), + # Indices are used in Python control flow / slicing; keep them on CPU. + "global_end_index": torch.tensor([0], dtype=torch.long), + "local_end_index": torch.tensor([0], dtype=torch.long), + } + else: + for name in ("global_end_index", "local_end_index"): + t = self._vace_kv_cache.get(name) + if isinstance(t, torch.Tensor): + t.fill_(0) + else: + self._vace_kv_cache[name] = torch.tensor([0], dtype=torch.long) + + return self._vace_kv_cache + def forward_vace( self, c, @@ -78,7 +127,18 @@ def forward_vace( c = all_c.pop(-1) # Run standard transformer block on current context - # VACE blocks don't use caching since they process reference images once + # Most Wan2.1 attention blocks can run without kv_cache for VACE hint generation. + # Some attention block implementations expect kv_cache to be present, so we provide + # a scratch kv_cache to keep the forward path intact. + kv_cache = self._get_or_create_vace_kv_cache(c) if needs_scratch_kv_cache else None + forward_kwargs = { + "kv_cache": kv_cache, + "crossattn_cache": None, + "current_start": 0, + } + # kv_cache_attention_bias is Krea-specific, don't pass to other pipelines + if needs_scratch_kv_cache: + forward_kwargs["kv_cache_attention_bias"] = 1.0 c = super().forward( c, e, @@ -88,9 +148,7 @@ def forward_vace( context, context_lens, block_mask, - kv_cache=None, - crossattn_cache=None, - current_start=0, + **forward_kwargs, ) # Handle case where block returns tuple (shouldn't happen with kv_cache=None) diff --git a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py index 662b6b23d..9a5a74c01 100644 --- a/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py +++ b/src/scope/core/pipelines/wan2_1/vace/models/causal_vace_model.py @@ -1,7 +1,6 @@ # Modified from https://github.com/ali-vilab/VACE/blob/48eb44f1c4be87cc65a98bff985a26976841e9f3/vace/models/wan/modules/model.py # Adapted for causal/autoregressive generation with factory pattern # Pipeline-agnostic using duck typing - works with any CausalWanModel -import inspect import math import torch @@ -87,6 +86,14 @@ def __init__( # Get the original block class BEFORE replacing blocks self._original_block_class = type(causal_wan_model.blocks[0]) + import inspect + + block_forward_params = inspect.signature(self._original_block_class.forward).parameters + self._block_forward_accepts_cache_start = "cache_start" in block_forward_params + self._block_forward_accepts_current_end = "current_end" in block_forward_params + self._block_forward_accepts_kv_cache_attention_bias = ( + "kv_cache_attention_bias" in block_forward_params + ) # Create factory-generated classes for this pipeline's block type self._BaseWanAttentionBlock = create_base_attention_block_class( @@ -109,10 +116,46 @@ def __init__( kernel_size=self.patch_size, stride=self.patch_size, ) + # Cache: VACE patch-embedded context is constant across denoise steps within a chunk. + # Cache the patch embedding + flatten/pad work to avoid re-running it per timestep. + self._cached_vace_context_key: tuple | None = None + self._cached_vace_context_tokens: torch.Tensor | None = None + + def _prepare_vace_context_tokens( + self, vace_context: list[torch.Tensor], seq_len: int + ) -> torch.Tensor: + # Embed VACE context + c = [self._vace_patch_embed(u) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] - # Cache block forward signature for dynamic parameter filtering - # This allows the VACE model to work with any CausalWanModel implementation - self._block_forward_params = self._get_block_forward_params() + # Pad to seq_len + c = torch.cat( + [ + torch.cat( + [u, u.new_zeros(1, max(0, seq_len - u.size(1)), u.size(2))], dim=1 + ) + for u in c + ] + ) + return c + + def _get_cached_vace_context_tokens( + self, vace_context: list[torch.Tensor], seq_len: int + ) -> torch.Tensor: + # Invalidate when the backing tensors change or the patch-embedding weights change. + bias = self.vace_patch_embedding.bias + key = ( + int(seq_len), + int(getattr(self.vace_patch_embedding.weight, "_version", 0)), + int(getattr(bias, "_version", 0)) if bias is not None else None, + tuple(id(u) for u in vace_context), + ) + if key == self._cached_vace_context_key and self._cached_vace_context_tokens is not None: + return self._cached_vace_context_tokens + c = self._prepare_vace_context_tokens(vace_context, seq_len) + self._cached_vace_context_key = key + self._cached_vace_context_tokens = c + return c def _get_block_init_kwargs(self): """Get initialization kwargs for creating new blocks. @@ -135,6 +178,8 @@ def _get_block_init_kwargs(self): } # Add pipeline-specific kwargs based on what the original block class expects + import inspect + sig = inspect.signature(self._original_block_class.__init__) params = sig.parameters @@ -147,71 +192,15 @@ def _get_block_init_kwargs(self): return kwargs - def _get_block_forward_params(self): - """Get the set of parameter names accepted by the block's forward method. - - Inspects the original block class's forward signature to determine which - parameters should be passed through to blocks. This allows the VACE model - to work with any CausalWanModel implementation without hardcoding parameter names. - - Returns: - set: Parameter names accepted by block.forward(), or None if the block - accepts **kwargs (VAR_KEYWORD) and can handle any parameters. - """ - sig = inspect.signature(self._original_block_class.forward) - - # If block accepts **kwargs, return None to indicate all params are accepted - has_var_keyword = any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() - ) - if has_var_keyword: - return None - - return set(sig.parameters.keys()) - - def _filter_block_kwargs(self, block_kwargs, block_index): - """Filter and prepare kwargs for a specific block. - - Handles two types of parameters: - 1. Per-block indexed: Lists with length matching num_blocks (e.g., kv_bank) - These get indexed with block_index. - 2. Shared: Scalar/other values passed to all blocks as-is - - Only includes parameters that the block's forward method accepts. - - Args: - block_kwargs: Dict of additional kwargs from _forward_inference - block_index: Index of the current block - - Returns: - Dict of kwargs filtered and prepared for this specific block - """ - if not block_kwargs: - return {} - - filtered = {} - for key, value in block_kwargs.items(): - # Skip if block doesn't accept this parameter - if ( - self._block_forward_params is not None - and key not in self._block_forward_params - ): - continue - - # Check if this is a per-block indexed parameter (list matching block count) - if isinstance(value, list | tuple) and len(value) == self.num_layers: - filtered[key] = value[block_index] - else: - filtered[key] = value - - return filtered - def _replace_blocks_with_hint_injection_support(self): """Replace blocks with BaseWanAttentionBlock to support hint injection. Creates new block instances of the factory-generated class and copies weights from the original blocks. Uses proper inheritance (not composition), so state_dict paths are preserved. + + Memory-optimized: replaces blocks one at a time to avoid doubling memory + usage when wrapping large models (e.g. 14B). """ original_blocks = self.causal_wan_model.blocks @@ -222,35 +211,46 @@ def _replace_blocks_with_hint_injection_support(self): # Get initialization kwargs block_kwargs = self._get_block_init_kwargs() - # Create new blocks with hint injection support + # Replace blocks one-at-a-time to minimize peak memory usage. new_blocks = nn.ModuleList() for i in range(self.num_layers): block_id = self.vace_layers_mapping[i] if i in self.vace_layers else None - new_block = self._BaseWanAttentionBlock( - **block_kwargs, - block_id=block_id, - ) - new_blocks.append(new_block) + orig_block = original_blocks[i] - # Set to eval mode and move to correct device/dtype - new_blocks.eval() - new_blocks.to(device=orig_device, dtype=orig_dtype) + with torch.device("cpu"): + new_block = self._BaseWanAttentionBlock( + **block_kwargs, + block_id=block_id, + ) - # Copy weights from original blocks - for _i, (orig_block, new_block) in enumerate( - zip(original_blocks, new_blocks, strict=False) - ): orig_state = orig_block.state_dict() new_state = new_block.state_dict() saved_block_id = new_block.block_id for key in orig_state.keys(): if key in new_state: - new_state[key] = orig_state[key].clone() + new_state[key] = orig_state[key].detach().to("cpu") new_block.load_state_dict(new_state, strict=False, assign=True) new_block.block_id = saved_block_id + # Drop the original block reference early so its parameters can be freed. + # This avoids a full-model 2x peak during wrapping. + original_blocks[i] = nn.Identity() + del orig_block + del orig_state + del new_state + + new_block = new_block.to(device=orig_device, dtype=orig_dtype) + new_block.eval() + new_blocks.append(new_block) + + if torch.cuda.is_available() and i % 10 == 0: + torch.cuda.empty_cache() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Replace blocks in wrapped model self.causal_wan_model.blocks = new_blocks @@ -258,25 +258,28 @@ def _replace_blocks_with_hint_injection_support(self): self.blocks = new_blocks def _create_vace_blocks(self): - """Create VACE blocks for parallel processing of reference images.""" - # Get device and dtype from existing blocks + """Create VACE blocks for parallel processing of reference images. + + Create on CPU by default; the owning pipeline can move these to the + target (device, dtype) before loading VACE weights. + """ + # Get dtype from existing blocks orig_dtype = next(self.blocks[0].parameters()).dtype - orig_device = next(self.blocks[0].parameters()).device # Get initialization kwargs block_kwargs = self._get_block_init_kwargs() - # Create VACE blocks + # Create VACE blocks on CPU to minimize peak memory usage during init. vace_blocks = nn.ModuleList() - for block_id in range(len(self.vace_layers)): - vace_block = self._VaceWanAttentionBlock( - **block_kwargs, - block_id=block_id, - ) - vace_blocks.append(vace_block) + with torch.device("cpu"): + for block_id in range(len(self.vace_layers)): + vace_block = self._VaceWanAttentionBlock( + **block_kwargs, + block_id=block_id, + ) + vace_blocks.append(vace_block) - # Move to correct device/dtype - vace_blocks.to(device=orig_device, dtype=orig_dtype) + vace_blocks.to(dtype=orig_dtype) self.vace_blocks = vace_blocks @@ -295,25 +298,7 @@ def forward_vace( crossattn_cache, ): """Process VACE context to generate hints.""" - # Get target dtype from vace_patch_embedding parameters - target_dtype = next(self.vace_patch_embedding.parameters()).dtype - - # Convert all VACE context to model dtype first - vace_context_converted = [u.to(dtype=target_dtype) for u in vace_context] - - # Embed VACE context - c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context_converted] - c = [u.flatten(2).transpose(1, 2) for u in c] - - # Pad to seq_len - c = torch.cat( - [ - torch.cat( - [u, u.new_zeros(1, max(0, seq_len - u.size(1)), u.size(2))], dim=1 - ) - for u in c - ] - ) + c = self._get_cached_vace_context_tokens(vace_context, seq_len) # Process through VACE blocks for _block_idx, block in enumerate(self.vace_blocks): @@ -334,6 +319,49 @@ def forward_vace( hints = torch.unbind(c)[:-1] return hints + def _patch_embed(self, u: torch.Tensor) -> torch.Tensor: + """Patch-embed a single latent sample, preferring pipeline fastpaths. + + Some pipelines (e.g. krea_realtime_video) provide a Conv3d(t=1) → Conv2d + fastpath to avoid slow Conv3d implementations on some backends. When present, prefer it. + """ + patch_embed = getattr(self.causal_wan_model, "_patch_embed", None) + if callable(patch_embed): + return patch_embed(u) + return self.causal_wan_model.patch_embedding(u.unsqueeze(0)) + + def _vace_patch_embed(self, u: torch.Tensor) -> torch.Tensor: + """Patch-embed a single VACE context sample. + + VACE uses a Conv3d patch embedding with the same (t,h,w) patch size as the + base model. When t==1, apply an equivalent Conv2d per frame to avoid slow + Conv3d paths on some backends. + """ + u = u.unsqueeze(0) # [1, C, F, H, W] + try: + t_patch, h_patch, w_patch = self.patch_size + except Exception: + return self.vace_patch_embedding(u) + + if int(t_patch) != 1: + return self.vace_patch_embedding(u) + + b, c, f, h, w = u.shape + u2 = u.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w) # [B*F, C, H, W] + + out2 = torch.nn.functional.conv2d( + u2, + self.vace_patch_embedding.weight.squeeze(2), + bias=self.vace_patch_embedding.bias, + stride=(int(h_patch), int(w_patch)), + padding=0, + ) + + out = out2.reshape(b, f, out2.shape[1], out2.shape[2], out2.shape[3]).permute( + 0, 2, 1, 3, 4 + ) + return out + def _forward_inference( self, x, @@ -347,6 +375,9 @@ def _forward_inference( kv_cache=None, crossattn_cache=None, current_start=0, + current_end=0, + cache_start=0, + kv_cache_attention_bias=1.0, **block_kwargs, ): """Forward pass with optional VACE conditioning.""" @@ -361,7 +392,7 @@ def _forward_inference( x = [torch.cat([u, v], dim=0) for u, v in zip(x, y, strict=False)] # Embeddings - x = [self.causal_wan_model.patch_embedding(u.unsqueeze(0)) for u in x] + x = [self._patch_embed(u) for u in x] grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x] ) @@ -421,8 +452,8 @@ def _forward_inference( crossattn_cache, ) - # Base arguments for transformer blocks (shared across all blocks) - base_kwargs = { + # Arguments for transformer blocks + kwargs = { "e": e0, "seq_lens": seq_lens, "grid_sizes": grid_sizes, @@ -443,23 +474,24 @@ def custom_forward(*inputs, **kwargs): # Process through blocks cache_update_infos = [] for block_index, block in enumerate(self.blocks): - # Build per-block kwargs: - # - kv_cache/crossattn_cache are always per-block indexed - # - Additional block_kwargs are dynamically filtered based on block's signature - # and automatically indexed if they're per-block lists - filtered_block_kwargs = self._filter_block_kwargs(block_kwargs, block_index) - per_block_kwargs = { + block_call_kwargs = { "kv_cache": kv_cache[block_index], "current_start": current_start, - **filtered_block_kwargs, + **block_kwargs, } + if self._block_forward_accepts_current_end: + block_call_kwargs["current_end"] = current_end + if self._block_forward_accepts_cache_start: + block_call_kwargs["cache_start"] = cache_start + if self._block_forward_accepts_kv_cache_attention_bias: + block_call_kwargs["kv_cache_attention_bias"] = kv_cache_attention_bias if torch.is_grad_enabled() and self.causal_wan_model.gradient_checkpointing: - kwargs = {**base_kwargs, **per_block_kwargs} result = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, **kwargs, + **block_call_kwargs, use_reentrant=False, ) if kv_cache is not None and isinstance(result, tuple): @@ -468,9 +500,8 @@ def custom_forward(*inputs, **kwargs): else: x = result else: - per_block_kwargs["crossattn_cache"] = crossattn_cache[block_index] - kwargs = {**base_kwargs, **per_block_kwargs} - result = block(x, **kwargs) + block_call_kwargs["crossattn_cache"] = crossattn_cache[block_index] + result = block(x, **kwargs, **block_call_kwargs) if kv_cache is not None and isinstance(result, tuple): x, block_cache_update_info = result cache_update_infos.append((block_index, block_cache_update_info)) @@ -478,9 +509,7 @@ def custom_forward(*inputs, **kwargs): x = result if kv_cache is not None and cache_update_infos: - self.causal_wan_model._apply_cache_updates( - kv_cache, cache_update_infos, **block_kwargs - ) + self.causal_wan_model._apply_cache_updates(kv_cache, cache_update_infos) x = self.causal_wan_model.head( x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2) diff --git a/src/scope/core/pipelines/wan2_1/vace/utils/encoding.py b/src/scope/core/pipelines/wan2_1/vace/utils/encoding.py index 0823e7af3..4007c99bb 100644 --- a/src/scope/core/pipelines/wan2_1/vace/utils/encoding.py +++ b/src/scope/core/pipelines/wan2_1/vace/utils/encoding.py @@ -11,12 +11,45 @@ - Standard path: vace_encode_frames -> vace_encode_masks -> vace_latent """ +import os + import torch import torch.nn.functional as F import torchvision.transforms.functional as TF from PIL import Image +_ZERO_INACTIVE_LATENT_CACHE: dict[tuple, torch.Tensor] = {} +_FULL_MASK_ENCODED_MASK_CACHE: dict[tuple, torch.Tensor] = {} + + +def _get_zero_inactive_latent( + vae, + *, + channels: int, + frames: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + key = (id(vae), channels, frames, height, width, str(device), str(dtype)) + cached = _ZERO_INACTIVE_LATENT_CACHE.get(key) + if cached is not None: + return cached + + zeros_pixel = torch.zeros( + (1, channels, frames, height, width), + device=device, + dtype=dtype, + ) + zeros_latent_out = vae.encode_to_latent(zeros_pixel, use_cache=False) + zeros_latent = zeros_latent_out[0].permute(1, 0, 2, 3).contiguous() + + _ZERO_INACTIVE_LATENT_CACHE[key] = zeros_latent + return zeros_latent + + def vace_encode_frames( vae, frames, @@ -24,14 +57,14 @@ def vace_encode_frames( masks=None, pad_to_96=True, use_cache=True, - inactive_cache=None, - reactive_cache=None, + *, + full_mask: bool = False, ): """ Encode frames and reference images via VAE for VACE conditioning. Args: - vae: VAE model wrapper (TAEWrapper or WanVAEWrapper) + vae: VAE model wrapper frames: List of video frames [B, C, F, H, W] or single frame [C, F, H, W] ref_images: List of reference images, one list per batch element Each element is a list of reference images [C, 1, H, W] @@ -39,28 +72,26 @@ def vace_encode_frames( pad_to_96: Whether to pad to 96 channels (default True). Set False when masks will be added later. use_cache: Whether to use streaming encode cache for frames (default True). Set False for one-off encoding (e.g., reference images only mode). - When masks are provided, caching is handled automatically based on - mask content: conditioning mode (all-1s masks) uses cache for both - streams, while extension/inpainting mode (mixed masks) skips cache - for reactive to weaken temporal blending. - inactive_cache: Explicit encoder cache for inactive stream (TAE only). - Create via vae.create_encoder_cache(). Reuse across chunks - for temporal continuity. If None, uses VAE's default cache. - reactive_cache: Explicit encoder cache for reactive stream (TAE only). - Must be separate from inactive_cache to prevent memory pollution. + full_mask: If True, masks represent an all-ones full-frame mask (white=generate). + Enables optional fastpaths for VACE when `SCOPE_VACE_FULL_MASK_FASTPATH=1`. Returns: List of concatenated latents [ref_latents + frame_latents] - - Note: - For TAE with masked encoding (depth/flow/pose/inpainting), you MUST provide - separate inactive_cache and reactive_cache to prevent MemBlock memory pollution. - WanVAE ignores these caches as its CausalConv3d doesn't have this issue. """ + use_full_mask_fastpath = full_mask and os.getenv("SCOPE_VACE_FULL_MASK_FASTPATH", "0") == "1" + uses_mask_path = masks is not None or use_full_mask_fastpath + + frames_stacked: torch.Tensor | None = None + if isinstance(frames, torch.Tensor): + frames_stacked = frames + batch_size = int(frames_stacked.shape[0]) + else: + batch_size = len(frames) + if ref_images is None: - ref_images = [None] * len(frames) + ref_images = [None] * batch_size else: - assert len(frames) == len(ref_images) + assert batch_size == len(ref_images) # Get VAE dtype for consistent encoding vae_dtype = next(vae.parameters()).dtype @@ -68,47 +99,78 @@ def vace_encode_frames( # Encode frames (with optional masking) # Note: WanVAEWrapper expects [B, C, F, H, W] and returns [B, F, C, H, W] if masks is None: - # Single encode path - no masks, just encode frames directly - # Stack list of [C, F, H, W] -> [B, C, F, H, W] - frames_stacked = torch.stack(frames, dim=0) - frames_stacked = frames_stacked.to(dtype=vae_dtype) - # Use provided cache setting (use_cache=False for reference-only mode with dummy frames) - latents_out = vae.encode_to_latent(frames_stacked, use_cache=use_cache) - # Convert [B, F, C, H, W] -> list of [C, F, H, W] (transpose to channel-first) - latents = [lat.permute(1, 0, 2, 3) for lat in latents_out] + if use_full_mask_fastpath: + # Full-mask (all ones) means the inactive branch is identically the encoding of + # all-zero pixels. Avoid materializing full-resolution masks and avoid re-encoding + # the inactive branch every chunk. + if frames_stacked is None: + frames_stacked = torch.stack(frames, dim=0) + frames_stacked = frames_stacked.to(dtype=vae_dtype) + + reactive_out = vae.encode_to_latent(frames_stacked, use_cache=use_cache) + reactive_transposed = [lat.permute(1, 0, 2, 3) for lat in reactive_out] + + zero_latent = _get_zero_inactive_latent( + vae, + channels=frames_stacked.shape[1], + frames=frames_stacked.shape[2], + height=frames_stacked.shape[3], + width=frames_stacked.shape[4], + device=frames_stacked.device, + dtype=vae_dtype, + ) + latents = [torch.cat((zero_latent, c), dim=0) for c in reactive_transposed] + else: + # Stack list of [C, F, H, W] -> [B, C, F, H, W] + if frames_stacked is None: + frames_stacked = torch.stack(frames, dim=0) + frames_stacked = frames_stacked.to(dtype=vae_dtype) + # Use provided cache setting (use_cache=False for reference-only mode with dummy frames) + latents_out = vae.encode_to_latent(frames_stacked, use_cache=use_cache) + # Convert [B, F, C, H, W] -> list of [C, F, H, W] (transpose to channel-first) + latents = [lat.permute(1, 0, 2, 3) for lat in latents_out] else: - # Dual encode path for masked video generation - # Each stream needs its own cache to prevent TAE's MemBlock memory pollution - masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] - inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks, strict=False)] - reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks, strict=False)] - - inactive_stacked = torch.stack(inactive, dim=0).to(dtype=vae_dtype) - reactive_stacked = torch.stack(reactive, dim=0).to(dtype=vae_dtype) - - # Auto-detect mode based on mask content and handle caching appropriately: - # - Conditioning mode (mask all 1s): inactive=zeros, reactive=content → both use cache - # - Extension/inpainting mode (mixed mask): reactive skips cache to weaken temporal blending - is_conditioning_mode = all((m > 0.5).all() for m in masks) - - # Encode with separate caches for temporal continuity without cross-contamination - inactive_out = vae.encode_to_latent( - inactive_stacked, use_cache=True, encoder_cache=inactive_cache - ) - reactive_out = vae.encode_to_latent( - reactive_stacked, - use_cache=is_conditioning_mode, - encoder_cache=reactive_cache, - ) - - # Transpose [B, F, C, H, W] -> [B, C, F, H, W] and concatenate along channel dim - inactive_transposed = [lat.permute(1, 0, 2, 3) for lat in inactive_out] - reactive_transposed = [lat.permute(1, 0, 2, 3) for lat in reactive_out] - - latents = [ - torch.cat((u, c), dim=0) - for u, c in zip(inactive_transposed, reactive_transposed, strict=False) - ] + if use_full_mask_fastpath: + # For the common "mask omitted" case we currently default to an all-ones + # mask, which makes the inactive branch identically zero. We can avoid + # re-encoding that inactive branch every chunk by caching a single + # zero-latent (per shape/device/dtype) and only encoding the reactive + # (full) frames. + if frames_stacked is None: + frames_stacked = torch.stack(frames, dim=0) + frames_stacked = frames_stacked.to(dtype=vae_dtype) + reactive_out = vae.encode_to_latent(frames_stacked, use_cache=use_cache) + reactive_transposed = [lat.permute(1, 0, 2, 3) for lat in reactive_out] + + zero_latent = _get_zero_inactive_latent( + vae, + channels=frames_stacked.shape[1], + frames=frames_stacked.shape[2], + height=frames_stacked.shape[3], + width=frames_stacked.shape[4], + device=frames_stacked.device, + dtype=vae_dtype, + ) + latents = [torch.cat((zero_latent, c), dim=0) for c in reactive_transposed] + else: + masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] + inactive = [ + i * (1 - m) + 0 * m for i, m in zip(frames, masks, strict=False) + ] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks, strict=False)] + inactive_stacked = torch.stack(inactive, dim=0).to(dtype=vae_dtype) + reactive_stacked = torch.stack(reactive, dim=0).to(dtype=vae_dtype) + # Default to cache=True for streaming consistency, but allow stateless + # encoding (use_cache=False) for hybrid modes that also encode `video`. + inactive_out = vae.encode_to_latent(inactive_stacked, use_cache=use_cache) + reactive_out = vae.encode_to_latent(reactive_stacked, use_cache=use_cache) + # Transpose [B, F, C, H, W] -> [B, C, F, H, W] and concatenate along channel dim + inactive_transposed = [lat.permute(1, 0, 2, 3) for lat in inactive_out] + reactive_transposed = [lat.permute(1, 0, 2, 3) for lat in reactive_out] + latents = [ + torch.cat((u, c), dim=0) + for u, c in zip(inactive_transposed, reactive_transposed, strict=False) + ] # Concatenate reference images if provided cat_latents = [] @@ -124,7 +186,7 @@ def vace_encode_frames( # Get first batch element and transpose: [num_refs, C, H, W] -> [C, num_refs, H, W] ref_latent_batch = ref_latent_out[0].permute(1, 0, 2, 3) - if masks is not None: + if uses_mask_path: # Pad reference latents with zeros for mask channel zeros = torch.zeros_like(ref_latent_batch) ref_latent_batch = torch.cat((ref_latent_batch, zeros), dim=0) @@ -152,7 +214,64 @@ def vace_encode_frames( return cat_latents -def vace_encode_masks(masks, ref_images=None, vae_stride=(4, 8, 8)): +def _get_full_mask_encoded_mask( + *, + num_frames: int, + height: int, + width: int, + ref_length: int, + device: torch.device, + dtype: torch.dtype, + vae_stride: tuple[int, int, int], +) -> torch.Tensor: + key = ( + num_frames, + height, + width, + ref_length, + str(device), + str(dtype), + tuple(int(v) for v in vae_stride), + ) + cached = _FULL_MASK_ENCODED_MASK_CACHE.get(key) + if cached is not None: + return cached + + new_depth = int((num_frames + (vae_stride[0] - 1)) // vae_stride[0]) + latent_height = 2 * (int(height) // (vae_stride[1] * 2)) + latent_width = 2 * (int(width) // (vae_stride[2] * 2)) + mask_channels = int(vae_stride[1]) * int(vae_stride[2]) + + base = torch.ones( + (mask_channels, new_depth, latent_height, latent_width), + device=device, + dtype=dtype, + ) + if ref_length > 0: + pad = torch.zeros( + (mask_channels, int(ref_length), latent_height, latent_width), + device=device, + dtype=dtype, + ) + base = torch.cat((pad, base), dim=1) + + _FULL_MASK_ENCODED_MASK_CACHE[key] = base + return base + + +def vace_encode_masks( + masks, + ref_images=None, + vae_stride=(4, 8, 8), + *, + full_mask: bool = False, + batch_size: int | None = None, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +): """ Encode masks for VACE context at VAE latent resolution. @@ -164,6 +283,64 @@ def vace_encode_masks(masks, ref_images=None, vae_stride=(4, 8, 8)): Returns: List of encoded masks at latent resolution """ + use_full_mask_fastpath = full_mask and os.getenv("SCOPE_VACE_FULL_MASK_FASTPATH", "0") == "1" + + if use_full_mask_fastpath: + inferred_batch = batch_size + if inferred_batch is None: + if masks is not None: + inferred_batch = len(masks) + elif ref_images is not None: + inferred_batch = len(ref_images) + if inferred_batch is None: + raise ValueError( + "vace_encode_masks: batch_size is required when masks=None and ref_images=None" + ) + + inferred_num_frames = num_frames + inferred_height = height + inferred_width = width + inferred_device = device + inferred_dtype = dtype + if masks is not None: + sample = masks[0] + if isinstance(sample, torch.Tensor) and sample.ndim == 4: + _, inferred_num_frames, inferred_height, inferred_width = sample.shape + inferred_device = sample.device + inferred_dtype = sample.dtype + + if ( + inferred_num_frames is None + or inferred_height is None + or inferred_width is None + or inferred_device is None + or inferred_dtype is None + ): + raise ValueError( + "vace_encode_masks: num_frames/height/width/device/dtype are required for full_mask fastpath" + ) + + if ref_images is None: + ref_images = [None] * int(inferred_batch) + else: + assert int(inferred_batch) == len(ref_images) + + result_masks = [] + for refs in ref_images: + ref_len = len(refs) if refs is not None else 0 + result_masks.append( + _get_full_mask_encoded_mask( + num_frames=int(inferred_num_frames), + height=int(inferred_height), + width=int(inferred_width), + ref_length=int(ref_len), + device=inferred_device, + dtype=inferred_dtype, + vae_stride=tuple(int(v) for v in vae_stride), + ) + ) + return result_masks + if ref_images is None: ref_images = [None] * len(masks) else: diff --git a/src/scope/realtime/__init__.py b/src/scope/realtime/__init__.py new file mode 100644 index 000000000..98cdf9f1c --- /dev/null +++ b/src/scope/realtime/__init__.py @@ -0,0 +1,94 @@ +"""Realtime control plane for video generation. + +This module provides the control layer for the realtime video generation system, +separating control semantics from the underlying pipeline implementation. + +Key components: +- ControlState: Immediate control surface for the generator +- ControlBus: Event queue with chunk-boundary semantics +- PipelineAdapter: Maps ControlState to pipeline kwargs, handles continuity +- GeneratorDriver: Tick loop that owns the pipeline and applies control events +""" + +from scope.realtime.control_state import ( + CompiledPrompt, + ControlState, + GenerationMode, +) +from scope.realtime.control_bus import ( + ApplyMode, + ControlBus, + ControlEvent, + EventType, + pause_event, + prompt_event, + world_state_event, +) +from scope.realtime.pipeline_adapter import PipelineAdapter +from scope.realtime.generator_driver import ( + DriverState, + GenerationResult, + GeneratorDriver, +) +from scope.realtime.style_manifest import StyleManifest, StyleRegistry +from scope.realtime.world_state import ( + BeatType, + CameraIntent, + CharacterState, + PropState, + WorldState, + create_simple_world, + create_character_scene, +) +from scope.realtime.prompt_compiler import ( + CompiledPrompt as StyleCompiledPrompt, + PromptCompiler, + TemplateCompiler, + LLMCompiler, + CachedCompiler, + InstructionSheet, + create_compiler, +) +from scope.realtime.prompt_playlist import PromptPlaylist + +__all__ = [ + # control_state + "CompiledPrompt", + "ControlState", + "GenerationMode", + # control_bus + "ApplyMode", + "ControlBus", + "ControlEvent", + "EventType", + "pause_event", + "prompt_event", + "world_state_event", + # pipeline_adapter + "PipelineAdapter", + # generator_driver + "DriverState", + "GenerationResult", + "GeneratorDriver", + # style_manifest + "StyleManifest", + "StyleRegistry", + # world_state + "BeatType", + "CameraIntent", + "CharacterState", + "PropState", + "WorldState", + "create_simple_world", + "create_character_scene", + # prompt_compiler + "StyleCompiledPrompt", + "PromptCompiler", + "TemplateCompiler", + "LLMCompiler", + "CachedCompiler", + "InstructionSheet", + "create_compiler", + # prompt_playlist + "PromptPlaylist", +] diff --git a/src/scope/realtime/control_bus.py b/src/scope/realtime/control_bus.py new file mode 100644 index 000000000..a51077abe --- /dev/null +++ b/src/scope/realtime/control_bus.py @@ -0,0 +1,231 @@ +"""Control bus - event queue with chunk-boundary semantics. + +All control is chunk-transactional: +- Generator state is mutated only at chunk boundaries (between pipeline calls) +- Events are applied in deterministic order and recorded with chunk index +- "Immediate" events only apply immediately when paused (safe without generating) + +Deterministic application order at each boundary: +1. Lifecycle (STOP, PAUSE, RESUME, STEP) +2. Snapshot/restore (RESTORE_SNAPSHOT, SNAPSHOT_REQUEST) +3. Style (SET_STYLE_MANIFEST) - rebind compiler +4. World (SET_WORLD_STATE) - then recompile if compiler active +5. Prompt/transition (SET_PROMPT) - direct override; may include transition +6. Runtime params (SET_DENOISE_STEPS, SET_SEED, SET_LORA_SCALES, ...) +""" + +import time +from collections import deque +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + + +class EventType(Enum): + """Types of control events.""" + + # Prompt and style + SET_PROMPT = "set_prompt" + SET_WORLD_STATE = "set_world_state" + SET_STYLE_MANIFEST = "set_style_manifest" + SET_LORA_SCALES = "set_lora_scales" + + # Generation parameters + SET_DENOISE_STEPS = "set_denoise_steps" + SET_SEED = "set_seed" + + # Lifecycle + PAUSE = "pause" + RESUME = "resume" + STEP = "step" + STOP = "stop" + + # Branching + SNAPSHOT_REQUEST = "snapshot_request" + FORK_REQUEST = "fork_request" + ROLLOUT_REQUEST = "rollout_request" + SELECT_BRANCH = "select_branch" + RESTORE_SNAPSHOT = "restore_snapshot" + + +class ApplyMode(Enum): + """When to apply an event.""" + + NEXT_BOUNDARY = "next_boundary" # Apply at start of next chunk + IMMEDIATE_IF_PAUSED = "immediate" # Apply now if paused, else next boundary + + +# Deterministic ordering for event types at chunk boundaries +# Lower number = applied first +EVENT_TYPE_ORDER: dict[EventType, int] = { + # Lifecycle first + EventType.STOP: 0, + EventType.PAUSE: 1, + EventType.RESUME: 2, + EventType.STEP: 3, + # Snapshot/restore second + EventType.RESTORE_SNAPSHOT: 10, + EventType.SNAPSHOT_REQUEST: 11, + # Style third + EventType.SET_STYLE_MANIFEST: 20, + # World fourth + EventType.SET_WORLD_STATE: 30, + # Prompt/transition fifth + EventType.SET_PROMPT: 40, + # Runtime params last + EventType.SET_DENOISE_STEPS: 50, + EventType.SET_SEED: 51, + EventType.SET_LORA_SCALES: 52, + # Branching requests (processed after state updates) + EventType.FORK_REQUEST: 60, + EventType.ROLLOUT_REQUEST: 61, + EventType.SELECT_BRANCH: 62, +} + + +@dataclass +class ControlEvent: + """A single control event with timing and application semantics.""" + + type: EventType + payload: dict = field(default_factory=dict) + apply_mode: ApplyMode = ApplyMode.NEXT_BOUNDARY + timestamp: float = field(default_factory=time.time) + event_id: str = field(default_factory=lambda: str(time.time_ns())) + + # For debugging/replay + source: str = "api" # "api", "vlm", "timeline", "dev_console" + + # Set when event is applied (for history tracking) + applied_chunk_index: Optional[int] = None + + +@dataclass +class ControlBus: + """Timestamped event queue with chunk-boundary semantics. + + Events are queued immediately but applied at chunk boundaries, + ensuring the generator always sees consistent state. + """ + + pending: deque[ControlEvent] = field(default_factory=deque) + history: list[ControlEvent] = field(default_factory=list) + max_history: int = 1000 + + def enqueue( + self, + event_type: EventType, + payload: Optional[dict] = None, + apply_mode: ApplyMode = ApplyMode.NEXT_BOUNDARY, + source: str = "api", + ) -> ControlEvent: + """Add an event to the queue.""" + event = ControlEvent( + type=event_type, + payload=payload or {}, + apply_mode=apply_mode, + source=source, + ) + self.pending.append(event) + return event + + def drain_pending( + self, is_paused: bool = False, chunk_index: Optional[int] = None + ) -> list[ControlEvent]: + """Get all events that should be applied now, in deterministic order. + + Called at chunk boundaries (or immediately if checking for pause-mode events). + + Args: + is_paused: Whether the driver is currently paused + chunk_index: Current chunk index (for history tracking) + + Returns: + Events to apply, sorted by deterministic order + """ + to_apply = [] + remaining = deque() + + for event in self.pending: + should_apply = event.apply_mode == ApplyMode.NEXT_BOUNDARY or ( + event.apply_mode == ApplyMode.IMMEDIATE_IF_PAUSED and is_paused + ) + + if should_apply: + # Record when this event was applied + event.applied_chunk_index = chunk_index + to_apply.append(event) + self._add_to_history(event) + else: + remaining.append(event) + + self.pending = remaining + + # Sort by deterministic order: type order, then timestamp, then event_id + to_apply.sort( + key=lambda e: ( + EVENT_TYPE_ORDER.get(e.type, 999), + e.timestamp, + e.event_id, + ) + ) + + return to_apply + + def _add_to_history(self, event: ControlEvent): + """Store event in history for debugging/replay.""" + self.history.append(event) + if len(self.history) > self.max_history: + self.history = self.history[-self.max_history :] + + def get_history( + self, + since_timestamp: float = 0, + event_types: Optional[list[EventType]] = None, + ) -> list[ControlEvent]: + """Query event history for debugging or replay.""" + filtered = [e for e in self.history if e.timestamp >= since_timestamp] + if event_types: + filtered = [e for e in filtered if e.type in event_types] + return filtered + + def clear_pending(self): + """Clear all pending events (e.g., on stop).""" + self.pending.clear() + + +# Convenience functions for common event patterns + + +def prompt_event( + prompts: list[dict], + transition: Optional[dict] = None, + source: str = "api", +) -> ControlEvent: + """Create a prompt update event.""" + payload = {"prompts": prompts} + if transition is not None: + payload["transition"] = transition + return ControlEvent( + type=EventType.SET_PROMPT, + payload=payload, + source=source, + ) + + +def world_state_event(updates: dict, source: str = "api") -> ControlEvent: + """Create a world state update event.""" + return ControlEvent( + type=EventType.SET_WORLD_STATE, + payload=updates, + source=source, + ) + + +def pause_event(source: str = "api") -> ControlEvent: + """Create a pause event (applies at next chunk boundary).""" + return ControlEvent( + type=EventType.PAUSE, + apply_mode=ApplyMode.NEXT_BOUNDARY, + source=source, + ) diff --git a/src/scope/realtime/control_state.py b/src/scope/realtime/control_state.py new file mode 100644 index 000000000..22ab15394 --- /dev/null +++ b/src/scope/realtime/control_state.py @@ -0,0 +1,104 @@ +"""Control state dataclasses for the realtime control plane. + +ControlState is the immediate control surface for the generator. It can be +populated by the PromptCompiler (from WorldState + StyleManifest) or directly +via the Dev Console. + +CompiledPrompt is the output of the PromptCompiler - what gets sent to the +pipeline after translation through a StyleManifest. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + + +class GenerationMode(Enum): + """Video generation input mode.""" + + T2V = "text_to_video" + V2V = "video_to_video" + + +@dataclass +class CompiledPrompt: + """Output of PromptCompiler - ready for pipeline consumption. + + Attributes: + positive: List of prompt dicts [{"text": "...", "weight": 1.0}] + negative: Negative prompt (NOTE: not consumed by Scope/KREA pipeline) + lora_scales: List of LoRA scale dicts [{"path": "...", "scale": 0.8}] + """ + + positive: list[dict] = field(default_factory=list) + negative: str = "" + lora_scales: list[dict] = field(default_factory=list) + + +@dataclass +class ControlState: + """Immediate control surface for the generator. + + This is the state that directly maps to pipeline kwargs. It can be: + - Populated by PromptCompiler from WorldState + StyleManifest + - Directly overridden via Dev Console for prompt iteration + """ + + # Prompts (output of PromptCompiler, or direct override) + # Shape: [{"text": "...", "weight": 1.0}] + prompts: list[dict] = field(default_factory=list) + + # NOTE: The current Scope/KREA realtime pipeline does not consume negative prompts. + # Keep this field for forward-compatibility with other backends. + negative_prompt: str = "" + + # LoRA configuration (runtime updates via lora_scales; edge-triggered) + # Shape: [{"path": "...", "scale": 0.8}] + lora_scales: list[dict] = field(default_factory=list) + + # Generation parameters + mode: GenerationMode = GenerationMode.T2V + num_frame_per_block: int = 3 # Must match pipeline config + denoising_step_list: list[int] = field( + default_factory=lambda: [1000, 750, 500, 250] + ) + + # Determinism + base_seed: int = 42 + branch_seed_offset: int = 0 # For deterministic branching + + # KV cache behavior (0.3 is KREA default - higher = more stable, less responsive) + kv_cache_attention_bias: float = 0.3 + + # Prompt transitions (pipeline-native; optional) + # Shape matches Scope's `transition` contract: + # {"target_prompts": [...], "num_steps": 4, "temporal_interpolation_method": "linear"} + transition: Optional[dict] = None + + # Pipeline state tracking + current_start_frame: int = 0 + + def effective_seed(self) -> int: + """Compute the effective seed including branch offset.""" + return self.base_seed + self.branch_seed_offset + + def to_pipeline_kwargs(self) -> dict: + """Convert to kwargs for pipeline call. + + NOTE: This produces the BASE kwargs. The PipelineAdapter is responsible + for: + - Adding `init_cache` (driver decides) + - Edge-triggering `lora_scales` (only when changed) + - NOT including `negative_prompt` (not consumed by Scope/KREA) + """ + kwargs = { + "prompts": self.prompts, + "num_frame_per_block": self.num_frame_per_block, + "denoising_step_list": self.denoising_step_list, + "base_seed": self.effective_seed(), + "kv_cache_attention_bias": self.kv_cache_attention_bias, + } + # Include transition if present + if self.transition is not None: + kwargs["transition"] = self.transition + return kwargs diff --git a/src/scope/realtime/gemini_client.py b/src/scope/realtime/gemini_client.py new file mode 100644 index 000000000..c4718e0df --- /dev/null +++ b/src/scope/realtime/gemini_client.py @@ -0,0 +1,508 @@ +""" +Gemini integration for LLMCompiler. + +Provides: +- GeminiCompiler: Implements llm_callable signature for prompt compilation +- GeminiWorldChanger: Natural language WorldState updates +- GeminiPromptJiggler: Prompt variation generator +""" + +from __future__ import annotations + +import logging +import os +import time +import hashlib +import threading +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from scope.realtime.world_state import WorldState + +logger = logging.getLogger(__name__) + +# Default model - Gemini 3 Flash Preview +DEFAULT_MODEL = "gemini-3-flash-preview" + +# Rate limiting +_last_call_time = 0.0 +_min_call_interval = 0.1 # 10 calls/sec max +_rate_limit_lock = threading.Lock() + + +def _rate_limit(): + """Simple rate limiter to avoid hitting API limits.""" + global _last_call_time + with _rate_limit_lock: + now = time.time() + elapsed = now - _last_call_time + if elapsed < _min_call_interval: + time.sleep(_min_call_interval - elapsed) + _last_call_time = time.time() + + +# Prompt jiggle cache (short TTL to reduce redundant LLM calls in live usage) +_JIGGLE_CACHE_TTL_SEC = 10.0 +_JIGGLE_CACHE_MAX_SIZE = 128 +_JIGGLE_SYSTEM_PROMPT_VERSION = "v1" +_jiggle_cache: dict[str, tuple[float, list[str]]] = {} +_jiggle_cache_lock = threading.Lock() + + +def _jiggle_cache_key( + prompt: str, + model: str, + mode: str, + direction: str | None, + intensity: float, + count: int, +) -> str: + dir_norm = (direction or "").strip().lower() + prompt_hash = hashlib.sha256(prompt.encode("utf-8")).hexdigest()[:12] + return ( + f"{_JIGGLE_SYSTEM_PROMPT_VERSION}:{model}:{mode}:{dir_norm}:{intensity:.2f}:{count}:{prompt_hash}" + ) + + +def _get_jiggle_cached(key: str) -> list[str] | None: + now = time.time() + with _jiggle_cache_lock: + entry = _jiggle_cache.get(key) + if entry is None: + return None + ts, variations = entry + if now - ts > _JIGGLE_CACHE_TTL_SEC: + _jiggle_cache.pop(key, None) + return None + return variations + + +def _set_jiggle_cached(key: str, variations: list[str]) -> None: + now = time.time() + with _jiggle_cache_lock: + if len(_jiggle_cache) >= _JIGGLE_CACHE_MAX_SIZE: + oldest_key = min(_jiggle_cache, key=lambda k: _jiggle_cache[k][0]) + _jiggle_cache.pop(oldest_key, None) + _jiggle_cache[key] = (now, variations) + + +def init_client(): + """ + Initialize the Gemini client. + + Reads GEMINI_API_KEY from environment. + + Returns: + google.genai.Client or None if no API key available + """ + api_key = os.environ.get("GEMINI_API_KEY") + if not api_key: + logger.warning("GEMINI_API_KEY not set - Gemini features will be unavailable") + return None + + try: + from google import genai + + return genai.Client(api_key=api_key) + except ImportError: + logger.error("google-genai package not installed") + return None + except Exception as e: + logger.error(f"Failed to initialize Gemini client: {e}") + return None + + +class GeminiCompiler: + """ + Gemini-based LLM compiler. + + Implements the llm_callable signature: (system_prompt, user_message) -> str + """ + + def __init__( + self, + model: str = DEFAULT_MODEL, + temperature: float = 0.4, + max_output_tokens: int = 256, + ): + self.model = model + self.temperature = temperature + self.max_output_tokens = max_output_tokens + self._client = None + + @property + def client(self): + """Lazy-initialize client on first use.""" + if self._client is None: + self._client = init_client() + return self._client + + def __call__(self, system_prompt: str, user_message: str) -> str: + """ + Call Gemini to generate a prompt. + + Matches LLMCompiler's llm_callable signature. + + Args: + system_prompt: System instructions with vocab, examples, etc. + user_message: The WorldState context and request + + Returns: + Generated prompt text + + Raises: + RuntimeError: If no API key or client unavailable + """ + if self.client is None: + raise RuntimeError("Gemini client not available - check GEMINI_API_KEY") + + _rate_limit() + + try: + from google.genai import types + + response = self.client.models.generate_content( + model=self.model, + contents=[user_message], + config=types.GenerateContentConfig( + system_instruction=system_prompt, + temperature=self.temperature, + max_output_tokens=self.max_output_tokens, + response_mime_type="text/plain", + ), + ) + return response.text.strip() + + except Exception as e: + logger.error(f"Gemini compilation failed: {e}") + raise + + +class GeminiWorldChanger: + """ + Natural language WorldState editor. + + Takes an instruction like "make Rooster angry" and returns updated WorldState. + """ + + SYSTEM_PROMPT = """You are a WorldState editor for an animation system. + +Given the current WorldState as JSON and a natural language instruction, +output ONLY valid JSON representing the updated WorldState. + +Rules: +- Make minimal changes - only modify what the instruction specifies +- Preserve all fields not mentioned in the instruction +- Valid emotions: happy, sad, angry, frustrated, shocked, neutral, determined, confused, surprised +- Valid beat types: setup, escalation, climax, payoff, reset, transition +- Valid camera intents: establishing, close_up, medium, wide, low_angle, high_angle, tracking, static +- Character actions should be short action verbs or phrases + +Output ONLY the JSON - no markdown, no explanation, no comments.""" + + def __init__( + self, + model: str = DEFAULT_MODEL, + temperature: float = 0.3, + ): + self.model = model + self.temperature = temperature + self._client = None + + @property + def client(self): + if self._client is None: + self._client = init_client() + return self._client + + def change(self, world_state: WorldState, instruction: str) -> WorldState: + """ + Apply a natural language instruction to update WorldState. + + Args: + world_state: Current world state + instruction: Natural language instruction (e.g., "make Rooster angry") + + Returns: + Updated WorldState + + Raises: + ValueError: If LLM returns invalid JSON + RuntimeError: If Gemini unavailable + """ + from .world_state import WorldState + + if self.client is None: + raise RuntimeError("Gemini client not available - check GEMINI_API_KEY") + + _rate_limit() + + current_json = world_state.model_dump_json(indent=2) + user_message = f"""Current WorldState: +{current_json} + +Instruction: {instruction} + +Output the updated WorldState JSON:""" + + try: + from google.genai import types + + response = self.client.models.generate_content( + model=self.model, + contents=[user_message], + config=types.GenerateContentConfig( + system_instruction=self.SYSTEM_PROMPT, + temperature=self.temperature, + max_output_tokens=2048, + response_mime_type="application/json", + ), + ) + + response_text = response.text.strip() + + # Parse the JSON response + return WorldState.model_validate_json(response_text) + + except Exception as e: + logger.error(f"WorldState change failed: {e}") + raise ValueError(f"Failed to parse LLM response: {e}") from e + + +class GeminiPromptJiggler: + """ + Generates variations of a prompt while preserving meaning. + + Useful for adding visual variety without changing the scene. + """ + + SYSTEM_PROMPT = """You are a prompt variation generator for video generation. + +Given a prompt, create a subtle variation that: +- Preserves all core elements (characters, action, camera, style triggers) +- Adjusts word order, synonyms, or emphasis +- Stays within the same token budget +- Maintains the same visual intent + + Output ONLY the varied prompt - no explanation, no quotes.""" + + ATTENTIONAL_SYSTEM_PROMPT = """You are a prompt variation generator for video generation. + +Given a prompt, create EXACTLY {count} distinct variations that: +- Preserve ALL semantic content (characters, actions, setting, mood, style triggers) +- Adjust word order, emphasis, or phrasing +- Shift what the model will attend to first +- Stay within a similar token budget +- Each variation must be meaningfully different +{direction_clause} + +Output ONLY the varied prompts, one per line. No numbering, bullets, quotes, or explanation.""" + + SEMANTIC_SYSTEM_PROMPT = """You are a prompt variation generator for video generation. + +Given a prompt and a direction, create EXACTLY {count} distinct variations that: +- Shift the meaning/mood/camera in the direction specified +- Preserve core characters and setting unless the direction says otherwise +- Make meaningful semantic changes, not just word swaps +- Each variation must be meaningfully different + +Direction: {direction} + +Output ONLY the varied prompts, one per line. No numbering, bullets, quotes, or explanation.""" + + def __init__( + self, + model: str = DEFAULT_MODEL, + temperature: float = 0.7, + ): + self.model = model + self.temperature = temperature + self._client = None + + @property + def client(self): + if self._client is None: + self._client = init_client() + return self._client + + def jiggle(self, prompt: str, intensity: float = 0.3) -> str: + """ + Generate a variation of the prompt. + + Args: + prompt: Original prompt text + intensity: How different the variation should be (0-1) + + Returns: + Varied prompt text + """ + if self.client is None: + # Graceful fallback - return original + logger.warning("Gemini unavailable, returning original prompt") + return prompt + + _rate_limit() + + # Scale temperature with intensity + adjusted_temp = 0.3 + (intensity * 0.7) # Range: 0.3 to 1.0 + + user_message = f"""Original prompt: +{prompt} + +Generate a subtle variation (intensity: {intensity:.1f}):""" + + try: + from google.genai import types + + response = self.client.models.generate_content( + model=self.model, + contents=[user_message], + config=types.GenerateContentConfig( + system_instruction=self.SYSTEM_PROMPT, + temperature=adjusted_temp, + max_output_tokens=256, + response_mime_type="text/plain", + ), + ) + return response.text.strip() + + except Exception as e: + logger.warning(f"Prompt jiggle failed, returning original: {e}") + return prompt + + def _parse_variations(self, raw: str, count: int, original: str) -> list[str]: + lines = raw.splitlines() + variations: list[str] = [] + original_stripped = original.strip() + + prefixes = [ + "1.", + "2.", + "3.", + "4.", + "5.", + "6.", + "7.", + "8.", + "9.", + "10.", + "-", + "•", + "*", + "—", + ] + + for line in lines: + candidate = line.strip() + if not candidate: + continue + + for prefix in prefixes: + if candidate.startswith(prefix): + candidate = candidate[len(prefix) :].strip() + break + + candidate = candidate.strip().strip('"').strip("'").strip() + if not candidate: + continue + if candidate == original_stripped: + continue + variations.append(candidate) + + # Dedupe while preserving order + seen: set[str] = set() + unique: list[str] = [] + for variation in variations: + if variation in seen: + continue + seen.add(variation) + unique.append(variation) + + if len(unique) < count: + unique.extend([original_stripped] * (count - len(unique))) + return unique[:count] + + def jiggle_multi( + self, + prompt: str, + count: int = 3, + intensity: float = 0.3, + direction: str | None = None, + mode: Literal["attentional", "semantic"] = "attentional", + ) -> list[str]: + """Generate multiple variations of the prompt in a single LLM call.""" + if count < 1: + raise ValueError("count must be >= 1") + + direction_clean = direction.strip() if direction else None + if mode == "semantic" and not direction_clean: + raise ValueError("direction is required when mode='semantic'") + + if self.client is None: + logger.warning("Gemini unavailable, returning original prompt") + return [prompt] * count + + cache_key = _jiggle_cache_key( + prompt=prompt, + model=self.model, + mode=mode, + direction=direction_clean, + intensity=float(intensity), + count=int(count), + ) + cached = _get_jiggle_cached(cache_key) + if cached is not None: + if len(cached) >= count: + return cached[:count] + return cached + [prompt] * (count - len(cached)) + + _rate_limit() + + adjusted_temp = 0.3 + (intensity * 0.7) # Range: 0.3 to 1.0 + max_output_tokens = min(2048, 256 * count) + + if mode == "semantic": + system_prompt = self.SEMANTIC_SYSTEM_PROMPT.format( + count=count, + direction=direction_clean, + ) + else: + direction_clause = f"\nDirection: {direction_clean}" if direction_clean else "" + system_prompt = self.ATTENTIONAL_SYSTEM_PROMPT.format( + count=count, + direction_clause=direction_clause, + ) + + user_message = f"""Original prompt: +{prompt} + +Generate EXACTLY {count} variations (intensity: {intensity:.1f}):""" + + try: + from google.genai import types + + response = self.client.models.generate_content( + model=self.model, + contents=[user_message], + config=types.GenerateContentConfig( + system_instruction=system_prompt, + temperature=adjusted_temp, + max_output_tokens=max_output_tokens, + response_mime_type="text/plain", + ), + ) + raw = (response.text or "").strip() + variations = self._parse_variations(raw, count=count, original=prompt) + + if any(v != prompt for v in variations): + _set_jiggle_cached(cache_key, variations) + + return variations + except Exception as e: + logger.warning(f"Prompt jiggle_multi failed, returning original: {e}") + return [prompt] * count + + +def is_gemini_available() -> bool: + """Check if Gemini is configured and available.""" + return bool(os.environ.get("GEMINI_API_KEY")) diff --git a/src/scope/realtime/generator_driver.py b/src/scope/realtime/generator_driver.py new file mode 100644 index 000000000..bfa8f7939 --- /dev/null +++ b/src/scope/realtime/generator_driver.py @@ -0,0 +1,269 @@ +"""Generator driver - tick loop that owns the pipeline and applies control events. + +The GeneratorDriver is the core loop that: +- Owns the pipeline and PipelineAdapter +- Applies control events at chunk boundaries in deterministic order +- Produces GenerationResult with frames and state snapshots +- Supports pause, resume, step, and snapshot/restore +""" + +import asyncio +import time +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Optional + +from scope.realtime.control_bus import ControlBus, EventType +from scope.realtime.control_state import ControlState, GenerationMode +from scope.realtime.pipeline_adapter import PipelineAdapter, PipelineProtocol + + +class DriverState(Enum): + """State of the GeneratorDriver.""" + + STOPPED = "stopped" + RUNNING = "running" + PAUSED = "paused" + STEPPING = "stepping" + + +@dataclass +class GenerationResult: + """Result of generating one chunk of frames.""" + + frames: Any # Tensor or numpy array + chunk_index: int + control_state_snapshot: dict + timing_ms: float + + +class GeneratorDriver: + """Tick loop that owns the pipeline and applies control events. + + The driver maintains: + - ControlState: current control surface + - ControlBus: queue of pending events + - PipelineAdapter: maps control to pipeline kwargs + + Events are applied at chunk boundaries in deterministic order. + """ + + def __init__( + self, + pipeline: Optional[PipelineProtocol] = None, + on_chunk: Optional[Callable[[GenerationResult], None]] = None, + on_state_change: Optional[Callable[[DriverState], None]] = None, + ): + """Initialize the driver. + + Args: + pipeline: The Scope/KREA pipeline instance (can be None for testing) + on_chunk: Callback for each generated chunk + on_state_change: Callback for driver state changes + """ + self.pipeline = pipeline + self.adapter = PipelineAdapter(pipeline) + self.control_bus = ControlBus() + self.on_chunk = on_chunk or (lambda _: None) + self.on_state_change = on_state_change or (lambda _: None) + + self.state = DriverState.STOPPED + self.control_state = ControlState() + self.chunk_index = 0 + + self._run_task: Optional[asyncio.Task] = None + self._is_prepared: bool = False # Controls init_cache on first call / resets + + def _apply_control_events(self) -> bool: + """Drain and apply queued ControlBus events at a chunk boundary. + + Returns: + True if generation should continue, False if stopped + """ + events = self.control_bus.drain_pending( + is_paused=(self.state == DriverState.PAUSED), + chunk_index=self.chunk_index, + ) + + for event in events: + if event.type == EventType.STOP: + self.stop() + return False + + elif event.type == EventType.PAUSE: + self.state = DriverState.PAUSED + self.on_state_change(self.state) + + elif event.type == EventType.RESUME: + if self.state == DriverState.PAUSED: + self.state = DriverState.RUNNING + self.on_state_change(self.state) + + elif event.type == EventType.SET_PROMPT: + # Direct override: payload may contain prompts and/or transition + for key, value in event.payload.items(): + if hasattr(self.control_state, key): + setattr(self.control_state, key, value) + + elif event.type == EventType.SET_LORA_SCALES: + self.control_state.lora_scales = event.payload.get("lora_scales", []) + + elif event.type == EventType.SET_DENOISE_STEPS: + self.control_state.denoising_step_list = event.payload.get( + "denoising_step_list", self.control_state.denoising_step_list + ) + + elif event.type == EventType.SET_SEED: + if "base_seed" in event.payload: + self.control_state.base_seed = int(event.payload["base_seed"]) + if "branch_seed_offset" in event.payload: + self.control_state.branch_seed_offset = int( + event.payload["branch_seed_offset"] + ) + + elif event.type == EventType.RESTORE_SNAPSHOT: + snapshot = event.payload.get("snapshot") + if snapshot: + self.restore(snapshot) + + return True + + async def run(self): + """Main generation loop.""" + self.state = DriverState.RUNNING + self.on_state_change(self.state) + + while self.state == DriverState.RUNNING: + await self._generate_chunk() + await asyncio.sleep(0) # Yield to event loop + + async def step(self) -> Optional[GenerationResult]: + """Generate exactly one chunk (for Dev Console). + + Returns: + GenerationResult if successful, None if stopped/paused + """ + self.state = DriverState.STEPPING + self.on_state_change(self.state) + + result = await self._generate_chunk() + + if self.state == DriverState.STEPPING: + self.state = DriverState.PAUSED + self.on_state_change(self.state) + + return result + + async def _generate_chunk(self) -> Optional[GenerationResult]: + """Generate one chunk of frames.""" + # Apply control events at chunk boundary + should_continue = self._apply_control_events() + if not should_continue: + return None + if self.state not in (DriverState.RUNNING, DriverState.STEPPING): + return None + + start_time = time.perf_counter() + + # Call pipeline (if available) + output = None + if self.pipeline is not None: + output = self.pipeline( + **self.adapter.kwargs_for_call( + self.control_state, + init_cache=(not self._is_prepared), + ) + ) + + elapsed_ms = (time.perf_counter() - start_time) * 1000 + + self._is_prepared = True + + # Mirror pipeline-owned start frame into control_state for UI/snapshots + if self.pipeline is not None and hasattr(self.pipeline, "state"): + frame = self.pipeline.state.get("current_start_frame") + if frame is not None: + self.control_state.current_start_frame = frame + + self.chunk_index += 1 + + result = GenerationResult( + frames=output, + chunk_index=self.chunk_index, + control_state_snapshot=self._snapshot_control_state(), + timing_ms=elapsed_ms, + ) + + self.on_chunk(result) + return result + + def pause(self): + """Pause generation.""" + self.state = DriverState.PAUSED + self.on_state_change(self.state) + + def resume(self): + """Resume generation. Guards against spawning multiple loops.""" + if self.state != DriverState.PAUSED: + return # Can only resume from paused + if self._run_task and not self._run_task.done(): + return # Already have an active loop + self._run_task = asyncio.create_task(self.run()) + + def stop(self): + """Stop generation and cancel any running task.""" + self.state = DriverState.STOPPED + self.on_state_change(self.state) + if self._run_task and not self._run_task.done(): + self._run_task.cancel() + self._run_task = None + + def snapshot(self) -> dict: + """Create a restorable snapshot of current state. + + Includes generator continuity buffers needed for seamless continuation. + """ + return { + "control_state": self._snapshot_control_state(), + "chunk_index": self.chunk_index, + "generator_continuity": self.adapter.capture_continuity(), + } + + def restore(self, snapshot: dict): + """Restore from a snapshot. + + If generator_continuity is present and valid, this produces seamless + continuation. Otherwise, it's a hard cut (acceptable for branching). + """ + # Restore control state + ctrl_data = snapshot.get("control_state", {}) + for key, value in ctrl_data.items(): + if hasattr(self.control_state, key): + if key == "mode": + if isinstance(value, str): + value = GenerationMode(value) + setattr(self.control_state, key, value) + + # Restore chunk index + self.chunk_index = snapshot.get("chunk_index", 0) + + # Restore generator continuity (if available) + if "generator_continuity" in snapshot: + self.adapter.restore_continuity(snapshot["generator_continuity"]) + + # Ensure next generate does not wipe continuity by forcing init_cache + self._is_prepared = True + + def _snapshot_control_state(self) -> dict: + """Create a serializable snapshot of ControlState.""" + return { + "prompts": self.control_state.prompts, + "negative_prompt": self.control_state.negative_prompt, + "lora_scales": self.control_state.lora_scales, + "base_seed": self.control_state.base_seed, + "branch_seed_offset": self.control_state.branch_seed_offset, + "current_start_frame": self.control_state.current_start_frame, + "denoising_step_list": self.control_state.denoising_step_list, + "mode": self.control_state.mode.value, + "kv_cache_attention_bias": self.control_state.kv_cache_attention_bias, + } diff --git a/src/scope/realtime/pipeline_adapter.py b/src/scope/realtime/pipeline_adapter.py new file mode 100644 index 000000000..34a8ce268 --- /dev/null +++ b/src/scope/realtime/pipeline_adapter.py @@ -0,0 +1,126 @@ +"""Pipeline adapter - maps ControlState to pipeline kwargs and handles continuity. + +The Scope/KREA pipeline stores both control inputs and continuity buffers inside +`pipeline.state` (a PipelineState key-value store). This adapter provides: + +1. Convert ControlState → exact kwargs for KreaRealtimeVideoPipeline.__call__() +2. Extract and restore continuity buffers from pipeline.state using known keys +3. Edge-trigger runtime changes with side effects (notably lora_scales) +""" + +from typing import Any, Optional, Protocol + +from scope.realtime.control_state import ControlState + + +class PipelineStateProtocol(Protocol): + """Protocol for pipeline.state access.""" + + def get(self, key: str) -> Any: ... + def set(self, key: str, value: Any) -> None: ... + + +class PipelineProtocol(Protocol): + """Protocol for the pipeline object.""" + + state: PipelineStateProtocol + + def __call__(self, **kwargs: Any) -> Any: ... + + +class PipelineAdapter: + """Adapter between ControlState and the Scope/KREA pipeline. + + Responsibilities: + - Convert ControlState to pipeline kwargs + - Edge-trigger lora_scales (only include when changed to avoid cache resets) + - Capture/restore continuity buffers from pipeline.state + """ + + # Keys in pipeline.state that hold continuity buffers + # These are produced/consumed by pipeline blocks and needed for seamless continuation + CONTINUITY_KEYS = [ + "current_start_frame", + "first_context_frame", + "context_frame_buffer", + "decoded_frame_buffer", + "context_frame_buffer_max_size", + "decoded_frame_buffer_max_size", + ] + + def __init__(self, pipeline: Optional[PipelineProtocol] = None): + """Initialize the adapter. + + Args: + pipeline: The Scope/KREA pipeline instance. Can be None for testing. + """ + self.pipeline = pipeline + self._last_lora_scales: Optional[list[dict]] = None + + def kwargs_for_call(self, control: ControlState, *, init_cache: bool) -> dict: + """Convert ControlState to pipeline kwargs. + + Args: + control: The current ControlState + init_cache: Whether to initialize/reset the cache + + Returns: + Dict of kwargs for pipeline.__call__() + + Note: + - lora_scales is edge-triggered: only included when it changes + - negative_prompt is NOT included (not consumed by Scope/KREA) + - init_cache is always explicit (driver decides) + """ + kwargs = control.to_pipeline_kwargs() + kwargs["init_cache"] = init_cache + + # Edge-trigger: only include lora_scales when it changes + # In Scope/KREA, providing lora_scales may force init_cache=True + # when manage_cache is enabled + current_scales = control.lora_scales + last_scales = self._last_lora_scales or [] + + if current_scales != last_scales: + if current_scales: + kwargs["lora_scales"] = current_scales + self._last_lora_scales = list(current_scales) if current_scales else [] + else: + # Ensure lora_scales is not in kwargs when unchanged + kwargs.pop("lora_scales", None) + + return kwargs + + def capture_continuity(self) -> dict: + """Capture continuity buffers from pipeline.state. + + Returns: + Dict of continuity buffers keyed by state key name. + Only includes keys that have non-None values. + """ + if self.pipeline is None: + return {} + + st = self.pipeline.state + return {k: st.get(k) for k in self.CONTINUITY_KEYS if st.get(k) is not None} + + def restore_continuity(self, continuity: dict): + """Restore continuity buffers to pipeline.state. + + Args: + continuity: Dict of continuity buffers from capture_continuity() + """ + if self.pipeline is None: + return + + st = self.pipeline.state + for k, v in continuity.items(): + st.set(k, v) + + def reset_lora_tracking(self): + """Reset lora_scales tracking. + + Call this after a full pipeline reset to ensure the next call + includes lora_scales even if it matches the previous value. + """ + self._last_lora_scales = None diff --git a/src/scope/realtime/prompt_compiler.py b/src/scope/realtime/prompt_compiler.py new file mode 100644 index 000000000..1a9fd30c4 --- /dev/null +++ b/src/scope/realtime/prompt_compiler.py @@ -0,0 +1,599 @@ +""" +PromptCompiler - Translates WorldState + StyleManifest into prompt strings. + +The compiler is pluggable: +- LLMCompiler: Uses an LLM (Gemini Flash, etc.) with instruction sheets +- TemplateCompiler: Deterministic vocab substitution (for testing/fallback) +- CachedCompiler: Wraps another compiler with memoization + +Factory function: +- create_compiler(): Creates the appropriate compiler based on config +""" + +import hashlib +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from .style_manifest import StyleManifest +from .world_state import WorldState + +logger = logging.getLogger(__name__) + + +@dataclass +class PromptEntry: + """A single prompt entry with text and weight.""" + + text: str + weight: float = 1.0 + + def to_dict(self) -> dict[str, Any]: + return {"text": self.text, "weight": self.weight} + + +@dataclass +class LoRAScaleUpdate: + """A single LoRA scale update in Scope runtime format.""" + + path: str + scale: float + + def to_dict(self) -> dict[str, Any]: + return {"path": self.path, "scale": self.scale} + + +@dataclass +class CompiledPrompt: + """Result of prompt compilation - pipeline-ready format.""" + + # Pipeline-ready format: list of {"text": ..., "weight": ...} + prompts: list[PromptEntry] = field(default_factory=list) + negative_prompt: str = "" + + # LoRA configuration from style + lora_scales: list[LoRAScaleUpdate] = field(default_factory=list) + + # Metadata about the compilation + style_name: str = "" + compiler_type: str = "" + token_count: int | None = None + truncated: bool = False + + # For debugging/iteration + world_state_hash: str = "" + manifest_hash: str = "" # For cache invalidation + raw_llm_response: str | None = None + + @property + def prompt(self) -> str: + """Convenience: return first prompt text (for simple cases).""" + if self.prompts: + return self.prompts[0].text + return "" + + def to_pipeline_kwargs(self) -> dict[str, Any]: + """Convert to kwargs suitable for pipeline call.""" + kwargs = { + "prompts": [p.to_dict() for p in self.prompts], + } + if self.negative_prompt: + kwargs["negative_prompt"] = self.negative_prompt + if self.lora_scales: + kwargs["lora_scales"] = [u.to_dict() for u in self.lora_scales] + return kwargs + + +class PromptCompiler(ABC): + """ + Abstract base class for prompt compilers. + + Subclasses implement the actual compilation logic. + """ + + @abstractmethod + def compile( + self, + world_state: WorldState, + style: StyleManifest, + **kwargs: Any, + ) -> CompiledPrompt: + """ + Compile a WorldState into a prompt using the given style. + + Args: + world_state: The world state to compile + style: The style manifest with vocab and constraints + **kwargs: Additional compiler-specific options + + Returns: + CompiledPrompt with the generated prompt string + """ + pass + + def _compute_state_hash(self, world_state: WorldState, style: StyleManifest) -> str: + """Compute a hash for cache keys (WorldState only).""" + state_str = world_state.model_dump_json() + return hashlib.md5(state_str.encode()).hexdigest()[:12] + + def _compute_manifest_hash(self, style: StyleManifest) -> str: + """Compute a hash for manifest content (for cache invalidation).""" + # Include all vocab content, not just name + manifest_str = style.model_dump_json() + return hashlib.md5(manifest_str.encode()).hexdigest()[:12] + + def _compute_cache_key(self, world_state: WorldState, style: StyleManifest) -> str: + """Compute full cache key including manifest content.""" + state_hash = self._compute_state_hash(world_state, style) + manifest_hash = self._compute_manifest_hash(style) + return f"{style.name}:{manifest_hash}:{state_hash}" + + +class TemplateCompiler(PromptCompiler): + """ + Deterministic template-based compiler. + + Uses simple vocab substitution without LLM calls. + Good for testing and as a fallback. + """ + + # Common action normalizations (user can extend via style.custom_vocab["action_aliases"]) + ACTION_ALIASES = { + "walking": "walk", + "running": "run", + "standing": "idle", + "sitting": "sit", + "jumping": "jump", + } + + def _normalize_action(self, action: str, style: StyleManifest) -> str: + """Normalize action key to canonical form.""" + # Check style-specific aliases first + aliases = style.custom_vocab.get("action_aliases", {}) + if action in aliases: + return aliases[action] + # Fall back to built-in aliases + return self.ACTION_ALIASES.get(action, action) + + def compile( + self, + world_state: WorldState, + style: StyleManifest, + **kwargs: Any, + ) -> CompiledPrompt: + """Compile using template substitution.""" + parts = [] + + # 1. Trigger words (always first) + if style.trigger_words: + parts.extend(style.trigger_words) + + # 2. Action/motion (with normalization) + if world_state.action: + action_key = self._normalize_action(world_state.action, style) + motion = style.get_vocab("motion", action_key, world_state.action) + parts.append(motion) + + # 3. Character emotions + for char in world_state.characters: + if char.emotion and char.emotion != "neutral": + emotion = style.get_vocab("emotion", char.emotion, char.emotion) + parts.append(f"{char.name} {emotion}") + + # 4. Camera + camera_key = world_state.camera.value + camera = style.get_vocab("camera", camera_key, camera_key) + parts.append(camera) + + # 5. Lighting (based on time of day or mood) + if world_state.time_of_day: + lighting = style.get_vocab("lighting", world_state.time_of_day) + if lighting != world_state.time_of_day: + parts.append(lighting) + + # 6. Beat modifier + beat_key = world_state.beat.value + beat = style.get_vocab("beat", beat_key) + if beat != beat_key: + parts.append(beat) + + # 7. Scene description (truncated if needed) + if world_state.scene_description: + parts.append(world_state.scene_description) + + # Join and basic cleanup + prompt_text = ", ".join(p for p in parts if p) + + # Build lora_scales from style if lora_path is set + lora_scales: list[LoRAScaleUpdate] = [] + if style.lora_path: + lora_scales.append( + LoRAScaleUpdate(path=style.lora_path, scale=style.lora_default_scale) + ) + + return CompiledPrompt( + prompts=[PromptEntry(text=prompt_text, weight=1.0)], + negative_prompt=style.default_negative, + lora_scales=lora_scales, + style_name=style.name, + compiler_type="template", + world_state_hash=self._compute_state_hash(world_state, style), + manifest_hash=self._compute_manifest_hash(style), + ) + + +@dataclass +class InstructionSheet: + """ + An instruction sheet for LLM-based compilation. + + Contains the system prompt and few-shot examples. + """ + + name: str + system_prompt: str + examples: list[dict[str, str]] = field(default_factory=list) + # Each example: {"world_state": "...", "output": "..."} + + @classmethod + def from_markdown(cls, path: str | Path) -> "InstructionSheet": + """ + Load an instruction sheet from a markdown file. + + Expected format: + ``` + # Sheet Name + + ## System Prompt + Your instructions here... + + ## Examples + + ### Example 1 + **Input:** + WorldState description... + + **Output:** + Expected prompt output... + ``` + """ + path = Path(path) + content = path.read_text() + + # Parse markdown (simplified parser) + lines = content.split("\n") + name = "" + system_prompt = "" + examples = [] + + current_section = None + current_example = {} + buffer = [] + + for line in lines: + if line.startswith("# ") and not name: + name = line[2:].strip() + elif line.startswith("## System Prompt"): + current_section = "system" + buffer = [] + elif line.startswith("## Examples"): + if buffer and current_section == "system": + system_prompt = "\n".join(buffer).strip() + current_section = "examples" + buffer = [] + elif line.startswith("### Example"): + # Save previous example if complete + if current_example and buffer and current_section == "output": + current_example["output"] = "\n".join(buffer).strip() + if current_example and "world_state" in current_example and "output" in current_example: + examples.append(current_example) + current_example = {} + buffer = [] + current_section = "example_start" + elif line.startswith("**Input:**"): + buffer = [] + current_section = "input" + elif line.startswith("**Output:**"): + current_example["world_state"] = "\n".join(buffer).strip() + buffer = [] + current_section = "output" + elif current_section in ("system", "input", "output"): + buffer.append(line) + + # Capture final content + if buffer: + if current_section == "system": + system_prompt = "\n".join(buffer).strip() + elif current_section == "output": + current_example["output"] = "\n".join(buffer).strip() + + if current_example and "world_state" in current_example and "output" in current_example: + examples.append(current_example) + + return cls( + name=name or path.stem, + system_prompt=system_prompt, + examples=examples, + ) + + +class LLMCompiler(PromptCompiler): + """ + LLM-based prompt compiler. + + Uses an LLM (via a callable) to translate WorldState into prompts, + guided by instruction sheets and style vocab. + """ + + def __init__( + self, + llm_callable: callable, + instruction_sheet: InstructionSheet | None = None, + ): + """ + Args: + llm_callable: A function that takes (system_prompt, user_message) -> str + instruction_sheet: Optional instruction sheet with system prompt and examples + """ + self.llm_callable = llm_callable + self.instruction_sheet = instruction_sheet + + def compile( + self, + world_state: WorldState, + style: StyleManifest, + **kwargs: Any, + ) -> CompiledPrompt: + """Compile using LLM.""" + # Build system prompt + system_prompt = self._build_system_prompt(style) + + # Build user message with WorldState + user_message = self._build_user_message(world_state, style) + + # Call LLM + try: + raw_response = self.llm_callable(system_prompt, user_message) + prompt_text = self._parse_response(raw_response) + except Exception as e: + logger.error(f"LLM compilation failed: {e}") + # Fall back to template compilation + fallback = TemplateCompiler() + result = fallback.compile(world_state, style, **kwargs) + result.compiler_type = "template_fallback" + return result + + # Build lora_scales from style + lora_scales: list[LoRAScaleUpdate] = [] + if style.lora_path: + lora_scales.append( + LoRAScaleUpdate(path=style.lora_path, scale=style.lora_default_scale) + ) + + return CompiledPrompt( + prompts=[PromptEntry(text=prompt_text, weight=1.0)], + negative_prompt=style.default_negative, + lora_scales=lora_scales, + style_name=style.name, + compiler_type="llm", + world_state_hash=self._compute_state_hash(world_state, style), + manifest_hash=self._compute_manifest_hash(style), + raw_llm_response=raw_response, + ) + + def _build_system_prompt(self, style: StyleManifest) -> str: + """Build the system prompt with vocab context and few-shot examples.""" + parts = [] + + # Base instruction sheet + if self.instruction_sheet: + parts.append(self.instruction_sheet.system_prompt) + else: + parts.append( + "You are a prompt compiler for video generation. " + "Translate the given world state into an effective prompt." + ) + + # Add vocab context + parts.append("\n\n## Available Vocabulary\n") + for category, vocab in style.get_all_vocab().items(): + if vocab: + parts.append(f"\n### {category.title()}") + for key, value in vocab.items(): + parts.append(f"- {key}: {value}") + + # Add constraints + parts.append(f"\n\n## Constraints") + parts.append(f"- Trigger words (must include): {', '.join(style.trigger_words)}") + parts.append(f"- Max tokens: {style.max_prompt_tokens}") + parts.append(f"- Priority order: {', '.join(style.priority_order)}") + + # Add few-shot examples from instruction sheet + if self.instruction_sheet and self.instruction_sheet.examples: + parts.append("\n\n## Examples\n") + for i, example in enumerate(self.instruction_sheet.examples, 1): + parts.append(f"\n### Example {i}") + if "world_state" in example: + parts.append(f"**Input:**\n{example['world_state']}") + if "output" in example: + parts.append(f"\n**Output:**\n{example['output']}") + + return "\n".join(parts) + + def _build_user_message(self, world_state: WorldState, style: StyleManifest) -> str: + """Build the user message with WorldState context.""" + context = world_state.to_context_dict() + + parts = ["## World State\n"] + for key, value in context.items(): + # Include all values except None and empty strings/lists + # This preserves visible=False, tension=0.0, etc. + if value is not None and value != "" and value != []: + parts.append(f"- {key}: {value}") + + parts.append("\n\nGenerate the prompt:") + + return "\n".join(parts) + + def _parse_response(self, response: str) -> str: + """Parse the LLM response to extract the prompt.""" + # Simple extraction - assume the response IS the prompt + # Could be enhanced to handle structured output + return response.strip() + + +class CachedCompiler(PromptCompiler): + """ + Caching wrapper for any PromptCompiler. + + Memoizes results based on WorldState + Style manifest content hash. + """ + + def __init__( + self, + inner_compiler: PromptCompiler, + max_cache_size: int = 100, + ): + self.inner = inner_compiler + self.max_cache_size = max_cache_size + self._cache: dict[str, CompiledPrompt] = {} + self._cache_order: list[str] = [] # For LRU eviction + + def compile( + self, + world_state: WorldState, + style: StyleManifest, + **kwargs: Any, + ) -> CompiledPrompt: + """Compile with caching.""" + # Use full cache key that includes manifest content hash + cache_key = self._compute_cache_key(world_state, style) + + if cache_key in self._cache: + # Move to end of LRU order + self._cache_order.remove(cache_key) + self._cache_order.append(cache_key) + return self._cache[cache_key] + + # Compile and cache + result = self.inner.compile(world_state, style, **kwargs) + + # LRU eviction + if len(self._cache) >= self.max_cache_size: + oldest_key = self._cache_order.pop(0) + del self._cache[oldest_key] + + self._cache[cache_key] = result + self._cache_order.append(cache_key) + + return result + + def clear_cache(self) -> None: + """Clear the cache.""" + self._cache.clear() + self._cache_order.clear() + + +def _load_instruction_sheet(style: StyleManifest) -> InstructionSheet | None: + """ + Load instruction sheet for a style. + + Looks for instructions.md in the style directory. + """ + if not style.name: + return None + + from .style_manifest import get_style_dirs + + candidates: list[Path] = [] + + raw = (style.instruction_sheet_path or "").strip() + if raw: + p = Path(raw).expanduser() + if p.is_absolute(): + candidates.append(p) + else: + # Try both "instructions.md" and "style/instructions.md" style paths. + for base in get_style_dirs(): + candidates.append(base / style.name / p) + candidates.append(base / p) + else: + for base in get_style_dirs(): + candidates.append(base / style.name / "instructions.md") + + for instruction_path in candidates: + if not instruction_path.exists(): + continue + try: + return InstructionSheet.from_markdown(instruction_path) + except Exception as e: + logger.warning( + "Failed to load instruction sheet %s: %s", instruction_path, e + ) + return None + + return None + + +def create_compiler( + style: StyleManifest, + mode: str = "auto", +) -> PromptCompiler: + """ + Create a PromptCompiler for the given style. + + Args: + style: The style manifest to compile prompts for + mode: Compiler mode - "gemini", "template", or "auto" + "auto" uses Gemini if GEMINI_API_KEY is set, else template + + Returns: + A PromptCompiler instance (possibly wrapped in CachedCompiler) + + Environment variables: + SCOPE_LLM_COMPILER: Override mode ("gemini", "template", "auto") + GEMINI_API_KEY: Required for Gemini mode + """ + # Check env override + env_mode = os.getenv("SCOPE_LLM_COMPILER", "auto") + if env_mode in ("gemini", "template"): + mode = env_mode + + # Template mode - simple and fast + if mode == "template": + logger.info(f"Using TemplateCompiler for style '{style.name}'") + return TemplateCompiler() + + # Gemini mode - check availability + if mode == "gemini": + from .gemini_client import GeminiCompiler, is_gemini_available + + if not is_gemini_available(): + logger.warning( + "SCOPE_LLM_COMPILER=gemini but GEMINI_API_KEY not set, " + "falling back to TemplateCompiler" + ) + return TemplateCompiler() + + instruction_sheet = _load_instruction_sheet(style) + inner = LLMCompiler( + llm_callable=GeminiCompiler(), + instruction_sheet=instruction_sheet, + ) + logger.info( + f"Using LLMCompiler (Gemini) for style '{style.name}' " + f"with instruction_sheet={instruction_sheet.name if instruction_sheet else 'None'}" + ) + return CachedCompiler(inner) + + # Auto mode - use Gemini if available + from .gemini_client import is_gemini_available + + if is_gemini_available(): + return create_compiler(style, mode="gemini") + + logger.info(f"Using TemplateCompiler for style '{style.name}' (auto mode, no API key)") + return TemplateCompiler() diff --git a/src/scope/realtime/prompt_playlist.py b/src/scope/realtime/prompt_playlist.py new file mode 100644 index 000000000..ced6c2792 --- /dev/null +++ b/src/scope/realtime/prompt_playlist.py @@ -0,0 +1,410 @@ +""" +PromptPlaylist - Load and navigate through a list of prompts from caption files. + +Features: +- Load prompts from text files (one per line) +- Trigger phrase swapping (e.g., "1988 Cel Animation" -> "Rankin/Bass Animagic Stop-Motion") +- Navigation: next, prev, goto, current +- Optional shuffle and loop modes +""" + +import json +import logging +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def _apply_trigger_swap(prompts: list[str], old_trigger: str, new_trigger: str) -> list[str]: + """Replace old trigger phrase with new one in all prompts. + + Uses case-insensitive matching for flexibility. + """ + result = [] + pattern = re.compile(re.escape(old_trigger), re.IGNORECASE) + for prompt in prompts: + swapped = pattern.sub(new_trigger, prompt) + result.append(swapped) + return result + + +@dataclass +class PromptPlaylist: + """A navigable playlist of prompts loaded from a caption file.""" + + source_file: str = "" + prompts: list[str] = field(default_factory=list) + current_index: int = 0 + + # Trigger swapping: (old_trigger, new_trigger) + trigger_swap: tuple[str, str] | None = None + + # Original prompts (before any trigger swap) - enables re-swapping + original_prompts: list[str] = field(default_factory=list) + + # Source trigger: what trigger phrase is in the original prompts (from file) + # This is what we search for when doing swaps + source_trigger: str | None = None + + # Current trigger applied to prompts (what we've swapped TO) + current_trigger: str | None = None + + # Metadata + original_count: int = 0 + + # Prompt bookmarks - indices of bookmarked prompts for quick navigation + bookmarked_indices: set[int] = field(default_factory=set) + + @classmethod + def from_file( + cls, + path: str | Path, + trigger_swap: tuple[str, str] | None = None, + skip_empty: bool = True, + ) -> "PromptPlaylist": + """ + Load prompts from a text file (one prompt per line). + + Args: + path: Path to the caption file + trigger_swap: Optional (old, new) trigger phrase to swap + skip_empty: Whether to skip empty lines + + Returns: + PromptPlaylist instance + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Caption file not found: {path}") + + lines = path.read_text().strip().split("\n") + original_count = len(lines) + + # First pass: collect original prompts (no swap applied) + original_prompts = [] + for line in lines: + line = line.strip() + if skip_empty and not line: + continue + original_prompts.append(line) + + # Second pass: apply trigger swap if configured + prompts = list(original_prompts) # copy + source_trigger = None + current_trigger = None + if trigger_swap: + old_trigger, new_trigger = trigger_swap + source_trigger = old_trigger # What's in the file + current_trigger = new_trigger # What we're swapping to + prompts = _apply_trigger_swap(prompts, old_trigger, new_trigger) + + logger.info( + f"Loaded {len(prompts)} prompts from {path.name}" + + (f" (swapped '{trigger_swap[0]}' -> '{trigger_swap[1]}')" if trigger_swap else "") + ) + + playlist = cls( + source_file=str(path), + prompts=prompts, + current_index=0, + trigger_swap=trigger_swap, + original_prompts=original_prompts, + source_trigger=source_trigger, + current_trigger=current_trigger, + original_count=original_count, + ) + + # Load any saved bookmarks + playlist._load_bookmarks() + + return playlist + + @property + def current(self) -> str: + """Get the current prompt.""" + if not self.prompts: + return "" + return self.prompts[self.current_index] + + @property + def total(self) -> int: + """Total number of prompts.""" + return len(self.prompts) + + @property + def has_next(self) -> bool: + """Check if there's a next prompt.""" + return self.current_index < len(self.prompts) - 1 + + @property + def has_prev(self) -> bool: + """Check if there's a previous prompt.""" + return self.current_index > 0 + + def next(self) -> str: + """Move to next prompt and return it.""" + if self.has_next: + self.current_index += 1 + return self.current + + def prev(self) -> str: + """Move to previous prompt and return it.""" + if self.has_prev: + self.current_index -= 1 + return self.current + + def goto(self, index: int) -> str: + """Go to a specific prompt index.""" + if self.prompts: + self.current_index = max(0, min(index, len(self.prompts) - 1)) + return self.current + + def first(self) -> str: + """Go to first prompt.""" + return self.goto(0) + + def last(self) -> str: + """Go to last prompt.""" + return self.goto(len(self.prompts) - 1) + + # Bookmark methods + def bookmark_current(self) -> bool: + """Bookmark the current prompt index. Returns True if newly added.""" + if self.current_index in self.bookmarked_indices: + return False + self.bookmarked_indices.add(self.current_index) + logger.info(f"Bookmarked prompt {self.current_index}") + self._save_bookmarks() + return True + + def unbookmark_current(self) -> bool: + """Remove bookmark from current prompt index. Returns True if removed.""" + if self.current_index not in self.bookmarked_indices: + return False + self.bookmarked_indices.discard(self.current_index) + logger.info(f"Unbookmarked prompt {self.current_index}") + self._save_bookmarks() + return True + + def toggle_bookmark(self) -> bool: + """Toggle bookmark on current prompt. Returns True if now bookmarked.""" + if self.current_index in self.bookmarked_indices: + self.unbookmark_current() + return False + else: + self.bookmark_current() + return True + + def is_bookmarked(self, index: int | None = None) -> bool: + """Check if an index (or current) is bookmarked.""" + idx = index if index is not None else self.current_index + return idx in self.bookmarked_indices + + def next_bookmarked(self) -> str | None: + """Move to the next bookmarked prompt. Returns None if no bookmarks ahead.""" + if not self.bookmarked_indices: + return None + sorted_bookmarks = sorted(self.bookmarked_indices) + for idx in sorted_bookmarks: + if idx > self.current_index: + self.current_index = idx + logger.info(f"Jumped to bookmarked prompt {idx}") + return self.current + # Wrap around to first bookmark + if sorted_bookmarks: + self.current_index = sorted_bookmarks[0] + logger.info(f"Wrapped to first bookmarked prompt {sorted_bookmarks[0]}") + return self.current + return None + + def prev_bookmarked(self) -> str | None: + """Move to the previous bookmarked prompt. Returns None if no bookmarks behind.""" + if not self.bookmarked_indices: + return None + sorted_bookmarks = sorted(self.bookmarked_indices, reverse=True) + for idx in sorted_bookmarks: + if idx < self.current_index: + self.current_index = idx + logger.info(f"Jumped to bookmarked prompt {idx}") + return self.current + # Wrap around to last bookmark + if sorted_bookmarks: + self.current_index = sorted_bookmarks[0] + logger.info(f"Wrapped to last bookmarked prompt {sorted_bookmarks[0]}") + return self.current + return None + + def clear_bookmarks(self) -> int: + """Clear all bookmarks. Returns count of cleared bookmarks.""" + count = len(self.bookmarked_indices) + self.bookmarked_indices.clear() + logger.info(f"Cleared {count} bookmarks") + self._save_bookmarks() + return count + + def _get_bookmarks_path(self) -> Path | None: + """Get the path for the bookmarks sidecar file.""" + if not self.source_file: + return None + source = Path(self.source_file) + return source.parent / f"{source.stem}.bookmarks.json" + + def _save_bookmarks(self) -> bool: + """Save bookmarks to sidecar JSON file.""" + bookmarks_path = self._get_bookmarks_path() + if not bookmarks_path: + return False + try: + data = { + "source_file": self.source_file, + "bookmarked_indices": sorted(self.bookmarked_indices), + } + bookmarks_path.write_text(json.dumps(data, indent=2)) + logger.info(f"Saved {len(self.bookmarked_indices)} bookmarks to {bookmarks_path}") + return True + except Exception as e: + logger.warning(f"Failed to save bookmarks: {e}") + return False + + def _load_bookmarks(self) -> bool: + """Load bookmarks from sidecar JSON file if it exists.""" + bookmarks_path = self._get_bookmarks_path() + if not bookmarks_path or not bookmarks_path.exists(): + return False + try: + data = json.loads(bookmarks_path.read_text()) + indices = data.get("bookmarked_indices", []) + # Filter to valid indices + valid_indices = {i for i in indices if 0 <= i < len(self.prompts)} + self.bookmarked_indices = valid_indices + logger.info(f"Loaded {len(valid_indices)} bookmarks from {bookmarks_path}") + return True + except Exception as e: + logger.warning(f"Failed to load bookmarks: {e}") + return False + + def set_source_trigger(self, trigger: str) -> None: + """Set the source trigger phrase (what's in the original prompts). + + Call this if the playlist was loaded without --swap and you want to + enable auto-trigger-swap on style changes. + """ + self.source_trigger = trigger + logger.info(f"Source trigger set to: '{trigger}'") + + def swap_trigger(self, new_trigger: str) -> bool: + """Swap the trigger phrase to a new one. + + Always swaps from source_trigger (what's in original file) to new_trigger. + Returns True if swap was applied, False if no change needed. + """ + if not self.original_prompts: + logger.warning("Cannot swap trigger: no original prompts stored") + return False + + if new_trigger == self.current_trigger: + logger.debug(f"Trigger already set to '{new_trigger}', skipping swap") + return False + + if not self.source_trigger: + # No source trigger known - can't do a swap + # This happens if playlist was loaded without trigger_swap parameter + logger.warning( + "Cannot swap trigger: source_trigger not set. " + "Load playlist with --swap to specify the source trigger phrase." + ) + return False + + # Always swap from source (what's in file) to new trigger + self.prompts = _apply_trigger_swap( + self.original_prompts, self.source_trigger, new_trigger + ) + logger.info(f"Swapped trigger: '{self.source_trigger}' -> '{new_trigger}'") + + self.current_trigger = new_trigger + return True + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for API responses.""" + return { + "source_file": self.source_file, + "current_index": self.current_index, + "total": self.total, + "current_prompt": self.current, + "has_next": self.has_next, + "has_prev": self.has_prev, + "trigger_swap": list(self.trigger_swap) if self.trigger_swap else None, + "source_trigger": self.source_trigger, + "current_trigger": self.current_trigger, + "bookmarked_indices": sorted(self.bookmarked_indices), + "is_bookmarked": self.is_bookmarked(), + } + + def preview(self, context: int = 2, max_prompt_len: int = 0) -> dict[str, Any]: + """Get a preview window around current position. + + Args: + context: Number of prompts to show before/after current + max_prompt_len: Max length for prompts (0 = no truncation) + """ + if not self.prompts: + return {"prompts": [], "current_index": 0} + + start = max(0, self.current_index - context) + end = min(len(self.prompts), self.current_index + context + 1) + + items = [] + for i in range(start, end): + prompt = self.prompts[i] + # Only truncate if max_prompt_len is set + if max_prompt_len > 0 and len(prompt) > max_prompt_len: + prompt = prompt[:max_prompt_len - 3] + "..." + items.append({ + "index": i, + "prompt": prompt, + "current": i == self.current_index, + "bookmarked": i in self.bookmarked_indices, + }) + + return { + "prompts": items, + "current_index": self.current_index, + "total": self.total, + "bookmarked_indices": sorted(self.bookmarked_indices), + } + + def preview_bookmarks(self, max_prompt_len: int = 0) -> dict[str, Any]: + """Get all bookmarked prompts (for filtered view). + + Args: + max_prompt_len: Max length for prompts (0 = no truncation) + """ + if not self.prompts: + return {"prompts": [], "current_index": 0, "total": 0, "bookmarked_indices": []} + + items = [] + # Include all bookmarked indices plus current (even if not bookmarked) + indices_to_show = sorted(self.bookmarked_indices | {self.current_index}) + + for i in indices_to_show: + if i < 0 or i >= len(self.prompts): + continue + prompt = self.prompts[i] + if max_prompt_len > 0 and len(prompt) > max_prompt_len: + prompt = prompt[:max_prompt_len - 3] + "..." + items.append({ + "index": i, + "prompt": prompt, + "current": i == self.current_index, + "bookmarked": i in self.bookmarked_indices, + }) + + return { + "prompts": items, + "current_index": self.current_index, + "total": self.total, + "bookmarked_indices": sorted(self.bookmarked_indices), + } diff --git a/src/scope/realtime/style_manifest.py b/src/scope/realtime/style_manifest.py new file mode 100644 index 000000000..59031e25c --- /dev/null +++ b/src/scope/realtime/style_manifest.py @@ -0,0 +1,369 @@ +""" +StyleManifest - LoRA-specific vocabulary and metadata. + +A StyleManifest captures everything needed to translate abstract world concepts +into effective prompt tokens for a specific LoRA/style. +""" + +import logging +import os +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Style Directory Resolution +# ============================================================================= + + +def get_style_dirs() -> list[Path]: + """ + Return style directories in precedence order (later wins on name conflicts). + + Priority: + 1. SCOPE_STYLES_DIRS env var (os.pathsep-separated: ":" on Linux, ";" on Windows) + 2. Default: ./styles (repo built-ins), ~/.daydream-scope/styles (user overrides) + """ + if custom := os.environ.get("SCOPE_STYLES_DIRS"): + return [Path(p).expanduser().resolve() for p in custom.split(os.pathsep) if p] + + return [ + Path("styles").resolve(), # repo/dev built-ins + Path.home() / ".daydream-scope" / "styles", # user overrides + ] + + +# ============================================================================= +# LoRA Path Canonicalization +# ============================================================================= + + +def get_lora_dir() -> Path: + """Get the canonical LoRA directory path.""" + from scope.server.models_config import get_models_dir + + return get_models_dir() / "lora" + + +def canonicalize_lora_path(raw: str | None) -> str | None: + """ + Canonicalize a LoRA path for consistent matching. + + Resolution rules: + - None/empty → None + - Already absolute → resolve and return + - Bare filename (no /) → resolve under models/lora + - Relative path with / → resolve under models/lora + + This ensures the same canonical string is used for: + - Pipeline preload: loras=[{"path": , ...}] + - Runtime updates: lora_scales=[{"path": , ...}] + + Args: + raw: Raw path from manifest (e.g., "rat_21_step5500.safetensors") + + Returns: + Canonical absolute path string, or None if input was empty + """ + if not raw: + return None + + p = Path(raw).expanduser() + + # Already absolute + if p.is_absolute(): + return str(p.resolve()) + + # Relative path: resolve under lora directory + lora_dir = get_lora_dir() + return str((lora_dir / p).resolve()) + + +class StyleManifest(BaseModel): + """ + Metadata and vocabulary for a specific LoRA/style. + + The vocab dictionaries map abstract concepts to effective prompt tokens + discovered through experimentation with this specific LoRA. + """ + + # Identity + name: str + description: str = "" + + # LoRA configuration + lora_path: str | None = None + lora_default_scale: float = 0.85 + + # Trigger words (always included in prompt) + trigger_words: list[str] = Field(default_factory=list) + + # Vocabulary mappings: abstract concept → effective tokens + # These are populated from your prompt experiments + material_vocab: dict[str, str] = Field(default_factory=dict) + motion_vocab: dict[str, str] = Field(default_factory=dict) + camera_vocab: dict[str, str] = Field(default_factory=dict) + lighting_vocab: dict[str, str] = Field(default_factory=dict) + emotion_vocab: dict[str, str] = Field(default_factory=dict) + beat_vocab: dict[str, str] = Field(default_factory=dict) + + # Custom vocab categories (extensible) + custom_vocab: dict[str, dict[str, str]] = Field(default_factory=dict) + + # Prompt constraints + default_negative: str = "" + max_prompt_tokens: int = 77 + + # Priority order for token budget allocation + # Earlier items are kept when truncating + priority_order: list[str] = Field( + default_factory=lambda: [ + "trigger", + "action", + "material", + "camera", + "mood", + ] + ) + + # Path to instruction sheet (markdown/text with LLM instructions) + instruction_sheet_path: str | None = None + + # Additional metadata + metadata: dict[str, Any] = Field(default_factory=dict) + + def get_vocab(self, category: str, key: str, default: str | None = None) -> str: + """ + Look up a vocab term by category and key. + + Args: + category: Vocab category (material, motion, camera, etc.) + key: The abstract term to look up + default: Fallback if not found + + Returns: + The effective prompt tokens, or default/key if not found + """ + vocab_dict = getattr(self, f"{category}_vocab", None) + if vocab_dict is None: + vocab_dict = self.custom_vocab.get(category, {}) + + result = vocab_dict.get(key) + if result is not None: + return result + + # Check for "default" key in vocab + if "default" in vocab_dict: + return vocab_dict["default"] + + return default if default is not None else key + + def get_all_vocab(self) -> dict[str, dict[str, str]]: + """Return all vocab dictionaries merged.""" + all_vocab = { + "material": self.material_vocab, + "motion": self.motion_vocab, + "camera": self.camera_vocab, + "lighting": self.lighting_vocab, + "emotion": self.emotion_vocab, + "beat": self.beat_vocab, + } + all_vocab.update(self.custom_vocab) + return {k: v for k, v in all_vocab.items() if v} + + @classmethod + def from_yaml(cls, path: str | Path) -> "StyleManifest": + """Load a StyleManifest from a YAML file.""" + path = Path(path) + with open(path) as f: + data = yaml.safe_load(f) + return cls(**data) + + def to_yaml(self, path: str | Path) -> None: + """Save this StyleManifest to a YAML file.""" + path = Path(path) + with open(path, "w") as f: + yaml.dump(self.model_dump(exclude_none=True), f, default_flow_style=False) + + +class StyleRegistry: + """ + Registry for loading and caching StyleManifests. + + Manifests can be loaded from: + - Individual YAML files + - A directory of manifests + - Programmatic registration + """ + + def __init__(self): + self._manifests: dict[str, StyleManifest] = {} + self._default_style: str | None = None + + def register(self, manifest: StyleManifest) -> None: + """Register a manifest by name.""" + self._manifests[manifest.name] = manifest + if self._default_style is None: + self._default_style = manifest.name + + def load_from_file(self, path: str | Path) -> StyleManifest: + """Load and register a manifest from a YAML file.""" + manifest = StyleManifest.from_yaml(path) + self.register(manifest) + return manifest + + def load_from_directory(self, directory: str | Path) -> list[StyleManifest]: + """ + Load all manifest.yaml files from a directory tree. + + Manifests are loaded in sorted path order for deterministic behavior. + """ + directory = Path(directory) + if not directory.exists(): + return [] + + manifests = [] + # Sort for deterministic ordering + for manifest_path in sorted(directory.rglob("manifest.yaml")): + try: + manifest = self.load_from_file(manifest_path) + manifests.append(manifest) + except Exception as e: + # Log but don't fail on individual manifest errors + logger.warning(f"Failed to load {manifest_path}: {e}") + return manifests + + def load_from_style_dirs(self) -> list[StyleManifest]: + """ + Load styles from all configured style directories. + + Later directories win on name conflicts (user overrides repo). + """ + all_manifests = [] + for style_dir in get_style_dirs(): + manifests = self.load_from_directory(style_dir) + all_manifests.extend(manifests) + if manifests: + logger.info(f"Loaded {len(manifests)} styles from {style_dir}") + return all_manifests + + def get_all_lora_paths(self, skip_missing: bool = True) -> list[str]: + """ + Get canonical LoRA paths for all registered styles. + + Args: + skip_missing: If True, skip LoRAs that don't exist on disk (with warning) + + Returns: + List of unique canonical LoRA paths + """ + seen: dict[str, str] = {} # canonical_path -> style_name (for logging) + + for style_name in self.list_styles(): + manifest = self.get(style_name) + if not manifest or not manifest.lora_path: + continue + + canonical = canonicalize_lora_path(manifest.lora_path) + if not canonical: + continue + + # Check if file exists + if skip_missing and not Path(canonical).exists(): + logger.warning( + f"Style '{style_name}': LoRA not found at {canonical}, skipping" + ) + continue + + if canonical in seen: + logger.debug( + f"Style '{style_name}' shares LoRA with '{seen[canonical]}': {canonical}" + ) + else: + seen[canonical] = style_name + + return list(seen.keys()) + + def build_lora_scales_for_style( + self, active_style: str | None + ) -> list[dict[str, Any]]: + """ + Build lora_scales list for a style switch. + + Sets the active style's LoRA to its default scale, all others to 0.0. + Uses canonical paths for consistent matching with preloaded LoRAs. + + Args: + active_style: Name of the style to activate (None = all at 0.0) + + Returns: + List of {"path": , "scale": } dicts + """ + scales_by_path: dict[str, float] = {} + + # First, set all known style LoRA paths to 0.0 (deduped by canonical path). + for style_name in self.list_styles(): + manifest = self.get(style_name) + if not manifest or not manifest.lora_path: + continue + + canonical = canonicalize_lora_path(manifest.lora_path) + if not canonical: + continue + + scales_by_path.setdefault(canonical, 0.0) + + # Then, ensure the active style wins even if multiple styles share a path. + if active_style: + manifest = self.get(active_style) + if manifest and manifest.lora_path: + canonical = canonicalize_lora_path(manifest.lora_path) + if canonical: + scales_by_path[canonical] = manifest.lora_default_scale + + return [{"path": p, "scale": s} for p, s in scales_by_path.items()] + + def get(self, name: str) -> StyleManifest | None: + """Get a manifest by name.""" + return self._manifests.get(name) + + def get_trigger_word(self, style_name: str) -> str | None: + """Get the primary trigger word for a style. + + Args: + style_name: Name of the style + + Returns: + First trigger word, or None if style not found or has no triggers + """ + manifest = self.get(style_name) + if manifest and manifest.trigger_words: + return manifest.trigger_words[0] + return None + + def get_default(self) -> StyleManifest | None: + """Get the default style manifest.""" + if self._default_style: + return self._manifests.get(self._default_style) + return None + + def set_default(self, name: str) -> None: + """Set the default style by name.""" + if name not in self._manifests: + raise ValueError(f"Style '{name}' not found in registry") + self._default_style = name + + def list_styles(self) -> list[str]: + """List all registered style names.""" + return list(self._manifests.keys()) + + def __contains__(self, name: str) -> bool: + return name in self._manifests + + def __len__(self) -> int: + return len(self._manifests) diff --git a/src/scope/realtime/world_state.py b/src/scope/realtime/world_state.py new file mode 100644 index 000000000..6f1c2c199 --- /dev/null +++ b/src/scope/realtime/world_state.py @@ -0,0 +1,232 @@ +""" +WorldState - Domain-agnostic representation of what's happening. + +WorldState captures the "truth" of the scene without any style/LoRA knowledge. +It's the input to prompt compilation, not the output. +""" + +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +class BeatType(str, Enum): + """Narrative beat types that affect pacing and framing.""" + + SETUP = "setup" + ESCALATION = "escalation" + CLIMAX = "climax" + PAYOFF = "payoff" + RESET = "reset" + TRANSITION = "transition" + + +class CameraIntent(str, Enum): + """Abstract camera intentions (style-agnostic).""" + + ESTABLISHING = "establishing" + CLOSE_UP = "close_up" + MEDIUM = "medium" + WIDE = "wide" + LOW_ANGLE = "low_angle" + HIGH_ANGLE = "high_angle" + TRACKING = "tracking" + STATIC = "static" + + +class CharacterState(BaseModel): + """Internal state of a character.""" + + name: str + emotion: str = "neutral" + action: str = "idle" + intensity: float = Field(default=0.5, ge=0.0, le=1.0) + + # What the character knows/believes (for future narrative logic) + knowledge: dict[str, Any] = Field(default_factory=dict) + + # Relationships to other characters + relationships: dict[str, str] = Field(default_factory=dict) + + +class PropState(BaseModel): + """State of a prop in the scene.""" + + name: str + location: str = "" + visible: bool = True + material: str | None = None + state: str = "normal" # e.g., "broken", "glowing", "wet" + + +class WorldState(BaseModel): + """ + Complete state of the world at a moment in time. + + This is style-agnostic - it describes WHAT is happening, + not HOW it should be rendered. The StyleManifest + PromptCompiler + translate this into style-specific prompt tokens. + """ + + # Scene context + scene_description: str = "" + location: str = "" + time_of_day: str = "" + weather: str = "" + + # Narrative state + beat: BeatType = BeatType.SETUP + tension: float = Field(default=0.5, ge=0.0, le=1.0) + pacing: float = Field(default=0.5, ge=0.0, le=1.0) # 0=slow, 1=frantic + + # Camera + camera: CameraIntent = CameraIntent.MEDIUM + focus_target: str = "" # What/who the camera focuses on + + # Characters + characters: list[CharacterState] = Field(default_factory=list) + + # Props + props: list[PropState] = Field(default_factory=list) + + # Current action/event description + action: str = "" + + # Abstract mood/atmosphere (0-1 scales) + mood: dict[str, float] = Field(default_factory=dict) + # e.g., {"comedy": 0.8, "tension": 0.2, "warmth": 0.6} + + # Free-form tags for additional context + tags: list[str] = Field(default_factory=list) + + # Custom fields for domain-specific data + custom: dict[str, Any] = Field(default_factory=dict) + + def get_character(self, name: str) -> CharacterState | None: + """Get a character by name.""" + for char in self.characters: + if char.name == name: + return char + return None + + def get_prop(self, name: str) -> PropState | None: + """Get a prop by name.""" + for prop in self.props: + if prop.name == name: + return prop + return None + + def is_empty(self) -> bool: + """Check if this WorldState has no meaningful content. + + An empty WorldState means no scene has been described - just defaults. + In performance mode, style switches should preserve the current prompt + rather than recompiling an empty WorldState to just trigger words. + """ + return ( + not self.scene_description + and not self.action + and not self.characters + and not self.focus_target + and not self.location + and not self.tags + and not self.mood + and not self.custom + ) + + def get_mood(self, key: str, default: float = 0.5) -> float: + """Get a mood value.""" + return self.mood.get(key, default) + + def set_mood(self, key: str, value: float) -> None: + """Set a mood value (clamped to 0-1).""" + self.mood[key] = max(0.0, min(1.0, value)) + + def to_context_dict(self) -> dict[str, Any]: + """ + Convert to a flat dictionary suitable for LLM context. + + This is what gets passed to the PromptCompiler/LLM. + """ + context = { + "scene": self.scene_description, + "location": self.location, + "time_of_day": self.time_of_day, + "weather": self.weather, + "beat": self.beat.value, + "tension": self.tension, + "pacing": self.pacing, + "camera": self.camera.value, + "focus": self.focus_target, + "action": self.action, + "tags": self.tags, + } + + # Add characters + for i, char in enumerate(self.characters): + prefix = f"char_{i}" + context[f"{prefix}_name"] = char.name + context[f"{prefix}_emotion"] = char.emotion + context[f"{prefix}_action"] = char.action + context[f"{prefix}_intensity"] = char.intensity + + # Add props + for i, prop in enumerate(self.props): + prefix = f"prop_{i}" + context[f"{prefix}_name"] = prop.name + context[f"{prefix}_location"] = prop.location + context[f"{prefix}_visible"] = prop.visible + context[f"{prefix}_state"] = prop.state + if prop.material: + context[f"{prefix}_material"] = prop.material + + # Add mood values + for key, value in self.mood.items(): + context[f"mood_{key}"] = value + + # Add custom fields + context.update(self.custom) + + return context + + +# Convenience factory functions + + +def create_simple_world( + action: str, + emotion: str = "neutral", + camera: CameraIntent = CameraIntent.MEDIUM, + tension: float = 0.5, +) -> WorldState: + """Create a simple WorldState with just the basics.""" + return WorldState( + action=action, + camera=camera, + tension=tension, + characters=[CharacterState(name="subject", emotion=emotion, action=action)], + ) + + +def create_character_scene( + character_name: str, + emotion: str, + action: str, + location: str = "", + beat: BeatType = BeatType.SETUP, +) -> WorldState: + """Create a WorldState focused on a single character.""" + return WorldState( + location=location, + beat=beat, + action=action, + focus_target=character_name, + characters=[ + CharacterState( + name=character_name, + emotion=emotion, + action=action, + ) + ], + ) diff --git a/src/scope/server/frame_processor.py b/src/scope/server/frame_processor.py index ca02c0869..47beab178 100644 --- a/src/scope/server/frame_processor.py +++ b/src/scope/server/frame_processor.py @@ -1,18 +1,48 @@ +import asyncio +import copy import logging +import os import queue import threading import time -from collections import deque +import uuid +from collections import OrderedDict, deque +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path from typing import Any +import cv2 +import numpy as np import torch -from aiortc.mediastreams import VideoFrame +from scope.realtime import ( + CompiledPrompt, + StyleManifest, + StyleRegistry, + TemplateCompiler, + WorldState, + create_compiler, +) + +try: + from aiortc.mediastreams import VideoFrame +except ImportError: # pragma: no cover + VideoFrame = Any # type: ignore[misc,assignment] + +from scope.realtime.control_bus import ControlBus, EventType + +from .models_config import get_model_file_path from .pipeline_manager import PipelineManager, PipelineNotAvailableException +from .session_recorder import SessionRecorder logger = logging.getLogger(__name__) +def _is_env_true(name: str, default: str = "0") -> bool: + return os.getenv(name, default).strip().lower() in ("1", "true", "yes", "on") + + # Multiply the # of output frames from pipeline by this to get the max size of the output queue OUTPUT_QUEUE_MAX_SIZE_FACTOR = 3 @@ -26,6 +56,750 @@ INPUT_FPS_SAMPLE_SIZE = 30 # Number of frame intervals to track INPUT_FPS_MIN_SAMPLES = 5 # Minimum samples needed before using input FPS +# Snapshot constants +MAX_SNAPSHOTS = 10 # Maximum number of snapshots to keep (LRU eviction) + +# Continuity keys from pipeline.state that define generation continuity +CONTINUITY_KEYS = [ + "current_start_frame", + "first_context_frame", + "context_frame_buffer", + "decoded_frame_buffer", + "context_frame_buffer_max_size", + "decoded_frame_buffer_max_size", +] + + +# VACE control map modes +VACE_CONTROL_MAP_MODES = ["none", "canny", "pidinet", "depth", "composite", "external"] + +# Control-map preview worker policy: +# - "canny" is CPU-only and usually cheap enough to run at input FPS for preview. +# - "pidinet" / "depth" / "composite" can be GPU-heavy; running them concurrently with +# generation can materially reduce end-to-end FPS. +VACE_CONTROL_MAP_WORKER_HEAVY_MODES = {"pidinet", "depth", "composite"} + + +def apply_canny_edges( + frames: list[torch.Tensor], + low_threshold: int | None = None, + high_threshold: int | None = None, + blur_kernel_size: int = 5, + blur_sigma: float = 1.4, + adaptive_thresholds: bool = True, + dilate_edges: bool = False, + dilate_kernel_size: int = 2, +) -> list[torch.Tensor]: + """Apply Canny edge detection to video frames for VACE control. + + Args: + frames: List of tensors, each (1, H, W, C) `uint8` or float in [0, 255]. + low_threshold: Canny low threshold. If None and adaptive_thresholds=True, + computed as 0.66 * median pixel value. + high_threshold: Canny high threshold. If None and adaptive_thresholds=True, + computed as 1.33 * median pixel value. + blur_kernel_size: Gaussian blur kernel size (must be odd). Set to 0 to disable. + blur_sigma: Gaussian blur sigma. Higher = more blur. + adaptive_thresholds: If True and thresholds are None, compute thresholds + based on image statistics (median-based). Produces cleaner edges. + dilate_edges: If True, dilate edges to make them thicker/more visible. + dilate_kernel_size: Size of dilation kernel. + + Returns: + List of edge tensors, each (1, H, W, 3) uint8 in [0, 255]. + """ + result = [] + + # Prepare dilation kernel if needed + dilate_kernel = None + if dilate_edges and dilate_kernel_size > 0: + dilate_kernel = np.ones((dilate_kernel_size, dilate_kernel_size), np.uint8) + + for frame in frames: + # frame is (1, H, W, C) uint8 or float in [0, 255] + img_t = frame.squeeze(0) + if img_t.dtype == torch.uint8: + img = img_t.numpy() + else: + img = img_t.clamp(0, 255).to(dtype=torch.uint8).numpy() + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + # Apply Gaussian blur to reduce noise (improves edge quality significantly) + if blur_kernel_size > 0: + # Ensure kernel size is odd + k = blur_kernel_size if blur_kernel_size % 2 == 1 else blur_kernel_size + 1 + gray = cv2.GaussianBlur(gray, (k, k), blur_sigma) + + # Compute adaptive thresholds if not provided + if adaptive_thresholds and (low_threshold is None or high_threshold is None): + # Median-based adaptive thresholds (common technique for Canny) + median_val = np.median(gray) + low_t = int(max(0, 0.66 * median_val)) if low_threshold is None else low_threshold + high_t = int(min(255, 1.33 * median_val)) if high_threshold is None else high_threshold + else: + # Use provided thresholds or defaults + low_t = low_threshold if low_threshold is not None else 100 + high_t = high_threshold if high_threshold is not None else 200 + + edges = cv2.Canny(gray, low_t, high_t) + + # Optional dilation to thicken edges + if dilate_kernel is not None: + edges = cv2.dilate(edges, dilate_kernel, iterations=1) + + # Convert back to 3-channel RGB + edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) + result.append(torch.from_numpy(edges_rgb).unsqueeze(0)) + return result + + +def soft_max_fusion( + depth: torch.Tensor, + edges: torch.Tensor, + edge_strength: float = 0.6, + sharpness: float = 10.0, +) -> torch.Tensor: + """Fuse depth and edges using soft max for smooth transitions. + + Soft max avoids hard discontinuities at transition boundaries. + Edges "punch through" where strong, depth dominates elsewhere. + + Args: + depth: Depth tensor normalized to [0, 1], shape (H, W) or (1, H, W, C). + edges: Edge tensor normalized to [0, 1], shape (H, W) or (1, H, W, C). + edge_strength: Scale factor for edges (0.5-0.7 recommended). + sharpness: Controls transition sharpness (higher = sharper). + + Returns: + Fused tensor in [0, 1], same shape as input. + """ + # Inputs should already be normalized, but clamp defensively. + depth = torch.clamp(depth, 0.0, 1.0) + edges = torch.clamp(edges, 0.0, 1.0) + + edge_strength = max(0.0, float(edge_strength)) + sharpness = float(sharpness) + if sharpness <= 0: + return torch.clamp(depth, 0.0, 1.0) + + scaled_edges = edges * edge_strength + + # Soft max (stable): logaddexp(a, b) / sharpness. + d_scaled = torch.clamp(sharpness * depth, -20, 20) + e_scaled = torch.clamp(sharpness * scaled_edges, -20, 20) + fused = torch.logaddexp(d_scaled, e_scaled) / sharpness + + # Stabilize to [0, 1] without per-frame min/max normalization. + # For a=b=0, fused=log(2)/sharpness; for depth=1 and edges=1, fused reaches a fixed max. + fused_min = float(np.log(2.0) / sharpness) + fused_max = float(np.logaddexp(sharpness, sharpness * edge_strength) / sharpness) + denom = fused_max - fused_min + if denom <= 1e-6: + return torch.clamp(fused, 0.0, 1.0) + return torch.clamp((fused - fused_min) / denom, 0.0, 1.0) + + +class VDADepthControlMapGenerator: + """Video Depth Anything streaming depth generator for VACE control maps. + + This class owns the VDA model and its streaming cache state. It should be + owned by a FrameProcessor (per-session) and reset on hard cuts. + + The model is lazy-loaded on first use to avoid GPU memory allocation + until depth mode is actually enabled. + """ + + # Model configs from VDA repo + MODEL_CONFIGS = { + "vits": { + "encoder": "vits", + "features": 64, + "out_channels": [48, 96, 192, 384], + }, + "vitb": { + "encoder": "vitb", + "features": 128, + "out_channels": [96, 192, 384, 768], + }, + "vitl": { + "encoder": "vitl", + "features": 256, + "out_channels": [256, 512, 1024, 1024], + }, + } + + def __init__( + self, + encoder: str = "vits", + checkpoint_path: str | Path | None = None, + device: str = "cuda", + input_size: int = 518, + ): + self.encoder = encoder + self.checkpoint_path = Path( + checkpoint_path + if checkpoint_path is not None + else get_model_file_path("vda/video_depth_anything_vits.pth") + ) + self.device = device + self.input_size = input_size + + self._model = None + self._model_loaded = False + + # Stabilization state (running quantiles for normalization) + self._q_lo: float | torch.Tensor | None = None + self._q_hi: float | torch.Tensor | None = None + self._quantile_momentum = 0.95 # EMA momentum for quantile updates + + # Contrast adjustment: gamma curve applied after normalization. + # Values > 1.0 increase mid-tone contrast (emphasize subtle depth differences). + # Values < 1.0 reduce contrast. Default 1.0 = no change. + self._depth_contrast: float = 1.0 + + @property + def depth_contrast(self) -> float: + """Get current depth contrast (gamma) value.""" + return self._depth_contrast + + @depth_contrast.setter + def depth_contrast(self, value: float) -> None: + """Set depth contrast (gamma). Values > 1.0 increase contrast.""" + self._depth_contrast = max(0.1, min(5.0, float(value))) # Clamp to sane range + + def _load_model(self): + """Lazy-load the VDA model on first use.""" + if self._model_loaded: + return + + if not self.checkpoint_path.exists(): + raise FileNotFoundError( + f"VDA checkpoint not found: {self.checkpoint_path} " + "(expected under DAYDREAM_SCOPE_MODELS_DIR)." + ) + + try: + from scope.vendored.video_depth_anything import VideoDepthAnything + + config = self.MODEL_CONFIGS[self.encoder] + self._model = VideoDepthAnything(**config) + self._model.load_state_dict( + torch.load(self.checkpoint_path, map_location="cpu"), + strict=True, + ) + self._model = self._model.to(self.device).eval() + + if _is_env_true("SCOPE_VDA_COMPILE", default="0"): + fullgraph = _is_env_true("SCOPE_VDA_TORCH_COMPILE_FULLGRAPH", default="0") + dynamic = _is_env_true("SCOPE_VDA_TORCH_COMPILE_DYNAMIC", default="0") + compiled_any = False + try: + self._model.forward_features = torch.compile( + self._model.forward_features, + fullgraph=fullgraph, + dynamic=dynamic, + ) + compiled_any = True + except Exception as e: + logger.warning("VDA torch.compile forward_features failed: %s", e) + + try: + self._model.forward_depth = torch.compile( + self._model.forward_depth, + fullgraph=fullgraph, + dynamic=dynamic, + ) + compiled_any = True + except Exception as e: + logger.warning("VDA torch.compile forward_depth failed: %s", e) + + if compiled_any: + logger.info( + "VDA torch.compile enabled (fullgraph=%s dynamic=%s)", + fullgraph, + dynamic, + ) + + self._model_loaded = True + logger.info( + "VDA model loaded: encoder=%s device=%s checkpoint=%s", + self.encoder, + self.device, + self.checkpoint_path, + ) + except Exception as e: + logger.error(f"Failed to load VDA model: {e}") + raise + + def _reset_model_temporal_cache(self) -> None: + """Reset only the VDA streaming cache (not normalization state).""" + if self._model is None: + return + # Reset VDA streaming cache + self._model.transform = None + self._model.frame_id_list = [] + self._model.frame_cache_list = [] + self._model.id = -1 + + def reset_cache(self): + """Reset streaming cache and stabilization state. + + Call this on hard cuts (when init_cache=True) to prevent blending + depth across discontinuities. + """ + self._reset_model_temporal_cache() + + # Reset stabilization state + self._q_lo = None + self._q_hi = None + logger.debug("VDA depth cache reset") + + def process_frames( + self, + frames: list[torch.Tensor], + hard_cut: bool = False, + *, + input_size: int | None = None, + fp32: bool | None = None, + temporal_mode: str | None = None, + output_device: str = "cpu", + ) -> list[torch.Tensor]: + """Process frames through VDA and return depth maps as control frames. + + Args: + frames: List of tensors, each (1, H, W, C) uint8 or float in [0, 255]. + hard_cut: If True, reset cache before processing. + input_size: Override VDA resize target (lower is faster). If changed mid-stream, + the generator treats it like a hard cut and resets streaming state. + fp32: If True, force FP32 (disables autocast). Default is False (autocast enabled). + temporal_mode: Controls whether VDA uses its streaming temporal cache. + - "stream" (default): uses temporal cache (more stable, can trail/ghost on fast motion) + - "stateless": disables temporal cache (no trails, can be noisier) + output_device: Device for returned control frames ("cpu" or "cuda"). Default: "cpu". + + Returns: + List of depth tensors, each (1, H, W, 3) uint8 in [0, 255]. + Depth is normalized per-session using running quantiles. + """ + self._load_model() + + if input_size is not None: + input_size_int = int(input_size) + if input_size_int <= 0: + raise ValueError(f"input_size must be > 0, got {input_size_int}") + # VDA's transform is initialized only on the first frame; changing the + # input_size mid-stream requires a reset so the transform + caches are consistent. + if input_size_int != self.input_size: + self.input_size = input_size_int + hard_cut = True + + fp32_flag = bool(fp32) if fp32 is not None else False + + temporal_mode_env = os.getenv("SCOPE_VACE_DEPTH_TEMPORAL_MODE", "").strip() + temporal_mode_norm = ( + (temporal_mode if temporal_mode is not None else temporal_mode_env) or "stream" + ).strip().lower() + if temporal_mode_norm in ("stream", "streaming", "temporal", "1", "true", "yes", "on"): + use_temporal_cache = True + elif temporal_mode_norm in ( + "stateless", + "single", + "single_frame", + "no_cache", + "0", + "false", + "no", + "off", + ): + use_temporal_cache = False + else: + logger.warning( + "Invalid vace_depth_temporal_mode=%r; expected 'stream' or 'stateless' (falling back to 'stream')", + temporal_mode_norm, + ) + use_temporal_cache = True + + if hard_cut: + self.reset_cache() + + output_device_norm = (output_device or "cpu").strip().lower() + output_on_gpu = output_device_norm not in ("", "cpu") + + result = [] + quantile_stride = int(os.getenv("SCOPE_VACE_DEPTH_QUANTILE_STRIDE", "4") or "4") + if quantile_stride < 1: + quantile_stride = 1 + for frame in frames: + # frame is (1, H, W, C) uint8 or float in [0, 255] + # VDA expects RGB numpy array (H, W, 3) uint8 + img_t = frame.squeeze(0) + if img_t.dtype == torch.uint8: + img = img_t.numpy() + else: + img = img_t.clamp(0, 255).to(dtype=torch.uint8).numpy() + + # Infer depth (H, W): numpy array by default, or a torch tensor when output_on_gpu. + try: + with torch.no_grad(): + if not use_temporal_cache: + # Stateless mode: force VDA to treat every frame as a "first frame" by + # resetting its temporal cache before each inference call. + self._reset_model_temporal_cache() + depth = self._model.infer_video_depth_one( + img, + input_size=self.input_size, + device=self.device, + fp32=fp32_flag, + return_torch=output_on_gpu, + ) + except AssertionError: + # VDA streaming asserts frame size consistency; treat size changes + # as a hard cut and re-run as "first frame". + self.reset_cache() + with torch.no_grad(): + if not use_temporal_cache: + self._reset_model_temporal_cache() + depth = self._model.infer_video_depth_one( + img, + input_size=self.input_size, + device=self.device, + fp32=fp32_flag, + return_torch=output_on_gpu, + ) + + if not output_on_gpu: + # ------------------------------- + # CPU output path (legacy) + # ------------------------------- + depth_np = depth + + # Stabilize using running quantiles (avoid per-frame normalization) + depth_sample = depth_np + if quantile_stride > 1: + depth_sample = depth_np[::quantile_stride, ::quantile_stride] + q_lo_frame, q_hi_frame = np.percentile(depth_sample, [2, 98]) + + # Initialize quantiles on first frame to avoid startup saturation. + if self._q_lo is None or self._q_hi is None: + self._q_lo = float(q_lo_frame) + self._q_hi = float(q_hi_frame) + else: + q_lo_prev = ( + float(self._q_lo.detach().float().cpu().item()) + if isinstance(self._q_lo, torch.Tensor) + else float(self._q_lo) + ) + q_hi_prev = ( + float(self._q_hi.detach().float().cpu().item()) + if isinstance(self._q_hi, torch.Tensor) + else float(self._q_hi) + ) + self._q_lo = ( + self._quantile_momentum * q_lo_prev + + (1 - self._quantile_momentum) * float(q_lo_frame) + ) + self._q_hi = ( + self._quantile_momentum * q_hi_prev + + (1 - self._quantile_momentum) * float(q_hi_frame) + ) + + # Normalize to [0, 1] using running quantiles + q_lo = float(self._q_lo) + q_hi = float(self._q_hi) + depth_range = max(q_hi - q_lo, 1e-6) + depth_norm = np.clip((depth_np - q_lo) / depth_range, 0.0, 1.0) + + # Apply contrast (gamma) curve: values > 1.0 increase mid-tone contrast + if self._depth_contrast != 1.0: + depth_norm = np.power(depth_norm, 1.0 / self._depth_contrast) + + # Convert to [0, 255] and replicate to 3 channels + depth_uint8 = (depth_norm * 255.0).astype(np.uint8) + depth_rgb = np.repeat(depth_uint8[:, :, None], 3, axis=2) + + result.append(torch.from_numpy(depth_rgb).unsqueeze(0)) + continue + + # ------------------------------- + # GPU output path + # ------------------------------- + if not isinstance(depth, torch.Tensor): + raise TypeError( + f"Expected torch depth when output_device={output_device!r}, got: {type(depth)}" + ) + + depth_t = depth + if output_device_norm not in ("cuda", "gpu") and not output_device_norm.startswith( + "cuda:" + ): + # Only "cuda*" devices are supported for output_on_gpu (caller requested non-cpu). + raise ValueError( + f"Unsupported output_device={output_device!r}; expected 'cpu' or 'cuda'" + ) + + use_gpu_quantiles = os.getenv("SCOPE_VACE_DEPTH_GPU_QUANTILES", "").lower() in ( + "1", + "true", + "yes", + "on", + ) + if use_gpu_quantiles: + # Keep the normalization path on GPU to avoid the `.cpu().numpy()` + # sync point when output_device=cuda. + depth_sample_t = depth_t + if quantile_stride > 1: + depth_sample_t = depth_sample_t[::quantile_stride, ::quantile_stride] + + sample_flat = depth_sample_t.detach().to(dtype=torch.float32).flatten() + q_lo_frame_t = torch.quantile(sample_flat, 0.02) + q_hi_frame_t = torch.quantile(sample_flat, 0.98) + + # Ensure cached quantiles are torch scalars on the right device. + if self._q_lo is None: + self._q_lo = q_lo_frame_t + elif not isinstance(self._q_lo, torch.Tensor): + self._q_lo = torch.tensor( + float(self._q_lo), + device=depth_t.device, + dtype=torch.float32, + ) + + if self._q_hi is None: + self._q_hi = q_hi_frame_t + elif not isinstance(self._q_hi, torch.Tensor): + self._q_hi = torch.tensor( + float(self._q_hi), + device=depth_t.device, + dtype=torch.float32, + ) + + self._q_lo = ( + self._quantile_momentum * self._q_lo + + (1 - self._quantile_momentum) * q_lo_frame_t + ) + self._q_hi = ( + self._quantile_momentum * self._q_hi + + (1 - self._quantile_momentum) * q_hi_frame_t + ) + + depth_range_t = (self._q_hi - self._q_lo).clamp_min(1e-6) + depth_norm_t = ((depth_t - self._q_lo) / depth_range_t).clamp(0.0, 1.0) + else: + # Stabilize quantiles on CPU using a downsampled grid to minimize sync cost. + depth_sample_t = depth_t + if quantile_stride > 1: + depth_sample_t = depth_sample_t[::quantile_stride, ::quantile_stride] + q_lo_frame, q_hi_frame = np.percentile( + depth_sample_t.detach().float().cpu().numpy(), [2, 98] + ) + + if self._q_lo is None or self._q_hi is None: + self._q_lo = float(q_lo_frame) + self._q_hi = float(q_hi_frame) + else: + q_lo_prev = ( + float(self._q_lo.detach().float().cpu().item()) + if isinstance(self._q_lo, torch.Tensor) + else float(self._q_lo) + ) + q_hi_prev = ( + float(self._q_hi.detach().float().cpu().item()) + if isinstance(self._q_hi, torch.Tensor) + else float(self._q_hi) + ) + self._q_lo = ( + self._quantile_momentum * q_lo_prev + + (1 - self._quantile_momentum) * float(q_lo_frame) + ) + self._q_hi = ( + self._quantile_momentum * q_hi_prev + + (1 - self._quantile_momentum) * float(q_hi_frame) + ) + + q_lo = float(self._q_lo) + q_hi = float(self._q_hi) + depth_range = max(q_hi - q_lo, 1e-6) + + depth_norm_t = ((depth_t - q_lo) / depth_range).clamp(0.0, 1.0) + + # Apply contrast (gamma) curve: values > 1.0 increase mid-tone contrast + if self._depth_contrast != 1.0: + depth_norm_t = torch.pow(depth_norm_t, 1.0 / self._depth_contrast) + + depth_uint8_t = (depth_norm_t * 255.0).to(dtype=torch.uint8) + depth_rgb_t = depth_uint8_t.unsqueeze(-1).repeat(1, 1, 3) + + result.append(depth_rgb_t.unsqueeze(0)) + + return result + + +class PiDiNetEdgeGenerator: + """PiDiNet neural edge detector for high-quality VACE control maps. + + Uses the PiDiNet model from controlnet_aux for learned edge detection + that produces cleaner, more semantically meaningful edges than Canny. + + Requires: controlnet_aux. If missing, run `uv sync` or `pip install controlnet_aux`. + """ + + def __init__( + self, + device: str = "cuda", + safe_mode: bool = True, + ): + """Initialize PiDiNet edge generator. + + Args: + device: Device to run model on ("cuda" or "cpu"). + safe_mode: If True, use safe/cleaner edge detection mode. + """ + self.device = device + self.safe_mode = safe_mode + self._model = None + self._model_loaded = False + + def _load_model(self): + """Lazy-load the PiDiNet model on first use.""" + if self._model_loaded: + return + + try: + from controlnet_aux import PidiNetDetector + + self._model = PidiNetDetector.from_pretrained("lllyasviel/Annotators") + # Move to device if possible (some models support .to()) + if hasattr(self._model, "to"): + self._model = self._model.to(self.device) + self._model_loaded = True + logger.info( + "PiDiNet model loaded: device=%s safe_mode=%s", + self.device, + self.safe_mode, + ) + except ImportError as e: + raise ImportError( + "controlnet_aux not installed. Install with: pip install controlnet_aux " + "(or run: uv sync)." + ) from e + except Exception as e: + logger.error(f"Failed to load PiDiNet model: {e}") + raise + + def process_frames( + self, frames: list[torch.Tensor], apply_filter: bool = True + ) -> list[torch.Tensor]: + """Process frames through PiDiNet and return edge maps. + + Args: + frames: List of tensors, each (1, H, W, C) `uint8` or float in [0, 255]. + apply_filter: If True, apply post-processing filter for cleaner edges. + + Returns: + List of edge tensors, each (1, H, W, 3) float in [0, 255]. + """ + self._load_model() + + from PIL import Image + + result = [] + for frame in frames: + # frame is (1, H, W, C) uint8 or float in [0, 255] + img_t = frame.squeeze(0) + if img_t.dtype == torch.uint8: + img = img_t.numpy() + else: + img = img_t.clamp(0, 255).to(dtype=torch.uint8).numpy() + + # Convert to PIL Image for controlnet_aux + pil_img = Image.fromarray(img) + + # Run PiDiNet detection + edge_pil = self._model( + pil_img, + detect_resolution=min(pil_img.width, pil_img.height), + image_resolution=min(pil_img.width, pil_img.height), + safe=self.safe_mode, + apply_filter=apply_filter, + ) + + # Convert back to numpy/tensor + edge_np = np.array(edge_pil) + + # Ensure 3-channel output + if len(edge_np.shape) == 2: + edge_np = cv2.cvtColor(edge_np, cv2.COLOR_GRAY2RGB) + elif edge_np.shape[2] == 4: + edge_np = cv2.cvtColor(edge_np, cv2.COLOR_RGBA2RGB) + + # Ensure output matches input resolution (required for composite mode). + if edge_np.shape[0] != img.shape[0] or edge_np.shape[1] != img.shape[1]: + edge_np = cv2.resize( + edge_np, + (img.shape[1], img.shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + + result.append(torch.from_numpy(edge_np).unsqueeze(0).float()) + + return result + + +@dataclass +class Snapshot: + """Server-side snapshot of generation state at a chunk boundary. + + Snapshots are stored in-memory and contain cloned GPU tensors. + Clients receive only snapshot_id + metadata, not the actual tensor data. + """ + + snapshot_id: str + chunk_index: int + created_at: float + + # Continuity state (cloned tensors from pipeline.state) + current_start_frame: int = 0 + first_context_frame: torch.Tensor | None = None + context_frame_buffer: torch.Tensor | None = None + decoded_frame_buffer: torch.Tensor | None = None + context_frame_buffer_max_size: int = 0 + decoded_frame_buffer_max_size: int = 0 + + # Control state (deep copy of parameters) + parameters: dict[str, Any] = field(default_factory=dict) + paused: bool = False + video_mode: bool = False + + # Style layer state (minimal, deterministic) + world_state_json: str | None = None # JSON string for thread-safe restore + active_style_name: str | None = None # Used for edge-triggering LoRA updates + compiled_prompt_text: str | None = None # For debugging + + # Compatibility metadata (for future validation) + pipeline_id: str | None = None + resolution: tuple[int, int] | None = None + + +class _FrameWithID: + """Attach a monotonically increasing frame_id to any input frame-like object. + + The wrapped object must expose: to_ndarray(format="rgb24") -> np.ndarray + Phase 2.1b: Used for frame ID tracking in control buffer architecture. + """ + + __slots__ = ["frame_id", "_frame"] + + def __init__(self, frame: Any, frame_id: int): + self._frame = frame + self.frame_id = int(frame_id) + + def to_ndarray(self, format: str = "rgb24"): + return self._frame.to_ndarray(format=format) + + def __getattr__(self, name: str) -> Any: + # Delegate any other access (pts, time_base, etc.) to the underlying frame. + return getattr(self._frame, name) + class _SpoutFrame: """Lightweight wrapper for Spout frames to match VideoFrame interface.""" @@ -39,6 +813,482 @@ def to_ndarray(self, format="rgb24"): return self._data +class _NDIFrame: + """Lightweight wrapper for NDI frames to match VideoFrame interface.""" + + __slots__ = ["_data"] + + def __init__(self, data): + self._data = data + + def to_ndarray(self, format="rgb24"): + return self._data + + +class ControlMapWorker: + """Background worker that generates control maps from raw frames. + + This worker runs independently of chunk processing, receiving raw frames and + producing control maps continuously. + + - Phase 2.1a: feeds the MJPEG preview stream at input FPS. + - Phase 2.1b (when control buffer is enabled): also writes per-frame control + maps into a ring buffer keyed by `frame_id`, so generation can sample + precomputed control maps instead of doing chunk-time control-map compute. + + Note: In `vace_control_map_mode="external"`, incoming frames are already + control maps and are not re-processed by this worker. + """ + + def __init__( + self, + latest_control_frame_lock: threading.Lock, + parameters_getter: callable, + max_queue_size: int = 60, + control_buffer_size: int = 120, + ): + """Initialize control map worker. + + Args: + latest_control_frame_lock: Lock for updating latest_control_frame_cpu. + parameters_getter: Callable that returns current parameters dict. + max_queue_size: Max frames to buffer (small = low latency). + control_buffer_size: Ring buffer size (frames) for generation sampling by frame_id. + """ + self._lock = latest_control_frame_lock + self._get_params = parameters_getter + self._input_queue: queue.Queue = queue.Queue(maxsize=max_queue_size) + self._latest_frame: torch.Tensor | None = None + + # Phase 2.1b: generation ring buffer keyed by frame_id + self._control_buffer_maxlen = int(control_buffer_size) + self._control_buffer: OrderedDict[int, torch.Tensor] = OrderedDict() + self._control_buffer_lock = threading.Lock() + self._last_frame_id: int | None = None + + self._thread: threading.Thread | None = None + self._shutdown = threading.Event() + self._running = False + + # Own generator instances (not shared with chunk processor) + self._depth_generator: VDADepthControlMapGenerator | None = None + self._pidinet_generator: PiDiNetEdgeGenerator | None = None + + # Stats + self._frames_processed = 0 + self._last_mode: str | None = None + self._dropped_input_frames = 0 + + # Hard cut request (applied on worker thread to avoid cross-thread generator mutation) + self._hard_cut_requested = False + + # Throttle: cap processing rate to reduce GPU contention. + # Disabled by default (max_fps <= 0). + self._max_fps: float | None = None + self._min_process_interval_s = 0.0 + self._last_process_t = 0.0 + + # Phase 2.1b: worker is the primary producer for generation when buffer enabled, + # so heavy modes should be allowed by default when buffer is on. + # When buffer disabled: default allow_heavy=0 (old behavior, preview-only) + # When buffer enabled: default allow_heavy=1 (worker feeds generation) + # Explicit SCOPE_VACE_CONTROL_MAP_WORKER_ALLOW_HEAVY always wins. + buffer_enabled = _is_env_true("SCOPE_VACE_CONTROL_BUFFER_ENABLED", default="0") + allow_heavy_default = "1" if buffer_enabled else "0" + self._allow_heavy = _is_env_true( + "SCOPE_VACE_CONTROL_MAP_WORKER_ALLOW_HEAVY", default=allow_heavy_default + ) + + def set_max_fps(self, max_fps: float | None) -> None: + if max_fps is None: + self._max_fps = None + self._min_process_interval_s = 0.0 + return + + max_fps_f = float(max_fps) + if max_fps_f <= 0: + self._max_fps = None + self._min_process_interval_s = 0.0 + return + + self._max_fps = max_fps_f + self._min_process_interval_s = 1.0 / max_fps_f + + def set_allow_heavy(self, allow_heavy: bool) -> None: + """Enable/disable GPU-heavy control-map modes in the worker (depth/pidinet/composite).""" + self._allow_heavy = bool(allow_heavy) + + def clear_latest(self, *, lock_held: bool = False) -> None: + """Clear cached preview output so callers fall back to chunk outputs.""" + if lock_held: + self._latest_frame = None + return + with self._lock: + self._latest_frame = None + + def clear_generation_buffer(self) -> None: + """Clear the generation ring buffer (Phase 2.1b).""" + with self._control_buffer_lock: + self._control_buffer.clear() + self._last_frame_id = None + + def _drain_input_queue(self) -> int: + """Best-effort drain of pending frames (used on hard cuts).""" + drained = 0 + while True: + try: + item = self._input_queue.get_nowait() + if item is None: + continue + drained += 1 + except queue.Empty: + break + return drained + + def request_hard_cut(self, *, clear_queue: bool = True, reason: str | None = None) -> None: + """Request a hard cut: + - clears generation buffer + preview immediately, + - optionally drains queued frames, + - resets VDA streaming caches on the worker thread before the next processed frame. + """ + _ = reason # reserved for future debug logging + self._hard_cut_requested = True + self.clear_latest() + self.clear_generation_buffer() + if clear_queue: + self._drain_input_queue() + + def get_debug_info(self) -> dict[str, object]: + try: + queue_depth = int(self._input_queue.qsize()) + except Exception: + queue_depth = -1 + + with self._control_buffer_lock: + buffer_depth = len(self._control_buffer) + last_frame_id = self._last_frame_id + + return { + "running": bool(self._running), + "last_mode": self._last_mode, + "frames_processed": int(self._frames_processed), + "dropped_input_frames": int(self._dropped_input_frames), + "queue_depth": queue_depth, + "control_buffer_depth": int(buffer_depth), + "last_frame_id": int(last_frame_id) if last_frame_id is not None else None, + "max_fps": float(self._max_fps) if self._max_fps is not None else None, + "allow_heavy": bool(self._allow_heavy), + "heavy_modes": sorted(VACE_CONTROL_MAP_WORKER_HEAVY_MODES), + } + + def start(self): + """Start the control map worker thread.""" + if self._running: + return + self._shutdown.clear() + self._running = True + self._thread = threading.Thread( + target=self._worker_loop, name="ControlMapWorker", daemon=True + ) + self._thread.start() + logger.info("ControlMapWorker started") + + def stop(self): + """Stop the control map worker thread.""" + if not self._running: + return + self._running = False + self._shutdown.set() + # Unblock queue.get() by putting sentinel + try: + self._input_queue.put_nowait(None) + except queue.Full: + pass + if self._thread: + self._thread.join(timeout=2.0) + self._thread = None + # Clear generators to free GPU memory + self._depth_generator = None + self._pidinet_generator = None + self.clear_generation_buffer() + logger.info(f"ControlMapWorker stopped after {self._frames_processed} frames") + + def put(self, frame) -> bool: + """Enqueue a raw frame for control map processing. + + Non-blocking: drops frame if queue is full (preview can skip frames). + + Args: + frame: VideoFrame, _SpoutFrame, or _FrameWithID with to_ndarray() method. + + Returns: + True if queued, False if dropped. + """ + if not self._running: + return False + try: + self._input_queue.put_nowait(frame) + return True + except queue.Full: + self._dropped_input_frames += 1 + # Note: We do NOT trigger hard cut on single drop (thrashing risk). + # Future: could add threshold-based hard cut if drops are sustained. + return False + + def get_latest(self) -> torch.Tensor | None: + """Get the most recent control frame. + + Returns (H, W, 3) uint8 tensor or None. + """ + with self._lock: + if self._latest_frame is not None: + return self._latest_frame.clone() + return None + + def reset_cache(self): + """Reset streaming caches (call on hard cuts). Backward-compatible alias.""" + self.request_hard_cut(clear_queue=True, reason="reset_cache") + + def sample_control_frames(self, frame_ids: list[int]) -> list[torch.Tensor] | None: + """Sample generation control frames by exact frame_id (Phase 2.1b). + + Returns: + List of (1, H, W, 3) tensors aligned to frame_ids, or None if any missing. + """ + if not frame_ids: + return [] + with self._control_buffer_lock: + # Exact-match only (missing policy is handled by FrameProcessor). + for fid in frame_ids: + if fid not in self._control_buffer: + return None + return [self._control_buffer[fid].clone() for fid in frame_ids] + + def _worker_loop(self): + """Main worker loop: process frames continuously.""" + while self._running and not self._shutdown.is_set(): + try: + # Block with timeout so we can check shutdown + frame = self._input_queue.get(timeout=0.1) + if frame is None: # Sentinel + continue + + # Apply pending hard cut on worker thread (safe point) + if self._hard_cut_requested: + self._hard_cut_requested = False + if self._depth_generator is not None: + self._depth_generator.reset_cache() + self.clear_generation_buffer() + self.clear_latest() + logger.debug("ControlMapWorker hard cut applied") + + # Get current mode from parameters + params = self._get_params() + mode = params.get("vace_control_map_mode", "none") + + if mode != self._last_mode: + logger.info( + "ControlMapWorker mode changed: %s -> %s", self._last_mode, mode + ) + self._last_mode = mode + # Prevent stale preview frames when switching modes (or disabling control maps). + self.clear_latest() + # Also clear generation buffer and reset streaming caches. + self.request_hard_cut(clear_queue=False, reason="mode_change") + + # Skip when control maps are disabled or passthrough-only. + # "external" mode means frames are already control maps and should not be + # re-processed by the worker. + if mode in ("none", "external"): + continue + + # Default: avoid duplicating GPU-heavy annotators in the preview worker. + # Chunk processing will still generate control frames at chunk rate. + if (not self._allow_heavy) and mode in VACE_CONTROL_MAP_WORKER_HEAVY_MODES: + continue + + if self._min_process_interval_s > 0: + now = time.perf_counter() + if now - self._last_process_t < self._min_process_interval_s: + continue + + # Process frame + gen_frame, preview_frame = self._process_frame(frame, mode, params) + if gen_frame is not None and preview_frame is not None: + with self._lock: + self._latest_frame = preview_frame + + # Phase 2.1b: write to generation ring buffer keyed by frame_id + frame_id = getattr(frame, "frame_id", None) + if frame_id is not None: + with self._control_buffer_lock: + self._control_buffer[int(frame_id)] = gen_frame.detach().cpu() + self._control_buffer.move_to_end(int(frame_id)) + while len(self._control_buffer) > self._control_buffer_maxlen: + self._control_buffer.popitem(last=False) + self._last_frame_id = int(frame_id) + + self._frames_processed += 1 + self._last_process_t = time.perf_counter() + + except queue.Empty: + continue + except Exception as e: + logger.error(f"ControlMapWorker error: {e}", exc_info=True) + time.sleep(0.01) + + def _process_frame( + self, frame, mode: str, params: dict + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[None, None]: + """Process a single frame through the appropriate control map generator. + + Args: + frame: VideoFrame, _SpoutFrame, or _FrameWithID. + mode: Control map mode ("canny", "pidinet", "depth", "composite"). + params: Current parameters dict. + + Returns: + Tuple of (gen_frame, preview_frame): + - gen_frame: (1, H, W, 3) tensor for generation (dtype depends on mode) + - preview_frame: (H, W, 3) uint8 tensor for MJPEG preview + Or (None, None) on error. + """ + try: + # Extract RGB data from frame + rgb = frame.to_ndarray(format="rgb24") + # Create tensor in format expected by generators: (1, H, W, C) uint8 [0, 255] + frame_tensor = torch.from_numpy(rgb).unsqueeze(0) + frames = [frame_tensor] + + if mode == "canny": + low = params.get("vace_canny_low_threshold") + high = params.get("vace_canny_high_threshold") + blur_kernel = params.get("vace_canny_blur_kernel", 5) + blur_sigma = params.get("vace_canny_blur_sigma", 1.4) + adaptive = params.get("vace_canny_adaptive", True) + dilate = params.get("vace_canny_dilate", False) + dilate_size = params.get("vace_canny_dilate_size", 2) + + control_frames = apply_canny_edges( + frames, + low_threshold=low, + high_threshold=high, + blur_kernel_size=blur_kernel, + blur_sigma=blur_sigma, + adaptive_thresholds=adaptive, + dilate_edges=dilate, + dilate_kernel_size=dilate_size, + ) + + elif mode == "pidinet": + if self._pidinet_generator is None: + self._pidinet_generator = PiDiNetEdgeGenerator() + safe_mode = params.get("vace_pidinet_safe", True) + apply_filter = params.get("vace_pidinet_filter", True) + self._pidinet_generator.safe_mode = safe_mode + control_frames = self._pidinet_generator.process_frames( + frames, apply_filter=apply_filter + ) + + elif mode == "depth": + if self._depth_generator is None: + self._depth_generator = VDADepthControlMapGenerator() + # Note: hard_cut should be signaled via reset_cache() externally + depth_input_size = params.get("vace_depth_input_size") + depth_fp32 = params.get("vace_depth_fp32") + depth_temporal_mode = params.get("vace_depth_temporal_mode") + depth_contrast = params.get("vace_depth_contrast") + if depth_contrast is not None: + self._depth_generator.depth_contrast = depth_contrast + control_frames = self._depth_generator.process_frames( + frames, + hard_cut=False, + input_size=depth_input_size, + fp32=depth_fp32, + temporal_mode=depth_temporal_mode, + ) + + elif mode == "composite": + # Composite mode: depth + edges fused with soft max + if self._depth_generator is None: + self._depth_generator = VDADepthControlMapGenerator() + + # Get composite parameters + edge_strength = params.get("composite_edge_strength", 0.6) + edge_thickness = params.get("composite_edge_thickness", 8) + sharpness = params.get("composite_sharpness", 10.0) + edge_source = params.get("composite_edge_source", "canny") + + # Generate depth + depth_input_size = params.get("vace_depth_input_size") + depth_fp32 = params.get("vace_depth_fp32") + depth_temporal_mode = params.get("vace_depth_temporal_mode") + depth_contrast = params.get("vace_depth_contrast") + if depth_contrast is not None: + self._depth_generator.depth_contrast = depth_contrast + depth_frames = self._depth_generator.process_frames( + frames, + hard_cut=False, + input_size=depth_input_size, + fp32=depth_fp32, + temporal_mode=depth_temporal_mode, + ) + + # Generate edges based on source + if edge_source == "pidinet": + if self._pidinet_generator is None: + self._pidinet_generator = PiDiNetEdgeGenerator() + safe_mode = params.get("vace_pidinet_safe", True) + apply_filter = params.get("vace_pidinet_filter", True) + self._pidinet_generator.safe_mode = safe_mode + edge_frames = self._pidinet_generator.process_frames( + frames, apply_filter=apply_filter + ) + else: + # Default to canny with dilation for thickness + edge_frames = apply_canny_edges( + frames, + adaptive_thresholds=True, + dilate_edges=True, + dilate_kernel_size=edge_thickness, + ) + + # Fuse depth + edges + depth_f = depth_frames[0] # (1, H, W, 3) float [0, 255] + edge_f = edge_frames[0] + + # Normalize to [0, 1] using first channel (grayscale) + depth_norm = depth_f[:, :, :, 0] / 255.0 + edge_norm = edge_f[:, :, :, 0] / 255.0 + + # Soft max fusion + fused = soft_max_fusion( + depth_norm, + edge_norm, + edge_strength=edge_strength, + sharpness=sharpness, + ) + + # Convert back to (1, H, W, 3) float [0, 255] + fused_uint8 = (fused * 255.0).clamp(0, 255) + fused_rgb = fused_uint8.unsqueeze(-1).expand(-1, -1, -1, 3) + control_frames = [fused_rgb] + + else: + return None, None + + # control_frames[-1] is (1, H, W, 3) (uint8 for canny/depth; float for pidinet/composite) + gen_frame = control_frames[-1] + preview = gen_frame.squeeze(0) + if preview.dtype != torch.uint8: + preview = preview.clamp(0, 255).to(torch.uint8) + return gen_frame, preview + + except Exception as e: + logger.error(f"ControlMapWorker._process_frame error: {e}", exc_info=True) + return None, None + + class FrameProcessor: def __init__( self, @@ -51,9 +1301,95 @@ def __init__( ): self.pipeline_manager = pipeline_manager + output_queue_env = ( + os.getenv("SCOPE_OUTPUT_QUEUE_MAX_FRAMES", "").strip() + or os.getenv("SCOPE_OUTPUT_QUEUE_SIZE", "").strip() + ) + output_queue_max_frames: int | None = None + if output_queue_env: + try: + output_queue_max_frames = max(1, int(output_queue_env)) + max_output_queue_size = output_queue_max_frames + except ValueError: + logger.warning( + "Invalid SCOPE_OUTPUT_QUEUE_MAX_FRAMES/SCOPE_OUTPUT_QUEUE_SIZE=%r; expected int", + output_queue_env, + ) + self.frame_buffer = deque(maxlen=max_buffer_size) self.frame_buffer_lock = threading.Lock() self.output_queue = queue.Queue(maxsize=max_output_queue_size) + self.output_queue_lock = threading.Lock() # Protects queue resize and flush + + # Output queue drop counter (when consumer can't keep up). + # Written by worker thread, read by debug endpoint. + self.output_frames_dropped = 0 + + # Low-latency mode: drop old input frames to reduce lag + # Opt-in via SCOPE_LOW_LATENCY_INPUT=1 env var or parameter + self._low_latency_mode = ( + os.environ.get("SCOPE_LOW_LATENCY_INPUT", "0") == "1" + or (initial_parameters or {}).get("low_latency_input", False) + ) + # Low-latency output: prefer newest generated frames over smooth playback. + # When enabled, the output queue drops the oldest frames to keep latency bounded. + self._low_latency_output_mode = ( + os.environ.get("SCOPE_LOW_LATENCY_OUTPUT", "0") == "1" + or (initial_parameters or {}).get("low_latency_output", False) + ) + # Optional cap on output queue max size (prevents auto-resize from increasing latency). + self._output_queue_maxsize_cap: int | None = ( + output_queue_max_frames if output_queue_max_frames is not None else None + ) + self._low_latency_buffer_factor = 2 # Keep chunk_size * factor frames max + self.input_frames_dropped = 0 # Counter for dropped input frames + + # Non-destructive latest frame buffer for REST /api/frame/latest + self.latest_frame_cpu: torch.Tensor | None = None + self.latest_frame_lock = threading.Lock() + # Monotonic version counter for latest_frame_cpu updates (observer pacing / wait). + self.latest_frame_id = 0 + self._latest_frame_event = asyncio.Event() + try: + self._latest_frame_event_loop: asyncio.AbstractEventLoop | None = asyncio.get_running_loop() + except RuntimeError: + self._latest_frame_event_loop = None + + # Latest control frame for VACE preview (MJPEG streaming) + self.latest_control_frame_cpu: torch.Tensor | None = None + self.latest_control_frame_lock = threading.Lock() + + # Phase 2.1b: Frame ID tracking for control buffer alignment + self._frame_id_lock = threading.Lock() + self._next_frame_id = 0 + + # Phase 2.1b: Feature flag to enable control buffer generation path + # Default OFF for safety - enables scaffolding without changing behavior + self._control_buffer_enabled = _is_env_true( + "SCOPE_VACE_CONTROL_BUFFER_ENABLED", default="0" + ) + + # Control map worker for high-FPS preview (Phase 2.1a) and generation (Phase 2.1b) + self._control_map_worker = ControlMapWorker( + latest_control_frame_lock=self.latest_control_frame_lock, + parameters_getter=lambda: self.parameters, + max_queue_size=int( + os.getenv("SCOPE_VACE_CONTROL_MAP_WORKER_QUEUE_SIZE", "60") or "60" + ), + control_buffer_size=int( + os.getenv("SCOPE_VACE_CONTROL_BUFFER_MAXLEN", "120") or "120" + ), + ) + + # VDA depth generator for VACE control maps (lazy-loaded) + self._depth_generator: VDADepthControlMapGenerator | None = None + + # PiDiNet neural edge generator for VACE control maps (lazy-loaded) + self._pidinet_generator: PiDiNetEdgeGenerator | None = None + + # Temporal EMA state for control map smoothing + self._prev_control_frames: list[torch.Tensor] | None = None + self._prev_control_map_mode: str | None = None # Current parameters used by processing thread self.parameters = initial_parameters or {} @@ -88,6 +1424,45 @@ def __init__( self.paused = False + # Control bus for deterministic event ordering at chunk boundaries + self.control_bus = ControlBus() + self.chunk_index = 0 + + # Style layer: WorldState + StyleManifest + TemplateCompiler + self.world_state: WorldState = WorldState() + self.style_manifest: StyleManifest | None = None + self.style_registry: StyleRegistry = StyleRegistry() + self.prompt_compiler: TemplateCompiler = TemplateCompiler() + self._compiled_prompt: CompiledPrompt | None = None + self._active_style_name: str | None = None # For edge-triggering LoRA updates + + # Step mode: allow generating N chunks even while paused. + # Stored on the worker thread for deterministic semantics. + self._pending_steps = 0 + + # Soft transition state (temporary KV cache bias adjustment) + self._soft_transition_active: bool = False + self._soft_transition_chunks_remaining: int = 0 + self._soft_transition_temp_bias: float | None = None + self._soft_transition_original_bias: float | None = None + self._soft_transition_original_bias_was_set: bool = False + # Soft transition recording latch: record softCut once at the next generated chunk. + self._soft_transition_record_pending: bool = False + + # Style switch behavior: when True, reset cache on style change (clean transition) + # When False, allow blend artifacts during style transitions (artistic effect) + self.reset_cache_on_style_switch: bool = True + + # Session recorder (server-side timeline export) + self.session_recorder = SessionRecorder() + self._last_recording_path: Path | None = None + + # Snapshot store (server-side, in-memory) + # Keys are snapshot_id, values are Snapshot objects with cloned tensors + self.snapshots: dict[str, Snapshot] = {} + self.snapshot_order: list[str] = [] # For LRU eviction (oldest first) + self.snapshot_response_callback: callable | None = None + # Spout integration self.spout_sender = None self.spout_sender_enabled = False @@ -104,10 +1479,104 @@ def __init__( self.spout_receiver_name = "" self.spout_receiver_thread = None + # NDI input + self.ndi_receiver = None + self.ndi_receiver_enabled = False + self.ndi_receiver_source = "" + self.ndi_receiver_extra_ips: list[str] | None = None + self.ndi_receiver_thread = None + + # NDI stats (exposed via debug/status endpoints) + self.ndi_frames_received: int = 0 + self.ndi_frames_dropped: int = 0 + self.ndi_last_frame_ts_s: float = 0.0 + self.ndi_frames_reused: int = 0 + self.ndi_connected_source: str | None = None + self.ndi_connected_url: str | None = None + self.ndi_reconnects: int = 0 + + # "Hold last input frame" for NDI external/passthrough control mode. + # This decouples generation from NDI jitter / lower FPS producers. + self._ndi_hold_last_input_frame: torch.Tensor | None = None + self._ndi_hold_last_input_frame_id: int | None = None + self.external_input_stale: bool = False + self._external_resume_hard_cut_pending: bool = False + + # Input source exclusivity (Phase 0 for external video inputs) + # Prevent mixing WebRTC + Spout (+ future NDI) frames in the same buffer. + self._input_source_lock = threading.Lock() + self._active_input_source: str = "webrtc" # "webrtc" | "spout" | "ndi" + # Input mode is signaled by the frontend at stream start. # This determines whether we wait for video frames or generate immediately. self._video_mode = (initial_parameters or {}).get("input_mode") == "video" + # Hard cut: if reset_cache is requested while waiting for video input, + # flush once (no log spam) but keep reset_cache pending until applied. + self._hard_cut_flushed_pending = False + + def _get_current_effective_prompt(self) -> tuple[str | None, float]: + """Best-effort extraction of the current pipeline-facing prompt. + + Precedence: + 1) transition.target_prompts[0] + 2) parameters["prompts"][0] + 3) pipeline.state["prompts"][0] (fallback) + 4) style layer compiled prompt (multiple shapes) + """ + transition = self.parameters.get("transition") + if isinstance(transition, dict): + targets = transition.get("target_prompts") + if isinstance(targets, list) and targets: + first = targets[0] + if isinstance(first, dict): + return first.get("text"), float(first.get("weight", 1.0)) + + prompts = self.parameters.get("prompts") + if isinstance(prompts, list) and prompts: + first = prompts[0] + if isinstance(first, dict): + return first.get("text"), float(first.get("weight", 1.0)) + if hasattr(first, "text"): + return getattr(first, "text", None), float(getattr(first, "weight", 1.0)) + + pipeline = None + try: + pipeline = self.pipeline_manager.get_pipeline() + except Exception: + pipeline = None + if pipeline is not None and hasattr(pipeline, "state"): + state = getattr(pipeline, "state", None) + state_prompts = None + if hasattr(state, "get"): + state_prompts = state.get("prompts") + if isinstance(state_prompts, list) and state_prompts: + first = state_prompts[0] + if isinstance(first, dict): + return first.get("text"), float(first.get("weight", 1.0)) + + compiled = getattr(self, "_compiled_prompt", None) + if compiled is not None: + cps = getattr(compiled, "prompts", None) + if isinstance(cps, list) and cps: + first = cps[0] + if hasattr(first, "text"): + return getattr(first, "text", None), float(getattr(first, "weight", 1.0)) + if isinstance(first, dict): + return first.get("text"), float(first.get("weight", 1.0)) + + pos = getattr(compiled, "positive", None) + if isinstance(pos, list) and pos: + first = pos[0] + if isinstance(first, dict): + return first.get("text"), float(first.get("weight", 1.0)) + + prompt_str = getattr(compiled, "prompt", None) + if isinstance(prompt_str, str) and prompt_str.strip(): + return prompt_str, 1.0 + + return None, 1.0 + def start(self): if self.running: return @@ -124,9 +1593,83 @@ def start(self): spout_config = self.parameters.pop("spout_receiver") self._update_spout_receiver(spout_config) + # Load style manifests from styles/ directory + try: + self.style_registry.load_from_style_dirs() + if len(self.style_registry) > 0: + logger.info( + "Loaded %d styles: %s", + len(self.style_registry), + self.style_registry.list_styles(), + ) + except Exception as e: + logger.warning("Failed to load styles from style dirs: %s", e) + + if default_style := os.getenv("STYLE_DEFAULT"): + style_name = default_style.strip() + if style_name and self.style_registry.get(style_name): + logger.info("STYLE_DEFAULT=%s: applying initial style", style_name) + self.update_parameters({"_rcp_set_style": style_name}) + elif style_name: + logger.warning( + "STYLE_DEFAULT=%s: style not found in registry, ignoring", + style_name, + ) + self.worker_thread = threading.Thread(target=self.worker_loop, daemon=True) self.worker_thread.start() + # Start control map worker for high-FPS preview (Phase 2.1a) + enable_control_map_worker = _is_env_true( + "SCOPE_VACE_CONTROL_MAP_WORKER", default="1" + ) + auto_disable_above_pixels = os.getenv( + "SCOPE_VACE_CONTROL_MAP_WORKER_AUTO_DISABLE_ABOVE_PIXELS", "" + ).strip() + if enable_control_map_worker and auto_disable_above_pixels: + try: + threshold_pixels = int(auto_disable_above_pixels) + except ValueError: + logger.warning( + "Invalid SCOPE_VACE_CONTROL_MAP_WORKER_AUTO_DISABLE_ABOVE_PIXELS=%r; expected int", + auto_disable_above_pixels, + ) + else: + if threshold_pixels > 0: + width, height = self._get_pipeline_dimensions() + pixels = int(width) * int(height) + if pixels >= threshold_pixels: + enable_control_map_worker = False + logger.info( + "ControlMapWorker auto-disabled for high-res: %dx%d (%d px) >= %d px " + "(SCOPE_VACE_CONTROL_MAP_WORKER_AUTO_DISABLE_ABOVE_PIXELS)", + width, + height, + pixels, + threshold_pixels, + ) + + max_fps_env = os.getenv("SCOPE_VACE_CONTROL_MAP_WORKER_MAX_FPS", "").strip() + if max_fps_env: + try: + max_fps = float(max_fps_env) + except ValueError: + logger.warning( + "Invalid SCOPE_VACE_CONTROL_MAP_WORKER_MAX_FPS=%r; expected float", + max_fps_env, + ) + else: + self._control_map_worker.set_max_fps(max_fps) + logger.info( + "ControlMapWorker max_fps=%.2f (SCOPE_VACE_CONTROL_MAP_WORKER_MAX_FPS)", + max_fps, + ) + + if enable_control_map_worker: + self._control_map_worker.start() + else: + logger.info("ControlMapWorker disabled") + logger.info("FrameProcessor started") def stop(self, error_message: str = None): @@ -141,15 +1684,14 @@ def stop(self, error_message: str = None): if threading.current_thread() != self.worker_thread: self.worker_thread.join(timeout=5.0) - while not self.output_queue.empty(): - try: - self.output_queue.get_nowait() - except queue.Empty: - break + self.flush_output_queue() with self.frame_buffer_lock: self.frame_buffer.clear() + # Stop control map worker (Phase 2.1a) + self._control_map_worker.stop() + # Clean up Spout sender self.spout_sender_enabled = False if self.spout_sender_thread and self.spout_sender_thread.is_alive(): @@ -168,6 +1710,10 @@ def stop(self, error_message: str = None): # Clean up Spout receiver self.spout_receiver_enabled = False + if self.spout_receiver_thread and self.spout_receiver_thread.is_alive(): + if threading.current_thread() != self.spout_receiver_thread: + self.spout_receiver_thread.join(timeout=2.0) + self.spout_receiver_thread = None if self.spout_receiver is not None: try: self.spout_receiver.release() @@ -175,32 +1721,192 @@ def stop(self, error_message: str = None): logger.error(f"Error releasing Spout receiver: {e}") self.spout_receiver = None + # Clean up NDI receiver + self.ndi_receiver_enabled = False + if self.ndi_receiver_thread and self.ndi_receiver_thread.is_alive(): + if threading.current_thread() != self.ndi_receiver_thread: + self.ndi_receiver_thread.join(timeout=2.0) + self.ndi_receiver_thread = None + # Receiver resources are released by the receiver thread (see _ndi_receiver_loop()). + self.ndi_receiver = None + if self.get_active_input_source() == "ndi": + self._set_active_input_source("webrtc") + # Clear input frame times with self.input_fps_lock: self.input_frame_times.clear() - logger.info("FrameProcessor stopped") + logger.info("FrameProcessor stopped") + + # Notify callback that frame processor has stopped + if self.notification_callback: + try: + message = {"type": "stream_stopped"} + if error_message: + message["error_message"] = error_message + self.notification_callback(message) + except Exception as e: + logger.error(f"Error in frame processor stop callback: {e}") + + def put(self, frame: VideoFrame) -> bool: + if not self.running: + return False + + if self.get_active_input_source() != "webrtc": + # Ignore WebRTC input when another input source is active (e.g. Spout/NDI). + # Do not call track_input_frame(): input_fps should reflect the active source. + return False + + # Track input frame timestamp for FPS measurement + self.track_input_frame() + + # Phase 2.1b: Assign a stable monotonic frame_id + with self._frame_id_lock: + frame_id = self._next_frame_id + self._next_frame_id += 1 + + wrapped = _FrameWithID(frame, frame_id) + + # Enqueue to control map worker (preview + generation ring buffer) + # Non-blocking: worker may drop frames if it falls behind. + self._control_map_worker.put(wrapped) + + with self.frame_buffer_lock: + self.frame_buffer.append(wrapped) + return True + + def get_active_input_source(self) -> str: + with self._input_source_lock: + return self._active_input_source + + def _set_active_input_source(self, source: str) -> None: + normalized = (source or "").strip().lower() + if normalized not in ("webrtc", "spout", "ndi"): + logger.warning("Ignoring unknown input source %r", source) + return + with self._input_source_lock: + self._active_input_source = normalized + + def flush_output_queue(self) -> int: + """Flush all frames from output queue. + + Thread-safe: uses output_queue_lock to prevent race with queue resize. + + Returns: + Number of frames flushed + """ + count = 0 + with self.output_queue_lock: + while True: + try: + self.output_queue.get_nowait() + count += 1 + except queue.Empty: + break + return count + + def get_latest_frame(self) -> torch.Tensor | None: + """Get the most recent frame without consuming from output queue. + + Returns a clone of the latest frame, or None if no frames produced yet. + Thread-safe: uses latest_frame_lock. + """ + with self.latest_frame_lock: + if self.latest_frame_cpu is not None: + return self.latest_frame_cpu.clone() + return None + + def set_latest_frame_event_loop(self, loop: asyncio.AbstractEventLoop) -> None: + """Set the asyncio loop used to signal latest-frame updates. + + FrameProcessor output is produced on a worker thread. Observers can await + `wait_for_frame()` without polling by having the worker thread signal an + asyncio.Event via `loop.call_soon_threadsafe(...)`. + """ + self._latest_frame_event_loop = loop + + def _signal_latest_frame_available(self) -> None: + """Best-effort signal to wake any `wait_for_frame()` awaiters.""" + loop = self._latest_frame_event_loop + if loop is None: + return + try: + loop.call_soon_threadsafe(self._latest_frame_event.set) + except RuntimeError: + # Loop may be closed during shutdown; observers will fall back to polling/timeouts. + self._latest_frame_event_loop = None + + async def wait_for_frame( + self, + after_id: int, + *, + timeout: float = 0.1, + ) -> tuple[torch.Tensor | None, int]: + """Wait until latest_frame_id advances beyond `after_id`, or timeout. + + Returns: + (latest_frame_clone_or_none, latest_frame_id) + """ + if self._latest_frame_event_loop is None: + try: + self._latest_frame_event_loop = asyncio.get_running_loop() + except RuntimeError: + self._latest_frame_event_loop = None + + deadline = time.monotonic() + max(0.0, float(timeout)) + while True: + with self.latest_frame_lock: + current_id = int(self.latest_frame_id) + if current_id > int(after_id): + frame = ( + self.latest_frame_cpu.clone() if self.latest_frame_cpu is not None else None + ) + return frame, current_id + + remaining = deadline - time.monotonic() + if remaining <= 0: + with self.latest_frame_lock: + current_id = int(self.latest_frame_id) + frame = ( + self.latest_frame_cpu.clone() if self.latest_frame_cpu is not None else None + ) + return frame, current_id + + # Avoid missing a signal by clearing + re-checking before awaiting. + self._latest_frame_event.clear() + with self.latest_frame_lock: + current_id = int(self.latest_frame_id) + if current_id > int(after_id): + continue - # Notify callback that frame processor has stopped - if self.notification_callback: try: - message = {"type": "stream_stopped"} - if error_message: - message["error_message"] = error_message - self.notification_callback(message) - except Exception as e: - logger.error(f"Error in frame processor stop callback: {e}") + await asyncio.wait_for(self._latest_frame_event.wait(), timeout=remaining) + except asyncio.TimeoutError: + # Loop back and return best-effort latest frame. + continue - def put(self, frame: VideoFrame) -> bool: - if not self.running: - return False + def get_latest_control_frame(self) -> torch.Tensor | None: + """Get the most recent VACE control frame for preview. - # Track input frame timestamp for FPS measurement - self.track_input_frame() + Returns a clone of the latest control frame (H, W, 3) uint8, + or None if no control frames generated yet. - with self.frame_buffer_lock: - self.frame_buffer.append(frame) - return True + Phase 2.1a: Prefers worker output (high-FPS) over chunk output. + Falls back to chunk output if worker hasn't produced anything yet. + """ + if (self.parameters.get("vace_control_map_mode") or "none") == "none": + return None + + # Prefer worker output (higher FPS) + worker_frame = self._control_map_worker.get_latest() + if worker_frame is not None: + return worker_frame + + # Fallback to chunk-generated output + with self.latest_control_frame_lock: + if self.latest_control_frame_cpu is not None: + return self.latest_control_frame_cpu.clone() + return None def get(self) -> torch.Tensor | None: if not self.running: @@ -239,11 +1945,145 @@ def get_output_fps(self) -> float: input_fps = self._get_input_fps() pipeline_fps = self.get_current_pipeline_fps() - if input_fps is None: - return pipeline_fps + # In external control mode, input frames are control signals (not a "video clock"). + # Allow output to run at pipeline FPS, reusing control maps as needed. + if self._ndi_external_hold_last_enabled() and self._ndi_hold_last_input_frame is not None: + base_fps = pipeline_fps + elif input_fps is None: + base_fps = pipeline_fps + else: + # Use minimum to respect both input rate and pipeline capacity + base_fps = min(input_fps, pipeline_fps) + + pacing_fps = self._get_output_pacing_fps() + if pacing_fps is not None: + return min(base_fps, pacing_fps) + + return base_fps + + def get_fps_debug(self) -> dict[str, object]: + """Return a debug snapshot of the FPS signals used for WebRTC pacing. + + This is intended for diagnosing cases where the end-to-end server FPS + differs from pipeline-only measurements. + """ + with self.input_fps_lock: + input_samples = len(self.input_frame_times) + input_fps = ( + float(self.current_input_fps) + if input_samples >= INPUT_FPS_MIN_SAMPLES + else None + ) + + pipeline_fps = float(self.get_current_pipeline_fps()) + output_pacing_fps = self._get_output_pacing_fps() + + external_timebase = ( + self._ndi_external_hold_last_enabled() and self._ndi_hold_last_input_frame is not None + ) + + base_output_fps = pipeline_fps + if not external_timebase and input_fps is not None: + base_output_fps = float(min(input_fps, pipeline_fps)) + + output_fps = float(self.get_output_fps()) + + bottleneck = "pipeline_fps" + if output_pacing_fps is not None and float(output_pacing_fps) <= float(base_output_fps): + bottleneck = "output_pacing_fps" + elif external_timebase: + bottleneck = "pipeline_fps" + elif input_fps is None: + bottleneck = "pipeline_fps" + elif input_fps <= pipeline_fps: + bottleneck = "input_fps" + else: + bottleneck = "pipeline_fps" + + with self.output_queue_lock: + output_queue_depth = int(self.output_queue.qsize()) + output_queue_max = int(self.output_queue.maxsize) + + with self.frame_buffer_lock: + frame_buffer_depth = int(len(self.frame_buffer)) + + ndi_last_frame_age_ms: float | None = None + if self.ndi_last_frame_ts_s > 0: + ndi_last_frame_age_ms = (time.monotonic() - float(self.ndi_last_frame_ts_s)) * 1000.0 + + estimated_input_buffer_window_ms: float | None = None + if input_fps is not None and input_fps > 0: + estimated_input_buffer_window_ms = (frame_buffer_depth / input_fps) * 1000.0 + + estimated_output_queue_window_ms: float | None = None + if output_fps > 0: + estimated_output_queue_window_ms = (output_queue_depth / output_fps) * 1000.0 - # Use minimum to respect both input rate and pipeline capacity - return min(input_fps, pipeline_fps) + estimated_server_buffer_window_ms: float | None = None + if ( + estimated_input_buffer_window_ms is not None + and estimated_output_queue_window_ms is not None + ): + estimated_server_buffer_window_ms = ( + estimated_input_buffer_window_ms + estimated_output_queue_window_ms + ) + + try: + parameters_queue_depth = int(self.parameters_queue.qsize()) + except Exception: + parameters_queue_depth = 0 + + backend_report: dict[str, object] | None = None + + return { + "input_fps": input_fps, + "input_fps_samples": input_samples, + "pipeline_fps": pipeline_fps, + "output_base_fps": float(base_output_fps), + "output_fps": output_fps, + "output_pacing_fps": float(output_pacing_fps) + if output_pacing_fps is not None + else None, + "output_fps_bottleneck": bottleneck, + "active_input_source": self.get_active_input_source(), + "frame_buffer_depth": frame_buffer_depth, + # Back-compat for scripts/check_fps.sh + "frame_buffer_size": frame_buffer_depth, + "output_queue_depth": output_queue_depth, + # Back-compat for scripts/check_fps.sh + "output_queue_size": output_queue_depth, + "output_queue_max": output_queue_max, + # Stopgap latency estimates (buffer windows only; excludes compute + network). + "estimated_input_buffer_window_ms": estimated_input_buffer_window_ms, + "estimated_output_queue_window_ms": estimated_output_queue_window_ms, + "estimated_server_buffer_window_ms": estimated_server_buffer_window_ms, + "output_frames_dropped": int(self.output_frames_dropped), + "input_frames_dropped": int(self.input_frames_dropped), + "low_latency_mode": bool(self._low_latency_mode), + "low_latency_output_mode": bool(self._low_latency_output_mode), + "output_queue_maxsize_cap": int(self._output_queue_maxsize_cap) + if self._output_queue_maxsize_cap is not None + else None, + "parameters_queue_depth": parameters_queue_depth, + "control_map_worker": self._control_map_worker.get_debug_info(), + "backend_report": backend_report, + "video_mode": bool(self._video_mode), + "ndi": { + "enabled": bool(self.ndi_receiver_enabled), + "source": str(self.ndi_receiver_source or ""), + "extra_ips": list(self.ndi_receiver_extra_ips) if self.ndi_receiver_extra_ips else None, + "connected_source": str(self.ndi_connected_source) if self.ndi_connected_source else None, + "connected_url": str(self.ndi_connected_url) if self.ndi_connected_url else None, + "reconnects": int(self.ndi_reconnects), + "frames_received": int(self.ndi_frames_received), + "frames_dropped_during_drain": int(self.ndi_frames_dropped), + "last_frame_age_ms": float(ndi_last_frame_age_ms) if ndi_last_frame_age_ms is not None else None, + "frames_reused_total": int(self.ndi_frames_reused), + "external_stale_ms": float(self._get_vace_external_stale_ms()), + "external_input_stale": bool(self.external_input_stale), + "external_resume_hard_cut": bool(self._get_vace_external_resume_hard_cut_enabled()), + }, + } def _get_input_fps(self) -> float | None: """Get the current measured input FPS. @@ -339,8 +2179,114 @@ def _get_pipeline_dimensions(self) -> tuple[int, int]: logger.warning(f"Could not get pipeline dimensions: {e}") return 512, 512 - def update_parameters(self, parameters: dict[str, Any]): - """Update parameters that will be used in the next pipeline call.""" + def _apply_temporal_ema( + self, + control_frames: list[torch.Tensor], + mode: str, + ema: float, + hard_cut: bool = False, + ) -> list[torch.Tensor]: + """Apply temporal EMA smoothing to control frames. + + Args: + control_frames: List of control frame tensors (1, H, W, 3) float [0, 255] + mode: Current control map mode (for detecting mode changes) + ema: EMA momentum (0.0 = no smoothing, 0.9 = heavy smoothing) + hard_cut: If True, reset EMA state (e.g., on cache reset) + + Returns: + Smoothed control frames + """ + if ema <= 0.0 or ema >= 1.0: + # No smoothing or invalid value + return control_frames + + # Reset on mode change or hard cut + if hard_cut or mode != self._prev_control_map_mode: + self._prev_control_frames = None + self._prev_control_map_mode = mode + + # If no previous frames, just store current and return + if self._prev_control_frames is None: + self._prev_control_frames = [f.clone() for f in control_frames] + return control_frames + + # Apply EMA: smoothed = ema * prev + (1 - ema) * current + smoothed_frames = [] + for i, current in enumerate(control_frames): + if i < len(self._prev_control_frames): + prev = self._prev_control_frames[i] + # Ensure same shape + if prev.shape == current.shape: + smoothed = ema * prev + (1.0 - ema) * current + else: + # Shape mismatch (resolution change), reset + smoothed = current + else: + smoothed = current + smoothed_frames.append(smoothed) + + # Store for next iteration + self._prev_control_frames = [f.clone() for f in smoothed_frames] + + return smoothed_frames + + def _try_sample_control_frames( + self, frame_ids: list[int] | None + ) -> list[torch.Tensor] | None: + """Try to sample control frames from worker buffer with block+timeout policy. + + Phase 2.1b: Attempts to retrieve pre-computed control frames from the + ControlMapWorker's ring buffer. If the frames aren't available yet, + blocks briefly (configurable) then returns None to signal fallback + to chunk-time compute. + + Args: + frame_ids: List of frame IDs to sample, or None if not available. + + Returns: + List of control frame tensors if all frame_ids are available, + or None to signal fallback to chunk-time compute. + """ + if not self._control_buffer_enabled: + return None + if frame_ids is None or not frame_ids: + return None + # Skip if all frame_ids are -1 (stub/unknown) + if all(fid == -1 for fid in frame_ids): + return None + + policy = ( + self.parameters.get("vace_control_buffer_missing_policy") + or os.getenv("SCOPE_VACE_CONTROL_BUFFER_MISSING_POLICY", "block") + or "block" + ).strip().lower() + + timeout_s = float( + self.parameters.get("vace_control_buffer_block_timeout_s") + or os.getenv("SCOPE_VACE_CONTROL_BUFFER_BLOCK_TIMEOUT_S", "0.25") + or "0.25" + ) + + t0 = time.perf_counter() + while True: + control_frames = self._control_map_worker.sample_control_frames(frame_ids) + if control_frames is not None: + return control_frames + if not policy.startswith("block"): + return None + if self.shutdown_event.is_set(): + return None + if (time.perf_counter() - t0) >= timeout_s: + return None + self.shutdown_event.wait(0.005) + + def update_parameters(self, parameters: dict[str, Any]) -> bool: + """Update parameters that will be used in the next pipeline call. + + Returns: + True if the update was queued successfully, False otherwise. + """ # Handle Spout output settings if "spout_sender" in parameters: spout_config = parameters.pop("spout_sender") @@ -351,13 +2297,27 @@ def update_parameters(self, parameters: dict[str, Any]): spout_config = parameters.pop("spout_receiver") self._update_spout_receiver(spout_config) - # Put new parameters in queue (replace any pending update) + # Handle NDI input settings + if "ndi_receiver" in parameters: + ndi_config = parameters.pop("ndi_receiver") + self._update_ndi_receiver(ndi_config) + + # Put new parameters in queue with mailbox semantics: + # If queue is full, drop oldest (not newest) to ensure latest control commands apply try: - # Add new update self.parameters_queue.put_nowait(parameters) except queue.Full: - logger.info("Parameter queue full, dropping parameter update") - return False + # Drop oldest to make room for newest (mailbox semantics) + try: + self.parameters_queue.get_nowait() + except queue.Empty: + pass + try: + self.parameters_queue.put_nowait(parameters) + except queue.Full: + logger.warning("Parameter queue still full after dropping oldest") + return False + return True def _update_spout_sender(self, config: dict): """Update Spout output configuration.""" @@ -464,6 +2424,13 @@ def _update_spout_receiver(self, config: dict): logger.warning("Spout module not available on this platform") return + def stop_receiver_thread() -> None: + self.spout_receiver_enabled = False + if self.spout_receiver_thread and self.spout_receiver_thread.is_alive(): + if threading.current_thread() != self.spout_receiver_thread: + self.spout_receiver_thread.join(timeout=2.0) + self.spout_receiver_thread = None + if enabled and not self.spout_receiver_enabled: # Enable Spout input try: @@ -471,6 +2438,7 @@ def _update_spout_receiver(self, config: dict): if self.spout_receiver.create(): self.spout_receiver_enabled = True self.spout_receiver_name = sender_name + self._set_active_input_source("spout") # Start receiving thread self.spout_receiver_thread = threading.Thread( target=self._spout_receiver_loop, daemon=True @@ -486,15 +2454,17 @@ def _update_spout_receiver(self, config: dict): elif not enabled and self.spout_receiver_enabled: # Disable Spout input - self.spout_receiver_enabled = False + stop_receiver_thread() if self.spout_receiver is not None: self.spout_receiver.release() self.spout_receiver = None + if self.get_active_input_source() == "spout": + self._set_active_input_source("webrtc") logger.info("Spout input disabled") elif enabled and sender_name != self.spout_receiver_name: # Name changed, recreate receiver - self.spout_receiver_enabled = False + stop_receiver_thread() if self.spout_receiver is not None: self.spout_receiver.release() try: @@ -502,22 +2472,120 @@ def _update_spout_receiver(self, config: dict): if self.spout_receiver.create(): self.spout_receiver_enabled = True self.spout_receiver_name = sender_name - # Restart receiving thread if not running - if ( - self.spout_receiver_thread is None - or not self.spout_receiver_thread.is_alive() - ): - self.spout_receiver_thread = threading.Thread( - target=self._spout_receiver_loop, daemon=True - ) - self.spout_receiver_thread.start() + self._set_active_input_source("spout") + # Restart receiving thread + self.spout_receiver_thread = threading.Thread( + target=self._spout_receiver_loop, daemon=True + ) + self.spout_receiver_thread.start() logger.info(f"Spout input changed to: '{sender_name or 'any'}'") else: logger.error("Failed to recreate Spout receiver") self.spout_receiver = None + if self.get_active_input_source() == "spout": + self._set_active_input_source("webrtc") except Exception as e: logger.error(f"Error recreating Spout receiver: {e}") self.spout_receiver = None + if self.get_active_input_source() == "spout": + self._set_active_input_source("webrtc") + elif enabled and self.spout_receiver_enabled: + # Receiver enabled but thread may have died; ensure it is running. + if self.spout_receiver is not None and ( + self.spout_receiver_thread is None or not self.spout_receiver_thread.is_alive() + ): + self._set_active_input_source("spout") + self.spout_receiver_thread = threading.Thread( + target=self._spout_receiver_loop, daemon=True + ) + self.spout_receiver_thread.start() + + def _update_ndi_receiver(self, config: dict): + """Update NDI input configuration.""" + enabled = bool(config.get("enabled", False)) + source = str(config.get("source", "") or "") + extra_ips_raw = config.get("extra_ips", None) + extra_ips: list[str] | None = None + if extra_ips_raw is not None: + if isinstance(extra_ips_raw, str): + extra_ips = [s.strip() for s in extra_ips_raw.split(",") if s.strip()] + elif isinstance(extra_ips_raw, list): + extra_ips = [str(s).strip() for s in extra_ips_raw if str(s).strip()] + + def stop_receiver_thread() -> None: + self.ndi_receiver_enabled = False + self._ndi_hold_last_input_frame = None + self._ndi_hold_last_input_frame_id = None + self.external_input_stale = False + self._external_resume_hard_cut_pending = False + self.ndi_connected_source = None + self.ndi_connected_url = None + if self.ndi_receiver_thread and self.ndi_receiver_thread.is_alive(): + if threading.current_thread() != self.ndi_receiver_thread: + self.ndi_receiver_thread.join(timeout=2.0) + self.ndi_receiver_thread = None + + if enabled and not self.ndi_receiver_enabled: + self.ndi_receiver_enabled = True + self.ndi_receiver_source = source + self.ndi_receiver_extra_ips = extra_ips + self._ndi_hold_last_input_frame = None + self._ndi_hold_last_input_frame_id = None + self.external_input_stale = False + self._external_resume_hard_cut_pending = False + self.ndi_frames_reused = 0 + self.ndi_connected_source = None + self.ndi_connected_url = None + self.ndi_reconnects = 0 + self._set_active_input_source("ndi") + + self.ndi_receiver_thread = threading.Thread( + target=self._ndi_receiver_loop, + daemon=True, + ) + self.ndi_receiver_thread.start() + logger.info("NDI input enabled (source=%r, extra_ips=%r)", source, extra_ips) + + elif not enabled and self.ndi_receiver_enabled: + stop_receiver_thread() + if self.get_active_input_source() == "ndi": + self._set_active_input_source("webrtc") + logger.info("NDI input disabled") + + elif enabled and self.ndi_receiver_enabled: + # Config changed: restart receiver. + if source != self.ndi_receiver_source or extra_ips != self.ndi_receiver_extra_ips: + stop_receiver_thread() + self.ndi_receiver_enabled = True + self.ndi_receiver_source = source + self.ndi_receiver_extra_ips = extra_ips + self._ndi_hold_last_input_frame = None + self._ndi_hold_last_input_frame_id = None + self.external_input_stale = False + self._external_resume_hard_cut_pending = False + self.ndi_frames_reused = 0 + self.ndi_connected_source = None + self.ndi_connected_url = None + self.ndi_reconnects = 0 + self._set_active_input_source("ndi") + + self.ndi_receiver_thread = threading.Thread( + target=self._ndi_receiver_loop, + daemon=True, + ) + self.ndi_receiver_thread.start() + logger.info( + "NDI input reconfigured (source=%r, extra_ips=%r)", source, extra_ips + ) + + # Receiver enabled but thread may have died; ensure it is running. + elif self.ndi_receiver_thread is None or not self.ndi_receiver_thread.is_alive(): + self._set_active_input_source("ndi") + self.ndi_receiver_thread = threading.Thread( + target=self._ndi_receiver_loop, + daemon=True, + ) + self.ndi_receiver_thread.start() def _spout_sender_loop(self): """Background thread that sends frames to Spout asynchronously.""" @@ -569,123 +2637,837 @@ def _spout_receiver_loop(self): and self.spout_receiver is not None ): try: - # Update target FPS dynamically from pipeline performance - current_pipeline_fps = self.get_current_pipeline_fps() - if current_pipeline_fps > 0: - target_fps = current_pipeline_fps - frame_interval = 1.0 / target_fps + # Update target FPS dynamically from pipeline performance + current_pipeline_fps = self.get_current_pipeline_fps() + if current_pipeline_fps > 0: + target_fps = current_pipeline_fps + frame_interval = 1.0 / target_fps + + current_time = time.time() + + # Frame rate limiting - don't receive faster than target FPS + time_since_last = current_time - last_frame_time + if time_since_last < frame_interval: + time.sleep(frame_interval - time_since_last) + continue + + if self.get_active_input_source() != "spout": + # Avoid mixing sources: keep receiver loop alive but do not feed the buffer. + time.sleep(0.01) + continue + + # Receive directly as RGB (avoids extra copy from RGBA slice) + rgb_frame = self.spout_receiver.receive(as_rgb=True) + if rgb_frame is not None: + last_frame_time = time.time() + + # Phase 2.1b: track input FPS and assign frame_id + self.track_input_frame() + + with self._frame_id_lock: + frame_id = self._next_frame_id + self._next_frame_id += 1 + + # Wrap in _FrameWithID for buffer sampling alignment + wrapped_frame = _FrameWithID(_SpoutFrame(rgb_frame), frame_id) + + # Enqueue to control map worker (preview + generation buffer) + self._control_map_worker.put(wrapped_frame) + + with self.frame_buffer_lock: + self.frame_buffer.append(wrapped_frame) + + frame_count += 1 + if frame_count % 100 == 0: + logger.debug(f"Spout input received {frame_count} frames") + else: + time.sleep(0.001) # Small sleep when no frame available + + except Exception as e: + logger.error(f"Error in Spout input loop: {e}") + time.sleep(0.01) + + logger.info(f"Spout input thread stopped after {frame_count} frames") + + def _ndi_receiver_loop(self): + """Background thread that receives frames from NDI and adds to buffer.""" + logger.info("NDI input thread started") + + frame_count = 0 + + try: + from scope.server.ndi import NDIReceiver + + receiver = NDIReceiver(recv_name="ScopeNDIRecv") + self.ndi_receiver = receiver + + if not receiver.create(): + logger.error("NDI receiver create() failed") + return + + # Discover + connect (retry until disabled/stopped) + while self.running and self.ndi_receiver_enabled: + try: + src = receiver.connect_discovered( + source_substring=self.ndi_receiver_source, + extra_ips=self.ndi_receiver_extra_ips, + timeout_ms=1500, + ) + logger.info("NDI connected: %s (%s)", src.name, src.url_address) + self.ndi_connected_source = src.name + self.ndi_connected_url = src.url_address + self.ndi_reconnects += 1 + break + except Exception as e: + logger.warning("NDI connect failed: %s", e) + time.sleep(0.5) + + while self.running and self.ndi_receiver_enabled: + if self.get_active_input_source() != "ndi": + time.sleep(0.01) + continue + + try: + rgb_frame = receiver.receive_latest_rgb24(timeout_ms=50) + except Exception as e: + logger.warning("NDI receive error: %s", e) + time.sleep(0.05) + continue + + if rgb_frame is None: + continue + + # Update stats + stats = receiver.get_stats() + self.ndi_frames_received = stats.frames_received + self.ndi_frames_dropped = stats.frames_dropped_during_drain + self.ndi_last_frame_ts_s = stats.last_frame_ts_s + + # Track input frame timestamp for FPS measurement + self.track_input_frame() + + with self._frame_id_lock: + frame_id = self._next_frame_id + self._next_frame_id += 1 + + wrapped_frame = _FrameWithID(_NDIFrame(rgb_frame), frame_id) + + # Enqueue to control map worker (preview + generation ring buffer) + self._control_map_worker.put(wrapped_frame) + + with self.frame_buffer_lock: + self.frame_buffer.append(wrapped_frame) + + frame_count += 1 + finally: + try: + if self.ndi_receiver is not None: + self.ndi_receiver.release() + except Exception as e: + logger.warning("Failed to release NDI receiver: %s", e) + self.ndi_receiver = None + logger.info("NDI input thread stopped after %d frames", frame_count) + + def worker_loop(self): + logger.info("Worker thread started") + + while self.running and not self.shutdown_event.is_set(): + try: + self.process_chunk() + + except PipelineNotAvailableException as e: + logger.debug(f"Pipeline temporarily unavailable: {e}") + # Flush frame buffer to prevent buildup + with self.frame_buffer_lock: + if self.frame_buffer: + logger.debug( + f"Flushing {len(self.frame_buffer)} frames due to pipeline unavailability" + ) + self.frame_buffer.clear() + continue + except Exception as e: + if self._is_recoverable(e): + logger.error(f"Error in worker loop: {e}") + continue + else: + logger.error( + f"Non-recoverable error in worker loop: {e}, stopping frame processor" + ) + self.stop(error_message=str(e)) + break + logger.info("Worker thread stopped") + + def process_chunk(self): + start_time = time.time() + + # Legacy safety: ensure we don't persist "paused" inside self.parameters. + # Pause state is tracked separately in self.paused and updated via events. + paused = self.parameters.pop("paused", None) + if paused is not None and paused != self.paused: + self.paused = paused + + # ======================================================================== + # INGEST: Drain ALL pending queue entries (mailbox semantics) + # ======================================================================== + # Intentional behavior change from "drain 1" to "drain all": + # - Old: at most 1 update per chunk (10 rapid updates → 10 chunks to apply) + # - New: all pending updates per chunk (commit at boundary) + merged_updates: dict = {} + while True: + try: + update = self.parameters_queue.get_nowait() + # Last-write-wins merge + merged_updates = {**merged_updates, **update} + except queue.Empty: + break + + # ======================================================================== + # RESERVED KEYS: Handle snapshot/restore commands (not forwarded to pipeline) + # ======================================================================== + # These reserved keys route through parameters_queue for thread safety, + # but are consumed here and never forwarded to the pipeline or events. + if "_rcp_snapshot_request" in merged_updates: + merged_updates.pop("_rcp_snapshot_request") + try: + snapshot = self._create_snapshot() + # Send response via callback if registered + if self.snapshot_response_callback: + self.snapshot_response_callback( + { + "type": "snapshot_response", + "snapshot_id": snapshot.snapshot_id, + "chunk_index": snapshot.chunk_index, + "current_start_frame": snapshot.current_start_frame, + } + ) + except Exception as e: + logger.error(f"Error creating snapshot: {e}") + if self.snapshot_response_callback: + self.snapshot_response_callback( + {"type": "snapshot_response", "error": str(e)} + ) + + if "_rcp_restore_snapshot" in merged_updates: + restore_data = merged_updates.pop("_rcp_restore_snapshot") + snapshot_id = restore_data.get("snapshot_id") if restore_data else None + if snapshot_id: + success = self._restore_snapshot(snapshot_id) + if self.snapshot_response_callback: + self.snapshot_response_callback( + { + "type": "restore_response", + "snapshot_id": snapshot_id, + "success": success, + } + ) + else: + logger.warning("restore_snapshot called without snapshot_id") + if self.snapshot_response_callback: + self.snapshot_response_callback( + { + "type": "restore_response", + "error": "snapshot_id required", + "success": False, + } + ) + + # Step: generate exactly one chunk even while paused. + # Keep a small backlog so step isn't dropped when input frames aren't ready. + if "_rcp_step" in merged_updates: + step_val = merged_updates.pop("_rcp_step") + step_count = 1 + if isinstance(step_val, int) and not isinstance(step_val, bool): + step_count = max(1, step_val) + self._pending_steps += step_count + + # Session recording start/stop (consumed here; never forwarded to pipeline) + if "_rcp_session_recording_start" in merged_updates: + merged_updates.pop("_rcp_session_recording_start", None) + try: + status = ( + self.pipeline_manager.peek_status_info() + if hasattr(self.pipeline_manager, "peek_status_info") + else self.pipeline_manager.get_status_info() + ) + except Exception as e: + logger.warning( + "Session recording start: failed to read pipeline status: %s", e + ) + status = {} + + if status.get("status") != "loaded": + logger.warning( + "Session recording start ignored: pipeline not loaded (status=%s)", + status.get("status"), + ) + else: + pipeline_id = status.get("pipeline_id") + if not pipeline_id: + logger.warning( + "Session recording start ignored: missing pipeline_id in status" + ) + else: + lp = status.get("load_params") or {} + runtime_params: dict[str, Any] = ( + dict(lp) if isinstance(lp, dict) else {"load_params": lp} + ) + + # Include key runtime params for timeline settings/replay + if "kv_cache_attention_bias" in self.parameters: + runtime_params["kv_cache_attention_bias"] = self.parameters.get( + "kv_cache_attention_bias" + ) + if "denoising_step_list" in self.parameters: + runtime_params["denoising_step_list"] = self.parameters.get( + "denoising_step_list" + ) + if "seed" not in runtime_params: + if "seed" in self.parameters: + runtime_params["seed"] = self.parameters.get("seed") + elif "base_seed" in self.parameters: + runtime_params["seed"] = self.parameters.get("base_seed") + + baseline_prompt, baseline_weight = self._get_current_effective_prompt() + try: + self.session_recorder.start( + chunk_index=self.chunk_index, + pipeline_id=pipeline_id, + load_params=runtime_params, + baseline_prompt=baseline_prompt, + baseline_weight=baseline_weight, + ) + self._last_recording_path = None + self._soft_transition_record_pending = bool( + self._soft_transition_active + ) + logger.info( + "Session recording started at chunk=%d", self.chunk_index + ) + except Exception as e: + logger.error("Session recording start failed: %s", e) + + if "_rcp_session_recording_stop" in merged_updates: + merged_updates.pop("_rcp_session_recording_stop", None) + try: + recording = self.session_recorder.stop(chunk_index=self.chunk_index) + except Exception as e: + logger.error("Session recording stop failed: %s", e) + recording = None + + if recording is not None: + ts = datetime.now().strftime("%Y-%m-%d_%H%M%S") + path = ( + Path.home() + / ".daydream-scope" + / "recordings" + / f"session_{ts}.timeline.json" + ) + try: + saved = self.session_recorder.save(recording, path) + self._last_recording_path = saved + logger.info("Session recording saved: %s", saved) + except Exception as e: + logger.error("Failed to save session recording timeline: %s", e) + + # Soft transition: temporarily lower KV cache bias for N chunks + if "_rcp_soft_transition" in merged_updates: + soft_data = merged_updates.pop("_rcp_soft_transition") + if isinstance(soft_data, dict): + temp_bias = soft_data.get("temp_bias", 0.1) + num_chunks = soft_data.get("num_chunks", 2) + + # Handle precedence: if explicit kv_cache_attention_bias in same message, + # treat it as the base bias to restore to (and don't let it override temp) + explicit_bias = merged_updates.pop("kv_cache_attention_bias", None) + + # Coerce + clamp inputs (avoid log(<=0) downstream) + try: + temp_bias = float(temp_bias) + except (TypeError, ValueError): + temp_bias = 0.1 + temp_bias = max(0.01, min(temp_bias, 1.0)) + + try: + num_chunks = int(num_chunks) + except (TypeError, ValueError): + num_chunks = 2 + num_chunks = max(1, min(num_chunks, 10)) + + if explicit_bias is not None: + try: + explicit_bias = float(explicit_bias) + except (TypeError, ValueError): + explicit_bias = None + if explicit_bias is not None: + explicit_bias = max(0.01, min(explicit_bias, 1.0)) + + # Re-entrancy: don't overwrite original if already in soft transition + if not self._soft_transition_active: + # First trigger: save current bias as original + if explicit_bias is not None: + self._soft_transition_original_bias = explicit_bias + self._soft_transition_original_bias_was_set = True + else: + # Preserve "unset": if the key wasn't present, restore by deleting it. + if "kv_cache_attention_bias" in self.parameters: + self._soft_transition_original_bias = self.parameters.get( + "kv_cache_attention_bias" + ) + self._soft_transition_original_bias_was_set = True + else: + self._soft_transition_original_bias = None + self._soft_transition_original_bias_was_set = False + elif explicit_bias is not None: + # Re-trigger with explicit bias: update restore target + self._soft_transition_original_bias = explicit_bias + self._soft_transition_original_bias_was_set = True + + # (Re)start countdown + self._soft_transition_temp_bias = temp_bias + self._soft_transition_chunks_remaining = num_chunks + self._soft_transition_active = True + + # Apply temporary bias immediately + self.parameters["kv_cache_attention_bias"] = temp_bias + self._soft_transition_record_pending = True + logger.info( + f"Soft transition: bias -> {temp_bias} for {num_chunks} chunks " + f"(will restore to " + f"{self._soft_transition_original_bias if self._soft_transition_original_bias_was_set else ''})" + ) + + # If an explicit bias update arrives while a soft transition is active (and it wasn't + # consumed above), treat it as an override and cancel the soft transition so we + # don't later restore over the user's explicit change. + if self._soft_transition_active and "kv_cache_attention_bias" in merged_updates: + logger.info( + "Soft transition canceled: explicit kv_cache_attention_bias update received" + ) + self._soft_transition_active = False + self._soft_transition_chunks_remaining = 0 + self._soft_transition_temp_bias = None + self._soft_transition_original_bias = None + self._soft_transition_original_bias_was_set = False + self._soft_transition_record_pending = False + + # Track if explicit prompts were set this chunk (for precedence) + explicit_prompts_set = "prompts" in merged_updates + + # Handle world state update (full replace, thread-safe via model_validate) + if "_rcp_world_state" in merged_updates: + world_data = merged_updates.pop("_rcp_world_state") + try: + self.world_state = WorldState.model_validate(world_data) + logger.debug(f"WorldState updated: action={self.world_state.action}") - current_time = time.time() + # Recompile if style active and no explicit prompts + if self.style_manifest and not explicit_prompts_set: + compiled = self.prompt_compiler.compile( + self.world_state, self.style_manifest + ) + self._compiled_prompt = compiled + # Inject compiled prompts into merged_updates for event processing + merged_updates["prompts"] = [p.to_dict() for p in compiled.prompts] + logger.debug(f"Auto-compiled prompt: {compiled.prompt[:80]}...") + # Note: LoRA NOT re-sent here (only on style change) + except Exception as e: + logger.warning(f"Failed to validate WorldState: {e}") - # Frame rate limiting - don't receive faster than target FPS - time_since_last = current_time - last_frame_time - if time_since_last < frame_interval: - time.sleep(frame_interval - time_since_last) - continue + # Handle style change + if "_rcp_set_style" in merged_updates: + style_name = merged_updates.pop("_rcp_set_style") + new_style = self.style_registry.get(style_name) + if new_style: + style_changed = style_name != self._active_style_name - # Receive directly as RGB (avoids extra copy from RGBA slice) - rgb_frame = self.spout_receiver.receive(as_rgb=True) - if rgb_frame is not None: - last_frame_time = time.time() - spout_frame = _SpoutFrame(rgb_frame) + self.style_manifest = new_style + self._active_style_name = style_name + logger.info(f"Active style set to: {style_name}") - with self.frame_buffer_lock: - self.frame_buffer.append(spout_frame) + # Recreate compiler for the new style (may switch to LLM if available) + if style_changed: + try: + self.prompt_compiler = create_compiler(new_style) + except Exception as e: + logger.warning( + f"Failed to create compiler for style {style_name}: {e}, " + "keeping current compiler" + ) - frame_count += 1 - if frame_count % 100 == 0: - logger.debug(f"Spout input received {frame_count} frames") + # Recompile with new style - but only if WorldState has content. + # In performance mode (empty WorldState), preserve the current prompt. + if self.world_state.is_empty(): + logger.info( + "WorldState empty - preserving current prompt (performance mode)" + ) else: - time.sleep(0.001) # Small sleep when no frame available - - except Exception as e: - logger.error(f"Error in Spout input loop: {e}") - time.sleep(0.01) + compiled = self.prompt_compiler.compile( + self.world_state, self.style_manifest + ) + self._compiled_prompt = compiled + + if not explicit_prompts_set: + merged_updates["prompts"] = [p.to_dict() for p in compiled.prompts] + + # LoRA only on style change (edge-trigger) + if style_changed: + # Reset cache for clean transition, or skip for blend artifacts + if self.reset_cache_on_style_switch: + merged_updates["reset_cache"] = True + else: + logger.info( + "Style switch without cache reset (blend mode enabled)" + ) - logger.info(f"Spout input thread stopped after {frame_count} frames") + # Canonicalize paths and dedupe updates (styles may share the same LoRA). + lora_updates = self.style_registry.build_lora_scales_for_style( + style_name + ) + if lora_updates: + merged_updates["lora_scales"] = lora_updates + # When blend mode is enabled, tell pipeline to skip cache reset on LoRA scale change + if not self.reset_cache_on_style_switch: + merged_updates["lora_scales_skip_cache_reset"] = True + logger.info( + "LoRA scales updated for style '%s' (%d paths)", + style_name, + len(lora_updates), + ) + else: + logger.warning(f"Style not found in registry: {style_name}") + + step_requested = self._pending_steps > 0 + + # ======================================================================== + # TRANSLATE: Convert dict updates to typed events for ordering + # ======================================================================== + if merged_updates: + # VACE control-map mode changes (e.g. raw video -> depth) should default + # to a hard cut. Otherwise the KV cache can retain prior video-derived + # appearance information and "leak" it after the switch. + if "vace_control_map_mode" in merged_updates: + prev_mode = self.parameters.get("vace_control_map_mode", "none") + next_mode = merged_updates.get("vace_control_map_mode") or "none" + if next_mode != prev_mode: + # Avoid stale previews across mode switches (e.g. canny -> depth). + with self.latest_control_frame_lock: + self.latest_control_frame_cpu = None + self._control_map_worker.clear_latest(lock_held=True) + + # Allow callers to explicitly override by sending reset_cache. + if "reset_cache" not in merged_updates: + merged_updates["reset_cache"] = True + logger.info( + "VACE control-map mode change: %s -> %s (forcing reset_cache=True)", + prev_mode, + next_mode, + ) - def worker_loop(self): - logger.info("Worker thread started") + if "vace_depth_temporal_mode" in merged_updates: + prev_mode = ( + (self.parameters.get("vace_depth_temporal_mode") or "stream") + .strip() + .lower() + ) + next_mode = ( + (merged_updates.get("vace_depth_temporal_mode") or "stream") + .strip() + .lower() + ) + if next_mode not in ("stream", "stateless"): + logger.warning( + "Ignoring invalid vace_depth_temporal_mode=%r; expected 'stream' or 'stateless'", + next_mode, + ) + elif next_mode != prev_mode: + with self.latest_control_frame_lock: + self.latest_control_frame_cpu = None + self._control_map_worker.clear_latest(lock_held=True) + if "reset_cache" not in merged_updates: + merged_updates["reset_cache"] = True + logger.info( + "VACE depth temporal mode change: %s -> %s (forcing reset_cache=True)", + prev_mode, + next_mode, + ) - while self.running and not self.shutdown_event.is_set(): - try: - self.process_chunk() + if "vace_control_buffer_enabled" in merged_updates: + raw = merged_updates.get("vace_control_buffer_enabled") + if isinstance(raw, str): + next_enabled = raw.strip().lower() in ("1", "true", "yes", "on") + else: + next_enabled = bool(raw) + prev_enabled = bool(self._control_buffer_enabled) + if next_enabled != prev_enabled: + if "reset_cache" not in merged_updates: + merged_updates["reset_cache"] = True + logger.info( + "VACE control buffer enabled: %s -> %s (forcing reset_cache=True)", + prev_enabled, + next_enabled, + ) + self._control_buffer_enabled = next_enabled - except PipelineNotAvailableException as e: - logger.debug(f"Pipeline temporarily unavailable: {e}") - # Flush frame buffer to prevent buildup - with self.frame_buffer_lock: - if self.frame_buffer: - logger.debug( - f"Flushing {len(self.frame_buffer)} frames due to pipeline unavailability" + if "vace_control_map_worker_enabled" in merged_updates: + raw = merged_updates.get("vace_control_map_worker_enabled") + if isinstance(raw, str): + worker_enabled = raw.strip().lower() in ("1", "true", "yes", "on") + else: + worker_enabled = bool(raw) + if worker_enabled: + self._control_map_worker.start() + else: + self._control_map_worker.stop() + with self.latest_control_frame_lock: + self.latest_control_frame_cpu = None + self._control_map_worker.clear_latest(lock_held=True) + + if "vace_control_map_worker_allow_heavy" in merged_updates: + raw = merged_updates.get("vace_control_map_worker_allow_heavy") + if isinstance(raw, str): + allow_heavy = raw.strip().lower() in ("1", "true", "yes", "on") + else: + allow_heavy = bool(raw) + self._control_map_worker.set_allow_heavy(allow_heavy) + + if "vace_control_map_worker_max_fps" in merged_updates: + raw = merged_updates.get("vace_control_map_worker_max_fps") + max_fps: float | None + if raw is None: + max_fps = None + else: + try: + max_fps = float(raw) + except (TypeError, ValueError): + logger.warning( + "Ignoring invalid vace_control_map_worker_max_fps=%r; expected float", + raw, ) - self.frame_buffer.clear() - continue - except Exception as e: - if self._is_recoverable(e): - logger.error(f"Error in worker loop: {e}") - continue + max_fps = None + self._control_map_worker.set_max_fps(max_fps) + + # Handle pause/resume via events + if "paused" in merged_updates: + paused_val = merged_updates.pop("paused") + if paused_val: + self.control_bus.enqueue(EventType.PAUSE) else: - logger.error( - f"Non-recoverable error in worker loop: {e}, stopping frame processor" + self.control_bus.enqueue(EventType.RESUME) + + # Handle prompts/transition via events + if "prompts" in merged_updates or "transition" in merged_updates: + payload = {} + if "prompts" in merged_updates: + payload["prompts"] = merged_updates.pop("prompts") + if "transition" in merged_updates: + payload["transition"] = merged_updates.pop("transition") + self.control_bus.enqueue(EventType.SET_PROMPT, payload=payload) + + # Handle lora_scales via events + if "lora_scales" in merged_updates: + lora_payload = {"lora_scales": merged_updates.pop("lora_scales")} + if "lora_scales_skip_cache_reset" in merged_updates: + lora_payload["lora_scales_skip_cache_reset"] = merged_updates.pop( + "lora_scales_skip_cache_reset" ) - self.stop(error_message=str(e)) - break - logger.info("Worker thread stopped") + self.control_bus.enqueue( + EventType.SET_LORA_SCALES, + payload=lora_payload, + ) - def process_chunk(self): - start_time = time.time() - try: - # Check if there are new parameters - new_parameters = self.parameters_queue.get_nowait() - if new_parameters != self.parameters: + # Handle base_seed via events + if "base_seed" in merged_updates: + self.control_bus.enqueue( + EventType.SET_SEED, + payload={"base_seed": merged_updates.pop("base_seed")}, + ) + + # Handle denoising_step_list via events + if "denoising_step_list" in merged_updates: + self.control_bus.enqueue( + EventType.SET_DENOISE_STEPS, + payload={ + "denoising_step_list": merged_updates.pop("denoising_step_list") + }, + ) + + # Update video mode if input_mode parameter changes + if "input_mode" in merged_updates: + self._video_mode = merged_updates.get("input_mode") == "video" + + # Remaining keys merge directly into self.parameters (no event needed) + if merged_updates: + self.parameters = {**self.parameters, **merged_updates} + + # ======================================================================== + # ORDER + APPLY: Apply events in deterministic order + # ======================================================================== + events = self.control_bus.drain_pending( + is_paused=self.paused, chunk_index=self.chunk_index + ) + + applied_prompt_payload: dict[str, Any] | None = None + for event in events: + if event.type == EventType.PAUSE: + self.paused = True + elif event.type == EventType.RESUME: + self.paused = False + elif event.type == EventType.SET_PROMPT: # Clear stale transition when new prompts arrive without transition if ( - "prompts" in new_parameters - and "transition" not in new_parameters + "prompts" in event.payload + and "transition" not in event.payload and "transition" in self.parameters ): self.parameters.pop("transition", None) + # Apply prompt/transition to parameters + if "prompts" in event.payload: + self.parameters["prompts"] = event.payload["prompts"] + if "transition" in event.payload: + self.parameters["transition"] = event.payload["transition"] + applied_prompt_payload = event.payload + elif event.type == EventType.SET_LORA_SCALES: + self.parameters["lora_scales"] = event.payload["lora_scales"] + if event.payload.get("lora_scales_skip_cache_reset"): + self.parameters["lora_scales_skip_cache_reset"] = True + elif event.type == EventType.SET_SEED: + new_seed = event.payload["base_seed"] + self.parameters["base_seed"] = new_seed + # Track seed history (keep last 50) + if not hasattr(self, "_seed_history"): + self._seed_history = [] + self._seed_history.append(new_seed) + if len(self._seed_history) > 50: + self._seed_history = self._seed_history[-50:] + logger.info(f"Seed set: {new_seed}") + elif event.type == EventType.SET_DENOISE_STEPS: + denoising_step_list = event.payload.get("denoising_step_list") + if ( + not isinstance(denoising_step_list, list) + or not denoising_step_list + or not all(isinstance(step, int) for step in denoising_step_list) + ): + logger.warning( + "Ignoring invalid denoising_step_list=%r", + denoising_step_list, + ) + else: + self.parameters["denoising_step_list"] = denoising_step_list + logger.info("Set denoising_step_list=%s", denoising_step_list) - # Update video mode if input_mode parameter changes - if "input_mode" in new_parameters: - self._video_mode = new_parameters.get("input_mode") == "video" + # Check if paused after applying events (step overrides pause) + if self.paused and not step_requested: + # Sleep briefly to avoid busy waiting + self.shutdown_event.wait(SLEEP_TIME) + return - # Merge new parameters with existing ones to preserve any missing keys - self.parameters = {**self.parameters, **new_parameters} - except queue.Empty: - pass + # Recorder prompt-edge detection: capture prompt changes applied while paused/video-waiting. + fallback_prompt: str | None = None + fallback_weight: float = 1.0 + if self.session_recorder.is_recording: + prev_prompt = self.session_recorder.last_prompt + cur_prompt, cur_weight = self._get_current_effective_prompt() + if cur_prompt is not None and cur_prompt != prev_prompt: + fallback_prompt = cur_prompt + fallback_weight = float(cur_weight) # Get the current pipeline using sync wrapper pipeline = self.pipeline_manager.get_pipeline() - # Pause or resume the processing - paused = self.parameters.pop("paused", None) - if paused is not None and paused != self.paused: - self.paused = paused - if self.paused: - # Sleep briefly to avoid busy waiting - self.shutdown_event.wait(SLEEP_TIME) - return + external_hold_last = ( + getattr(pipeline, "vace_enabled", False) and self._ndi_external_hold_last_enabled() + ) + if external_hold_last: + # External control staleness policy: + # - Hold-last keeps generation running through short gaps/jitter. + # - If the newest control frame is too old, stall until fresh input arrives. + stale_ms = self._get_vace_external_stale_ms() + stale_now = False + age_ms: float | None = None + if stale_ms > 0 and self.ndi_last_frame_ts_s > 0: + age_ms = (time.monotonic() - float(self.ndi_last_frame_ts_s)) * 1000.0 + stale_now = age_ms > stale_ms + + if stale_now: + if not self.external_input_stale: + logger.warning( + "External control stale (age_ms=%.1f > stale_ms=%.1f); stalling generation", + age_ms, + stale_ms, + ) + self.external_input_stale = True + self._external_resume_hard_cut_pending = True + self.shutdown_event.wait(SLEEP_TIME) + return + + if self.external_input_stale: + self.external_input_stale = False + resume_hard_cut = self._get_vace_external_resume_hard_cut_enabled() + if ( + resume_hard_cut + and self._external_resume_hard_cut_pending + and "reset_cache" not in self.parameters + ): + self.parameters["reset_cache"] = True + logger.info( + "External control resumed (age_ms=%.1f <= stale_ms=%.1f); forcing reset_cache=True", + age_ms if age_ms is not None else -1.0, + stale_ms, + ) + elif self._external_resume_hard_cut_pending: + logger.info( + "External control resumed (age_ms=%.1f <= stale_ms=%.1f); resume hard cut disabled", + age_ms if age_ms is not None else -1.0, + stale_ms, + ) + self._external_resume_hard_cut_pending = False + else: + self.external_input_stale = False + self._external_resume_hard_cut_pending = False # prepare() will handle any required preparation based on parameters internally - reset_cache = self.parameters.pop("reset_cache", None) - - # Pop lora_scales to prevent re-processing on every frame - lora_scales = self.parameters.pop("lora_scales", None) + reset_cache = self.parameters.get("reset_cache", None) + lora_scales = self.parameters.get("lora_scales", None) + lora_scales_skip_cache_reset = self.parameters.get( + "lora_scales_skip_cache_reset", False + ) + hard_cut_executed = False - # Clear output buffer queue when reset_cache is requested to prevent old frames + # Clear output buffer queue when reset_cache is requested to prevent old frames. + # Keep reset_cache pending until it is actually applied (we might early-return + # while waiting for video input). if reset_cache: - logger.info("Clearing output buffer queue due to reset_cache request") - while not self.output_queue.empty(): + if not self._hard_cut_flushed_pending: + logger.info( + "HARD CUT: reset_cache=True received, will pass init_cache=True to pipeline" + ) + self.flush_output_queue() + + # Phase 2.1b: clear generation control buffer + reset VDA streaming state. + # Do this once per hard cut to avoid repeatedly draining buffers while we wait for video. try: - self.output_queue.get_nowait() - except queue.Empty: - break + self._control_map_worker.request_hard_cut( + clear_queue=True, reason="reset_cache" + ) + except Exception: + logger.warning("ControlMapWorker hard_cut failed", exc_info=True) + + # Also reset local fallback generators so chunk-time fallback is clean. + if self._depth_generator is not None: + self._depth_generator.reset_cache() + self._prev_control_frames = None + self._prev_control_map_mode = None + + self._hard_cut_flushed_pending = True + else: + self._hard_cut_flushed_pending = False requirements = None if hasattr(pipeline, "prepare"): prepare_params = dict(self.parameters.items()) + prepare_params.pop("reset_cache", None) + prepare_params.pop("lora_scales", None) + prepare_params.pop("lora_scales_skip_cache_reset", None) if self._video_mode: # Signal to prepare() that video input is expected. # This allows resolve_input_mode() to detect video mode correctly. @@ -695,41 +3477,324 @@ def process_chunk(self): ) video_input = None + frame_ids: list[int] | None = None if requirements is not None: current_chunk_size = requirements.input_size - with self.frame_buffer_lock: - if not self.frame_buffer or len(self.frame_buffer) < current_chunk_size: + hold_last = self._ndi_external_hold_last_enabled() + if hold_last: + # In external/passthrough mode, NDI frames are control maps. Avoid coupling + # generator cadence to NDI arrival cadence by holding the last frame. + with self.frame_buffer_lock: + buffer_len = len(self.frame_buffer) + has_any_frame = buffer_len > 0 or self._ndi_hold_last_input_frame is not None + if not has_any_frame: + # Sleep briefly to avoid busy waiting + self.shutdown_event.wait(SLEEP_TIME) + return + video_input, frame_ids = self._prepare_chunk_hold_last(current_chunk_size) + else: + with self.frame_buffer_lock: + has_enough_frames = bool(self.frame_buffer) and ( + len(self.frame_buffer) >= current_chunk_size + ) + if not has_enough_frames: # Sleep briefly to avoid busy waiting self.shutdown_event.wait(SLEEP_TIME) return - video_input = self.prepare_chunk(current_chunk_size) + video_input, frame_ids = self.prepare_chunk(current_chunk_size) + if len(video_input) < current_chunk_size: + # Buffer state changed underneath us; retry next loop. + self.shutdown_event.wait(SLEEP_TIME) + return + chunk_error: Exception | None = None try: # Pass parameters (excluding prepare-only parameters) call_params = dict(self.parameters.items()) + call_params.pop("reset_cache", None) + call_params.pop("lora_scales", None) + call_params.pop("lora_scales_skip_cache_reset", None) # Pass reset_cache as init_cache to pipeline call_params["init_cache"] = not self.is_prepared if reset_cache is not None: call_params["init_cache"] = reset_cache + hard_cut_executed = bool(reset_cache) # Pass lora_scales only when present (one-time update) if lora_scales is not None: call_params["lora_scales"] = lora_scales - - # Route video input based on VACE status - # We do not support combining latent initialization and VACE conditioning + # When blend mode is enabled, tell pipeline to skip its own cache reset + if lora_scales_skip_cache_reset: + call_params["lora_scales_skip_cache_reset"] = True + + # Pass soft_transition_active to prevent cache reset during soft transitions + if self._soft_transition_active: + call_params["soft_transition_active"] = True + + # Route video input based on VACE status. + # + # Default behavior is mutually exclusive: + # - VACE enabled: treat the incoming stream as conditioning-only (`vace_input_frames`) + # - VACE disabled: treat the incoming stream as latent-init V2V (`video`) + # + # Experimental hybrid mode (opt-in via `vace_hybrid_video_init=True`): + # - Provide BOTH `video` (latent init) and `vace_input_frames` (conditioning) if video_input is not None: vace_enabled = getattr(pipeline, "vace_enabled", False) - vace_use_input_video = self.parameters.get("vace_use_input_video", True) + if vace_enabled: + vace_hybrid_video_init = bool( + self.parameters.get("vace_hybrid_video_init", False) + ) + # VACE V2V editing mode: route to vace_input_frames + # Apply control map transform if enabled + control_map_mode = self.parameters.get( + "vace_control_map_mode", "none" + ) + if control_map_mode == "canny": + # Phase 2.1b: try buffer sampling first + control_frames = self._try_sample_control_frames(frame_ids) + if control_frames is None: + # Fallback: chunk-time compute + low = self.parameters.get("vace_canny_low_threshold", None) + high = self.parameters.get("vace_canny_high_threshold", None) + blur_kernel = self.parameters.get("vace_canny_blur_kernel", 5) + blur_sigma = self.parameters.get("vace_canny_blur_sigma", 1.4) + adaptive = self.parameters.get("vace_canny_adaptive", True) + dilate = self.parameters.get("vace_canny_dilate", False) + dilate_size = self.parameters.get("vace_canny_dilate_size", 2) + control_frames = apply_canny_edges( + video_input, + low_threshold=low, + high_threshold=high, + blur_kernel_size=blur_kernel, + blur_sigma=blur_sigma, + adaptive_thresholds=adaptive, + dilate_edges=dilate, + dilate_kernel_size=dilate_size, + ) + call_params["vace_input_frames"] = control_frames + # Store latest control frame for preview streaming + with self.latest_control_frame_lock: + last = control_frames[-1].squeeze(0) + if last.dtype != torch.uint8: + last = last.clamp(0, 255).to(torch.uint8) + self.latest_control_frame_cpu = last.to(device="cpu") + elif control_map_mode == "pidinet": + # Phase 2.1b: try buffer sampling first + control_frames = self._try_sample_control_frames(frame_ids) + if control_frames is None: + # Fallback: chunk-time compute + if self._pidinet_generator is None: + self._pidinet_generator = PiDiNetEdgeGenerator() + safe_mode = self.parameters.get("vace_pidinet_safe", True) + apply_filter = self.parameters.get( + "vace_pidinet_filter", True + ) + self._pidinet_generator.safe_mode = safe_mode + control_frames = self._pidinet_generator.process_frames( + video_input, apply_filter=apply_filter + ) + call_params["vace_input_frames"] = control_frames + # Store latest control frame for preview streaming + with self.latest_control_frame_lock: + last = control_frames[-1].squeeze(0) + if last.dtype != torch.uint8: + last = last.clamp(0, 255).to(torch.uint8) + self.latest_control_frame_cpu = last.to(device="cpu") + elif control_map_mode == "depth": + hard_cut = reset_cache is not None and reset_cache + # Phase 2.1b: try buffer sampling first + control_frames = self._try_sample_control_frames(frame_ids) + if control_frames is None: + # Fallback: chunk-time compute + if self._depth_generator is None: + self._depth_generator = VDADepthControlMapGenerator() + depth_input_size = self.parameters.get("vace_depth_input_size") + depth_fp32 = self.parameters.get("vace_depth_fp32") + depth_temporal_mode = self.parameters.get( + "vace_depth_temporal_mode" + ) + depth_contrast = self.parameters.get("vace_depth_contrast") + if depth_contrast is not None: + self._depth_generator.depth_contrast = depth_contrast + + depth_output_device = os.getenv( + "SCOPE_VACE_DEPTH_CHUNK_OUTPUT_DEVICE", "cpu" + ).strip() + if depth_output_device.lower() not in ("", "cpu", "cuda") and not depth_output_device.lower().startswith( + "cuda:" + ): + logger.warning( + "Invalid SCOPE_VACE_DEPTH_CHUNK_OUTPUT_DEVICE=%r; expected 'cpu' or 'cuda'", + depth_output_device, + ) + depth_output_device = "cpu" + + control_frames = self._depth_generator.process_frames( + video_input, + hard_cut=hard_cut, + input_size=depth_input_size, + fp32=depth_fp32, + temporal_mode=depth_temporal_mode, + output_device=depth_output_device, + ) + # Apply temporal EMA if enabled (after either path) + temporal_ema = self.parameters.get( + "vace_control_map_temporal_ema", 0.0 + ) + if temporal_ema > 0: + control_frames = self._apply_temporal_ema( + control_frames, "depth", temporal_ema, hard_cut + ) + call_params["vace_input_frames"] = control_frames + # Store latest control frame for preview streaming + with self.latest_control_frame_lock: + last = control_frames[-1].squeeze(0) + if last.dtype != torch.uint8: + last = last.clamp(0, 255).to(torch.uint8) + self.latest_control_frame_cpu = last.to(device="cpu") + elif control_map_mode == "composite": + hard_cut = reset_cache is not None and reset_cache + # Phase 2.1b: try buffer sampling first + control_frames = self._try_sample_control_frames(frame_ids) + if control_frames is None: + # Fallback: chunk-time compute (depth + edges fused) + if self._depth_generator is None: + self._depth_generator = VDADepthControlMapGenerator() + + edge_strength = self.parameters.get( + "composite_edge_strength", 0.6 + ) + edge_thickness = self.parameters.get( + "composite_edge_thickness", 8 + ) + sharpness = self.parameters.get("composite_sharpness", 10.0) + edge_source = self.parameters.get( + "composite_edge_source", "canny" + ) + + depth_input_size = self.parameters.get("vace_depth_input_size") + depth_fp32 = self.parameters.get("vace_depth_fp32") + depth_temporal_mode = self.parameters.get( + "vace_depth_temporal_mode" + ) + depth_contrast = self.parameters.get("vace_depth_contrast") + if depth_contrast is not None: + self._depth_generator.depth_contrast = depth_contrast + + depth_output_device = os.getenv( + "SCOPE_VACE_DEPTH_CHUNK_OUTPUT_DEVICE", "cpu" + ).strip() + if depth_output_device.lower() not in ("", "cpu", "cuda") and not depth_output_device.lower().startswith( + "cuda:" + ): + logger.warning( + "Invalid SCOPE_VACE_DEPTH_CHUNK_OUTPUT_DEVICE=%r; expected 'cpu' or 'cuda'", + depth_output_device, + ) + depth_output_device = "cpu" + + depth_frames = self._depth_generator.process_frames( + video_input, + hard_cut=hard_cut, + input_size=depth_input_size, + fp32=depth_fp32, + temporal_mode=depth_temporal_mode, + output_device=depth_output_device, + ) - if vace_enabled and vace_use_input_video: - # VACE conditioning: route to vace_input_frames - call_params["vace_input_frames"] = video_input + if edge_source == "pidinet": + if self._pidinet_generator is None: + self._pidinet_generator = PiDiNetEdgeGenerator() + safe_mode = self.parameters.get( + "vace_pidinet_safe", True + ) + apply_filter = self.parameters.get( + "vace_pidinet_filter", True + ) + self._pidinet_generator.safe_mode = safe_mode + edge_frames = self._pidinet_generator.process_frames( + video_input, apply_filter=apply_filter + ) + else: + edge_frames = apply_canny_edges( + video_input, + adaptive_thresholds=True, + dilate_edges=True, + dilate_kernel_size=edge_thickness, + ) + + control_frames = [] + for depth_f, edge_f in zip(depth_frames, edge_frames, strict=True): + if depth_f.device != edge_f.device: + edge_f = edge_f.to(device=depth_f.device) + depth_norm = depth_f[:, :, :, 0] / 255.0 + edge_norm = edge_f[:, :, :, 0] / 255.0 + + fused = soft_max_fusion( + depth_norm, + edge_norm, + edge_strength=edge_strength, + sharpness=sharpness, + ) + + fused_uint8 = (fused * 255.0).clamp(0, 255) + fused_rgb = fused_uint8.unsqueeze(-1).expand(-1, -1, -1, 3) + control_frames.append(fused_rgb) + + # Apply temporal EMA if enabled (after either path) + temporal_ema = self.parameters.get( + "vace_control_map_temporal_ema", 0.0 + ) + if temporal_ema > 0: + control_frames = self._apply_temporal_ema( + control_frames, "composite", temporal_ema, hard_cut + ) + call_params["vace_input_frames"] = control_frames + # Store latest control frame for preview streaming + with self.latest_control_frame_lock: + last = control_frames[-1].squeeze(0) + if last.dtype != torch.uint8: + last = last.clamp(0, 255).to(torch.uint8) + self.latest_control_frame_cpu = last.to(device="cpu") + elif control_map_mode == "external": + # External/passthrough mode: use video input directly as control signal + # No processing - frames are assumed to already be control maps (e.g., from OBS, TouchDesigner) + call_params["vace_input_frames"] = video_input + # Store latest frame for preview streaming (so user can see what VACE receives) + with self.latest_control_frame_lock: + last = video_input[-1].squeeze(0) + if last.dtype != torch.uint8: + last = last.clamp(0, 255).to(torch.uint8) + self.latest_control_frame_cpu = last.to(device="cpu") + else: + call_params["vace_input_frames"] = video_input + if vace_hybrid_video_init: + # Hybrid: also pass the raw video frames as latent-init base. + # Note: VACE encoding must avoid clobbering the VAE streaming cache when + # both `video` and `vace_input_frames` are present (handled in VACE blocks). + call_params["video"] = video_input else: - # Latent initialization: route to video + # Normal V2V mode: route to video call_params["video"] = video_input - output = pipeline(**call_params) + transition_active_after_call = False + with self.pipeline_manager.locked_pipeline() as locked_pipeline: + output = locked_pipeline(**call_params) + if hasattr(locked_pipeline, "state"): + transition_active_after_call = locked_pipeline.state.get( + "_transition_active", False + ) + + # Consume one-shot updates only after they were passed to the pipeline. + if lora_scales is not None: + self.parameters.pop("lora_scales", None) + self.parameters.pop("lora_scales_skip_cache_reset", None) + if reset_cache is not None: + self.parameters.pop("reset_cache", None) + self._hard_cut_flushed_pending = False + # Also reset control map worker cache (Phase 2.1a) + self._control_map_worker.reset_cache() # Clear vace_ref_images from parameters after use to prevent sending them on subsequent chunks # vace_ref_images should only be sent when explicitly provided in parameter updates @@ -742,12 +3807,20 @@ def process_chunk(self): # Clear transition when complete (blocks signal completion via _transition_active) # Contract: Modular pipelines manage prompts internally; frame_processor manages lifecycle if "transition" in call_params and "transition" in self.parameters: - transition_active = False - if hasattr(pipeline, "state"): - transition_active = pipeline.state.get("_transition_active", False) + transition_active = transition_active_after_call transition = call_params.get("transition") if not transition_active or transition is None: + target_prompts = None + if isinstance(transition, dict): + target_prompts = transition.get("target_prompts") + elif transition is not None and hasattr( + transition, "target_prompts" + ): + target_prompts = getattr(transition, "target_prompts", None) + + if target_prompts is not None: + self.parameters["prompts"] = target_prompts self.parameters.pop("transition", None) processing_time = time.time() - start_time @@ -766,44 +3839,259 @@ def process_chunk(self): .cpu() ) - # Resize output queue to meet target max size - target_output_queue_max_size = num_frames * OUTPUT_QUEUE_MAX_SIZE_FACTOR - if self.output_queue.maxsize < target_output_queue_max_size: - logger.info( - f"Increasing output queue size to {target_output_queue_max_size}, current size {self.output_queue.maxsize}, num_frames {num_frames}" - ) - - # Transfer frames from old queue to new queue - old_queue = self.output_queue - self.output_queue = queue.Queue(maxsize=target_output_queue_max_size) - while not old_queue.empty(): + # Store latest frame for non-destructive REST reads + with self.latest_frame_lock: + self.latest_frame_cpu = output[-1].clone() + self.latest_frame_id += 1 + self._signal_latest_frame_available() + + # Resize output queue to meet target max size. + # + # Lock protects against race with flush_output_queue(). In low-latency output mode, + # we also keep the queue size bounded (cap) and prefer dropping the oldest frames + # over dropping newly-generated frames. + with self.output_queue_lock: + factor = OUTPUT_QUEUE_MAX_SIZE_FACTOR + factor_env = os.getenv("SCOPE_OUTPUT_QUEUE_MAX_SIZE_FACTOR", "").strip() + if factor_env: try: - frame = old_queue.get_nowait() - self.output_queue.put_nowait(frame) - except queue.Empty: - break + factor = max(1, int(factor_env)) + except ValueError: + logger.warning( + "Invalid SCOPE_OUTPUT_QUEUE_MAX_SIZE_FACTOR=%r; expected int", + factor_env, + ) + elif self._low_latency_output_mode: + # Prefer a small default queue (≈ one chunk) when low-latency output + # is enabled. Users can increase via SCOPE_OUTPUT_QUEUE_MAX_SIZE_FACTOR. + factor = 1 + + target_output_queue_max_size = num_frames * factor + if self._output_queue_maxsize_cap is not None: + target_output_queue_max_size = min( + int(self._output_queue_maxsize_cap), + int(target_output_queue_max_size), + ) - for frame in output: - try: - self.output_queue.put_nowait(frame) - except queue.Full: - logger.warning("Output queue full, dropping processed frame") - # Update FPS calculation based on processing time and frame count - self._calculate_pipeline_fps(start_time, num_frames) - continue + if ( + target_output_queue_max_size > 0 + and self.output_queue.maxsize < target_output_queue_max_size + ): + logger.info( + "Increasing output queue size to %s, current size %s, num_frames %s", + target_output_queue_max_size, + self.output_queue.maxsize, + num_frames, + ) + + # Transfer frames from old queue to new queue + old_queue = self.output_queue + self.output_queue = queue.Queue(maxsize=target_output_queue_max_size) + while not old_queue.empty(): + try: + frame = old_queue.get_nowait() + self.output_queue.put_nowait(frame) + except queue.Empty: + break + + for frame in output: + if self._low_latency_output_mode: + # Drop oldest frames to make room for newest output. + while True: + try: + self.output_queue.put_nowait(frame) + break + except queue.Full: + try: + self.output_queue.get_nowait() + self.output_frames_dropped += 1 + except queue.Empty: + # Shouldn't happen, but avoid spinning. + self.output_frames_dropped += 1 + break + else: + try: + self.output_queue.put_nowait(frame) + except queue.Full: + logger.warning("Output queue full, dropping processed frame") + self.output_frames_dropped += 1 + continue # Update FPS calculation based on processing time and frame count self._calculate_pipeline_fps(start_time, num_frames) except Exception as e: + chunk_error = e if self._is_recoverable(e): # Handle recoverable errors with full stack trace and continue processing logger.error(f"Error processing chunk: {e}", exc_info=True) else: raise e + # SessionRecorder: record prompt/transition + hard/soft cuts for this chunk. + # Record only after a successful pipeline call so paused/video-wait churn doesn't + # create phantom segments. + if self.session_recorder.is_recording and chunk_error is None: + wall_time = time.monotonic() + + # Soft cut metadata is recorded ONCE per trigger, at the first generated chunk. + soft_cut_bias = None + soft_cut_chunks = None + soft_restore_bias = None + soft_restore_was_set = False + if ( + self._soft_transition_record_pending + and self._soft_transition_temp_bias is not None + ): + soft_cut_bias = float(self._soft_transition_temp_bias) + soft_cut_chunks = int(self._soft_transition_chunks_remaining) + soft_restore_was_set = bool(self._soft_transition_original_bias_was_set) + soft_restore_bias = ( + float(self._soft_transition_original_bias) + if self._soft_transition_original_bias is not None + and self._soft_transition_original_bias_was_set + else None + ) + self._soft_transition_record_pending = False + + recorded_prompt_event = False + if applied_prompt_payload is not None: + prompt_text = None + prompt_weight = 1.0 + + tr = applied_prompt_payload.get("transition") + if isinstance(tr, dict): + targets = tr.get("target_prompts") + if isinstance(targets, list) and targets: + first = targets[0] + if isinstance(first, dict): + prompt_text = first.get("text") + prompt_weight = float(first.get("weight", 1.0)) + + if prompt_text is None: + prompts = applied_prompt_payload.get("prompts") + if isinstance(prompts, list) and prompts: + first = prompts[0] + if isinstance(first, dict): + prompt_text = first.get("text") + prompt_weight = float(first.get("weight", 1.0)) + + transition_steps = None + transition_method = None + if isinstance(tr, dict): + transition_steps = tr.get("num_steps") + transition_method = tr.get("temporal_interpolation_method") + + if prompt_text is not None: + self.session_recorder.record_event( + chunk_index=self.chunk_index, + wall_time=wall_time, + prompt=prompt_text, + prompt_weight=prompt_weight, + transition_steps=transition_steps, + transition_method=transition_method, + hard_cut=hard_cut_executed, + soft_cut_bias=soft_cut_bias, + soft_cut_chunks=soft_cut_chunks, + soft_cut_restore_bias=soft_restore_bias, + soft_cut_restore_was_set=soft_restore_was_set, + ) + recorded_prompt_event = True + hard_cut_executed = False + soft_cut_bias = None + + # Fallback: prompt changed since last recorded chunk (e.g. edits while paused) + if not recorded_prompt_event and fallback_prompt is not None: + self.session_recorder.record_event( + chunk_index=self.chunk_index, + wall_time=wall_time, + prompt=fallback_prompt, + prompt_weight=fallback_weight, + hard_cut=hard_cut_executed, + soft_cut_bias=soft_cut_bias, + soft_cut_chunks=soft_cut_chunks, + soft_cut_restore_bias=soft_restore_bias, + soft_cut_restore_was_set=soft_restore_was_set, + ) + recorded_prompt_event = True + hard_cut_executed = False + soft_cut_bias = None + + # Cut-only event (no prompt change): recorder carries forward last prompt + if (not recorded_prompt_event) and ( + hard_cut_executed or soft_cut_bias is not None + ): + self.session_recorder.record_event( + chunk_index=self.chunk_index, + wall_time=wall_time, + prompt=None, + hard_cut=hard_cut_executed, + soft_cut_bias=soft_cut_bias, + soft_cut_chunks=soft_cut_chunks, + soft_cut_restore_bias=soft_restore_bias, + soft_cut_restore_was_set=soft_restore_was_set, + ) + self.is_prepared = True - def prepare_chunk(self, chunk_size: int) -> list[torch.Tensor]: + # Soft transition countdown and auto-restore at chunk boundary + if self._soft_transition_active: + self._soft_transition_chunks_remaining -= 1 + if self._soft_transition_chunks_remaining <= 0: + if self._soft_transition_original_bias_was_set: + # Restore original bias value + if self._soft_transition_original_bias is not None: + self.parameters["kv_cache_attention_bias"] = ( + self._soft_transition_original_bias + ) + logger.info( + f"Soft transition complete: restored bias to " + f"{self._soft_transition_original_bias}" + ) + else: + self.parameters.pop("kv_cache_attention_bias", None) + logger.info( + "Soft transition complete: restored kv_cache_attention_bias to " + ) + else: + # Restore to "unset" (pipeline/config default) if we didn't get overridden. + current_bias = self.parameters.get("kv_cache_attention_bias") + if ( + self._soft_transition_temp_bias is None + or current_bias == self._soft_transition_temp_bias + ): + self.parameters.pop("kv_cache_attention_bias", None) + logger.info( + "Soft transition complete: restored kv_cache_attention_bias to " + ) + else: + logger.info( + "Soft transition complete: keeping kv_cache_attention_bias override" + ) + + self._soft_transition_active = False + self._soft_transition_chunks_remaining = 0 + self._soft_transition_temp_bias = None + self._soft_transition_original_bias = None + self._soft_transition_original_bias_was_set = False + self._soft_transition_record_pending = False + + self.chunk_index += 1 + + # Send step response after completing a step-driven chunk generation. + if self._pending_steps > 0: + self._pending_steps = max(0, self._pending_steps - 1) + + if step_requested and self.snapshot_response_callback: + self.snapshot_response_callback( + { + "type": "step_response", + "chunk_index": self.chunk_index, + "success": chunk_error is None, + "error": str(chunk_error) if chunk_error is not None else None, + } + ) + + def prepare_chunk(self, chunk_size: int) -> tuple[list[torch.Tensor], list[int]]: """ Sample frames uniformly from the buffer, convert them to tensors, and remove processed frames. @@ -812,44 +4100,370 @@ def prepare_chunk(self, chunk_size: int) -> list[torch.Tensor]: indices and removes all frames up to the last sampled frame to prevent buffer buildup. - Note: - This function must be called with self.frame_buffer_lock held to ensure - thread safety. The caller is responsible for acquiring the lock. + When low-latency mode is enabled (SCOPE_LOW_LATENCY_INPUT=1), the buffer is + first trimmed to only keep the newest frames, reducing input lag at the cost + of dropping older frames. + + Implementation note: + We hold `frame_buffer_lock` only while sampling + popping frames, then + convert frames to tensors outside the lock to avoid blocking concurrent + input producers (WebRTC/Spout/NDI) on expensive `to_ndarray()` work. - Example: + Example (normal mode): With buffer_len=8 and chunk_size=4: - step = 8/4 = 2.0 - indices = [0, 2, 4, 6] (uniformly distributed) - Returns frames at positions 0, 2, 4, 6 - Removes frames 0-6 from buffer (7 frames total) + Example (low-latency mode, buffer_factor=2): + With buffer_len=30 and chunk_size=4: + - max_keep = 4 * 2 = 8 + - Drop 22 oldest frames first + - Then sample uniformly from remaining 8 frames + Returns: - List of tensor frames, each (1, H, W, C) for downstream preprocess_chunk + Tuple of (tensor_frames, frame_ids): + - tensor_frames: List of (1, H, W, C) tensors for downstream preprocess_chunk + - frame_ids: List of frame IDs for Phase 2.1b control buffer sampling """ - # Calculate uniform sampling step - step = len(self.frame_buffer) / chunk_size - # Generate indices for uniform sampling - indices = [round(i * step) for i in range(chunk_size)] - # Extract VideoFrames at sampled indices - video_frames = [self.frame_buffer[i] for i in indices] - - # Drop all frames up to and including the last sampled frame - last_idx = indices[-1] - for _ in range(last_idx + 1): - self.frame_buffer.popleft() - - # Convert VideoFrames to tensors (keep as uint8, GPU will handle dtype conversion) + with self.frame_buffer_lock: + if not self.frame_buffer or len(self.frame_buffer) < chunk_size: + return [], [] + + # Low-latency mode: trim buffer to reduce lag + if self._low_latency_mode: + max_keep = chunk_size * self._low_latency_buffer_factor + if len(self.frame_buffer) > max_keep: + drop_count = len(self.frame_buffer) - max_keep + for _ in range(drop_count): + self.frame_buffer.popleft() + self.input_frames_dropped += drop_count + + # Calculate uniform sampling step + step = len(self.frame_buffer) / chunk_size + # Generate indices for uniform sampling + indices = [round(i * step) for i in range(chunk_size)] + # Extract VideoFrames at sampled indices + video_frames = [self.frame_buffer[i] for i in indices] + + # Phase 2.1b: carry frame IDs through for control buffer lookup + frame_ids = [int(getattr(f, "frame_id", -1)) for f in video_frames] + + # Drop all frames up to and including the last sampled frame + last_idx = indices[-1] + for _ in range(last_idx + 1): + self.frame_buffer.popleft() + + # Convert VideoFrames to tensors + mirror_input = self.parameters.get("mirror_input", True) # Default: mirrored (selfie mode) tensor_frames = [] for video_frame in video_frames: - # Convert VideoFrame into (1, H, W, C) uint8 tensor on cpu + # Convert VideoFrame into (1, H, W, C) tensor on cpu # The T=1 dimension is expected by preprocess_chunk which rearranges T H W C -> T C H W - # Note: We keep uint8 here and let pipeline preprocess chunk to target dtype on GPU - tensor = torch.from_numpy(video_frame.to_ndarray(format="rgb24")).unsqueeze( - 0 + tensor = ( + torch.from_numpy(video_frame.to_ndarray(format="rgb24")) + .unsqueeze(0) ) + # Apply horizontal flip when mirror mode is DISABLED + # mirror_input=True → selfie mode (natural, no flip - browser already sends mirrored) + # mirror_input=False → raw camera mode (flip to show true orientation) + if not mirror_input: + tensor = torch.flip(tensor, dims=[2]) # Flip width dimension (1, H, W, C) + tensor_frames.append(tensor) + + return tensor_frames, frame_ids + + def _ndi_external_hold_last_enabled(self) -> bool: + if self.get_active_input_source() != "ndi": + return False + return (self.parameters.get("vace_control_map_mode") or "none") == "external" + + def _get_vace_external_stale_ms(self) -> float: + raw = self.parameters.get("vace_external_stale_ms") + if raw is None: + raw = os.getenv("SCOPE_VACE_EXTERNAL_STALE_MS", "500") + try: + ms = float(raw) + except (TypeError, ValueError): + ms = 500.0 + return max(0.0, ms) + + def _get_vace_external_resume_hard_cut_enabled(self) -> bool: + raw = self.parameters.get("vace_external_resume_hard_cut") + if raw is None: + raw = os.getenv("SCOPE_VACE_EXTERNAL_RESUME_HARD_CUT", "1") + if isinstance(raw, str): + return raw.strip().lower() in ("1", "true", "yes", "on") + return bool(raw) + + def _get_output_pacing_fps(self) -> float | None: + raw = self.parameters.get("output_pacing_fps") + if raw is None: + raw = os.getenv("SCOPE_OUTPUT_PACING_FPS", "").strip() + + if raw in (None, ""): + return None + + try: + fps = float(raw) + except (TypeError, ValueError): + return None + + if fps <= 0: + return None + + return fps + + def _prepare_chunk_hold_last(self, chunk_size: int) -> tuple[list[torch.Tensor], list[int]]: + """Prepare a chunk, repeating the latest input frame if we're underflowing. + + This is intended for NDI + `vace_control_map_mode="external"` where input + frames are control maps and should not stall the generator. + """ + # Fast path: preserve existing uniform-sampling semantics when enough frames exist. + with self.frame_buffer_lock: + buffer_len = len(self.frame_buffer) + if buffer_len >= chunk_size: + tensor_frames, frame_ids = self.prepare_chunk(chunk_size) + if tensor_frames: + self._ndi_hold_last_input_frame = tensor_frames[-1].detach().clone() + self._ndi_hold_last_input_frame_id = int(frame_ids[-1]) if frame_ids else None + return tensor_frames, frame_ids + + # Underflow: take whatever we have right now (possibly 0), then fill by repeating. + with self.frame_buffer_lock: + video_frames = list(self.frame_buffer) + self.frame_buffer.clear() + + mirror_input = self.parameters.get("mirror_input", True) # Default: mirrored (selfie mode) + tensor_frames: list[torch.Tensor] = [] + frame_ids: list[int] = [] + for video_frame in video_frames: + tensor = torch.from_numpy(video_frame.to_ndarray(format="rgb24")).unsqueeze(0) + if not mirror_input: + tensor = torch.flip(tensor, dims=[2]) # Flip width dimension (1, H, W, C) tensor_frames.append(tensor) + frame_ids.append(int(getattr(video_frame, "frame_id", -1))) + + base_tensor: torch.Tensor | None = None + base_frame_id: int = -1 + if tensor_frames: + base_tensor = tensor_frames[-1] + base_frame_id = frame_ids[-1] + self._ndi_hold_last_input_frame = base_tensor.detach().clone() + self._ndi_hold_last_input_frame_id = int(base_frame_id) + elif self._ndi_hold_last_input_frame is not None: + base_tensor = self._ndi_hold_last_input_frame + if self._ndi_hold_last_input_frame_id is not None: + base_frame_id = int(self._ndi_hold_last_input_frame_id) + else: + return [], [] + + reused = 0 + while len(tensor_frames) < chunk_size: + tensor_frames.append(base_tensor.clone()) + frame_ids.append(base_frame_id) + reused += 1 + + if reused: + self.ndi_frames_reused += reused + + return tensor_frames, frame_ids + + def _create_snapshot(self) -> Snapshot: + """Create a snapshot of current generation state. + + Captures: + - Continuity state from pipeline.state (cloned tensors) + - Control state (deep copy of parameters) + - Metadata (chunk_index, timestamp, resolution) + + Returns: + Snapshot object with unique ID + """ + snapshot_id = str(uuid.uuid4()) + with self.pipeline_manager.locked_pipeline() as pipeline: + # Capture continuity state from pipeline.state + current_start_frame = 0 + first_context_frame = None + context_frame_buffer = None + decoded_frame_buffer = None + context_frame_buffer_max_size = 0 + decoded_frame_buffer_max_size = 0 + + if hasattr(pipeline, "state"): + state = pipeline.state + current_start_frame = state.get("current_start_frame", 0) + context_frame_buffer_max_size = state.get( + "context_frame_buffer_max_size", 0 + ) + decoded_frame_buffer_max_size = state.get( + "decoded_frame_buffer_max_size", 0 + ) + + # Clone tensors to avoid mutation + fcf = state.get("first_context_frame") + if fcf is not None and isinstance(fcf, torch.Tensor): + first_context_frame = fcf.detach().clone() + + cfb = state.get("context_frame_buffer") + if cfb is not None and isinstance(cfb, torch.Tensor): + context_frame_buffer = cfb.detach().clone() + + dfb = state.get("decoded_frame_buffer") + if dfb is not None and isinstance(dfb, torch.Tensor): + decoded_frame_buffer = dfb.detach().clone() + + # Get resolution from pipeline manager + resolution = self._get_pipeline_dimensions() + + # Get pipeline_id if available + pipeline_id = None + try: + status_info = self.pipeline_manager.get_status_info() + pipeline_id = status_info.get("pipeline_id") + except Exception: + pass - return tensor_frames + snapshot = Snapshot( + snapshot_id=snapshot_id, + chunk_index=self.chunk_index, + created_at=time.time(), + current_start_frame=current_start_frame, + first_context_frame=first_context_frame, + context_frame_buffer=context_frame_buffer, + decoded_frame_buffer=decoded_frame_buffer, + context_frame_buffer_max_size=context_frame_buffer_max_size, + decoded_frame_buffer_max_size=decoded_frame_buffer_max_size, + parameters=copy.deepcopy(self.parameters), + paused=self.paused, + video_mode=self._video_mode, + # Style layer state + world_state_json=self.world_state.model_dump_json(), + active_style_name=self.style_manifest.name if self.style_manifest else None, + compiled_prompt_text=( + self._compiled_prompt.prompt if self._compiled_prompt else None + ), + pipeline_id=pipeline_id, + resolution=resolution, + ) + + # Store snapshot with LRU eviction + self.snapshots[snapshot_id] = snapshot + self.snapshot_order.append(snapshot_id) + + # Evict oldest if over limit + while len(self.snapshots) > MAX_SNAPSHOTS: + oldest_id = self.snapshot_order.pop(0) + old_snapshot = self.snapshots.pop(oldest_id, None) + if old_snapshot: + # Release tensor memory explicitly + old_snapshot.first_context_frame = None + old_snapshot.context_frame_buffer = None + old_snapshot.decoded_frame_buffer = None + logger.debug(f"Evicted snapshot {oldest_id} (LRU)") + + logger.info( + f"Created snapshot {snapshot_id} at chunk {self.chunk_index}, " + f"total snapshots: {len(self.snapshots)}" + ) + + return snapshot + + def _restore_snapshot(self, snapshot_id: str) -> bool: + """Restore generation state from a snapshot. + + Restores: + - Continuity state to pipeline.state + - Control state to self.parameters + - Clears output_queue to prevent stale frames + - Sets is_prepared=True to avoid accidental cache reset + + Args: + snapshot_id: ID of snapshot to restore + + Returns: + True if restore succeeded, False if snapshot not found + """ + snapshot = self.snapshots.get(snapshot_id) + if snapshot is None: + logger.warning(f"Snapshot {snapshot_id} not found") + return False + + # LRU: move restored snapshot to end of order (most recently used) + if snapshot_id in self.snapshot_order: + self.snapshot_order.remove(snapshot_id) + self.snapshot_order.append(snapshot_id) + + with self.pipeline_manager.locked_pipeline() as pipeline: + # Restore continuity state to pipeline.state + if hasattr(pipeline, "state"): + state = pipeline.state + state.set("current_start_frame", snapshot.current_start_frame) + state.set( + "context_frame_buffer_max_size", + snapshot.context_frame_buffer_max_size, + ) + state.set( + "decoded_frame_buffer_max_size", + snapshot.decoded_frame_buffer_max_size, + ) + + # Restore tensors back to pipeline.state (or clear when None). + state.set( + "first_context_frame", + snapshot.first_context_frame.detach().clone() + if snapshot.first_context_frame is not None + else None, + ) + state.set( + "context_frame_buffer", + snapshot.context_frame_buffer.detach().clone() + if snapshot.context_frame_buffer is not None + else None, + ) + state.set( + "decoded_frame_buffer", + snapshot.decoded_frame_buffer.detach().clone() + if snapshot.decoded_frame_buffer is not None + else None, + ) + + # Restore control state + self.parameters = copy.deepcopy(snapshot.parameters) + self.paused = snapshot.paused + self._video_mode = snapshot.video_mode + self.chunk_index = snapshot.chunk_index + + # Restore style layer state (thread-safe via model_validate_json) + if snapshot.world_state_json: + self.world_state = WorldState.model_validate_json(snapshot.world_state_json) + if snapshot.active_style_name: + self.style_manifest = self.style_registry.get(snapshot.active_style_name) + self._active_style_name = snapshot.active_style_name + # Recompile after restore + if self.world_state and self.style_manifest: + self._compiled_prompt = self.prompt_compiler.compile( + self.world_state, self.style_manifest + ) + + # Clear output_queue to prevent stale pre-restore frames + self.flush_output_queue() + + # Clear frame_buffer in V2V mode to prevent stale input frames + if self._video_mode: + with self.frame_buffer_lock: + self.frame_buffer.clear() + + # Set is_prepared=True to avoid accidental cache reset on next chunk + self.is_prepared = True + + logger.info( + f"Restored snapshot {snapshot_id} to chunk {snapshot.chunk_index}" + ) + + return True def __enter__(self): self.start() diff --git a/src/scope/server/ndi/__init__.py b/src/scope/server/ndi/__init__.py new file mode 100644 index 000000000..b7401542a --- /dev/null +++ b/src/scope/server/ndi/__init__.py @@ -0,0 +1,9 @@ +from .finder import NDISource, list_sources +from .receiver import NDIReceiver + +__all__ = [ + "NDIReceiver", + "NDISource", + "list_sources", +] + diff --git a/src/scope/server/ndi/_ctypes.py b/src/scope/server/ndi/_ctypes.py new file mode 100644 index 000000000..9fc1da9f3 --- /dev/null +++ b/src/scope/server/ndi/_ctypes.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import ctypes +import os +from ctypes.util import find_library +from dataclasses import dataclass +from pathlib import Path + + +class NDIlib_source_t(ctypes.Structure): + _fields_ = [ + ("p_ndi_name", ctypes.c_char_p), + ("p_url_address", ctypes.c_char_p), + ] + + +class NDIlib_find_create_t(ctypes.Structure): + _fields_ = [ + ("show_local_sources", ctypes.c_bool), + ("p_groups", ctypes.c_char_p), + ("p_extra_ips", ctypes.c_char_p), + ] + + +class NDIlib_recv_create_v3_t(ctypes.Structure): + _fields_ = [ + ("source_to_connect_to", NDIlib_source_t), + ("color_format", ctypes.c_int), + ("bandwidth", ctypes.c_int), + ("allow_video_fields", ctypes.c_bool), + ("p_ndi_recv_name", ctypes.c_char_p), + ] + + +class NDIlib_send_create_t(ctypes.Structure): + _fields_ = [ + ("p_ndi_name", ctypes.c_char_p), + ("p_groups", ctypes.c_char_p), + ("clock_video", ctypes.c_bool), + ("clock_audio", ctypes.c_bool), + ] + + +class NDIlib_video_frame_v2_t(ctypes.Structure): + _fields_ = [ + ("xres", ctypes.c_int), + ("yres", ctypes.c_int), + ("FourCC", ctypes.c_uint32), + ("frame_rate_N", ctypes.c_int), + ("frame_rate_D", ctypes.c_int), + ("picture_aspect_ratio", ctypes.c_float), + ("frame_format_type", ctypes.c_int), + ("timecode", ctypes.c_int64), + ("p_data", ctypes.POINTER(ctypes.c_uint8)), + ("line_stride_in_bytes", ctypes.c_int), + ("p_metadata", ctypes.c_char_p), + ("timestamp", ctypes.c_int64), + ] + + +@dataclass(frozen=True) +class NDILib: + lib: ctypes.CDLL + + +def _try_load_libndi(path: str) -> ctypes.CDLL | None: + try: + return ctypes.CDLL(path) + except OSError: + return None + + +def _cyndilib_bundled_libndi_path() -> Path | None: + try: + import cyndilib # type: ignore[import-not-found] + except Exception: + return None + + pkg_dir = Path(cyndilib.__file__).resolve().parent + candidate = pkg_dir / "wrapper" / "bin" / "x86_64-linux-gnu" / "libndi.so" + if candidate.exists(): + return candidate + return None + + +def load_libndi() -> NDILib: + env_path = os.environ.get("SCOPE_NDI_LIB_PATH") + if env_path: + lib = _try_load_libndi(env_path) + if lib is None: + raise RuntimeError(f"Failed to load NDI library at SCOPE_NDI_LIB_PATH={env_path!r}") + return NDILib(lib=lib) + + # Prefer system install if present. + for candidate in (find_library("ndi"), "libndi.so.6", "libndi.so"): + if not candidate: + continue + lib = _try_load_libndi(candidate) + if lib is not None: + return NDILib(lib=lib) + + # Fallback: cyndilib ships a bundled libndi.so (useful in dev containers). + bundled = _cyndilib_bundled_libndi_path() + if bundled is not None: + lib = _try_load_libndi(str(bundled)) + if lib is not None: + return NDILib(lib=lib) + + raise RuntimeError( + "NDI runtime not found. Install the NDI SDK/runtime (libndi.so.6), " + "or set SCOPE_NDI_LIB_PATH, or install cyndilib (dev fallback)." + ) + + +def configure_libndi_prototypes(ndi: NDILib) -> None: + lib = ndi.lib + + # init + lib.NDIlib_initialize.restype = ctypes.c_bool + lib.NDIlib_destroy.argtypes = [] + lib.NDIlib_version.restype = ctypes.c_char_p + + # find + lib.NDIlib_find_create_v2.argtypes = [ctypes.POINTER(NDIlib_find_create_t)] + lib.NDIlib_find_create_v2.restype = ctypes.c_void_p + lib.NDIlib_find_destroy.argtypes = [ctypes.c_void_p] + lib.NDIlib_find_wait_for_sources.argtypes = [ctypes.c_void_p, ctypes.c_uint32] + lib.NDIlib_find_wait_for_sources.restype = ctypes.c_bool + lib.NDIlib_find_get_current_sources.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_uint32), + ] + lib.NDIlib_find_get_current_sources.restype = ctypes.POINTER(NDIlib_source_t) + + # recv + lib.NDIlib_recv_create_v3.argtypes = [ctypes.POINTER(NDIlib_recv_create_v3_t)] + lib.NDIlib_recv_create_v3.restype = ctypes.c_void_p + lib.NDIlib_recv_destroy.argtypes = [ctypes.c_void_p] + lib.NDIlib_recv_connect.argtypes = [ctypes.c_void_p, ctypes.POINTER(NDIlib_source_t)] + lib.NDIlib_recv_capture_v2.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(NDIlib_video_frame_v2_t), + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ] + lib.NDIlib_recv_capture_v2.restype = ctypes.c_int + lib.NDIlib_recv_free_video_v2.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(NDIlib_video_frame_v2_t), + ] + lib.NDIlib_recv_get_no_connections.argtypes = [ctypes.c_void_p] + lib.NDIlib_recv_get_no_connections.restype = ctypes.c_int + + # send + lib.NDIlib_send_create.argtypes = [ctypes.POINTER(NDIlib_send_create_t)] + lib.NDIlib_send_create.restype = ctypes.c_void_p + lib.NDIlib_send_destroy.argtypes = [ctypes.c_void_p] + + lib.NDIlib_send_send_video_v2.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(NDIlib_video_frame_v2_t), + ] + lib.NDIlib_send_send_video_v2.restype = None + + # Send-side get_no_connections takes a timeout_in_ms parameter. + lib.NDIlib_send_get_no_connections.argtypes = [ctypes.c_void_p, ctypes.c_uint32] + lib.NDIlib_send_get_no_connections.restype = ctypes.c_int + + +# Enums (subset) from Processing.NDI.structs.h / Processing.NDI.Recv.h +NDIlib_frame_type_none = 0 +NDIlib_frame_type_video = 1 +NDIlib_frame_type_error = 4 +NDIlib_frame_type_status_change = 100 + +NDIlib_recv_color_format_BGRX_BGRA = 0 +NDIlib_recv_bandwidth_highest = 100 + +# Frame format type enum values (subset). +NDIlib_frame_format_type_interleaved = 0 +NDIlib_frame_format_type_progressive = 1 +NDIlib_frame_format_type_field_0 = 2 +NDIlib_frame_format_type_field_1 = 3 + + +def ndi_fourcc(a: str, b: str, c: str, d: str) -> int: + return (ord(a) | (ord(b) << 8) | (ord(c) << 16) | (ord(d) << 24)) + + +NDIlib_FourCC_type_BGRA = ndi_fourcc("B", "G", "R", "A") +NDIlib_FourCC_type_BGRX = ndi_fourcc("B", "G", "R", "X") + +# When sending, NDI recommends timecode synthesis (INT64_MAX). +NDIlib_send_timecode_synthesize = 9223372036854775807 diff --git a/src/scope/server/ndi/finder.py b/src/scope/server/ndi/finder.py new file mode 100644 index 000000000..71529f784 --- /dev/null +++ b/src/scope/server/ndi/finder.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import ctypes +from dataclasses import dataclass +from typing import Sequence + +from ._ctypes import NDIlib_find_create_t, NDIlib_source_t +from .runtime import get_runtime + + +@dataclass(frozen=True) +class NDISource: + name: str + url_address: str | None + + +def list_sources( + *, + timeout_ms: int = 1000, + extra_ips: Sequence[str] | None = None, + show_local_sources: bool = True, +) -> list[NDISource]: + """Discover sources using NDI's find API. + + Notes: + - `extra_ips` should be a list of IPs (comma-separated in the NDI API) for cross-subnet/VPN discovery. + - Strings returned by NDI are copied into Python strings before the finder is destroyed. + """ + runtime = get_runtime() + ndi = runtime.acquire() + try: + extra_ips_str = ",".join([ip.strip() for ip in (extra_ips or []) if ip.strip()]) + settings = NDIlib_find_create_t() + settings.show_local_sources = bool(show_local_sources) + settings.p_groups = None + settings.p_extra_ips = extra_ips_str.encode("utf-8") if extra_ips_str else None + + finder = ndi.lib.NDIlib_find_create_v2(ctypes.byref(settings)) + if not finder: + raise RuntimeError("NDIlib_find_create_v2() failed") + try: + ndi.lib.NDIlib_find_wait_for_sources(finder, int(timeout_ms)) + + count = ctypes.c_uint32(0) + sources_ptr = ndi.lib.NDIlib_find_get_current_sources(finder, ctypes.byref(count)) + + sources: list[NDISource] = [] + for i in range(int(count.value)): + src: NDIlib_source_t = sources_ptr[i] + if not src.p_ndi_name: + continue + name = src.p_ndi_name.decode("utf-8", errors="replace") + url = src.p_url_address.decode("utf-8", errors="replace") if src.p_url_address else None + sources.append(NDISource(name=name, url_address=url)) + return sources + finally: + ndi.lib.NDIlib_find_destroy(finder) + finally: + runtime.release() + diff --git a/src/scope/server/ndi/receiver.py b/src/scope/server/ndi/receiver.py new file mode 100644 index 000000000..de48351a9 --- /dev/null +++ b/src/scope/server/ndi/receiver.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import ctypes +import logging +import time +from dataclasses import dataclass +from typing import Sequence + +import numpy as np + +from ._ctypes import ( + NDIlib_recv_bandwidth_highest, + NDIlib_recv_color_format_BGRX_BGRA, + NDIlib_recv_create_v3_t, + NDIlib_source_t, + NDIlib_video_frame_v2_t, + NDIlib_frame_type_error, + NDIlib_frame_type_none, + NDIlib_frame_type_status_change, + NDIlib_frame_type_video, +) +from .finder import NDISource, list_sources +from .runtime import get_runtime + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class NDIReceiverStats: + frames_received: int = 0 + frames_dropped_during_drain: int = 0 + last_frame_ts_s: float = 0.0 + + +class NDIReceiver: + """Thin NDI receiver wrapper (ctypes over libndi).""" + + def __init__( + self, + *, + recv_name: str = "ScopeNDIRecv", + color_format: int = NDIlib_recv_color_format_BGRX_BGRA, + bandwidth: int = NDIlib_recv_bandwidth_highest, + allow_video_fields: bool = True, + ) -> None: + self._recv_name = recv_name + self._color_format = int(color_format) + self._bandwidth = int(bandwidth) + self._allow_video_fields = bool(allow_video_fields) + + self._runtime = get_runtime() + self._ndi = None + self._recv = None + self._connected_url: bytes | None = None + self._last_source: NDISource | None = None + + self._stats = NDIReceiverStats() + + def create(self) -> bool: + if self._recv is not None: + return True + + self._ndi = self._runtime.acquire() + create = NDIlib_recv_create_v3_t() + create.source_to_connect_to = NDIlib_source_t(None, None) + create.color_format = self._color_format + create.bandwidth = self._bandwidth + create.allow_video_fields = self._allow_video_fields + create.p_ndi_recv_name = self._recv_name.encode("utf-8") + + recv = self._ndi.lib.NDIlib_recv_create_v3(ctypes.byref(create)) + if not recv: + self._runtime.release() + self._ndi = None + return False + + self._recv = recv + return True + + def get_no_connections(self) -> int: + if self._recv is None or self._ndi is None: + return 0 + return int(self._ndi.lib.NDIlib_recv_get_no_connections(self._recv)) + + def connect(self, *, url_address: str, source_name: str | None = None) -> None: + if self._recv is None or self._ndi is None: + raise RuntimeError("Receiver not created") + + url_bytes = url_address.encode("utf-8") + self._connected_url = url_bytes + self._last_source = NDISource(name=source_name or "", url_address=url_address) + + src = NDIlib_source_t(None, url_bytes) + self._ndi.lib.NDIlib_recv_connect(self._recv, ctypes.byref(src)) + + def connect_discovered( + self, + *, + source_substring: str, + extra_ips: Sequence[str] | None = None, + timeout_ms: int = 1500, + ) -> NDISource: + sources = list_sources( + timeout_ms=timeout_ms, + extra_ips=extra_ips, + show_local_sources=True, + ) + if not sources: + raise RuntimeError("No NDI sources discovered") + + needle = (source_substring or "").strip().lower() + chosen: NDISource | None = None + if not needle: + chosen = sources[0] + else: + for s in sources: + if needle in s.name.lower(): + chosen = s + break + + if chosen is None: + raise RuntimeError( + f"No NDI source matched {source_substring!r}. Discovered: {[s.name for s in sources]}" + ) + + if not chosen.url_address: + raise RuntimeError(f"NDI source {chosen.name!r} has no url_address; cannot connect reliably") + + self.connect(url_address=chosen.url_address, source_name=chosen.name) + return chosen + + def receive_latest_rgb24(self, *, timeout_ms: int = 5) -> np.ndarray | None: + if self._recv is None or self._ndi is None: + return None + + # Drain-to-latest: keep only the newest available video frame. + dropped = 0 + last_vf: NDIlib_video_frame_v2_t | None = None + + def free_if_needed(vf: NDIlib_video_frame_v2_t | None) -> None: + if vf is None: + return + self._ndi.lib.NDIlib_recv_free_video_v2(self._recv, ctypes.byref(vf)) + + try: + # Block briefly for the first frame. + first = NDIlib_video_frame_v2_t() + ft = int(self._ndi.lib.NDIlib_recv_capture_v2(self._recv, ctypes.byref(first), None, None, int(timeout_ms))) + if ft == NDIlib_frame_type_video: + last_vf = first + elif ft in (NDIlib_frame_type_none, NDIlib_frame_type_status_change): + return None + elif ft == NDIlib_frame_type_error: + raise RuntimeError("NDI receiver capture error (connection lost?)") + else: + return None + + # Drain any queued frames (timeout=0). + while True: + nxt = NDIlib_video_frame_v2_t() + ft = int(self._ndi.lib.NDIlib_recv_capture_v2(self._recv, ctypes.byref(nxt), None, None, 0)) + if ft != NDIlib_frame_type_video: + break + free_if_needed(last_vf) + dropped += 1 + last_vf = nxt + + if last_vf is None: + return None + + h, w = int(last_vf.yres), int(last_vf.xres) + if h <= 0 or w <= 0: + return None + + stride = int(last_vf.line_stride_in_bytes) or (w * 4) + raw = ctypes.string_at(last_vf.p_data, stride * h) + + # Copy/reshape with stride, then drop padding and convert BGRX->RGB. + rows = np.frombuffer(raw, dtype=np.uint8).reshape((h, stride)) + bgrx = rows[:, : w * 4].reshape((h, w, 4)) + rgb = bgrx[:, :, 2::-1].copy() + + self._stats = NDIReceiverStats( + frames_received=self._stats.frames_received + 1, + frames_dropped_during_drain=self._stats.frames_dropped_during_drain + dropped, + last_frame_ts_s=time.monotonic(), + ) + return rgb + finally: + free_if_needed(last_vf) + + def get_stats(self) -> NDIReceiverStats: + return self._stats + + def release(self) -> None: + if self._recv is None: + return + try: + if self._ndi is not None: + self._ndi.lib.NDIlib_recv_destroy(self._recv) + finally: + self._recv = None + self._ndi = None + self._connected_url = None + self._last_source = None + self._runtime.release() + diff --git a/src/scope/server/ndi/runtime.py b/src/scope/server/ndi/runtime.py new file mode 100644 index 000000000..ca0b07f8a --- /dev/null +++ b/src/scope/server/ndi/runtime.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import logging +import threading + +from ._ctypes import NDILib, configure_libndi_prototypes, load_libndi + +logger = logging.getLogger(__name__) + + +class NDIRuntime: + """Process-global NDI SDK lifecycle (refcounted).""" + + def __init__(self) -> None: + self._lock = threading.Lock() + self._refcount = 0 + self._ndi: NDILib | None = None + self._initialized = False + + def acquire(self) -> NDILib: + with self._lock: + if self._ndi is None: + self._ndi = load_libndi() + configure_libndi_prototypes(self._ndi) + + if self._refcount == 0 and not self._initialized: + ok = bool(self._ndi.lib.NDIlib_initialize()) + if not ok: + raise RuntimeError("NDIlib_initialize() failed (unsupported CPU or missing runtime?)") + self._initialized = True + try: + ver = self._ndi.lib.NDIlib_version() + if ver: + logger.info("NDI runtime initialized: %s", ver.decode("utf-8", errors="replace")) + except Exception: + pass + + self._refcount += 1 + return self._ndi + + def release(self) -> None: + with self._lock: + if self._refcount <= 0: + return + self._refcount -= 1 + if self._refcount == 0 and self._initialized and self._ndi is not None: + try: + self._ndi.lib.NDIlib_destroy() + finally: + self._initialized = False + + +_RUNTIME = NDIRuntime() + + +def get_runtime() -> NDIRuntime: + return _RUNTIME + diff --git a/src/scope/server/session_recorder.py b/src/scope/server/session_recorder.py new file mode 100644 index 000000000..e85a98de3 --- /dev/null +++ b/src/scope/server/session_recorder.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + + +@dataclass +class ControlEvent: + """A single control event during recording.""" + + # Primary timebase + chunk_index: int + + # Secondary timebase: seconds since recording start + wall_time: float + + # Prompt + prompt: str | None = None + prompt_weight: float = 1.0 + + # Transition + transition_steps: int | None = None + transition_method: str | None = None # "linear" or "slerp" + + # Cuts + hard_cut: bool = False + soft_cut_bias: float | None = None + soft_cut_chunks: int | None = None + soft_cut_restore_bias: float | None = None # None means "was unset" + soft_cut_restore_was_set: bool = False + + +@dataclass +class SessionRecording: + """Container for a complete recording session.""" + + events: list[ControlEvent] = field(default_factory=list) + + # Chunk timebase (primary) + start_chunk: int = 0 + end_chunk: int | None = None + + # Wall-clock (secondary) + start_wall_time: float | None = None + end_wall_time: float | None = None + + # Pipeline info + pipeline_id: str | None = None + load_params: dict[str, Any] = field(default_factory=dict) + + @property + def is_active(self) -> bool: + return self.start_wall_time is not None and self.end_wall_time is None + + @property + def duration_seconds(self) -> float: + if self.start_wall_time is None: + return 0.0 + end = self.end_wall_time if self.end_wall_time is not None else time.monotonic() + return end - self.start_wall_time + + @property + def duration_chunks(self) -> int: + if self.end_chunk is None: + return 0 + return self.end_chunk - self.start_chunk + + +class SessionRecorder: + """Records control events during a streaming session. + + Intended usage: mutate only from the FrameProcessor worker thread. + FastAPI threads should read status via get_status_snapshot(). + """ + + def __init__(self) -> None: + self._recording: SessionRecording | None = None + self._last_prompt: str | None = None + self._status_snapshot: dict[str, Any] = {"is_recording": False} + + @property + def is_recording(self) -> bool: + recording = self._recording + return recording is not None and recording.is_active + + @property + def last_prompt(self) -> str | None: + return self._last_prompt + + def start( + self, + *, + chunk_index: int, + pipeline_id: str, + load_params: dict[str, Any], + baseline_prompt: str | None = None, + baseline_weight: float = 1.0, + ) -> None: + if not pipeline_id: + raise ValueError("pipeline_id is required for session recording") + + start_wall_time = time.monotonic() + self._recording = SessionRecording( + start_chunk=int(chunk_index), + start_wall_time=start_wall_time, + pipeline_id=pipeline_id, + load_params=dict(load_params or {}), + ) + + self._last_prompt = None + if baseline_prompt is not None: + self._recording.events.append( + ControlEvent( + chunk_index=int(chunk_index), + wall_time=0.0, + prompt=baseline_prompt, + prompt_weight=float(baseline_weight), + ) + ) + self._last_prompt = baseline_prompt + + self._update_status_snapshot() + + def record_event( + self, + *, + chunk_index: int, + wall_time: float, + prompt: str | None = None, + prompt_weight: float = 1.0, + transition_steps: int | None = None, + transition_method: str | None = None, + hard_cut: bool = False, + soft_cut_bias: float | None = None, + soft_cut_chunks: int | None = None, + soft_cut_restore_bias: float | None = None, + soft_cut_restore_was_set: bool = False, + ) -> None: + if not self.is_recording: + return + recording = self._recording + if recording is None or recording.start_wall_time is None: + return + + if prompt is not None: + self._last_prompt = prompt + + effective_prompt = prompt + if prompt is None and (hard_cut or soft_cut_bias is not None): + effective_prompt = self._last_prompt + + relative_time = max(0.0, float(wall_time) - float(recording.start_wall_time)) + + recording.events.append( + ControlEvent( + chunk_index=int(chunk_index), + wall_time=relative_time, + prompt=effective_prompt, + prompt_weight=float(prompt_weight), + transition_steps=transition_steps, + transition_method=transition_method, + hard_cut=bool(hard_cut), + soft_cut_bias=soft_cut_bias, + soft_cut_chunks=soft_cut_chunks, + soft_cut_restore_bias=soft_cut_restore_bias, + soft_cut_restore_was_set=bool(soft_cut_restore_was_set), + ) + ) + self._update_status_snapshot() + + def stop(self, *, chunk_index: int) -> SessionRecording | None: + if not self.is_recording: + return None + + recording = self._recording + if recording is None: + return None + + recording.end_chunk = int(chunk_index) + recording.end_wall_time = time.monotonic() + + self._recording = None + self._last_prompt = None + self._update_status_snapshot() + return recording + + def _update_status_snapshot(self) -> None: + recording = self._recording + if recording is None: + self._status_snapshot = {"is_recording": False} + return + + self._status_snapshot = { + "is_recording": recording.is_active, + "start_chunk": recording.start_chunk, + "duration_seconds": recording.duration_seconds, + "events_count": len(recording.events), + } + + def get_status_snapshot(self) -> dict[str, Any]: + return self._status_snapshot + + def export_timeline(self, recording: SessionRecording) -> dict[str, Any]: + segments: list[dict[str, Any]] = [] + + for i, event in enumerate(recording.events): + if event.prompt is None: + continue + + end_chunk = recording.end_chunk if recording.end_chunk is not None else event.chunk_index + end_time = recording.duration_seconds + for next_event in recording.events[i + 1 :]: + if next_event.prompt is None: + continue + end_chunk = next_event.chunk_index + end_time = next_event.wall_time + break + + segment: dict[str, Any] = { + "startTime": float(event.wall_time), + "endTime": float(end_time), + "startChunk": int(event.chunk_index - recording.start_chunk), + "endChunk": int(end_chunk - recording.start_chunk), + "prompts": [ + {"text": event.prompt, "weight": float(event.prompt_weight)} + ], + } + + if event.transition_steps is not None and int(event.transition_steps) > 0: + segment["transitionSteps"] = int(event.transition_steps) + if event.transition_method: + segment["temporalInterpolationMethod"] = event.transition_method + + if event.hard_cut: + segment["initCache"] = True + + if event.soft_cut_bias is not None: + segment["softCut"] = { + "bias": float(event.soft_cut_bias), + "chunks": int(event.soft_cut_chunks or 2), + "restoreBias": ( + float(event.soft_cut_restore_bias) + if event.soft_cut_restore_bias is not None + else None + ), + "restoreWasSet": bool(event.soft_cut_restore_was_set), + } + + segments.append(segment) + + load_params = recording.load_params or {} + height = load_params.get("height") + width = load_params.get("width") + + settings: dict[str, Any] = {"pipelineId": recording.pipeline_id} + if height is not None and width is not None: + settings["resolution"] = {"height": int(height), "width": int(width)} + if "seed" in load_params: + settings["seed"] = load_params.get("seed") + if "kv_cache_attention_bias" in load_params: + settings["kvCacheAttentionBias"] = load_params.get("kv_cache_attention_bias") + + # Export LoRA configuration for replay + loras = load_params.get("loras") + if loras and isinstance(loras, list): + settings["loras"] = [ + { + "path": lora.get("path"), + "scale": float(lora.get("scale", 1.0)), + **({"mergeMode": lora.get("merge_mode")} if lora.get("merge_mode") else {}), + } + for lora in loras + if isinstance(lora, dict) and lora.get("path") + ] + lora_merge_mode = load_params.get("lora_merge_mode") + if lora_merge_mode: + settings["loraMergeStrategy"] = lora_merge_mode + + return { + "version": "1.1", + "exportedAt": datetime.now(timezone.utc).isoformat(), + "recording": { + "durationSeconds": float(recording.duration_seconds), + "durationChunks": int(recording.duration_chunks), + "startChunk": int(recording.start_chunk), + "endChunk": int(recording.end_chunk) if recording.end_chunk is not None else None, + }, + "settings": settings, + "prompts": segments, + } + + def save(self, recording: SessionRecording, path: Path) -> Path: + timeline = self.export_timeline(recording) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(timeline, indent=2)) + return path diff --git a/src/scope/vendored/video_depth_anything/LICENSE b/src/scope/vendored/video_depth_anything/LICENSE new file mode 100644 index 000000000..f49a4e16e --- /dev/null +++ b/src/scope/vendored/video_depth_anything/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/src/scope/vendored/video_depth_anything/README.md b/src/scope/vendored/video_depth_anything/README.md new file mode 100644 index 000000000..adf06e0c0 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/README.md @@ -0,0 +1,53 @@ +# Video Depth Anything (Vendored) + +Minimal inference code vendored from [Video-Depth-Anything](https://github.com/DepthAnything/Video-Depth-Anything). + +## Source + +- **Repository:** https://github.com/DepthAnything/Video-Depth-Anything +- **Commit:** 4f5ae23172ba60fd7bc11ef671cca678842c7072 +- **Date vendored:** 2025-12-28 + +## License + +Apache-2.0 (see LICENSE file) + +## Modifications + +- Renamed `utils/` to `vda_utils/` to avoid import conflicts +- Updated import path in `video_depth_stream.py` +- Moved to `src/scope/vendored/` to be included in package + +## Usage + +```python +from scope.vendored.video_depth_anything import VideoDepthAnything + +model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, +} + +model = VideoDepthAnything(**model_configs['vits']) +model.load_state_dict(torch.load('~/.daydream-scope/models/vda/video_depth_anything_vits.pth', map_location='cpu')) +model = model.to('cuda').eval() + +# Streaming inference (one frame at a time) +depth = model.infer_video_depth_one(frame_rgb, input_size=518, device='cuda') + +# Reset cache on hard cuts +model.transform = None +model.frame_id_list = [] +model.frame_cache_list = [] +model.id = -1 +``` + +## Checkpoints + +Download to `~/.daydream-scope/models/vda/` (or `$DAYDREAM_SCOPE_MODELS_DIR/vda/`): +- `video_depth_anything_vits.pth` (VDA-Small, 28.4M params, Apache-2.0) + +```bash +mkdir -p ~/.daydream-scope/models/vda +wget -O ~/.daydream-scope/models/vda/video_depth_anything_vits.pth \ + https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth +``` diff --git a/src/scope/vendored/video_depth_anything/__init__.py b/src/scope/vendored/video_depth_anything/__init__.py new file mode 100644 index 000000000..1cce74e2d --- /dev/null +++ b/src/scope/vendored/video_depth_anything/__init__.py @@ -0,0 +1,7 @@ +# Video Depth Anything - vendored from https://github.com/DepthAnything/Video-Depth-Anything +# Commit: 4f5ae23172ba60fd7bc11ef671cca678842c7072 +# License: Apache-2.0 + +from .video_depth_stream import VideoDepthAnything + +__all__ = ["VideoDepthAnything"] diff --git a/src/scope/vendored/video_depth_anything/dinov2.py b/src/scope/vendored/video_depth_anything/dinov2.py new file mode 100644 index 000000000..83d250818 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/src/scope/vendored/video_depth_anything/dinov2_layers/__init__.py b/src/scope/vendored/video_depth_anything/dinov2_layers/__init__.py new file mode 100644 index 000000000..8120f4bc8 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/src/scope/vendored/video_depth_anything/dinov2_layers/attention.py b/src/scope/vendored/video_depth_anything/dinov2_layers/attention.py new file mode 100644 index 000000000..e94fd2c15 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2_layers/attention.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +import os +import torch +import torch.nn.functional as F +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +def _is_env_true(name: str, default: str = "0") -> bool: + return (os.getenv(name, default) or default).strip().lower() in ("1", "true", "yes", "on") + + +_USE_SDPA_FALLBACK = _is_env_true("SCOPE_VDA_SDPA_FALLBACK", default="1") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + # Fallback to PyTorch SDPA (Flash/MemEff/Math) when xFormers isn't available. + # This is typically much faster than the naive q@k^T path on modern GPUs. + if _USE_SDPA_FALLBACK and hasattr(F, "scaled_dot_product_attention"): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = qkv.unbind(dim=2) # (B, N, H, D) + q = q.transpose(1, 2) # (B, H, N, D) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + dropout_p = float(self.attn_drop.p) if self.training else 0.0 + try: + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + ) + except Exception: + # If SDPA isn't available for this shape/dtype/device, fall back + # to the naive implementation to preserve correctness. + return super().forward(x) + + out = out.transpose(1, 2).reshape(B, N, C) + out = self.proj(out) + out = self.proj_drop(out) + return out + + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + diff --git a/src/scope/vendored/video_depth_anything/dinov2_layers/block.py b/src/scope/vendored/video_depth_anything/dinov2_layers/block.py new file mode 100644 index 000000000..25488f57c --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/src/scope/vendored/video_depth_anything/dinov2_layers/drop_path.py b/src/scope/vendored/video_depth_anything/dinov2_layers/drop_path.py new file mode 100644 index 000000000..af0562598 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/src/scope/vendored/video_depth_anything/dinov2_layers/layer_scale.py b/src/scope/vendored/video_depth_anything/dinov2_layers/layer_scale.py new file mode 100644 index 000000000..ca5daa52b --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/src/scope/vendored/video_depth_anything/dinov2_layers/mlp.py b/src/scope/vendored/video_depth_anything/dinov2_layers/mlp.py new file mode 100644 index 000000000..5e4b315f9 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/src/scope/vendored/video_depth_anything/dinov2_layers/patch_embed.py b/src/scope/vendored/video_depth_anything/dinov2_layers/patch_embed.py new file mode 100644 index 000000000..574abe411 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/src/scope/vendored/video_depth_anything/dinov2_layers/swiglu_ffn.py b/src/scope/vendored/video_depth_anything/dinov2_layers/swiglu_ffn.py new file mode 100644 index 000000000..b3324b266 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/src/scope/vendored/video_depth_anything/dpt.py b/src/scope/vendored/video_depth_anything/dpt.py new file mode 100644 index 000000000..8c43a1447 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dpt.py @@ -0,0 +1,160 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .util.blocks import FeatureFusionBlock, _make_scratch + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + \ No newline at end of file diff --git a/src/scope/vendored/video_depth_anything/dpt_temporal.py b/src/scope/vendored/video_depth_anything/dpt_temporal.py new file mode 100644 index 000000000..d629a1127 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/dpt_temporal.py @@ -0,0 +1,125 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +import torch.nn as nn +from .dpt import DPTHead +from .motion_module.motion_module import TemporalModule +from easydict import EasyDict + + +class DPTHeadTemporal(DPTHead): + def __init__(self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False, + num_frames=32, + pe='ape' + ): + super().__init__(in_channels, features, use_bn, out_channels, use_clstoken) + + assert num_frames > 0 + motion_module_kwargs = EasyDict(num_attention_heads = 8, + num_transformer_block = 1, + num_attention_blocks = 2, + temporal_max_len = num_frames, + zero_initialize = True, + pos_embedding_type = pe) + + self.motion_modules = nn.ModuleList([ + TemporalModule(in_channels=out_channels[2], + **motion_module_kwargs), + TemporalModule(in_channels=out_channels[3], + **motion_module_kwargs), + TemporalModule(in_channels=features, + **motion_module_kwargs), + TemporalModule(in_channels=features, + **motion_module_kwargs) + ]) + + def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size=4, cached_hidden_state_list=None): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous() + + B, T = x.shape[0] // frame_length, frame_length + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + B, T = layer_1.shape[0] // frame_length, frame_length + if cached_hidden_state_list is not None: + N = len(cached_hidden_state_list) // len(self.motion_modules) + else: + N = 0 + + layer_3, h0 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[0:N] if N else None) + layer_3 = layer_3.permute(0, 2, 1, 3, 4).flatten(0, 1) + layer_4, h1 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[N:2*N] if N else None) + layer_4 = layer_4.permute(0, 2, 1, 3, 4).flatten(0, 1) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_4, h2 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[2*N:3*N] if N else None) + path_4 = path_4.permute(0, 2, 1, 3, 4).flatten(0, 1) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_3, h3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None, cached_hidden_state_list[3*N:] if N else None) + path_3 = path_3.permute(0, 2, 1, 3, 4).flatten(0, 1) + + batch_size = layer_1_rn.shape[0] + if batch_size <= micro_batch_size or batch_size % micro_batch_size != 0: + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate( + out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True + ) + ori_type = out.dtype + with torch.autocast(device_type="cuda", enabled=False): + out = self.scratch.output_conv2(out.float()) + + output = out.to(ori_type) + else: + ret = [] + for i in range(0, batch_size, micro_batch_size): + path_2 = self.scratch.refinenet2(path_3[i:i + micro_batch_size], layer_2_rn[i:i + micro_batch_size], size=layer_1_rn[i:i + micro_batch_size].shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn[i:i + micro_batch_size]) + out = self.scratch.output_conv1(path_1) + out = F.interpolate( + out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True + ) + ori_type = out.dtype + with torch.autocast(device_type="cuda", enabled=False): + out = self.scratch.output_conv2(out.float()) + ret.append(out.to(ori_type)) + output = torch.cat(ret, dim=0) + + return output, h0 + h1 + h2 + h3 diff --git a/src/scope/vendored/video_depth_anything/motion_module/attention.py b/src/scope/vendored/video_depth_anything/motion_module/attention.py new file mode 100644 index 000000000..41f551ba1 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/motion_module/attention.py @@ -0,0 +1,429 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +try: + import xformers + import xformers.ops + + XFORMERS_AVAILABLE = True +except ImportError: + print("xFormers not available") + XFORMERS_AVAILABLE = False + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.upcast_efficient_attention = False + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous() + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous() + return tensor + + def reshape_heads_to_4d(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous() + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous() + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous() + return tensor + + def reshape_4d_to_heads(self, tensor): + batch_size, seq_len, head_size, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous() + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + if self.upcast_efficient_attention: + org_dtype = query.dtype + query = query.float() + key = key.float() + value = value.float() + if attention_mask is not None: + attention_mask = attention_mask.float() + hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask) + + if self.upcast_efficient_attention: + hidden_states = hidden_states.to(org_dtype) + + hidden_states = self.reshape_4d_to_heads(hidden_states) + return hidden_states + + # print("Errror: no xformers") + # raise NotImplementedError + + def _memory_efficient_attention_split(self, query, key, value, attention_mask): + batch_size = query.shape[0] + max_batch_size = 65535 + num_batches = (batch_size + max_batch_size - 1) // max_batch_size + results = [] + for i in range(num_batches): + start_idx = i * max_batch_size + end_idx = min((i + 1) * max_batch_size, batch_size) + query_batch = query[start_idx:end_idx] + key_batch = key[start_idx:end_idx] + value_batch = value[start_idx:end_idx] + if attention_mask is not None: + attention_mask_batch = attention_mask[start_idx:end_idx] + else: + attention_mask_batch = None + result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch) + results.append(result) + full_result = torch.cat(results, dim=0) + return full_result + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous()) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous()) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2) + return xq_out.type_as(xq), xk_out.type_as(xk) diff --git a/src/scope/vendored/video_depth_anything/motion_module/motion_module.py b/src/scope/vendored/video_depth_anything/motion_module/motion_module.py new file mode 100644 index 000000000..d4d81a3e0 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/motion_module/motion_module.py @@ -0,0 +1,321 @@ +# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff +# SPDX-License-Identifier: Apache-2.0 license +# +# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification] +# Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme]. +import torch +import torch.nn.functional as F +from torch import nn + +from .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis + +from einops import rearrange, repeat +import math + +try: + import xformers + import xformers.ops + + XFORMERS_AVAILABLE = True +except ImportError: + print("xFormers not available") + XFORMERS_AVAILABLE = False + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +class TemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads = 8, + num_transformer_block = 2, + num_attention_blocks = 2, + norm_num_groups = 32, + temporal_max_len = 32, + zero_initialize = True, + pos_embedding_type = "ape", + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads, + num_layers=num_transformer_block, + num_attention_blocks=num_attention_blocks, + norm_num_groups=norm_num_groups, + temporal_max_len=temporal_max_len, + pos_embedding_type=pos_embedding_type, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + + def forward(self, input_tensor, encoder_hidden_states, attention_mask=None, cached_hidden_state_list=None): + hidden_states = input_tensor + hidden_states, output_hidden_state_list = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask, cached_hidden_state_list) + + output = hidden_states + return output, output_hidden_state_list # list of hidden states + + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + num_attention_blocks = 2, + norm_num_groups = 32, + temporal_max_len = 32, + pos_embedding_type = "ape", + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_attention_blocks=num_attention_blocks, + temporal_max_len=temporal_max_len, + pos_embedding_type=pos_embedding_type, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, cached_hidden_state_list=None): + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + output_hidden_state_list = [] + + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous() + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + if cached_hidden_state_list is not None: + n = len(cached_hidden_state_list) // len(self.transformer_blocks) + else: + n = 0 + for i, block in enumerate(self.transformer_blocks): + hidden_states, hidden_state_list = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask, + cached_hidden_state_list=cached_hidden_state_list[i*n:(i+1)*n] if n else None) + output_hidden_state_list.extend(hidden_state_list) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output, output_hidden_state_list + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + num_attention_blocks = 2, + temporal_max_len = 32, + pos_embedding_type = "ape", + ): + super().__init__() + + self.attention_blocks = nn.ModuleList( + [ + TemporalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + temporal_max_len=temporal_max_len, + pos_embedding_type=pos_embedding_type, + ) + for i in range(num_attention_blocks) + ] + ) + self.norms = nn.ModuleList( + [ + nn.LayerNorm(dim) + for i in range(num_attention_blocks) + ] + ) + + self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu") + self.ff_norm = nn.LayerNorm(dim) + + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, cached_hidden_state_list=None): + output_hidden_state_list = [] + for i, (attention_block, norm) in enumerate(zip(self.attention_blocks, self.norms)): + norm_hidden_states = norm(hidden_states) + residual_hidden_states, output_hidden_states = attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + video_length=video_length, + attention_mask=attention_mask, + cached_hidden_states=cached_hidden_state_list[i] if cached_hidden_state_list is not None else None, + ) + hidden_states = residual_hidden_states + hidden_states + output_hidden_state_list.append(output_hidden_states) + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output, output_hidden_state_list + + +class PositionalEncoding(nn.Module): + def __init__( + self, + d_model, + dropout = 0., + max_len = 32 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)].to(x.dtype) + return self.dropout(x) + +class TemporalAttention(CrossAttention): + def __init__( + self, + temporal_max_len = 32, + pos_embedding_type = "ape", + *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.pos_embedding_type = pos_embedding_type + self._use_memory_efficient_attention_xformers = True + + self.pos_encoder = None + self.freqs_cis = None + if self.pos_embedding_type == "ape": + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0., + max_len=temporal_max_len + ) + + elif self.pos_embedding_type == "rope": + self.freqs_cis = precompute_freqs_cis( + kwargs["query_dim"], + temporal_max_len + ) + + else: + raise NotImplementedError + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, cached_hidden_states=None): + # TODO: support cache for these + assert encoder_hidden_states is None + assert attention_mask is None + + d = hidden_states.shape[1] + d_in = 0 + if cached_hidden_states is None: + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + input_hidden_states = hidden_states # (bxd) f c + else: + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=1) + input_hidden_states = hidden_states + d_in = cached_hidden_states.shape[1] + hidden_states = torch.cat([cached_hidden_states, hidden_states], dim=1) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states[:, d_in:, ...]) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if self.freqs_cis is not None: + seq_len = query.shape[1] + freqs_cis = self.freqs_cis[:seq_len].to(query.device) + query, key = apply_rotary_emb(query, key, freqs_cis) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + + use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers + if use_memory_efficient and (dim // self.heads) % 8 != 0: + # print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads)) + use_memory_efficient = False + + # attention, what we cannot get enough of + if use_memory_efficient: + query = self.reshape_heads_to_4d(query) + key = self.reshape_heads_to_4d(key) + value = self.reshape_heads_to_4d(value) + + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + raise NotImplementedError + # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states, input_hidden_states diff --git a/src/scope/vendored/video_depth_anything/util/blocks.py b/src/scope/vendored/video_depth_anything/util/blocks.py new file mode 100644 index 000000000..0be16c053 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/util/blocks.py @@ -0,0 +1,162 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn is True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn is True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn is True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + ): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand is True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1 + ) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output diff --git a/src/scope/vendored/video_depth_anything/util/transform.py b/src/scope/vendored/video_depth_anything/util/transform.py new file mode 100644 index 000000000..b14aacd44 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/util/transform.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/src/scope/vendored/video_depth_anything/vda_utils/__init__.py b/src/scope/vendored/video_depth_anything/vda_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/scope/vendored/video_depth_anything/vda_utils/util.py b/src/scope/vendored/video_depth_anything/vda_utils/util.py new file mode 100644 index 000000000..75ff80a84 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/vda_utils/util.py @@ -0,0 +1,74 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +def compute_scale_and_shift(prediction, target, mask, scale_only=False): + if scale_only: + return compute_scale(prediction, target, mask), 0 + else: + return compute_scale_and_shift_full(prediction, target, mask) + + +def compute_scale(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + prediction = prediction.astype(np.float32) + target = target.astype(np.float32) + mask = mask.astype(np.float32) + + a_00 = np.sum(mask * prediction * prediction) + a_01 = np.sum(mask * prediction) + a_11 = np.sum(mask) + + # right hand side: b = [b_0, b_1] + b_0 = np.sum(mask * prediction * target) + + x_0 = b_0 / (a_00 + 1e-6) + + return x_0 + +def compute_scale_and_shift_full(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + prediction = prediction.astype(np.float32) + target = target.astype(np.float32) + mask = mask.astype(np.float32) + + a_00 = np.sum(mask * prediction * prediction) + a_01 = np.sum(mask * prediction) + a_11 = np.sum(mask) + + b_0 = np.sum(mask * prediction * target) + b_1 = np.sum(mask * target) + + x_0 = 1 + x_1 = 0 + + det = a_00 * a_11 - a_01 * a_01 + + if det != 0: + x_0 = (a_11 * b_0 - a_01 * b_1) / det + x_1 = (-a_01 * b_0 + a_00 * b_1) / det + + return x_0, x_1 + + +def get_interpolate_frames(frame_list_pre, frame_list_post): + assert len(frame_list_pre) == len(frame_list_post) + min_w = 0.0 + max_w = 1.0 + step = (max_w - min_w) / (len(frame_list_pre)-1) + post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w] + interpolated_frames = [] + for i in range(len(frame_list_pre)): + interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i]) + return interpolated_frames \ No newline at end of file diff --git a/src/scope/vendored/video_depth_anything/video_depth.py b/src/scope/vendored/video_depth_anything/video_depth.py new file mode 100644 index 000000000..5598d302a --- /dev/null +++ b/src/scope/vendored/video_depth_anything/video_depth.py @@ -0,0 +1,163 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +import torch.nn as nn +from torchvision.transforms import Compose +import cv2 +from tqdm import tqdm +import numpy as np +import gc + +from .dinov2 import DINOv2 +from .dpt_temporal import DPTHeadTemporal +from .util.transform import Resize, NormalizeImage, PrepareForNet + +from utils.util import compute_scale_and_shift, get_interpolate_frames + +# infer settings, do not change +INFER_LEN = 32 +OVERLAP = 10 +KEYFRAMES = [0,12,24,25,26,27,28,29,30,31] +INTERP_LEN = 8 + +class VideoDepthAnything(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False, + num_frames=32, + pe='ape', + metric=False, + ): + super(VideoDepthAnything, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + "vitb": [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe) + self.metric = metric + + def forward(self, x): + B, T, C, H, W = x.shape + patch_h, patch_w = H // 14, W // 14 + features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True) + depth = self.head(features, patch_h, patch_w, T)[0] + depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True) + depth = F.relu(depth) + return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W] + + def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda', fp32=False): + frame_height, frame_width = frames[0].shape[:2] + ratio = max(frame_height, frame_width) / min(frame_height, frame_width) + if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation + input_size = int(input_size * 1.777 / ratio) + input_size = round(input_size / 14) * 14 + + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + frame_list = [frames[i] for i in range(frames.shape[0])] + frame_step = INFER_LEN - OVERLAP + org_video_len = len(frame_list) + append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step) + frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len + + depth_list = [] + pre_input = None + for frame_id in tqdm(range(0, org_video_len, frame_step)): + cur_list = [] + for i in range(INFER_LEN): + cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0)) + cur_input = torch.cat(cur_list, dim=1).to(device) + if pre_input is not None: + cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...] + + with torch.no_grad(): + with torch.autocast(device_type=device, enabled=(not fp32)): + depth = self.forward(cur_input) # depth shape: [1, T, H, W] + + depth = depth.to(cur_input.dtype) + depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True) + depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])] + + pre_input = cur_input + + del frame_list + gc.collect() + + depth_list_aligned = [] + ref_align = [] + align_len = OVERLAP - INTERP_LEN + kf_align_list = KEYFRAMES[:align_len] + + for frame_id in range(0, len(depth_list), INFER_LEN): + if len(depth_list_aligned) == 0: + depth_list_aligned += depth_list[:INFER_LEN] + for kf_id in kf_align_list: + ref_align.append(depth_list[frame_id+kf_id]) + else: + curr_align = [] + for i in range(len(kf_align_list)): + curr_align.append(depth_list[frame_id+i]) + + if self.metric: + scale, shift = 1.0, 0.0 + else: + scale, shift = compute_scale_and_shift(np.concatenate(curr_align), + np.concatenate(ref_align), + np.concatenate(np.ones_like(ref_align)==1)) + + pre_depth_list = depth_list_aligned[-INTERP_LEN:] + post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP] + for i in range(len(post_depth_list)): + post_depth_list[i] = post_depth_list[i] * scale + shift + post_depth_list[i][post_depth_list[i]<0] = 0 + depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list) + + for i in range(OVERLAP, INFER_LEN): + new_depth = depth_list[frame_id+i] * scale + shift + new_depth[new_depth<0] = 0 + depth_list_aligned.append(new_depth) + + ref_align = ref_align[:1] + for kf_id in kf_align_list[1:]: + new_depth = depth_list[frame_id+kf_id] * scale + shift + new_depth[new_depth<0] = 0 + ref_align.append(new_depth) + + depth_list = depth_list_aligned + + return np.stack(depth_list[:org_video_len], axis=0), target_fps + diff --git a/src/scope/vendored/video_depth_anything/video_depth_stream.py b/src/scope/vendored/video_depth_anything/video_depth_stream.py new file mode 100644 index 000000000..5dde67b51 --- /dev/null +++ b/src/scope/vendored/video_depth_anything/video_depth_stream.py @@ -0,0 +1,179 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +import torch.nn as nn +from torchvision.transforms import Compose +import cv2 +import numpy as np + +from .dinov2 import DINOv2 +from .dpt_temporal import DPTHeadTemporal +from .util.transform import Resize, NormalizeImage, PrepareForNet + +from .vda_utils.util import compute_scale_and_shift, get_interpolate_frames + +# infer settings, do not change +INFER_LEN = 32 +OVERLAP = 10 +INTERP_LEN = 8 + +class VideoDepthAnything(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False, + num_frames=32, + pe='ape' + ): + super(VideoDepthAnything, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + "vitb": [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe) + self.transform = None + self.frame_id_list = [] + self.frame_cache_list = [] + self.gap = (INFER_LEN - OVERLAP) * 2 - 1 - (OVERLAP - INTERP_LEN) + assert self.gap == 41 + self.id = -1 + + def forward(self, x): + return self.forward_depth(self.forward_features(x), x.shape)[0] + + def forward_features(self, x): + features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True) + return features + + def forward_depth(self, features, x_shape, cached_hidden_state_list=None): + B, T, C, H, W = x_shape + patch_h, patch_w = H // 14, W // 14 + depth, cur_cached_hidden_state_list = self.head(features, patch_h, patch_w, T, cached_hidden_state_list=cached_hidden_state_list) + depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True) + depth = F.relu(depth) + return depth.squeeze(1).unflatten(0, (B, T)), cur_cached_hidden_state_list # return shape [B, T, H, W] + + def infer_video_depth_one( + self, + frame, + input_size=518, + device='cuda', + fp32=False, + return_torch: bool = False, + use_temporal_cache: bool = True, + ): + # Note: use_temporal_cache is accepted for API compatibility but VDA + # streaming always uses temporal caching by design. Set to False to + # force single-frame mode (resets cache each frame - slower but stateless). + if not use_temporal_cache: + # Stateless mode: reset cache to force first-frame path each time + self.transform = None + self.frame_cache_list = [None] + self.frame_id_list = [0] + self.id = 0 + + self.id += 1 + + if self.transform is None: # first frame + # Initialize the transform + frame_height, frame_width = frame.shape[:2] + self.frame_height = frame_height + self.frame_width = frame_width + ratio = max(frame_height, frame_width) / min(frame_height, frame_width) + if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation + input_size = int(input_size * 1.777 / ratio) + input_size = round(input_size / 14) * 14 + + self.transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + # Inference the first frame + cur_list = [torch.from_numpy(self.transform({'image': frame.astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0)] + cur_input = torch.cat(cur_list, dim=1).to(device) + + with torch.no_grad(): + with torch.autocast(device_type=device, enabled=(not fp32)): + cur_feature = self.forward_features(cur_input) + x_shape = cur_input.shape + depth, cached_hidden_state_list = self.forward_depth(cur_feature, x_shape) + + depth = depth.to(cur_input.dtype) + depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True) + + # Copy multiple cache to simulate the windows + self.frame_cache_list = [cached_hidden_state_list] * INFER_LEN + self.frame_id_list.extend([0] * (INFER_LEN - 1)) + + new_depth_t = depth[0][0] + new_depth = new_depth_t if return_torch else new_depth_t.cpu().numpy() + else: + frame_height, frame_width = frame.shape[:2] + assert frame_height == self.frame_height + assert frame_width == self.frame_width + + # infer feature + cur_input = torch.from_numpy(self.transform({'image': frame.astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0).to(device) + with torch.no_grad(): + with torch.autocast(device_type=device, enabled=(not fp32)): + cur_feature = self.forward_features(cur_input) + x_shape = cur_input.shape + + cur_list = self.frame_cache_list[0:2] + self.frame_cache_list[-INFER_LEN+3:] + ''' + cur_id = self.frame_id_list[0:2] + self.frame_id_list[-INFER_LEN+3:] + print(f"cur_id: {cur_id}") + ''' + assert len(cur_list) == INFER_LEN - 1 + cur_cache = [torch.cat([h[i] for h in cur_list], dim=1) for i in range(len(cur_list[0]))] + + # infer depth + with torch.no_grad(): + with torch.autocast(device_type=device, enabled=(not fp32)): + depth, new_cache = self.forward_depth(cur_feature, x_shape, cached_hidden_state_list=cur_cache) + + depth = depth.to(cur_input.dtype) + depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True) + new_depth_t = depth[-1][0] + new_depth = new_depth_t if return_torch else new_depth_t.cpu().numpy() + + self.frame_cache_list.append(new_cache) + + # adjust the sliding window + self.frame_id_list.append(self.id) + if self.id + INFER_LEN > self.gap + 1: + del self.frame_id_list[1] + del self.frame_cache_list[1] + + return new_depth