From 2ae8d7b32278c0884b989167d4579da3965c43a2 Mon Sep 17 00:00:00 2001 From: Nynxz Date: Thu, 8 Jan 2026 14:28:03 +1000 Subject: [PATCH 1/9] checkpoint: pre-abstraction image-centric architecture --- __init__.py | 4 + nodes/EmbeddrLoRAStack.py | 59 +++++++++++ nodes/EmbeddrLoadVideo.py | 190 +++++++++++++++++++++++++++++++++++ ui/main.tsx | 1 + ui/nodes/EmbeddrLoRAStack.ts | 118 ++++++++++++++++++++++ 5 files changed, 372 insertions(+) create mode 100644 nodes/EmbeddrLoRAStack.py create mode 100644 nodes/EmbeddrLoadVideo.py create mode 100644 ui/nodes/EmbeddrLoRAStack.ts diff --git a/__init__.py b/__init__.py index bda3f04..a7bcb88 100644 --- a/__init__.py +++ b/__init__.py @@ -11,6 +11,8 @@ from .nodes.EmbeddrFindSimilar import EmbeddrFindSimilarNode from .nodes.EmbeddrFindSimilarText import EmbeddrFindSimilarTextNode from .nodes.EmbeddrUploadVideo import EmbeddrUploadVideo +from .nodes.EmbeddrLoadVideo import EmbeddrLoadVideoNode +from .nodes.EmbeddrLoRAStack import EmbeddrLoRAStack CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.json") @@ -93,6 +95,8 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: EmbeddrMergeIDsNode, EmbeddrSaveToFolderNode, EmbeddrUploadVideo, + EmbeddrLoadVideoNode, + EmbeddrLoRAStack, ] diff --git a/nodes/EmbeddrLoRAStack.py b/nodes/EmbeddrLoRAStack.py new file mode 100644 index 0000000..08f6229 --- /dev/null +++ b/nodes/EmbeddrLoRAStack.py @@ -0,0 +1,59 @@ +import folder_paths +from comfy_api.latest import io + + +class EmbeddrLoRAStack(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + loras = folder_paths.get_filename_list("loras") + + inputs = [ + io.Model.Input("model"), + ] + + # Add 1 slot initially, frontend will handle the rest + inputs.append(io.Combo.Input( + "lora_1", default="None", options=["None"] + loras)) + inputs.append(io.Float.Input( + "strength_1", default=1.0, min=-10.0, max=10.0, step=0.01)) + + return io.Schema( + node_id="embeddr.LoRAStack", + display_name="Embeddr LoRA Stack", + description="Apply multiple LoRAs to a model and clip.", + category="Embeddr", + inputs=inputs, + outputs=[ + io.Model.Output("model"), + ], + ) + + @classmethod + def execute(cls, model, **kwargs): + import comfy.sd + import comfy.utils + + out_model = model + + # Iterate over all provided lora inputs + # We expect keys like lora_1, strength_1, lora_2, strength_2, etc. + + # Find all lora keys + lora_keys = [k for k in kwargs.keys() if k.startswith("lora_")] + # Sort them by index + lora_keys.sort(key=lambda x: int(x.split("_")[1])) + + for key in lora_keys: + i = key.split("_")[1] + lora_name = kwargs.get(f"lora_{i}") + strength = kwargs.get(f"strength_{i}", 1.0) + + if lora_name and lora_name != "None": + lora_path = folder_paths.get_full_path("loras", lora_name) + if lora_path: + lora = comfy.utils.load_torch_file( + lora_path, safe_load=True) + out_model, out_clip = comfy.sd.load_lora_for_models( + out_model, None, lora, strength, strength) + + return io.NodeOutput(out_model) diff --git a/nodes/EmbeddrLoadVideo.py b/nodes/EmbeddrLoadVideo.py new file mode 100644 index 0000000..8c7c242 --- /dev/null +++ b/nodes/EmbeddrLoadVideo.py @@ -0,0 +1,190 @@ +import requests +import torch +import numpy as np +import cv2 +import tempfile +import os +import shutil +from comfy_api.latest import io, ui +from .utils import get_config + + +class EmbeddrLoadVideoNode(io.ComfyNode): + _cache = {} + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.LoadVideo", + display_name="Embeddr Load Video", + description="Loads a video from Embeddr ID.", + category="Embeddr", + inputs=[ + io.String.Input("image_id", default=""), + io.Int.Input("frame_load_cap", default=0, min=0, max=100000, + step=1, tooltip="Stop loading after this many frames (0=all)"), + io.Int.Input("skip_first_frames", default=0, min=0, max=10000, + step=1, tooltip="Skip this many frames at the start"), + io.Int.Input("select_every_nth", default=1, min=1, + max=100, step=1, tooltip="Load every Nth frame"), + io.Int.Input("force_rate", default=0, min=0, max=120, + step=1, tooltip="Force playback FPS (0=original)"), + io.Int.Input("custom_width", default=0, min=0, max=4096, + step=8, tooltip="Resize width (0=original)"), + io.Int.Input("custom_height", default=0, min=0, max=4096, + step=8, tooltip="Resize height (0=original)"), + ], + outputs=[ + io.Image.Output("images"), + io.Int.Output("frame_count"), + # Removed experimental outputs for standard compatibility + # io.Output("audio"), + # io.Output("video_info"), + ], + ) + + @classmethod + def execute(cls, image_id, frame_load_cap, skip_first_frames, select_every_nth, force_rate, custom_width, custom_height): + if not image_id: + # Return empty + empty_image = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput(empty_image, 0) + + # URL construction + try: + config = get_config() + endpoint = config.get("endpoint", "http://localhost:8003") + endpoint = endpoint.rstrip("/") + api_url = f"{endpoint}/api/v1/images/{image_id}/file" + except Exception: + print(f"[Embeddr] Could not get config for endpoint") + empty_image = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput(empty_image, 0) + + # Download to temp file + temp_file_path = None + try: + # Check cache for path? For simplicity, we just download. + # In production, we should cache the file path if it's the same ID. + + # Stream download + with requests.get(api_url, stream=True) as r: + r.raise_for_status() + # Determine extension + content_type = r.headers.get('content-type', '') + ext = '.mp4' + if 'quicktime' in content_type: + ext = '.mov' + if 'webm' in content_type: + ext = '.webm' + + with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as f: + shutil.copyfileobj(r.raw, f) + temp_file_path = f.name + + # Open with CV2 + cap = cv2.VideoCapture(temp_file_path) + if not cap.isOpened(): + raise ValueError( + f"Could not open video file: {temp_file_path}") + + fps = cap.get(cv2.CAP_PROP_FPS) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = total_frames / fps if fps > 0 else 0 + + # Determine target dimensions + target_w, target_h = width, height + if custom_width > 0: + target_w = custom_width + if custom_height == 0: + target_h = int(height * (custom_width / width)) + if custom_height > 0: + target_h = custom_height + if custom_width == 0: + target_w = int(width * (custom_height / height)) + + # Ensure divisible by 2 for some codecs if needed, but for simple tensor it's fine. + # ComfyUI usually expects divisible by 8 for VAEs? + # We won't enforce it unless user sets it. + + frames = [] + + current_frame = 0 + collected = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + # Skip logic + if current_frame < skip_first_frames: + current_frame += 1 + continue + + if (current_frame - skip_first_frames) % select_every_nth != 0: + current_frame += 1 + continue + + # Process frame + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + if target_w != width or target_h != height: + frame = cv2.resize( + frame, (target_w, target_h), interpolation=cv2.INTER_LINEAR) + + # Normalize to 0-1 + frame = frame.astype(np.float32) / 255.0 + frames.append(frame) + + collected += 1 + if frame_load_cap > 0 and collected >= frame_load_cap: + break + + current_frame += 1 + cap.release() + + if not frames: + empty_image = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput(empty_image, 0) + + video_tensor = torch.from_numpy(np.stack(frames)) + # Shape is (B, H, W, C) + + final_fps = force_rate if force_rate > 0 else fps + + video_info = { + "source_fps": fps, + "source_frame_count": total_frames, + "source_duration": duration, + "source_width": width, + "source_height": height, + "loaded_fps": final_fps, + "loaded_frame_count": len(frames), + "loaded_width": target_w, + "loaded_height": target_h, + } + + # Audio - currently returning None/Empty as we don't have audio extraction logic + # To support audio properly we'd need audio libraries unavailable in minimal env. + audio = None + + return io.NodeOutput(video_tensor, len(frames)) + + except Exception as e: + print(f"[Embeddr] Error loading video: {e}") + empty_image = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput(empty_image, 0) + + finally: + if temp_file_path and os.path.exists(temp_file_path): + try: + os.remove(temp_file_path) + except: + pass diff --git a/ui/main.tsx b/ui/main.tsx index b8dcd24..b7697c9 100644 --- a/ui/main.tsx +++ b/ui/main.tsx @@ -8,6 +8,7 @@ import EmbeddrPanel from "./components/panels/EmbeddrPanel.js"; import { GlobalDialog } from "./components/GlobalDialog"; import "./nodes/EmbeddrLoadImage.js"; import "./nodes/EmbeddrMergeIds.js"; +import "./nodes/EmbeddrLoRAStack.js"; // @ts-ignore import "./globals.css"; diff --git a/ui/nodes/EmbeddrLoRAStack.ts b/ui/nodes/EmbeddrLoRAStack.ts new file mode 100644 index 0000000..d3ed531 --- /dev/null +++ b/ui/nodes/EmbeddrLoRAStack.ts @@ -0,0 +1,118 @@ +import { app } from "../../../scripts/app.js"; + +const _ID = "embeddr.LoRAStack"; + +app.registerExtension({ + name: "embeddr.dynamic_lora_stack", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name !== _ID) { + return; + } + + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + const me = onNodeCreated?.apply(this); + + // We expect lora_1 and strength_1 to exist from the python definition + const loraWidget = this.widgets.find((w) => w.name === "lora_1"); + if (!loraWidget) return me; + + // Store the options for creating new widgets + this.loraOptions = loraWidget.options.values; + + // Helper to add a new pair of widgets + this.addLoRAPair = (index) => { + const loraName = `lora_${index}`; + const strengthName = `strength_${index}`; + + // Add LoRA Combo + const w = this.addWidget( + "combo", + loraName, + "None", + (v) => { + this.updateWidgets(); + }, + { values: this.loraOptions } + ); + + // Add Strength Float + const s = this.addWidget("number", strengthName, 1.0, (v) => {}, { + min: -10.0, + max: 10.0, + step: 0.01, + precision: 2, + }); + + return { w, s }; + }; + + // Helper to update widgets based on values + this.updateWidgets = () => { + const widgets = this.widgets; + // Find all lora widgets + const loraWidgets = widgets.filter((w) => w.name.startsWith("lora_")); + + // Sort by index + loraWidgets.sort((a, b) => { + const idxA = parseInt(a.name.split("_")[1]); + const idxB = parseInt(b.name.split("_")[1]); + return idxA - idxB; + }); + + const lastWidget = loraWidgets[loraWidgets.length - 1]; + const lastIndex = parseInt(lastWidget.name.split("_")[1]); + + // If last widget has a value != "None", add a new one + if (lastWidget.value !== "None") { + this.addLoRAPair(lastIndex + 1); + } + + // If we have more than 1 widget, and the last TWO are "None", remove the last one + // Actually, we just want to ensure there is exactly one "None" at the end? + // Or maybe just ensure there is at least one "None" at the end. + // And remove any "None" that are not at the end? No, user might want to skip one. + + // Let's stick to: Always have one empty slot at the end. + // If the last one is filled, add one. + // If the second to last one is empty AND the last one is empty, remove the last one. + + if (loraWidgets.length > 1) { + const secondLastWidget = loraWidgets[loraWidgets.length - 2]; + if ( + lastWidget.value === "None" && + secondLastWidget.value === "None" + ) { + // Remove the last pair + // We need to remove both lora and strength widgets + const strengthName = `strength_${lastIndex}`; + + // Find index of widgets to remove + const wIndex = this.widgets.findIndex( + (w) => w.name === lastWidget.name + ); + if (wIndex > -1) this.widgets.splice(wIndex, 1); + + const sIndex = this.widgets.findIndex( + (w) => w.name === strengthName + ); + if (sIndex > -1) this.widgets.splice(sIndex, 1); + + // Resize node + this.onResize?.(this.size); + this.graph?.setDirtyCanvas(true); + } + } + }; + + // Hook into the first widget's callback + const originalCallback = loraWidget.callback; + loraWidget.callback = (v) => { + originalCallback?.(v); + this.updateWidgets(); + }; + + return me; + }; + }, +}); From 1d71fec77c1a7e60a0b82a5cc8e96cd0c4680fbb Mon Sep 17 00:00:00 2001 From: Nynxz Date: Tue, 27 Jan 2026 17:58:03 +1000 Subject: [PATCH 2/9] checkpoint: pre execution refine --- __init__.py | 18 +- nodes/EmbeddrFindCollection.py | 97 ++++++++++ nodes/EmbeddrFindSimilar.py | 111 ----------- nodes/EmbeddrFindSimilarArtifacts.py | 92 +++++++++ nodes/EmbeddrLoadArtifact.py | 141 ++++++++++++++ nodes/EmbeddrLoadArtifacts.py | 180 ++++++++++++++++++ nodes/EmbeddrLoadImage.py | 75 -------- nodes/EmbeddrLoadImages.py | 180 +++++++++++++----- nodes/EmbeddrUploadArtifact.py | 131 +++++++++++++ nodes/EmbeddrUploadImage.py | 162 ---------------- nodes/EmbeddrUploadVideo.py | 7 +- nodes/utils/__init__.py | 4 +- nodes/utils/api.py | 12 +- nodes/utils/config.py | 45 +++++ package.json | 1 + ui/components/GlobalDialog.tsx | 151 ++++++++++++--- ui/components/panels/EmbeddrPanel.tsx | 5 +- ui/components/panels/ImageDetails.tsx | 178 ++++++++++++++++- .../selectors/CollectionSelector.tsx | 130 +++++++++++++ ui/components/tabs/ExploreTab.tsx | 70 ++++--- ui/components/ui/ImageGrid.tsx | 10 +- ui/components/ui/SearchBar.tsx | 28 +-- ui/hooks/useEmbeddrApi.ts | 22 ++- ui/hooks/useEmbeddrCollections.ts | 122 ++++++++++++ ui/hooks/useEmbeddrImages.ts | 173 +++++++++++++---- ui/hooks/useEmbeddrSettings.ts | 2 +- ui/main.tsx | 4 +- ui/nodes/EmbeddrFindCollection.ts | 52 +++++ ui/nodes/EmbeddrLoadArtifact.ts | 51 +++++ ui/nodes/EmbeddrLoadImage.ts | 35 ---- ui/nodes/EmbeddrUploadArtifact.ts | 67 +++++++ ui/utils/nodeDragAndDrop.ts | 65 +++++++ 32 files changed, 1867 insertions(+), 554 deletions(-) create mode 100644 nodes/EmbeddrFindCollection.py delete mode 100644 nodes/EmbeddrFindSimilar.py create mode 100644 nodes/EmbeddrFindSimilarArtifacts.py create mode 100644 nodes/EmbeddrLoadArtifact.py create mode 100644 nodes/EmbeddrLoadArtifacts.py delete mode 100644 nodes/EmbeddrLoadImage.py create mode 100644 nodes/EmbeddrUploadArtifact.py delete mode 100644 nodes/EmbeddrUploadImage.py create mode 100644 ui/components/selectors/CollectionSelector.tsx create mode 100644 ui/hooks/useEmbeddrCollections.ts create mode 100644 ui/nodes/EmbeddrFindCollection.ts create mode 100644 ui/nodes/EmbeddrLoadArtifact.ts delete mode 100644 ui/nodes/EmbeddrLoadImage.ts create mode 100644 ui/nodes/EmbeddrUploadArtifact.ts create mode 100644 ui/utils/nodeDragAndDrop.ts diff --git a/__init__.py b/__init__.py index a7bcb88..c669560 100644 --- a/__init__.py +++ b/__init__.py @@ -4,15 +4,16 @@ from server import PromptServer from comfy_api.latest import ComfyExtension, io -from .nodes.EmbeddrUploadImage import EmbeddrSaveToFolderNode -from .nodes.EmbeddrLoadImage import EmbeddrLoadImageNode -from .nodes.EmbeddrLoadImages import EmbeddrLoadImagesNode +from .nodes.EmbeddrUploadArtifact import EmbeddrUploadArtifactNode +from .nodes.EmbeddrLoadArtifact import EmbeddrLoadArtifactNode +from .nodes.EmbeddrLoadArtifacts import EmbeddrLoadArtifactsNode from .nodes.EmbeddrMergeIDs import EmbeddrMergeIDsNode -from .nodes.EmbeddrFindSimilar import EmbeddrFindSimilarNode +from .nodes.EmbeddrFindSimilarArtifacts import EmbeddrFindSimilarArtifactsNode from .nodes.EmbeddrFindSimilarText import EmbeddrFindSimilarTextNode from .nodes.EmbeddrUploadVideo import EmbeddrUploadVideo from .nodes.EmbeddrLoadVideo import EmbeddrLoadVideoNode from .nodes.EmbeddrLoRAStack import EmbeddrLoRAStack +from .nodes.EmbeddrFindCollection import EmbeddrFindCollectionNode CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.json") @@ -88,15 +89,16 @@ async def get_config(request): class EmbeddrComfyUIExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - EmbeddrFindSimilarNode, + EmbeddrFindSimilarArtifactsNode, EmbeddrFindSimilarTextNode, - EmbeddrLoadImageNode, - EmbeddrLoadImagesNode, + EmbeddrLoadArtifactNode, + EmbeddrLoadArtifactsNode, EmbeddrMergeIDsNode, - EmbeddrSaveToFolderNode, + EmbeddrUploadArtifactNode, EmbeddrUploadVideo, EmbeddrLoadVideoNode, EmbeddrLoRAStack, + EmbeddrFindCollectionNode, ] diff --git a/nodes/EmbeddrFindCollection.py b/nodes/EmbeddrFindCollection.py new file mode 100644 index 0000000..871b2af --- /dev/null +++ b/nodes/EmbeddrFindCollection.py @@ -0,0 +1,97 @@ +import requests +from comfy_api.latest import io +from .utils import get_config + + +def Embeddr_Log(message: str): + print(f"[Embeddr] {message}") + + +class EmbeddrFindCollectionNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.FindCollection", + display_name="Embeddr Find Collection", + category="Embeddr", + inputs=[ + io.String.Input("collection_name", default="", optional=True, + tooltip="Name to find (or create if missing)"), + io.String.Input("collection_id", default="", + tooltip="Direct Collection ID (overrides Name)"), + io.Boolean.Input("create_if_missing", default=True, + tooltip="Create collection if it doesn't exist (Only applies to Name)"), + ], + outputs=[ + io.String.Output("collection_id"), + ] + ) + + @classmethod + def execute(cls, collection_name, create_if_missing, collection_id=""): + Embeddr_Log( + f"EXECUTE FindCollection: name='{collection_name}', id='{collection_id}', create={create_if_missing}") + config = get_config() + base_url = config.get("embeddr_url") or config.get( + "endpoint") or "http://localhost:8003" + base_url = base_url.rstrip("/") + + # 1. Direct ID Priority + if collection_id and len(str(collection_id).strip()) > 10: + Embeddr_Log(f"Using Direct Collection ID: {collection_id}") + # Assume valid UUID if present + return io.NodeOutput(collection_id) + + try: + # 2. List Collections to Find by Name + # Note: Removed limit=1000 to avoid potential 422 if API doesn't support it + resp = requests.get(f"{base_url}/api/v2/collections") + # If 404, maybe endpoint is different. + if resp.status_code == 404: + # Fallback to V1? Or just fail. + pass + + collections = [] + if resp.status_code == 200: + data = resp.json() + if isinstance(data, list): + collections = data + elif isinstance(data, dict) and "items" in data: + collections = data["items"] + + found = None + if collection_name: + for c in collections: + # Case insensitive match? user might prefer exact. + # API returns 'label' usually, but maybe 'name' in some versions + label = c.get("label") or c.get("name") + if label and label.lower() == collection_name.lower(): + found = c + break + + if found: + Embeddr_Log( + f"Found Collection: {found.get('label', 'Unnamed')} ({found.get('id')})") + return io.NodeOutput(str(found.get("id"))) + + if collection_name and create_if_missing: + # Create + payload = {"label": collection_name, + "type_name": "collection:mix", + "uri": f"embeddr:///collections/{collection_name.lower().replace(' ', '_')}"} + resp = requests.post( + f"{base_url}/api/v2/collections", json=payload) + resp.raise_for_status() + new_col = resp.json() + Embeddr_Log( + f"Created Collection: {new_col.get('label')} ({new_col.get('id')})") + return io.NodeOutput(str(new_col.get("id"))) + + Embeddr_Log( + f"Collection '{collection_name}' not found and creation disabled.") + # Fallback to empty string + return io.NodeOutput("") + + except Exception as e: + Embeddr_Log(f"FindCollection error: {e}") + return io.NodeOutput("") diff --git a/nodes/EmbeddrFindSimilar.py b/nodes/EmbeddrFindSimilar.py deleted file mode 100644 index 9abd896..0000000 --- a/nodes/EmbeddrFindSimilar.py +++ /dev/null @@ -1,111 +0,0 @@ -import requests -import torch -import numpy as np -from PIL import Image -import io as pyio -from comfy_api.latest import io, ui -from .utils import get_config -from .utils.api import get_libraries, get_collections - - -class EmbeddrFindSimilarNode(io.ComfyNode): - @classmethod - def define_schema(cls) -> io.Schema: - # Fetch dynamic options - libraries = ["All"] + get_libraries() - collections = ["All"] + get_collections() - - return io.Schema( - node_id="embeddr.FindSimilar", - display_name="Embeddr Find Similar", - description="Finds similar images in your Embeddr library using an input image.", - category="Embeddr", - inputs=[ - io.Image.Input("image"), - io.Combo.Input("library", options=libraries, default="All"), - io.Combo.Input( - "collection", options=collections, default="All"), - io.Int.Input("limit", default=5, min=1, max=50), - io.Float.Input("threshold", default=0.0, min=0.0, - max=1.0, step=0.01, display_name="Min Score"), - ], - outputs=[ - io.Image.Output("images", is_output_list=True), - io.String.Output("embeddr_ids", is_output_list=True), - ], - ) - - @classmethod - def execute(cls, image, library="All", collection="All", limit=5, threshold=0.0): - config = get_config() - endpoint = config.get("endpoint", "http://localhost:8003") - api_url = endpoint.rstrip("/") + "/api/v1/images/search/image" - - # Prepare image (take first of batch) - img_array = (image[0].cpu().numpy() * 255).astype(np.uint8) - img = Image.fromarray(np.clip(img_array, 0, 255)) - - buf = pyio.BytesIO() - img.save(buf, format="PNG") - buf.seek(0) - - files = {"file": ("image.png", buf, "image/png")} - data = { - "limit": limit, - } - - # Parse IDs from "ID: Name" format - if library != "All": - try: - lib_id = int(library.split(":")[0]) - data["library_id"] = lib_id - except: - pass - - if collection != "All": - try: - col_id = int(collection.split(":")[0]) - data["collection_id"] = col_id - except: - pass - - try: - response = requests.post(api_url, files=files, data=data) - response.raise_for_status() - results = response.json() - items = results.get("items", []) - - if not items: - # Return empty - empty = torch.zeros( - (1, 64, 64, 3), dtype=torch.float32, device="cpu") - return io.NodeOutput(empty, "[]") - - # Load images - output_images = [] - output_ids = [] - - for item in items: - # Fetch image file - img_url = endpoint.rstrip( - "/") + f"/api/v1/images/{item['id']}/file" - img_resp = requests.get(img_url) - if img_resp.status_code == 200: - i = Image.open(pyio.BytesIO(img_resp.content)) - i = i.convert("RGB") - i = np.array(i).astype(np.float32) / 255.0 - output_images.append(torch.from_numpy(i)) - output_ids.append(str(item['id'])) - - if not output_images: - return io.NodeOutput([], []) - - # Return list of images (unsqueeze to add batch dimension 1) - final_images = [img.unsqueeze(0) for img in output_images] - return io.NodeOutput(final_images, output_ids) - - except Exception as e: - print(f"[Embeddr] Search failed: {e}") - empty = torch.zeros( - (1, 64, 64, 3), dtype=torch.float32, device="cpu") - return io.NodeOutput(empty, "[]") diff --git a/nodes/EmbeddrFindSimilarArtifacts.py b/nodes/EmbeddrFindSimilarArtifacts.py new file mode 100644 index 0000000..bf5c8bc --- /dev/null +++ b/nodes/EmbeddrFindSimilarArtifacts.py @@ -0,0 +1,92 @@ +import requests +import torch +import numpy as np +from PIL import Image +import io as pyio +from comfy_api.latest import io, ui +from .utils import get_config + + +class EmbeddrFindSimilarArtifactsNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.FindSimilarArtifacts", + display_name="Embeddr Find Similar Artifacts (V2)", + description="Finds similar artifacts using an input image via V2 API.", + category="Embeddr", + inputs=[ + io.Image.Input("image"), + io.Int.Input("limit", default=5, min=1, max=50), + ], + outputs=[ + io.Image.Output("images", is_output_list=True), + io.String.Output("artifact_ids", is_output_list=True), + ], + ) + + @classmethod + def execute(cls, image, limit): + config = get_config() + base_url = config.get("embeddr_url") or config.get( + "endpoint") or "http://localhost:8003" + base_url = base_url.rstrip("/") + + # Endpoint in Plugin + api_url = f"{base_url}/api/v2/plugins/embeddr-comfyui/find_similar" + + # Prepare image (take first of batch for query) + img_array = (image[0].cpu().numpy() * 255).astype(np.uint8) + img = Image.fromarray(np.clip(img_array, 0, 255)) + + buf = pyio.BytesIO() + img.save(buf, format="PNG") + buf.seek(0) + + files = {"file": ("query.png", buf, "image/png")} + data = {"limit": limit} + + try: + # Upload & Search + response = requests.post(api_url, files=files, data=data) + response.raise_for_status() + results = response.json() + items = results.get("items", []) # List of objects {id, uri, ...} + + if not items: + empty = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput([empty], ["-1"]) + + output_images = [] + output_ids = [] + + for item in items: + art_id = item.get("id") + content_url = f"{base_url}/api/v2/plugins/embeddr-comfyui/content/{art_id}" + + try: + img_resp = requests.get(content_url) + if img_resp.status_code == 200: + i = Image.open(pyio.BytesIO(img_resp.content)) + i = i.convert("RGB") + i_np = np.array(i).astype(np.float32) / 255.0 + # Add batch dim [1, H, W, C] + output_images.append(torch.from_numpy(i_np)[None,]) + output_ids.append(str(art_id)) + except Exception as e: + print( + f"Failed to fetch content for similar item {art_id}: {e}") + + if not output_images: + empty = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput([empty], ["-1"]) + + return io.NodeOutput(output_images, output_ids) + + except Exception as e: + print(f"[Embeddr] FindSimilar Error: {e}") + empty = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput([empty], ["-1"]) diff --git a/nodes/EmbeddrLoadArtifact.py b/nodes/EmbeddrLoadArtifact.py new file mode 100644 index 0000000..ce540a1 --- /dev/null +++ b/nodes/EmbeddrLoadArtifact.py @@ -0,0 +1,141 @@ +import requests +from urllib.parse import urljoin, urlparse +import logging +import os +import torch +import numpy as np +from PIL import Image, ImageOps +from io import BytesIO +from comfy_api.latest import io, ui +from .utils import get_config + + +class EmbeddrLoadArtifactNode(io.ComfyNode): + _cache = {} + + _logger = logging.getLogger("embeddr.comfyui.load_artifact") + + @classmethod + def _debug(cls, message: str, **fields): + if os.environ.get("EMBEDDR_COMFYUI_DEBUG", "").lower() not in { + "1", + "true", + "yes", + }: + return + try: + cls._logger.info("%s | %s", message, fields) + except Exception: + cls._logger.info("%s", message) + + @classmethod + def _resolve_artifact_url(cls, base_url: str, artifact_id: str): + resolve_url = f"{base_url}/api/v2/artifacts/{artifact_id}/resolve?variant=original&proxy=1" + cls._debug("resolving_artifact", artifact_id=artifact_id, + resolve_url=resolve_url) + res = requests.get(resolve_url) + res.raise_for_status() + data = res.json() + url = data.get("url") + headers = data.get("headers") or {} + if url and url.startswith("/"): + url = urljoin(base_url, url) + + if url and "/api/v2/artifacts/" in url and "/content" in url and "proxy=" not in url: + url = f"{url}?proxy=1" + + base_netloc = urlparse(base_url).netloc + url_netloc = urlparse(url).netloc if url else "" + if url and base_netloc and url_netloc and url_netloc != base_netloc: + proxy_url = f"{base_url}/api/v2/artifacts/{artifact_id}/content?proxy=1" + cls._debug( + "forcing_proxy_url", + artifact_id=artifact_id, + resolved_url=url, + proxy_url=proxy_url, + ) + return proxy_url, {} + cls._debug("resolved_artifact", artifact_id=artifact_id, + url=url, headers=headers) + return url, headers + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.LoadArtifact", + display_name="Embeddr Load Artifact (V2)", + description="Loads an image/artifact from Embeddr by UUID.", + category="Embeddr", + inputs=[ + io.String.Input("artifact_id", default="", + tooltip="UUID of the artifact to load"), + io.Boolean.Input("use_cache", default=True) + ], + outputs=[ + io.Image.Output("image"), + io.Mask.Output("mask"), + io.String.Output("artifact_id_out"), + ], + ) + + @classmethod + def execute(cls, artifact_id, use_cache): + if not artifact_id: + # Return empty black image if no ID + empty_image = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + empty_mask = torch.zeros( + (1, 64, 64), dtype=torch.float32, device="cpu") + return io.NodeOutput(empty_image, empty_mask, "") + + if use_cache and artifact_id in cls._cache: + image, mask = cls._cache[artifact_id] + return io.NodeOutput(image, mask, artifact_id) + + try: + config = get_config() + base_url = config.get("embeddr_url") or config.get( + "endpoint") or "http://localhost:8003" + base_url = base_url.rstrip("/") + endpoint, content_headers = cls._resolve_artifact_url( + base_url, artifact_id + ) + + cls._debug( + "requesting_artifact_content", + artifact_id=artifact_id, + endpoint=endpoint, + ) + + response = requests.get(endpoint, headers=content_headers) + response.raise_for_status() + cls._debug( + "artifact_content_response", + status=response.status_code, + content_type=response.headers.get("content-type"), + length=response.headers.get("content-length"), + ) + img = Image.open(BytesIO(response.content)) + + img = ImageOps.exif_transpose(img) + image = img.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + + if 'A' in img.getbands(): + mask = np.array(img.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") + + if use_cache: + cls._cache[artifact_id] = (image, mask) + + return io.NodeOutput(image, mask, artifact_id, ui=ui.PreviewImage(image)) + + except Exception as e: + print(f"[Embeddr] Error loading artifact {artifact_id}: {e}") + cls._debug("artifact_load_failed", + artifact_id=artifact_id, error=str(e)) + # Raise error to stop workflow if loading fails + raise e diff --git a/nodes/EmbeddrLoadArtifacts.py b/nodes/EmbeddrLoadArtifacts.py new file mode 100644 index 0000000..a3a589b --- /dev/null +++ b/nodes/EmbeddrLoadArtifacts.py @@ -0,0 +1,180 @@ +import requests +import logging +import os +from urllib.parse import urljoin, urlparse +import torch +import numpy as np +from PIL import Image, ImageOps +from io import BytesIO +from comfy_api.latest import io, ui +from .utils import get_config +from .utils.api import get_collections + + +class EmbeddrLoadArtifactsNode(io.ComfyNode): + _cache = {} + + _logger = logging.getLogger("embeddr.comfyui.load_artifacts") + + @classmethod + def _debug(cls, message: str, **fields): + if os.environ.get("EMBEDDR_COMFYUI_DEBUG", "").lower() not in { + "1", + "true", + "yes", + }: + return + try: + cls._logger.info("%s | %s", message, fields) + except Exception: + cls._logger.info("%s", message) + + @classmethod + def _resolve_artifact_url(cls, base_url: str, artifact_id: str): + resolve_url = f"{base_url}/api/v2/artifacts/{artifact_id}/resolve?variant=original&proxy=1" + cls._debug("resolving_artifact", artifact_id=artifact_id, + resolve_url=resolve_url) + res = requests.get(resolve_url) + res.raise_for_status() + data = res.json() + url = data.get("url") + headers = data.get("headers") or {} + if url and url.startswith("/"): + url = urljoin(base_url, url) + + if url and "/api/v2/artifacts/" in url and "/content" in url and "proxy=" not in url: + url = f"{url}?proxy=1" + + base_netloc = urlparse(base_url).netloc + url_netloc = urlparse(url).netloc if url else "" + if url and base_netloc and url_netloc and url_netloc != base_netloc: + proxy_url = f"{base_url}/api/v2/artifacts/{artifact_id}/content?proxy=1" + cls._debug( + "forcing_proxy_url", + artifact_id=artifact_id, + resolved_url=url, + proxy_url=proxy_url, + ) + return proxy_url, {} + cls._debug("resolved_artifact", artifact_id=artifact_id, + url=url, headers=headers) + return url, headers + + @classmethod + def define_schema(cls) -> io.Schema: + collections = ["All"] + get_collections() + + return io.Schema( + node_id="embeddr.LoadArtifacts", + display_name="Embeddr Load Artifacts (V2)", + description="Loads generic artifacts (images) from Embeddr using V2 API.", + category="Embeddr", + inputs=[ + io.Combo.Input( + "collection", options=collections, default="All"), + io.Combo.Input("sort_by", options=[ + "newest", "random"], default="newest"), + io.Int.Input("limit", default=5, min=1, max=50), + io.Int.Input("seed", default=0, display_name="Random Seed"), + ], + outputs=[ + io.Image.Output("images", is_output_list=True), + io.String.Output("artifact_ids", is_output_list=True), + io.Mask.Output("masks", is_output_list=True), + ], + ) + + @classmethod + def execute(cls, collection, sort_by, limit, seed): + cache_key = (collection, sort_by, limit, seed) + if cache_key in cls._cache: + return cls._cache[cache_key] + + try: + config = get_config() + base_url = config.get("embeddr_url") or config.get( + "endpoint") or "http://localhost:8003" + base_url = base_url.rstrip("/") + + # List Artifacts + api_url = f"{base_url}/api/v2/artifacts/" + params = { + "limit": limit, + "type_name": "image", + "offset": 0 + } + + if collection != "All": + try: + col_id = collection.split(":")[0].strip() + params["collection_id"] = col_id + except: + pass + + if sort_by == "random": + params["sort"] = "random" + params["seed"] = seed # Pass seed if API supports it + else: + params["sort"] = "new" + + response = requests.get(api_url, params=params) + response.raise_for_status() + data = response.json() + items = data.get("items", []) + + if not items: + return cls._return_empty() + + images_list = [] + masks_list = [] + ids_list = [] + + for item in items: + art_id = item.get("id") + + content_url, content_headers = cls._resolve_artifact_url( + base_url, art_id) + try: + c_resp = requests.get(content_url, headers=content_headers) + c_resp.raise_for_status() + img = Image.open(BytesIO(c_resp.content)) + img = ImageOps.exif_transpose(img) + + i = img.convert("RGB") + i_np = np.array(i).astype(np.float32) / 255.0 + images_list.append(torch.from_numpy(i_np)) + + if 'A' in img.getbands(): + m_np = np.array(img.getchannel('A')).astype( + np.float32) / 255.0 + masks_list.append(1. - torch.from_numpy(m_np)) + else: + masks_list.append(torch.zeros( + (i_np.shape[0], i_np.shape[1]), dtype=torch.float32, device="cpu")) + + ids_list.append(str(art_id)) + except Exception as e: + print(f"Failed to load artifact {art_id}: {e}") + cls._debug("artifact_load_failed", + artifact_id=str(art_id), error=str(e)) + + if not images_list: + return cls._return_empty() + + final_images = [img[None,] for img in images_list] + + res = io.NodeOutput(final_images, ids_list, masks_list) + cls._cache[cache_key] = res + return res + + except Exception as e: + print(f"[Embeddr] LoadArtifacts error: {e}") + return cls._return_empty() + + @classmethod + def _return_empty(cls): + empty_image = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + empty_mask = torch.zeros( + (1, 64, 64), dtype=torch.float32, device="cpu") + return io.NodeOutput([empty_image], ["-1"], [empty_mask]) diff --git a/nodes/EmbeddrLoadImage.py b/nodes/EmbeddrLoadImage.py deleted file mode 100644 index 437b116..0000000 --- a/nodes/EmbeddrLoadImage.py +++ /dev/null @@ -1,75 +0,0 @@ -import requests -import torch -import numpy as np -from PIL import Image, ImageOps -from io import BytesIO -from comfy_api.latest import io, ui -from .utils import get_config - - -class EmbeddrLoadImageNode(io.ComfyNode): - _cache = {} - - @classmethod - def define_schema(cls) -> io.Schema: - return io.Schema( - node_id="embeddr.LoadImage", - display_name="Embeddr Load Image", - description="Loads an image from a Embeddr Image ID.", - category="Embeddr", - inputs=[ - io.String.Input("image_id", default=""), - ], - outputs=[ - io.Image.Output("image"), - io.Mask.Output("mask"), - io.String.Output("embeddr_id"), - ], - ) - - @classmethod - def execute(cls, image_id): - if not image_id: - # Return empty black image if no ID - empty_image = torch.zeros( - (1, 64, 64, 3), dtype=torch.float32, device="cpu") - empty_mask = torch.zeros( - (1, 64, 64), dtype=torch.float32, device="cpu") - return io.NodeOutput(empty_image, empty_mask, "") - - if image_id in cls._cache: - image, mask = cls._cache[image_id] - return io.NodeOutput(image, mask, image_id) - - try: - config = get_config() - endpoint = config.get("endpoint", "http://localhost:8003") - # Ensure endpoint doesn't end with slash - endpoint = endpoint.rstrip("/") - api_url = f"{endpoint}/api/v1/images/{image_id}/file" - - response = requests.get(api_url) - response.raise_for_status() - img = Image.open(BytesIO(response.content)) - - img = ImageOps.exif_transpose(img) - image = img.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - - if 'A' in img.getbands(): - mask = np.array(img.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) - else: - mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") - - cls._cache[image_id] = (image, mask) - return io.NodeOutput(image, mask, image_id, ui=ui.PreviewImage(image)) - - except Exception as e: - print(f"[Embeddr] Error loading image {image_id}: {e}") - empty_image = torch.zeros( - (1, 64, 64, 3), dtype=torch.float32, device="cpu") - empty_mask = torch.zeros( - (1, 64, 64), dtype=torch.float32, device="cpu") - return io.NodeOutput(empty_image, empty_mask, "") diff --git a/nodes/EmbeddrLoadImages.py b/nodes/EmbeddrLoadImages.py index 555855e..739b4b7 100644 --- a/nodes/EmbeddrLoadImages.py +++ b/nodes/EmbeddrLoadImages.py @@ -1,4 +1,7 @@ import requests +from urllib.parse import urljoin, urlparse +import logging +import os import torch import numpy as np from PIL import Image, ImageOps @@ -12,15 +15,64 @@ class EmbeddrLoadImagesNode(io.ComfyNode): _cache = {} + _logger = logging.getLogger("embeddr.comfyui.load_images") + + @classmethod + def _debug(cls, message: str, **fields): + if os.environ.get("EMBEDDR_COMFYUI_DEBUG", "").lower() not in { + "1", + "true", + "yes", + }: + return + try: + cls._logger.info("%s | %s", message, fields) + except Exception: + cls._logger.info("%s", message) + + @classmethod + def _resolve_artifact_url(cls, base_url: str, artifact_id: str): + resolve_url = f"{base_url}/api/v2/artifacts/{artifact_id}/resolve?variant=original&proxy=1" + cls._debug("resolving_artifact", artifact_id=artifact_id, + resolve_url=resolve_url) + res = requests.get(resolve_url) + res.raise_for_status() + data = res.json() + url = data.get("url") + headers = data.get("headers") or {} + if url and url.startswith("/"): + url = urljoin(base_url, url) + + if url and "/api/v2/artifacts/" in url and "/content" in url and "proxy=" not in url: + url = f"{url}?proxy=1" + + base_netloc = urlparse(base_url).netloc + url_netloc = urlparse(url).netloc if url else "" + if url and base_netloc and url_netloc and url_netloc != base_netloc: + proxy_url = f"{base_url}/api/v2/artifacts/{artifact_id}/content?proxy=1" + cls._debug( + "forcing_proxy_url", + artifact_id=artifact_id, + resolved_url=url, + proxy_url=proxy_url, + ) + return proxy_url, {} + cls._debug("resolved_artifact", artifact_id=artifact_id, + url=url, headers=headers) + return url, headers + @classmethod def define_schema(cls) -> io.Schema: + # Note: Dynamic fetching (get_collections) might need updates to V2 too + # but for now we focus on the execution logic collections = ["All"] + get_collections() + # Legacy libraries -> Folders/Collections? libraries = ["All"] + get_libraries() return io.Schema( node_id="embeddr.EmbeddrLoadImages", display_name="Embeddr Load Images", - description="Loads images from Embeddr with filtering and sorting.", + description="Loads images from Embeddr using V2 Artifacts API.", category="Embeddr", inputs=[ io.Combo.Input("library", options=libraries, default="All"), @@ -48,41 +100,44 @@ def execute(cls, library, collection, sort_by, limit, seed): try: config = get_config() - endpoint = config.get("endpoint", "http://localhost:8003") - api_base_url = endpoint.rstrip("/") + "/api/v1" + base_url = config.get("embeddr_url") or config.get( + "endpoint") or "http://localhost:8003" + base_url = base_url.rstrip("/") + + # V2 API: List Artifacts + api_url = f"{base_url}/api/v2/artifacts/" params = { "limit": limit, + "type_name": "image", # Filter for images + "offset": 0 } - # Parse Library ID - if library != "All": + # Parse Collection ID (UUID in V2?) + # Legacy logic assumed "123: Name" + if collection != "All": try: - lib_id = int(library.split(":")[0]) - params["library_id"] = lib_id + col_id = collection.split(":")[0].strip() + params["collection_id"] = col_id except: pass - # Parse Collection ID - if collection != "All": + # Legacy Libraries mapped to ? + # In V2 we might ignore library or treat as another collection filter + if library != "All": try: - col_id = int(collection.split(":")[0]) - params["collection_id"] = col_id + lib_id = library.split(":")[0].strip() + params["library_id"] = lib_id except: pass - # Handle Sort if sort_by == "random": params["sort"] = "random" - # Note: Server-side random might not respect seed, but we pass it just in case - # or we could implement client-side shuffle if we fetched more. - # For now, we rely on server-side random for efficiency. + params["seed"] = seed # Pass seed if API supports it else: params["sort"] = "new" - # Fetch images - url = f"{api_base_url}/images" - response = requests.get(url, params=params) + response = requests.get(api_url, params=params) response.raise_for_status() data = response.json() items = data.get("items", []) @@ -95,51 +150,90 @@ def execute(cls, library, collection, sort_by, limit, seed): ids_list = [] for item in items: - image_id = item.get("id") - if not image_id: + art_id = item.get("id") + if not art_id: continue - image_url = f"{api_base_url}/images/{image_id}/file" + # Fetch Content via Plugin Endpoint + # Uses the plugin endpoint we defined to get raw content + content_url, content_headers = cls._resolve_artifact_url( + base_url, art_id) + + cls._debug( + "requesting_artifact_content", + artifact_id=str(art_id), + content_url=content_url, + ) try: - img_resp = requests.get(image_url) + img_resp = requests.get( + content_url, headers=content_headers) img_resp.raise_for_status() + cls._debug( + "artifact_content_response", + artifact_id=str(art_id), + status=img_resp.status_code, + content_type=img_resp.headers.get("content-type"), + length=img_resp.headers.get("content-length"), + ) + img = Image.open(BytesIO(img_resp.content)) img = ImageOps.exif_transpose(img) - image = img.convert("RGB") - image_np = np.array(image).astype(np.float32) / 255.0 - # Add batch dimension: [1, H, W, C] - image_tensor = torch.from_numpy(image_np).unsqueeze(0) + # Convert to tensor conformant + i = img.convert("RGB") + i_np = np.array(i).astype(np.float32) / 255.0 + images_list.append(torch.from_numpy(i_np)) + # Mask if 'A' in img.getbands(): - mask_np = np.array(img.getchannel( - 'A')).astype(np.float32) / 255.0 - mask_tensor = 1. - torch.from_numpy(mask_np) - mask_tensor = mask_tensor.unsqueeze(0) # [1, H, W] + m_np = np.array(img.getchannel('A')).astype( + np.float32) / 255.0 + masks_list.append(1. - torch.from_numpy(m_np)) else: - mask_tensor = torch.zeros( - (1, img.height, img.width), dtype=torch.float32, device="cpu") + masks_list.append(torch.zeros( + (i_np.shape[0], i_np.shape[1]), dtype=torch.float32, device="cpu")) + + ids_list.append(str(art_id)) - images_list.append(image_tensor) - masks_list.append(mask_tensor) - ids_list.append(str(image_id)) except Exception as e: - print(f"[Embeddr] Failed to load image {image_id}: {e}") - continue + print(f"Failed to load artifact {art_id}: {e}") + cls._debug( + "artifact_load_failed", + artifact_id=str(art_id), + error=str(e), + ) if not images_list: return cls._return_empty() - # Return lists - result = io.NodeOutput(images_list, ids_list, masks_list) + # Stack images into batch [B, H, W, C]? + # Note: If images have different sizes, stacking might fail. + # ComfyUI usually expects same size for batch. + # If size differs, we might return list or resize. + # Since output is_output_list=True, likely returns list of individual tensors. + + # Correct return for output list + # "images" -> list of [H, W, C] or list of [1, H, W, C]? + # Usually list output means list of whatever single output is. + # Single Image output is [1, H, W, C]. + + # Ensure each is [1, H, W, C] + final_images = [img[None,] for img in images_list] + + # Cache result + result = io.NodeOutput(final_images, ids_list, masks_list) cls._cache[cache_key] = result return result except Exception as e: - print(f"[Embeddr] Error loading images: {e}") + print(f"[Embeddr] LoadImages Error: {e}") return cls._return_empty() - @staticmethod - def _return_empty(): - return io.NodeOutput([], [], []) + @classmethod + def _return_empty(cls): + empty_image = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + empty_mask = torch.zeros( + (1, 64, 64), dtype=torch.float32, device="cpu") + return io.NodeOutput([empty_image], ["-1"], [empty_mask]) diff --git a/nodes/EmbeddrUploadArtifact.py b/nodes/EmbeddrUploadArtifact.py new file mode 100644 index 0000000..c0cfccb --- /dev/null +++ b/nodes/EmbeddrUploadArtifact.py @@ -0,0 +1,131 @@ +import requests +import numpy as np +from PIL import Image +import io as pyio +import json +from comfy_api.latest import io, ui +from .utils import get_embeddr_base_url, get_upload_mode + + +def Embeddr_Log(message: str): + print(f"[Embeddr] {message}") + + +def normalize_list(value): + if not value: + return [] + + if isinstance(value, str): + items = [v.strip() for v in value.split(",")] + elif isinstance(value, list): + items = [str(v).strip() for v in value] + else: + raise TypeError(f"Unsupported type: {type(value)}") + + out = [] + seen = set() + for v in items: + if not v: + continue + lv = v.lower() + if lv in ("none", "null", "undefined"): + continue + if v not in seen: + seen.add(v) + out.append(v) + return out + + +class EmbeddrUploadArtifactNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.UploadArtifact", + display_name="Embeddr Upload (V2)", + category="Embeddr", + is_output_node=True, + inputs=[ + io.Image.Input("image"), + io.String.Input("parent_ids", default="", optional=True, + tooltip="Comma separated parent artifact UUIDs"), + io.String.Input("collection_ids", default="", optional=True, + tooltip="Comma separated Collection UUIDs (for grouping)"), + io.String.Input("tags", default="generated,comfyui", + tooltip="Comma separated tags"), + io.Boolean.Input("trigger_automation", default=True, + tooltip="Trigger Auto-Analysis (Thumbnails, Embeddings, etc)"), + ], + outputs=[ + io.String.Output("artifact_ids"), + ] + ) + + @classmethod + def execute(cls, image, parent_ids, collection_ids, tags, trigger_automation): + base_url = get_embeddr_base_url() + upload_mode = get_upload_mode() + endpoint = f"{base_url}/api/v2/plugins/embeddr-comfyui/upload" + + results = [] + + if upload_mode in {"skip", "disabled", "off", "none"}: + Embeddr_Log( + "Upload disabled (EMBEDDR_UPLOAD_MODE). Skipping Embeddr upload." + ) + return io.NodeOutput("", ui=ui.PreviewImage(image)) + + if upload_mode in {"best_effort", "auto"}: + try: + health_url = f"{base_url}/api/v2/system/routes" + requests.get(health_url, timeout=2) + except Exception as e: + Embeddr_Log( + f"Embeddr backend unavailable ({e}); skipping upload." + ) + return io.NodeOutput("", ui=ui.PreviewImage(image)) + + # 'image' input is a batch tensor [B, H, W, C] + for batch_idx, img_tensor in enumerate(image): + try: + # Convert tensor to PIL + i = 255. * img_tensor.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + + # Save to buffer + img_byte_arr = pyio.BytesIO() + img.save(img_byte_arr, format='PNG') + img_byte_arr.seek(0) + + # Prepare Metadata + meta = { + "parent_ids": normalize_list(parent_ids), + "collection_ids": normalize_list(collection_ids), + "tags": normalize_list(tags), + "trigger_automation": trigger_automation, + "compute_embedding": trigger_automation, # Legacy Compat + "batch_index": batch_idx, + "confirm": True + } + + # Prepare multipart upload + files = {'file': (f'image_{batch_idx}.png', + img_byte_arr, 'image/png')} + data = {'metadata': json.dumps(meta)} + + # Post to Embeddr Core Plugin + response = requests.post(endpoint, files=files, data=data) + response.raise_for_status() + res_json = response.json() + + art_id = res_json.get("id") + results.append(str(art_id)) + Embeddr_Log(f"Uploaded Artifact: {art_id}") + + except Exception as e: + Embeddr_Log(f"Upload failed for batch {batch_idx}: {e}") + # We don't crash the whole node, but result might be partial + + result_str = ",".join(results) + + # Return IDs and UI Preview + return io.NodeOutput(result_str, ui=ui.PreviewImage(image)) diff --git a/nodes/EmbeddrUploadImage.py b/nodes/EmbeddrUploadImage.py deleted file mode 100644 index d471806..0000000 --- a/nodes/EmbeddrUploadImage.py +++ /dev/null @@ -1,162 +0,0 @@ -import folder_paths -from .utils.api import get_libraries, get_collections -import os -import json -import requests -import numpy as np -from PIL import Image -from PIL.PngImagePlugin import PngInfo -from comfy_api.latest import io, ui -from comfy_api.latest._io import _UIOutput, ComfyNode, FolderType -import io as pyio -import random -from .utils import get_config - - -def Embeddr_Log(message: str): - print(f"[Embeddr] {message}") - - -class EmbeddrImage(ui.PreviewImage): - def __init__(self, image: io.Image.Type, ids=None, animated: bool = False, cls: type[ComfyNode] = None, **kwargs): - super().__init__(image, animated, cls=cls) - self.extra = {} - if ids is not None: - self.extra["embeddr_ids"] = ids - - def as_dict(self): - d = { - "images": self.values, - "animated": (self.animated,) - } - d.update(self.extra) - return d - - -class EmbeddrSaveToFolderNode(io.ComfyNode): - - @classmethod - def define_schema(cls) -> io.Schema: - libraries = ["Default"] + get_libraries() - collections = ["None"] + get_collections() - - return io.Schema( - node_id="embeddr.SaveToFolder", - display_name="Embeddr Upload Image", - category="Embeddr", - is_output_node=True, - inputs=[ - io.Image.Input("image"), - io.String.Input("caption", optional=True), - io.String.Input("parent_ids", optional=True), - io.Combo.Input("library", options=libraries, - default="Default"), - io.Combo.Input( - "collection", options=collections, default="None"), - io.String.Input("tags", optional=True, default=""), - io.Boolean.Input("allow_duplicates", default=False, - display_name="Allow Duplicates"), - io.Boolean.Input("save_backup", default=False, - display_name="Save to Comfy History"), - ], - outputs=[ - # THIS IS KEY: output name must match - io.String.Output("embeddr_id"), - ], - ) - - @classmethod - def VALIDATE_INPUTS(cls, **kwargs): - return True - - @classmethod - def execute(cls, image, caption=None, parent_ids=None, library="Default", collection="None", tags="", allow_duplicates=False, save_backup=False, **kwargs): - """ - image: input tensor(s) from previous node - caption: optional string - parent_ids: optional string (comma separated) or list of IDs - Returns the backend ID in the outputs so Comfy history sees it. - """ - uploaded_ids = [] - - config = get_config() - print("Loaded Config: ", config) - endpoint = config.get("endpoint", "http://localhost:8003") - api_base_url = endpoint.rstrip("/") + "/api/v1" - upload_url = f"{api_base_url}/images/upload" - - # Loop over batch images - for i in range(image.shape[0]): - img_array = (image[i].cpu().numpy() * 255).astype(np.uint8) - img = Image.fromarray(np.clip(img_array, 0, 255)) - - # Save backup if requested - if save_backup: - try: - output_dir = folder_paths.get_output_directory() - filename_prefix = "Embeddr_Backup" - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( - filename_prefix, output_dir, img.size[0], img.size[1]) - - metadata = PngInfo() - if caption: - metadata.add_text("parameters", caption) - - file = f"{filename}_{counter:05}_.png" - img.save(os.path.join(full_output_folder, file), - pnginfo=metadata, compress_level=4) - except Exception as e: - print(f"[Embeddr] Failed to save backup: {e}") - - buf = pyio.BytesIO() - img.save(buf, format="PNG") - buf.seek(0) - - files = {"file": ("image.png", buf, "image/png")} - data = {"prompt": caption or ""} - - if allow_duplicates: - data["force"] = "true" - - if tags: - data["tags"] = tags - - if library != "Default": - try: - data["library_id"] = int(library.split(":")[0]) - except: - pass - - if parent_ids: - if isinstance(parent_ids, list): - data["parent_ids"] = ",".join(map(str, parent_ids)) - else: - data["parent_ids"] = str(parent_ids) - - try: - response = requests.post(upload_url, files=files, data=data) - response.raise_for_status() - result = response.json() - uploaded_id = result.get("id") - uploaded_ids.append(str(uploaded_id)) - - # Add to collection if selected - if collection and collection != "None" and uploaded_id: - try: - collection_id = int(collection.split(":")[0]) - requests.post( - f"{api_base_url}/collections/{collection_id}/items", - json={"image_id": uploaded_id} - ) - except Exception as e: - print(f"[Embeddr] Failed to add to collection: {e}") - - except Exception as e: - print(f"[Embeddr] Upload failed: {e}") - uploaded_ids.append("-1") - - # Create preview - preview = EmbeddrImage(image, uploaded_ids, cls=cls) - - # Return IDs as comma-separated string - return io.NodeOutput(",".join(uploaded_ids), ui=preview) diff --git a/nodes/EmbeddrUploadVideo.py b/nodes/EmbeddrUploadVideo.py index ce6d91f..1db9e9f 100644 --- a/nodes/EmbeddrUploadVideo.py +++ b/nodes/EmbeddrUploadVideo.py @@ -6,7 +6,7 @@ import tempfile from comfy_api.latest import io, ui from comfy_api.latest._io import ComfyNode -from .utils import get_config +from .utils import get_embeddr_base_url class EmbeddrUploadVideo(io.ComfyNode): @@ -52,9 +52,8 @@ def VALIDATE_INPUTS(cls, **kwargs): @classmethod def execute(cls, video, caption=None, parent_ids=None, library="Default", collection="None", tags="", format="mp4", codec="h264", allow_duplicates=False, save_backup=False, **kwargs): uploaded_ids = [] - config = get_config() - endpoint = config.get("endpoint", "http://localhost:8003") - api_base_url = endpoint.rstrip("/") + "/api/v1" + base_url = get_embeddr_base_url() + api_base_url = f"{base_url}/api/v1" upload_url = f"{api_base_url}/images/upload" try: diff --git a/nodes/utils/__init__.py b/nodes/utils/__init__.py index dc46902..4159ef7 100644 --- a/nodes/utils/__init__.py +++ b/nodes/utils/__init__.py @@ -1,3 +1,3 @@ -from .config import get_config +from .config import get_config, get_embeddr_base_url, get_upload_mode -__all__ = ["get_config"] +__all__ = ["get_config", "get_embeddr_base_url", "get_upload_mode"] diff --git a/nodes/utils/api.py b/nodes/utils/api.py index 0ee1c1a..fea25a1 100644 --- a/nodes/utils/api.py +++ b/nodes/utils/api.py @@ -1,12 +1,11 @@ import requests -from .config import get_config +from .config import get_embeddr_base_url def get_libraries(): try: - config = get_config() - endpoint = config.get("endpoint", "http://localhost:8003") - api_url = endpoint.rstrip("/") + "/api/v1/libraries" + base_url = get_embeddr_base_url() + api_url = f"{base_url}/api/v1/libraries" response = requests.get(api_url) if response.status_code == 200: data = response.json() @@ -22,9 +21,8 @@ def get_libraries(): def get_collections(): try: - config = get_config() - endpoint = config.get("endpoint", "http://localhost:8003") - api_url = endpoint.rstrip("/") + "/api/v1/collections" + base_url = get_embeddr_base_url() + api_url = f"{base_url}/api/v1/collections" response = requests.get(api_url) if response.status_code == 200: data = response.json() diff --git a/nodes/utils/config.py b/nodes/utils/config.py index c349da9..08000e1 100644 --- a/nodes/utils/config.py +++ b/nodes/utils/config.py @@ -2,6 +2,27 @@ import json +def _normalize_base_url(url: str | None, default: str) -> str: + if not url: + return default + + clean = str(url).strip().rstrip("/") + + # Strip API suffixes to get the root base + if clean.endswith("/api/v1"): + clean = clean[:-7] + elif clean.endswith("/api/v2"): + clean = clean[:-7] + elif clean.endswith("/api"): + clean = clean[:-4] + + # Add scheme if missing + if "://" not in clean: + clean = f"http://{clean}" + + return clean.rstrip("/") + + def get_config(): try: # Go up 3 levels: utils -> nodes -> embeddr-comfyui @@ -13,3 +34,27 @@ def get_config(): except Exception: pass return {} + + +def get_embeddr_base_url(default: str = "http://localhost:8003") -> str: + cfg = get_config() + env_url = ( + os.environ.get("EMBEDDR_BACKEND_URL") + or os.environ.get("EMBEDDR_URL") + or os.environ.get("EMBEDDR_ENDPOINT") + ) + cfg_url = ( + cfg.get("embeddr_url") + or cfg.get("backend_url") + or cfg.get("endpoint") + ) + + return _normalize_base_url(env_url or cfg_url, default) + + +def get_upload_mode(default: str = "require") -> str: + cfg = get_config() + env_mode = os.environ.get("EMBEDDR_UPLOAD_MODE") + cfg_mode = cfg.get("upload_mode") + mode = (env_mode or cfg_mode or default or "require").strip().lower() + return mode diff --git a/package.json b/package.json index 1e58a2c..53c88fa 100644 --- a/package.json +++ b/package.json @@ -21,6 +21,7 @@ "@dnd-kit/core": "^6.3.1", "@dnd-kit/sortable": "^10.0.0", "@dnd-kit/utilities": "^3.2.2", + "@embeddr/api": "workspace:*", "@embeddr/react-ui": "^0.1.4", "@radix-ui/react-aspect-ratio": "^1.1.8", "@radix-ui/react-avatar": "^1.1.11", diff --git a/ui/components/GlobalDialog.tsx b/ui/components/GlobalDialog.tsx index 68fd9fb..f1aebf4 100644 --- a/ui/components/GlobalDialog.tsx +++ b/ui/components/GlobalDialog.tsx @@ -6,13 +6,17 @@ import { DialogTitle, } from "@embeddr/react-ui/components/dialog"; import { ExploreTab } from "./tabs/ExploreTab"; +import { CollectionSelector } from "./selectors/CollectionSelector"; import { useEmbeddrApi } from "../hooks/useEmbeddrApi"; // @ts-ignore import { app } from "../../../scripts/app.js"; +type DialogMode = "image" | "collection"; + export function GlobalDialog() { const [isOpen, setIsOpen] = useState(false); const [targetNodeId, setTargetNodeId] = useState(null); + const [mode, setMode] = useState("image"); const api = useEmbeddrApi(); @@ -20,11 +24,17 @@ export function GlobalDialog() { const handleOpen = (e: Event) => { const customEvent = e as CustomEvent; setTargetNodeId(customEvent.detail.nodeId); + const requestedMode = customEvent.detail.mode || "image"; + setMode(requestedMode); setIsOpen(true); // Trigger a fetch if needed if (api.configLoaded) { - api.fetchImages(true); + if (requestedMode === "collection") { + api.fetchCollections(); + } else { + api.fetchImages(true); + } } }; @@ -69,31 +79,109 @@ export function GlobalDialog() { return () => observer.disconnect(); }, [api.theme]); - const handleSelect = (image: any) => { + const handleSelect = (item: any) => { if (targetNodeId !== null) { const node = app.graph.getNodeById(targetNodeId); if (node) { - // Assuming the first widget is the image_id input - // We should check the widget name or type - const idWidget = node.widgets?.find((w: any) => w.name === "image_id"); - if (idWidget) { - idWidget.value = image.id.toString(); - if (idWidget.callback) { - idWidget.callback(idWidget.value); + if (mode === "collection") { + // Handle Collection Selection + + // 1. Try "collection_ids" -> Set ID + const idWidget = node.widgets?.find( + (w: any) => w.name === "collection_ids" + ); + if (idWidget) { + idWidget.value = item.id.toString(); + if (idWidget.callback) { + idWidget.callback(idWidget.value); + } + } + + // 2. Try "collection_id" -> Set ID (for FindCollection node V2) + console.log( + "[Embeddr] Debug: Listing node widgets:", + node.widgets?.map((w: any) => w.name) + ); + + const colIdWidget = node.widgets?.find( + (w: any) => w.name === "collection_id" + ); + + let idSet = false; + if (colIdWidget && item.id) { + console.log("[Embeddr] Setting collection_id to", item.id); + colIdWidget.value = item.id.toString(); + if (colIdWidget.callback) { + colIdWidget.callback(colIdWidget.value); + } + idSet = true; + } else { + console.warn( + "[Embeddr] collection_id widget not found on node or invalid item.id!", + { widgetFound: !!colIdWidget, itemId: item.id } + ); + } + + // 3. Update Info Widget (Friendly display) + const infoWidget = node.widgets?.find((w: any) => w.name === "Info"); + if (infoWidget) { + const count = item.file_count !== undefined ? item.file_count : "?"; + infoWidget.value = `${ + item.label || item.name + } (${count} items) [ID: ${item.id?.substring(0, 8)}...]`; + } + + // 4. Update or Clear Name Widget + // If we successfully set the ID, clear the name widget to avoid confusion + // and ensure the backend uses the ID. + const nameWidget = node.widgets?.find( + (w: any) => w.name === "collection_name" + ); + if (nameWidget) { + if (idSet) { + // Clear name to prioritize ID match and avoid ambiguity + console.log( + "[Embeddr] Clearing collection_name widget to prioritize ID" + ); + nameWidget.value = ""; + } else { + nameWidget.value = item.label || item.name; + } + + if (nameWidget.callback) { + nameWidget.callback(nameWidget.value); + } + } + + // Force UI update + if (app.graph) { + app.graph.setDirtyCanvas(true, true); } } else { - // Fallback to first widget if name doesn't match - if (node.widgets && node.widgets[0]) { - node.widgets[0].value = image.id.toString(); + // Handle Image Selection + // Check for artifact_id (V2) or image_id (V1) + const idWidget = node.widgets?.find( + (w: any) => w.name === "artifact_id" || w.name === "image_id" + ); + if (idWidget) { + idWidget.value = item.id.toString(); + if (idWidget.callback) { + idWidget.callback(idWidget.value); + } + } else { + // Fallback to first widget if name doesn't match + if (node.widgets && node.widgets[0]) { + node.widgets[0].value = item.id.toString(); + } } - } - // Also update image_url if it exists (for preview/compatibility) - const urlWidget = node.widgets?.find( - (w: any) => w.name === "image_url" - ); - if (urlWidget) { - urlWidget.value = image.image_url; + // Also update image_url if it exists (for preview/compatibility) + const urlWidget = node.widgets?.find( + (w: any) => w.name === "image_url" + ); + if (urlWidget) { + urlWidget.value = item.image_url; + } } app.graph.setDirtyCanvas(true, true); @@ -106,14 +194,27 @@ export function GlobalDialog() { - Select Image + + {mode === "collection" ? "Select Collection" : "Select Image"} +
- + {mode === "collection" ? ( + + ) : ( + + )}
diff --git a/ui/components/panels/EmbeddrPanel.tsx b/ui/components/panels/EmbeddrPanel.tsx index 06c1433..fba7560 100644 --- a/ui/components/panels/EmbeddrPanel.tsx +++ b/ui/components/panels/EmbeddrPanel.tsx @@ -34,6 +34,7 @@ export default function EmbeddrPanel() { setSimilarImageId, theme, setTheme, + apiClient, } = useEmbeddrApi(); const { openExternal } = useExternalNav(); @@ -88,7 +89,7 @@ export default function EmbeddrPanel() { size="icon" className={cn( "ml-auto", - activeTab === "settings" ? "bg-primary/50" : "" + activeTab === "settings" ? "bg-primary/50" : "", )} onClick={() => setActiveTab("settings")} > @@ -115,6 +116,8 @@ export default function EmbeddrPanel() { similarImageId={similarImageId} setSimilarImageId={setSimilarImageId} mode={mode} + apiBase={endpoint} + apiClient={apiClient} gridSize={gridSize} gridPreviewContain={gridPreviewContain} configLoaded={configLoaded} diff --git a/ui/components/panels/ImageDetails.tsx b/ui/components/panels/ImageDetails.tsx index ab15ac6..d4fdae2 100644 --- a/ui/components/panels/ImageDetails.tsx +++ b/ui/components/panels/ImageDetails.tsx @@ -1,7 +1,22 @@ -import React, { useState } from "react"; +import React, { useMemo, useState, useEffect } from "react"; import { Button } from "@embeddr/react-ui/components/button"; import { ScrollArea } from "@embeddr/react-ui/components/scroll-area"; -import { ArrowBigRightDashIcon, Check, Copy, Plus } from "lucide-react"; +import { Badge } from "@embeddr/react-ui/components/badge"; +import { Separator } from "@embeddr/react-ui/components/separator"; +import { Skeleton } from "@embeddr/react-ui/components/skeleton"; +import { EmbeddrApiClient } from "@embeddr/api"; + +import { + ArrowBigRightDashIcon, + Check, + Copy, + Plus, + Tag, + Folder, + Hash, + Info, + FileText, +} from "lucide-react"; import type { PromptImageRead } from "@hooks/useEmbeddrApi"; import type { TargetNode } from "@hooks/useNodeScanner"; @@ -10,6 +25,17 @@ interface ImageDetailsProps { targetNodes: Array; onLoadIntoNode: (nodeId: number, imageUrl: string) => void; onUseImage: (imageUrl: string) => void; + apiBase?: string; + apiClient?: EmbeddrApiClient; +} + +interface ArtifactDetail { + id: string; + type_name: string; + uri: string; + metadata_json: Record; + collections: Array<{ id: string; name: string }>; + tags: Array<{ id: string; name: string }>; } export function ImageDetails({ @@ -17,8 +43,57 @@ export function ImageDetails({ targetNodes, onLoadIntoNode, onUseImage, + apiBase, + apiClient, }: ImageDetailsProps) { const [copied, setCopied] = useState(false); + const [artifact, setArtifact] = useState(null); + const [loading, setLoading] = useState(false); + + const localClient = useMemo(() => { + if (apiClient) return apiClient; + if (!apiBase) return null; + let baseUrl = apiBase; + if (baseUrl.endsWith("/")) baseUrl = baseUrl.slice(0, -1); + if (!baseUrl.endsWith("/api/v2")) { + baseUrl = `${baseUrl}/api/v2`; + } + return new EmbeddrApiClient({ baseUrl }); + }, [apiClient, apiBase]); + + useEffect(() => { + if (!selectedImage?.id) { + setArtifact(null); + return; + } + + setLoading(true); + const client = localClient; + const fallback = apiBase + ? apiBase.replace(/\/+$/, "").replace(/\/+api\/v2$/, "") + "/api/v2" + : ""; + + if (client) { + client.artifacts + .get(selectedImage.id) + .then((data) => setArtifact(data as ArtifactDetail)) + .catch((e) => console.error("Failed to fetch artifact details", e)) + .finally(() => setLoading(false)); + return; + } + + if (!fallback) { + setArtifact(null); + setLoading(false); + return; + } + + fetch(`${fallback}/artifacts/${selectedImage.id}`) + .then((res) => (res.ok ? res.json() : null)) + .then((data) => setArtifact(data)) + .catch((e) => console.error("Failed to fetch artifact details", e)) + .finally(() => setLoading(false)); + }, [selectedImage.id, apiBase, localClient]); const handleCopyPrompt = () => { if (selectedImage?.prompt) { @@ -90,11 +165,100 @@ export function ImageDetails({ - {selectedImage.prompt && ( -
- {selectedImage.prompt} -
- )} +
+ {selectedImage.prompt && ( +
+
+ Prompt +
+
+ {selectedImage.prompt} +
+
+ )} + + {loading && ( +
+ + +
+ )} + + {!loading && artifact && ( + <> + + {/* Tech Specs */} +
+
+ + Dimensions + + + {artifact.metadata_json?.width || selectedImage.width} x{" "} + {artifact.metadata_json?.height || selectedImage.height} + +
+
+ + Type + + + {artifact.metadata_json?.format || artifact.type_name} + +
+ {/* File Path (usually useful for debugging or local) */} +
+ + {/* Collections */} + {artifact.collections && artifact.collections.length > 0 && ( +
+
+ Collections +
+
+ {artifact.collections.map((c) => ( + + {c.name} + + ))} +
+
+ )} + + {/* Tags */} + {((artifact.tags && artifact.tags.length > 0) || + (artifact.metadata_json?.tags && + Array.isArray(artifact.metadata_json.tags))) && ( +
+
+ Tags +
+
+ {/* Prefer relational tags, fallback to metadata tags */} + {(artifact.tags?.length > 0 + ? artifact.tags + : (artifact.metadata_json.tags as string[]).map( + (t) => ({ id: t, name: t }), + ) + ).map((t) => ( + + #{t.name} + + ))} +
+
+ )} + + )} +
diff --git a/ui/components/selectors/CollectionSelector.tsx b/ui/components/selectors/CollectionSelector.tsx new file mode 100644 index 0000000..658808b --- /dev/null +++ b/ui/components/selectors/CollectionSelector.tsx @@ -0,0 +1,130 @@ +import React, { useEffect, useState } from "react"; +import { + Card, + CardDescription, + CardHeader, + CardTitle, +} from "@embeddr/react-ui/components/card"; +import { ScrollArea } from "@embeddr/react-ui/components/scroll-area"; +import { Input } from "@embeddr/react-ui/components/input"; +import { Button } from "@embeddr/react-ui/components/button"; +import { Folder, Search, Plus } from "lucide-react"; +import type { Collection } from "../../hooks/useEmbeddrCollections"; + +interface CollectionSelectorProps { + collections: Collection[]; + loading: boolean; + onSelect: (collection: Collection) => void; + fetchCollections: () => void; + createCollection: (label: string) => Promise; + creating: boolean; +} + +export function CollectionSelector({ + collections, + loading, + onSelect, + fetchCollections, + createCollection, + creating, +}: CollectionSelectorProps) { + const [search, setSearch] = useState(""); + const [showCreate, setShowCreate] = useState(false); + const [newLabel, setNewLabel] = useState(""); + + useEffect(() => { + fetchCollections(); + }, []); + + const handleCreate = async () => { + if (!newLabel.trim()) return; + const success = await createCollection(newLabel); + if (success) { + setNewLabel(""); + setShowCreate(false); + } + }; + + const filtered = collections.filter((c) => + (c.label || "").toLowerCase().includes(search.toLowerCase()) + ); + + return ( +
+
+
+ + setSearch(e.target.value)} + className="pl-8" + /> +
+ +
+ + {showCreate && ( + +

Create New Collection

+
+ setNewLabel(e.target.value)} + /> + +
+
+ )} + + + {loading ? ( +
+ Scanning collections... +
+ ) : filtered.length === 0 ? ( +
+
No collections found
+
Create one above to get started
+
+ ) : ( +
+ {filtered.map((collection) => ( + onSelect(collection)} + > + +
+ +
+
+ + {collection.label} + + + {collection.file_count ?? 0} items + +
+
+
+ ))} +
+ )} +
+
+ ); +} diff --git a/ui/components/tabs/ExploreTab.tsx b/ui/components/tabs/ExploreTab.tsx index 73a1218..75a3594 100644 --- a/ui/components/tabs/ExploreTab.tsx +++ b/ui/components/tabs/ExploreTab.tsx @@ -11,11 +11,13 @@ import { ImageGrid } from "@components/ui/ImageGrid"; import { useNodeScanner } from "@hooks/useNodeScanner"; import { ImageDetails } from "../panels/ImageDetails"; import { SearchBar } from "../ui/SearchBar"; +import { useEmbeddrCollections } from "../../hooks/useEmbeddrCollections"; import type { ApiMode, LibraryPath, PromptImageRead, } from "@hooks/useEmbeddrApi"; +import type { EmbeddrApiClient } from "@embeddr/api"; interface ExploreTabProps { images: Array; @@ -26,12 +28,15 @@ interface ExploreTabProps { query?: string, viewMode?: "all" | "mine", libId?: number | null, - similarId?: number | null + similarId?: string | number | null, + collectionId?: string | null, ) => Promise; libraries: Array; similarImageId: number | null; setSimilarImageId: (id: number | null) => void; mode: ApiMode; + apiBase?: string; // Need apiBase for collections + apiClient?: EmbeddrApiClient; gridSize: number; setGridSize?: (size: number) => void; gridPreviewContain: boolean; @@ -49,6 +54,8 @@ export function ExploreTab({ similarImageId, setSimilarImageId, mode, + apiBase = "", + apiClient, gridSize, setGridSize, gridPreviewContain, @@ -60,20 +67,35 @@ export function ExploreTab({ const { openImage, closeImage, setGalleryImages, currentGallery } = useImageDialog(); + const { collections, fetchCollections } = useEmbeddrCollections({ + apiBase, + configLoaded, + apiClient, + }); + + // Load collections on mount/config load + useEffect(() => { + if (configLoaded) { + fetchCollections(); + } + }, [configLoaded, fetchCollections]); + const [searchQuery, setSearchQuery] = useState(""); const [viewMode, setViewMode] = useState<"all" | "mine">("all"); - const [selectedLibrary, setSelectedLibrary] = useState("all"); + const [selectedCollectionId, setSelectedCollectionId] = + useState("all"); const [selectedImage, setSelectedImage] = useState( - null + null, ); const scrollRef = useRef(null); // Fetch when dependencies change useEffect(() => { if (!configLoaded) return; - const libId = selectedLibrary === "all" ? null : parseInt(selectedLibrary); - fetchImages(true, searchQuery, viewMode, libId, similarImageId); - }, [viewMode, configLoaded, selectedLibrary, mode, similarImageId]); // Re-fetch when view mode changes or config is loaded + const colId = selectedCollectionId === "all" ? null : selectedCollectionId; + // libraryId is null now as we use collections + fetchImages(true, searchQuery, viewMode, null, similarImageId, colId); + }, [viewMode, configLoaded, selectedCollectionId, mode, similarImageId]); // Re-fetch when view mode changes or config is loaded // Sync images to lightbox when they change useEffect(() => { @@ -90,8 +112,8 @@ export function ExploreTab({ const handleSearch = (e: React.FormEvent) => { e.preventDefault(); - const libId = selectedLibrary === "all" ? null : parseInt(selectedLibrary); - fetchImages(true, searchQuery, viewMode, libId, similarImageId); + const colId = selectedCollectionId === "all" ? null : selectedCollectionId; + fetchImages(true, searchQuery, viewMode, null, similarImageId, colId); }; return ( @@ -112,9 +134,9 @@ export function ExploreTab({ } }} mode={mode} - selectedLibrary={selectedLibrary} - setSelectedLibrary={setSelectedLibrary} - libraries={libraries} + selectedCollectionId={selectedCollectionId} + setSelectedCollectionId={setSelectedCollectionId} + collections={collections} viewMode={viewMode} setViewMode={setViewMode} /> @@ -152,18 +174,19 @@ export function ExploreTab({ setSelectedImage(e); }} onLoadMore={() => { - const libId = - selectedLibrary === "all" ? null : parseInt(selectedLibrary); + const colId = + selectedCollectionId === "all" ? null : selectedCollectionId; fetchImages( false, searchQuery, viewMode, - libId, - similarImageId + null, + similarImageId, + colId, ); }} onSimilarSearch={(image) => { - setSimilarImageId(image); + setSimilarImageId(image.id); // Assuming image.id is number/string }} onSelect={(image) => { if (!image) return; @@ -191,16 +214,17 @@ export function ExploreTab({ totalImages: totalImages, fetchMore: async (_dir: any, _offset: any) => { if (hasMore) { - const libId = - selectedLibrary === "all" + const colId = + selectedCollectionId === "all" ? null - : parseInt(selectedLibrary); + : selectedCollectionId; await fetchImages( false, searchQuery, viewMode, - libId, - similarImageId + null, + similarImageId, + colId, ); } }, @@ -224,7 +248,7 @@ export function ExploreTab({ }, }, ], - image.prompt + image.prompt, ); }} selectedId={selectedImage?.id} @@ -243,6 +267,8 @@ export function ExploreTab({ targetNodes={targetNodes} onUseImage={handleUseImage} onLoadIntoNode={handleLoadIntoNode} + apiBase={apiBase} + apiClient={apiClient} /> diff --git a/ui/components/ui/ImageGrid.tsx b/ui/components/ui/ImageGrid.tsx index 05787be..aba1373 100644 --- a/ui/components/ui/ImageGrid.tsx +++ b/ui/components/ui/ImageGrid.tsx @@ -56,9 +56,11 @@ export function ImageGrid({ }; }, [hasMore, loading, onLoadMore]); - const handleDragStart = (e: React.DragEvent, imageUrl: string) => { - e.dataTransfer.setData("text/plain", imageUrl); - e.dataTransfer.setData("text/uri-list", imageUrl); + const handleDragStart = (e: React.DragEvent, image: any) => { + e.dataTransfer.setData("text/plain", image.image_url); + e.dataTransfer.setData("text/uri-list", image.image_url); + e.dataTransfer.setData("embeddr/id", image.id.toString()); + e.dataTransfer.setData("embeddr/json", JSON.stringify(image)); }; return ( @@ -81,7 +83,7 @@ export function ImageGrid({ }} onClick={() => onSelect?.(image)} draggable - onDragStart={(e) => handleDragStart(e, image.image_url)} + onDragStart={(e) => handleDragStart(e, image)} > void; mode: ApiMode; - selectedLibrary: string; - setSelectedLibrary: (lib: string) => void; - libraries: Array; + selectedCollectionId: string; + setSelectedCollectionId: (id: string) => void; + collections: Array; viewMode: "all" | "mine"; setViewMode: (mode: "all" | "mine") => void; } @@ -35,9 +36,9 @@ export function SearchBar({ similarImageId, setSimilarImageId, mode, - selectedLibrary, - setSelectedLibrary, - libraries, + selectedCollectionId, + setSelectedCollectionId, + collections, viewMode, setViewMode, }: SearchBarProps) { @@ -75,15 +76,18 @@ export function SearchBar({
{mode === "local" ? ( - - + - All Libraries - {libraries.map((lib) => ( - - {lib.name || lib.path} + All Collections + {collections.map((col) => ( + + {col.label || "Untitled"} ({col.file_count || 0}) ))} diff --git a/ui/hooks/useEmbeddrApi.ts b/ui/hooks/useEmbeddrApi.ts index 5eb140d..39c3fca 100644 --- a/ui/hooks/useEmbeddrApi.ts +++ b/ui/hooks/useEmbeddrApi.ts @@ -1,9 +1,15 @@ +import { useMemo } from "react"; +import { EmbeddrApiClient } from "@embeddr/api"; import { useEmbeddrSettings } from "./useEmbeddrSettings"; import { useEmbeddrLibraries } from "./useEmbeddrLibraries"; import { useEmbeddrImages } from "./useEmbeddrImages"; +import { + useEmbeddrCollections, + type Collection, +} from "./useEmbeddrCollections"; import type { ApiMode, LibraryPath, PromptImageRead } from "@types"; -export type { PromptImageRead, LibraryPath, ApiMode }; +export type { PromptImageRead, LibraryPath, ApiMode, Collection }; interface UseEmbeddrApiProps { baseUrl?: string; @@ -14,6 +20,11 @@ export function useEmbeddrApi({ }: UseEmbeddrApiProps = {}) { const settings = useEmbeddrSettings({ baseUrl }); + const apiClient = useMemo( + () => new EmbeddrApiClient({ baseUrl: settings.apiBase }), + [settings.apiBase], + ); + const libraries = useEmbeddrLibraries({ apiBase: settings.apiBase, mode: settings.mode, @@ -24,11 +35,20 @@ export function useEmbeddrApi({ apiBase: settings.apiBase, mode: settings.mode, configLoaded: settings.configLoaded, + apiClient, + }); + + const collections = useEmbeddrCollections({ + apiBase: settings.apiBase, + configLoaded: settings.configLoaded, + apiClient, }); return { ...settings, + apiClient, ...libraries, ...images, + ...collections, }; } diff --git a/ui/hooks/useEmbeddrCollections.ts b/ui/hooks/useEmbeddrCollections.ts new file mode 100644 index 0000000..5c02634 --- /dev/null +++ b/ui/hooks/useEmbeddrCollections.ts @@ -0,0 +1,122 @@ +import { useState, useCallback } from "react"; +import type { EmbeddrApiClient } from "@embeddr/api"; + +export interface Collection { + id: string; + label: string; + type_name: string; + file_count: number; + uri?: string; + created_at?: string; +} + +interface UseEmbeddrCollectionsProps { + apiBase: string; + configLoaded: boolean; + apiClient?: EmbeddrApiClient; +} + +export function useEmbeddrCollections({ + apiBase, + configLoaded, + apiClient, +}: UseEmbeddrCollectionsProps) { + const [collections, setCollections] = useState([]); + const [loadingCollections, setLoadingCollections] = useState(false); + + const fetchCollections = useCallback(async () => { + if (!configLoaded || !apiBase) return; + + setLoadingCollections(true); + try { + const headers: Record = { + "Content-Type": "application/json", + }; + + let baseUrl = apiBase; + if (baseUrl.endsWith("/")) baseUrl = baseUrl.slice(0, -1); + // Ensure we target the V2 API + if (!baseUrl.endsWith("/api/v2")) { + baseUrl = `${baseUrl}/api/v2`; + } + + const res = await fetch(`${baseUrl}/collections`, { + method: "GET", + headers, + }); + if (res.ok) { + const data = await res.json(); + // data could be paginated or just a list + // Assuming list for now based on typical embeddr api + const list = Array.isArray(data) ? data : data.items || []; + setCollections(list); + } else { + console.error("Failed to fetch collections", res.status); + } + } catch (error) { + console.error("Error fetching collections:", error); + } finally { + setLoadingCollections(false); + } + }, [apiBase, configLoaded]); + + const [creating, setCreating] = useState(false); + + const createCollection = useCallback( + async (label: string) => { + if (!configLoaded || !apiBase) return; + setCreating(true); + try { + const headers: Record = { + "Content-Type": "application/json", + }; + const payload = { + label: label, + type_name: "collection:mix", // Default to simple mix + uri: `embeddr:///collections/${label + .toLowerCase() + .replace(/\s/g, "_")}_${Date.now()}`, + }; + + let baseUrl = apiBase; + if (baseUrl.endsWith("/")) baseUrl = baseUrl.slice(0, -1); + if (!baseUrl.endsWith("/api/v2")) { + baseUrl = `${baseUrl}/api/v2`; + } + + // Use the artifact endpoint to create a collection, since /collections might be read-only or alias + // But if /api/v2/collections exists as a dedicated endpoint, we use it. + // Assuming /api/v2/collections POST works as expected for creating collections specificically. + // If not, we might need to POST to /artifacts with type=collection. + // Let's stick to the user's requested endpoint /api/v2/collections for now. + const res = await fetch(`${baseUrl}/collections`, { + method: "POST", + headers, + body: JSON.stringify(payload), + }); + + if (res.ok) { + await fetchCollections(); // Refresh list + return true; + } else { + console.error("Failed to create collection", res.status); + return false; + } + } catch (e) { + console.error(e); + return false; + } finally { + setCreating(false); + } + }, + [apiBase, configLoaded, fetchCollections], + ); + + return { + collections, + fetchCollections, + loadingCollections, + createCollection, + creating, + }; +} diff --git a/ui/hooks/useEmbeddrImages.ts b/ui/hooks/useEmbeddrImages.ts index d6448c3..d434562 100644 --- a/ui/hooks/useEmbeddrImages.ts +++ b/ui/hooks/useEmbeddrImages.ts @@ -1,4 +1,5 @@ import { useCallback, useRef, useState } from "react"; +import type { EmbeddrApiClient } from "@embeddr/api"; // @ts-ignore import { app } from "../../../scripts/app.js"; import type { ApiMode, PromptImageRead } from "@types"; @@ -7,19 +8,23 @@ interface UseEmbeddrImagesProps { apiBase: string; mode: ApiMode; configLoaded: boolean; + apiClient?: EmbeddrApiClient; } export function useEmbeddrImages({ apiBase, mode, configLoaded, + apiClient, }: UseEmbeddrImagesProps) { const [images, setImages] = useState>([]); const [loading, setLoading] = useState(false); const loadingRef = useRef(false); const pageRef = useRef(1); const [hasMore, setHasMore] = useState(true); - const [similarImageId, setSimilarImageId] = useState(null); + const [similarImageId, setSimilarImageId] = useState( + null, + ); const fetchImages = useCallback( async ( @@ -27,7 +32,8 @@ export function useEmbeddrImages({ searchQuery = "", viewMode: "all" | "mine" = "all", libraryId?: number | null, - similarId?: number | null, + similarId?: string | number | null, + collectionId?: string | null, ) => { if (!configLoaded) return; if (loadingRef.current && !reset) return; @@ -35,7 +41,9 @@ export function useEmbeddrImages({ loadingRef.current = true; setLoading(true); try { - const headers: Record = {}; + const headers: Record = { + "Content-Type": "application/json", + }; const storedKey = localStorage.getItem("embeddr_api_key"); if (storedKey) { headers["Authorization"] = `Bearer ${storedKey}`; @@ -46,53 +54,152 @@ export function useEmbeddrImages({ let baseUrl = apiBase; if (baseUrl.endsWith("/")) baseUrl = baseUrl.slice(0, -1); + // Ensure V2 + if (!baseUrl.endsWith("/api/v2")) { + baseUrl = `${baseUrl}/api/v2`; + } - let url; const currentSimilarId = similarId !== undefined ? similarId : similarImageId; + let url = ""; + let method = "GET"; + let body: string | undefined = undefined; + + // V2 API Logic if (currentSimilarId) { - url = `${baseUrl}/images/${currentSimilarId}/similar?limit=20&skip=${offset}`; - if (libraryId) { - url += `&library_id=${libraryId}`; + // Use embeddr-search plugin for similar items + url = `${baseUrl}/plugins/embeddr-search/similar`; + method = "POST"; + body = JSON.stringify({ + artifact_id: currentSimilarId.toString(), + limit: 20, + }); + } else if (searchQuery) { + // Use Embeddr Search Plugin (semantic text search) + url = `${baseUrl}/plugins/embeddr-search/query`; + method = "POST"; + body = JSON.stringify({ + query: searchQuery, + limit: 20, + }); + } else if (apiClient) { + const list = await apiClient.artifacts.list({ + limit: 20, + offset, + type_name: "image", + sort: "new", + library_id: libraryId ? String(libraryId) : undefined, + collection_id: collectionId || undefined, + }); + + const items = list.items || []; + const mapped = items.map((item: any) => { + const id = item.id; + const metadata = item.metadata_json || {}; + return { + id: id, + prompt: metadata.prompt || metadata.filename || "Untitled", + image_url: apiClient.artifacts.getContentUrl(id), + thumb_url: apiClient.artifacts.getPreviewUrl(id, "thumbnail"), + created_at: item.created_at || new Date().toISOString(), + like_count: 0, + liked_by_me: false, + width: metadata.width || 0, + height: metadata.height || 0, + }; + }); + + if (reset) { + setImages(mapped); + pageRef.current = 2; + } else { + setImages((prev) => [...prev, ...mapped]); + pageRef.current = currentPage + 1; } - } else if (mode === "local") { - // Local API - url = `${baseUrl}/images?limit=20&skip=${offset}`; + + setHasMore(offset + mapped.length < list.total); + return; + } else { + // List Artifacts + url = `${baseUrl}/artifacts/?type_name=image&sort=new&limit=20&offset=${offset}`; if (libraryId) { url += `&library_id=${libraryId}`; } + if (collectionId) { + url += `&collection_id=${collectionId}`; + } } - if (searchQuery) { - url += `&q=${encodeURIComponent(searchQuery)}`; + let response = await fetch(url, { method, headers, body }); + + // Fallback for Similar Search if plugin missing (404) + if (!response.ok && currentSimilarId && response.status === 404) { + console.warn( + "Embeddr Search plugin not found, falling back to latest", + ); + url = `${baseUrl}/artifacts/?type_name=image&sort=new&limit=20&offset=${offset}`; + method = "GET"; + body = undefined; + response = await fetch(url, { method, headers }); } - const response = await fetch(url, { headers }); + // Fallback for Text Search if plugin missing (404) + if (!response.ok && searchQuery && response.status === 404) { + console.warn( + "Embeddr Search plugin not found, falling back to simple search", + ); + url = `${baseUrl}/artifacts/search?q=${encodeURIComponent( + searchQuery, + )}&limit=20&offset=${offset}`; + method = "GET"; + body = undefined; + response = await fetch(url, { method, headers }); + } if (response.ok) { const data = await response.json(); let items: Array = []; - if (mode === "local") { - // Local API returns { items: [], total: ... } - items = data.items.map((item: any) => ({ - id: item.id, - prompt: item.prompt, // Use filename as prompt for now - image_url: `${baseUrl}/images/${item.id}/file`, - thumb_url: `${baseUrl}/images/${item.id}/thumbnail`, - created_at: item.created_at, - like_count: 0, - liked_by_me: false, - // Local specific - filename: item.filename, - path: item.path, - width: item.width, - height: item.height, - })); - setHasMore(items.length === 20); // Simple check - } else { - // Cloud API returns array + if (data.items) { + // V2 Paginated Response or Search Response + items = data.items.map((item: any) => { + // Map V2 Artifact to PromptImageRead + // OR Map SearchResultItem (which only has ID/Score) + const id = item.id; + // Check if we have metadata (Artifact) or just ID (Search) + const isFullArtifact = !!item.uri; + const metadata = item.metadata_json || {}; + + return { + id: id, + prompt: + metadata.prompt || + metadata.filename || + (isFullArtifact ? "Untitled" : "Similar Result"), + image_url: apiClient + ? apiClient.artifacts.getContentUrl(id) + : `${baseUrl}/artifacts/${id}/content`, + thumb_url: apiClient + ? apiClient.artifacts.getPreviewUrl(id, "thumbnail") + : `${baseUrl}/artifacts/${id}/preview?preview_type=thumbnail`, + created_at: item.created_at || new Date().toISOString(), + like_count: 0, + liked_by_me: false, + width: metadata.width || 0, + height: metadata.height || 0, + score: item.score, // specific to search + }; + }); + // Handle pagination check + if (data.total !== undefined) { + setHasMore(offset + items.length < data.total); + } else { + // Search plugin might return count + setHasMore(items.length === 20); + } + } else if (Array.isArray(data)) { + // Legacy Cloud API or simple array items = data; setHasMore(data.length === 20); } @@ -105,7 +212,7 @@ export function useEmbeddrImages({ pageRef.current = currentPage + 1; } } else { - throw new Error("Failed to fetch images"); + throw new Error(`Failed to fetch images: ${response.status}`); } } catch (error) { console.error("Error fetching images:", error); diff --git a/ui/hooks/useEmbeddrSettings.ts b/ui/hooks/useEmbeddrSettings.ts index b90ab61..4ed288b 100644 --- a/ui/hooks/useEmbeddrSettings.ts +++ b/ui/hooks/useEmbeddrSettings.ts @@ -34,7 +34,7 @@ export function useEmbeddrSettings({ // computed API base for requests const apiBase = useMemo(() => { const url = endpoint.replace(/\/$/, ""); // remove trailing slash - return `${url}/api/v1`; // append API path automatically + return `${url}/api/v2`; // append API path automatically }, [endpoint]); // Apply theme diff --git a/ui/main.tsx b/ui/main.tsx index b7697c9..23cd608 100644 --- a/ui/main.tsx +++ b/ui/main.tsx @@ -6,9 +6,11 @@ import { ExternalNavProvider } from "@embeddr/react-ui"; import { app } from "../../../scripts/app.js"; import EmbeddrPanel from "./components/panels/EmbeddrPanel.js"; import { GlobalDialog } from "./components/GlobalDialog"; -import "./nodes/EmbeddrLoadImage.js"; +import "./nodes/EmbeddrLoadArtifact.js"; import "./nodes/EmbeddrMergeIds.js"; import "./nodes/EmbeddrLoRAStack.js"; +import "./nodes/EmbeddrUploadArtifact.js"; +import "./nodes/EmbeddrFindCollection.js"; // @ts-ignore import "./globals.css"; diff --git a/ui/nodes/EmbeddrFindCollection.ts b/ui/nodes/EmbeddrFindCollection.ts new file mode 100644 index 0000000..b3c0936 --- /dev/null +++ b/ui/nodes/EmbeddrFindCollection.ts @@ -0,0 +1,52 @@ +// @ts-ignore +import { app } from "../../../scripts/app.js"; + +app.registerExtension({ + name: "Embeddr.FindCollection", + async beforeRegisterNodeDef(nodeType: any, nodeData: any, app: any) { + if (nodeData.name === "embeddr.FindCollection") { + // Add a button to open the dialog + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + const r = onNodeCreated + ? onNodeCreated.apply(this, arguments) + : undefined; + + // Custom Widget for Displaying Collection Info + // We create a custom widget that just renders text, no input + const displayWidget = this.addWidget( + "text", + "Info", + "No Collection Selected", + () => {}, + { serialize: false } + ); + if (displayWidget.inputEl) { + displayWidget.inputEl.readOnly = true; + displayWidget.inputEl.style.opacity = "0.6"; + displayWidget.inputEl.style.fontSize = "10px"; + } + + // Add button widget + this.addWidget("button", "Search Existing", "search_collection", () => { + // Dispatch event to open dialog with collection mode + const event = new CustomEvent("embeddr-open-dialog", { + detail: { + nodeId: this.id, + mode: "collection", + }, + }); + window.dispatchEvent(event); + }); + + // Patch onConfigure to restore display if needed? + // For now, let's just use the widget value if it was saved? + // Ah, serialize: false means it won't be saved. + // Maybe we want to update it if the collection_id has a value on load? + // That would require fetching details which is complex here. + + return r; + }; + } + }, +}); diff --git a/ui/nodes/EmbeddrLoadArtifact.ts b/ui/nodes/EmbeddrLoadArtifact.ts new file mode 100644 index 0000000..c294dd0 --- /dev/null +++ b/ui/nodes/EmbeddrLoadArtifact.ts @@ -0,0 +1,51 @@ +// @ts-ignore +import { app } from "../../../scripts/app.js"; +import { registerNodeDragAndDrop } from "../utils/nodeDragAndDrop.js"; + +app.registerExtension({ + name: "Embeddr.LoadArtifact", + async beforeRegisterNodeDef(nodeType: any, nodeData: any, app: any) { + if (nodeData.name === "embeddr.LoadArtifact") { + // Add a button to open the dialog + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + const r = onNodeCreated + ? onNodeCreated.apply(this, arguments) + : undefined; + + // Add button widget + this.addWidget("button", "Search Image", "search", () => { + // Dispatch event to open dialog + const event = new CustomEvent("embeddr-open-dialog", { + detail: { nodeId: this.id }, + }); + window.dispatchEvent(event); + }); + + return r; + }; + + // Use utility to register drag and drop + registerNodeDragAndDrop(nodeType, { + acceptTypes: ["embeddr/id"], + onDrop: (e: DragEvent, node: any) => { + const id = e.dataTransfer?.getData("embeddr/id"); + if (id) { + const widget = node.widgets?.find( + (w: any) => w.name === "artifact_id" + ); + if (widget) { + widget.value = id; + if (widget.callback) { + widget.callback(widget.value); + } + node.setDirtyCanvas(true, true); + return true; // handled + } + } + return false; + }, + }); + } + }, +}); diff --git a/ui/nodes/EmbeddrLoadImage.ts b/ui/nodes/EmbeddrLoadImage.ts deleted file mode 100644 index eb5f7d6..0000000 --- a/ui/nodes/EmbeddrLoadImage.ts +++ /dev/null @@ -1,35 +0,0 @@ -// @ts-ignore -import { app } from "../../../scripts/app.js"; - -app.registerExtension({ - name: "Embeddr.LoadImage", - async beforeRegisterNodeDef(nodeType: any, nodeData: any, app: any) { - if (nodeData.name === "embeddr.LoadImage") { - // Add a button to open the dialog - const onNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = function () { - const r = onNodeCreated - ? onNodeCreated.apply(this, arguments) - : undefined; - - // Add button widget - this.addWidget("button", "Search Image", "search", () => { - // Dispatch event to open dialog - const event = new CustomEvent("embeddr-open-dialog", { - detail: { nodeId: this.id }, - }); - window.dispatchEvent(event); - }); - - return r; - }; - - // Handle image_id updates (if we want to preview) - // We can reuse the updatePreview logic from LoadImage if we want - // But here we might just want to show the ID or fetch the image. - // The python node returns the image, so ComfyUI will handle the preview if we return it in the output. - // But if we want to show it on the node before execution, we need to fetch it. - // For now, let's just handle the ID insertion. - } - }, -}); diff --git a/ui/nodes/EmbeddrUploadArtifact.ts b/ui/nodes/EmbeddrUploadArtifact.ts new file mode 100644 index 0000000..04c99c2 --- /dev/null +++ b/ui/nodes/EmbeddrUploadArtifact.ts @@ -0,0 +1,67 @@ +// @ts-ignore +import { app } from "../../../scripts/app.js"; +import { registerNodeDragAndDrop } from "../utils/nodeDragAndDrop.js"; + +app.registerExtension({ + name: "Embeddr.UploadArtifact", + async beforeRegisterNodeDef(nodeType: any, nodeData: any, app: any) { + if (nodeData.name === "embeddr.UploadArtifact") { + // Add a button to open the dialog + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + const r = onNodeCreated + ? onNodeCreated.apply(this, arguments) + : undefined; + + // Add button widget + this.addWidget( + "button", + "Select Collection", + "select_collection", + () => { + // Dispatch event to open dialog with collection mode + const event = new CustomEvent("embeddr-open-dialog", { + detail: { + nodeId: this.id, + mode: "collection", + }, + }); + window.dispatchEvent(event); + } + ); + + return r; + }; + + // Use utility to register drag and drop for collections + registerNodeDragAndDrop(nodeType, { + acceptTypes: ["embeddr/collection_id"], + onDrop: (e: DragEvent, node: any) => { + const id = e.dataTransfer?.getData("embeddr/collection_id"); + if (id) { + const widget = node.widgets?.find( + (w: any) => w.name === "collection_ids" + ); + if (widget) { + // Append if shift key held? Or just replace? + // For now, let's assume replacement or simple csv append if non-empty + const current = widget.value ? widget.value.toString() : ""; + if (current && !current.includes(id)) { + widget.value = `${current},${id}`; + } else { + widget.value = id; + } + + if (widget.callback) { + widget.callback(widget.value); + } + node.setDirtyCanvas(true, true); + return true; // handled + } + } + return false; + }, + }); + } + }, +}); diff --git a/ui/utils/nodeDragAndDrop.ts b/ui/utils/nodeDragAndDrop.ts new file mode 100644 index 0000000..b0d5b23 --- /dev/null +++ b/ui/utils/nodeDragAndDrop.ts @@ -0,0 +1,65 @@ +// Emulate a React hook-like behavior but designed for LiteGraph nodes class/prototype usage +// Since LiteGraph nodes are not React components, this is more of a mixin/utility function + +export interface DragAndDropOptions { + // Return true if handled + onDragOver?: (e: DragEvent, node: any) => boolean; + onDrop?: (e: DragEvent, node: any) => boolean; + // If specific data types are required + acceptTypes?: string[]; +} + +/** + * Adds drag and drop file/data handling to a node prototype + */ +export function registerNodeDragAndDrop( + nodeType: any, + options: DragAndDropOptions +) { + const { onDragOver, onDrop, acceptTypes = [] } = options; + + const originalDragOver = nodeType.prototype.onDragOver; + nodeType.prototype.onDragOver = function (e: DragEvent) { + if (e.dataTransfer) { + // Check if we have any of the accepted types + const hasAcceptedType = + acceptTypes.length === 0 || + acceptTypes.some( + (type) => e.dataTransfer && e.dataTransfer.types.includes(type) + ); + + if (hasAcceptedType) { + if (onDragOver) { + // Allow custom handler to intervene + const handled = onDragOver.call(this, e, this); + if (handled) { + e.preventDefault(); + return true; + } + } + // Default behavior if type matches: allow drop + e.preventDefault(); + return true; + } + } + // Fallback to original + return originalDragOver ? originalDragOver.apply(this, arguments) : false; + }; + + const originalDragDrop = nodeType.prototype.onDragDrop; + nodeType.prototype.onDragDrop = function (e: DragEvent) { + if (e.dataTransfer) { + const hasAcceptedType = + acceptTypes.length === 0 || + acceptTypes.some( + (type) => e.dataTransfer && e.dataTransfer.types.includes(type) + ); + + if (hasAcceptedType && onDrop) { + const handled = onDrop.call(this, e, this); + if (handled) return true; + } + } + return originalDragDrop ? originalDragDrop.apply(this, arguments) : false; + }; +} From d788b0d48633ecf06321876d601e8f65dc37b056 Mon Sep 17 00:00:00 2001 From: Nynxz Date: Sun, 8 Feb 2026 21:45:57 +1000 Subject: [PATCH 3/9] checkpoint: pre remote-worker mode --- __init__.py | 165 +++- nodes/EmbeddrExtractArtifactInfo.py | 94 +++ nodes/EmbeddrFindCollection.py | 7 +- nodes/EmbeddrFindSimilarArtifacts.py | 16 +- nodes/EmbeddrFindSimilarText.py | 6 +- nodes/EmbeddrFindSimilarToArtifact.py | 99 +++ nodes/EmbeddrLoadArtifact.py | 77 +- nodes/EmbeddrLoadArtifacts.py | 92 ++- nodes/EmbeddrLoadImages.py | 20 +- nodes/EmbeddrLoadOptions.py | 0 nodes/EmbeddrLoadVideo.py | 3 +- nodes/EmbeddrMergeIDs.py | 6 +- nodes/EmbeddrSplitIDs.py | 33 + nodes/EmbeddrUploadArtifact.py | 120 +-- nodes/EmbeddrUploadOptions.py | 44 ++ nodes/EmbeddrUploadVideo.py | 134 ++-- nodes/types.py | 42 + nodes/utils/api.py | 6 +- nodes/utils/config.py | 19 +- nodes/utils/ids.py | 46 ++ package.json | 10 +- ui/components/GlobalDialog.tsx | 33 +- ui/components/ZenShell.tsx | 1026 +++++++++++++++++++++++++ ui/components/panels/EmbeddrPanel.tsx | 67 +- ui/components/panels/ImageDetails.tsx | 33 +- ui/components/tabs/ExploreTab.tsx | 4 + ui/components/tabs/SettingsForm.tsx | 79 +- ui/components/ui/AuthorizedImage.tsx | 47 ++ ui/components/ui/ImageGrid.tsx | 13 +- ui/globals.css | 395 ++++++++-- ui/hooks/useEmbeddrApi.ts | 38 +- ui/hooks/useEmbeddrCollections.ts | 44 +- ui/hooks/useEmbeddrImages.ts | 82 +- ui/hooks/useEmbeddrLibraries.ts | 28 +- ui/hooks/useEmbeddrSettings.ts | 156 +++- ui/hooks/useThemePacks.ts | 26 + ui/main.tsx | 147 +++- ui/nodes/EmbeddrMergeIds.ts | 6 +- ui/utils/proxyFetch.ts | 32 + ui/utils/themePacks.ts | 154 ++++ 40 files changed, 3084 insertions(+), 365 deletions(-) create mode 100644 nodes/EmbeddrExtractArtifactInfo.py create mode 100644 nodes/EmbeddrFindSimilarToArtifact.py create mode 100644 nodes/EmbeddrLoadOptions.py create mode 100644 nodes/EmbeddrSplitIDs.py create mode 100644 nodes/EmbeddrUploadOptions.py create mode 100644 nodes/types.py create mode 100644 nodes/utils/ids.py create mode 100644 ui/components/ZenShell.tsx create mode 100644 ui/components/ui/AuthorizedImage.tsx create mode 100644 ui/hooks/useThemePacks.ts create mode 100644 ui/utils/proxyFetch.ts create mode 100644 ui/utils/themePacks.ts diff --git a/__init__.py b/__init__.py index c669560..8ca5983 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,6 @@ import os import json +import aiohttp from aiohttp import web from server import PromptServer from comfy_api.latest import ComfyExtension, io @@ -8,15 +9,21 @@ from .nodes.EmbeddrLoadArtifact import EmbeddrLoadArtifactNode from .nodes.EmbeddrLoadArtifacts import EmbeddrLoadArtifactsNode from .nodes.EmbeddrMergeIDs import EmbeddrMergeIDsNode +from .nodes.EmbeddrExtractArtifactInfo import EmbeddrExtractArtifactInfoNode +from .nodes.EmbeddrSplitIDs import EmbeddrSplitIDsNode from .nodes.EmbeddrFindSimilarArtifacts import EmbeddrFindSimilarArtifactsNode +from .nodes.EmbeddrFindSimilarToArtifact import EmbeddrFindSimilarToArtifactNode from .nodes.EmbeddrFindSimilarText import EmbeddrFindSimilarTextNode from .nodes.EmbeddrUploadVideo import EmbeddrUploadVideo from .nodes.EmbeddrLoadVideo import EmbeddrLoadVideoNode from .nodes.EmbeddrLoRAStack import EmbeddrLoRAStack from .nodes.EmbeddrFindCollection import EmbeddrFindCollectionNode +from .nodes.EmbeddrUploadOptions import UploadArtifactOptionsNode CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.json") +print("[Embeddr Extension] Loading... Proxy routes registered.") + def get_api_key(): if os.path.exists(CONFIG_PATH): @@ -29,6 +36,83 @@ def get_api_key(): return "" +def _load_config() -> dict: + if os.path.exists(CONFIG_PATH): + try: + with open(CONFIG_PATH, "r") as f: + return json.load(f) + except Exception: + return {} + return {} + + +def _normalize_base_url(url: str | None, default: str = "http://localhost:8003") -> str: + if not url: + return default + clean = str(url).strip().rstrip("/") + if clean.endswith("/api/v1"): + clean = clean[:-7] + elif clean.endswith("/api"): + clean = clean[:-4] + return clean.rstrip("/") + + +@PromptServer.instance.routes.get("/embeddr/proxy") +@PromptServer.instance.routes.post("/embeddr/proxy") +@PromptServer.instance.routes.put("/embeddr/proxy") +@PromptServer.instance.routes.delete("/embeddr/proxy") +async def proxy_request(request): + url = request.rel_url.query.get("url") + if not url: + return web.Response(status=400, text="Missing url param") + + # Load config to get API key + config = {} + if os.path.exists(CONFIG_PATH): + try: + with open(CONFIG_PATH, "r") as f: + config = json.load(f) + except: + pass + + api_key = config.get("api_key", "") + + headers = {} + if api_key: + headers["X-API-Key"] = api_key + + # Forward Content-Type if present + if "Content-Type" in request.headers: + headers["Content-Type"] = request.headers["Content-Type"] + + method = request.method + data = None + if request.can_read_body: + data = await request.read() + + async with aiohttp.ClientSession() as session: + try: + async with session.request(method, url, headers=headers, data=data) as resp: + # Create response with status code from upstream + response = web.StreamResponse( + status=resp.status, reason=resp.reason) + + # Forward relevant headers + for h in ['Content-Type', 'Content-Length', 'Content-Disposition']: + if h in resp.headers: + response.headers[h] = resp.headers[h] + + await response.prepare(request) + + async for chunk in resp.content.iter_chunked(1024*64): + await response.write(chunk) + + return response + except Exception as e: + print(f"[Embeddr Proxy Error] {e}") + return web.Response(status=500, text=str(e)) + + @PromptServer.instance.routes.post("/embeddr/config") async def save_config(request): try: @@ -66,35 +150,98 @@ async def save_config(request): @PromptServer.instance.routes.get("/embeddr/config") async def get_config(request): - config = {} - if os.path.exists(CONFIG_PATH): - try: - with open(CONFIG_PATH, "r") as f: - config = json.load(f) - except: - pass + config = _load_config() endpoint = config.get("endpoint", "http://localhost:8003") mode = config.get("mode", "local") grid_preview_contain = config.get("grid_preview_contain", False) + api_key = config.get("api_key", "") - # Return masked key for UI + # Return key for UI so frontend can make authorized requests return web.json_response({ "endpoint": endpoint, "mode": mode, - "grid_preview_contain": grid_preview_contain + "grid_preview_contain": grid_preview_contain, + "api_key": api_key }) +@PromptServer.instance.routes.get("/embeddr/health") +async def embeddr_health(request): + config = _load_config() + endpoint = config.get("endpoint") or "http://localhost:8003" + api_key = config.get("api_key") + base_url = _normalize_base_url(endpoint) + + if not api_key: + return web.json_response( + { + "ok": False, + "status": "missing_key", + "endpoint": base_url, + "note": "ComfyUI API key is not configured.", + }, + status=400, + ) + + url = f"{base_url}/api/v1/plugins/embeddr-comfyui/check" + headers = {"X-API-Key": api_key} + + async with aiohttp.ClientSession() as session: + try: + async with session.get(url, headers=headers, timeout=5) as resp: + payload = None + try: + payload = await resp.json() + except Exception: + payload = None + + if resp.status == 403: + return web.json_response( + { + "ok": False, + "status": "unauthorized", + "endpoint": base_url, + "note": "Embeddr rejected the API key.", + "embeddr": payload, + }, + status=403, + ) + + return web.json_response( + { + "ok": resp.status < 400, + "status": resp.status, + "endpoint": base_url, + "embeddr": payload, + }, + status=200 if resp.status < 400 else resp.status, + ) + except Exception as e: + return web.json_response( + { + "ok": False, + "status": "error", + "endpoint": base_url, + "note": str(e), + }, + status=500, + ) + + class EmbeddrComfyUIExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmbeddrFindSimilarArtifactsNode, + EmbeddrFindSimilarToArtifactNode, EmbeddrFindSimilarTextNode, EmbeddrLoadArtifactNode, EmbeddrLoadArtifactsNode, EmbeddrMergeIDsNode, + EmbeddrSplitIDsNode, + EmbeddrExtractArtifactInfoNode, EmbeddrUploadArtifactNode, + UploadArtifactOptionsNode, EmbeddrUploadVideo, EmbeddrLoadVideoNode, EmbeddrLoRAStack, diff --git a/nodes/EmbeddrExtractArtifactInfo.py b/nodes/EmbeddrExtractArtifactInfo.py new file mode 100644 index 0000000..8aceedf --- /dev/null +++ b/nodes/EmbeddrExtractArtifactInfo.py @@ -0,0 +1,94 @@ +from comfy_api.latest import io +from .types import EmbeddrArtifactInfo, EmbeddrArtifactID, EmbeddrArtifactIDObject + + +class EmbeddrExtractArtifactInfoNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.ExtractArtifactInfo", + display_name="Embeddr Extract Artifact Info (V2)", + description="Extracts specific metadata fields from an Artifact Info object.", + category="Embeddr", + inputs=[ + EmbeddrArtifactInfo.Input( + "artifact_info", tooltip="Artifact Info object from Load Artifact"), + ], + outputs=[ + EmbeddrArtifactID.Output( + "parent_ids", tooltip="List of parent Artifact IDs"), + # EmbeddrArtifactID.Output("collection_ids", tooltip="List of collection IDs"), + io.String.Output("tags", tooltip="Comma-separated tags"), + io.String.Output( + "all_json", tooltip="Full JSON dump of the artifact metadata"), + ], + ) + + @classmethod + def execute(cls, artifact_info): + if not artifact_info: + return io.NodeOutput(EmbeddrArtifactIDObject(artifact_id=""), "", "{}") + + data = artifact_info.data + + # Parse Parents (from relations or metadata) + # Note: API usually returns `parents` relation list in `relations` key or similar depending on implementation + # For now, let's assume standard Artifact response structure. + # If the API doesn't return relations inline, we might need to rely on metadata_json['parent_ids'] + + parents = [] + + # Helper to safely dig into dicts + def get_parents_from_dict(d): + if not isinstance(d, dict): + return [] + res = [] + # Check direct key + if "parent_ids" in d: + val = d["parent_ids"] + if isinstance(val, list): + res.extend([str(v) for v in val]) + elif isinstance(val, str): + res.extend([v.strip() for v in val.split(",")]) + + # Check comfy_meta subkey + if "comfy_meta" in d: + res.extend(get_parents_from_dict(d["comfy_meta"])) + + return res + + # 1. Try Top Level + parents.extend(get_parents_from_dict(data)) + + # 2. Try metadata_json (if distinct from data root) + meta = data.get("metadata_json") + if meta and isinstance(meta, dict): + parents.extend(get_parents_from_dict(meta)) + + # Dedup + parents = list(set(parents)) + parent_str = ",".join(parents) + + # Tags (often in relations or separate 'tags' key) + tags = [] + if "tags" in data: + # If tags are objects + if isinstance(data["tags"], list): + for t in data["tags"]: + if isinstance(t, dict): + tags.append(t.get("label", "")) + else: + tags.append(str(t)) + elif "tags" in meta: + tags = meta["tags"] + + tags_str = ",".join(tags) + + import json + json_str = json.dumps(data, indent=2) + + return io.NodeOutput( + EmbeddrArtifactIDObject(artifact_id=parent_str), + tags_str, + json_str + ) diff --git a/nodes/EmbeddrFindCollection.py b/nodes/EmbeddrFindCollection.py index 871b2af..449512e 100644 --- a/nodes/EmbeddrFindCollection.py +++ b/nodes/EmbeddrFindCollection.py @@ -1,6 +1,6 @@ import requests from comfy_api.latest import io -from .utils import get_config +from .utils.config import get_config, get_auth_headers def Embeddr_Log(message: str): @@ -45,7 +45,8 @@ def execute(cls, collection_name, create_if_missing, collection_id=""): try: # 2. List Collections to Find by Name # Note: Removed limit=1000 to avoid potential 422 if API doesn't support it - resp = requests.get(f"{base_url}/api/v2/collections") + resp = requests.get( + f"{base_url}/api/v1/collections", headers=get_auth_headers()) # If 404, maybe endpoint is different. if resp.status_code == 404: # Fallback to V1? Or just fail. @@ -80,7 +81,7 @@ def execute(cls, collection_name, create_if_missing, collection_id=""): "type_name": "collection:mix", "uri": f"embeddr:///collections/{collection_name.lower().replace(' ', '_')}"} resp = requests.post( - f"{base_url}/api/v2/collections", json=payload) + f"{base_url}/api/v1/collections", json=payload, headers=get_auth_headers()) resp.raise_for_status() new_col = resp.json() Embeddr_Log( diff --git a/nodes/EmbeddrFindSimilarArtifacts.py b/nodes/EmbeddrFindSimilarArtifacts.py index bf5c8bc..7233fed 100644 --- a/nodes/EmbeddrFindSimilarArtifacts.py +++ b/nodes/EmbeddrFindSimilarArtifacts.py @@ -5,6 +5,7 @@ import io as pyio from comfy_api.latest import io, ui from .utils import get_config +from .utils.config import get_auth_headers class EmbeddrFindSimilarArtifactsNode(io.ComfyNode): @@ -18,6 +19,7 @@ def define_schema(cls) -> io.Schema: inputs=[ io.Image.Input("image"), io.Int.Input("limit", default=5, min=1, max=50), + io.String.Input("model_name", default="lotus"), ], outputs=[ io.Image.Output("images", is_output_list=True), @@ -26,14 +28,14 @@ def define_schema(cls) -> io.Schema: ) @classmethod - def execute(cls, image, limit): + def execute(cls, image, limit, model_name="lotus"): config = get_config() base_url = config.get("embeddr_url") or config.get( "endpoint") or "http://localhost:8003" base_url = base_url.rstrip("/") # Endpoint in Plugin - api_url = f"{base_url}/api/v2/plugins/embeddr-comfyui/find_similar" + api_url = f"{base_url}/api/v1/plugins/embeddr-comfyui/find_similar" # Prepare image (take first of batch for query) img_array = (image[0].cpu().numpy() * 255).astype(np.uint8) @@ -44,11 +46,12 @@ def execute(cls, image, limit): buf.seek(0) files = {"file": ("query.png", buf, "image/png")} - data = {"limit": limit} + data = {"limit": limit, "model_name": model_name} try: # Upload & Search - response = requests.post(api_url, files=files, data=data) + response = requests.post( + api_url, files=files, data=data, headers=get_auth_headers()) response.raise_for_status() results = response.json() items = results.get("items", []) # List of objects {id, uri, ...} @@ -63,10 +66,11 @@ def execute(cls, image, limit): for item in items: art_id = item.get("id") - content_url = f"{base_url}/api/v2/plugins/embeddr-comfyui/content/{art_id}" + content_url = f"{base_url}/api/v1/plugins/embeddr-comfyui/content/{art_id}" try: - img_resp = requests.get(content_url) + img_resp = requests.get( + content_url, headers=get_auth_headers()) if img_resp.status_code == 200: i = Image.open(pyio.BytesIO(img_resp.content)) i = i.convert("RGB") diff --git a/nodes/EmbeddrFindSimilarText.py b/nodes/EmbeddrFindSimilarText.py index 0769fae..e93daed 100644 --- a/nodes/EmbeddrFindSimilarText.py +++ b/nodes/EmbeddrFindSimilarText.py @@ -6,6 +6,7 @@ from comfy_api.latest import io, ui from .utils import get_config from .utils.api import get_libraries, get_collections +from .utils.config import get_auth_headers class EmbeddrFindSimilarTextNode(io.ComfyNode): @@ -60,7 +61,8 @@ def execute(cls, prompt, library="All", collection="All", limit=5): pass try: - response = requests.get(api_url, params=params) + response = requests.get( + api_url, params=params, headers=get_auth_headers()) response.raise_for_status() results = response.json() items = results.get("items", []) @@ -82,7 +84,7 @@ def execute(cls, prompt, library="All", collection="All", limit=5): # Fetch image file img_url = endpoint.rstrip( "/") + f"/api/v1/images/{item['id']}/file" - img_resp = requests.get(img_url) + img_resp = requests.get(img_url, headers=get_auth_headers()) if img_resp.status_code == 200: i = Image.open(pyio.BytesIO(img_resp.content)) i = ImageOps.exif_transpose(i) diff --git a/nodes/EmbeddrFindSimilarToArtifact.py b/nodes/EmbeddrFindSimilarToArtifact.py new file mode 100644 index 0000000..2eaceca --- /dev/null +++ b/nodes/EmbeddrFindSimilarToArtifact.py @@ -0,0 +1,99 @@ +import requests +import torch +import numpy as np +from PIL import Image +import io as pyio +from comfy_api.latest import io, ui +from .utils import get_config +from .utils.config import get_auth_headers +from .types import EmbeddrArtifactID + + +class EmbeddrFindSimilarToArtifactNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.FindSimilarToArtifact", + display_name="Embeddr Find Similar To Artifact ID", + description="Finds similar artifacts using an existing Artifact ID.", + category="Embeddr", + inputs=[ + EmbeddrArtifactID.Input( + "artifact_id", tooltip="The Source Artifact ID"), + io.Int.Input("limit", default=5, min=1, max=50), + io.String.Input("model_name", default="lotus"), + ], + outputs=[ + io.Image.Output("images", is_output_list=True), + io.String.Output("artifact_ids", is_output_list=True), + ], + ) + + @classmethod + def execute(cls, artifact_id, limit, model_name="lotus"): + config = get_config() + base_url = config.get("embeddr_url") or config.get( + "endpoint") or "http://localhost:8003" + base_url = base_url.rstrip("/") + + # Endpoint in Plugin + api_url = f"{base_url}/api/v1/plugins/embeddr-comfyui/find_similar_to_id" + + # Resolve ID if object + aid = str(artifact_id) + if hasattr(artifact_id, "artifact_id"): + aid = str(artifact_id.artifact_id) + if isinstance(aid, list): + aid = aid[0] # Take first if list + + data = { + "artifact_id": aid, + "limit": limit, + "model_name": model_name + } + + try: + # Search + response = requests.post( + api_url, data=data, headers=get_auth_headers()) + response.raise_for_status() + results = response.json() + items = results.get("items", []) + + if not items: + empty = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput([empty], ["-1"]) + + output_images = [] + output_ids = [] + + for item in items: + art_id = item.get("id") + content_url = f"{base_url}/api/v1/plugins/embeddr-comfyui/content/{art_id}" + + try: + img_resp = requests.get( + content_url, headers=get_auth_headers()) + if img_resp.status_code == 200: + i = Image.open(pyio.BytesIO(img_resp.content)) + i = i.convert("RGB") + i_np = np.array(i).astype(np.float32) / 255.0 + output_images.append(torch.from_numpy(i_np)[None,]) + output_ids.append(str(art_id)) + except Exception as e: + print( + f"Failed to fetch content for similar item {art_id}: {e}") + + if not output_images: + empty = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput([empty], ["-1"]) + + return io.NodeOutput(output_images, output_ids) + + except Exception as e: + print(f"[Embeddr] FindSimilarToID Error: {e}") + empty = torch.zeros( + (1, 64, 64, 3), dtype=torch.float32, device="cpu") + return io.NodeOutput([empty], ["-1"]) diff --git a/nodes/EmbeddrLoadArtifact.py b/nodes/EmbeddrLoadArtifact.py index ce540a1..590563f 100644 --- a/nodes/EmbeddrLoadArtifact.py +++ b/nodes/EmbeddrLoadArtifact.py @@ -8,6 +8,8 @@ from io import BytesIO from comfy_api.latest import io, ui from .utils import get_config +from .utils.config import get_auth_headers +from .types import EmbeddrArtifactID, EmbeddrArtifactIDObject, EmbeddrArtifactInfo, EmbeddrArtifactInfoObject class EmbeddrLoadArtifactNode(io.ComfyNode): @@ -30,10 +32,10 @@ def _debug(cls, message: str, **fields): @classmethod def _resolve_artifact_url(cls, base_url: str, artifact_id: str): - resolve_url = f"{base_url}/api/v2/artifacts/{artifact_id}/resolve?variant=original&proxy=1" + resolve_url = f"{base_url}/api/v1/artifacts/{artifact_id}/resolve?variant=original&proxy=1" cls._debug("resolving_artifact", artifact_id=artifact_id, resolve_url=resolve_url) - res = requests.get(resolve_url) + res = requests.get(resolve_url, headers=get_auth_headers()) res.raise_for_status() data = res.json() url = data.get("url") @@ -41,13 +43,13 @@ def _resolve_artifact_url(cls, base_url: str, artifact_id: str): if url and url.startswith("/"): url = urljoin(base_url, url) - if url and "/api/v2/artifacts/" in url and "/content" in url and "proxy=" not in url: + if url and "/api/v1/artifacts/" in url and "/content" in url and "proxy=" not in url: url = f"{url}?proxy=1" base_netloc = urlparse(base_url).netloc url_netloc = urlparse(url).netloc if url else "" if url and base_netloc and url_netloc and url_netloc != base_netloc: - proxy_url = f"{base_url}/api/v2/artifacts/{artifact_id}/content?proxy=1" + proxy_url = f"{base_url}/api/v1/artifacts/{artifact_id}/content?proxy=1" cls._debug( "forcing_proxy_url", artifact_id=artifact_id, @@ -67,47 +69,82 @@ def define_schema(cls) -> io.Schema: description="Loads an image/artifact from Embeddr by UUID.", category="Embeddr", inputs=[ - io.String.Input("artifact_id", default="", - tooltip="UUID of the artifact to load"), + io.String.Input("manual_artifact_id", + tooltip="Manual UUID string (used if input not connected)", default=""), + EmbeddrArtifactID.Input("artifact_id", + tooltip="UUID of the artifact to load", optional=True), io.Boolean.Input("use_cache", default=True) ], outputs=[ io.Image.Output("image"), io.Mask.Output("mask"), - io.String.Output("artifact_id_out"), + EmbeddrArtifactID.Output("artifact_id_out"), + EmbeddrArtifactInfo.Output("artifact_info"), ], ) @classmethod - def execute(cls, artifact_id, use_cache): - if not artifact_id: + def execute(cls, use_cache, artifact_id=None, manual_artifact_id=None): + # Resolve artifact_id from connection or manual input + resolved_id = "" + + if artifact_id is not None: + if isinstance(artifact_id, EmbeddrArtifactIDObject): + resolved_id = artifact_id.artifact_id + else: + resolved_id = str(artifact_id) + + # Fallback to manual input if connection is empty + if not resolved_id and manual_artifact_id: + resolved_id = str(manual_artifact_id).strip() + + if not resolved_id: # Return empty black image if no ID empty_image = torch.zeros( (1, 64, 64, 3), dtype=torch.float32, device="cpu") empty_mask = torch.zeros( (1, 64, 64), dtype=torch.float32, device="cpu") - return io.NodeOutput(empty_image, empty_mask, "") + return io.NodeOutput( + empty_image, + empty_mask, + EmbeddrArtifactIDObject(artifact_id=""), + EmbeddrArtifactInfoObject(data={}) + ) - if use_cache and artifact_id in cls._cache: - image, mask = cls._cache[artifact_id] - return io.NodeOutput(image, mask, artifact_id) + if use_cache and resolved_id in cls._cache: + image, mask, info = cls._cache[resolved_id] + return io.NodeOutput(image, mask, EmbeddrArtifactIDObject(artifact_id=resolved_id), info) try: config = get_config() base_url = config.get("embeddr_url") or config.get( "endpoint") or "http://localhost:8003" base_url = base_url.rstrip("/") + + # 1. Fetch JSON metadata first + meta_url = f"{base_url}/api/v1/artifacts/{resolved_id}" + meta_res = requests.get(meta_url, headers=get_auth_headers()) + meta_res.raise_for_status() + artifact_data = meta_res.json() + info_obj = EmbeddrArtifactInfoObject(data=artifact_data) + + # 2. Resolve content endpoint, content_headers = cls._resolve_artifact_url( - base_url, artifact_id + base_url, resolved_id ) cls._debug( "requesting_artifact_content", - artifact_id=artifact_id, + artifact_id=resolved_id, endpoint=endpoint, ) - response = requests.get(endpoint, headers=content_headers) + # Merge auth headers with any resolved headers (e.g. S3 signed headers or similar) + final_headers = get_auth_headers() + if content_headers: + final_headers.update(content_headers) + + response = requests.get(endpoint, headers=final_headers) response.raise_for_status() cls._debug( "artifact_content_response", @@ -129,13 +166,13 @@ def execute(cls, artifact_id, use_cache): mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") if use_cache: - cls._cache[artifact_id] = (image, mask) + cls._cache[resolved_id] = (image, mask, info_obj) - return io.NodeOutput(image, mask, artifact_id, ui=ui.PreviewImage(image)) + return io.NodeOutput(image, mask, EmbeddrArtifactIDObject(artifact_id=resolved_id), info_obj, ui=ui.PreviewImage(image)) except Exception as e: - print(f"[Embeddr] Error loading artifact {artifact_id}: {e}") + print(f"[Embeddr] Error loading artifact {resolved_id}: {e}") cls._debug("artifact_load_failed", - artifact_id=artifact_id, error=str(e)) + artifact_id=resolved_id, error=str(e)) # Raise error to stop workflow if loading fails raise e diff --git a/nodes/EmbeddrLoadArtifacts.py b/nodes/EmbeddrLoadArtifacts.py index a3a589b..6c40dc8 100644 --- a/nodes/EmbeddrLoadArtifacts.py +++ b/nodes/EmbeddrLoadArtifacts.py @@ -9,6 +9,9 @@ from comfy_api.latest import io, ui from .utils import get_config from .utils.api import get_collections +from .utils.config import get_auth_headers +from .utils.ids import normalize_ids +from .types import EmbeddrArtifactID class EmbeddrLoadArtifactsNode(io.ComfyNode): @@ -31,10 +34,10 @@ def _debug(cls, message: str, **fields): @classmethod def _resolve_artifact_url(cls, base_url: str, artifact_id: str): - resolve_url = f"{base_url}/api/v2/artifacts/{artifact_id}/resolve?variant=original&proxy=1" + resolve_url = f"{base_url}/api/v1/artifacts/{artifact_id}/resolve?variant=original&proxy=1" cls._debug("resolving_artifact", artifact_id=artifact_id, resolve_url=resolve_url) - res = requests.get(resolve_url) + res = requests.get(resolve_url, headers=get_auth_headers()) res.raise_for_status() data = res.json() url = data.get("url") @@ -42,13 +45,13 @@ def _resolve_artifact_url(cls, base_url: str, artifact_id: str): if url and url.startswith("/"): url = urljoin(base_url, url) - if url and "/api/v2/artifacts/" in url and "/content" in url and "proxy=" not in url: + if url and "/api/v1/artifacts/" in url and "/content" in url and "proxy=" not in url: url = f"{url}?proxy=1" base_netloc = urlparse(base_url).netloc url_netloc = urlparse(url).netloc if url else "" if url and base_netloc and url_netloc and url_netloc != base_netloc: - proxy_url = f"{base_url}/api/v2/artifacts/{artifact_id}/content?proxy=1" + proxy_url = f"{base_url}/api/v1/artifacts/{artifact_id}/content?proxy=1" cls._debug( "forcing_proxy_url", artifact_id=artifact_id, @@ -70,6 +73,8 @@ def define_schema(cls) -> io.Schema: description="Loads generic artifacts (images) from Embeddr using V2 API.", category="Embeddr", inputs=[ + EmbeddrArtifactID.Input( + "artifact_ids", tooltip="Optional list of IDs to load (overrides collection)", optional=True), io.Combo.Input( "collection", options=collections, default="All"), io.Combo.Input("sort_by", options=[ @@ -85,8 +90,18 @@ def define_schema(cls) -> io.Schema: ) @classmethod - def execute(cls, collection, sort_by, limit, seed): - cache_key = (collection, sort_by, limit, seed) + def execute(cls, collection, sort_by, limit, seed, artifact_ids=None): + # Check for explicit IDs first + manual_ids = normalize_ids(artifact_ids) + + # We cache based on manual_ids if present, else collection params + if manual_ids: + # Sort for stability in cache key + manual_ids.sort() + cache_key = ("ids", tuple(manual_ids)) + else: + cache_key = (collection, sort_by, limit, seed) + if cache_key in cls._cache: return cls._cache[cache_key] @@ -96,31 +111,44 @@ def execute(cls, collection, sort_by, limit, seed): "endpoint") or "http://localhost:8003" base_url = base_url.rstrip("/") - # List Artifacts - api_url = f"{base_url}/api/v2/artifacts/" - params = { - "limit": limit, - "type_name": "image", - "offset": 0 - } + items = [] - if collection != "All": - try: - col_id = collection.split(":")[0].strip() - params["collection_id"] = col_id - except: - pass - - if sort_by == "random": - params["sort"] = "random" - params["seed"] = seed # Pass seed if API supports it - else: - params["sort"] = "new" + # 1. Load by IDs + if manual_ids: + # We can fetch them one by one or via a wrapper if API supports batch. + # Assuming singular fetch for now to be safe, or check if /api/v1/artifacts/ supports ids=... + # Iterate and construct items list + for aid in manual_ids: + items.append({"id": aid}) - response = requests.get(api_url, params=params) - response.raise_for_status() - data = response.json() - items = data.get("items", []) + # 2. Load via Search/Collection + else: + # List Artifacts + api_url = f"{base_url}/api/v1/artifacts/" + params = { + "limit": limit, + "type_name": "image", + "offset": 0 + } + + if collection != "All": + try: + col_id = collection.split(":")[0].strip() + params["collection_id"] = col_id + except: + pass + + if sort_by == "random": + params["sort"] = "random" + params["seed"] = seed # Pass seed if API supports it + else: + params["sort"] = "new" + + response = requests.get( + api_url, params=params, headers=get_auth_headers()) + response.raise_for_status() + data = response.json() + items = data.get("items", []) if not items: return cls._return_empty() @@ -135,7 +163,11 @@ def execute(cls, collection, sort_by, limit, seed): content_url, content_headers = cls._resolve_artifact_url( base_url, art_id) try: - c_resp = requests.get(content_url, headers=content_headers) + final_headers = get_auth_headers() + if content_headers: + final_headers.update(content_headers) + + c_resp = requests.get(content_url, headers=final_headers) c_resp.raise_for_status() img = Image.open(BytesIO(c_resp.content)) img = ImageOps.exif_transpose(img) diff --git a/nodes/EmbeddrLoadImages.py b/nodes/EmbeddrLoadImages.py index 739b4b7..61b559e 100644 --- a/nodes/EmbeddrLoadImages.py +++ b/nodes/EmbeddrLoadImages.py @@ -10,6 +10,7 @@ from comfy_api.latest import io, ui from .utils import get_config from .utils.api import get_collections, get_libraries +from .utils.config import get_auth_headers class EmbeddrLoadImagesNode(io.ComfyNode): @@ -32,10 +33,10 @@ def _debug(cls, message: str, **fields): @classmethod def _resolve_artifact_url(cls, base_url: str, artifact_id: str): - resolve_url = f"{base_url}/api/v2/artifacts/{artifact_id}/resolve?variant=original&proxy=1" + resolve_url = f"{base_url}/api/v1/artifacts/{artifact_id}/resolve?variant=original&proxy=1" cls._debug("resolving_artifact", artifact_id=artifact_id, resolve_url=resolve_url) - res = requests.get(resolve_url) + res = requests.get(resolve_url, headers=get_auth_headers()) res.raise_for_status() data = res.json() url = data.get("url") @@ -43,13 +44,13 @@ def _resolve_artifact_url(cls, base_url: str, artifact_id: str): if url and url.startswith("/"): url = urljoin(base_url, url) - if url and "/api/v2/artifacts/" in url and "/content" in url and "proxy=" not in url: + if url and "/api/v1/artifacts/" in url and "/content" in url and "proxy=" not in url: url = f"{url}?proxy=1" base_netloc = urlparse(base_url).netloc url_netloc = urlparse(url).netloc if url else "" if url and base_netloc and url_netloc and url_netloc != base_netloc: - proxy_url = f"{base_url}/api/v2/artifacts/{artifact_id}/content?proxy=1" + proxy_url = f"{base_url}/api/v1/artifacts/{artifact_id}/content?proxy=1" cls._debug( "forcing_proxy_url", artifact_id=artifact_id, @@ -105,7 +106,7 @@ def execute(cls, library, collection, sort_by, limit, seed): base_url = base_url.rstrip("/") # V2 API: List Artifacts - api_url = f"{base_url}/api/v2/artifacts/" + api_url = f"{base_url}/api/v1/artifacts/" params = { "limit": limit, @@ -137,7 +138,8 @@ def execute(cls, library, collection, sort_by, limit, seed): else: params["sort"] = "new" - response = requests.get(api_url, params=params) + response = requests.get( + api_url, params=params, headers=get_auth_headers()) response.raise_for_status() data = response.json() items = data.get("items", []) @@ -166,8 +168,12 @@ def execute(cls, library, collection, sort_by, limit, seed): ) try: + final_headers = get_auth_headers() + if content_headers: + final_headers.update(content_headers) + img_resp = requests.get( - content_url, headers=content_headers) + content_url, headers=final_headers) img_resp.raise_for_status() cls._debug( "artifact_content_response", diff --git a/nodes/EmbeddrLoadOptions.py b/nodes/EmbeddrLoadOptions.py new file mode 100644 index 0000000..e69de29 diff --git a/nodes/EmbeddrLoadVideo.py b/nodes/EmbeddrLoadVideo.py index 8c7c242..42fa5d3 100644 --- a/nodes/EmbeddrLoadVideo.py +++ b/nodes/EmbeddrLoadVideo.py @@ -7,6 +7,7 @@ import shutil from comfy_api.latest import io, ui from .utils import get_config +from .utils.config import get_auth_headers class EmbeddrLoadVideoNode(io.ComfyNode): @@ -70,7 +71,7 @@ def execute(cls, image_id, frame_load_cap, skip_first_frames, select_every_nth, # In production, we should cache the file path if it's the same ID. # Stream download - with requests.get(api_url, stream=True) as r: + with requests.get(api_url, stream=True, headers=get_auth_headers()) as r: r.raise_for_status() # Determine extension content_type = r.headers.get('content-type', '') diff --git a/nodes/EmbeddrMergeIDs.py b/nodes/EmbeddrMergeIDs.py index 15a8954..3dd9f82 100644 --- a/nodes/EmbeddrMergeIDs.py +++ b/nodes/EmbeddrMergeIDs.py @@ -1,4 +1,5 @@ from comfy_api.latest import io, ui +from .types import EmbeddrArtifactID class EmbeddrMergeIDsNode(io.ComfyNode): @@ -12,7 +13,8 @@ def define_schema(cls) -> io.Schema: category="Embeddr", inputs=[], outputs=[ - io.String.Output("ids"), + EmbeddrArtifactID.Output( + "artifact_ids", tooltip="List of merged Artifact IDs", display_name="artifact_ids"), ], ) @@ -21,7 +23,7 @@ def execute(cls, **kwargs): ids = [] # Iterate through all possible inputs for key, value in kwargs.items(): - if key.startswith("id") and value: + if key.startswith("artifact_") and value: ids.append(value) # Return as list of strings diff --git a/nodes/EmbeddrSplitIDs.py b/nodes/EmbeddrSplitIDs.py new file mode 100644 index 0000000..8cd8eb1 --- /dev/null +++ b/nodes/EmbeddrSplitIDs.py @@ -0,0 +1,33 @@ +from comfy_api.latest import io +from .types import EmbeddrArtifactID, EmbeddrArtifactIDObject +from .utils.ids import normalize_ids + + +class EmbeddrSplitIDsNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.SplitIDs", + display_name="Embeddr Split IDs", + description="Splits a comma-separated list of Artifact IDs into individual items for batch processing.", + category="Embeddr", + inputs=[ + EmbeddrArtifactID.Input( + "artifact_ids", tooltip="Comma-separated IDs or list of IDs"), + ], + outputs=[ + EmbeddrArtifactID.Output( + "split_ids", tooltip="Individual IDs (List Execution)", is_output_list=True), + ], + ) + + @classmethod + def execute(cls, artifact_ids): + # Normalize to list of ID strings + final_ids = normalize_ids(artifact_ids) + + # Return a LIST of objects. ComfyUI will iterate this list for downstream nodes if they support it. + results = [EmbeddrArtifactIDObject( + artifact_id=fid) for fid in final_ids] + + return io.NodeOutput(results) diff --git a/nodes/EmbeddrUploadArtifact.py b/nodes/EmbeddrUploadArtifact.py index c0cfccb..e5510bd 100644 --- a/nodes/EmbeddrUploadArtifact.py +++ b/nodes/EmbeddrUploadArtifact.py @@ -3,39 +3,19 @@ from PIL import Image import io as pyio import json +from .utils.ids import normalize_ids +from .utils.config import get_embeddr_base_url, get_upload_mode, get_auth_headers from comfy_api.latest import io, ui -from .utils import get_embeddr_base_url, get_upload_mode +from .EmbeddrUploadOptions import EmbeddrUploadArtifactOptions, EmbeddrUploadArtifactOptionsObject +from .types import EmbeddrArtifactID, EmbeddrArtifactIDObject + +from .utils.ids import normalize_ids def Embeddr_Log(message: str): print(f"[Embeddr] {message}") -def normalize_list(value): - if not value: - return [] - - if isinstance(value, str): - items = [v.strip() for v in value.split(",")] - elif isinstance(value, list): - items = [str(v).strip() for v in value] - else: - raise TypeError(f"Unsupported type: {type(value)}") - - out = [] - seen = set() - for v in items: - if not v: - continue - lv = v.lower() - if lv in ("none", "null", "undefined"): - continue - if v not in seen: - seen.add(v) - out.append(v) - return out - - class EmbeddrUploadArtifactNode(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: @@ -46,25 +26,24 @@ def define_schema(cls) -> io.Schema: is_output_node=True, inputs=[ io.Image.Input("image"), - io.String.Input("parent_ids", default="", optional=True, - tooltip="Comma separated parent artifact UUIDs"), - io.String.Input("collection_ids", default="", optional=True, - tooltip="Comma separated Collection UUIDs (for grouping)"), - io.String.Input("tags", default="generated,comfyui", - tooltip="Comma separated tags"), - io.Boolean.Input("trigger_automation", default=True, - tooltip="Trigger Auto-Analysis (Thumbnails, Embeddings, etc)"), + EmbeddrArtifactID.Input("parent_ids", optional=True, + tooltip="Parent artifact UUIDs"), + EmbeddrUploadArtifactOptions.Input("options", tooltip="Upload Artifact Options", + optional=True, display_name="Options"), ], outputs=[ - io.String.Output("artifact_ids"), + EmbeddrArtifactID.Output("artifact_ids"), ] ) @classmethod - def execute(cls, image, parent_ids, collection_ids, tags, trigger_automation): + def execute(cls, image, parent_ids: EmbeddrArtifactIDObject = None, options: EmbeddrUploadArtifactOptionsObject = None) -> io.NodeOutput: base_url = get_embeddr_base_url() upload_mode = get_upload_mode() - endpoint = f"{base_url}/api/v2/plugins/embeddr-comfyui/upload" + endpoint = f"{base_url}/api/v1/plugins/embeddr-comfyui/upload" + + # parent_ids could be handled by normalize_list directly + normalized_parent_ids = normalize_ids(parent_ids) results = [] @@ -72,17 +51,17 @@ def execute(cls, image, parent_ids, collection_ids, tags, trigger_automation): Embeddr_Log( "Upload disabled (EMBEDDR_UPLOAD_MODE). Skipping Embeddr upload." ) - return io.NodeOutput("", ui=ui.PreviewImage(image)) + return io.NodeOutput(EmbeddrArtifactIDObject(artifact_id=""), ui=ui.PreviewImage(image)) if upload_mode in {"best_effort", "auto"}: try: - health_url = f"{base_url}/api/v2/system/routes" + health_url = f"{base_url}/api/v1/system/routes" requests.get(health_url, timeout=2) except Exception as e: Embeddr_Log( f"Embeddr backend unavailable ({e}); skipping upload." ) - return io.NodeOutput("", ui=ui.PreviewImage(image)) + return io.NodeOutput(EmbeddrArtifactIDObject(artifact_id=""), ui=ui.PreviewImage(image)) # 'image' input is a batch tensor [B, H, W, C] for batch_idx, img_tensor in enumerate(image): @@ -95,25 +74,71 @@ def execute(cls, image, parent_ids, collection_ids, tags, trigger_automation): img_byte_arr = pyio.BytesIO() img.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) + Embeddr_Log(f""" + Uploading artifact batch {batch_idx}... + Storage Provider: {options.storage_provider if options else 'default'} + Storage Path: {options.storage_path if options else 'default'} + Tags: {options.tags if options else 'none'} + Related Artifacts: {options.related_artifact_ids if options else 'none'} + Parent IDs: {normalized_parent_ids} + """) + + storage_provider = None + storage_path = None + if options: + if isinstance(options, dict): + storage_provider = options.get("storage_provider") + storage_path = options.get("storage_path") + else: + storage_provider = getattr( + options, "storage_provider", None) + storage_path = getattr(options, "storage_path", None) + + storage_provider = ( + str(storage_provider).strip().lower() + if storage_provider not in (None, "", "__default__") + else None + ) + storage_path = ( + str(storage_path).strip() + if storage_path not in (None, "", "__default__") + else None + ) # Prepare Metadata meta = { - "parent_ids": normalize_list(parent_ids), - "collection_ids": normalize_list(collection_ids), - "tags": normalize_list(tags), - "trigger_automation": trigger_automation, - "compute_embedding": trigger_automation, # Legacy Compat + "parent_ids": normalized_parent_ids, + "collection_ids": normalize_ids(options.related_artifact_ids) if options else [], + "tags": normalize_ids(options.tags) if options else [], + "trigger_automation": options.trigger_ingest if options else True, + "compute_embedding": options.trigger_ingest if options else True, # Legacy Compat "batch_index": batch_idx, "confirm": True } + if storage_provider: + meta["storage_provider"] = storage_provider + meta["storage_backend"] = storage_provider + if storage_path: + meta["storage_path"] = storage_path + + if storage_provider or storage_path: + Embeddr_Log( + f"Upload storage overrides: provider={storage_provider or 'default'} path={storage_path or 'default'}" + ) + # Prepare multipart upload files = {'file': (f'image_{batch_idx}.png', img_byte_arr, 'image/png')} data = {'metadata': json.dumps(meta)} # Post to Embeddr Core Plugin - response = requests.post(endpoint, files=files, data=data) + response = requests.post( + endpoint, + files=files, + data=data, + headers=get_auth_headers() + ) response.raise_for_status() res_json = response.json() @@ -126,6 +151,7 @@ def execute(cls, image, parent_ids, collection_ids, tags, trigger_automation): # We don't crash the whole node, but result might be partial result_str = ",".join(results) + result_obj = EmbeddrArtifactIDObject(artifact_id=result_str) # Return IDs and UI Preview - return io.NodeOutput(result_str, ui=ui.PreviewImage(image)) + return io.NodeOutput(result_obj, ui=ui.PreviewImage(image)) diff --git a/nodes/EmbeddrUploadOptions.py b/nodes/EmbeddrUploadOptions.py new file mode 100644 index 0000000..87fbd1f --- /dev/null +++ b/nodes/EmbeddrUploadOptions.py @@ -0,0 +1,44 @@ +from comfy_api.latest import io +from .types import EmbeddrUploadArtifactOptions, EmbeddrUploadArtifactOptionsObject + + +class UploadArtifactOptionsNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.UploadArtifact.Options", + display_name="Upload Artifact Options", + category="Embeddr", + inputs=[ + io.String.Input( + "storage_provider", tooltip="Override default storage provider for the artifact", default=None), + io.String.Input( + "storage_path", tooltip="Override default storage path for the artifact", default=None), + + io.Boolean.Input("trigger_ingest", default=True, + tooltip="Trigger ingest process after upload", display_name="Trigger Ingest"), + io.String.Input("tags", default=[], + tooltip="Tags to associate with the uploaded artifact", display_name="Tags"), + io.String.Input("related_artifact_ids", default=[], + tooltip="IDs of related artifacts", display_name="Related Artifact IDs"), + # io.Combo.Input( + # "scale_mode", options=[e.value for e in ResizeModeEnum], default="contain", + # tooltip="Choose how images are scaled to fit the target size", display_name="Scale Mode"), + ], + outputs=[ + EmbeddrUploadArtifactOptions.Output( + "options", tooltip="Upload Artifact Options Object", display_name="options"), + ] + ) + + @classmethod + def execute(cls, storage_provider, storage_path, trigger_ingest, tags, related_artifact_ids): + options = EmbeddrUploadArtifactOptionsObject( + storage_provider=storage_provider, + storage_path=storage_path, + trigger_ingest=trigger_ingest, + tags=tags, + related_artifact_ids=related_artifact_ids, + ) + + return io.NodeOutput(options) diff --git a/nodes/EmbeddrUploadVideo.py b/nodes/EmbeddrUploadVideo.py index 1db9e9f..cea15a0 100644 --- a/nodes/EmbeddrUploadVideo.py +++ b/nodes/EmbeddrUploadVideo.py @@ -1,21 +1,18 @@ -import folder_paths -from .utils.api import get_libraries, get_collections import os import json import requests import tempfile -from comfy_api.latest import io, ui -from comfy_api.latest._io import ComfyNode -from .utils import get_embeddr_base_url +from comfy_api.latest import io +from .utils.ids import normalize_ids +from .utils.config import get_embeddr_base_url, get_upload_mode, get_auth_headers +from .EmbeddrUploadOptions import EmbeddrUploadArtifactOptions, EmbeddrUploadArtifactOptionsObject +from .types import EmbeddrArtifactID, EmbeddrArtifactIDObject class EmbeddrUploadVideo(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: - libraries = ["Default"] + get_libraries() - collections = ["None"] + get_collections() - formats = ["mp4", "mkv", "webm", "mov", "avi"] codecs = ["h264", "h265", "vp9", "vp8", "prores"] @@ -27,21 +24,15 @@ def define_schema(cls) -> io.Schema: inputs=[ io.Video.Input("video", tooltip="The video to save."), io.String.Input("caption", optional=True), - io.String.Input("parent_ids", optional=True), - io.Combo.Input("library", options=libraries, - default="Default"), - io.Combo.Input( - "collection", options=collections, default="None"), - io.String.Input("tags", optional=True, default=""), + EmbeddrArtifactID.Input("parent_ids", optional=True, + tooltip="Parent artifact UUIDs"), + EmbeddrUploadArtifactOptions.Input("options", tooltip="Upload Artifact Options", + optional=True, display_name="Options"), io.Combo.Input("format", options=formats, default="mp4"), io.Combo.Input("codec", options=codecs, default="h264"), - io.Boolean.Input("allow_duplicates", default=False, - display_name="Allow Duplicates"), - io.Boolean.Input("save_backup", default=False, - display_name="Save to Comfy History"), ], outputs=[ - io.String.Output("embeddr_id"), + EmbeddrArtifactID.Output("artifact_ids"), ], ) @@ -50,11 +41,36 @@ def VALIDATE_INPUTS(cls, **kwargs): return True @classmethod - def execute(cls, video, caption=None, parent_ids=None, library="Default", collection="None", tags="", format="mp4", codec="h264", allow_duplicates=False, save_backup=False, **kwargs): + def execute( + cls, + video, + caption=None, + parent_ids: EmbeddrArtifactIDObject = None, + options: EmbeddrUploadArtifactOptionsObject = None, + format="mp4", + codec="h264", + **kwargs, + ): uploaded_ids = [] base_url = get_embeddr_base_url() - api_base_url = f"{base_url}/api/v1" - upload_url = f"{api_base_url}/images/upload" + upload_mode = get_upload_mode() + upload_url = f"{base_url}/api/v1/plugins/embeddr-comfyui/upload" + + normalized_parent_ids = normalize_ids(parent_ids) + + if upload_mode in {"skip", "disabled", "off", "none"}: + print( + "[Embeddr] Upload disabled (EMBEDDR_UPLOAD_MODE). Skipping Embeddr upload.") + return io.NodeOutput(EmbeddrArtifactIDObject(artifact_id="")) + + if upload_mode in {"best_effort", "auto"}: + try: + health_url = f"{base_url}/api/v1/system/routes" + requests.get(health_url, timeout=2) + except Exception as e: + print( + f"[Embeddr] Embeddr backend unavailable ({e}); skipping upload.") + return io.NodeOutput(EmbeddrArtifactIDObject(artifact_id="")) try: # Create temp file @@ -67,38 +83,59 @@ def execute(cls, video, caption=None, parent_ids=None, library="Default", collec # But often strings work or we can map them if we knew the library. video.save_to(temp_path, format=format, codec=codec) - # Upload + storage_provider = None + storage_path = None + if options: + if isinstance(options, dict): + storage_provider = options.get("storage_provider") + storage_path = options.get("storage_path") + else: + storage_provider = getattr( + options, "storage_provider", None) + storage_path = getattr(options, "storage_path", None) + + storage_provider = ( + str(storage_provider).strip().lower() + if storage_provider not in (None, "", "__default__") + else None + ) + storage_path = ( + str(storage_path).strip() + if storage_path not in (None, "", "__default__") + else None + ) + + meta = { + "parent_ids": normalized_parent_ids, + "collection_ids": normalize_ids(options.related_artifact_ids) if options else [], + "tags": normalize_ids(options.tags) if options else [], + "trigger_automation": options.trigger_ingest if options else True, + "compute_embedding": options.trigger_ingest if options else True, + "caption": caption or "", + "confirm": True, + } + + if storage_provider: + meta["storage_provider"] = storage_provider + meta["storage_backend"] = storage_provider + if storage_path: + meta["storage_path"] = storage_path + with open(temp_path, "rb") as f: files = {"file": (f"video.{format}", f, f"video/{format}")} - data = {"prompt": caption or ""} - if allow_duplicates: - data["force"] = "true" - if tags: - data["tags"] = tags - if library != "Default": - try: - data["library_id"] = int(library.split(":")[0]) - except: - pass - if parent_ids: - data["parent_ids"] = parent_ids - - response = requests.post(upload_url, files=files, data=data) + data = {"metadata": json.dumps(meta)} + + response = requests.post( + upload_url, + files=files, + data=data, + headers=get_auth_headers(), + ) response.raise_for_status() result = response.json() uploaded_id = result.get("id") uploaded_ids.append(str(uploaded_id)) - if collection and collection != "None" and uploaded_id: - try: - collection_id = int(collection.split(":")[0]) - requests.post( - f"{api_base_url}/collections/{collection_id}/items", - json={"image_id": uploaded_id} - ) - except Exception as e: - print(f"[Embeddr] Failed to add to collection: {e}") - except Exception as e: print(f"[Embeddr] Video upload failed: {e}") uploaded_ids.append("-1") @@ -106,4 +143,5 @@ def execute(cls, video, caption=None, parent_ids=None, library="Default", collec if 'temp_path' in locals() and os.path.exists(temp_path): os.remove(temp_path) - return io.NodeOutput(",".join(uploaded_ids)) + result_str = ",".join(uploaded_ids) + return io.NodeOutput(EmbeddrArtifactIDObject(artifact_id=result_str)) diff --git a/nodes/types.py b/nodes/types.py new file mode 100644 index 0000000..de31786 --- /dev/null +++ b/nodes/types.py @@ -0,0 +1,42 @@ +from comfy_api.latest._io import ComfyTypeIO, comfytype + + +## Base Artifact ID Type## +@comfytype(io_type="EMBEDDR_ARTIFACT_ID") +class EmbeddrArtifactID(ComfyTypeIO): + Type = str | list[str] + + +class EmbeddrArtifactIDObject: + def __init__(self, artifact_id: str | list[str]): + self.artifact_id: str | list[str] = artifact_id + + +## Upload Artifact Options Type ## +@comfytype(io_type="EMBEDDR_UPLOADARTIFACT_OPTS") +class EmbeddrUploadArtifactOptions(ComfyTypeIO): + Type = object + + +class EmbeddrUploadArtifactOptionsObject: + def __init__(self, storage_provider=None, storage_path=None, trigger_ingest=True, tags=None, related_artifact_ids=None): + self.storage_provider: str | None = storage_provider + self.storage_path: str | None = storage_path + self.trigger_ingest: bool = trigger_ingest + self.tags: list[str] | None = tags + self.related_artifact_ids: list[str] | None = related_artifact_ids + + +## Artifact Info Type ## +@comfytype(io_type="EMBEDDR_ARTIFACT_INFO") +class EmbeddrArtifactInfo(ComfyTypeIO): + Type = dict + + +class EmbeddrArtifactInfoObject: + def __init__(self, data: dict): + self.data = data + self.id = data.get("id") + self.uri = data.get("uri") + self.type_name = data.get("type_name") + self.metadata = data.get("metadata_json") or {} diff --git a/nodes/utils/api.py b/nodes/utils/api.py index fea25a1..18dcf19 100644 --- a/nodes/utils/api.py +++ b/nodes/utils/api.py @@ -1,12 +1,12 @@ import requests -from .config import get_embeddr_base_url +from .config import get_embeddr_base_url, get_auth_headers def get_libraries(): try: base_url = get_embeddr_base_url() api_url = f"{base_url}/api/v1/libraries" - response = requests.get(api_url) + response = requests.get(api_url, headers=get_auth_headers()) if response.status_code == 200: data = response.json() # Return list of names, but we might need IDs. @@ -23,7 +23,7 @@ def get_collections(): try: base_url = get_embeddr_base_url() api_url = f"{base_url}/api/v1/collections" - response = requests.get(api_url) + response = requests.get(api_url, headers=get_auth_headers()) if response.status_code == 200: data = response.json() return [f"{col['id']}: {col['name']}" for col in data] diff --git a/nodes/utils/config.py b/nodes/utils/config.py index 08000e1..f8acfdd 100644 --- a/nodes/utils/config.py +++ b/nodes/utils/config.py @@ -11,7 +11,7 @@ def _normalize_base_url(url: str | None, default: str) -> str: # Strip API suffixes to get the root base if clean.endswith("/api/v1"): clean = clean[:-7] - elif clean.endswith("/api/v2"): + elif clean.endswith("/api/v1"): clean = clean[:-7] elif clean.endswith("/api"): clean = clean[:-4] @@ -52,6 +52,23 @@ def get_embeddr_base_url(default: str = "http://localhost:8003") -> str: return _normalize_base_url(env_url or cfg_url, default) +def get_api_key() -> str | None: + cfg = get_config() + return ( + os.environ.get("EMBEDDR_API_KEY") + or os.environ.get("EMBEDDR_KEY") + or cfg.get("api_key") + or cfg.get("key") + ) + + +def get_auth_headers() -> dict[str, str]: + key = get_api_key() + if key: + return {"X-API-Key": key} + return {} + + def get_upload_mode(default: str = "require") -> str: cfg = get_config() env_mode = os.environ.get("EMBEDDR_UPLOAD_MODE") diff --git a/nodes/utils/ids.py b/nodes/utils/ids.py new file mode 100644 index 0000000..8583fab --- /dev/null +++ b/nodes/utils/ids.py @@ -0,0 +1,46 @@ +def normalize_ids(value) -> list[str]: + """ + Normalizes a variety of input types (single string, list of strings, + comma-separated strings, EmbeddrArtifactIDObjects) into a clean list of string IDs. + """ + if not value: + return [] + + # Helper to extract string from potential ID objects + def extract_val(x): + if hasattr(x, "artifact_id"): + # artifact_id can be str or list + val = x.artifact_id + if isinstance(val, list): + return ",".join(str(v) for v in val) + return str(val) + return str(x) + + # Initial collection of strings + raw_strings = [] + if isinstance(value, str): + raw_strings = [value] + elif isinstance(value, list): + raw_strings = [extract_val(v) for v in value] + elif hasattr(value, "artifact_id"): + raw_strings = [extract_val(value)] + else: + # Fallback for other single types + raw_strings = [str(value)] + + # Flatten by splitting commas and cleaning + out = [] + seen = set() + for s in raw_strings: + parts = s.split(',') + for p in parts: + clean_p = p.strip() + if not clean_p: + continue + if clean_p.lower() in ("none", "null", "undefined"): + continue + + if clean_p not in seen: + seen.add(clean_p) + out.append(clean_p) + return out diff --git a/package.json b/package.json index 53c88fa..9d6b820 100644 --- a/package.json +++ b/package.json @@ -7,9 +7,9 @@ }, "devDependencies": { "@comfyorg/comfyui-frontend-types": "^1.22.1", + "@tanstack/eslint-config": "^0.3.4", "@vitejs/plugin-vue": "^5.2.3", "tw-animate-css": "^1.4.0", - "@tanstack/eslint-config": "^0.3.4", "vite": "^6.3.5", "vite-plugin-vue-devtools": "^7.7.2" }, @@ -22,7 +22,9 @@ "@dnd-kit/sortable": "^10.0.0", "@dnd-kit/utilities": "^3.2.2", "@embeddr/api": "workspace:*", - "@embeddr/react-ui": "^0.1.4", + "@embeddr/react-ui": "workspace:*", + "@embeddr/zen-shell": "workspace:*", + "@fontsource-variable/jetbrains-mono": "^5.2.8", "@radix-ui/react-aspect-ratio": "^1.1.8", "@radix-ui/react-avatar": "^1.1.11", "@radix-ui/react-dialog": "^1.1.15", @@ -34,15 +36,17 @@ "@radix-ui/react-separator": "^1.1.8", "@radix-ui/react-slot": "^1.2.4", "@radix-ui/react-switch": "^1.2.6", - "shadcn": "^3.6.2", "@radix-ui/react-tabs": "^1.1.13", "@tailwindcss/vite": "^4.1.17", + "@tanstack/react-query": "^5.59.0", "@vitejs/plugin-react": "^5.1.1", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "lucide-react": "^0.556.0", "react": "^19.2.1", "react-dom": "^19.2.1", + "recharts": "^2.13.0", + "shadcn": "^3.6.2", "sonner": "^2.0.7", "tailwind-merge": "^3.4.0", "tailwindcss": "^4.1.17" diff --git a/ui/components/GlobalDialog.tsx b/ui/components/GlobalDialog.tsx index f1aebf4..babb311 100644 --- a/ui/components/GlobalDialog.tsx +++ b/ui/components/GlobalDialog.tsx @@ -5,6 +5,7 @@ import { DialogHeader, DialogTitle, } from "@embeddr/react-ui/components/dialog"; +import { useImageDialog } from "@embeddr/react-ui"; import { ExploreTab } from "./tabs/ExploreTab"; import { CollectionSelector } from "./selectors/CollectionSelector"; import { useEmbeddrApi } from "../hooks/useEmbeddrApi"; @@ -19,6 +20,13 @@ export function GlobalDialog() { const [mode, setMode] = useState("image"); const api = useEmbeddrApi(); + const { setApiKey: setDialogApiKey } = useImageDialog(); + + useEffect(() => { + if (setDialogApiKey && api.apiKey) { + setDialogApiKey(api.apiKey); + } + }, [api.apiKey, setDialogApiKey]); useEffect(() => { const handleOpen = (e: Event) => { @@ -46,7 +54,7 @@ export function GlobalDialog() { useEffect(() => { const applyTheme = () => { const portals = document.querySelectorAll( - "[data-radix-portal], [data-slot='dialog-content'], [data-slot='dialog-overlay'], [data-slot='select-content'], [data-slot='select-viewport'], [data-slot='popover-content'], [data-slot='dropdown-menu-content']" + "[data-radix-portal], [data-slot='dialog-content'], [data-slot='dialog-overlay'], [data-slot='select-content'], [data-slot='select-viewport'], [data-slot='popover-content'], [data-slot='dropdown-menu-content']", ); const isDark = api.theme === "dark"; @@ -55,6 +63,12 @@ export function GlobalDialog() { if (!portal.classList.contains("tailwind")) { portal.classList.add("tailwind"); } + if (!portal.classList.contains("font-sans")) { + portal.classList.add("font-sans"); + } + if (!portal.classList.contains("embeddr-theme-root")) { + portal.classList.add("embeddr-theme-root"); + } if (isDark) { portal.classList.add("dark"); } else { @@ -88,7 +102,7 @@ export function GlobalDialog() { // 1. Try "collection_ids" -> Set ID const idWidget = node.widgets?.find( - (w: any) => w.name === "collection_ids" + (w: any) => w.name === "collection_ids", ); if (idWidget) { idWidget.value = item.id.toString(); @@ -100,11 +114,11 @@ export function GlobalDialog() { // 2. Try "collection_id" -> Set ID (for FindCollection node V2) console.log( "[Embeddr] Debug: Listing node widgets:", - node.widgets?.map((w: any) => w.name) + node.widgets?.map((w: any) => w.name), ); const colIdWidget = node.widgets?.find( - (w: any) => w.name === "collection_id" + (w: any) => w.name === "collection_id", ); let idSet = false; @@ -118,7 +132,7 @@ export function GlobalDialog() { } else { console.warn( "[Embeddr] collection_id widget not found on node or invalid item.id!", - { widgetFound: !!colIdWidget, itemId: item.id } + { widgetFound: !!colIdWidget, itemId: item.id }, ); } @@ -135,13 +149,13 @@ export function GlobalDialog() { // If we successfully set the ID, clear the name widget to avoid confusion // and ensure the backend uses the ID. const nameWidget = node.widgets?.find( - (w: any) => w.name === "collection_name" + (w: any) => w.name === "collection_name", ); if (nameWidget) { if (idSet) { // Clear name to prioritize ID match and avoid ambiguity console.log( - "[Embeddr] Clearing collection_name widget to prioritize ID" + "[Embeddr] Clearing collection_name widget to prioritize ID", ); nameWidget.value = ""; } else { @@ -161,7 +175,7 @@ export function GlobalDialog() { // Handle Image Selection // Check for artifact_id (V2) or image_id (V1) const idWidget = node.widgets?.find( - (w: any) => w.name === "artifact_id" || w.name === "image_id" + (w: any) => w.name === "artifact_id" || w.name === "image_id", ); if (idWidget) { idWidget.value = item.id.toString(); @@ -177,7 +191,7 @@ export function GlobalDialog() { // Also update image_url if it exists (for preview/compatibility) const urlWidget = node.widgets?.find( - (w: any) => w.name === "image_url" + (w: any) => w.name === "image_url", ); if (urlWidget) { urlWidget.value = item.image_url; @@ -213,6 +227,7 @@ export function GlobalDialog() { {...api} activeTab="explore" onImageSelect={handleSelect} + apiKey={api.apiKey} /> )}
diff --git a/ui/components/ZenShell.tsx b/ui/components/ZenShell.tsx new file mode 100644 index 0000000..64172f1 --- /dev/null +++ b/ui/components/ZenShell.tsx @@ -0,0 +1,1026 @@ +import React, { useEffect, useState, useMemo, useCallback } from "react"; +import type { EmbeddrAPI } from "@embeddr/react-ui/types"; +import { + ZenPanelManagerCore, + useZenWindowStore, + usePluginRegistry, + loadExternalPlugins, + DynamicPluginComponent, + PluginErrorBoundary, + type PluginManifest, + type ZenWindowRendererProps, + EmbeddrProvider, + type PluginLoaderAdapter, +} from "@embeddr/zen-shell"; +import { useEmbeddrApi } from "../hooks/useEmbeddrApi"; +import { + Terminal, + Grid, + Maximize2, + Minimize2, + X, + Play, + RefreshCw, + LayoutTemplate, +} from "lucide-react"; +import { Button } from "@embeddr/react-ui/components/button"; +import { cn } from "@embeddr/react-ui"; + +// Helper to resolve component ID to plugin and component name +function resolveComponentId(fullId: string, plugins: Record) { + if (!fullId) return null; + + // Try longest prefix match for pluginId + let bestPid: string | null = null; + for (const pid of Object.keys(plugins)) { + // Exact match check (for simple ID cases) + if (fullId === pid) { + if (!bestPid || pid.length > bestPid.length) bestPid = pid; + } + // Prefix match + const prefix = pid + "-"; + if (fullId.startsWith(prefix)) { + if (!bestPid || pid.length > bestPid.length) bestPid = pid; + } + } + + if (!bestPid) return null; + + const localId = fullId.slice(bestPid.length + 1); + const plugin = plugins[bestPid]; + + // If localId is empty, it might be the main/default component + const compDef = plugin.components?.find( + (c: any) => + c.name === localId || + c.component === localId || + c.exportName === localId || + (!localId && c.name === "main"), + ); + + return { + pluginId: bestPid, + componentName: compDef?.exportName || compDef?.component || localId, + def: compDef, + }; +} + +const CustomWindowRenderer = React.memo((props: ZenWindowRendererProps) => { + const { id, windowState, isActive } = props; + const { plugins } = usePluginRegistry(); + const api = useEmbeddrApi(); + const updateWindow = useZenWindowStore((s) => s.updateWindow); + const baseApi = useMemo(() => createEmbeddrApiAdapter(api), [api]); + + const resolved = useMemo( + () => resolveComponentId(windowState.componentId, plugins), + [windowState.componentId, plugins], + ); + + const handleClose = useCallback(() => { + useZenWindowStore.getState().closeWindow(id); + }, [id]); + + const handleFocus = useCallback(() => { + useZenWindowStore.getState().bringToFront(id); + }, [id]); + + if (!resolved) { + return ( +
+ +
+
Component Not Found
+
+ ID: {windowState.componentId} +
+
Plugin might not be loaded yet.
+
+
+
+ ); + } + + return ( +
+ + {(() => { + const pluginApi = extendApiForPlugin(baseApi, resolved.pluginId); + const GlobalEmbeddrProvider = + (window as any).EmbeddrUI?.EmbeddrProvider || EmbeddrProvider; + return ( + updateWindow(id, { position: pos })} + onSizeChange={(next) => updateWindow(id, { size: next })} + > + + + + + + + ); + })()} + +
+ ); +}); + +function BasicWindowPanel({ + id, + title, + position, + size, + isActive, + zIndex, + onClose, + onMouseDown, + onPositionChange, + onSizeChange, + children, +}: { + id: string; + title: string; + position?: { x: number; y: number }; + size?: { width: number; height: number }; + isActive?: boolean; + zIndex?: number; + onClose: () => void; + onMouseDown?: (event: React.MouseEvent) => void; + onPositionChange?: (pos: { x: number; y: number }) => void; + onSizeChange?: (size: { width: number; height: number }) => void; + children: React.ReactNode; +}) { + const [pos, setPos] = useState(position || { x: 100, y: 100 }); + const [dimensions, setDimensions] = useState( + size || { width: 500, height: 400 }, + ); + const dragRef = React.useRef<{ + startX: number; + startY: number; + pointerId: number; + } | null>(null); + const resizeRef = React.useRef<{ + startX: number; + startY: number; + startW: number; + startH: number; + pointerId: number; + } | null>(null); + + useEffect(() => { + if (position) setPos(position); + }, [position?.x, position?.y]); + + useEffect(() => { + if (size) setDimensions(size); + }, [size?.width, size?.height]); + + const handlePointerDown = (event: React.PointerEvent) => { + event.stopPropagation(); + onMouseDown?.(event as unknown as React.MouseEvent); + dragRef.current = { + startX: event.clientX - pos.x, + startY: event.clientY - pos.y, + pointerId: event.pointerId, + }; + }; + + const handlePointerMove = useCallback((event: PointerEvent) => { + if (!dragRef.current || dragRef.current.pointerId !== event.pointerId) { + return; + } + const next = { + x: event.clientX - dragRef.current.startX, + y: event.clientY - dragRef.current.startY, + }; + setPos(next); + }, []); + + const handlePointerUp = useCallback( + (event: PointerEvent) => { + if (!dragRef.current || dragRef.current.pointerId !== event.pointerId) { + return; + } + dragRef.current = null; + onPositionChange?.(pos); + }, + [onPositionChange, pos], + ); + + const handleResizeDown = (event: React.PointerEvent) => { + event.stopPropagation(); + resizeRef.current = { + startX: event.clientX, + startY: event.clientY, + startW: dimensions.width, + startH: dimensions.height, + pointerId: event.pointerId, + }; + }; + + const handleResizeMove = useCallback( + (event: PointerEvent) => { + if ( + !resizeRef.current || + resizeRef.current.pointerId !== event.pointerId + ) { + return; + } + const next = { + width: Math.max( + 240, + resizeRef.current.startW + (event.clientX - resizeRef.current.startX), + ), + height: Math.max( + 180, + resizeRef.current.startH + (event.clientY - resizeRef.current.startY), + ), + }; + setDimensions(next); + }, + [dimensions.width, dimensions.height], + ); + + const handleResizeUp = useCallback( + (event: PointerEvent) => { + if ( + !resizeRef.current || + resizeRef.current.pointerId !== event.pointerId + ) { + return; + } + resizeRef.current = null; + onSizeChange?.(dimensions); + }, + [dimensions, onSizeChange], + ); + + useEffect(() => { + window.addEventListener("pointermove", handlePointerMove); + window.addEventListener("pointerup", handlePointerUp); + window.addEventListener("pointermove", handleResizeMove); + window.addEventListener("pointerup", handleResizeUp); + return () => { + window.removeEventListener("pointermove", handlePointerMove); + window.removeEventListener("pointerup", handlePointerUp); + window.removeEventListener("pointermove", handleResizeMove); + window.removeEventListener("pointerup", handleResizeUp); + }; + }, [handlePointerMove, handlePointerUp, handleResizeMove, handleResizeUp]); + + return ( +
+
+
{title}
+ +
+
{children}
+
+
+ ); +} + +class WindowErrorBoundary extends React.Component< + { + title: string; + onClose: () => void; + position?: { x: number; y: number }; + size?: { width: number; height: number }; + children: React.ReactNode; + }, + { error?: Error } +> { + state = { error: undefined } as { error?: Error }; + + static getDerivedStateFromError(error: Error) { + return { error }; + } + + componentDidCatch(error: Error) { + console.error("[ZenShell] Window renderer crashed", error); + } + + render() { + if (this.state.error) { + const pos = this.props.position || { x: 100, y: 100 }; + const size = this.props.size || { width: 500, height: 400 }; + return ( +
+
+
+ {this.props.title} +
+ +
+
+ {this.state.error.message} +
+
+ ); + } + + return this.props.children; + } +} + +type EmbeddrApiAdapterInput = ReturnType; + +function createEmbeddrApiAdapter(input: EmbeddrApiAdapterInput): EmbeddrAPI { + const backendUrl = (input.endpoint || "http://localhost:8003").replace( + /\/$/, + "", + ); + const apiBase = `${backendUrl}/api/v1`; + + const jsonRequest = async (path: string, init?: RequestInit) => { + const normalized = path.startsWith("/") ? path : `/${path}`; + const url = path.startsWith("http") ? path : `${apiBase}${normalized}`; + + const key = input.apiKey || ""; + const headers = new Headers(init?.headers || {}); + if (key && !headers.has("X-API-Key")) { + headers.set("X-API-Key", key); + } + const nextInit: RequestInit = { ...init, headers }; + + const addTrailingSlash = (inputUrl: string) => { + const [base, query] = inputUrl.split("?"); + if (base.endsWith("/")) return inputUrl; + return query ? `${base}/?${query}` : `${base}/`; + }; + + const run = async (target: string) => { + const res = await input.apiClient.fetch(target, nextInit); + if (res.ok) return res.json(); + return res; + }; + + const first = await run(url); + if (first instanceof Response) { + if (first.status === 404) { + const fallbackUrl = addTrailingSlash(url); + const second = await run(fallbackUrl); + if (second instanceof Response) { + const txt = await second.text().catch(() => ""); + throw new Error(txt || second.statusText || "Request failed"); + } + return second; + } + const txt = await first.text().catch(() => ""); + throw new Error(txt || first.statusText || "Request failed"); + } + return first; + }; + + const eventTarget = new EventTarget(); + + const api: EmbeddrAPI = { + stores: { + global: { + selectedImage: null, + selectImage: () => {}, + }, + generation: { + workflows: [], + selectedWorkflow: null, + generations: [], + isGenerating: false, + generate: async () => {}, + setWorkflowInput: () => {}, + selectWorkflow: () => {}, + }, + }, + ui: { + activePanelId: null, + isPanelActive: () => false, + }, + workspaces: { + getState: () => ({}), + subscribe: () => () => {}, + list: () => [], + getActiveId: () => null, + ensureDefault: () => {}, + create: () => "default", + save: () => {}, + saveActive: () => {}, + apply: () => {}, + rename: () => {}, + clone: () => null, + remove: () => {}, + setTemplate: () => {}, + }, + settings: { + get: (key: string, defaultValue?: T) => { + const raw = localStorage.getItem(key); + return (raw !== null ? (JSON.parse(raw) as T) : defaultValue) as T; + }, + set: (key: string, value: any) => { + localStorage.setItem(key, JSON.stringify(value)); + }, + getPlugin: (pluginId: string, key: string, defaultValue?: T) => { + const raw = localStorage.getItem(`${pluginId}:${key}`); + return (raw !== null ? (JSON.parse(raw) as T) : defaultValue) as T; + }, + setPlugin: (pluginId: string, key: string, value: any) => { + localStorage.setItem(`${pluginId}:${key}`, JSON.stringify(value)); + }, + }, + toast: { + success: (message: string) => console.log("[Embeddr]", message), + error: (message: string) => console.error("[Embeddr]", message), + info: (message: string) => console.info("[Embeddr]", message), + }, + utils: { + backendUrl, + getApiKey: () => input.apiKey || null, + uploadImage: async () => { + throw new Error("uploadImage not implemented in ComfyUI shell"); + }, + getPluginUrl: (path: string) => { + const cleanPath = path.startsWith("/") ? path.slice(1) : path; + return `${apiBase}/plugins/${cleanPath}`; + }, + }, + artifacts: { + list: (inputData) => { + const q = new URLSearchParams(); + if (inputData?.limit !== undefined) + q.append("limit", String(inputData.limit)); + if (inputData?.offset !== undefined) + q.append("offset", String(inputData.offset)); + if (inputData?.type_name) q.append("type_name", inputData.type_name); + if (inputData?.sort) q.append("sort", inputData.sort); + if (inputData?.ids?.length) q.append("ids", inputData.ids.join(",")); + const qs = q.toString(); + return jsonRequest(`/artifacts${qs ? `?${qs}` : ""}`); + }, + get: (id: string) => jsonRequest(`/artifacts/${id}`), + getContentUrl: (id: string) => `${apiBase}/artifacts/${id}/content`, + resolve: (inputData: any) => jsonRequest(`/artifacts/${inputData.id}`), + getPreviewUrl: ( + id: string, + type: "thumbnail" | "preview" = "thumbnail", + ) => `${apiBase}/artifacts/${id}/preview?preview_type=${type}`, + getEmbeddings: (id: string) => jsonRequest(`/artifacts/${id}/embeddings`), + getAnnotations: (id: string) => + jsonRequest(`/artifacts/${id}/annotations`), + getLineage: (id: string) => jsonRequest(`/artifacts/${id}/lineage`), + getRelations: (id: string) => jsonRequest(`/artifacts/${id}/relations`), + addRelation: (sourceId: string, inputData: any) => + jsonRequest(`/artifacts/${sourceId}/relations`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + target_id: inputData?.target_id, + relation_type: inputData?.relation_type || "contains", + metadata_json: inputData?.metadata_json || {}, + }), + }), + getSubgraph: (id: string, params: any) => { + const q = new URLSearchParams(); + if (params?.maxDepth !== undefined) + q.append("max_depth", String(params.maxDepth)); + if (params?.includeLineage !== undefined) + q.append("include_lineage", String(params.includeLineage)); + if (params?.includeRelations !== undefined) + q.append("include_relations", String(params.includeRelations)); + const qs = q.toString(); + return jsonRequest(`/artifacts/${id}/subgraph${qs ? `?${qs}` : ""}`); + }, + }, + collections: input.apiClient.collections, + library: input.apiClient.collections as any, + executions: { + create: (payload) => + jsonRequest(`/executions`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload), + }), + get: (executionId) => jsonRequest(`/executions/${executionId}`), + list: (payload) => { + const q = new URLSearchParams(); + if (payload?.plugin_name) q.append("plugin_name", payload.plugin_name); + if (payload?.status) q.append("status", payload.status); + if (payload?.limit) q.append("limit", String(payload.limit)); + if (payload?.offset) q.append("offset", String(payload.offset)); + return jsonRequest(`/executions?${q.toString()}`); + }, + }, + lotus: { + invoke: (capId: string, payload?: Record) => + jsonRequest(`/lotus/${capId}`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(payload ?? {}), + }), + query: (query: string, limit = 20) => + jsonRequest( + `/lotus/query?q=${encodeURIComponent(query)}&limit=${limit}`, + ), + list: (payload?: any) => { + const q = new URLSearchParams(); + if (payload?.kind) q.append("kind", payload.kind); + if (payload?.plugin) q.append("plugin", payload.plugin); + if (payload?.slot) q.append("slot", payload.slot); + if (payload?.limit) q.append("limit", String(payload.limit)); + if (payload?.offset) q.append("offset", String(payload.offset)); + return jsonRequest( + `/lotus/list${q.toString() ? `?${q.toString()}` : ""}`, + ); + }, + }, + client: { + plugins: { + call: (pluginId: string, path: string, method = "GET", body?: any) => { + const normalized = path.startsWith("/") ? path : `/${path}`; + return jsonRequest(`/plugins/${pluginId}${normalized}`, { + method, + headers: body ? { "Content-Type": "application/json" } : undefined, + body: body ? JSON.stringify(body) : undefined, + }); + }, + }, + } as any, + events: { + on: (event, listener) => { + const handler = (e: Event) => listener((e as CustomEvent).detail); + eventTarget.addEventListener(event, handler as EventListener); + return () => + eventTarget.removeEventListener(event, handler as EventListener); + }, + off: (event, listener) => { + eventTarget.removeEventListener(event, listener as EventListener); + }, + emit: (event, payload) => { + eventTarget.dispatchEvent(new CustomEvent(event, { detail: payload })); + }, + }, + comfy: { + getLoras: async () => ({ items: [], total: 0, page: 1, pages: 1 }), + getCheckpoints: async () => ({ items: [], total: 0, page: 1, pages: 1 }), + getEmbeddings: async () => ({ items: [], total: 0, page: 1, pages: 1 }), + getSamplers: async () => ({ samplers: [], schedulers: [] }), + }, + windows: { + open: (id: string, title: string, componentId: string, props?: any) => + useZenWindowStore.getState().openWindow({ + id, + title, + componentId, + props, + }), + spawn: (componentId: string, title: string, props?: any) => + useZenWindowStore.getState().spawnWindow(componentId, title, props), + register: () => {}, + getState: () => useZenWindowStore.getState(), + list: () => Object.values(useZenWindowStore.getState().windows), + }, + }; + + (api as any).__proxyFetch = input.apiClient.fetch; + return api; +} + +function extendApiForPlugin(api: EmbeddrAPI, pluginId: string): EmbeddrAPI { + if (!api?.utils) return api; + return { + ...api, + utils: { + ...api.utils, + getPluginUrl: (path: string) => { + const cleanPath = path.startsWith("/") ? path.slice(1) : path; + return `${api.utils.backendUrl}/api/v1/plugins/${pluginId}/${cleanPath}`; + }, + }, + plugin: { + fetch: (path: string, init?: RequestInit) => { + const proxyFetch = (api as any).__proxyFetch || fetch; + const key = (api as any).utils?.getApiKey?.() || ""; + const headers = new Headers(init?.headers || {}); + if (key && !headers.has("X-API-Key")) headers.set("X-API-Key", key); + const nextInit: RequestInit = { ...init, headers }; + if (path.startsWith("http")) { + return proxyFetch(path, nextInit); + } + const cleanPath = path.startsWith("/") ? path.slice(1) : path; + const url = `${api.utils.backendUrl}/api/v1/plugins/${pluginId}/${cleanPath}`; + return proxyFetch(url, nextInit); + }, + request: async (path: string, init?: RequestInit) => { + const proxyFetch = (api as any).__proxyFetch || fetch; + const key = (api as any).utils?.getApiKey?.() || ""; + const headers = new Headers(init?.headers || {}); + if (key && !headers.has("X-API-Key")) headers.set("X-API-Key", key); + const nextInit: RequestInit = { ...init, headers }; + const url = path.startsWith("http") + ? path + : `${api.utils.backendUrl}/api/v1/plugins/${pluginId}/${ + path.startsWith("/") ? path.slice(1) : path + }`; + + if ( + url.includes("/api/v1/lotus/") && + (nextInit.method || "GET").toUpperCase() === "POST" + ) { + const capId = url.split("/api/v1/lotus/")[1] || ""; + let payload: any = undefined; + if (typeof nextInit.body === "string") { + try { + payload = JSON.parse(nextInit.body); + } catch { + payload = {}; + } + } else if (nextInit.body && typeof nextInit.body === "object") { + payload = nextInit.body as any; + } + return api.lotus.invoke(capId, payload); + } + const res = await proxyFetch(url, nextInit); + if (!res.ok) { + const errorText = await res.text().catch(() => res.statusText); + throw new Error(errorText || `Request failed: ${res.status}`); + } + return res.json(); + }, + } as any, + } as EmbeddrAPI; +} + +export function ZenShell() { + console.log("[ZenShell] Rendering..."); + const [isOpen, setIsOpen] = useState(false); + const [minimized, setMinimized] = useState(false); + + let api; + try { + api = useEmbeddrApi(); + } catch (e) { + console.error("[ZenShell] Failed to get API context", e); + return null; + } + + const { plugins, knownPlugins } = usePluginRegistry(); + const spawnWindow = useZenWindowStore((s) => s.spawnWindow); + const embeddrApi = useMemo( + () => createEmbeddrApiAdapter(api), + [api.endpoint, api.apiKey, api.apiClient], + ); + + useEffect(() => { + console.log("[ZenShell] Mounted"); + return () => console.log("[ZenShell] Unmounted"); + }, []); + + const adapter = useMemo(() => { + console.log("[ZenShell] Recreating adapter"); + return { + list: async () => { + try { + const baseUrl = api.endpoint || "http://localhost:8003"; + // Reverting to /v2/plugins if that's what was working, or checking both? + // Let's assume /api/v1/plugins is correct based on recent changes, but we'll log it. + const targetUrl = baseUrl.endsWith("/") + ? `${baseUrl}api/v1/plugins` + : `${baseUrl}/api/v1/plugins`; + + console.log("[ZenShell] Fetching plugins from", targetUrl); + const res = await api.apiClient.fetch(targetUrl); + if (!res.ok) { + console.error( + "[ZenShell] Plugin fetch failed", + res.status, + res.statusText, + ); + return []; + } + const data = await res.json(); + console.log("[ZenShell] Fetched plugins:", data.length); + return data; + } catch (e) { + console.error("Failed to list plugins via proxy", e); + return []; + } + }, + resolveScriptUrl: (manifest) => { + const baseUrl = (api.endpoint || "http://localhost:8003").replace( + /\/$/, + "", + ); + const url = manifest.url; + + if (!url) return ""; + + if (url.startsWith("/")) { + const target = `${baseUrl}${url}`; + return `/embeddr/proxy?url=${encodeURIComponent(target)}`; + } + return url; + }, + resolveCssUrl: (manifest) => { + const baseUrl = (api.endpoint || "http://localhost:8003").replace( + /\/$/, + "", + ); + let url = manifest.url; + + if (!url) return null; + + if (url.startsWith("/")) { + if (url.endsWith(".js")) { + url = url.replace(".js", ".css"); + } + const target = `${baseUrl}${url}`; + return `/embeddr/proxy?url=${encodeURIComponent(target)}`; + } + + if (url.endsWith(".js")) return url.replace(".js", ".css"); + return null; + }, + }; + }, [api.endpoint, api.apiClient]); + + useEffect(() => { + const handleToggle = () => { + console.log("[ZenShell] Toggle event received"); + setIsOpen((prev) => !prev); + }; + const handleLaunch = (e: CustomEvent) => { + console.log("[ZenShell] Launch event received", e.detail); + setIsOpen(true); + if (e.detail && e.detail.componentId) { + const title = e.detail.title || e.detail.componentId; + spawnWindow(e.detail.componentId, title, e.detail.props); + } + }; + + const targets: Window[] = []; + const addTarget = (target?: Window | null) => { + if (!target) return; + if (!targets.includes(target)) targets.push(target); + }; + + addTarget(window); + try { + addTarget(window.parent); + } catch (e) { + console.warn("[ZenShell] Unable to access window.parent", e); + } + try { + addTarget(window.top); + } catch (e) { + console.warn("[ZenShell] Unable to access window.top", e); + } + + targets.forEach((target) => { + target.addEventListener("embeddr-toggle-shell", handleToggle); + target.addEventListener( + "embeddr-launch-window", + handleLaunch as EventListener, + ); + }); + + return () => { + targets.forEach((target) => { + target.removeEventListener("embeddr-toggle-shell", handleToggle); + target.removeEventListener( + "embeddr-launch-window", + handleLaunch as EventListener, + ); + }); + }; + }, [spawnWindow]); + + // Initial Load + useEffect(() => { + if (!api.configLoaded) return; + console.log("[ZenShell] Loading external plugins..."); + loadExternalPlugins({ adapter }); + }, [api.configLoaded, adapter]); + + // We always render the Manager (so windows exist), but maybe hide the Dock + return ( + <> +
+ {/* The Window Manager Layer */} +
+ + + +
+
+ + {/* The Shell Dock / Launcher */} + {isOpen && ( +
setMinimized(false) : undefined} + > + {minimized ? ( +
+ +
+ ) : ( + <> + {/* Header */} +
+
+ + Zen Launcher +
+
+ + +
+
+ + {/* Content */} +
+
+ Available Plugins ({knownPlugins.length}) +
+ {knownPlugins.length === 0 && ( +
+ + Loading plugins... +
+ )} + + {knownPlugins.length > 0 && ( +
+ {knownPlugins.map((pid, idx) => { + const p = plugins[pid]; + return ( +
+
+
+
+ {p.name || pid} +
+
+
+ {p.components?.map((c: any, cIdx: number) => { + const componentId = `${pid}-${ + c.exportName || + c.component || + c.name || + `comp-${cIdx}` + }`; + return ( + + ); + })} +
+
+ ); + })} +
+ )} + + +
+ + )} +
+ )} + + ); +} diff --git a/ui/components/panels/EmbeddrPanel.tsx b/ui/components/panels/EmbeddrPanel.tsx index fba7560..2e64125 100644 --- a/ui/components/panels/EmbeddrPanel.tsx +++ b/ui/components/panels/EmbeddrPanel.tsx @@ -6,8 +6,14 @@ import { TabsList, TabsTrigger, } from "@embeddr/react-ui/components/tabs"; -import { useExternalNav } from "@embeddr/react-ui"; -import { GlobeIcon, MessageCircleIcon, Search, Settings } from "lucide-react"; +import { useExternalNav, useImageDialog } from "@embeddr/react-ui"; +import { + GlobeIcon, + MessageCircleIcon, + Search, + Settings, + LayoutTemplate, +} from "lucide-react"; import { Button } from "@embeddr/react-ui/components/button"; import { useEmbeddrApi } from "@hooks/useEmbeddrApi"; import { SettingsForm } from "../tabs/SettingsForm"; @@ -34,10 +40,22 @@ export default function EmbeddrPanel() { setSimilarImageId, theme, setTheme, + themePackId, + setThemePackId, + apiBase, apiClient, + apiKey, + setApiKey, } = useEmbeddrApi(); const { openExternal } = useExternalNav(); + const { setApiKey: setDialogApiKey } = useImageDialog(); + + useEffect(() => { + if (setDialogApiKey && apiKey) { + setDialogApiKey(apiKey); + } + }, [apiKey, setDialogApiKey]); const [activeTab, setActiveTab] = useState("explore"); @@ -55,7 +73,36 @@ export default function EmbeddrPanel() { }, [viewMode, configLoaded, selectedLibrary, mode, similarImageId]); const handleSave = async () => { - await saveSettings(endpoint, mode, gridSize, gridPreviewContain); + await saveSettings(endpoint, mode, gridSize, gridPreviewContain, apiKey); + }; + + const dispatchShellEvent = ( + name: string, + detail?: Record, + ) => { + const targets: Window[] = []; + console.log("[EmbeddrPanel] Dispatching event", name, detail); + const addTarget = (target?: Window | null) => { + if (!target) return; + if (!targets.includes(target)) targets.push(target); + }; + addTarget(window); + try { + addTarget(window.parent); + } catch (e) { + console.warn("[EmbeddrPanel] Unable to access window.parent", e); + } + try { + addTarget(window.top); + } catch (e) { + console.warn("[EmbeddrPanel] Unable to access window.top", e); + } + + targets.forEach((target) => { + const event = new CustomEvent(name, { detail }); + console.log("[EmbeddrPanel] Sending event to target", target, event); + target.dispatchEvent(event); + }); }; return ( @@ -84,6 +131,14 @@ export default function EmbeddrPanel() { +
+
+ + setApiKey && setApiKey(e.target.value)} + /> +

+ Optional API Key for authentication (X-API-Key). +

+
+
@@ -139,6 +162,34 @@ export function SettingsForm({
+
+ + +

+ Theme packs apply token overrides and optional CSS. +

+
+ diff --git a/ui/components/ui/AuthorizedImage.tsx b/ui/components/ui/AuthorizedImage.tsx new file mode 100644 index 0000000..fcb30c3 --- /dev/null +++ b/ui/components/ui/AuthorizedImage.tsx @@ -0,0 +1,47 @@ +import React, { useEffect, useState, useRef } from "react"; +import { cn } from "@embeddr/react-ui"; + +interface AuthorizedImageProps extends React.ImgHTMLAttributes { + src: string; + fallbackSrc?: string; + authHeader?: Record; + apiKey?: string; +} + +export function AuthorizedImage({ + src, + fallbackSrc, + className, + alt, + authHeader, + apiKey, + ...props +}: AuthorizedImageProps) { + // Switch to Proxy strategy for robustness against CORS and Auth issues. + const isEmbeddrUrl = + src.includes("/api/v") && + (src.includes("/artifacts") || src.includes("/content")); + const shouldUseProxy = (!!apiKey || isEmbeddrUrl) && src.startsWith("http"); + const proxySrc = shouldUseProxy + ? `/embeddr/proxy?url=${encodeURIComponent(src)}` + : src; + + const [error, setError] = useState(false); + + if (error && fallbackSrc) { + return {alt}; + } + + return ( + {alt} { + if (fallbackSrc) e.currentTarget.src = fallbackSrc; + setError(true); + }} + {...props} + /> + ); +} diff --git a/ui/components/ui/ImageGrid.tsx b/ui/components/ui/ImageGrid.tsx index aba1373..e1622f3 100644 --- a/ui/components/ui/ImageGrid.tsx +++ b/ui/components/ui/ImageGrid.tsx @@ -3,6 +3,7 @@ import { ScrollArea } from "@embeddr/react-ui/components/scroll-area"; import { Eye, PenLineIcon } from "lucide-react"; import { Button } from "@embeddr/react-ui/components/button"; import { cn } from "@embeddr/react-ui"; +import { AuthorizedImage } from "./AuthorizedImage"; import type { PromptImageRead } from "@hooks/useEmbeddrApi"; @@ -18,6 +19,7 @@ interface ImageGridProps { gridSize?: number; imagePreviewContain?: boolean; scrollRef?: React.RefObject; + apiKey?: string; } export function ImageGrid({ @@ -32,6 +34,7 @@ export function ImageGrid({ gridSize = 3, imagePreviewContain = true, scrollRef, + apiKey, }: ImageGridProps) { const observerTarget = useRef(null); @@ -42,7 +45,7 @@ export function ImageGrid({ onLoadMore(); } }, - { threshold: 0, rootMargin: "200px" } + { threshold: 0, rootMargin: "200px" }, ); if (observerTarget.current) { @@ -74,7 +77,7 @@ export function ImageGrid({ {images.map((image) => (
{ @@ -85,14 +88,14 @@ export function ImageGrid({ draggable onDragStart={(e) => handleDragStart(e, image)} > - {image.prompt}
diff --git a/ui/globals.css b/ui/globals.css index 93b8ad8..0b828ce 100644 --- a/ui/globals.css +++ b/ui/globals.css @@ -1,35 +1,23 @@ @import "tailwindcss"; -@source "../node_modules/@embeddr/react-ui/dist/**/*"; +@source "../node_modules/@embeddr/react-ui/dist/**/*.{js,jsx,ts,tsx}"; +@source "../../embeddr-react-ui/dist/**/*.{js,jsx,ts,tsx}"; +@source "../../embeddr-react-ui/src/**/*.{js,jsx,ts,tsx}"; +@source "../../embeddr-plugins/plugins/**/*.{ts,tsx,js,jsx}"; +@source "../../embeddr-plugins/plugins/**/*.{css,scss}"; @import "tw-animate-css"; @import "shadcn/tailwind.css"; +@import "@fontsource-variable/jetbrains-mono"; -@custom-variant dark (&:is(.dark *)); - -@layer base { - /* Ensure dialogs are above ComfyUI UI (which can have high z-indices) */ - .tailwind[data-slot="dialog-overlay"], - [data-slot="dialog-overlay"] { - z-index: 9999 !important; - } - - .tailwind[data-slot="dialog-content"], - [data-slot="dialog-content"] { - z-index: 10000 !important; - } - - .tailwind[data-slot="select-content"], - [data-slot="select-content"] { - z-index: 10000 !important; - background-color: var(--popover); - color: var(--popover-foreground); - } - .tailwind[data-slot="select-viewport"], - [data-slot="select-viewport"] { - z-index: 10000 !important; - } -} +@custom-variant dark (&:is(.dark *, .midnight *, .forest *, .frappe *)); @theme inline { + --radius-sm: calc(var(--radius) - 4px); + --radius-md: calc(var(--radius) - 2px); + --radius-lg: var(--radius); + --radius-xl: calc(var(--radius) + 4px); + --radius-2xl: calc(var(--radius) + 8px); + --radius-3xl: calc(var(--radius) + 12px); + --radius-4xl: calc(var(--radius) + 16px); --color-background: var(--background); --color-foreground: var(--foreground); --color-card: var(--card); @@ -45,7 +33,6 @@ --color-accent: var(--accent); --color-accent-foreground: var(--accent-foreground); --color-destructive: var(--destructive); - --color-destructive-foreground: var(--destructive-foreground); --color-border: var(--border); --color-input: var(--input); --color-ring: var(--ring); @@ -54,10 +41,6 @@ --color-chart-3: var(--chart-3); --color-chart-4: var(--chart-4); --color-chart-5: var(--chart-5); - --radius-sm: calc(var(--radius) - 4px); - --radius-md: calc(var(--radius) - 2px); - --radius-lg: var(--radius); - --radius-xl: calc(var(--radius) + 4px); --color-sidebar: var(--sidebar); --color-sidebar-foreground: var(--sidebar-foreground); --color-sidebar-primary: var(--sidebar-primary); @@ -66,22 +49,11 @@ --color-sidebar-accent-foreground: var(--sidebar-accent-foreground); --color-sidebar-border: var(--sidebar-border); --color-sidebar-ring: var(--sidebar-ring); -} - -@layer base { - * { - @apply border-border outline-ring/50; - } - - body { - @apply bg-background text-foreground; - } + --font-sans: "JetBrains Mono Variable", monospace; } :root { --radius: 0.625rem; - --background: oklch(1 0 0); - --foreground: oklch(0.145 0 0); --card: oklch(1 0 0); --card-foreground: oklch(0.145 0 0); --popover: oklch(1 0 0); @@ -94,15 +66,15 @@ --muted-foreground: oklch(0.556 0 0); --accent: oklch(0.97 0 0); --accent-foreground: oklch(0.205 0 0); - --destructive: oklch(0.577 0.245 27.325); + --destructive: oklch(0.58 0.22 27); --border: oklch(0.922 0 0); --input: oklch(0.922 0 0); --ring: oklch(0.708 0 0); - --chart-1: oklch(0.646 0.222 41.116); - --chart-2: oklch(0.6 0.118 184.704); - --chart-3: oklch(0.398 0.07 227.392); - --chart-4: oklch(0.828 0.189 84.429); - --chart-5: oklch(0.769 0.188 70.08); + --chart-1: oklch(0.809 0.105 251.813); + --chart-2: oklch(0.623 0.214 259.815); + --chart-3: oklch(0.546 0.245 262.881); + --chart-4: oklch(0.488 0.243 264.376); + --chart-5: oklch(0.424 0.199 265.638); --sidebar: oklch(0.985 0 0); --sidebar-foreground: oklch(0.145 0 0); --sidebar-primary: oklch(0.205 0 0); @@ -111,32 +83,71 @@ --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); + --background: oklch(1 0 0); + --foreground: oklch(0.145 0 0); } -.dark { +html.embeddr:not([data-embeddr-theme-pack]), +body.embeddr:not([data-embeddr-theme-pack]) { + --radius: 0; + --background: oklch(0.145 0 0); + --foreground: oklch(0.985 0 0); + --card: oklch(0.205 0 0); + --card-foreground: oklch(0.985 0 0); + --popover: oklch(0.205 0 0); + --popover-foreground: oklch(0.985 0 0); + --primary: oklch(0.87 0 0); + --primary-foreground: oklch(0.205 0 0); + --secondary: oklch(0.269 0 0); + --secondary-foreground: oklch(0.985 0 0); + --muted: oklch(0.269 0 0); + --muted-foreground: oklch(0.708 0 0); + --accent: oklch(0.371 0 0); + --accent-foreground: oklch(0.985 0 0); + --destructive: rgb(231, 130, 132); + --border: oklch(1 0 0 / 10%); + --input: oklch(1 0 0 / 15%); + --ring: oklch(0.556 0 0); + --chart-1: oklch(0.809 0.105 251.813); + --chart-2: oklch(0.623 0.214 259.815); + --chart-3: oklch(0.546 0.245 262.881); + --chart-4: oklch(0.488 0.243 264.376); + --chart-5: oklch(0.424 0.199 265.638); + --sidebar: oklch(0.205 0 0); + --sidebar-foreground: oklch(0.985 0 0); + --sidebar-primary: oklch(0.488 0.243 264.376); + --sidebar-primary-foreground: oklch(0.985 0 0); + --sidebar-accent: oklch(0.269 0 0); + --sidebar-accent-foreground: oklch(0.985 0 0); + --sidebar-border: oklch(1 0 0 / 10%); + --sidebar-ring: oklch(0.556 0 0); +} +html.dark:not([data-embeddr-theme-pack]), +body.dark:not([data-embeddr-theme-pack]) { + --radius: 0; --background: oklch(0.145 0 0); --foreground: oklch(0.985 0 0); --card: oklch(0.205 0 0); --card-foreground: oklch(0.985 0 0); --popover: oklch(0.205 0 0); --popover-foreground: oklch(0.985 0 0); - --primary: oklch(0.922 0 0); + --primary: oklch(0.87 0 0); --primary-foreground: oklch(0.205 0 0); --secondary: oklch(0.269 0 0); --secondary-foreground: oklch(0.985 0 0); --muted: oklch(0.269 0 0); --muted-foreground: oklch(0.708 0 0); - --accent: oklch(0.269 0 0); + --accent: oklch(0.371 0 0); --accent-foreground: oklch(0.985 0 0); - --destructive: oklch(0.704 0.191 22.216); + --destructive: rgb(231, 130, 132); --border: oklch(1 0 0 / 10%); --input: oklch(1 0 0 / 15%); --ring: oklch(0.556 0 0); - --chart-1: oklch(0.488 0.243 264.376); - --chart-2: oklch(0.696 0.17 162.48); - --chart-3: oklch(0.769 0.188 70.08); - --chart-4: oklch(0.627 0.265 303.9); - --chart-5: oklch(0.645 0.246 16.439); + --chart-1: oklch(0.809 0.105 251.813); + --chart-2: oklch(0.623 0.214 259.815); + --chart-3: oklch(0.546 0.245 262.881); + --chart-4: oklch(0.488 0.243 264.376); + --chart-5: oklch(0.424 0.199 265.638); --sidebar: oklch(0.205 0 0); --sidebar-foreground: oklch(0.985 0 0); --sidebar-primary: oklch(0.488 0.243 264.376); @@ -146,3 +157,271 @@ --sidebar-border: oklch(1 0 0 / 10%); --sidebar-ring: oklch(0.556 0 0); } + +html.midnight:not([data-embeddr-theme-pack]), +body.midnight:not([data-embeddr-theme-pack]) { + --radius: 0; + --background: oklch(0.12 0.05 270); + --foreground: oklch(0.9 0.02 270); + --card: oklch(0.15 0.05 270); + --card-foreground: oklch(0.9 0.02 270); + --popover: oklch(0.15 0.05 270); + --popover-foreground: oklch(0.9 0.02 270); + --primary: oklch(0.7 0.15 280); + --primary-foreground: oklch(0.12 0.05 270); + --secondary: oklch(0.2 0.05 270); + --secondary-foreground: oklch(0.9 0.02 270); + --muted: oklch(0.2 0.05 270); + --muted-foreground: oklch(0.6 0.05 270); + --accent: oklch(0.25 0.05 270); + --accent-foreground: oklch(0.9 0.02 270); + --destructive: oklch(0.6 0.2 30); + --border: oklch(0.3 0.05 270); + --input: oklch(0.3 0.05 270); + --ring: oklch(0.7 0.15 280); + --sidebar: oklch(0.12 0.05 270); + --sidebar-foreground: oklch(0.9 0.02 270); + --sidebar-primary: oklch(0.7 0.15 280); + --sidebar-primary-foreground: oklch(0.12 0.05 270); + --sidebar-accent: oklch(0.2 0.05 270); + --sidebar-accent-foreground: oklch(0.9 0.02 270); + --sidebar-border: oklch(0.3 0.05 270); + --sidebar-ring: oklch(0.7 0.15 280); +} + +html.latte:not([data-embeddr-theme-pack]), +body.latte:not([data-embeddr-theme-pack]) { + --radius: 0; + --background: oklch(0.95 0.02 50); + --foreground: oklch(0.25 0.05 40); + --card: oklch(0.98 0.01 50); + --card-foreground: oklch(0.25 0.05 40); + --popover: oklch(0.98 0.01 50); + --popover-foreground: oklch(0.25 0.05 40); + --primary: oklch(0.6 0.15 35); + --primary-foreground: oklch(0.98 0.01 50); + --secondary: oklch(0.9 0.03 50); + --secondary-foreground: oklch(0.25 0.05 40); + --muted: oklch(0.9 0.03 50); + --muted-foreground: oklch(0.5 0.05 40); + --accent: oklch(0.9 0.03 50); + --accent-foreground: oklch(0.25 0.05 40); + --destructive: oklch(0.6 0.18 20); + --border: oklch(0.85 0.03 50); + --input: oklch(0.85 0.03 50); + --ring: oklch(0.6 0.15 35); + --sidebar: oklch(0.95 0.02 50); + --sidebar-foreground: oklch(0.25 0.05 40); + --sidebar-primary: oklch(0.6 0.15 35); + --sidebar-primary-foreground: oklch(0.98 0.01 50); + --sidebar-accent: oklch(0.9 0.03 50); + --sidebar-accent-foreground: oklch(0.25 0.05 40); + --sidebar-border: oklch(0.85 0.03 50); + --sidebar-ring: oklch(0.6 0.15 35); +} + +html.forest:not([data-embeddr-theme-pack]), +body.forest:not([data-embeddr-theme-pack]) { + --radius: 0; + --background: oklch(0.12 0.04 145); + --foreground: oklch(0.85 0.05 140); + --card: oklch(0.15 0.04 145); + --card-foreground: oklch(0.85 0.05 140); + --popover: oklch(0.15 0.04 145); + --popover-foreground: oklch(0.85 0.05 140); + --primary: oklch(0.65 0.15 140); + --primary-foreground: oklch(0.1 0.05 140); + --secondary: oklch(0.2 0.05 140); + --secondary-foreground: oklch(0.85 0.05 140); + --muted: oklch(0.2 0.05 140); + --muted-foreground: oklch(0.6 0.05 140); + --accent: oklch(0.25 0.05 140); + --accent-foreground: oklch(0.85 0.05 140); + --destructive: oklch(0.5 0.15 20); + --border: oklch(0.3 0.05 140); + --input: oklch(0.3 0.05 140); + --ring: oklch(0.65 0.15 140); + --sidebar: oklch(0.12 0.04 145); + --sidebar-foreground: oklch(0.85 0.05 140); + --sidebar-primary: oklch(0.65 0.15 140); + --sidebar-primary-foreground: oklch(0.1 0.05 140); + --sidebar-accent: oklch(0.2 0.05 140); + --sidebar-accent-foreground: oklch(0.85 0.05 140); + --sidebar-border: oklch(0.3 0.05 140); + --sidebar-ring: oklch(0.65 0.15 140); +} + +html.frappe:not([data-embeddr-theme-pack]), +body.frappe:not([data-embeddr-theme-pack]) { + --radius: 0.5rem; + /* Base */ + --background: #303446; + /* Text */ + --foreground: #c6d0f5; + + /* Surface 0 */ + --card: #303446; + --card-foreground: #c6d0f5; + + --popover: #303446; + --popover-foreground: #c6d0f5; + + /* Pink */ + --primary: #f4b8e4 !important; + --primary-foreground: #232634; + + /* Surface 0 */ + --secondary: #414559; + --secondary-foreground: #c6d0f5; + + /* Surface 0 */ + --muted: #414559; + /* Overlay 0 */ + --muted-foreground: #737994; + + /* Surface 1 */ + --accent: #51576d; + --accent-foreground: #c6d0f5; + + /* Red */ + --destructive: #e78284; + --destructive-foreground: #232634; + + /* Surface 1 */ + --border: #51576d; + --input: #51576d; + /* Pink */ + --ring: #f4b8e4; + + --chart-1: #e78284; + --chart-2: #ef9f76; + --chart-3: #e5c890; + --chart-4: #a6d189; + --chart-5: #85c1dc; + + /* Mantle */ + --sidebar: #292c3c; + --sidebar-foreground: #c6d0f5; + --sidebar-primary: #f4b8e4; + --sidebar-primary-foreground: #303446; + --sidebar-accent: #303446; + --sidebar-accent-foreground: #c6d0f5; + --sidebar-border: #414559; + --sidebar-ring: #f2d5cf; +} + +@layer base { + * { + @apply border-border outline-ring/50; + @apply border-border outline-ring/50; + } + body { + @apply bg-background text-foreground; + @apply font-sans bg-background text-foreground; + } + html { + @apply font-sans; + } + + /* Make the app's main content the only scrolling area. Add a nicer scrollbar */ + .app-scroll { + border-radius: 0; + } + + /* Webkit-based browsers */ + .app-scroll::-webkit-scrollbar { + border-radius: 0; + width: 12px; + } + .app-scroll::-webkit-scrollbar-track { + border-radius: 0; + background: transparent; + } + .app-scroll::-webkit-scrollbar-thumb { + background-color: rgba(255, 100, 100, 0.35); + border-radius: 0; + border: 3px solid transparent; + background-clip: padding-box; + } + + /* Firefox */ + .app-scroll { + scrollbar-width: thin; + scrollbar-color: rgba(100, 100, 100, 0.35) transparent; + } +} + +@layer utilities { + .aspect-3\/4 { + aspect-ratio: 3 / 4; + } +} + +@layer base { + * { + @apply border-border outline-ring/50; + @apply border-border outline-ring/50; + } + body { + @apply bg-background text-foreground; + @apply font-sans bg-background text-foreground; + } + html { + @apply font-sans; + } + + /* Make the app's main content the only scrolling area. Add a nicer scrollbar */ + .app-scroll { + } + + /* Webkit-based browsers */ + .app-scroll::-webkit-scrollbar { + width: 12px; + } + .app-scroll::-webkit-scrollbar-track { + background: transparent; + } + .app-scroll::-webkit-scrollbar-thumb { + background-color: rgba(255, 100, 100, 0.35); + border: 3px solid transparent; + background-clip: padding-box; + } + + /* Firefox */ + .app-scroll { + scrollbar-width: thin; + scrollbar-color: rgba(100, 100, 100, 0.35) transparent; + } +} + +@layer utilities { + .aspect-3\/4 { + aspect-ratio: 3 / 4; + } +} + +@layer base { + /* Ensure dialogs are above ComfyUI UI (which can have high z-indices) */ + .tailwind[data-slot="dialog-overlay"], + [data-slot="dialog-overlay"] { + z-index: 18000 !important; + } + + .tailwind[data-slot="dialog-content"], + [data-slot="dialog-content"] { + z-index: 19000 !important; + } + + .tailwind[data-slot="select-content"], + [data-slot="select-content"], + .tailwind[data-slot="select-viewport"], + [data-slot="select-viewport"], + .tailwind[data-slot="popover-content"], + [data-slot="popover-content"], + .tailwind[data-slot="dropdown-menu-content"], + [data-slot="dropdown-menu-content"] { + z-index: 20000 !important; + background-color: var(--popover); + color: var(--popover-foreground); + } +} diff --git a/ui/hooks/useEmbeddrApi.ts b/ui/hooks/useEmbeddrApi.ts index 39c3fca..3ed92f5 100644 --- a/ui/hooks/useEmbeddrApi.ts +++ b/ui/hooks/useEmbeddrApi.ts @@ -7,6 +7,7 @@ import { useEmbeddrCollections, type Collection, } from "./useEmbeddrCollections"; +import { proxyFetch } from "../utils/proxyFetch"; import type { ApiMode, LibraryPath, PromptImageRead } from "@types"; export type { PromptImageRead, LibraryPath, ApiMode, Collection }; @@ -21,27 +22,56 @@ export function useEmbeddrApi({ const settings = useEmbeddrSettings({ baseUrl }); const apiClient = useMemo( - () => new EmbeddrApiClient({ baseUrl: settings.apiBase }), - [settings.apiBase], + () => + new EmbeddrApiClient({ + baseUrl: settings.apiBase, + // In ComfyUI, we don't need to pass headers client-side because the proxy handles it. + // But we keep this for local dev or direct connection scenarios if proxy logic wasn't used. + // However, since we are injecting proxyFetch, headers passed here are forwarded to fetch, + // which sends them to the proxy endpoint. The proxy endpoint usually EXPECTS raw structure, + // but here we are sending metadata. + // Actually, our proxy implementation forwards the body and method, but headers handling + // in __init__.py is explicit about keys. It adds X-API-Key itself. + // So we can strictly rely on proxyFetch to do the lifting. + fetch: proxyFetch, + headers: () => { + const headers: Record = {}; + // We can attach the key here, but the proxy will OVERWRITE/Attach it from server-side config. + // It's safer to rely on server-side config for the key to avoid exposing it in browser network tab + // if we can. But providing it here doesn't hurt. + const key = + settings.apiKey || localStorage.getItem("embeddr_api_key"); + if (key) { + // For now we do NOT attach it here to ensure we test the proxy works from server config + // headers["X-API-Key"] = key; + } + return headers; + }, + }), + [settings.apiBase, settings.apiKey], ); const libraries = useEmbeddrLibraries({ apiBase: settings.apiBase, mode: settings.mode, configLoaded: settings.configLoaded, + apiKey: settings.apiKey, + apiClient, // Pass client }); const images = useEmbeddrImages({ apiBase: settings.apiBase, mode: settings.mode, configLoaded: settings.configLoaded, - apiClient, + apiClient, // Pass client - hook will use it now + apiKey: settings.apiKey, }); const collections = useEmbeddrCollections({ apiBase: settings.apiBase, configLoaded: settings.configLoaded, - apiClient, + apiClient, // Pass client + apiKey: settings.apiKey, }); return { diff --git a/ui/hooks/useEmbeddrCollections.ts b/ui/hooks/useEmbeddrCollections.ts index 5c02634..609fd6a 100644 --- a/ui/hooks/useEmbeddrCollections.ts +++ b/ui/hooks/useEmbeddrCollections.ts @@ -14,12 +14,14 @@ interface UseEmbeddrCollectionsProps { apiBase: string; configLoaded: boolean; apiClient?: EmbeddrApiClient; + apiKey?: string; } export function useEmbeddrCollections({ apiBase, configLoaded, apiClient, + apiKey, }: UseEmbeddrCollectionsProps) { const [collections, setCollections] = useState([]); const [loadingCollections, setLoadingCollections] = useState(false); @@ -32,18 +34,30 @@ export function useEmbeddrCollections({ const headers: Record = { "Content-Type": "application/json", }; + if (apiKey) { + headers["X-API-Key"] = apiKey; + } let baseUrl = apiBase; if (baseUrl.endsWith("/")) baseUrl = baseUrl.slice(0, -1); // Ensure we target the V2 API - if (!baseUrl.endsWith("/api/v2")) { - baseUrl = `${baseUrl}/api/v2`; + if (!baseUrl.endsWith("/api/v1")) { + baseUrl = `${baseUrl}/api/v1`; + } + + const url = `${baseUrl}/collections`; + let res: Response; + + // Use Proxy + if (url.startsWith("http")) { + res = await fetch(`/embeddr/proxy?url=${encodeURIComponent(url)}`, { + method: "GET", + headers, + }); + } else { + res = await fetch(url, { method: "GET", headers }); } - const res = await fetch(`${baseUrl}/collections`, { - method: "GET", - headers, - }); if (res.ok) { const data = await res.json(); // data could be paginated or just a list @@ -58,7 +72,7 @@ export function useEmbeddrCollections({ } finally { setLoadingCollections(false); } - }, [apiBase, configLoaded]); + }, [apiBase, configLoaded, apiKey]); const [creating, setCreating] = useState(false); @@ -70,6 +84,10 @@ export function useEmbeddrCollections({ const headers: Record = { "Content-Type": "application/json", }; + if (apiKey) { + headers["X-API-Key"] = apiKey; + } + const payload = { label: label, type_name: "collection:mix", // Default to simple mix @@ -80,15 +98,15 @@ export function useEmbeddrCollections({ let baseUrl = apiBase; if (baseUrl.endsWith("/")) baseUrl = baseUrl.slice(0, -1); - if (!baseUrl.endsWith("/api/v2")) { - baseUrl = `${baseUrl}/api/v2`; + if (!baseUrl.endsWith("/api/v1")) { + baseUrl = `${baseUrl}/api/v1`; } // Use the artifact endpoint to create a collection, since /collections might be read-only or alias - // But if /api/v2/collections exists as a dedicated endpoint, we use it. - // Assuming /api/v2/collections POST works as expected for creating collections specificically. + // But if /api/v1/collections exists as a dedicated endpoint, we use it. + // Assuming /api/v1/collections POST works as expected for creating collections specificically. // If not, we might need to POST to /artifacts with type=collection. - // Let's stick to the user's requested endpoint /api/v2/collections for now. + // Let's stick to the user's requested endpoint /api/v1/collections for now. const res = await fetch(`${baseUrl}/collections`, { method: "POST", headers, @@ -109,7 +127,7 @@ export function useEmbeddrCollections({ setCreating(false); } }, - [apiBase, configLoaded, fetchCollections], + [apiBase, configLoaded, fetchCollections, apiKey], ); return { diff --git a/ui/hooks/useEmbeddrImages.ts b/ui/hooks/useEmbeddrImages.ts index d434562..862505f 100644 --- a/ui/hooks/useEmbeddrImages.ts +++ b/ui/hooks/useEmbeddrImages.ts @@ -9,6 +9,7 @@ interface UseEmbeddrImagesProps { mode: ApiMode; configLoaded: boolean; apiClient?: EmbeddrApiClient; + apiKey?: string; } export function useEmbeddrImages({ @@ -16,6 +17,7 @@ export function useEmbeddrImages({ mode, configLoaded, apiClient, + apiKey, }: UseEmbeddrImagesProps) { const [images, setImages] = useState>([]); const [loading, setLoading] = useState(false); @@ -25,6 +27,8 @@ export function useEmbeddrImages({ const [similarImageId, setSimilarImageId] = useState( null, ); + const failureCountRef = useRef(0); + const nextAllowedFetchRef = useRef(0); const fetchImages = useCallback( async ( @@ -36,6 +40,11 @@ export function useEmbeddrImages({ collectionId?: string | null, ) => { if (!configLoaded) return; + const now = Date.now(); + if (!reset && now < nextAllowedFetchRef.current) { + return; + } + // If we are loading more pages (not reset) and already loading, skip if (loadingRef.current && !reset) return; loadingRef.current = true; @@ -44,9 +53,10 @@ export function useEmbeddrImages({ const headers: Record = { "Content-Type": "application/json", }; - const storedKey = localStorage.getItem("embeddr_api_key"); - if (storedKey) { - headers["Authorization"] = `Bearer ${storedKey}`; + // Use prop apiKey first, then fallback to localStorage if needed (though prop should be source of truth) + const currentKey = apiKey || localStorage.getItem("embeddr_api_key"); + if (currentKey) { + headers["X-API-Key"] = currentKey; } const currentPage = reset ? 1 : pageRef.current; @@ -55,8 +65,8 @@ export function useEmbeddrImages({ let baseUrl = apiBase; if (baseUrl.endsWith("/")) baseUrl = baseUrl.slice(0, -1); // Ensure V2 - if (!baseUrl.endsWith("/api/v2")) { - baseUrl = `${baseUrl}/api/v2`; + if (!baseUrl.endsWith("/api/v1")) { + baseUrl = `${baseUrl}/api/v1`; } const currentSimilarId = @@ -84,6 +94,7 @@ export function useEmbeddrImages({ limit: 20, }); } else if (apiClient) { + // Use apiClient for listing items, which now uses proxyFetch internally through constructor const list = await apiClient.artifacts.list({ limit: 20, offset, @@ -93,15 +104,29 @@ export function useEmbeddrImages({ collection_id: collectionId || undefined, }); + // Helper to ensure urls are proxied even if returned from client helpers + // The client's getContentUrl returns raw url, we must wrap it if display is needed + const proxify = (u: string) => + u.startsWith("http") + ? `/embeddr/proxy?url=${encodeURIComponent(u)}` + : u; + const items = list.items || []; const mapped = items.map((item: any) => { const id = item.id; const metadata = item.metadata_json || {}; + + const rawImageUrl = apiClient.artifacts.getContentUrl(id); + const rawThumbUrl = apiClient.artifacts.getPreviewUrl( + id, + "thumbnail", + ); + return { id: id, prompt: metadata.prompt || metadata.filename || "Untitled", - image_url: apiClient.artifacts.getContentUrl(id), - thumb_url: apiClient.artifacts.getPreviewUrl(id, "thumbnail"), + image_url: proxify(rawImageUrl), + thumb_url: proxify(rawThumbUrl), created_at: item.created_at || new Date().toISOString(), like_count: 0, liked_by_me: false, @@ -121,7 +146,7 @@ export function useEmbeddrImages({ setHasMore(offset + mapped.length < list.total); return; } else { - // List Artifacts + // List Artifacts (No Client Fallback - Should rarely happen if hook setup correct) url = `${baseUrl}/artifacts/?type_name=image&sort=new&limit=20&offset=${offset}`; if (libraryId) { url += `&library_id=${libraryId}`; @@ -131,7 +156,17 @@ export function useEmbeddrImages({ } } - let response = await fetch(url, { method, headers, body }); + let response: Response; + + // Use Proxy for all requests to avoid CORS/Auth issues in ComfyUI environment + // The backend proxy injects the API key from server-side config + const isComfyEnv = true; // We are in ComfyUI extension + if (isComfyEnv && url.startsWith("http")) { + const proxyUrl = `/embeddr/proxy?url=${encodeURIComponent(url)}`; + response = await fetch(proxyUrl, { method, headers, body }); + } else { + response = await fetch(url, { method, headers, body }); + } // Fallback for Similar Search if plugin missing (404) if (!response.ok && currentSimilarId && response.status === 404) { @@ -171,18 +206,25 @@ export function useEmbeddrImages({ const isFullArtifact = !!item.uri; const metadata = item.metadata_json || {}; + const rawImageUrl = apiClient + ? apiClient.artifacts.getContentUrl(id) + : `${baseUrl}/artifacts/${id}/content`; + const rawThumbUrl = apiClient + ? apiClient.artifacts.getPreviewUrl(id, "thumbnail") + : `${baseUrl}/artifacts/${id}/preview?preview_type=thumbnail`; + return { id: id, prompt: metadata.prompt || metadata.filename || (isFullArtifact ? "Untitled" : "Similar Result"), - image_url: apiClient - ? apiClient.artifacts.getContentUrl(id) - : `${baseUrl}/artifacts/${id}/content`, - thumb_url: apiClient - ? apiClient.artifacts.getPreviewUrl(id, "thumbnail") - : `${baseUrl}/artifacts/${id}/preview?preview_type=thumbnail`, + image_url: rawImageUrl.startsWith("http") + ? `/embeddr/proxy?url=${encodeURIComponent(rawImageUrl)}` + : rawImageUrl, + thumb_url: rawThumbUrl.startsWith("http") + ? `/embeddr/proxy?url=${encodeURIComponent(rawThumbUrl)}` + : rawThumbUrl, created_at: item.created_at || new Date().toISOString(), like_count: 0, liked_by_me: false, @@ -211,11 +253,19 @@ export function useEmbeddrImages({ setImages((prev) => [...prev, ...items]); pageRef.current = currentPage + 1; } + failureCountRef.current = 0; + nextAllowedFetchRef.current = 0; } else { throw new Error(`Failed to fetch images: ${response.status}`); } } catch (error) { console.error("Error fetching images:", error); + failureCountRef.current += 1; + const backoffMs = Math.min( + 60000, + 2000 * 2 ** (failureCountRef.current - 1), + ); + nextAllowedFetchRef.current = Date.now() + backoffMs; if (app.extensionManager?.toast) { app.extensionManager.toast.add({ severity: "error", @@ -230,7 +280,7 @@ export function useEmbeddrImages({ setLoading(false); } }, - [apiBase, configLoaded, mode, similarImageId], + [apiBase, configLoaded, mode, similarImageId, apiKey], ); return { diff --git a/ui/hooks/useEmbeddrLibraries.ts b/ui/hooks/useEmbeddrLibraries.ts index da7315c..7a0cb57 100644 --- a/ui/hooks/useEmbeddrLibraries.ts +++ b/ui/hooks/useEmbeddrLibraries.ts @@ -5,12 +5,14 @@ interface UseEmbeddrLibrariesProps { apiBase: string; mode: ApiMode; configLoaded: boolean; + apiKey?: string; } export function useEmbeddrLibraries({ apiBase, mode, configLoaded, + apiKey, }: UseEmbeddrLibrariesProps) { const [libraries, setLibraries] = useState>([]); @@ -19,7 +21,25 @@ export function useEmbeddrLibraries({ let baseUrl = apiBase; if (baseUrl.endsWith("/")) baseUrl = baseUrl.slice(0, -1); - const res = await fetch(`${baseUrl}/workspace/paths`); + const headers: Record = { + "Content-Type": "application/json", + }; + if (apiKey) { + headers["X-API-Key"] = apiKey; + } + + const url = `${baseUrl}/workspace/paths`; + let res: Response; + // Use Proxy + if (url.startsWith("http")) { + res = await fetch(`/embeddr/proxy?url=${encodeURIComponent(url)}`, { + method: "GET", + headers, + }); + } else { + res = await fetch(url, { method: "GET", headers }); + } + if (res.ok) { const data = await res.json(); setLibraries(data); @@ -32,9 +52,11 @@ export function useEmbeddrLibraries({ // Fetch libraries when in local mode useEffect(() => { if (configLoaded && mode === "local") { - fetchLibraries(); + // TODO: Backend route /api/v1/workspace/paths is currently missing. + // Re-enable this when the endpoint is restored or replaced. + // fetchLibraries(); } - }, [configLoaded, mode, apiBase]); + }, [configLoaded, mode, apiBase, apiKey]); return { libraries, diff --git a/ui/hooks/useEmbeddrSettings.ts b/ui/hooks/useEmbeddrSettings.ts index 4ed288b..73044fb 100644 --- a/ui/hooks/useEmbeddrSettings.ts +++ b/ui/hooks/useEmbeddrSettings.ts @@ -1,7 +1,13 @@ -import { useEffect, useMemo, useState } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; // @ts-ignore import { app } from "../../../scripts/app.js"; import type { ApiMode } from "@types"; +import { + applyThemePackCss, + applyThemePackTokens, + clearThemePackTokens, + loadThemePacks, +} from "../utils/themePacks"; interface UseEmbeddrSettingsProps { baseUrl?: string; @@ -10,58 +16,118 @@ interface UseEmbeddrSettingsProps { export function useEmbeddrSettings({ baseUrl = "http://localhost:8003", }: UseEmbeddrSettingsProps = {}) { + const normalizeEndpoint = (value: string) => + value.replace(/\/api\/v1\/?$/, "").replace(/\/+$/, ""); + const [endpoint, setEndpoint] = useState(() => { const stored = localStorage.getItem("embeddr_endpoint"); - return stored || baseUrl; + return normalizeEndpoint(stored || baseUrl); }); const [mode, setMode] = useState(() => "local"); + const [apiKey, setApiKey] = useState(() => { + return localStorage.getItem("embeddr_api_key") || ""; + }); + const [gridSize, setGridSize] = useState(() => - parseInt(localStorage.getItem("embeddr_grid_size") || "3") + parseInt(localStorage.getItem("embeddr_grid_size") || "3"), ); const [gridPreviewContain, setGridPreviewContain] = useState( - () => localStorage.getItem("embeddr_grid_preview_contain") === "true" + () => localStorage.getItem("embeddr_grid_preview_contain") === "true", ); const [theme, setTheme] = useState(() => { return localStorage.getItem("embeddr_theme") || "dark"; }); + const [themePackId, setThemePackId] = useState(() => { + return localStorage.getItem("embeddr_theme_pack") || ""; + }); + + const appliedTokenKeysRef = useRef([]); + const [configLoaded, setConfigLoaded] = useState(false); // computed API base for requests const apiBase = useMemo(() => { - const url = endpoint.replace(/\/$/, ""); // remove trailing slash - return `${url}/api/v2`; // append API path automatically + const url = normalizeEndpoint(endpoint); + return url ? `${url}/api/v1` : "/api/v1"; }, [endpoint]); // Apply theme useEffect(() => { - const container = document.querySelector(".embeddr-sidebar-container"); - if (container) { + const roots = document.querySelectorAll(".embeddr-theme-root"); + roots.forEach((root) => { if (theme === "dark") { - container.classList.add("dark"); + root.classList.add("dark"); } else { - container.classList.remove("dark"); + root.classList.remove("dark"); } - } + }); - // Also handle portals immediately - const portals = document.querySelectorAll( - "[data-radix-portal], [data-slot='dialog-content'], [data-slot='dialog-overlay'], [data-slot='select-content'], [data-slot='select-viewport'], [data-slot='popover-content'], [data-slot='dropdown-menu-content']" - ); - portals.forEach((portal) => { - if (theme === "dark") { - portal.classList.add("dark"); + localStorage.setItem("embeddr_theme", theme); + }, [theme]); + + useEffect(() => { + localStorage.setItem("embeddr_theme_pack", themePackId); + }, [themePackId]); + + useEffect(() => { + const roots = [document.documentElement, document.body]; + roots.forEach((root) => { + if (themePackId) { + root.setAttribute("data-embeddr-theme-pack", "1"); } else { - portal.classList.remove("dark"); + root.removeAttribute("data-embeddr-theme-pack"); } }); + }, [themePackId]); - localStorage.setItem("embeddr_theme", theme); - }, [theme]); + useEffect(() => { + let isActive = true; + + const applyTheme = async () => { + if (!apiBase) return; + const targets = Array.from( + document.querySelectorAll(".embeddr-theme-root"), + ); + + clearThemePackTokens(targets, appliedTokenKeysRef.current); + + if (!themePackId) { + applyThemePackCss(null); + appliedTokenKeysRef.current = []; + return; + } + + try { + const packs = await loadThemePacks(apiBase); + if (!isActive) return; + const pack = packs.find((item) => item.id === themePackId); + if (!pack) { + applyThemePackCss(null); + appliedTokenKeysRef.current = []; + return; + } + + applyThemePackCss(pack); + appliedTokenKeysRef.current = applyThemePackTokens( + targets, + pack, + theme === "dark" ? "dark" : "light", + ); + } catch (e) { + console.warn("[EmbeddrUI] Failed to load theme packs", e); + } + }; + + applyTheme(); + return () => { + isActive = false; + }; + }, [apiBase, theme, themePackId]); // Load config on mount useEffect(() => { @@ -71,8 +137,11 @@ export function useEmbeddrSettings({ const data = await res.json(); if (data.endpoint) { - setEndpoint(new URL(data.endpoint).toString()); - localStorage.setItem("embeddr_endpoint", data.endpoint); + const normalized = normalizeEndpoint( + new URL(data.endpoint).toString(), + ); + setEndpoint(normalized); + localStorage.setItem("embeddr_endpoint", normalized); } if (data.mode) { @@ -80,11 +149,16 @@ export function useEmbeddrSettings({ localStorage.setItem("embeddr_mode", data.mode); } + if (data.api_key) { + setApiKey(data.api_key); + localStorage.setItem("embeddr_api_key", data.api_key); + } + if (data.grid_preview_contain !== undefined) { setGridPreviewContain(data.grid_preview_contain); localStorage.setItem( "embeddr_grid_preview_contain", - data.grid_preview_contain.toString() + data.grid_preview_contain.toString(), ); } } catch (e) { @@ -100,33 +174,43 @@ export function useEmbeddrSettings({ newEndpoint: string, newMode: ApiMode, newGridSize: number, - newGridPreviewContain: boolean + newGridPreviewContain: boolean, + newApiKey?: string, ) => { try { - localStorage.setItem("embeddr_endpoint", newEndpoint); + const normalizedEndpoint = normalizeEndpoint(newEndpoint); + localStorage.setItem("embeddr_endpoint", normalizedEndpoint); localStorage.setItem("embeddr_mode", newMode); localStorage.setItem("embeddr_grid_size", newGridSize.toString()); localStorage.setItem( "embeddr_grid_preview_contain", - newGridPreviewContain.toString() + newGridPreviewContain.toString(), ); + const payload: any = { + endpoint: normalizedEndpoint, + mode: newMode, + grid_size: newGridSize, + grid_preview_contain: newGridPreviewContain, + }; + + if (newApiKey !== undefined) { + payload.api_key = newApiKey; + setApiKey(newApiKey); + localStorage.setItem("embeddr_api_key", newApiKey); + } + const res = await fetch("/embeddr/config", { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ - endpoint: newEndpoint, - mode: newMode, - grid_size: newGridSize, - grid_preview_contain: newGridPreviewContain, - }), + body: JSON.stringify(payload), }); if (!res.ok) { throw new Error(`Server returned ${res.status}`); } - setEndpoint(newEndpoint); + setEndpoint(normalizedEndpoint); setMode(newMode); setGridSize(newGridSize); setGridPreviewContain(newGridPreviewContain); @@ -170,7 +254,11 @@ export function useEmbeddrSettings({ setGridPreviewContain, theme, setTheme, + themePackId, + setThemePackId, configLoaded, + apiKey, + setApiKey, apiBase, saveSettings, }; diff --git a/ui/hooks/useThemePacks.ts b/ui/hooks/useThemePacks.ts new file mode 100644 index 0000000..94938a7 --- /dev/null +++ b/ui/hooks/useThemePacks.ts @@ -0,0 +1,26 @@ +import { useCallback, useEffect, useState } from "react"; +import { loadThemePacks, type ThemePack } from "../utils/themePacks"; + +export function useThemePacks(apiBase: string, enabled = true) { + const [packs, setPacks] = useState([]); + const [isLoading, setIsLoading] = useState(false); + + const reload = useCallback(async () => { + if (!enabled || !apiBase) return; + setIsLoading(true); + try { + const data = await loadThemePacks(apiBase); + setPacks(data); + } catch { + setPacks([]); + } finally { + setIsLoading(false); + } + }, [apiBase, enabled]); + + useEffect(() => { + reload(); + }, [reload]); + + return { packs, isLoading, reload }; +} diff --git a/ui/main.tsx b/ui/main.tsx index 23cd608..b1d4b4e 100644 --- a/ui/main.tsx +++ b/ui/main.tsx @@ -1,11 +1,18 @@ import React from "react"; import ReactDOM from "react-dom/client"; +import * as ReactDOMLib from "react-dom"; +import * as EmbeddrUI from "@embeddr/react-ui"; +import * as Lucide from "lucide-react"; +import * as ReactQuery from "@tanstack/react-query"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import * as Recharts from "recharts"; import { ImageDialogProvider } from "@embeddr/react-ui/providers/ImageDialogProvider"; import { ExternalNavProvider } from "@embeddr/react-ui"; // @ts-ignore import { app } from "../../../scripts/app.js"; import EmbeddrPanel from "./components/panels/EmbeddrPanel.js"; import { GlobalDialog } from "./components/GlobalDialog"; +import { ZenShell } from "./components/ZenShell"; import "./nodes/EmbeddrLoadArtifact.js"; import "./nodes/EmbeddrMergeIds.js"; import "./nodes/EmbeddrLoRAStack.js"; @@ -14,36 +21,137 @@ import "./nodes/EmbeddrFindCollection.js"; // @ts-ignore import "./globals.css"; +(window as any).React = React; +(window as any).ReactDOM = ReactDOMLib; +(window as any).Lucide = Lucide; +(window as any).ReactQuery = ReactQuery; +(window as any).Recharts = Recharts; + +const embeddrUI: Record = { ...EmbeddrUI }; +if (!("usePluginDrop" in embeddrUI)) { + embeddrUI.usePluginDrop = () => ({ + isOver: false, + canDrop: false, + dropRef: () => {}, + }); +} +if (!("usePluginStorage" in embeddrUI)) { + embeddrUI.usePluginStorage = ( + pluginId: string, + key: string, + initialValue: T, + ) => { + const storageKey = pluginId ? `plugin-storage:${pluginId}:${key}` : key; + const [value, setValue] = React.useState(() => { + const raw = localStorage.getItem(storageKey); + if (raw !== null) { + try { + return JSON.parse(raw) as T; + } catch { + return initialValue; + } + } + return initialValue; + }); + React.useEffect(() => { + try { + localStorage.setItem(storageKey, JSON.stringify(value)); + } catch (e) { + console.warn("[EmbeddrUI] Failed to persist plugin storage", e); + } + }, [storageKey, value]); + return [value, setValue] as const; + }; +} +(window as any).EmbeddrUI = embeddrUI; +(window as any)["@embeddr/react-ui"] = embeddrUI; +(window as any).EmbeddrReactUI = embeddrUI; +(window as any).embeddr_react_ui = embeddrUI; +(window as any)["lucide-react"] = Lucide; +(window as any).lucideReact = Lucide; +(window as any)["@tanstack/react-query"] = ReactQuery; +(window as any).reactQuery = ReactQuery; +(window as any)["recharts"] = Recharts; +(window as any).Recharts = Recharts; + +document.documentElement.classList.add("embeddr-theme-root", "font-sans"); +document.body.classList.add("embeddr-theme-root", "font-sans"); + +const originalFetch = window.fetch.bind(window); +window.fetch = async (input: RequestInfo | URL, init?: RequestInit) => { + const urlStr = input instanceof Request ? input.url : input.toString(); + const endpoint = + localStorage.getItem("embeddr_endpoint") || "http://localhost:8003"; + const apiKey = localStorage.getItem("embeddr_api_key") || ""; + + const shouldAttachKey = (() => { + try { + if (!apiKey) return false; + const target = urlStr.startsWith("http") + ? new URL(urlStr) + : new URL(urlStr, window.location.origin); + const base = new URL(endpoint); + return ( + target.origin === base.origin && target.pathname.startsWith("/api/v1/") + ); + } catch { + return false; + } + })(); + + if (!shouldAttachKey) { + return originalFetch(input, init); + } + + const headers = new Headers( + input instanceof Request ? input.headers : undefined, + ); + if (init?.headers) { + new Headers(init.headers).forEach((value, key) => headers.set(key, value)); + } + if (!headers.has("X-API-Key")) headers.set("X-API-Key", apiKey); + + if (input instanceof Request) { + const nextRequest = new Request(input, { ...init, headers }); + return originalFetch(nextRequest); + } + + return originalFetch(input, { ...init, headers }); +}; + +const queryClient = new QueryClient(); + // Mount Global Dialog const dialogContainer = document.createElement("div"); dialogContainer.id = "embeddr-global-dialog-root"; // Add tailwind class to ensure styles work if they rely on parent class dialogContainer.classList.add("tailwind"); -dialogContainer.classList.add("dark"); +dialogContainer.classList.add("font-sans"); +dialogContainer.classList.add("embeddr-theme-root"); document.body.appendChild(dialogContainer); const dialogRoot = ReactDOM.createRoot(dialogContainer); dialogRoot.render( - + + - + , ); - -// Register Embeddr Sidebar app.extensionManager.registerSidebarTab({ id: "embeddr", icon: "mdi mdi-cloud-search-outline", title: "Embeddr", type: "custom", render(container) { - document.documentElement.classList.add("dark", "tailwind"); + document.documentElement.classList.add("tailwind"); container.innerHTML = ""; container.classList.add("tailwind"); + container.classList.add("font-sans"); container.classList.add("embeddr-sidebar-container"); // Default to dark, but let React handle it - container.classList.add("dark"); + container.classList.add("embeddr-theme-root"); // Prevent the parent container from scrolling container.style.overflow = "hidden"; container.style.height = "100%"; @@ -53,14 +161,21 @@ app.extensionManager.registerSidebarTab({ mutations.forEach((mutation) => { if (mutation.type === "childList") { const portals = document.querySelectorAll( - "[data-radix-portal], [data-slot='dialog-content'], [data-slot='dialog-overlay'], [data-slot='select-content'], [data-slot='select-viewport'], [data-slot='popover-content'], [data-slot='dropdown-menu-content']" + "[data-radix-portal], [data-slot='dialog-content'], [data-slot='dialog-overlay'], [data-slot='select-content'], [data-slot='select-viewport'], [data-slot='popover-content'], [data-slot='dropdown-menu-content']", ); - const isDark = container.classList.contains("dark"); + const isDark = + container.classList.contains("dark") || + localStorage.getItem("embeddr_theme") === "dark"; portals.forEach((portal) => { if (!portal.classList.contains("tailwind")) { portal.classList.add("tailwind"); } - // // Sync dark mode + if (!portal.classList.contains("font-sans")) { + portal.classList.add("font-sans"); + } + if (!portal.classList.contains("embeddr-theme-root")) { + portal.classList.add("embeddr-theme-root"); + } if (isDark) { portal.classList.add("dark"); } else { @@ -74,11 +189,13 @@ app.extensionManager.registerSidebarTab({ const root = ReactDOM.createRoot(container); root.render( - - - - - + + + + + + + , ); return () => { observer.disconnect(); diff --git a/ui/nodes/EmbeddrMergeIds.ts b/ui/nodes/EmbeddrMergeIds.ts index 5a3e847..f5841c0 100644 --- a/ui/nodes/EmbeddrMergeIds.ts +++ b/ui/nodes/EmbeddrMergeIds.ts @@ -1,8 +1,8 @@ import { app } from "../../../scripts/app.js"; const _ID = "embeddr.MergeIDs"; -const _PREFIX = "id"; -const _TYPE = "STRING"; +const _PREFIX = "artifact_"; +const _TYPE = "EMBEDDR_ARTIFACT_ID"; app.registerExtension({ name: "embeddr.dynamic_merge_ids", @@ -29,7 +29,7 @@ app.registerExtension({ slotIdx, event, linkInfo, - nodeSlot + nodeSlot, ) { const me = onConnectionsChange?.apply(this, arguments); diff --git a/ui/utils/proxyFetch.ts b/ui/utils/proxyFetch.ts new file mode 100644 index 0000000..e7dc7b0 --- /dev/null +++ b/ui/utils/proxyFetch.ts @@ -0,0 +1,32 @@ +// Utility to proxy requests through ComfyUI backend +// This solves CORS and Auth Header issues by delegating the request to the python server +// which attaches the key from config.json and forwards it. + +export const proxyFetch = async ( + input: RequestInfo | URL, + init?: RequestInit, +): Promise => { + let urlStr = input.toString(); + + // If input is Request object, handle it (though usually it's string in our app) + if (input instanceof Request) { + urlStr = input.url; + } + + // Only proxy http(s) requests + if (urlStr.startsWith("http")) { + const proxyUrl = `/embeddr/proxy?url=${encodeURIComponent(urlStr)}`; + const proxyRes = await fetch(proxyUrl, init); + if (proxyRes.status === 404 || proxyRes.status === 405) { + console.warn("[proxyFetch] Proxy missing, falling back to direct fetch", { + url: urlStr, + status: proxyRes.status, + }); + return fetch(urlStr, init); + } + return proxyRes; // Let ComfyUI handle the request + } + + // Fallback to normal fetch for relative or other protocols + return fetch(input, init); +}; diff --git a/ui/utils/themePacks.ts b/ui/utils/themePacks.ts new file mode 100644 index 0000000..cdabeb5 --- /dev/null +++ b/ui/utils/themePacks.ts @@ -0,0 +1,154 @@ +import { proxyFetch } from "./proxyFetch"; + +export type ThemePackTokens = { + light?: Record; + dark?: Record; +}; + +export type ThemePack = { + id: string; + name: string; + version?: string; + author?: string; + description?: string; + preview?: string; + iconUrl?: string; + bannerUrl?: string; + tokens?: ThemePackTokens; + css?: string; + cssUrl?: string; + icon?: string; + banner?: string; + cssFile?: string; +}; + +export type ThemePackIndex = { + packs: ThemePack[]; +}; + +const THEME_STYLE_ID = "embeddr-theme-pack-css"; + +const buildThemeUrl = (apiBase: string) => { + const trimmed = apiBase.replace(/\/+$/, ""); + if (/\/api\/v1$/.test(trimmed)) { + return `${trimmed}/themes`; + } + return `${trimmed}/api/v1/themes`; +}; + +const resolveAssetBase = (apiBase: string) => { + const trimmed = apiBase.replace(/\/+$/, ""); + return trimmed.replace(/\/api\/v1$/, ""); +}; + +export async function loadThemePacks(apiBase: string): Promise { + const packs = new Map(); + + const addPack = (pack?: ThemePack) => { + if (!pack?.id) return; + packs.set(pack.id, pack); + }; + + const addIndex = (index?: ThemePackIndex) => { + if (!index?.packs) return; + index.packs.forEach((pack) => addPack(pack)); + }; + + const themeUrl = buildThemeUrl(apiBase); + const assetBase = resolveAssetBase(apiBase); + + const res = await proxyFetch(themeUrl, { cache: "no-store" }); + if (!res.ok) { + throw new Error(`Theme pack fetch failed (${res.status})`); + } + + const data = (await res.json()) as ThemePackIndex; + const normalized: ThemePackIndex = { + packs: (data.packs || []).map((pack) => { + const iconUrl = + pack.iconUrl || + (pack.icon ? `${assetBase}/themes/${pack.id}/${pack.icon}` : undefined); + const bannerUrl = + pack.bannerUrl || + (pack.banner + ? `${assetBase}/themes/${pack.id}/${pack.banner}` + : undefined); + const cssUrl = + pack.cssUrl || + (pack.cssFile + ? `${assetBase}/themes/${pack.id}/${pack.cssFile}` + : undefined); + + return { + ...pack, + iconUrl: + iconUrl && iconUrl.startsWith("/") + ? `${assetBase}${iconUrl}` + : iconUrl, + bannerUrl: + bannerUrl && bannerUrl.startsWith("/") + ? `${assetBase}${bannerUrl}` + : bannerUrl, + cssUrl: + cssUrl && cssUrl.startsWith("/") ? `${assetBase}${cssUrl}` : cssUrl, + }; + }), + }; + + addIndex(normalized); + return Array.from(packs.values()); +} + +export function applyThemePackCss(pack?: ThemePack | null) { + if (typeof document === "undefined") return; + const existing = document.getElementById(THEME_STYLE_ID); + if (existing?.parentElement) { + existing.parentElement.removeChild(existing); + } + + if (!pack) return; + + if (pack.css) { + const style = document.createElement("style"); + style.id = THEME_STYLE_ID; + style.textContent = pack.css; + document.head.appendChild(style); + return; + } + + if (pack.cssUrl) { + const link = document.createElement("link"); + link.id = THEME_STYLE_ID; + link.rel = "stylesheet"; + link.href = pack.cssUrl; + document.head.appendChild(link); + } +} + +export function applyThemePackTokens( + targets: HTMLElement[], + pack: ThemePack | null | undefined, + mode: "light" | "dark", +) { + if (!pack?.tokens) return [] as string[]; + const tokens = + (mode === "dark" ? pack.tokens.dark : pack.tokens.light) || + pack.tokens.light || + pack.tokens.dark; + if (!tokens) return [] as string[]; + const keys = Object.keys(tokens); + targets.forEach((target) => { + Object.entries(tokens).forEach(([key, value]) => { + if (typeof value !== "string") return; + target.style.setProperty(key, value); + }); + }); + return keys; +} + +export function clearThemePackTokens(targets: HTMLElement[], keys: string[]) { + if (!keys.length) return; + targets.forEach((target) => { + keys.forEach((key) => target.style.removeProperty(key)); + }); +} From fc998290c911dab3f03233c3dd2dc63e7880fd68 Mon Sep 17 00:00:00 2001 From: Nynxz Date: Fri, 20 Feb 2026 22:24:39 +1000 Subject: [PATCH 4/9] checkpoint: pre 0.2.0 cleanup --- __init__.py | 2 + nodes/EmbeddrAction.py | 480 ++++++++++++ nodes/EmbeddrLoadArtifact.py | 24 +- nodes/EmbeddrLoadArtifacts.py | 20 +- nodes/EmbeddrLoadImages.py | 18 +- nodes/EmbeddrUploadArtifact.py | 24 +- nodes/utils/config.py | 21 +- package.json | 2 +- ui/components/GlobalDialog.tsx | 2 +- ui/components/ZenShell.tsx | 199 +++-- ui/components/panels/EmbeddrPanel.tsx | 15 +- ui/components/panels/ImageDetails.tsx | 12 +- .../selectors/CollectionSelector.tsx | 8 +- ui/components/tabs/ExploreTab.tsx | 6 +- ui/components/tabs/PromptTab.tsx | 436 +++++++++++ ui/components/tabs/SettingsForm.tsx | 12 +- ui/components/ui/ImageGrid.tsx | 44 +- ui/components/ui/SearchBar.tsx | 8 +- ui/hooks/useEmbeddrApi.ts | 2 +- ui/hooks/useEmbeddrCollections.ts | 2 +- ui/hooks/useEmbeddrImages.ts | 2 +- ui/hooks/useEmbeddrSettings.ts | 23 +- ui/main.tsx | 1 + ui/nodes/EmbeddrAction.ts | 693 ++++++++++++++++++ ui/nodes/EmbeddrLoadArtifact.ts | 37 +- ui/utils/themePacks.ts | 270 ++++++- vite.config.mts | 3 + 27 files changed, 2244 insertions(+), 122 deletions(-) create mode 100644 nodes/EmbeddrAction.py create mode 100644 ui/components/tabs/PromptTab.tsx create mode 100644 ui/nodes/EmbeddrAction.ts diff --git a/__init__.py b/__init__.py index 8ca5983..1c3ec12 100644 --- a/__init__.py +++ b/__init__.py @@ -19,6 +19,7 @@ from .nodes.EmbeddrLoRAStack import EmbeddrLoRAStack from .nodes.EmbeddrFindCollection import EmbeddrFindCollectionNode from .nodes.EmbeddrUploadOptions import UploadArtifactOptionsNode +from .nodes.EmbeddrAction import EmbeddrActionNode CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.json") @@ -246,6 +247,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: EmbeddrLoadVideoNode, EmbeddrLoRAStack, EmbeddrFindCollectionNode, + EmbeddrActionNode, ] diff --git a/nodes/EmbeddrAction.py b/nodes/EmbeddrAction.py new file mode 100644 index 0000000..4b0d276 --- /dev/null +++ b/nodes/EmbeddrAction.py @@ -0,0 +1,480 @@ +""" +EmbeddrAction – Dynamic Lotus Action node for ComfyUI. + +Introspects Lotus capabilities (kind=action) at execution time and +dispatches the selected action via the Embeddr REST API. + +The frontend extension (ui/nodes/EmbeddrAction.ts) handles the +dynamic input/widget creation based on the selected action's schema. +""" +import json +import requests +from comfy_api.latest import io +from .types import EmbeddrArtifactID, EmbeddrArtifactIDObject +from .utils.config import get_embeddr_base_url, get_auth_headers + + +def _log(msg: str): + print(f"[Embeddr Action] {msg}") + + +_DYN_PREFIX = "dyn_" + + +def _fetch_actions() -> list[dict]: + """Fetch all exposed Lotus actions from the backend.""" + try: + base_url = get_embeddr_base_url() + resp = requests.get( + f"{base_url}/api/v1/lotus/list", + params={"kind": "action", "limit": 500}, + headers=get_auth_headers(), + timeout=5, + ) + if resp.status_code == 200: + data = resp.json() + items = data.get("items", []) if isinstance(data, dict) else data + # Only return actions that are exposed via lotus API + exposed = [] + for cap in items: + action = cap.get("action") or {} + expose = action.get("expose") or {} + # Include if exposed, or if no expose policy (lenient) + if expose.get("lotus", True): + exposed.append(cap) + return exposed + except Exception as e: + _log(f"Failed to fetch actions: {e}") + return [] + + +# ── Cache for capability schema lookup ── +_cap_cache: dict[str, dict] = {} +_cap_cache_ts: float = 0.0 +_CAP_CACHE_TTL = 30.0 + + +def _get_cap_schema(cap_id: str) -> dict | None: + """Get the capability dict for a given cap_id, using a short-lived cache.""" + import time + global _cap_cache, _cap_cache_ts + + now = time.time() + if now - _cap_cache_ts > _CAP_CACHE_TTL: + actions = _fetch_actions() + _cap_cache = {a["id"]: a for a in actions if "id" in a} + _cap_cache_ts = now + + return _cap_cache.get(cap_id) + + +def _get_input_schema_props(cap: dict) -> dict[str, dict]: + """Extract the JSON Schema 'properties' from a capability's input schema.""" + action = cap.get("action") or cap.get("data", {}) + inp = action.get("input") or {} + schema = inp.get("schema") or {} + return schema.get("properties", {}) + + +def _get_output_schema_props(cap: dict) -> dict[str, dict]: + """Extract JSON Schema 'properties' for action output.""" + action = cap.get("action") or cap.get("data", {}) + out = action.get("output") or {} + schema = out.get("schema") or {} + return schema.get("properties", {}) + + +def _pick_first_value(source: dict, keys: list[str]): + for key in keys: + if key in source and source.get(key) is not None: + return source.get(key) + return None + + +def _pick_first_value_deep(source, keys: list[str]): + """Depth-first search through nested dict/list values for preferred keys.""" + if not isinstance(source, (dict, list)): + return None + + stack = [source] + seen_ids: set[int] = set() + while stack: + cur = stack.pop() + cur_id = id(cur) + if cur_id in seen_ids: + continue + seen_ids.add(cur_id) + + if isinstance(cur, dict): + val = _pick_first_value(cur, keys) + if val is not None: + return val + for v in cur.values(): + if isinstance(v, (dict, list)): + stack.append(v) + elif isinstance(cur, list): + for item in cur: + if isinstance(item, (dict, list)): + stack.append(item) + + return None + + +def _build_action_choices() -> list[str]: + """Build a list of action IDs for the combo dropdown.""" + actions = _fetch_actions() + choices = ["(select action)"] + for cap in actions: + cap_id = cap.get("id", "") + title = cap.get("title", cap_id) + plugin = cap.get("plugin", "") + label = f"{cap_id}" + if title and title != cap_id: + label = f"{cap_id} [{title}]" + if plugin: + label = f"{plugin}/{label}" + choices.append(label) + return choices + + +def _extract_cap_id(choice: str) -> str: + """Extract the raw capability ID from a combo label. + + Labels look like: + - "plugin/cap.id [Title]" + - "plugin/cap.id" + - "cap.id [Title]" + - "cap.id" + """ + if not choice or choice == "(select action)": + return "" + + # Strip plugin prefix if present + if "/" in choice: + choice = choice.split("/", 1)[1] + + # Strip title suffix + if " [" in choice: + choice = choice.split(" [", 1)[0] + + return choice.strip() + + +class EmbeddrActionNode(io.ComfyNode): + """ + Executes any Lotus action dynamically. + + The node exposes a combo dropdown populated with available actions + and a JSON payload input. The frontend extension adds typed widgets + for each input declared by the selected action's schema; those values + are merged into the payload at execution time. + """ + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="embeddr.Action", + display_name="Embeddr Action", + description=( + "Execute any Lotus action. Select an action from the " + "dropdown — input widgets are created automatically." + ), + category="Embeddr/Lotus", + inputs=[ + # String input — the frontend extension converts this into + # a combo widget populated with live Lotus actions. + # We intentionally avoid io.Combo.Input because ComfyUI + # validates combo values against the static options list + # defined at schema time, which cannot include actions + # discovered dynamically from the backend. + io.String.Input( + "action_id", + default="", + tooltip="Select a Lotus action to execute (populated by frontend)", + ), + io.String.Input( + "payload_json", + default="{}", + multiline=True, + tooltip=( + "Additional JSON payload merged with widget inputs. " + "Widget values take precedence." + ), + ), + # Optional artifact input for actions that operate on artifacts + EmbeddrArtifactID.Input( + "artifact_id", + optional=True, + tooltip="Artifact ID input (passed as 'artifact_id' in payload)", + ), + EmbeddrArtifactID.Input( + "artifact_ids", + optional=True, + tooltip="Multiple artifact IDs (passed as 'artifact_ids' in payload)", + ), + ], + outputs=[ + io.String.Output( + "result_json", + tooltip="Full JSON response from the action", + ), + io.String.Output( + "status", + tooltip="'ok', 'error', or 'pending'", + ), + EmbeddrArtifactID.Output( + "output_artifact_ids", + tooltip="Artifact IDs extracted from the result (if any)", + ), + io.String.Output( + "text", + tooltip="Primary text content (caption, message, value, etc.)", + ), + io.String.Output( + "error", + tooltip="Error message (empty on success)", + ), + ], + ) + + @classmethod + def execute( + cls, + action_id: str, + payload_json: str = "{}", + artifact_id: EmbeddrArtifactIDObject | None = None, + artifact_ids: EmbeddrArtifactIDObject | None = None, + **kwargs, + ): + cap_id = _extract_cap_id(action_id) + if not cap_id: + _log("No action selected") + return io.NodeOutput( + json.dumps({"error": "No action selected"}), + "error", + "", + "", + "No action selected", + ) + + # ── Build payload from payload_json ── + # The frontend packs all dyn_* widget values into payload_json + # before serialization (since ComfyUI only passes schema-declared + # inputs to execute). + try: + payload = json.loads(payload_json) if payload_json.strip() else {} + except json.JSONDecodeError as e: + _log(f"Invalid payload JSON: {e}") + payload = {} + + # Merge connected dynamic inputs (dyn_*) so graph-linked values + # can drive action payload fields. These override payload_json. + for key, val in (kwargs or {}).items(): + if not isinstance(key, str) or not key.startswith(_DYN_PREFIX): + continue + + payload_key = key[len(_DYN_PREFIX):] + if not payload_key: + continue + + runtime_val = val + if isinstance(runtime_val, list) and len(runtime_val) == 1: + runtime_val = runtime_val[0] + + if isinstance(runtime_val, EmbeddrArtifactIDObject): + runtime_val = runtime_val.artifact_id + + if isinstance(runtime_val, str): + trimmed = runtime_val.strip() + if trimmed.startswith("{") or trimmed.startswith("["): + try: + runtime_val = json.loads(trimmed) + except Exception: + pass + + if runtime_val is None: + continue + if isinstance(runtime_val, str) and runtime_val == "": + continue + + payload[payload_key] = runtime_val + + # ── Schema-aware artifact_id mapping ── + cap = _get_cap_schema(cap_id) + schema_props = _get_input_schema_props(cap) if cap else {} + has_resource_field = "resource" in schema_props + + # Extract connected artifact IDs + raw_aid = None + if artifact_id: + raw_aid = artifact_id.artifact_id if isinstance( + artifact_id, EmbeddrArtifactIDObject) else artifact_id + + raw_aids = None + if artifact_ids: + raw_aids = artifact_ids.artifact_id if isinstance( + artifact_ids, EmbeddrArtifactIDObject) else artifact_ids + if raw_aids and not isinstance(raw_aids, list): + raw_aids = [raw_aids] + + if raw_aid or raw_aids: + # Deduplicate + seen: set[str] = set() + unique_ids: list[str] = [] + for aid in ([raw_aid] if isinstance(raw_aid, str) and raw_aid else (raw_aid if isinstance(raw_aid, list) else [])): + if aid and aid not in seen: + seen.add(aid) + unique_ids.append(aid) + for aid in (raw_aids or []): + if aid and aid not in seen: + seen.add(aid) + unique_ids.append(aid) + + if unique_ids: + if has_resource_field and "resource" not in payload: + payload["resource"] = unique_ids[0] if len( + unique_ids) == 1 else unique_ids + else: + if "artifact_id" not in payload: + payload["artifact_id"] = unique_ids[0] + if len(unique_ids) > 1 and "artifact_ids" not in payload: + payload["artifact_ids"] = unique_ids + + # Dispatch the action + base_url = get_embeddr_base_url() + url = f"{base_url}/api/v1/lotus/{cap_id}" + + _log( + f"Invoking action '{cap_id}' with payload: {json.dumps(payload, default=str)[:500]}") + + try: + resp = requests.post( + url, + json=payload, + headers={ + **get_auth_headers(), + "Content-Type": "application/json", + }, + timeout=120, + ) + + _log(f"Response: {resp.status_code}") + + if resp.status_code >= 400: + error_text = resp.text[:500] + _log(f"Action error: {error_text}") + return io.NodeOutput( + json.dumps( + {"error": error_text, "status_code": resp.status_code}), + "error", + "", + "", + error_text, + ) + + result = resp.json() if resp.headers.get( + "content-type", "").startswith("application/json") else {"raw": resp.text} + + # ── Extract typed outputs from result ── + output_ids = "" + text_output = "" + error_output = "" + status_str = "ok" + + if isinstance(result, dict): + outputs_obj = result.get("outputs") if isinstance( + result.get("outputs"), dict) else {} + output_schema_props = _get_output_schema_props( + cap) if cap else {} + + # Status + if "status" in result: + status_str = str(result["status"]) + elif result.get("ok") is False: + status_str = "error" + + # Artifact IDs + artifact_keys = [ + key for key in ( + "artifact_id", "id", "artifact_ids", "ids", "output_artifact_id", + ) + if (not output_schema_props) or key in output_schema_props + ] or ["artifact_id", "id", "artifact_ids", "ids", "output_artifact_id"] + artifact_val = _pick_first_value(outputs_obj, artifact_keys) + if artifact_val is None: + artifact_val = _pick_first_value(result, artifact_keys) + if artifact_val: + output_ids = artifact_val if isinstance( + artifact_val, str) else json.dumps(artifact_val) + + # Primary text content (captions, messages, values, etc.) + _TEXT_KEYS = ( + "response_text", "caption_text", "caption", "value", "text", + "message", "content", "description", "summary", + "output", "answer", "response", + ) + text_keys = [ + key for key in _TEXT_KEYS + if (not output_schema_props) or key in output_schema_props + ] or list(_TEXT_KEYS) + text_val = _pick_first_value(outputs_obj, text_keys) + if text_val is None: + text_val = _pick_first_value(result, text_keys) + if text_val is None: + text_val = _pick_first_value_deep(outputs_obj, text_keys) + if text_val is None: + text_val = _pick_first_value_deep(result, text_keys) + if isinstance(text_val, str) and text_val.strip().lower() in { + "completed", + "queued", + "pending", + "ok", + "success", + }: + better_val = _pick_first_value_deep( + outputs_obj, + [k for k in text_keys if k != "message"], + ) + if better_val is None: + better_val = _pick_first_value_deep( + result, + [k for k in text_keys if k != "message"], + ) + if better_val is not None: + text_val = better_val + if text_val is not None: + text_output = str(text_val) if not isinstance( + text_val, str) else text_val + + # Error + if result.get("error"): + error_output = str(result["error"]) + elif isinstance(outputs_obj, dict) and outputs_obj.get("error"): + error_output = str(outputs_obj["error"]) + + return io.NodeOutput( + json.dumps(result, default=str), + status_str, + output_ids, + text_output, + error_output, + ) + + except requests.Timeout: + _log(f"Action '{cap_id}' timed out") + return io.NodeOutput( + json.dumps({"error": "Request timed out (120s)"}), + "error", + "", + "", + "Request timed out (120s)", + ) + except Exception as e: + _log(f"Action '{cap_id}' failed: {e}") + return io.NodeOutput( + json.dumps({"error": str(e)}), + "error", + "", + "", + str(e), + ) diff --git a/nodes/EmbeddrLoadArtifact.py b/nodes/EmbeddrLoadArtifact.py index 590563f..ed43c8f 100644 --- a/nodes/EmbeddrLoadArtifact.py +++ b/nodes/EmbeddrLoadArtifact.py @@ -31,11 +31,12 @@ def _debug(cls, message: str, **fields): cls._logger.info("%s", message) @classmethod - def _resolve_artifact_url(cls, base_url: str, artifact_id: str): + def _resolve_artifact_url(cls, base_url: str, artifact_id: str, auth_ticket: str = ""): resolve_url = f"{base_url}/api/v1/artifacts/{artifact_id}/resolve?variant=original&proxy=1" cls._debug("resolving_artifact", artifact_id=artifact_id, resolve_url=resolve_url) - res = requests.get(resolve_url, headers=get_auth_headers()) + res = requests.get( + resolve_url, headers=get_auth_headers(auth_ticket=auth_ticket)) res.raise_for_status() data = res.json() url = data.get("url") @@ -73,6 +74,8 @@ def define_schema(cls) -> io.Schema: tooltip="Manual UUID string (used if input not connected)", default=""), EmbeddrArtifactID.Input("artifact_id", tooltip="UUID of the artifact to load", optional=True), + io.String.Input("auth_ticket", default="", + tooltip="Ephemeral auth ticket for per-user access", optional=True), io.Boolean.Input("use_cache", default=True) ], outputs=[ @@ -84,7 +87,7 @@ def define_schema(cls) -> io.Schema: ) @classmethod - def execute(cls, use_cache, artifact_id=None, manual_artifact_id=None): + def execute(cls, use_cache, artifact_id=None, manual_artifact_id=None, auth_ticket: str = ""): # Resolve artifact_id from connection or manual input resolved_id = "" @@ -111,8 +114,10 @@ def execute(cls, use_cache, artifact_id=None, manual_artifact_id=None): EmbeddrArtifactInfoObject(data={}) ) - if use_cache and resolved_id in cls._cache: - image, mask, info = cls._cache[resolved_id] + cache_key = (resolved_id, str(auth_ticket or "")) + + if use_cache and cache_key in cls._cache: + image, mask, info = cls._cache[cache_key] return io.NodeOutput(image, mask, EmbeddrArtifactIDObject(artifact_id=resolved_id), info) try: @@ -123,14 +128,15 @@ def execute(cls, use_cache, artifact_id=None, manual_artifact_id=None): # 1. Fetch JSON metadata first meta_url = f"{base_url}/api/v1/artifacts/{resolved_id}" - meta_res = requests.get(meta_url, headers=get_auth_headers()) + meta_res = requests.get( + meta_url, headers=get_auth_headers(auth_ticket=auth_ticket)) meta_res.raise_for_status() artifact_data = meta_res.json() info_obj = EmbeddrArtifactInfoObject(data=artifact_data) # 2. Resolve content endpoint, content_headers = cls._resolve_artifact_url( - base_url, resolved_id + base_url, resolved_id, auth_ticket ) cls._debug( @@ -140,7 +146,7 @@ def execute(cls, use_cache, artifact_id=None, manual_artifact_id=None): ) # Merge auth headers with any resolved headers (e.g. S3 signed headers or similar) - final_headers = get_auth_headers() + final_headers = get_auth_headers(auth_ticket=auth_ticket) if content_headers: final_headers.update(content_headers) @@ -166,7 +172,7 @@ def execute(cls, use_cache, artifact_id=None, manual_artifact_id=None): mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") if use_cache: - cls._cache[resolved_id] = (image, mask, info_obj) + cls._cache[cache_key] = (image, mask, info_obj) return io.NodeOutput(image, mask, EmbeddrArtifactIDObject(artifact_id=resolved_id), info_obj, ui=ui.PreviewImage(image)) diff --git a/nodes/EmbeddrLoadArtifacts.py b/nodes/EmbeddrLoadArtifacts.py index 6c40dc8..0791650 100644 --- a/nodes/EmbeddrLoadArtifacts.py +++ b/nodes/EmbeddrLoadArtifacts.py @@ -33,11 +33,12 @@ def _debug(cls, message: str, **fields): cls._logger.info("%s", message) @classmethod - def _resolve_artifact_url(cls, base_url: str, artifact_id: str): + def _resolve_artifact_url(cls, base_url: str, artifact_id: str, auth_ticket: str = ""): resolve_url = f"{base_url}/api/v1/artifacts/{artifact_id}/resolve?variant=original&proxy=1" cls._debug("resolving_artifact", artifact_id=artifact_id, resolve_url=resolve_url) - res = requests.get(resolve_url, headers=get_auth_headers()) + res = requests.get( + resolve_url, headers=get_auth_headers(auth_ticket=auth_ticket)) res.raise_for_status() data = res.json() url = data.get("url") @@ -75,6 +76,8 @@ def define_schema(cls) -> io.Schema: inputs=[ EmbeddrArtifactID.Input( "artifact_ids", tooltip="Optional list of IDs to load (overrides collection)", optional=True), + io.String.Input("auth_ticket", default="", + tooltip="Ephemeral auth ticket for per-user access", optional=True), io.Combo.Input( "collection", options=collections, default="All"), io.Combo.Input("sort_by", options=[ @@ -90,7 +93,7 @@ def define_schema(cls) -> io.Schema: ) @classmethod - def execute(cls, collection, sort_by, limit, seed, artifact_ids=None): + def execute(cls, collection, sort_by, limit, seed, artifact_ids=None, auth_ticket: str = ""): # Check for explicit IDs first manual_ids = normalize_ids(artifact_ids) @@ -98,9 +101,10 @@ def execute(cls, collection, sort_by, limit, seed, artifact_ids=None): if manual_ids: # Sort for stability in cache key manual_ids.sort() - cache_key = ("ids", tuple(manual_ids)) + cache_key = ("ids", tuple(manual_ids), str(auth_ticket or "")) else: - cache_key = (collection, sort_by, limit, seed) + cache_key = (collection, sort_by, limit, + seed, str(auth_ticket or "")) if cache_key in cls._cache: return cls._cache[cache_key] @@ -145,7 +149,7 @@ def execute(cls, collection, sort_by, limit, seed, artifact_ids=None): params["sort"] = "new" response = requests.get( - api_url, params=params, headers=get_auth_headers()) + api_url, params=params, headers=get_auth_headers(auth_ticket=auth_ticket)) response.raise_for_status() data = response.json() items = data.get("items", []) @@ -161,9 +165,9 @@ def execute(cls, collection, sort_by, limit, seed, artifact_ids=None): art_id = item.get("id") content_url, content_headers = cls._resolve_artifact_url( - base_url, art_id) + base_url, art_id, auth_ticket) try: - final_headers = get_auth_headers() + final_headers = get_auth_headers(auth_ticket=auth_ticket) if content_headers: final_headers.update(content_headers) diff --git a/nodes/EmbeddrLoadImages.py b/nodes/EmbeddrLoadImages.py index 61b559e..5162a68 100644 --- a/nodes/EmbeddrLoadImages.py +++ b/nodes/EmbeddrLoadImages.py @@ -32,11 +32,12 @@ def _debug(cls, message: str, **fields): cls._logger.info("%s", message) @classmethod - def _resolve_artifact_url(cls, base_url: str, artifact_id: str): + def _resolve_artifact_url(cls, base_url: str, artifact_id: str, auth_ticket: str = ""): resolve_url = f"{base_url}/api/v1/artifacts/{artifact_id}/resolve?variant=original&proxy=1" cls._debug("resolving_artifact", artifact_id=artifact_id, resolve_url=resolve_url) - res = requests.get(resolve_url, headers=get_auth_headers()) + res = requests.get( + resolve_url, headers=get_auth_headers(auth_ticket=auth_ticket)) res.raise_for_status() data = res.json() url = data.get("url") @@ -82,6 +83,8 @@ def define_schema(cls) -> io.Schema: io.Combo.Input("sort_by", options=[ "newest", "random"], default="newest"), io.Int.Input("limit", default=5, min=1, max=100), + io.String.Input("auth_ticket", default="", + tooltip="Ephemeral auth ticket for per-user access", optional=True), io.Int.Input("seed", default=0, display_name="Seed (Random Sort)"), ], @@ -93,9 +96,10 @@ def define_schema(cls) -> io.Schema: ) @classmethod - def execute(cls, library, collection, sort_by, limit, seed): + def execute(cls, library, collection, sort_by, limit, auth_ticket: str = "", seed=0): # Cache key based on inputs - cache_key = (library, collection, sort_by, limit, seed) + cache_key = (library, collection, sort_by, + limit, seed, str(auth_ticket or "")) if cache_key in cls._cache: return cls._cache[cache_key] @@ -139,7 +143,7 @@ def execute(cls, library, collection, sort_by, limit, seed): params["sort"] = "new" response = requests.get( - api_url, params=params, headers=get_auth_headers()) + api_url, params=params, headers=get_auth_headers(auth_ticket=auth_ticket)) response.raise_for_status() data = response.json() items = data.get("items", []) @@ -159,7 +163,7 @@ def execute(cls, library, collection, sort_by, limit, seed): # Fetch Content via Plugin Endpoint # Uses the plugin endpoint we defined to get raw content content_url, content_headers = cls._resolve_artifact_url( - base_url, art_id) + base_url, art_id, auth_ticket) cls._debug( "requesting_artifact_content", @@ -168,7 +172,7 @@ def execute(cls, library, collection, sort_by, limit, seed): ) try: - final_headers = get_auth_headers() + final_headers = get_auth_headers(auth_ticket=auth_ticket) if content_headers: final_headers.update(content_headers) diff --git a/nodes/EmbeddrUploadArtifact.py b/nodes/EmbeddrUploadArtifact.py index e5510bd..b62b9ef 100644 --- a/nodes/EmbeddrUploadArtifact.py +++ b/nodes/EmbeddrUploadArtifact.py @@ -16,6 +16,14 @@ def Embeddr_Log(message: str): print(f"[Embeddr] {message}") +def _has_auth_ticket(value) -> bool: + if isinstance(value, (list, tuple)): + value = value[0] if value else None + if isinstance(value, dict): + value = value.get("auth_ticket") or value.get("ticket") + return bool(str(value).strip()) if value is not None else False + + class EmbeddrUploadArtifactNode(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: @@ -28,6 +36,8 @@ def define_schema(cls) -> io.Schema: io.Image.Input("image"), EmbeddrArtifactID.Input("parent_ids", optional=True, tooltip="Parent artifact UUIDs"), + io.String.Input("auth_ticket", default="", + tooltip="Ephemeral auth ticket for per-user ownership", optional=True), EmbeddrUploadArtifactOptions.Input("options", tooltip="Upload Artifact Options", optional=True, display_name="Options"), ], @@ -37,7 +47,7 @@ def define_schema(cls) -> io.Schema: ) @classmethod - def execute(cls, image, parent_ids: EmbeddrArtifactIDObject = None, options: EmbeddrUploadArtifactOptionsObject = None) -> io.NodeOutput: + def execute(cls, image, parent_ids: EmbeddrArtifactIDObject = None, auth_ticket: str = "", options: EmbeddrUploadArtifactOptionsObject = None) -> io.NodeOutput: base_url = get_embeddr_base_url() upload_mode = get_upload_mode() endpoint = f"{base_url}/api/v1/plugins/embeddr-comfyui/upload" @@ -81,6 +91,7 @@ def execute(cls, image, parent_ids: EmbeddrArtifactIDObject = None, options: Emb Tags: {options.tags if options else 'none'} Related Artifacts: {options.related_artifact_ids if options else 'none'} Parent IDs: {normalized_parent_ids} + Auth Ticket: {'present' if _has_auth_ticket(auth_ticket) else 'absent'} """) storage_provider = None @@ -137,7 +148,7 @@ def execute(cls, image, parent_ids: EmbeddrArtifactIDObject = None, options: Emb endpoint, files=files, data=data, - headers=get_auth_headers() + headers=get_auth_headers(auth_ticket=auth_ticket) ) response.raise_for_status() res_json = response.json() @@ -146,6 +157,15 @@ def execute(cls, image, parent_ids: EmbeddrArtifactIDObject = None, options: Emb results.append(str(art_id)) Embeddr_Log(f"Uploaded Artifact: {art_id}") + except requests.HTTPError as e: + body = "" + try: + body = e.response.text[:500] if e.response is not None else "" + except Exception: + body = "" + suffix = f" body={body}" if body else "" + Embeddr_Log( + f"Upload failed for batch {batch_idx}: {e}{suffix}") except Exception as e: Embeddr_Log(f"Upload failed for batch {batch_idx}: {e}") # We don't crash the whole node, but result might be partial diff --git a/nodes/utils/config.py b/nodes/utils/config.py index f8acfdd..51f8bd7 100644 --- a/nodes/utils/config.py +++ b/nodes/utils/config.py @@ -62,11 +62,26 @@ def get_api_key() -> str | None: ) -def get_auth_headers() -> dict[str, str]: +def get_auth_headers(auth_ticket: str | None = None) -> dict[str, str]: + headers: dict[str, str] = {} + + ticket_value = auth_ticket + if isinstance(ticket_value, (list, tuple)): + ticket_value = ticket_value[0] if ticket_value else None + if isinstance(ticket_value, dict): + ticket_value = ticket_value.get( + "auth_ticket") or ticket_value.get("ticket") + + if ticket_value is not None: + ticket_text = str(ticket_value).strip() + if ticket_text: + headers["X-Embeddr-Ticket"] = ticket_text + key = get_api_key() if key: - return {"X-API-Key": key} - return {} + headers["X-API-Key"] = key + + return headers def get_upload_mode(default: str = "require") -> str: diff --git a/package.json b/package.json index 9d6b820..b04d9a2 100644 --- a/package.json +++ b/package.json @@ -21,7 +21,7 @@ "@dnd-kit/core": "^6.3.1", "@dnd-kit/sortable": "^10.0.0", "@dnd-kit/utilities": "^3.2.2", - "@embeddr/api": "workspace:*", + "@embeddr/client-typescript": "workspace:*", "@embeddr/react-ui": "workspace:*", "@embeddr/zen-shell": "workspace:*", "@fontsource-variable/jetbrains-mono": "^5.2.8", diff --git a/ui/components/GlobalDialog.tsx b/ui/components/GlobalDialog.tsx index babb311..b32adc0 100644 --- a/ui/components/GlobalDialog.tsx +++ b/ui/components/GlobalDialog.tsx @@ -4,7 +4,7 @@ import { DialogContent, DialogHeader, DialogTitle, -} from "@embeddr/react-ui/components/dialog"; +} from "@embeddr/react-ui/components/ui"; import { useImageDialog } from "@embeddr/react-ui"; import { ExploreTab } from "./tabs/ExploreTab"; import { CollectionSelector } from "./selectors/CollectionSelector"; diff --git a/ui/components/ZenShell.tsx b/ui/components/ZenShell.tsx index 64172f1..727e12a 100644 --- a/ui/components/ZenShell.tsx +++ b/ui/components/ZenShell.tsx @@ -11,6 +11,9 @@ import { type ZenWindowRendererProps, EmbeddrProvider, type PluginLoaderAdapter, + CoreUIEventBridge, + ZenWebSocketProvider, + globalEventBus, } from "@embeddr/zen-shell"; import { useEmbeddrApi } from "../hooks/useEmbeddrApi"; import { @@ -23,9 +26,16 @@ import { RefreshCw, LayoutTemplate, } from "lucide-react"; -import { Button } from "@embeddr/react-ui/components/button"; +import { Button } from "@embeddr/react-ui/components/ui"; import { cn } from "@embeddr/react-ui"; +const PANEL_SAFE_AREA = { top: 8, right: 8, bottom: 8, left: 8 }; + +const clamp = (value: number, min: number, max: number) => { + if (max < min) return min; + return Math.min(Math.max(value, min), max); +}; + // Helper to resolve component ID to plugin and component name function resolveComponentId(fullId: string, plugins: Record) { if (!fullId) return null; @@ -148,6 +158,7 @@ const CustomWindowRenderer = React.memo((props: ZenWindowRendererProps) => { componentName={resolved.componentName} api={pluginApi} windowId={id} + isActive={isActive} {...windowState.props} /> @@ -220,16 +231,30 @@ function BasicWindowPanel({ }; }; - const handlePointerMove = useCallback((event: PointerEvent) => { - if (!dragRef.current || dragRef.current.pointerId !== event.pointerId) { - return; - } - const next = { - x: event.clientX - dragRef.current.startX, - y: event.clientY - dragRef.current.startY, - }; - setPos(next); - }, []); + const handlePointerMove = useCallback( + (event: PointerEvent) => { + if (!dragRef.current || dragRef.current.pointerId !== event.pointerId) { + return; + } + const maxX = window.innerWidth - dimensions.width - PANEL_SAFE_AREA.right; + const maxY = + window.innerHeight - dimensions.height - PANEL_SAFE_AREA.bottom; + const next = { + x: clamp( + event.clientX - dragRef.current.startX, + PANEL_SAFE_AREA.left, + maxX, + ), + y: clamp( + event.clientY - dragRef.current.startY, + PANEL_SAFE_AREA.top, + maxY, + ), + }; + setPos(next); + }, + [dimensions.height, dimensions.width], + ); const handlePointerUp = useCallback( (event: PointerEvent) => { @@ -261,19 +286,31 @@ function BasicWindowPanel({ ) { return; } + const minWidth = 240; + const minHeight = 180; + const maxWidth = Math.max( + minWidth, + window.innerWidth - pos.x - PANEL_SAFE_AREA.right, + ); + const maxHeight = Math.max( + minHeight, + window.innerHeight - pos.y - PANEL_SAFE_AREA.bottom, + ); const next = { - width: Math.max( - 240, + width: clamp( resizeRef.current.startW + (event.clientX - resizeRef.current.startX), + minWidth, + maxWidth, ), - height: Math.max( - 180, + height: clamp( resizeRef.current.startH + (event.clientY - resizeRef.current.startY), + minHeight, + maxHeight, ), }; setDimensions(next); }, - [dimensions.width, dimensions.height], + [pos.x, pos.y], ); const handleResizeUp = useCallback( @@ -303,6 +340,43 @@ function BasicWindowPanel({ }; }, [handlePointerMove, handlePointerUp, handleResizeMove, handleResizeUp]); + useEffect(() => { + const ensureInBounds = () => { + const maxX = window.innerWidth - dimensions.width - PANEL_SAFE_AREA.right; + const maxY = + window.innerHeight - dimensions.height - PANEL_SAFE_AREA.bottom; + setPos((prev) => ({ + x: clamp(prev.x, PANEL_SAFE_AREA.left, maxX), + y: clamp(prev.y, PANEL_SAFE_AREA.top, maxY), + })); + setDimensions((prev) => { + const minWidth = 240; + const minHeight = 180; + const constrainedWidth = clamp( + prev.width, + minWidth, + Math.max(minWidth, window.innerWidth - pos.x - PANEL_SAFE_AREA.right), + ); + const constrainedHeight = clamp( + prev.height, + minHeight, + Math.max( + minHeight, + window.innerHeight - pos.y - PANEL_SAFE_AREA.bottom, + ), + ); + return { + width: constrainedWidth, + height: constrainedHeight, + }; + }); + }; + + ensureInBounds(); + window.addEventListener("resize", ensureInBounds); + return () => window.removeEventListener("resize", ensureInBounds); + }, [dimensions.height, dimensions.width, pos.x, pos.y]); + return (
{}, + setPipelineInput: () => {}, + selectPipeline: () => {}, + }; + + const modelCatalog = { + list: async (input: { + category: string; + page?: number; + limit?: number; + }) => ({ + items: [], + total: 0, + page: input.page || 1, + pages: 1, + category: input.category, + }), + listSamplers: async () => ({ samplers: [], schedulers: [] }), + }; const api: EmbeddrAPI = { stores: { @@ -452,15 +549,7 @@ function createEmbeddrApiAdapter(input: EmbeddrApiAdapterInput): EmbeddrAPI { selectedImage: null, selectImage: () => {}, }, - generation: { - workflows: [], - selectedWorkflow: null, - generations: [], - isGenerating: false, - generate: async () => {}, - setWorkflowInput: () => {}, - selectWorkflow: () => {}, - }, + execution: executionStore, }, ui: { activePanelId: null, @@ -615,25 +704,19 @@ function createEmbeddrApiAdapter(input: EmbeddrApiAdapterInput): EmbeddrAPI { }, } as any, events: { - on: (event, listener) => { - const handler = (e: Event) => listener((e as CustomEvent).detail); - eventTarget.addEventListener(event, handler as EventListener); - return () => - eventTarget.removeEventListener(event, handler as EventListener); - }, - off: (event, listener) => { - eventTarget.removeEventListener(event, listener as EventListener); - }, - emit: (event, payload) => { - eventTarget.dispatchEvent(new CustomEvent(event, { detail: payload })); - }, - }, - comfy: { - getLoras: async () => ({ items: [], total: 0, page: 1, pages: 1 }), - getCheckpoints: async () => ({ items: [], total: 0, page: 1, pages: 1 }), - getEmbeddings: async () => ({ items: [], total: 0, page: 1, pages: 1 }), - getSamplers: async () => ({ samplers: [], schedulers: [] }), + on: (event, listener) => + globalEventBus.on( + event as string, + listener as (...args: any[]) => void, + ), + off: (event, listener) => + globalEventBus.off( + event as string, + listener as (...args: any[]) => void, + ), + emit: (event, payload) => globalEventBus.emit(event as string, payload), }, + models: modelCatalog, windows: { open: (id: string, title: string, componentId: string, props?: any) => useZenWindowStore.getState().openWindow({ @@ -734,16 +817,30 @@ export function ZenShell() { const { plugins, knownPlugins } = usePluginRegistry(); const spawnWindow = useZenWindowStore((s) => s.spawnWindow); + const setPanelConstraints = useZenWindowStore((s) => s.setPanelConstraints); + const [pluginReloadTick, setPluginReloadTick] = useState(0); const embeddrApi = useMemo( () => createEmbeddrApiAdapter(api), [api.endpoint, api.apiKey, api.apiClient], ); + const wsBackendUrl = useMemo( + () => (api.endpoint || "http://localhost:8003").replace(/\/$/, ""), + [api.endpoint], + ); useEffect(() => { console.log("[ZenShell] Mounted"); return () => console.log("[ZenShell] Unmounted"); }, []); + useEffect(() => { + setPanelConstraints({ + enabled: true, + safeArea: PANEL_SAFE_AREA, + snapThreshold: 24, + }); + }, [setPanelConstraints]); + const adapter = useMemo(() => { console.log("[ZenShell] Recreating adapter"); return { @@ -872,7 +969,10 @@ export function ZenShell() { // We always render the Manager (so windows exist), but maybe hide the Dock return ( - <> +
+ @@ -1012,7 +1114,10 @@ export function ZenShell() { variant="ghost" size="sm" className="w-full mt-4 text-xs text-muted-foreground hover:text-foreground" - onClick={() => loadExternalPlugins({ adapter })} + onClick={async () => { + await loadExternalPlugins({ adapter }); + setPluginReloadTick((prev) => prev + 1); + }} > Reload Plugins @@ -1021,6 +1126,6 @@ export function ZenShell() { )}
)} - +
); } diff --git a/ui/components/panels/EmbeddrPanel.tsx b/ui/components/panels/EmbeddrPanel.tsx index 2e64125..9195af0 100644 --- a/ui/components/panels/EmbeddrPanel.tsx +++ b/ui/components/panels/EmbeddrPanel.tsx @@ -5,7 +5,7 @@ import { TabsContent, TabsList, TabsTrigger, -} from "@embeddr/react-ui/components/tabs"; +} from "@embeddr/react-ui/components/ui"; import { useExternalNav, useImageDialog } from "@embeddr/react-ui"; import { GlobeIcon, @@ -14,10 +14,11 @@ import { Settings, LayoutTemplate, } from "lucide-react"; -import { Button } from "@embeddr/react-ui/components/button"; +import { Button } from "@embeddr/react-ui/components/ui"; import { useEmbeddrApi } from "@hooks/useEmbeddrApi"; import { SettingsForm } from "../tabs/SettingsForm"; import { ExploreTab } from "../tabs/ExploreTab"; +import { PromptTab } from "../tabs/PromptTab"; export default function EmbeddrPanel() { const { @@ -181,6 +182,16 @@ export default function EmbeddrPanel() { /> + + + openExternal(`${endpoint.replace(/\/+$/, "")}/docs`) + } + /> + + ; diff --git a/ui/components/tabs/PromptTab.tsx b/ui/components/tabs/PromptTab.tsx new file mode 100644 index 0000000..49e7a8a --- /dev/null +++ b/ui/components/tabs/PromptTab.tsx @@ -0,0 +1,436 @@ +import React, { useCallback, useEffect, useMemo, useState } from "react"; +import { AlertCircle, Bot, Loader2, Send, Wrench } from "lucide-react"; +import { Button } from "@embeddr/react-ui/components/ui"; +import { Card } from "@embeddr/react-ui/components/ui"; +import { Input } from "@embeddr/react-ui/components/ui"; +import { Label } from "@embeddr/react-ui/components/ui"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@embeddr/react-ui/components/ui"; +import { ScrollArea } from "@embeddr/react-ui/components/ui"; +import { Textarea } from "@embeddr/react-ui/components/ui"; + +type ChatMessage = { + id?: string; + role: "user" | "assistant" | "error"; + content: string; +}; + +type LlmProvider = { + id: string; + name?: string; +}; + +type LlmModel = { + id: string; + name?: string; + provider_id?: string; +}; + +interface PromptTabProps { + endpoint: string; + apiKey?: string; + onOpenDocs: () => void; +} + +const POLL_INTERVAL_MS = 1200; +const MAX_POLL_ATTEMPTS = 120; + +const withApiPrefix = (endpoint: string, path: string) => { + const cleanEndpoint = endpoint.replace(/\/+$/, ""); + const cleanPath = path.startsWith("/") ? path : `/${path}`; + if (cleanEndpoint.endsWith("/api/v1")) { + return `${cleanEndpoint}${cleanPath}`; + } + return `${cleanEndpoint}/api/v1${cleanPath}`; +}; + +const getHeaders = (apiKey?: string) => { + const headers: Record = { + "Content-Type": "application/json", + }; + if (apiKey?.trim()) { + headers["X-API-Key"] = apiKey.trim(); + } + return headers; +}; + +async function jsonRequest( + endpoint: string, + path: string, + init: RequestInit = {}, + apiKey?: string, +): Promise { + const url = withApiPrefix(endpoint, path); + const res = await fetch(url, { + ...init, + headers: { + ...getHeaders(apiKey), + ...(init.headers || {}), + }, + }); + if (!res.ok) { + const body = await res.text().catch(() => ""); + throw new Error(body || `Request failed (${res.status})`); + } + return (await res.json()) as T; +} + +export function PromptTab({ endpoint, apiKey, onOpenDocs }: PromptTabProps) { + const [checkingPlugin, setCheckingPlugin] = useState(true); + const [hasLlmPlugin, setHasLlmPlugin] = useState(false); + const [pluginCheckError, setPluginCheckError] = useState(null); + + const [providers, setProviders] = useState([]); + const [models, setModels] = useState([]); + const [selectedProviderId, setSelectedProviderId] = useState("auto"); + const [selectedModel, setSelectedModel] = useState("auto"); + + const [systemPrompt, setSystemPrompt] = useState(""); + const [prompt, setPrompt] = useState(""); + const [messages, setMessages] = useState([]); + const [isSending, setIsSending] = useState(false); + + const selectedProviderModels = useMemo(() => { + if (selectedProviderId === "auto") { + return models; + } + return models.filter((model) => model.provider_id === selectedProviderId); + }, [models, selectedProviderId]); + + const refreshPluginState = useCallback(async () => { + if (!endpoint) { + setCheckingPlugin(false); + setHasLlmPlugin(false); + setPluginCheckError("No API endpoint configured."); + return; + } + setCheckingPlugin(true); + setPluginCheckError(null); + try { + const data = await jsonRequest(endpoint, "/plugins", {}, apiKey); + const list = Array.isArray(data) + ? data + : Array.isArray(data?.items) + ? data.items + : []; + const found = list.some((plugin: any) => { + const id = String(plugin?.id || plugin?.plugin_id || ""); + return id === "embeddr-llm"; + }); + setHasLlmPlugin(found); + } catch (e: any) { + setHasLlmPlugin(false); + setPluginCheckError(e?.message || "Failed to check installed plugins."); + } finally { + setCheckingPlugin(false); + } + }, [endpoint, apiKey]); + + const loadProvidersAndModels = useCallback(async () => { + if (!hasLlmPlugin) return; + try { + const [providerRes, modelRes] = await Promise.all([ + jsonRequest( + endpoint, + "/plugins/embeddr-llm/providers", + {}, + apiKey, + ), + jsonRequest(endpoint, "/plugins/embeddr-llm/models", {}, apiKey), + ]); + + const providerList = Array.isArray(providerRes?.data) + ? providerRes.data + : Array.isArray(providerRes) + ? providerRes + : []; + const modelList = Array.isArray(modelRes?.data) + ? modelRes.data + : Array.isArray(modelRes) + ? modelRes + : []; + + setProviders(providerList); + setModels(modelList); + } catch { + setProviders([]); + setModels([]); + } + }, [endpoint, apiKey, hasLlmPlugin]); + + useEffect(() => { + refreshPluginState(); + }, [refreshPluginState]); + + useEffect(() => { + loadProvidersAndModels(); + }, [loadProvidersAndModels]); + + const runChat = useCallback(async () => { + const text = prompt.trim(); + if (!text || isSending || !hasLlmPlugin) return; + + const pendingId = `pending-${Date.now()}`; + setPrompt(""); + setIsSending(true); + setMessages((prev) => [ + ...prev, + { role: "user", content: text }, + { id: pendingId, role: "assistant", content: "Thinking..." }, + ]); + + try { + const execution = await jsonRequest( + endpoint, + "/executions", + { + method: "POST", + body: JSON.stringify({ + plugin_name: "embeddr-llm", + job_type: "llm.respond", + inputs: { + prompt: text, + system_prompt: systemPrompt.trim() || undefined, + provider_id: + selectedProviderId !== "auto" ? selectedProviderId : undefined, + model: selectedModel !== "auto" ? selectedModel : undefined, + }, + }), + }, + apiKey, + ); + + const executionId = execution?.id; + if (!executionId) { + throw new Error("No execution id returned."); + } + + let assistantText = ""; + let failedText = ""; + + for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) { + await new Promise((resolve) => setTimeout(resolve, POLL_INTERVAL_MS)); + const job = await jsonRequest( + endpoint, + `/executions/${executionId}`, + { method: "GET" }, + apiKey, + ); + + if (job?.status === "completed") { + assistantText = + job?.outputs?.response_text || + job?.outputs?.response || + job?.outputs?.text || + JSON.stringify(job?.outputs || {}, null, 2); + break; + } + + if (job?.status === "failed") { + failedText = job?.error || job?.message || "LLM execution failed."; + break; + } + } + + if (assistantText) { + setMessages((prev) => + prev.map((message) => + message.id === pendingId + ? { role: "assistant", content: assistantText } + : message, + ), + ); + } else if (failedText) { + setMessages((prev) => + prev.map((message) => + message.id === pendingId + ? { role: "error", content: failedText } + : message, + ), + ); + } else { + setMessages((prev) => + prev.map((message) => + message.id === pendingId + ? { + role: "error", + content: "Timed out while waiting for LLM response.", + } + : message, + ), + ); + } + } catch (e: any) { + setMessages((prev) => + prev.map((message) => + message.id === pendingId + ? { role: "error", content: e?.message || "Failed to send prompt." } + : message, + ), + ); + } finally { + setIsSending(false); + } + }, [ + apiKey, + endpoint, + hasLlmPlugin, + isSending, + prompt, + selectedModel, + selectedProviderId, + systemPrompt, + ]); + + if (checkingPlugin) { + return ( +
+ + Checking LLM plugin... +
+ ); + } + + if (!hasLlmPlugin) { + return ( +
+ +
+ + LLM chat is not available +
+

+ This tab needs the embeddr-llm plugin installed and + enabled on your configured instance. +

+ {pluginCheckError ? ( +
+ + {pluginCheckError} +
+ ) : null} +
+ + +
+
+
+ ); + } + + return ( +
+
+
+ + +
+ +
+ + +
+
+ +
+ +