diff --git a/.gitignore b/.gitignore index 053780c..10c173a 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,9 @@ weights/ **/.DS_Store docs/ +uv.lock +*.ply +*.glb +*.7z +export_data.bin +lingbot-map-long.pt diff --git a/README.md b/README.md index 2313c69..0c54a2c 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,48 @@ pip install --index-url https://pypi.org/simple flashinfer-python pip install -e ".[vis]" ``` +### 💚 Alternative: `uv` for faster installs (conda alternative) + +The `pyproject.toml` is fully compatible with [`uv`](https://github.com/astral-sh/uv), a fast Python package installer and virtual environment manager written in Rust. It mirrors the conda setup below but resolves and installs dependencies orders of magnitude faster: + +**1. Create uv virtual environment** + +```bash +curl -LsSf https://astral.sh/uv/install.sh | sh +uv venv --python 3.12 +source .venv/bin/activate +``` + +**2. Install core package (installs all dependencies from `pyproject.toml`)** + +```bash +uv sync +``` + +This installs the package in editable mode (`-e .`) plus all listed `[project]` dependencies — including `torch`, `torchvision`, `flashinfer-python`, `Pillow`, `huggingface_hub`, `einops`, `safetensors`, `opencv-python`, `tqdm`, `scipy`, and `flashinfer-cubin`. + +> **Note:** PyTorch version pinning (`2.8.0` with CUDA 12.8) is currently specified in the conda guide for Kaolin compatibility. If you need that exact pin, install it first: +> ```bash +> uv pip install torch==2.8.0 torchvision==0.23.0 \ +> --index-url https://download.pytorch.org/whl/cu128 +> ``` + +**3. Optional — visualization dependencies** + +```bash +uv sync --group vis +``` + +Installs `[project.optional-dependencies.vis]` (`viser`, `trimesh`, `matplotlib`, `onnxruntime-gpu`, `requests`). + +**4. Optional — demo extra (pulls in `vis`)** + +```bash +uv sync --group demo +``` + +> `uv` manages the virtual environment entirely through PEP 621 metadata from `pyproject.toml` — no `environment.yml`, no conda channels needed. It is fully compatible with the existing conda instructions above; pick whichever tooling you prefer. + ## 📦 Model Download | Model Name | Huggingface Repository | ModelScope Repository | Description | diff --git a/demo.py b/demo.py index ac6e51c..779d1ca 100644 --- a/demo.py +++ b/demo.py @@ -24,6 +24,7 @@ import sys import tempfile import time +from pathlib import Path # Must be set before `import torch` / any CUDA init. Reduces the reserved-vs-allocated # memory gap by letting the caching allocator grow segments on demand instead of @@ -88,10 +89,11 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png resolved_folder = out_dir print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})") else: + assert image_folder is not None exts = image_ext.split(",") paths = [] for ext in exts: - paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}"))) + paths.extend(glob.glob(str(Path(image_folder) / f"*{ext}"))) paths = sorted(paths) resolved_folder = image_folder @@ -106,7 +108,7 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png # Image.ROTATE_270 = lossless 90° clockwise (270° counter-clockwise) reordering. for p in tqdm(paths, desc="Rotating images 90° CW"): out_path = os.path.join(rotated_dir, os.path.basename(p)) - Image.open(p).transpose(Image.ROTATE_270).save(out_path) + Image.open(p).transpose(Image.Transpose.ROTATE_270).save(out_path) rotated_paths.append(out_path) paths = rotated_paths resolved_folder = rotated_dir @@ -211,10 +213,15 @@ def _warm_streaming(model, images, scale_frames, warm_stream_n, dtype, warm_stream_n = max(1, min(int(warm_stream_n), num_avail - scale_frames)) kf_int = max(int(keyframe_interval), 1) - # images: [S, 3, H, W] on device already; slice + add batch dim, no copy of - # spatial dims so warmup shape == real inference shape (H, W). - warm_scale = images[:scale_frames].unsqueeze(0).to(dtype) - warm_stream = images[scale_frames:scale_frames + warm_stream_n].unsqueeze(0).to(dtype) + # images may live on CPU; move only the warmup slices to the model device so + # long videos do not become persistent GPU residents before inference starts. + _model_device = next(model.parameters()).device + warm_scale = images[:scale_frames].unsqueeze(0).to( + device=_model_device, dtype=dtype, non_blocking=True + ) + warm_stream = images[scale_frames:scale_frames + warm_stream_n].unsqueeze(0).to( + device=_model_device, dtype=dtype, non_blocking=True + ) for _ in range(passes): model.clean_kv_cache() @@ -387,7 +394,7 @@ def main(): parser.add_argument( "--offload_to_cpu", action=argparse.BooleanOptionalAction, - default=False, + default=True, help="Offload per-frame predictions to CPU during inference to cut GPU peak memory " "(on by default). Use --no-offload_to_cpu to keep outputs on GPU.", ) @@ -459,14 +466,13 @@ def main(): print(f"Casting aggregator to {dtype} (heads kept in fp32)") model.aggregator = model.aggregator.to(dtype=dtype) - images = images.to(device) num_frames = images.shape[0] print(f"Input: {num_frames} frames, shape {tuple(images.shape)}") print(f"Mode: {args.mode}") if torch.cuda.is_available(): torch.cuda.empty_cache() print( - f"GPU mem after load: " + f"GPU mem after model load: " f"alloc={torch.cuda.memory_allocated()/1e9:.2f} GB, " f"reserved={torch.cuda.memory_reserved()/1e9:.2f} GB" ) @@ -498,6 +504,12 @@ def main(): f"(window_size={args.window_size} keyframes, scale={args.num_scale_frames})." ) + if not args.offload_to_cpu and (args.mode == "windowed" or num_frames > 512): + print( + "Warning: --no-offload_to_cpu keeps the full prediction history on the GPU; " + "long sequences can OOM even when KV cache growth is bounded." + ) + # ── Optional: torch.compile + CUDA-graph warmup (streaming only) ──────── if args.compile: if args.mode != "streaming": diff --git a/gct_profile.py b/gct_profile.py index 66597d0..224398a 100644 --- a/gct_profile.py +++ b/gct_profile.py @@ -169,11 +169,11 @@ def fps(ms): ms_lo, ms_mid, ms_hi = avg_ms(p_lo), avg_ms(p_mid), avg_ms(p_hi) print(f"\n [{label}] ({total_frames} total frames: {scale_frames} scale + {n} streaming)") - print(f" ── Global FPS ─────────────────────────────────────") + print(" ── Global FPS ─────────────────────────────────────") print(f" total time: {total_ms / 1000:.2f} s " f"({phase1_ms:.1f} ms phase1 + {total_ms - phase1_ms:.1f} ms phase2)") print(f" per frame : {global_ms_per_frame:6.2f} ms → {global_fps:6.2f} FPS") - print(f" ── Windowed FPS (±30 streaming frames) ────────────") + print(" ── Windowed FPS (±30 streaming frames) ────────────") print(f" frame {scale_frames + p_lo:>5d} (10%): {ms_lo:6.2f} ms → {fps(ms_lo):6.2f} FPS") print(f" frame {scale_frames + p_mid:>5d} (50%): {ms_mid:6.2f} ms → {fps(ms_mid):6.2f} FPS") print(f" frame {scale_frames + p_hi:>5d} (90%): {ms_hi:6.2f} ms → {fps(ms_hi):6.2f} FPS") @@ -182,7 +182,7 @@ def fps(ms): # original script. This naturally skips the cold first streaming frame # (global index = scale_frames), whose ms is dominated by one-time CUDA # graph (re)capture after `clean_kv_cache()` in profile_streaming. - print(f" ── FPS trace (every 100 global frames) ────────────") + print(" ── FPS trace (every 100 global frames) ────────────") first_trace = (100 - scale_frames) % 100 or 100 for i in range(first_trace, n, 100): ms_i = avg_ms(i, window=3) @@ -344,12 +344,12 @@ def _warm(m, passes=1): _warm(model) if args.compile: - print(f" Compiling hot modules...") + print(" Compiling hot modules...") compile_model(model) # Three passes under compile: 1st captures CUDA graphs, 2nd/3rd # replay so the caching allocator and graph-address map converge # on the exact state the subsequent profile will see. - print(f" Warmup compiled (3× dress rehearsal)...") + print(" Warmup compiled (3× dress rehearsal)...") _warm(model, passes=3) else: # No compile → a single dress-rehearsal pass is enough to diff --git a/lingbot_map/models/gct_stream_window.py b/lingbot_map/models/gct_stream_window.py index 0edd971..7a43d36 100644 --- a/lingbot_map/models/gct_stream_window.py +++ b/lingbot_map/models/gct_stream_window.py @@ -9,9 +9,10 @@ """ import logging +import numpy as np import torch import torch.nn as nn -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Union, cast from tqdm.auto import tqdm from lingbot_map.utils.rotation import quat_to_mat, mat_to_quat @@ -299,7 +300,9 @@ def _aggregate_features( num_frame_for_scale: Optional[int] = None, sliding_window_size: Optional[int] = None, num_frame_per_block: int = 1, - **kwargs, + view_graphs: Optional[torch.Tensor] = None, + causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None, + ordered_video: Optional[torch.Tensor] = None, ) -> tuple: """ Run aggregator to get multi-scale features. @@ -313,7 +316,9 @@ def _aggregate_features( Returns: (aggregated_tokens_list, patch_start_idx) """ - aggregated_tokens_list, patch_start_idx = self.aggregator( + del view_graphs, causal_graphs, ordered_video + aggregator = cast(Any, self.aggregator) + aggregated_tokens_list, patch_start_idx = aggregator( images, selected_idx=[4, 11, 17, 23], num_frame_for_scale=num_frame_for_scale, @@ -329,12 +334,14 @@ def clean_kv_cache(self): Call this method when starting a new video sequence to clear cached key-value pairs from previous sequences. """ - if hasattr(self.aggregator, 'clean_kv_cache'): - self.aggregator.clean_kv_cache() + aggregator = cast(Any, self.aggregator) + camera_head = cast(Any, self.camera_head) + if hasattr(aggregator, "clean_kv_cache"): + aggregator.clean_kv_cache() else: logger.warning("Aggregator does not support KV cache cleaning") - if hasattr(self.camera_head, 'kv_cache'): - self.camera_head.clean_kv_cache() + if hasattr(camera_head, "kv_cache"): + camera_head.clean_kv_cache() else: logger.warning("Camera head does not support KV cache cleaning") @@ -348,13 +355,15 @@ def _set_skip_append(self, skip: bool): Args: skip: If True, subsequent forward passes will not append KV to cache. """ - if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None: - self.aggregator.kv_cache["_skip_append"] = skip + aggregator = cast(Any, self.aggregator) + camera_head = cast(Any, self.camera_head) + if hasattr(aggregator, "kv_cache") and aggregator.kv_cache is not None: + aggregator.kv_cache["_skip_append"] = skip # FlashInfer manager - if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None: - self.aggregator.kv_cache_manager._skip_append = skip - if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None: - for cache_dict in self.camera_head.kv_cache: + if hasattr(aggregator, "kv_cache_manager") and aggregator.kv_cache_manager is not None: + aggregator.kv_cache_manager._skip_append = skip + if camera_head is not None and hasattr(camera_head, "kv_cache") and camera_head.kv_cache is not None: + for cache_dict in cast(List[Dict[str, Any]], camera_head.kv_cache): cache_dict["_skip_append"] = skip # ── Flow-based keyframe helpers ──────────────────────────────────────── @@ -366,14 +375,16 @@ def _set_defer_eviction(self, defer: bool): the most recent append without having to restore evicted frames. """ # FlashInfer manager - if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None: - self.aggregator.kv_cache_manager._defer_eviction = defer + aggregator = cast(Any, self.aggregator) + camera_head = cast(Any, self.camera_head) + if hasattr(aggregator, "kv_cache_manager") and aggregator.kv_cache_manager is not None: + aggregator.kv_cache_manager._defer_eviction = defer # SDPA aggregator cache (dict) - if hasattr(self.aggregator, 'kv_cache') and isinstance(self.aggregator.kv_cache, dict): - self.aggregator.kv_cache["_defer_eviction"] = defer + if hasattr(aggregator, "kv_cache") and isinstance(aggregator.kv_cache, dict): + aggregator.kv_cache["_defer_eviction"] = defer # Camera head SDPA caches - if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None: - for cache_dict in self.camera_head.kv_cache: + if camera_head is not None and hasattr(camera_head, "kv_cache") and camera_head.kv_cache is not None: + for cache_dict in cast(List[Dict[str, Any]], camera_head.kv_cache): cache_dict["_defer_eviction"] = defer def _rollback_last_frame(self): @@ -384,14 +395,16 @@ def _rollback_last_frame(self): Must be called while eviction is still deferred. """ # FlashInfer manager — rollback each transformer block - if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None: - mgr = self.aggregator.kv_cache_manager + aggregator = cast(Any, self.aggregator) + camera_head = cast(Any, self.camera_head) + if hasattr(aggregator, "kv_cache_manager") and aggregator.kv_cache_manager is not None: + mgr = cast(Any, aggregator.kv_cache_manager) for block_idx in range(mgr.num_blocks): mgr.rollback_last_frame(block_idx) # SDPA aggregator cache — trim last frame along dim=2 - if hasattr(self.aggregator, 'kv_cache') and isinstance(self.aggregator.kv_cache, dict): - kv = self.aggregator.kv_cache + if hasattr(aggregator, "kv_cache") and isinstance(aggregator.kv_cache, dict): + kv = cast(Dict[str, Any], aggregator.kv_cache) for key in list(kv.keys()): if key.startswith(("k_", "v_")) and kv[key] is not None and torch.is_tensor(kv[key]): if kv[key].dim() >= 3 and kv[key].shape[2] > 1: @@ -400,17 +413,18 @@ def _rollback_last_frame(self): kv[key] = None # Camera head - if self.camera_head is not None and hasattr(self.camera_head, 'rollback_last_frame'): - self.camera_head.rollback_last_frame() + if camera_head is not None and hasattr(camera_head, "rollback_last_frame"): + camera_head.rollback_last_frame() # Aggregator frame counter (used for 3D RoPE temporal positions) - self.aggregator.total_frames_processed -= 1 + aggregator.total_frames_processed -= 1 def _execute_deferred_eviction(self): """Execute the eviction that was deferred during the last forward pass.""" # FlashInfer manager - if hasattr(self.aggregator, 'kv_cache_manager') and self.aggregator.kv_cache_manager is not None: - mgr = self.aggregator.kv_cache_manager + aggregator = cast(Any, self.aggregator) + if hasattr(aggregator, "kv_cache_manager") and aggregator.kv_cache_manager is not None: + mgr = cast(Any, aggregator.kv_cache_manager) for block_idx in range(mgr.num_blocks): mgr.execute_deferred_eviction( block_idx, @@ -427,10 +441,11 @@ def get_kv_cache_info(self) -> Dict[str, Any]: - num_cached_blocks: Number of blocks with cached KV - cache_memory_mb: Approximate memory usage in MB """ - if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None: + aggregator = cast(Any, self.aggregator) + if not hasattr(aggregator, "kv_cache") or aggregator.kv_cache is None: return {"num_cached_blocks": 0, "cache_memory_mb": 0.0} - kv_cache = self.aggregator.kv_cache + kv_cache = cast(Dict[str, Any], aggregator.kv_cache) num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special')) # Estimate memory usage @@ -660,7 +675,8 @@ def _to_out(t: torch.Tensor) -> torch.Tensor: # Apply prediction normalization if enabled if self.pred_normalization: - predictions = self._normalize_predictions(predictions) + normalize_predictions = cast(Any, self._normalize_predictions) + predictions = normalize_predictions(predictions) return predictions @@ -676,7 +692,7 @@ def _to_out(t: torch.Tensor) -> torch.Tensor: def _stitch_windows( self, - windows: List[Dict], + windows: List[Dict[str, Any]], window_size: int, overlap: int, ) -> Dict: @@ -717,12 +733,12 @@ def _stitch_windows( end = total if is_last else max(total - overlap, 0) slices.append((0, end) if end > 0 else None) - parts = [ - values[i][:, s:e] - for i, s_e in enumerate(slices) - if s_e is not None - for s, e in [s_e] - ] + parts = [] + for tensor, s_e in zip(values, slices): + if tensor is None or s_e is None: + continue + s, e = s_e + parts.append(cast(torch.Tensor, tensor)[:, s:e]) if parts: stitched[key] = torch.cat(parts, dim=1) else: @@ -1080,11 +1096,52 @@ def _new_lists(): 'frame_type': [], # list of ints: 0=scale, 1=keyframe, 2=non-keyframe } + def _tensor_ref(pred: Dict) -> torch.Tensor: + return next( + v for k in ("pose_enc", "world_points", "depth") + if (v := pred.get(k)) is not None + ) + + merged_predictions: Optional[Dict] = None + prev_warped_window: Optional[Dict] = None + per_window_scales: List[torch.Tensor] = [] + per_window_transforms: List[torch.Tensor] = [] + merged_window_count = 0 + + def _merge_window_pred(window_pred: Dict) -> None: + nonlocal merged_predictions, prev_warped_window, merged_window_count + + ref = _tensor_ref(prev_warped_window or window_pred) + dev, dt, nb = ref.device, ref.dtype, ref.shape[0] + + if prev_warped_window is None: + s_rel = torch.ones(nb, device=dev, dtype=dt) + R_rel = torch.eye(3, device=dev, dtype=dt).unsqueeze(0).expand(nb, -1, -1).clone() + t_rel = torch.zeros(nb, 3, device=dev, dtype=dt) + warped = window_pred + merged_predictions = window_pred + else: + s_rel, R_rel, t_rel = self._pairwise_alignment( + prev_warped_window, window_pred, eff_overlap, nb, dev, dt, + ) + warped = self._warp_predictions(window_pred, R_rel, t_rel, s_rel, nb) + assert merged_predictions is not None + merged_predictions = self._stitch_windows( + [merged_predictions, warped], window_size, eff_overlap + ) + + per_window_scales.append(s_rel.clone()) + T = torch.eye(4, device=dev, dtype=dt).unsqueeze(0).expand(nb, -1, -1).clone() + T[:, :3, :3] = R_rel + T[:, :3, 3] = t_rel + per_window_transforms.append(T) + prev_warped_window = warped + merged_window_count += 1 + # ================================================================ # Flow-based mode: dynamic windows (can't precompute window list) # ================================================================ if use_flow_keyframe: - all_window_predictions: List[Dict] = [] cursor = 0 window_idx = 0 pbar = tqdm(total=S, desc='Windowed inference (flow)', initial=0) @@ -1114,6 +1171,7 @@ def _new_lists(): last_kf_pose_enc = scale_out["pose_enc"][:, -1:] last_kf_local_idx = window_scale - 1 del scale_out + del scale_images cursor += window_scale pbar.update(window_scale) @@ -1167,10 +1225,13 @@ def _new_lists(): _collect_frame(frame_out, w_lists) del frame_out + del frame_image cursor += 1 pbar.update(1) - all_window_predictions.append(_make_window_pred(w_lists)) + window_pred = _make_window_pred(w_lists) + _merge_window_pred(window_pred) + del window_pred window_idx += 1 # Next window starts overlap_size frames back (= scale frames) @@ -1204,7 +1265,6 @@ def _new_lists(): if end_idx == S: break - all_window_predictions: List[Dict] = [] for start, end in tqdm(windows, desc='Windowed inference'): # Slice on whichever device `images` lives on, then move just # this window to the model device. Keeps peak memory at one @@ -1255,20 +1315,24 @@ def _new_lists(): w_lists['frame_type'].append(1 if is_keyframe else 2) del frame_out - all_window_predictions.append(_make_window_pred(w_lists)) + window_pred = _make_window_pred(w_lists) + _merge_window_pred(window_pred) + del window_pred + del window_images - # Store for merge helpers - self._last_window_size = eff_overlap # not used directly, but kept for compat - self._last_overlap_size = eff_overlap - - # Align and stitch windows - predictions = self._align_and_stitch_windows( - all_window_predictions, scale_mode=scale_mode - ) + predictions = merged_predictions or {} + if predictions: + if merged_window_count > 1: + if per_window_scales: + predictions["chunk_scales"] = torch.stack(per_window_scales, dim=1) + if per_window_transforms: + predictions["chunk_transforms"] = torch.stack(per_window_transforms, dim=1) + predictions["alignment_mode"] = "scaled" predictions["images"] = _to_out(images) if self.pred_normalization: - predictions = self._normalize_predictions(predictions) + normalize_predictions = cast(Any, self._normalize_predictions) + predictions = normalize_predictions(predictions) - return predictions \ No newline at end of file + return predictions diff --git a/lingbot_map/vis/__init__.py b/lingbot_map/vis/__init__.py index c0422e6..a30dacd 100644 --- a/lingbot_map/vis/__init__.py +++ b/lingbot_map/vis/__init__.py @@ -11,10 +11,16 @@ - PointCloudViewer: Interactive point cloud viewer with camera visualization - viser_wrapper: Quick visualization wrapper for predictions - predictions_to_glb: Export predictions to GLB 3D format +- predictions_to_ply: Export predictions to PLY point-cloud format - Colorization and utility functions Usage: - from lingbot_map.vis import PointCloudViewer, viser_wrapper, predictions_to_glb + from lingbot_map.vis import ( + PointCloudViewer, + viser_wrapper, + predictions_to_glb, + predictions_to_ply, + ) # Interactive visualization viewer = PointCloudViewer(pred_dict=predictions, port=8080) @@ -26,6 +32,9 @@ # Export to GLB scene = predictions_to_glb(predictions) scene.export("output.glb") + + # Export to PLY + predictions_to_ply(predictions, "output.ply") """ from lingbot_map.vis.point_cloud_viewer import PointCloudViewer @@ -37,7 +46,11 @@ load_or_create_sky_masks, segment_sky, ) -from lingbot_map.vis.glb_export import predictions_to_glb +from lingbot_map.vis.glb_export import ( + predictions_to_glb, + predictions_to_ply, + save_point_cloud_to_ply, +) __all__ = [ # Main viewer @@ -46,6 +59,8 @@ "viser_wrapper", # GLB export "predictions_to_glb", + "predictions_to_ply", + "save_point_cloud_to_ply", # Utilities "CameraState", "colorize", diff --git a/lingbot_map/vis/glb_export.py b/lingbot_map/vis/glb_export.py index b2ccd74..769756f 100644 --- a/lingbot_map/vis/glb_export.py +++ b/lingbot_map/vis/glb_export.py @@ -5,12 +5,15 @@ # LICENSE file in the root directory of this source tree. """ -GLB 3D export utilities for GCT predictions. +3D export utilities for GCT predictions. """ import os import copy -from typing import Optional, Tuple +import importlib +import importlib.util +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import cv2 @@ -25,10 +28,10 @@ _result_map_to_non_sky_conf, ) -try: - import trimesh -except ImportError: - trimesh = None +trimesh: Any = None +if importlib.util.find_spec("trimesh") is not None: + trimesh = importlib.import_module("trimesh") +else: print("trimesh not found. GLB export will not work.") @@ -78,7 +81,172 @@ def predictions_to_glb( conf_thres = 10.0 print("Building GLB scene") + vertices_3d, colors_rgb, extrinsics_matrices, scene_scale = _prepare_export_point_cloud( + predictions, + conf_thres=conf_thres, + filter_by_frames=filter_by_frames, + mask_black_bg=mask_black_bg, + mask_white_bg=mask_white_bg, + mask_sky=mask_sky, + target_dir=target_dir, + prediction_mode=prediction_mode, + ) + + if np.asarray(vertices_3d).size == 0: + vertices_3d = np.array([[1.0, 0.0, 0.0]], dtype=np.float32) + colors_rgb = np.array([[255, 255, 255]], dtype=np.uint8) + + colormap = matplotlib.colormaps.get_cmap("gist_rainbow") + scene_3d = trimesh.Scene() + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + scene_3d.add_geometry(point_cloud_data) + + # Add cameras + if show_cam and len(extrinsics_matrices) > 0: + num_cameras = len(extrinsics_matrices) + for i in range(num_cameras): + world_to_camera = extrinsics_matrices[i] + camera_to_world = np.linalg.inv(world_to_camera) + rgba_color = colormap(i / num_cameras) + current_color = ( + int(255 * rgba_color[0]), + int(255 * rgba_color[1]), + int(255 * rgba_color[2]), + ) + integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) + + # Align scene + if len(extrinsics_matrices) > 0: + scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) + + print("GLB Scene built") + return scene_3d + + +def predictions_to_ply( + predictions: dict, + output_path: str, + conf_thres: float = 50.0, + filter_by_frames: str = "all", + mask_black_bg: bool = False, + mask_white_bg: bool = False, + mask_sky: bool = False, + target_dir: Optional[str] = None, + prediction_mode: str = "Predicted Pointmap", +) -> int: + """ + Export filtered prediction points as a PLY point cloud. + + Args: + predictions: Prediction dictionary in the same format as predictions_to_glb. + output_path: Destination ``.ply`` path. + conf_thres: Percentage of low-confidence points to filter out. + filter_by_frames: Frame filter specification ("all" or frame index). + mask_black_bg: Mask out black background pixels. + mask_white_bg: Mask out white background pixels. + mask_sky: Apply sky segmentation mask. + target_dir: Output directory for intermediate files. + prediction_mode: "Predicted Pointmap" or "Predicted Depthmap". + + Returns: + Number of exported points. + """ + vertices_3d, colors_rgb, extrinsics_matrices, _ = _prepare_export_point_cloud( + predictions, + conf_thres=conf_thres, + filter_by_frames=filter_by_frames, + mask_black_bg=mask_black_bg, + mask_white_bg=mask_white_bg, + mask_sky=mask_sky, + target_dir=target_dir, + prediction_mode=prediction_mode, + ) + if len(vertices_3d) > 0 and len(extrinsics_matrices) > 0: + vertices_3d = apply_scene_alignment_to_vertices(vertices_3d, extrinsics_matrices) + save_point_cloud_to_ply(vertices_3d, colors_rgb, output_path) + return int(len(vertices_3d)) + + +def save_point_cloud_to_ply( + vertices: np.ndarray, + colors_rgb: np.ndarray, + output_path: str, + normals: Optional[np.ndarray] = None, + alpha: Optional[np.ndarray] = None, + extra_vertex_properties: Optional[Dict[str, np.ndarray]] = None, + comments: Optional[List[str]] = None, + extra_elements: Optional[List[Tuple[str, Dict[str, np.ndarray]]]] = None, +) -> None: + """Write a point cloud to a binary little-endian PLY file.""" + vertices = np.asarray(vertices, dtype=np.float32) + colors_rgb = _coerce_colors_to_uint8(colors_rgb) + + if vertices.ndim != 2 or vertices.shape[1] != 3: + raise ValueError(f"vertices must have shape (N, 3), got {vertices.shape}") + if colors_rgb.ndim != 2 or colors_rgb.shape[1] != 3: + raise ValueError(f"colors_rgb must have shape (N, 3), got {colors_rgb.shape}") + if len(vertices) != len(colors_rgb): + raise ValueError( + f"vertices/colors length mismatch: {len(vertices)} vs {len(colors_rgb)}" + ) + + out_dir = os.path.dirname(output_path) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + vertex_properties: "OrderedDict[str, np.ndarray]" = OrderedDict() + vertex_properties["x"] = vertices[:, 0] + vertex_properties["y"] = vertices[:, 1] + vertex_properties["z"] = vertices[:, 2] + vertex_properties["red"] = colors_rgb[:, 0] + vertex_properties["green"] = colors_rgb[:, 1] + vertex_properties["blue"] = colors_rgb[:, 2] + + if normals is not None: + normals = np.asarray(normals, dtype=np.float32) + if normals.ndim != 2 or normals.shape != vertices.shape: + raise ValueError( + f"normals must have shape {vertices.shape}, got {normals.shape}" + ) + vertex_properties["nx"] = normals[:, 0] + vertex_properties["ny"] = normals[:, 1] + vertex_properties["nz"] = normals[:, 2] + + if alpha is not None: + alpha = np.asarray(alpha) + if alpha.ndim != 1 or len(alpha) != len(vertices): + raise ValueError( + f"alpha must have shape ({len(vertices)},), got {alpha.shape}" + ) + vertex_properties["alpha"] = np.clip(alpha, 0, 255).astype(np.uint8) + + if extra_vertex_properties: + for name, values in extra_vertex_properties.items(): + if name in vertex_properties: + raise ValueError(f"duplicate PLY vertex property: {name}") + vertex_properties[name] = _coerce_ply_property_array( + name, values, len(vertices) + ) + + elements: List[Tuple[str, Dict[str, np.ndarray]]] = [("vertex", vertex_properties)] + if extra_elements: + elements.extend(extra_elements) + + _write_ply_elements(output_path, elements, comments) + + +def _prepare_export_point_cloud( + predictions: dict, + conf_thres: float = 50.0, + filter_by_frames: str = "all", + mask_black_bg: bool = False, + mask_white_bg: bool = False, + mask_sky: bool = False, + target_dir: Optional[str] = None, + prediction_mode: str = "Predicted Pointmap", +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]: + """Prepare filtered point-cloud export data shared by GLB and PLY.""" # Parse frame filter selected_frame_idx = None if filter_by_frames != "all" and filter_by_frames != "All": @@ -111,35 +279,29 @@ def predictions_to_glb( images = predictions["images"] camera_matrices = predictions["extrinsic"] - # Apply sky segmentation if enabled if mask_sky and target_dir is not None: pred_world_points_conf = _apply_sky_mask( pred_world_points_conf, target_dir, images ) - # Apply frame filter if selected_frame_idx is not None: pred_world_points = pred_world_points[selected_frame_idx][None] pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None] images = images[selected_frame_idx][None] camera_matrices = camera_matrices[selected_frame_idx][None] - # Prepare vertices and colors - vertices_3d = pred_world_points.reshape(-1, 3) + vertices_3d = np.asarray(pred_world_points).reshape(-1, 3).astype(np.float32, copy=False) - # Handle different image formats - if images.ndim == 4 and images.shape[1] == 3: # NCHW format + if images.ndim == 4 and images.shape[1] == 3: colors_rgb = np.transpose(images, (0, 2, 3, 1)) else: colors_rgb = images - colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) + colors_rgb = _coerce_colors_to_uint8(colors_rgb.reshape(-1, 3)) - # Apply confidence filtering - conf = pred_world_points_conf.reshape(-1) + conf = np.asarray(pred_world_points_conf).reshape(-1) conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0 conf_mask = (conf >= conf_threshold) & (conf > 1e-5) - # Apply background masking if mask_black_bg: black_bg_mask = colors_rgb.sum(axis=1) >= 16 conf_mask = conf_mask & black_bg_mask @@ -155,43 +317,30 @@ def predictions_to_glb( vertices_3d = vertices_3d[conf_mask] colors_rgb = colors_rgb[conf_mask] - # Handle empty point cloud - if vertices_3d is None or np.asarray(vertices_3d).size == 0: - vertices_3d = np.array([[1, 0, 0]]) - colors_rgb = np.array([[255, 255, 255]]) - scene_scale = 1 + extrinsics_matrices = np.zeros((len(camera_matrices), 4, 4), dtype=np.float32) + if len(camera_matrices) > 0: + extrinsics_matrices[:, :3, :4] = camera_matrices + extrinsics_matrices[:, 3, 3] = 1 + + if np.asarray(vertices_3d).size == 0: + scene_scale = 1.0 else: lower_percentile = np.percentile(vertices_3d, 5, axis=0) upper_percentile = np.percentile(vertices_3d, 95, axis=0) - scene_scale = np.linalg.norm(upper_percentile - lower_percentile) - - colormap = matplotlib.colormaps.get_cmap("gist_rainbow") - - # Build scene - scene_3d = trimesh.Scene() - point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) - scene_3d.add_geometry(point_cloud_data) + scene_scale = float(np.linalg.norm(upper_percentile - lower_percentile)) + scene_scale = max(scene_scale, 0.1) - # Prepare camera matrices - num_cameras = len(camera_matrices) - extrinsics_matrices = np.zeros((num_cameras, 4, 4)) - extrinsics_matrices[:, :3, :4] = camera_matrices - extrinsics_matrices[:, 3, 3] = 1 + return vertices_3d, colors_rgb, extrinsics_matrices, scene_scale - # Add cameras - if show_cam: - for i in range(num_cameras): - world_to_camera = extrinsics_matrices[i] - camera_to_world = np.linalg.inv(world_to_camera) - rgba_color = colormap(i / num_cameras) - current_color = tuple(int(255 * x) for x in rgba_color[:3]) - integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) - # Align scene - scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) - - print("GLB Scene built") - return scene_3d +def _coerce_colors_to_uint8(colors_rgb: np.ndarray) -> np.ndarray: + """Convert float RGB in [0, 1] or integer RGB in [0, 255] to uint8.""" + colors_rgb = np.asarray(colors_rgb) + if colors_rgb.dtype == np.uint8: + return colors_rgb + if np.issubdtype(colors_rgb.dtype, np.floating): + return (np.clip(colors_rgb, 0.0, 1.0) * 255).astype(np.uint8) + return np.clip(colors_rgb, 0, 255).astype(np.uint8) def _apply_sky_mask( @@ -234,6 +383,9 @@ def _apply_sky_mask( else: sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath) + if sky_mask is None: + print(f"Warning: failed to read sky mask for {image_name}, keeping all pixels") + sky_mask = np.full((H, W), 255, dtype=np.uint8) if sky_mask.shape[0] != H or sky_mask.shape[1] != W: sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR) @@ -247,7 +399,7 @@ def _apply_sky_mask( def integrate_camera_into_scene( scene: "trimesh.Scene", transform: np.ndarray, - face_colors: Tuple[int, int, int], + face_colors: Sequence[int], scene_scale: float, frustum_thickness: float = 1.0, ): @@ -304,7 +456,8 @@ def integrate_camera_into_scene( mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales)) camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) - camera_mesh.visual.face_colors[:, :3] = face_colors + camera_visual = camera_mesh.visual + camera_visual.face_colors[:, :3] = tuple(face_colors[:3]) scene.add_geometry(camera_mesh) @@ -322,16 +475,55 @@ def apply_scene_alignment( Returns: Aligned 3D scene """ + initial_transformation = get_scene_alignment_transform(extrinsics_matrices) + scene_3d.apply_transform(initial_transformation) + return scene_3d + + +def apply_scene_alignment_to_vertices( + vertices: np.ndarray, + extrinsics_matrices: np.ndarray, +) -> np.ndarray: + """Apply the same scene alignment used by GLB export directly to vertices.""" + if len(extrinsics_matrices) == 0: + return np.asarray(vertices, dtype=np.float32) + + transformation = get_scene_alignment_transform(extrinsics_matrices) + return transform_points(transformation, np.asarray(vertices, dtype=np.float32)) + + +def apply_scene_alignment_to_directions( + vectors: np.ndarray, + extrinsics_matrices: np.ndarray, +) -> np.ndarray: + """Apply scene alignment rotation to direction vectors such as normals.""" + vectors = np.asarray(vectors, dtype=np.float32) + if len(extrinsics_matrices) == 0 or vectors.size == 0: + return vectors + + rotation_only = np.eye(4, dtype=np.float32) + rotation_only[:3, :3] = get_scene_alignment_transform(extrinsics_matrices)[:3, :3] + rotated = transform_points(rotation_only, vectors) + norm = np.linalg.norm(rotated, axis=-1, keepdims=True) + valid = norm > 1e-8 + rotated = np.where(valid, rotated / np.where(valid, norm, 1.0), 0.0) + return rotated.astype(np.float32, copy=False) + + +def get_scene_alignment_transform(extrinsics_matrices: np.ndarray) -> np.ndarray: + """Return the world transform used to align GLB/PLY exports.""" + if len(extrinsics_matrices) == 0: + return np.eye(4, dtype=np.float32) + opengl_conversion_matrix = get_opengl_conversion_matrix() align_rotation = np.eye(4) align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix() - initial_transformation = ( + transformation = ( np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation ) - scene_3d.apply_transform(initial_transformation) - return scene_3d + return transformation.astype(np.float32, copy=False) def get_opengl_conversion_matrix() -> np.ndarray: @@ -368,6 +560,131 @@ def transform_points( return points[..., :dim].reshape(*initial_shape, dim) +def _write_ply_elements( + output_path: str, + elements: List[Tuple[str, Dict[str, np.ndarray]]], + comments: Optional[List[str]] = None, +) -> None: + """Write one or more binary PLY elements to disk.""" + prepared_elements = [ + _prepare_ply_element(name, properties) + for name, properties in elements + ] + + header_lines = [ + "ply", + "format binary_little_endian 1.0", + ] + for comment in comments or []: + header_lines.append(f"comment {comment}") + for name, structured, property_specs in prepared_elements: + header_lines.append(f"element {name} {len(structured)}") + for property_name, ply_type in property_specs: + header_lines.append(f"property {ply_type} {property_name}") + header_lines.append("end_header") + header = "\n".join(header_lines) + "\n" + + with open(output_path, "wb") as f: + f.write(header.encode("ascii")) + for _, structured, _ in prepared_elements: + structured.tofile(f) + + +def _prepare_ply_element( + element_name: str, + properties: Dict[str, np.ndarray], +) -> Tuple[str, np.ndarray, List[Tuple[str, str]]]: + """Normalize one PLY element into a structured array plus header specs.""" + if not properties: + raise ValueError(f"PLY element {element_name!r} has no properties") + + normalized: "OrderedDict[str, np.ndarray]" = OrderedDict() + element_length: Optional[int] = None + for property_name, values in properties.items(): + values = np.asarray(values) + if values.ndim != 1: + raise ValueError( + f"PLY property {element_name}.{property_name} must be 1D, got {values.shape}" + ) + if element_length is None: + element_length = len(values) + elif len(values) != element_length: + raise ValueError( + f"PLY element {element_name!r} property length mismatch for {property_name}: " + f"expected {element_length}, got {len(values)}" + ) + normalized[property_name] = _normalize_ply_property_dtype(values) + + assert element_length is not None + dtype_fields = [] + property_specs = [] + for property_name, values in normalized.items(): + little_endian_dtype = values.dtype.newbyteorder("<") + dtype_fields.append((property_name, little_endian_dtype)) + property_specs.append((property_name, _numpy_dtype_to_ply_type(little_endian_dtype))) + + structured = np.empty(element_length, dtype=np.dtype(dtype_fields)) + for property_name, values in normalized.items(): + structured[property_name] = values.astype( + structured[property_name].dtype, + copy=False, + ) + + return element_name, structured, property_specs + + +def _coerce_ply_property_array( + property_name: str, + values: np.ndarray, + expected_length: int, +) -> np.ndarray: + """Validate an extra PLY property array against the expected vertex count.""" + values = np.asarray(values) + if values.ndim != 1 or len(values) != expected_length: + raise ValueError( + f"PLY property {property_name!r} must have shape ({expected_length},), " + f"got {values.shape}" + ) + return _normalize_ply_property_dtype(values) + + +def _normalize_ply_property_dtype(values: np.ndarray) -> np.ndarray: + """Cast property arrays to PLY-compatible scalar dtypes.""" + values = np.asarray(values) + if values.dtype == np.bool_: + return values.astype(np.uint8) + if values.dtype == np.float16: + return values.astype(np.float32) + if np.issubdtype(values.dtype, np.floating): + return values.astype(np.float32 if values.dtype.itemsize <= 4 else np.float64) + if np.issubdtype(values.dtype, np.signedinteger): + return values.astype(np.int32 if values.dtype.itemsize > 4 else values.dtype) + if np.issubdtype(values.dtype, np.unsignedinteger): + if values.dtype.itemsize > 4: + return values.astype(np.uint32) + return values.astype(values.dtype) + raise ValueError(f"Unsupported PLY property dtype: {values.dtype}") + + +def _numpy_dtype_to_ply_type(dtype: np.dtype) -> str: + """Map a numpy scalar dtype to a binary PLY header type.""" + dtype = np.dtype(dtype).newbyteorder("=") + dtype_map = { + np.dtype(np.int8): "char", + np.dtype(np.uint8): "uchar", + np.dtype(np.int16): "short", + np.dtype(np.uint16): "ushort", + np.dtype(np.int32): "int", + np.dtype(np.uint32): "uint", + np.dtype(np.float32): "float", + np.dtype(np.float64): "double", + } + try: + return dtype_map[dtype] + except KeyError as exc: + raise ValueError(f"Unsupported PLY dtype: {dtype}") from exc + + def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray: """Computes the faces for the camera mesh.""" faces_list = [] @@ -435,6 +752,8 @@ def segment_sky( Continuous non-sky confidence map in [0, 1] """ image = cv2.imread(image_path) + if image is None: + raise ValueError(f"Failed to read image for sky segmentation: {image_path}") result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image) result_map_original = cv2.resize( result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR diff --git a/lingbot_map/vis/point_cloud_viewer.py b/lingbot_map/vis/point_cloud_viewer.py index a1d698c..5ade5de 100644 --- a/lingbot_map/vis/point_cloud_viewer.py +++ b/lingbot_map/vis/point_cloud_viewer.py @@ -17,12 +17,12 @@ import subprocess import tempfile import shutil -from typing import List, Optional, Dict, Any, Tuple +from typing import List, Optional, Dict, Any, Tuple, cast import numpy as np import torch import cv2 -import matplotlib.cm as cm +from matplotlib import colormaps from tqdm.auto import tqdm import viser @@ -123,10 +123,16 @@ def __init__( self.image_mask = image_mask self.show_camera = show_camera self.on_replay = False - self.vis_pts_list = [] - self.traj_list = [] - self.orig_img_list = [x[0] for x in color_list if len(x) > 0] if color_list else [] self.via_points = [] + self.max_history_frames = max(0, min(self.num_frames - 1, 200)) + self.default_history_frames = min(20, self.max_history_frames) + self.default_max_render_points = 300_000 + self._render_cache_token = None + self._frame_render_cache: Dict[int, Tuple[np.ndarray, np.ndarray]] = {} + self.active_point_handle = None + self.cam_handles = [] + self.frame_visibility_mode = "current" + self.max_drawn_frame_idx = 0 self._setup_gui() self.server.on_client_connect(self._connect_client) @@ -164,6 +170,8 @@ def _process_pred_dict( # Compute world points from depth if not using the precomputed point map if not use_point_map: + if depth_map is None: + raise ValueError("depth predictions are required when use_point_map=False") world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) conf = depth_conf else: @@ -172,6 +180,8 @@ def _process_pred_dict( # Apply sky segmentation if enabled if mask_sky: + if conf is None: + raise ValueError("confidence predictions are required when mask_sky=True") conf = apply_sky_segmentation( conf, image_folder=image_folder, images=images, sky_mask_dir=sky_mask_dir, @@ -190,7 +200,6 @@ def _process_pred_dict( self.original_images.append(img) # Create lists - apply depth_stride to skip frames for point projection - H, W = world_points.shape[1], world_points.shape[2] pc_list = [] color_list = [] conf_list = [] @@ -388,16 +397,21 @@ def _(_) -> None: button3 = self.server.gui.add_button("4D (Only Show Current Frame)") button4 = self.server.gui.add_button("3D (Show All Frames)") - self.is_render = False - self.fourd = False + button5 = self.server.gui.add_button("Recent Frame History") @button3.on_click def _(event: viser.GuiEvent) -> None: - self.fourd = True + self._set_frame_visibility_mode("current") @button4.on_click def _(event: viser.GuiEvent) -> None: - self.fourd = False + self._set_frame_visibility_mode("all") + + @button5.on_click + def _(event: viser.GuiEvent) -> None: + self._set_frame_visibility_mode("recent") + if hasattr(self, "gui_history_frames"): + self.gui_history_frames.value = self.default_history_frames self.focal_slider = self.server.gui.add_slider( "Focal Length", min=0.1, max=99999, step=1, initial_value=533 @@ -441,8 +455,8 @@ def _(event: viser.GuiEvent) -> None: def _(event: viser.GuiEvent) -> None: self._take_screenshot(event.client) - # GLB export controls - with self.server.gui.add_folder("Export GLB"): + # 3D export controls + with self.server.gui.add_folder("Export 3D"): self.glb_output_path = self.server.gui.add_text( "Output Path", initial_value="export.glb" ) @@ -498,6 +512,14 @@ def _(event: viser.GuiEvent) -> None: hint="Export current filtered point clouds and cameras as GLB.", ) self.glb_status = self.server.gui.add_text("Status", initial_value="Ready") + self.ply_output_path = self.server.gui.add_text( + "PLY Output Path", initial_value="export.ply" + ) + self.ply_export_button = self.server.gui.add_button( + "Export PLY", + hint="Export the current filtered point cloud as a PLY file.", + ) + self.ply_status = self.server.gui.add_text("PLY Status", initial_value="Ready") @self.glb_mode_dropdown.on_update def _(_) -> None: @@ -509,6 +531,10 @@ def _(_) -> None: def _(_) -> None: self._export_glb() + @self.ply_export_button.on_click + def _(_) -> None: + self._export_ply() + # Video saving controls with self.server.gui.add_folder("Video Saving"): self.save_video_button = self.server.gui.add_button("Save Video", disabled=False) @@ -534,19 +560,14 @@ def _(_) -> None: if self.current_frame_image is not None: self.current_frame_image.visible = self.show_video_checkbox.value - self.pc_handles = [] - self.cam_handles = [] - @self.psize_slider.on_update def _(_) -> None: - for handle in self.pc_handles: - handle.point_size = self.psize_slider.value + if self.active_point_handle is not None: + self.active_point_handle.point_size = self.psize_slider.value @self.camsize_slider.on_update def _(_) -> None: - for handle in self.cam_handles: - handle.scale = self.camsize_slider.value - handle.line_thickness = 0.03 * handle.scale + self._regenerate_cameras() @self.downsample_slider.on_update def _(_) -> None: @@ -558,8 +579,7 @@ def _(_) -> None: if self.show_camera: self._regenerate_cameras() else: - for handle in self.cam_handles: - handle.visible = False + self._clear_active_cameras() @self.vis_threshold_slider.on_update def _(_) -> None: @@ -571,43 +591,199 @@ def _(_) -> None: self._regenerate_cameras() def _regenerate_point_clouds(self): - """Regenerate all point clouds with current settings.""" - if not hasattr(self, 'frame_nodes'): - return + """Regenerate the active browser point cloud with current settings.""" + self._invalidate_point_cache() + self._refresh_active_scene() - for handle in self.pc_handles: - try: - handle.remove() - except (KeyError, AttributeError): - pass - self.pc_handles.clear() - self.vis_pts_list.clear() + def _regenerate_cameras(self): + """Regenerate active camera visualizations with current settings.""" + self._clear_active_cameras() + self._refresh_active_scene() + + def _invalidate_point_cache(self) -> None: + """Clear filtered point-cache entries after a render-setting change.""" + self._render_cache_token = None + self._frame_render_cache.clear() + + def _get_render_cache_token(self) -> Tuple[float, int]: + """Return the render-settings token used for cached filtered points.""" + return (float(self.vis_threshold), int(self.downsample_slider.value)) + + def _get_filtered_frame_geometry(self, step: int) -> Tuple[np.ndarray, np.ndarray]: + """Get filtered/downsampled point cloud data for one frame.""" + token = self._get_render_cache_token() + if token != self._render_cache_token: + self._render_cache_token = token + self._frame_render_cache.clear() + + cached = self._frame_render_cache.get(step) + if cached is not None: + return cached + + frame_payload = self._extract_frame_export_samples(step) + pred_pts = frame_payload["points"] + color = frame_payload["colors"] + self._frame_render_cache[step] = (pred_pts, color) + return pred_pts, color - for i, step in enumerate(self.all_steps): - pc = self.pcs[step]["pc"] - color = self.pcs[step]["color"] - conf = self.pcs[step]["conf"] - edge_color = self.pcs[step].get("edge_color", None) - - pred_pts, pc_color = self.parse_pc_data( - pc, color, conf, edge_color, set_border_color=True, - downsample_factor=self.downsample_slider.value - ) + @staticmethod + def _compute_frame_normals(pc: np.ndarray) -> np.ndarray: + """Estimate per-point normals from a structured point-map frame.""" + pc = np.asarray(pc, dtype=np.float32) + if pc.ndim != 3 or pc.shape[-1] != 3 or pc.size == 0: + return np.zeros_like(pc, dtype=np.float32) + + right = np.roll(pc, -1, axis=1) - pc + left = pc - np.roll(pc, 1, axis=1) + down = np.roll(pc, -1, axis=0) - pc + up = pc - np.roll(pc, 1, axis=0) + normals = np.cross(right + left, down + up) + + invalid = ~np.isfinite(pc).all(axis=2) + invalid |= np.roll(invalid, 1, axis=0) + invalid |= np.roll(invalid, -1, axis=0) + invalid |= np.roll(invalid, 1, axis=1) + invalid |= np.roll(invalid, -1, axis=1) + normals[invalid] = 0.0 + normals[[0, -1], :, :] = 0.0 + normals[:, [0, -1], :] = 0.0 + + lengths = np.linalg.norm(normals, axis=2, keepdims=True) + valid = lengths[..., 0] > 1e-8 + normals[valid] /= lengths[valid] + normals[~valid] = 0.0 + return normals.astype(np.float32, copy=False) + + def _extract_frame_export_samples(self, step: int) -> Dict[str, np.ndarray]: + """Collect filtered export samples plus metadata for one frame.""" + frame = self.pcs[step] + pc = np.asarray(frame["pc"], dtype=np.float32) + color = np.asarray(frame["color"]) + conf = frame["conf"] + + points = pc.reshape(-1, 3) + normals = self._compute_frame_normals(pc).reshape(-1, 3) + + if color.size == 0: + colors = np.zeros((len(points), 3), dtype=np.uint8) + elif np.isnan(color).any(): + colors = np.zeros((len(points), 3), dtype=np.float32) + colors[:, 2] = 1.0 + else: + colors = color.reshape(-1, 3) + colors = np.asarray(colors) + if colors.dtype != np.uint8: + colors = (np.clip(colors, 0.0, 1.0) * 255).astype(np.uint8) - self.vis_pts_list.append(pred_pts) - handle = self.server.scene.add_point_cloud( - name=f"/frames/{step}/pred_pts", - points=pred_pts, - colors=pc_color, - point_size=self.psize_slider.value, - ) - self.pc_handles.append(handle) + if conf is None: + confidence = np.ones((len(points),), dtype=np.float32) + else: + confidence = np.asarray(conf, dtype=np.float32).reshape(-1) - def _regenerate_cameras(self): - """Regenerate camera visualizations with current settings.""" - if not hasattr(self, 'frame_nodes'): - return + if pc.ndim >= 2: + pixel_row, pixel_col = np.indices(pc.shape[:2], dtype=np.int32) + pixel_row = pixel_row.reshape(-1) + pixel_col = pixel_col.reshape(-1) + else: + pixel_row = np.zeros((len(points),), dtype=np.int32) + pixel_col = np.zeros((len(points),), dtype=np.int32) + + if len(points) == 0: + return { + "points": np.zeros((0, 3), dtype=np.float32), + "colors": np.zeros((0, 3), dtype=np.uint8), + "normals": np.zeros((0, 3), dtype=np.float32), + "confidence": np.zeros((0,), dtype=np.float32), + "frame_index": np.zeros((0,), dtype=np.int32), + "pixel_row": np.zeros((0,), dtype=np.int32), + "pixel_col": np.zeros((0,), dtype=np.int32), + } + + valid = np.isfinite(points).all(axis=1) + valid &= np.isfinite(confidence) + valid &= np.isfinite(normals).all(axis=1) + + points = points[valid] + colors = colors[valid] + normals = normals[valid] + confidence = confidence[valid] + pixel_row = pixel_row[valid] + pixel_col = pixel_col[valid] + + conf_mask = confidence > float(self.vis_threshold) + points = points[conf_mask] + colors = colors[conf_mask] + normals = normals[conf_mask] + confidence = confidence[conf_mask] + pixel_row = pixel_row[conf_mask] + pixel_col = pixel_col[conf_mask] + + downsample_factor = max(1, int(self.downsample_slider.value)) + if downsample_factor > 1 and len(points) > 0: + indices = np.arange(0, len(points), downsample_factor) + points = points[indices] + colors = colors[indices] + normals = normals[indices] + confidence = confidence[indices] + pixel_row = pixel_row[indices] + pixel_col = pixel_col[indices] + + frame_index = np.full((len(points),), int(step), dtype=np.int32) + return { + "points": np.asarray(points, dtype=np.float32), + "colors": np.asarray(colors, dtype=np.uint8), + "normals": np.asarray(normals, dtype=np.float32), + "confidence": np.asarray(confidence, dtype=np.float32), + "frame_index": frame_index, + "pixel_row": np.asarray(pixel_row, dtype=np.int32), + "pixel_col": np.asarray(pixel_col, dtype=np.int32), + } + def _get_visible_step_indices(self) -> List[int]: + """Get the frame indices currently visible in the browser scene.""" + if not hasattr(self, "gui_timestep"): + return [] + + current_idx = int(self.gui_timestep.value) + mode = getattr(self, "frame_visibility_mode", "current") + if mode == "current": + return [current_idx] + if mode == "all": + retained_end = max(current_idx, int(self.max_drawn_frame_idx)) + return list(range(0, retained_end + 1)) + + history = int(self.gui_history_frames.value) if hasattr(self, "gui_history_frames") else 0 + start_idx = max(0, current_idx - history) + return list(range(start_idx, current_idx + 1)) + + def _set_frame_visibility_mode(self, mode: str) -> None: + """Switch between current-frame, cumulative, and recent-history rendering.""" + self.frame_visibility_mode = mode + if mode == "all" and hasattr(self, "gui_timestep"): + self.max_drawn_frame_idx = int(self.gui_timestep.value) + if hasattr(self, "gui_history_frames"): + self.gui_history_frames.disabled = mode != "recent" + if hasattr(self, "gui_timestep"): + self.update_frame_visibility() + + def _limit_rendered_points( + self, + points: np.ndarray, + colors: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """Bound the browser payload to keep the UI responsive.""" + if not hasattr(self, "gui_max_render_points"): + return points, colors + + max_points = int(self.gui_max_render_points.value) + if max_points <= 0 or len(points) <= max_points: + return points, colors + + stride = max(1, int(np.ceil(len(points) / max_points))) + return points[::stride], colors[::stride] + + def _clear_active_cameras(self) -> None: + """Remove the currently rendered camera handles from the scene.""" for handle in self.cam_handles: try: handle.remove() @@ -615,77 +791,306 @@ def _regenerate_cameras(self): pass self.cam_handles.clear() - if self.show_camera: - downsample_factor = int(self.camera_downsample_slider.value) - for i, step in enumerate(self.all_steps): - if i % downsample_factor == 0: - self.add_camera(step) - - def _export_glb(self): - """Export current filtered point clouds and cameras as a GLB file.""" - try: - import trimesh - except ImportError: - self.glb_status.value = "Error: pip install trimesh" + def _refresh_active_scene(self) -> None: + """Refresh the bounded working set rendered in the browser.""" + if not hasattr(self, "gui_timestep"): return - self.glb_status.value = "Collecting points..." - print("Exporting GLB...") + visible_indices = self._get_visible_step_indices() + visible_steps = [self.all_steps[i] for i in visible_indices] - # Collect all currently visible, filtered points and colors all_points = [] all_colors = [] - for step in self.all_steps: - pc = self.pcs[step]["pc"] - color = self.pcs[step]["color"] - conf = self.pcs[step]["conf"] - edge_color = self.pcs[step].get("edge_color", None) - - pts, cols = self.parse_pc_data( - pc, color, conf, edge_color, set_border_color=False, - downsample_factor=self.downsample_slider.value, - ) - if len(pts) > 0: - all_points.append(pts) - if cols.dtype != np.uint8: - cols = (np.clip(cols, 0, 1) * 255).astype(np.uint8) - all_colors.append(cols) + for step in visible_steps: + pts, cols = self._get_filtered_frame_geometry(step) + if len(pts) == 0: + continue + all_points.append(pts) + all_colors.append(cols) - if not all_points: - self.glb_status.value = "Error: no points to export" - return + if all_points: + points = np.concatenate(all_points, axis=0) + colors = np.concatenate(all_colors, axis=0) + points, colors = self._limit_rendered_points(points, colors) + else: + points = np.zeros((0, 3), dtype=np.float32) + colors = np.zeros((0, 3), dtype=np.uint8) + + with self.server.atomic(): + if self.active_point_handle is None: + self.active_point_handle = self.server.scene.add_point_cloud( + name="/active/points", + points=points, + colors=colors, + point_size=self.psize_slider.value, + ) + else: + self.active_point_handle.points = points + self.active_point_handle.colors = colors + self.active_point_handle.point_size = self.psize_slider.value + self.active_point_handle.visible = len(points) > 0 - vertices = np.concatenate(all_points, axis=0) - colors_rgb = np.concatenate(all_colors, axis=0) + self._clear_active_cameras() + if not self.show_camera or self.cam_dict is None: + return - # --- Color enhancement --- - colors_float = colors_rgb.astype(np.float32) / 255.0 + downsample_factor = max(1, int(self.camera_downsample_slider.value)) + current_step = self.all_steps[int(self.gui_timestep.value)] + for offset, step in enumerate(visible_steps): + if step != current_step and (offset % downsample_factor) != 0: + continue + self.add_camera(step) - sat_boost = self.glb_saturation_slider.value + def _apply_export_color_adjustments( + self, + colors_rgb: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """Apply shared Export 3D color/opacity controls.""" + colors_float = np.asarray(colors_rgb, dtype=np.float32) / 255.0 + + sat_boost = float(self.glb_saturation_slider.value) if sat_boost != 1.0: gray = colors_float.mean(axis=1, keepdims=True) colors_float = gray + sat_boost * (colors_float - gray) - bri_boost = self.glb_brightness_slider.value + bri_boost = float(self.glb_brightness_slider.value) if bri_boost != 1.0: colors_float = colors_float * bri_boost - colors_float = np.clip(colors_float, 0.0, 1.0) - - # --- Opacity --- - # Simulate opacity by blending colors toward white (works in all viewers). - # For Spheres mode, also set true alpha for viewers that support it. - alpha = self.glb_opacity_slider.value + alpha = float(self.glb_opacity_slider.value) if alpha < 1.0: - bg = np.ones_like(colors_float) # white background + bg = np.ones_like(colors_float) colors_float = colors_float * alpha + bg * (1.0 - alpha) - colors_float = np.clip(colors_float, 0.0, 1.0) + colors_float = np.clip(colors_float, 0.0, 1.0) colors_u8 = (colors_float * 255).astype(np.uint8) + alpha_u8 = np.full( + (len(colors_u8),), + int(np.clip(round(alpha * 255), 0, 255)), + dtype=np.uint8, + ) + return colors_u8, alpha_u8 + + @staticmethod + def _slice_export_point_payload( + payload: Dict[str, np.ndarray], + indices: np.ndarray, + ) -> Dict[str, np.ndarray]: + """Slice every per-point array in an export payload consistently.""" + point_count = len(payload["points"]) + sliced: Dict[str, np.ndarray] = {} + for key, values in payload.items(): + if isinstance(values, np.ndarray) and len(values) == point_count: + sliced[key] = values[indices] + else: + sliced[key] = values + return sliced + + def _apply_export_mode_to_point_payload( + self, + payload: Dict[str, np.ndarray], + ) -> Dict[str, np.ndarray]: + """Apply Export 3D mode-specific point sampling/radius controls.""" + export_mode = self.glb_mode_dropdown.value + point_radius = 0.0 + updated_payload = dict(payload) + + if export_mode == "Spheres": + point_radius = float(self.glb_sphere_radius_slider.value) + max_pts = max(1, int(self.glb_max_sphere_pts_slider.value)) + if len(updated_payload["points"]) > max_pts: + indices = np.linspace( + 0, + len(updated_payload["points"]) - 1, + max_pts, + dtype=np.int64, + ) + indices = np.unique(indices) + updated_payload = self._slice_export_point_payload(updated_payload, indices) + + updated_payload["render_radius"] = np.full( + (len(updated_payload["points"]),), + point_radius, + dtype=np.float32, + ) + return updated_payload + + def _get_scene_alignment_extrinsics(self) -> np.ndarray: + """Build the reference extrinsics array used by GLB/PLY alignment.""" + if self.cam_dict is None or len(self.all_steps) == 0: + return np.zeros((0, 4, 4), dtype=np.float32) + + step0 = self.all_steps[0] + R0 = self.cam_dict["R"][step0] if "R" in self.cam_dict else np.eye(3) + t0 = self.cam_dict["t"][step0] if "t" in self.cam_dict else np.zeros(3) + c2w_0 = np.eye(4, dtype=np.float32) + c2w_0[:3, :3] = R0 + c2w_0[:3, 3] = t0 + w2c_0 = np.linalg.inv(c2w_0) + return np.expand_dims(w2c_0.astype(np.float32), 0) + + def _build_camera_export_payload(self) -> Optional[Dict[str, np.ndarray]]: + """Prepare aligned camera-pose metadata for companion PLY export.""" + if self.cam_dict is None or len(self.all_steps) == 0: + return None + + from lingbot_map.vis.glb_export import get_scene_alignment_transform + + num_cameras = len(self.all_steps) + camera_to_world = np.repeat(np.eye(4, dtype=np.float32)[None], num_cameras, axis=0) + focal = np.ones((num_cameras,), dtype=np.float32) + principal_x = np.ones((num_cameras,), dtype=np.float32) + principal_y = np.ones((num_cameras,), dtype=np.float32) + + for i, step in enumerate(self.all_steps): + R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3) + t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3) + pp = self.cam_dict["pp"][step] if "pp" in self.cam_dict else (1.0, 1.0) + focal[i] = self.cam_dict["focal"][step] if "focal" in self.cam_dict else 1.0 + principal_x[i] = pp[0] + principal_y[i] = pp[1] + camera_to_world[i, :3, :3] = R + camera_to_world[i, :3, 3] = t + + alignment_transform = get_scene_alignment_transform( + self._get_scene_alignment_extrinsics() + ) + aligned_camera_to_world = np.einsum( + "ij,njk->nik", alignment_transform, camera_to_world + ) + positions = aligned_camera_to_world[:, :3, 3].astype(np.float32) + rotations = aligned_camera_to_world[:, :3, :3].astype(np.float32) + quaternions = np.stack( + [np.asarray(tf.SO3.from_matrix(R).wxyz, dtype=np.float32) for R in rotations], + axis=0, + ) + forward = rotations[:, :, 2].astype(np.float32) + up = (-rotations[:, :, 1]).astype(np.float32) + right = rotations[:, :, 0].astype(np.float32) + + trajectory_distance = np.zeros((num_cameras,), dtype=np.float32) + if num_cameras > 1: + segment_lengths = np.linalg.norm(np.diff(positions, axis=0), axis=1).astype(np.float32) + trajectory_distance[1:] = np.cumsum(segment_lengths) + if num_cameras > 1 and trajectory_distance[-1] > 1e-8: + trajectory_u = trajectory_distance / trajectory_distance[-1] + else: + trajectory_u = np.zeros((num_cameras,), dtype=np.float32) + + aspect = np.where( + np.abs(principal_y) > 1e-6, + principal_x / principal_y, + 1.0, + ).astype(np.float32) + fov = (2.0 * np.arctan(principal_x / np.maximum(focal, 1e-6))).astype(np.float32) + camera_colors = (self.camera_colors[:, :3] * 255).astype(np.uint8) + + return { + "frame_index": np.asarray(self.all_steps, dtype=np.int32), + "tx": positions[:, 0], + "ty": positions[:, 1], + "tz": positions[:, 2], + "forward_x": forward[:, 0], + "forward_y": forward[:, 1], + "forward_z": forward[:, 2], + "up_x": up[:, 0], + "up_y": up[:, 1], + "up_z": up[:, 2], + "right_x": right[:, 0], + "right_y": right[:, 1], + "right_z": right[:, 2], + "qw": quaternions[:, 0], + "qx": quaternions[:, 1], + "qy": quaternions[:, 2], + "qz": quaternions[:, 3], + "r00": rotations[:, 0, 0], + "r01": rotations[:, 0, 1], + "r02": rotations[:, 0, 2], + "r10": rotations[:, 1, 0], + "r11": rotations[:, 1, 1], + "r12": rotations[:, 1, 2], + "r20": rotations[:, 2, 0], + "r21": rotations[:, 2, 1], + "r22": rotations[:, 2, 2], + "focal": focal, + "principal_x": principal_x, + "principal_y": principal_y, + "fov": fov, + "aspect": aspect, + "trajectory_u": trajectory_u.astype(np.float32), + "trajectory_distance": trajectory_distance, + "color_r": camera_colors[:, 0], + "color_g": camera_colors[:, 1], + "color_b": camera_colors[:, 2], + "display_scale": np.full( + (num_cameras,), + float(self.glb_cam_scale_slider.value), + dtype=np.float32, + ), + "frustum_thickness": np.full( + (num_cameras,), + float(self.glb_frustum_thickness_slider.value), + dtype=np.float32, + ), + "trajectory_radius": np.full( + (num_cameras,), + float(self.glb_trajectory_radius_slider.value), + dtype=np.float32, + ), + } + + @staticmethod + def _get_companion_trajectory_ply_path(output_path: str) -> str: + """Derive the companion trajectory PLY path from the main export path.""" + stem, ext = os.path.splitext(output_path) + if ext.lower() != ".ply": + ext = ".ply" + return f"{stem}_trajectory{ext}" + + def _build_export_comments(self) -> List[str]: + """Serialize Export 3D UI state into PLY header comments.""" + return [ + "generated_by lingbot-map PointCloudViewer", + f"export_mode {self.glb_mode_dropdown.value}", + f"include_cameras {int(self.glb_show_cam_checkbox.value)}", + f"show_trajectory {int(self.glb_trajectory_checkbox.value)}", + f"camera_scale {float(self.glb_cam_scale_slider.value):.6f}", + f"frustum_thickness {float(self.glb_frustum_thickness_slider.value):.6f}", + f"trajectory_radius {float(self.glb_trajectory_radius_slider.value):.6f}", + f"sphere_radius {float(self.glb_sphere_radius_slider.value):.6f}", + f"max_sphere_points {int(self.glb_max_sphere_pts_slider.value)}", + f"opacity {float(self.glb_opacity_slider.value):.6f}", + f"saturation_boost {float(self.glb_saturation_slider.value):.6f}", + f"brightness_boost {float(self.glb_brightness_slider.value):.6f}", + f"vis_threshold {float(self.vis_threshold):.6f}", + f"downsample_factor {int(self.downsample_slider.value)}", + ] + + def _export_glb(self): + """Export current filtered point clouds and cameras as a GLB file.""" + try: + import trimesh + except ImportError: + self.glb_status.value = "Error: pip install trimesh" + return + + self.glb_status.value = "Collecting points..." + print("Exporting GLB...") + + payload = self._collect_export_point_payload() + if len(payload["points"]) == 0: + self.glb_status.value = "Error: no points to export" + return + payload = self._apply_export_mode_to_point_payload(payload) + vertices = payload["points"] + colors_rgb = payload["colors"] + + colors_u8, alpha_u8 = self._apply_export_color_adjustments(colors_rgb) + alpha = float(self.glb_opacity_slider.value) colors_rgba = np.concatenate([ colors_u8, - np.full((len(colors_u8), 1), int(alpha * 255), dtype=np.uint8), - ], axis=1) # (N, 4) + alpha_u8[:, None], + ], axis=1) # Compute scene scale for camera sizing lo = np.percentile(vertices, 5, axis=0) @@ -698,15 +1103,7 @@ def _export_glb(self): export_mode = self.glb_mode_dropdown.value if export_mode == "Spheres": self.glb_status.value = "Building spheres..." - max_pts = int(self.glb_max_sphere_pts_slider.value) - radius = self.glb_sphere_radius_slider.value - - # Subsample if too many points - if len(vertices) > max_pts: - idx = np.random.choice(len(vertices), max_pts, replace=False) - idx.sort() - vertices = vertices[idx] - colors_rgba = colors_rgba[idx] + radius = float(self.glb_sphere_radius_slider.value) sphere_template = trimesh.creation.icosphere(subdivisions=1, radius=radius) n_verts_per = len(sphere_template.vertices) @@ -724,10 +1121,11 @@ def _export_glb(self): all_face_colors[f_off:f_off + n_faces_per] = rgba mesh = trimesh.Trimesh(vertices=all_verts, faces=all_faces) - mesh.visual.face_colors = all_face_colors + mesh_visual = cast(Any, mesh.visual) + mesh_visual.face_colors = all_face_colors # Enable alpha blending in glTF material for true transparency if alpha < 1.0: - mesh.visual.material.alphaMode = 'BLEND' + mesh_visual.material.alphaMode = "BLEND" scene_3d.add_geometry(mesh) print(f"Spheres mode: {len(vertices):,} spheres, {len(all_faces):,} faces") else: @@ -755,7 +1153,11 @@ def _export_glb(self): cam_positions.append(np.array(t, dtype=np.float64)) rgba_c = colormap(i / max(num_cameras - 1, 1)) - cam_color = tuple(int(255 * x) for x in rgba_c[:3]) + cam_color = ( + int(255 * rgba_c[0]), + int(255 * rgba_c[1]), + int(255 * rgba_c[2]), + ) integrate_camera_into_scene( scene_3d, c2w, cam_color, effective_cam_scale, @@ -773,16 +1175,9 @@ def _export_glb(self): scene_3d.add_geometry(traj_mesh) # Align scene using first camera extrinsic - if self.cam_dict is not None and len(self.all_steps) > 0: + extrinsics = self._get_scene_alignment_extrinsics() + if len(extrinsics) > 0: from lingbot_map.vis.glb_export import apply_scene_alignment - step0 = self.all_steps[0] - R0 = self.cam_dict["R"][step0] if "R" in self.cam_dict else np.eye(3) - t0 = self.cam_dict["t"][step0] if "t" in self.cam_dict else np.zeros(3) - c2w_0 = np.eye(4) - c2w_0[:3, :3] = R0 - c2w_0[:3, 3] = t0 - w2c_0 = np.linalg.inv(c2w_0) - extrinsics = np.expand_dims(w2c_0, 0) scene_3d = apply_scene_alignment(scene_3d, extrinsics) output_path = self.glb_output_path.value @@ -793,6 +1188,145 @@ def _export_glb(self): self.glb_status.value = f"Saved: {output_path} ({n_pts:,} {mode_str})" print(f"GLB exported to {output_path} ({n_pts:,} {mode_str})") + def _export_ply(self): + """Export the filtered point cloud, plus a companion trajectory PLY.""" + from lingbot_map.vis.glb_export import ( + apply_scene_alignment_to_directions, + apply_scene_alignment_to_vertices, + save_point_cloud_to_ply, + ) + + self.ply_status.value = "Collecting points..." + print("Exporting PLY...") + + payload = self._collect_export_point_payload() + if len(payload["points"]) == 0: + self.ply_status.value = "Error: no points to export" + return + payload = self._apply_export_mode_to_point_payload(payload) + + colors_u8, alpha_u8 = self._apply_export_color_adjustments(payload["colors"]) + vertices = payload["points"] + normals = payload["normals"] + extrinsics = self._get_scene_alignment_extrinsics() + if len(extrinsics) > 0: + vertices = apply_scene_alignment_to_vertices(vertices, extrinsics) + normals = apply_scene_alignment_to_directions(normals, extrinsics) + + camera_payload = self._build_camera_export_payload() + vertex_properties: Dict[str, np.ndarray] = { + "confidence": payload["confidence"], + "frame_index": payload["frame_index"], + "pixel_row": payload["pixel_row"], + "pixel_col": payload["pixel_col"], + "render_radius": payload["render_radius"], + } + + output_path = self.ply_output_path.value + save_point_cloud_to_ply( + vertices, + colors_u8, + output_path, + normals=normals, + alpha=alpha_u8, + extra_vertex_properties=vertex_properties, + comments=self._build_export_comments(), + ) + + saved_trajectory_path = None + saved_trajectory_poses = 0 + if self.glb_show_cam_checkbox.value and camera_payload is not None: + trajectory_path = self._get_companion_trajectory_ply_path(output_path) + trajectory_points = np.stack( + [camera_payload["tx"], camera_payload["ty"], camera_payload["tz"]], + axis=1, + ).astype(np.float32) + trajectory_colors = np.stack( + [camera_payload["color_r"], camera_payload["color_g"], camera_payload["color_b"]], + axis=1, + ).astype(np.uint8) + trajectory_normals = np.stack( + [ + camera_payload["forward_x"], + camera_payload["forward_y"], + camera_payload["forward_z"], + ], + axis=1, + ).astype(np.float32) + trajectory_vertex_properties = { + key: value + for key, value in camera_payload.items() + if key not in {"tx", "ty", "tz", "color_r", "color_g", "color_b", "forward_x", "forward_y", "forward_z"} + } + trajectory_comments = self._build_export_comments() + [ + "trajectory_vertices one_per_camera_pose", + f"source_pointcloud {os.path.basename(output_path)}", + ] + save_point_cloud_to_ply( + trajectory_points, + trajectory_colors, + trajectory_path, + normals=trajectory_normals, + extra_vertex_properties=trajectory_vertex_properties, + comments=trajectory_comments, + ) + saved_trajectory_path = trajectory_path + saved_trajectory_poses = len(camera_payload["frame_index"]) + + if saved_trajectory_path is not None: + self.ply_status.value = ( + f"Saved: {output_path} ({len(vertices):,} points) + " + f"{saved_trajectory_path} ({saved_trajectory_poses:,} poses)" + ) + print( + f"PLY exported to {output_path} ({len(vertices):,} points) " + f"and {saved_trajectory_path} ({saved_trajectory_poses:,} poses)" + ) + else: + self.ply_status.value = f"Saved: {output_path} ({len(vertices):,} points)" + print(f"PLY exported to {output_path} ({len(vertices):,} points)") + + def _collect_export_point_payload(self) -> Dict[str, np.ndarray]: + """Collect the filtered point-cloud payload used by GLB and PLY export.""" + all_points = [] + all_colors = [] + all_normals = [] + all_confidence = [] + all_frame_indices = [] + all_pixel_rows = [] + all_pixel_cols = [] + for step in self.all_steps: + frame_payload = self._extract_frame_export_samples(step) + if len(frame_payload["points"]) == 0: + continue + all_points.append(frame_payload["points"]) + all_colors.append(frame_payload["colors"]) + all_normals.append(frame_payload["normals"]) + all_confidence.append(frame_payload["confidence"]) + all_frame_indices.append(frame_payload["frame_index"]) + all_pixel_rows.append(frame_payload["pixel_row"]) + all_pixel_cols.append(frame_payload["pixel_col"]) + + if not all_points: + return { + "points": np.zeros((0, 3), dtype=np.float32), + "colors": np.zeros((0, 3), dtype=np.uint8), + "normals": np.zeros((0, 3), dtype=np.float32), + "confidence": np.zeros((0,), dtype=np.float32), + "frame_index": np.zeros((0,), dtype=np.int32), + "pixel_row": np.zeros((0,), dtype=np.int32), + "pixel_col": np.zeros((0,), dtype=np.int32), + } + return { + "points": np.concatenate(all_points, axis=0), + "colors": np.concatenate(all_colors, axis=0), + "normals": np.concatenate(all_normals, axis=0), + "confidence": np.concatenate(all_confidence, axis=0), + "frame_index": np.concatenate(all_frame_indices, axis=0), + "pixel_row": np.concatenate(all_pixel_rows, axis=0), + "pixel_col": np.concatenate(all_pixel_cols, axis=0), + } + @staticmethod def _build_trajectory_tube(positions, radius, colormap, num_cameras): """Build a tube mesh following camera trajectory with per-segment color. @@ -852,15 +1386,8 @@ def _build_trajectory_tube(positions, radius, colormap, num_cameras): return trimesh.util.concatenate(segments) def update_frame_visibility(self): - """Show all frames up to the current timestep (or only the current one in 4D mode).""" - if not hasattr(self, 'frame_nodes') or not hasattr(self, 'gui_timestep'): - return - - current_timestep = self.gui_timestep.value - for i, frame_node in enumerate(self.frame_nodes): - frame_node.visible = ( - i <= current_timestep if not self.fourd else i == current_timestep - ) + """Refresh the active browser scene for the current timestep.""" + self._refresh_active_scene() def _move_to_camera(self, frame_idx: int, smooth: bool = True): """Move viewer camera to match reconstructed camera at given frame.""" @@ -1045,7 +1572,7 @@ def read_data(self, pc_list, color_list, conf_list, edge_color_list=None): normalized_indices = np.array(list(range(num_cameras))) / (num_cameras - 1) else: normalized_indices = np.array([0.0]) - cmap = cm.get_cmap('viridis') + cmap = colormaps.get_cmap('viridis') self.camera_colors = cmap(normalized_indices) return pcs, step_list @@ -1095,28 +1622,6 @@ def parse_pc_data( return pred_pts, color - def add_pc(self, step): - """Add point cloud for a frame.""" - pc = self.pcs[step]["pc"] - color = self.pcs[step]["color"] - conf = self.pcs[step]["conf"] - edge_color = self.pcs[step].get("edge_color", None) - - pred_pts, color = self.parse_pc_data( - pc, color, conf, edge_color, set_border_color=True, - downsample_factor=self.downsample_slider.value - ) - - self.vis_pts_list.append(pred_pts) - self.pc_handles.append( - self.server.scene.add_point_cloud( - name=f"/frames/{step}/pred_pts", - points=pred_pts, - colors=color, - point_size=self.psize_slider.value, - ) - ) - def add_camera(self, step): """Add camera visualization for a frame.""" cam = self.cam_dict @@ -1128,28 +1633,18 @@ def add_camera(self, step): q = tf.SO3.from_matrix(R).wxyz fov = 2 * np.arctan(pp[0] / focal) aspect = pp[0] / pp[1] - self.traj_list.append((q, t)) step_index = self.all_steps.index(step) if step in self.all_steps else 0 camera_color = self.camera_colors[step_index] camera_color_rgb = tuple((camera_color[:3] * 255).astype(int)) - self.server.scene.add_frame( - f"/frames/{step}/camera_frame", - wxyz=q, - position=t, - axes_length=0.05, - axes_radius=0.002, - origin_radius=0.002, - ) - frustum_handle = self.server.scene.add_camera_frustum( - name=f"/frames/{step}/camera", + name=f"/active/cameras/{step}", fov=fov, aspect=aspect, wxyz=q, position=t, - scale=0.03, + scale=self.camsize_slider.value, color=camera_color_rgb, ) @@ -1172,9 +1667,25 @@ def animate(self): ) gui_next_frame = self.server.gui.add_button("Next Step", disabled=False) gui_prev_frame = self.server.gui.add_button("Prev Step", disabled=False) - gui_playing = self.server.gui.add_checkbox("Playing", True) + gui_playing = self.server.gui.add_checkbox("Playing", False) gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=20) gui_framerate_options = self.server.gui.add_button_group("FPS options", ("10", "20", "30", "60")) + self.gui_history_frames = self.server.gui.add_slider( + "Visible History Frames", + min=0, + max=self.max_history_frames, + step=1, + initial_value=0, + hint="Frames retained in the browser scene. Capped to keep Chrome responsive.", + ) + self.gui_max_render_points = self.server.gui.add_slider( + "Max Browser Points", + min=50_000, + max=1_000_000, + step=50_000, + initial_value=self.default_max_render_points, + hint="Hard cap on rendered points after history aggregation.", + ) @gui_next_frame.on_click def _(_) -> None: @@ -1194,45 +1705,34 @@ def _(_) -> None: def _(_) -> None: gui_framerate.value = int(gui_framerate_options.value) - prev_timestep = self.gui_timestep.value - @self.gui_timestep.on_update def _(_) -> None: - nonlocal prev_timestep current_timestep = self.gui_timestep.value + if self.frame_visibility_mode == "all": + self.max_drawn_frame_idx = max(self.max_drawn_frame_idx, int(current_timestep)) if self.current_frame_image is not None and hasattr(self, 'original_images'): if current_timestep < len(self.original_images): self.current_frame_image.image = self.original_images[current_timestep] - with self.server.atomic(): - self.frame_nodes[current_timestep].visible = True - self.frame_nodes[prev_timestep].visible = False - self.server.flush() - - prev_timestep = current_timestep - - self.server.scene.add_frame("/frames", show_axes=False) - self.frame_nodes = [] - for i in range(self.num_frames): - step = self.all_steps[i] - self.frame_nodes.append( - self.server.scene.add_frame(f"/frames/{step}", show_axes=False) - ) - self.add_pc(step) - if self.show_camera: - downsample_factor = int(self.camera_downsample_slider.value) - if i % downsample_factor == 0: - self.add_camera(step) + self.update_frame_visibility() + + @self.gui_history_frames.on_update + def _(_) -> None: + self.update_frame_visibility() + + @self.gui_max_render_points.on_update + def _(_) -> None: + self.update_frame_visibility() + + # Start paused on the first frame so large sequences do not immediately + # trigger browser-side playback / cumulative point-cloud rendering. + self.gui_history_frames.disabled = True + self.update_frame_visibility() - prev_timestep = self.gui_timestep.value while True: - if self.on_replay: - pass - else: - if gui_playing.value: - self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames - self.update_frame_visibility() + if (not self.on_replay) and gui_playing.value: + self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames time.sleep(1.0 / gui_framerate.value) diff --git a/pyproject.toml b/pyproject.toml index b99ce1e..9a48703 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,19 +2,24 @@ name = "lingbot-map" version = "0.1.0" description = "LingBot-Map: Geometric Context Transformer for Streaming 3D Reconstruction" -requires-python = ">= 3.10" +requires-python = "~=3.12.0" dependencies = [ + "torch>=2.12.1", + "torchvision>=0.27.1", + "flashinfer-python", "Pillow", "huggingface_hub", "einops", "safetensors", "opencv-python", "tqdm", - "scipy" + "scipy", + "pip>=26.1.2", + "flashinfer-cubin>=0.6.13", ] [project.optional-dependencies] -vis = ["viser>=0.2.23", "trimesh", "matplotlib", "onnxruntime", "requests"] +vis = ["viser>=0.2.23", "trimesh", "matplotlib", "onnxruntime-gpu", "requests"] demo = ["lingbot-map[vis]"] [build-system]