diff --git a/src-frontend/src/api/workspace.ts b/src-frontend/src/api/workspace.ts index d570d138..9d7bd7e3 100644 --- a/src-frontend/src/api/workspace.ts +++ b/src-frontend/src/api/workspace.ts @@ -7,6 +7,7 @@ export { useGetWorkspaceSuspense, useGetWorkspacesSuspenseInfinite, useUpdateWorkspace, + useUpdateWorkspaceNotes, openWorkspace, } from "./generated/endpoints/workspace/workspace"; diff --git a/src-frontend/src/features/Tabs/WorkspacePanel/WorkspaceNotesEditForm.tsx b/src-frontend/src/features/Tabs/WorkspacePanel/WorkspaceNotesEditForm.tsx index 58fb7309..282ada46 100644 --- a/src-frontend/src/features/Tabs/WorkspacePanel/WorkspaceNotesEditForm.tsx +++ b/src-frontend/src/features/Tabs/WorkspacePanel/WorkspaceNotesEditForm.tsx @@ -3,7 +3,7 @@ import { useTranslation } from "react-i18next"; import { toast } from "sonner"; import { TABS_WORKSPACE_NAMESPACE } from "@/i18n/resources"; import type { WorkspaceRead } from "@/api/generated/schemas"; -import { invalidateWorkspaceQueries, useUpdateWorkspace } from "@/api/workspace"; +import { invalidateWorkspaceQueries, useUpdateWorkspaceNotes } from "@/api/workspace"; import { FormShell, FormShellFooter } from "@/components/custom/form/FormShell"; import { Button } from "@/components/ui/button"; import { useWorkspaceStore } from "@/stores/workspace-store"; @@ -29,7 +29,7 @@ export function WorkspaceNotesEditForm({ workspace, onConfirm }: WorkspaceNotesE [workspace.notes] ); - const updateMutation = useUpdateWorkspace({ + const updateNotesMutation = useUpdateWorkspaceNotes({ mutation: { async onSuccess(updatedWorkspace: { id: number; name: string }) { await invalidateWorkspaceQueries(updatedWorkspace.id); @@ -47,7 +47,7 @@ export function WorkspaceNotesEditForm({ workspace, onConfirm }: WorkspaceNotesE const handleSubmit = (data: WorkspaceNotesEditFormValues) => { const payload = notesEditFormValuesToPayload(data); - updateMutation.mutate({ workspaceId: workspace.id, data: payload }); + updateNotesMutation.mutate({ workspaceId: workspace.id, data: payload }); }; return ( @@ -59,8 +59,8 @@ export function WorkspaceNotesEditForm({ workspace, onConfirm }: WorkspaceNotesE - diff --git a/src-frontend/src/features/Tabs/WorkspacePanel/form-types.ts b/src-frontend/src/features/Tabs/WorkspacePanel/form-types.ts index aaf1f810..a1afab48 100644 --- a/src-frontend/src/features/Tabs/WorkspacePanel/form-types.ts +++ b/src-frontend/src/features/Tabs/WorkspacePanel/form-types.ts @@ -1,5 +1,6 @@ import type { WorkspaceCreate, + WorkspaceNotesUpdate, WorkspaceRead, WorkspaceUpdate, } from "@/api/generated/schemas"; @@ -52,23 +53,13 @@ export function createFormValuesToPayload( export function editFormValuesToPayload( values: WorkspaceEditFormValues, ): WorkspaceUpdate { - return { - ...values, - notes: null, // explicitly exclude `notes` - }; + return values; } export function notesEditFormValuesToPayload( values: WorkspaceNotesEditFormValues, -): WorkspaceUpdate { +): WorkspaceNotesUpdate { return { - name: null, - directory: null, - instruction: null, - usable_agent_ids: null, - usable_tool_ids: null, - usable_skill_ids: null, notes: arboristDataToResources(values.notes), }; } - diff --git a/src-frontend/src/i18n/locales/en/error.json b/src-frontend/src/i18n/locales/en/error.json index 948bd408..0ff1222a 100644 --- a/src-frontend/src/i18n/locales/en/error.json +++ b/src-frontend/src/i18n/locales/en/error.json @@ -24,5 +24,6 @@ "TOOLSET_INTERNAL_KEY_ALREADY_EXISTS": "This toolset already exists.", "TOOLSET_NOT_FOUND": "The toolset was not found.", "UNEXPECTED_ERROR": "Something went wrong. Please try again.", - "WORKSPACE_NOT_FOUND": "The workspace was not found." + "WORKSPACE_NOT_FOUND": "The workspace was not found.", + "WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK": "Workspace notes cannot be updated while a task in this workspace is running. Please try again after the task finishes." } diff --git a/src-frontend/src/i18n/locales/zh_CN/error.json b/src-frontend/src/i18n/locales/zh_CN/error.json index 014d3b6a..cf3b8d0c 100644 --- a/src-frontend/src/i18n/locales/zh_CN/error.json +++ b/src-frontend/src/i18n/locales/zh_CN/error.json @@ -24,5 +24,6 @@ "TOOLSET_INTERNAL_KEY_ALREADY_EXISTS": "该工具集已存在。", "TOOLSET_NOT_FOUND": "未找到该工具集。", "UNEXPECTED_ERROR": "发生错误,请稍后重试。", - "WORKSPACE_NOT_FOUND": "未找到该工作区。" + "WORKSPACE_NOT_FOUND": "未找到该工作区。", + "WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK": "当前工作区下有任务正在运行,暂时无法更新工作区笔记,请在任务结束后重试。" } diff --git a/src-server/src/agent/context/__init__.py b/src-server/src/agent/context/__init__.py index 961fb44d..6f4fca98 100644 --- a/src-server/src/agent/context/__init__.py +++ b/src-server/src/agent/context/__init__.py @@ -1,11 +1,3 @@ -from .aliases import BuiltInToolAliases -from .models import ToolRuntimeContext, AgentContextResource, AgentContextPersistence - -__all__ = [ - "AgentContext", - "ToolRuntimeContext", -] - import platform import xml.etree.ElementTree as ET from collections import namedtuple @@ -14,7 +6,6 @@ from loguru import logger from dais_sdk.tool import Toolset from dais_sdk.types import Message, ToolDef -from src.agent.notes import NoteManager from src.db import db_context from src.schemas import ( agent as agent_schemas, @@ -29,7 +20,10 @@ from src.services.provider import ProviderService from src.settings import use_app_setting_manager from .persistence import create_agent_context_persistence -from ..tool import use_mcp_toolset_manager, BuiltinToolsetManager, McpToolsetManager, BuiltInToolset +from .aliases import BuiltInToolAliases +from .models import AgentContextResource, AgentContextPersistence +from ..notes import NoteMaterializer +from ..tool import use_mcp_toolset_manager, BuiltinToolsetManager, McpToolsetManager from ..prompts import ( BASE_INSTRUCTION, FAILED_TO_LOAD_NOTES_INDEX, @@ -49,7 +43,7 @@ def __init__(self, *, messages: list[Message], resource: AgentContextResource, - tool_context: ToolRuntimeContext, + usage: ContextUsage, persistence: AgentContextPersistence, builtin_toolset_manager: BuiltinToolsetManager, mcp_toolset_manager: McpToolsetManager): @@ -58,7 +52,7 @@ def __init__(self, self._resource = resource self._messages = messages - self._tool_context = tool_context + self._usage = usage self._persistence = persistence self._builtin_toolset_manager = builtin_toolset_manager self._mcp_toolset_manager = mcp_toolset_manager @@ -82,12 +76,7 @@ async def create(cls, task: task_runtime_schemas.TaskRuntimeContext) -> Self: usage = ContextUsage(**asdict(usage)) messages = task.messages - note_manager = NoteManager(task.workspace_id) - await note_manager.materialize() - await note_manager.start_watching() - - tool_context = ToolRuntimeContext(usage=usage, note_manager=note_manager) - builtin_toolset_manager = await BuiltinToolsetManager.create(workspace.id, workspace.directory, tool_context) + builtin_toolset_manager = await BuiltinToolsetManager.create(workspace.id, workspace.directory) mcp_toolset_manager = use_mcp_toolset_manager() persistence = create_agent_context_persistence(task) @@ -101,7 +90,7 @@ async def create(cls, task: task_runtime_schemas.TaskRuntimeContext) -> Self: model=provider_schemas.LlmModelRead.model_validate(model), skills=[skill_schemas.SkillBrief.model_validate(skill) for skill in skills], ), - tool_context=tool_context, + usage=usage, persistence=persistence, builtin_toolset_manager=builtin_toolset_manager, mcp_toolset_manager=mcp_toolset_manager) @@ -131,7 +120,7 @@ def _resolve_instructions(self) -> ResolvedInstructions: @property def usage(self) -> ContextUsage: - return self._tool_context.usage + return self._usage @property def toolsets(self) -> list[Toolset]: @@ -171,7 +160,7 @@ async def compose_system_instruction(self) -> str: available_skills = AgentContext._format_skills([ skill for skill in self._resource.skills if skill.is_enabled]) resolved_instructions = self._resolve_instructions() - notes_index = await self._tool_context.note_manager.get_notes_index() + notes_index = await NoteMaterializer.get_notes_index(self._resource.workspace.id) resolved_notes_index = notes_index if notes_index is not None else FAILED_TO_LOAD_NOTES_INDEX return resolved_instructions.base.format( @@ -194,14 +183,8 @@ def find_tool(self, tool_name: str) -> ToolDef | None: return None async def persist(self) -> task_runtime_schemas.TaskRuntimeContext: - try: - await self._tool_context.note_manager.stop_watching() - await self._tool_context.note_manager.clear_materialized() - except: - self._logger.exception("Failed to execute NoteManager cleanup.") - return await self._persistence.persist( self.task_id, self._messages, - self._tool_context.usage, + self._usage, ) diff --git a/src-server/src/agent/context/models.py b/src-server/src/agent/context/models.py index cb175fad..4c81e5f7 100644 --- a/src-server/src/agent/context/models.py +++ b/src-server/src/agent/context/models.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Protocol from dais_sdk.types import Message -from src.agent.notes import NoteManager from src.db.models import tasks as task_models from src.schemas import ( agent as agent_schemas, @@ -10,14 +9,8 @@ workspace as workspace_schemas, ) from src.schemas.tasks import runtime as task_runtime_schemas -from src.agent.types import ContextUsage -@dataclass(frozen=True) -class ToolRuntimeContext: - usage: ContextUsage - note_manager: NoteManager - @dataclass(frozen=True) class AgentContextResource: workspace: workspace_schemas.WorkspaceRead diff --git a/src-server/src/agent/notes/__init__.py b/src-server/src/agent/notes/__init__.py index fd0cb64c..854c4dfe 100644 --- a/src-server/src/agent/notes/__init__.py +++ b/src-server/src/agent/notes/__init__.py @@ -1 +1,3 @@ -from .manager import NoteManager +from .materializer import NoteMaterializer +from .watcher import NoteWatcher +from .workspace_ref_manager import WorkspaceRefManager diff --git a/src-server/src/agent/notes/manager.py b/src-server/src/agent/notes/manager.py deleted file mode 100644 index 0f181eb8..00000000 --- a/src-server/src/agent/notes/manager.py +++ /dev/null @@ -1,212 +0,0 @@ -import asyncio -import shutil -from typing import Any, Callable -from anyio import Path as AnyioPath -from pathlib import Path as StdPath -from loguru import logger -from watchfiles import awatch, Change as ChangeType -from src.common import DATA_DIR -from src.db import db_context - - -type FileChange = tuple[ChangeType, AnyioPath] - -class NoteManager: - NOTES_DIR_ENVNAME = "DAIS_NOTES_DIR" - _logger = logger.bind(name="NoteManager") - _workspace_ref_counts: dict[int, int] = {} - _workspace_ref_lock: asyncio.Lock | None = None - - def __init__(self, workspace_id: int): - self._workspace_id = workspace_id - self._stop_watching_event = None - self._watch_task = None - - @classmethod - def _get_lock(cls) -> asyncio.Lock: - if cls._workspace_ref_lock is None: - cls._workspace_ref_lock = asyncio.Lock() - return cls._workspace_ref_lock - - @classmethod - async def _increment_workspace_ref(cls, workspace_id: int) -> int: - async with cls._get_lock(): - current = cls._workspace_ref_counts.get(workspace_id, 0) - ref_count = current + 1 - cls._workspace_ref_counts[workspace_id] = ref_count - return ref_count - - @classmethod - async def _decrement_workspace_ref(cls, workspace_id: int, *, force: bool = False) -> int: - async with cls._get_lock(): - if force: - cls._workspace_ref_counts.pop(workspace_id, None) - return 0 - current = cls._workspace_ref_counts.get(workspace_id, 0) - if current <= 1: - cls._workspace_ref_counts.pop(workspace_id, None) - return 0 - ref_count = current - 1 - cls._workspace_ref_counts[workspace_id] = ref_count - return ref_count - - @staticmethod - async def get_notes_root_dir() -> AnyioPath: - notes_root_dir = AnyioPath(DATA_DIR, ".notes") - await notes_root_dir.mkdir(parents=True, exist_ok=True) - return notes_root_dir - - @staticmethod - def get_notes_dir_env(workspace_id: int) -> dict[str, str]: - return {NoteManager.NOTES_DIR_ENVNAME: str(DATA_DIR / ".notes" / str(workspace_id))} - - async def get_notes_dir(self) -> AnyioPath: - notes_root_dir = await self.get_notes_root_dir() - notes_dir = notes_root_dir / str(self._workspace_id) - await notes_dir.mkdir(parents=True, exist_ok=True) - return notes_dir - - async def get_notes_index(self) -> str | None: - notes_dir = await self.get_notes_dir() - index_file = notes_dir / "NOTES.md" - if not await index_file.exists(): return None - try: - return await index_file.read_text(encoding="utf-8") - except: - self._logger.exception(f"Failed to read root NOTES.md for workspace {self._workspace_id}.") - return None - - async def materialize(self) -> AnyioPath: - """Materialize workspace notes into `$DAIS_NOTES_DIR` for the current task context. - - Calling this method increments the workspace-level materialization reference count. - The caller must later call `clear_materialized()` exactly once to release this - reference; otherwise temporary notes state may be retained unexpectedly. - """ - from src.services.workspace import WorkspaceService - - async with db_context() as db_session: - workspace = await WorkspaceService(db_session).get_workspace_by_id(self._workspace_id) - workspace_notes = workspace.notes - - notes_dir = await self.get_notes_dir() - for note in workspace_notes: - note_path = notes_dir / note.relative - await note_path.parent.mkdir(parents=True, exist_ok=True) - await note_path.write_text(note.content, "utf-8") - - await self._increment_workspace_ref(self._workspace_id) - return notes_dir - - async def clear_materialized(self, *, force: bool = False): - """Release one materialization reference and clean notes when eligible. - - This method is the required counterpart of `materialize()`: after each successful - `materialize()` call, the caller must invoke `clear_materialized()` once. - When `force=True`, cleanup ignores reference counts (used by workspace deletion). - """ - notes_root_dir = await self.get_notes_root_dir() - notes_dir = notes_root_dir / str(self._workspace_id) - - ref_count = await self._decrement_workspace_ref(self._workspace_id, force=force) - should_delete = ref_count == 0 - - if not should_delete: return - if not await notes_dir.exists(): return - await asyncio.to_thread(shutil.rmtree, notes_dir) - - async def start_watching(self): - if self._watch_task is not None: - await self.stop_watching() - self._stop_watching_event = asyncio.Event() - notes_dir = await self.get_notes_dir() - self._watch_task = asyncio.create_task( - self._watch_files(notes_dir, self._stop_watching_event)) - - async def stop_watching(self) -> None: - if self._stop_watching_event: - self._stop_watching_event.set() - self._stop_watching_event = None - - if self._watch_task and not self._watch_task.done(): - self._watch_task.cancel() - try: - await self._watch_task - except asyncio.CancelledError: - pass - finally: - self._watch_task = None - - async def _handle_file_changes(self, changes: list[FileChange]): - from src.services.workspace import WorkspaceService - from src.db.models import workspace as workspace_models - - if len(changes) == 0: return - added_notes: list[tuple[AnyioPath, str]] = [] # (relative, content) - updated_notes: list[tuple[AnyioPath, str]] = [] # (relative, content) - deleted_notes: list[AnyioPath] = [] - - for change_type, file_path in changes: - try: - match change_type: - case ChangeType.added: - content = await AnyioPath(file_path).read_text("utf-8") - added_notes.append((file_path, content)) - case ChangeType.modified: - content = await AnyioPath(file_path).read_text("utf-8") - updated_notes.append((file_path, content)) - case ChangeType.deleted: - deleted_notes.append(file_path) - except Exception: - # read file failed, skip - pass - - notes_dir = await self.get_notes_dir() - normalized_path: Callable[[AnyioPath], str] = lambda path: path.relative_to(notes_dir).as_posix() - - async with db_context() as db_session: - workspace = await WorkspaceService(db_session).get_workspace_by_id(self._workspace_id) - existing_notes: dict[str, workspace_models.WorkspaceNote] = { - note.relative: note for note in workspace.notes - } - - for path in deleted_notes: - existing_notes.pop(normalized_path(path)) - for path, content in added_notes: - relative = normalized_path(path) - existing_notes[relative] = workspace_models.WorkspaceNote( - relative=relative, - content=content, - ) - for path, content in updated_notes: - relative = normalized_path(path) - existing_note = existing_notes.get(relative, workspace_models.WorkspaceNote( - relative=relative, - content=content, - )) - existing_note.content = content - - workspace.notes = list(existing_notes.values()) - - async def _watch_files(self, notes_dir: AnyioPath, stop_event: asyncio.Event) -> None: - if not await notes_dir.exists(): - return - - try: - async for changes in awatch( - StdPath(notes_dir), - stop_event=stop_event, - debounce=500, - recursive=True, - ): - markdown_changes: list[Any] = [] - for change_type, path in changes: - path = AnyioPath(path) - if await path.is_symlink() or await path.is_dir(): continue - if path.suffix.lower() != ".md": continue - markdown_changes.append((change_type, path)) - await self._handle_file_changes(markdown_changes) - except asyncio.CancelledError: - self._logger.debug(f"Notes watch cancelled for workspace {self._workspace_id}") - except Exception: - self._logger.exception(f"Error watching notes for workspace {self._workspace_id}") diff --git a/src-server/src/agent/notes/materializer.py b/src-server/src/agent/notes/materializer.py new file mode 100644 index 00000000..bb1e7e49 --- /dev/null +++ b/src-server/src/agent/notes/materializer.py @@ -0,0 +1,79 @@ +import asyncio +import shutil +from anyio import Path as AnyioPath +from loguru import logger +from src.common import DATA_DIR +from src.db import db_context, workspace_models +from src.schemas import workspace as workspace_schemas + + +class NoteMaterializer: + NOTES_DIR_ENVNAME = "DAIS_NOTES_DIR" + _logger = logger.bind(name="NoteMaterializer") + + @staticmethod + async def get_notes_root_dir() -> AnyioPath: + notes_root_dir = AnyioPath(DATA_DIR, ".notes") + await notes_root_dir.mkdir(parents=True, exist_ok=True) + return notes_root_dir + + @classmethod + def get_notes_dir_env(cls, workspace_id: int) -> dict[str, str]: + return {cls.NOTES_DIR_ENVNAME: str(DATA_DIR / ".notes" / str(workspace_id))} + + @classmethod + async def get_notes_dir(cls, workspace_id: int) -> AnyioPath: + notes_root_dir = await cls.get_notes_root_dir() + notes_dir = notes_root_dir / str(workspace_id) + await notes_dir.mkdir(parents=True, exist_ok=True) + return notes_dir + + @classmethod + async def get_notes_index(cls, workspace_id: int) -> str | None: + notes_dir = await cls.get_notes_dir(workspace_id) + index_file = notes_dir / "NOTES.md" + if not await index_file.exists(): return None + try: + return await index_file.read_text(encoding="utf-8") + except: + cls._logger.exception(f"Failed to read root NOTES.md for workspace {workspace_id}.") + return None + + @classmethod + async def materialize(cls, workspace: workspace_schemas.WorkspaceRead) -> AnyioPath: + notes_dir = await cls.get_notes_dir(workspace.id) + for note in workspace.notes: + note_path = notes_dir / note.relative + await note_path.parent.mkdir(parents=True, exist_ok=True) + await note_path.write_text(note.content, "utf-8") + + return notes_dir + + @classmethod + async def materialize_all(cls): + from src.services.workspace import WorkspaceService + + async with db_context() as db_session: + workspaces = await WorkspaceService(db_session).get_all_workspaces() + + sem = asyncio.Semaphore(16) + async def sem_materialize(workspace: workspace_models.Workspace): + async with sem: + workspace_read = workspace_schemas.WorkspaceRead.model_validate(workspace) + await cls.clear_materialized(workspace_read.id) + await cls.materialize(workspace_read) + + tasks = [sem_materialize(workspace) for workspace in workspaces] + results = await asyncio.gather(*tasks, return_exceptions=True) + for result in results: + if isinstance(result, BaseException): + cls._logger.opt(exception=result).warning("Failed to materialize workspaces") + + @classmethod + async def clear_materialized(cls, workspace_id: int): + cls._logger.debug(f"Clearing notes for workspace {workspace_id}") + notes_root_dir = await cls.get_notes_root_dir() + notes_dir = notes_root_dir / str(workspace_id) + + if not await notes_dir.exists(): return + await asyncio.to_thread(shutil.rmtree, notes_dir) diff --git a/src-server/src/agent/notes/watcher.py b/src-server/src/agent/notes/watcher.py new file mode 100644 index 00000000..213fdb78 --- /dev/null +++ b/src-server/src/agent/notes/watcher.py @@ -0,0 +1,111 @@ +from typing import Callable +from anyio import Path as AnyioPath +from watchfiles import Change as ChangeType +from src.db import db_context +from src.utils import DirectoryWatcher, FileChange +from .materializer import NoteMaterializer +from .workspace_ref_manager import WorkspaceRefManager + + +type NoteChange = tuple[ChangeType, AnyioPath] + +class NoteWatcher: + def __init__(self, workspace_id: int) -> None: + self._workspace_id = workspace_id + self._ref_acquired = False + self._watcher: DirectoryWatcher | None = None + + async def _start(self): + if self._watcher is not None: + await self._stop() + + WorkspaceRefManager.increase_workspace_ref(self._workspace_id) + self._ref_acquired = True + try: + notes_dir = await NoteMaterializer.get_notes_dir(self._workspace_id) + self._watcher = DirectoryWatcher(notes_dir, on_changes=self._handle_file_changes) + await self._watcher.start() + except BaseException: + WorkspaceRefManager.decrease_workspace_ref(self._workspace_id) + self._ref_acquired = False + raise + + async def _stop(self) -> None: + if self._watcher: + await self._watcher.stop() + self._watcher = None + if self._ref_acquired: + WorkspaceRefManager.decrease_workspace_ref(self._workspace_id) + + async def _handle_file_changes(self, raw_changes: list[FileChange]) -> None: + notes_dir = await NoteMaterializer.get_notes_dir(self._workspace_id) + + markdown_changes: list[NoteChange] = [] + for change_type, path_str in raw_changes: + path = AnyioPath(path_str) + if await path.is_symlink() or await path.is_dir(): + continue + if path.suffix.lower() != ".md": + continue + markdown_changes.append((change_type, path)) + + if markdown_changes: + await self._handle_note_changes(notes_dir, markdown_changes) + + async def _handle_note_changes(self, base: AnyioPath, changes: list[NoteChange]): + from src.services.workspace import WorkspaceService + from src.db.models import workspace as workspace_models + + if len(changes) == 0: return + + added_notes: list[tuple[AnyioPath, str]] = [] # (relative, content) + updated_notes: list[tuple[AnyioPath, str]] = [] # (relative, content) + deleted_notes: list[AnyioPath] = [] + + for change_type, file_path in changes: + try: + match change_type: + case ChangeType.added: + content = await AnyioPath(file_path).read_text("utf-8") + added_notes.append((file_path, content)) + case ChangeType.modified: + content = await AnyioPath(file_path).read_text("utf-8") + updated_notes.append((file_path, content)) + case ChangeType.deleted: + deleted_notes.append(file_path) + except Exception: + # read file failed, skip + pass + + normalized_path: Callable[[AnyioPath], str] = lambda path: path.relative_to(base).as_posix() + + async with db_context() as db_session: + workspace = await WorkspaceService(db_session).get_workspace_by_id(self._workspace_id) + existing_notes: dict[str, workspace_models.WorkspaceNote] = { + note.relative: note for note in workspace.notes + } + + for path in deleted_notes: + existing_notes.pop(normalized_path(path), None) + for path, content in added_notes: + relative = normalized_path(path) + existing_notes[relative] = workspace_models.WorkspaceNote( + relative=relative, + content=content, + ) + for path, content in updated_notes: + relative = normalized_path(path) + existing_note = existing_notes.get(relative, workspace_models.WorkspaceNote( + relative=relative, + content=content, + )) + existing_note.content = content + + workspace.notes = list(existing_notes.values()) + + async def __aenter__(self): + await self._start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._stop() diff --git a/src-server/src/agent/notes/workspace_ref_manager.py b/src-server/src/agent/notes/workspace_ref_manager.py new file mode 100644 index 00000000..960720bb --- /dev/null +++ b/src-server/src/agent/notes/workspace_ref_manager.py @@ -0,0 +1,18 @@ +from collections import Counter + + +class WorkspaceRefManager: + _workspace_refs: Counter[int] = Counter() + + @classmethod + def increase_workspace_ref(cls, workspace_id: int): + cls._workspace_refs[workspace_id] += 1 + + @classmethod + def decrease_workspace_ref(cls, workspace_id: int): + current_count = cls._workspace_refs[workspace_id] + cls._workspace_refs[workspace_id] = max(current_count - 1, 0) + + @classmethod + def is_workspace_in_use(cls, workspace_id: int) -> bool: + return cls._workspace_refs[workspace_id] > 0 diff --git a/src-server/src/agent/skills/materializer.py b/src-server/src/agent/skills/materializer.py index 844e9002..7482461d 100644 --- a/src-server/src/agent/skills/materializer.py +++ b/src-server/src/agent/skills/materializer.py @@ -32,7 +32,7 @@ async def get_skill_dir(skill: skill_schemas.SkillRead) -> Path: return skill_dir @classmethod - async def materialize_skill(cls, skill: skill_schemas.SkillRead) -> Path: + async def materialize(cls, skill: skill_schemas.SkillRead) -> Path: """ Materialize a skill to a temporary directory and return the directory absolute path. """ @@ -57,7 +57,7 @@ async def materialize_skill(cls, skill: skill_schemas.SkillRead) -> Path: return skill_dir @classmethod - async def materialize_skills(cls): + async def materialize_all(cls): async with db_context() as db_session: skills = await SkillService(db_session).get_all_skills() @@ -65,17 +65,18 @@ async def materialize_skills(cls): async def sem_materialize(skill: skill_models.Skill): async with sem: skill_read = skill_schemas.SkillRead.model_validate(skill) - await cls.clear_materialized(skill_read) - await cls.materialize_skill(skill_read) + await cls.clear_materialized(skill_read.id) + await cls.materialize(skill_read) + tasks = [sem_materialize(skill) for skill in skills] results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, BaseException): - cls._logger.opt(exception=result).warning("Failed to materialize skill") + cls._logger.opt(exception=result).warning("Failed to materialize skills") @classmethod - async def clear_materialized(cls, skill: skill_schemas.SkillRead): + async def clear_materialized(cls, skill_id: int): skills_dir = await cls.get_skills_dir() - skill_dir = skills_dir / str(skill.id) + skill_dir = skills_dir / str(skill_id) if not await skill_dir.exists(): return await asyncio.to_thread(shutil.rmtree, skill_dir) diff --git a/src-server/src/agent/task/__init__.py b/src-server/src/agent/task/__init__.py index 7984aaac..e986d1a9 100644 --- a/src-server/src/agent/task/__init__.py +++ b/src-server/src/agent/task/__init__.py @@ -1,10 +1,12 @@ import asyncio from loguru import logger from dais_sdk.types import ToolMessage, AssistantMessage +from src.schemas import workspace as workspace_schemas from src.schemas.tasks import runtime as task_runtime_schemas from .message_manager import MessageManager, MessageNotFoundError from .tool_call_manager import ToolCallManager from .llm_request_manager import LlmRequestManager +from ..notes import NoteWatcher from ..context import AgentContext from ..types import ( AgentGenerator, StopReason, @@ -37,72 +39,84 @@ def messages(self) -> MessageManager: def tool_calls(self) -> ToolCallManager: return self._tool_call_manager + @property + def workspace(self) -> workspace_schemas.WorkspaceRead: + return self._ctx._resource.workspace + async def run(self) -> AgentGenerator: _exited_by_generator_close = False retries = 0 max_retries = 3 - try: - while self._is_running: - last_chunk: MessageEndEvent | TaskInterruptedEvent | ErrorEvent | None = None - try: - llm_stream = self._llm_request_manager.create_llm_call() - async for chunk in llm_stream: - if isinstance(chunk, self._llm_request_manager.FINISHING_CHUNK_TYPE): - last_chunk = chunk - continue - yield chunk - except asyncio.CancelledError: - # Task cancelled by user - break - - match last_chunk: - case MessageEndEvent() as message_end_chunk: - retries = 0 - yield message_end_chunk - case ErrorEvent(error=error, retryable=retryable) as error_chunk: - self._logger.warning(f"LLM provider error: {error}") - if not retryable or retries >= max_retries: - yield error_chunk - break - retries += 1 - continue # retry - case TaskInterruptedEvent() as interrupted_chunk: - yield interrupted_chunk - break - case _ as chunk: - self._logger.warning(f"Unexpected message event: {chunk}") + async with NoteWatcher(self.workspace.id): + try: + tail_tool_calls = list(self.messages.tail_tool_messages_iter()) + dispatch_stream, _ = self.tool_calls.dispatch(tail_tool_calls) + async for event in dispatch_stream: + yield event + has_pending_tool_calls = len(self.tool_calls.collect_pendings()) > 0 + if has_pending_tool_calls: return + + while self._is_running: + last_chunk: MessageEndEvent | TaskInterruptedEvent | ErrorEvent | None = None + try: + llm_stream = self._llm_request_manager.create_llm_call() + async for chunk in llm_stream: + if isinstance(chunk, self._llm_request_manager.FINISHING_CHUNK_TYPE): + last_chunk = chunk + continue + yield chunk + except asyncio.CancelledError: + # Task cancelled by user break - assistant_message = last_chunk.message - if (assistant_message.content == None and - assistant_message.reasoning_content == None and - (assistant_message.tool_calls is None or len(assistant_message.tool_calls) == 0)): - # empty message, retry - continue + match last_chunk: + case MessageEndEvent() as message_end_chunk: + retries = 0 + yield message_end_chunk + case ErrorEvent(error=error, retryable=retryable) as error_chunk: + self._logger.warning(f"LLM provider error: {error}") + if not retryable or retries >= max_retries: + yield error_chunk + break + retries += 1 + continue # retry + case TaskInterruptedEvent() as interrupted_chunk: + yield interrupted_chunk + break + case _ as chunk: + self._logger.warning(f"Unexpected message event: {chunk}") + break - self._ctx.messages.append(assistant_message) - tool_call_messages = self._extract_tool_call(assistant_message) - if tool_call_messages is None: - break + assistant_message = last_chunk.message + if (assistant_message.content == None and + assistant_message.reasoning_content == None and + (assistant_message.tool_calls is None or len(assistant_message.tool_calls) == 0)): + # empty message, retry + continue - for message in tool_call_messages: - self._ctx.messages.append(message) - yield ToolCallEndEvent(message=message) + self._ctx.messages.append(assistant_message) + tool_call_messages = self._extract_tool_call(assistant_message) + if tool_call_messages is None: + break - dispatch_stream, dispatch_result =\ - self._tool_call_manager.dispatch(tool_call_messages) - async for event in dispatch_stream: - yield event - if (dispatch_result.has_finished_task or - dispatch_result.has_blocked_tool_calls): - self._is_running = False - break - except GeneratorExit: - _exited_by_generator_close = True - finally: - if not _exited_by_generator_close: - yield TaskDoneEvent() + for message in tool_call_messages: + self._ctx.messages.append(message) + yield ToolCallEndEvent(message=message) + + dispatch_stream, dispatch_result =\ + self._tool_call_manager.dispatch(tool_call_messages) + async for event in dispatch_stream: + yield event + if (dispatch_result.has_finished_task or + dispatch_result.has_blocked_tool_calls): + self._is_running = False + break + except GeneratorExit: + _exited_by_generator_close = True + finally: + if not _exited_by_generator_close: + yield TaskDoneEvent() async def run_until_done(self) -> StopReason: async for event in self.run(): diff --git a/src-server/src/agent/tool/builtin_tools/file_system.py b/src-server/src/agent/tool/builtin_tools/file_system.py index 8093d694..7d386dfd 100644 --- a/src-server/src/agent/tool/builtin_tools/file_system.py +++ b/src-server/src/agent/tool/builtin_tools/file_system.py @@ -16,8 +16,6 @@ from src.settings import use_app_setting_manager from src.utils import MarkdownConverter from ..toolset_wrapper import built_in_tool, BuiltInToolset, BuiltInToolsetContext, BuiltInToolDefaults -from ...notes import NoteManager -from ...skills import SkillMaterializer from ...prompts import create_one_turn_llm, SemanticFileAnalysis, SemanticFileAnalysisInput @@ -46,10 +44,13 @@ class FileSystemToolset(BuiltInToolset): def __init__(self, ctx: BuiltInToolsetContext, toolset_ent: toolset_models.Toolset | None = None): + from ...notes import NoteMaterializer + from ...skills import SkillMaterializer + super().__init__(ctx, toolset_ent) self._path_expander = PathExpander({ **SkillMaterializer.get_skill_dir_env(), - **NoteManager.get_notes_dir_env(ctx.workspace_id), + **NoteMaterializer.get_notes_dir_env(ctx.workspace_id), }) self._markdown_converter = MarkdownConverter() diff --git a/src-server/src/agent/tool/builtin_tools/os_interactions.py b/src-server/src/agent/tool/builtin_tools/os_interactions.py index 443bf579..000ae5c1 100644 --- a/src-server/src/agent/tool/builtin_tools/os_interactions.py +++ b/src-server/src/agent/tool/builtin_tools/os_interactions.py @@ -5,23 +5,24 @@ from itertools import islice from dais_shell import AgentShell, CommandStep from dais_shell.iostream_reader import IOStreamBuffer -from src.agent.skills import SkillMaterializer from src.db.models import toolset as toolset_models from src.binaries import UV_PATH, NODE_PATH from ..toolset_wrapper import built_in_tool, BuiltInToolset, BuiltInToolsetContext -from ...notes import NoteManager class OsInteractionsToolset(BuiltInToolset): def __init__(self, ctx: BuiltInToolsetContext, toolset_ent: toolset_models.Toolset | None = None): + from ...notes import NoteMaterializer + from ...skills import SkillMaterializer + super().__init__(ctx, toolset_ent) self._shell = AgentShell( extra_paths=[str(NODE_PATH.parent), str(UV_PATH.parent)], extra_env={ **SkillMaterializer.get_skill_dir_env(), - **NoteManager.get_notes_dir_env(ctx.workspace_id), + **NoteMaterializer.get_notes_dir_env(ctx.workspace_id), } ) diff --git a/src-server/src/agent/tool/toolset_manager/builtin_toolset_manager.py b/src-server/src/agent/tool/toolset_manager/builtin_toolset_manager.py index a30c31b2..0446d94c 100644 --- a/src-server/src/agent/tool/toolset_manager/builtin_toolset_manager.py +++ b/src-server/src/agent/tool/toolset_manager/builtin_toolset_manager.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Sequence, override +from typing import Sequence, override from dais_sdk.tool import Toolset from sqlalchemy.ext.asyncio import AsyncSession from src.db import db_context @@ -7,27 +7,16 @@ from ..toolset_wrapper import BuiltInToolset, BuiltInToolsetContext from ..builtin_tools import BUILT_IN_TOOLSETS -if TYPE_CHECKING: - from ...context import ToolRuntimeContext - class BuiltinToolsetManager(ToolsetManager): - def __init__(self, workspace_id: int, cwd: str, runtime_getter: Callable[[], ToolRuntimeContext]): - self._ctx = BuiltInToolsetContext(workspace_id, cwd, runtime_getter) + def __init__(self, workspace_id: int, cwd: str): + self._ctx = BuiltInToolsetContext(workspace_id, cwd) self._toolset_map: dict[str, toolset_models.Toolset] | None = None self._toolsets: list[BuiltInToolset] | None = None @classmethod def default(cls): - """Create a manager for static built-in tool metadata export only. - - The returned manager is intended for metadata-only operations, such as exporting - built-in tool definitions before a real agent runtime context exists. Any access - to runtime-backed context fields from this instance is considered invalid. - """ - def runtime_getter(): - raise ValueError("ToolRuntimeContext is unavailable in BuiltinToolsetManager.default()") - return cls(1, "~", runtime_getter) + return cls(1, "~") @staticmethod async def sync_toolsets(db_session: AsyncSession): @@ -60,7 +49,7 @@ async def initialize(self): self._toolsets.append(toolset_t(self._ctx, toolset_ent)) @classmethod - async def create(cls, workspace_id: int, cwd: str, tool_runtime_context: ToolRuntimeContext) -> BuiltinToolsetManager: - manager = cls(workspace_id, cwd, lambda: tool_runtime_context) + async def create(cls, workspace_id: int, cwd: str) -> BuiltinToolsetManager: + manager = cls(workspace_id, cwd) await manager.initialize() return manager diff --git a/src-server/src/agent/tool/toolset_wrapper/built_in_toolset.py b/src-server/src/agent/tool/toolset_wrapper/built_in_toolset.py index 58470a7a..03c7d1d9 100644 --- a/src-server/src/agent/tool/toolset_wrapper/built_in_toolset.py +++ b/src-server/src/agent/tool/toolset_wrapper/built_in_toolset.py @@ -1,15 +1,12 @@ from dataclasses import replace from pathlib import Path -from typing import Callable, Self, cast, override, TYPE_CHECKING, TypedDict +from typing import Self, cast, override, TYPE_CHECKING, TypedDict from dais_sdk.tool import PythonToolset, python_tool from dais_sdk.types import ToolDef from sqlalchemy.ext.asyncio import AsyncSession from ..types import ToolMetadata -from ...notes import NoteManager -from ...types import ContextUsage if TYPE_CHECKING: - from ...context import ToolRuntimeContext from ....db.models import toolset as toolset_models @@ -22,28 +19,19 @@ class BuiltInToolDefaults(TypedDict, total=False): needs_user_interaction: bool class BuiltInToolsetContext: - def __init__(self, workspace_id: int, cwd: str | Path, runtime_getter: Callable[[], ToolRuntimeContext]): + def __init__(self, workspace_id: int, cwd: str | Path): self.workspace_id = workspace_id self.cwd = Path(cwd).expanduser().resolve() - self._runtime_getter = runtime_getter - - @property - def usage(self) -> ContextUsage: return self._runtime_getter().usage - - @property - def note_manager(self) -> NoteManager: return self._runtime_getter().note_manager @classmethod def default(cls) -> Self: """Create a context for static tool metadata export without runtime state. The returned context is only intended for code paths that inspect built-in tool - definitions, such as tool metadata synchronization. Runtime-only properties like - usage and note_manager are intentionally unavailable on this instance. + definitions, such as tool metadata synchronization. Runtime-only properties are + intentionally unavailable on this instance. """ - def runtime_getter(): - raise ValueError("ToolRuntimeContext is unavailable in BuiltInToolsetContext.default()") - return cls(1, Path.cwd(), runtime_getter) + return cls(1, Path.cwd()) class BuiltInToolset(PythonToolset): def __init__(self, diff --git a/src-server/src/api/exceptions.py b/src-server/src/api/exceptions.py index 87c66778..0a51279a 100644 --- a/src-server/src/api/exceptions.py +++ b/src-server/src/api/exceptions.py @@ -14,6 +14,8 @@ class ApiErrorCode(StrEnum): TASK_RESOURCE_NOT_FOUND = "TASK_RESOURCE_NOT_FOUND" TASK_RESOURCE_SHOULD_HAVE_FILENAME_AND_CONTENTTYPE = "TASK_RESOURCE_SHOULD_HAVE_FILENAME_AND_CONTENTTYPE" + WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK = "WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK" + INVALID_SKILL_ARCHIVE = "INVALID_SKILL_ARCHIVE" SUMMARIZE_TITLE_FAILED = "SUMMARIZE_TITLE_FAILED" diff --git a/src-server/src/api/lifespan.py b/src-server/src/api/lifespan.py index 6f3491d4..d417f954 100644 --- a/src-server/src/api/lifespan.py +++ b/src-server/src/api/lifespan.py @@ -5,6 +5,7 @@ from typing import Coroutine, TypedDict from fastapi import FastAPI from src.agent.skills import SkillMaterializer +from src.agent.notes import NoteMaterializer from src.agent.task.schedule_runner import init_schedule_runner from src.agent.tool import BuiltinToolsetManager, McpToolsetManager, use_mcp_toolset_manager from src.db import engine as database_engine, db_context @@ -47,7 +48,8 @@ def __init__(self): self.background_task_manager = BackgroundTaskManager() async def _init_resources(self): - self.background_task_manager.add_task(SkillMaterializer.materialize_skills()) + self.background_task_manager.add_task(SkillMaterializer.materialize_all()) + self.background_task_manager.add_task(NoteMaterializer.materialize_all()) self.background_task_manager.add_task(self.schedule_runner.load_schedules()) self.background_task_manager.add_task(self._clear_unused_cache()) diff --git a/src-server/src/api/routes/skill.py b/src-server/src/api/routes/skill.py index 06178c0c..2847a3f1 100644 --- a/src-server/src/api/routes/skill.py +++ b/src-server/src/api/routes/skill.py @@ -42,8 +42,8 @@ def process_archive(file_obj: IO[bytes]) -> skill_schemas.SkillCreate: def start_materializing_background_task(background_tasks: BackgroundTasks, skill_ent: skill_models.Skill): skill_data = skill_schemas.SkillRead.model_validate(skill_ent) async def clear_and_rematerialize(skill: skill_schemas.SkillRead): - await SkillMaterializer.clear_materialized(skill) - await SkillMaterializer.materialize_skill(skill) + await SkillMaterializer.clear_materialized(skill.id) + await SkillMaterializer.materialize(skill) background_tasks.add_task(clear_and_rematerialize, skill_data) @skills_router.get("/", response_model=Page[skill_schemas.SkillBrief]) @@ -99,5 +99,4 @@ async def delete_skill( background_tasks: BackgroundTasks, ): deleted_skill = await SkillService(db_session).delete_skill(skill_id) - skill_data = skill_schemas.SkillRead.model_validate(deleted_skill) - background_tasks.add_task(SkillMaterializer.clear_materialized, skill_data) + background_tasks.add_task(SkillMaterializer.clear_materialized, deleted_skill.id) diff --git a/src-server/src/api/routes/tasks/schedule.py b/src-server/src/api/routes/tasks/schedule.py index 2b8408c3..f4318473 100644 --- a/src-server/src/api/routes/tasks/schedule.py +++ b/src-server/src/api/routes/tasks/schedule.py @@ -4,7 +4,9 @@ from src.agent.task.schedule_runner import use_schedule_runner from src.db import DbSessionDep from src.schemas.tasks import schedule as schedule_schemas +from src.services.tasks import TaskResourceService from src.services.schedule import RunRecordService, ScheduleService +from src.schemas.tasks import runtime as task_runtime_schemas schedule_manage_router = APIRouter(tags=["schedule"]) @@ -70,3 +72,4 @@ async def get_run_record(run_record_id: int, db_session: DbSessionDep): @schedule_manage_router.delete("/records/{run_record_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_run_record(run_record_id: int, db_session: DbSessionDep): await RunRecordService(db_session).delete_run_record(run_record_id) + await TaskResourceService(db_session, task_runtime_schemas.TaskType.SCHEDULE).delete_task_resources(run_record_id) diff --git a/src-server/src/api/routes/tasks/stream.py b/src-server/src/api/routes/tasks/stream.py index 3ba7b37b..995a66d0 100644 --- a/src-server/src/api/routes/tasks/stream.py +++ b/src-server/src/api/routes/tasks/stream.py @@ -70,23 +70,6 @@ async def continue_task( body: ContinueTaskBody, request: Request, ): - """ - Directly continue a existing task - """ task = await create_agent_task(task_type, task_id, body.agent_id) - - # ensure all approved tool calls are executed before continuing - try: - tail_tool_calls = list(task.messages.tail_tool_messages_iter()) - dispatch_stream, _ = task.tool_calls.dispatch(tail_tool_calls) - async for event in dispatch_stream: - yield event - finally: - await asyncio.shield(task.persist()) - - if len(task.tool_calls.collect_pendings()) > 0: - # prevent starting agent loop when there are still unresolved tool calls - yield TaskDoneEvent() - return async for event in stream_connector(task, request): yield event diff --git a/src-server/src/api/routes/tasks/task.py b/src-server/src/api/routes/tasks/task.py index a1f2bcc3..ed672af4 100644 --- a/src-server/src/api/routes/tasks/task.py +++ b/src-server/src/api/routes/tasks/task.py @@ -8,8 +8,11 @@ from src.db import DbSessionDep from src.db.models import agent as agent_models from src.db.models import tasks as task_models -from src.services.tasks import TaskService -from src.schemas.tasks import task as task_schemas +from src.services.tasks import TaskResourceService, TaskService +from src.schemas.tasks import ( + task as task_schemas, + runtime as task_runtime_schemas, +) from ...exceptions import ApiError, ApiErrorCode @@ -90,3 +93,4 @@ async def summarize_task_title(task_id: int, db_session: DbSessionDep): @task_manage_router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_task(task_id: int, db_session: DbSessionDep): await TaskService(db_session).delete_task(task_id) + await TaskResourceService(db_session, task_runtime_schemas.TaskType.TASK).delete_task_resources(task_id) diff --git a/src-server/src/api/routes/workspace.py b/src-server/src/api/routes/workspace.py index 8c4b008e..f35c8854 100644 --- a/src-server/src/api/routes/workspace.py +++ b/src-server/src/api/routes/workspace.py @@ -2,9 +2,11 @@ from fastapi_pagination import Page from fastapi_pagination.ext.sqlalchemy import apaginate from src.db import DbSessionDep +from src.agent.notes import NoteMaterializer, WorkspaceRefManager from src.services.workspace import WorkspaceService from src.schemas import workspace as workspace_schemas from src.utils.open_in_file_manager import open_in_file_manager +from ..exceptions import ApiError, ApiErrorCode workspaces_router = APIRouter(tags=["workspace"]) @@ -23,7 +25,10 @@ async def create_workspace( db_session: DbSessionDep, body: workspace_schemas.WorkspaceCreate, ): - return await WorkspaceService(db_session).create_workspace(body) + created_workspace = await WorkspaceService(db_session).create_workspace(body) + workspace = workspace_schemas.WorkspaceRead.model_validate(created_workspace) + await NoteMaterializer.materialize(workspace) + return workspace @workspaces_router.put("/{workspace_id}", response_model=workspace_schemas.WorkspaceRead) async def update_workspace( @@ -33,9 +38,24 @@ async def update_workspace( ): return await WorkspaceService(db_session).update_workspace(workspace_id, body) +@workspaces_router.put("/{workspace_id}/notes", response_model=workspace_schemas.WorkspaceRead) +async def update_workspace_notes( + workspace_id: int, + body: workspace_schemas.WorkspaceNotesUpdate, + db_session: DbSessionDep, +): + if WorkspaceRefManager.is_workspace_in_use(workspace_id): + raise ApiError(status.HTTP_409_CONFLICT, ApiErrorCode.WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK) + updated_workspace = await WorkspaceService(db_session).update_workspace_notes(workspace_id, body) + workspace = workspace_schemas.WorkspaceRead.model_validate(updated_workspace) + await NoteMaterializer.clear_materialized(workspace.id) + await NoteMaterializer.materialize(workspace) + return workspace + @workspaces_router.delete("/{workspace_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_workspace(workspace_id: int, db_session: DbSessionDep): await WorkspaceService(db_session).delete_workspace(workspace_id) + await NoteMaterializer.clear_materialized(workspace_id) # --- --- --- --- --- --- diff --git a/src-server/src/logger.py b/src-server/src/logger.py index 53a30407..1e261918 100644 --- a/src-server/src/logger.py +++ b/src-server/src/logger.py @@ -39,6 +39,7 @@ def setup_logging(log_level: int): NOISY_LIBS = ( "aiosqlite", + "apscheduler", "binaryornot", "httpcore", "httpx", diff --git a/src-server/src/schemas/workspace.py b/src-server/src/schemas/workspace.py index 7774b5b0..c0a706dd 100644 --- a/src-server/src/schemas/workspace.py +++ b/src-server/src/schemas/workspace.py @@ -37,7 +37,9 @@ class WorkspaceUpdate(DTOBase): name: str | None directory: str | None instruction: str | None - notes: list[WorkspaceNoteBase] | None usable_agent_ids: list[int] | None usable_tool_ids: list[int] | None usable_skill_ids: list[int] | None = None + +class WorkspaceNotesUpdate(DTOBase): + notes: list[WorkspaceNoteBase] diff --git a/src-server/src/services/agent.py b/src-server/src/services/agent.py index d03eef8e..8184ea4a 100644 --- a/src-server/src/services/agent.py +++ b/src-server/src/services/agent.py @@ -5,7 +5,6 @@ from src.schemas import agent as agent_schemas from .service_base import ServiceBase from .exceptions import NotFoundError, ServiceErrorCode -from .utils import build_load_options, Relations class AgentNotFoundError(NotFoundError): @@ -14,10 +13,10 @@ def __init__(self, agent_id: int) -> None: class AgentService(ServiceBase): @staticmethod - def relations() -> Relations: + def relations(): return [ - agent_models.Agent.model, - agent_models.Agent.usable_tools, + selectinload(agent_models.Agent.model), + selectinload(agent_models.Agent.usable_tools), ] def get_agents_query(self): @@ -30,7 +29,7 @@ def get_agents_query(self): async def get_agent_by_id(self, id: int) -> agent_models.Agent: agent = await self._db_session.get( agent_models.Agent, id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not agent: raise AgentNotFoundError(id) diff --git a/src-server/src/services/provider.py b/src-server/src/services/provider.py index f500056d..746047e8 100644 --- a/src-server/src/services/provider.py +++ b/src-server/src/services/provider.py @@ -1,10 +1,10 @@ from loguru import logger from sqlalchemy import select +from sqlalchemy.orm import selectinload from src.db.models import provider as provider_models from src.schemas import provider as provider_schemas from .service_base import ServiceBase from .exceptions import NotFoundError, ServiceErrorCode -from .utils import build_load_options, Relations _logger = logger.bind(name="ProviderService") @@ -15,15 +15,15 @@ def __init__(self, provider_id: int) -> None: class ProviderService(ServiceBase): @staticmethod - def relations() -> Relations: + def relations(): return [ - provider_models.Provider.models, + selectinload(provider_models.Provider.models), ] def get_providers_query(self): return ( select(provider_models.Provider) - .options(*build_load_options(self.relations())) + .options(*self.relations()) .order_by(provider_models.Provider.id.asc()) ) @@ -36,7 +36,7 @@ async def get_provider_by_id(self, provider_id: int) -> provider_models.Provider provider = await self._db_session.get( provider_models.Provider, provider_id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not provider: raise ProviderNotFoundError(provider_id) diff --git a/src-server/src/services/schedule.py b/src-server/src/services/schedule.py index 6aaaab6b..15af94f5 100644 --- a/src-server/src/services/schedule.py +++ b/src-server/src/services/schedule.py @@ -1,10 +1,10 @@ from dais_sdk.types import UserMessage from sqlalchemy import select +from sqlalchemy.orm import selectinload from src.db.models import tasks as task_models from src.schemas.tasks import schedule as schedule_schemas from .service_base import ServiceBase from .exceptions import NotFoundError, ServiceErrorCode -from .utils import build_load_options, Relations class ScheduleNotFoundError(NotFoundError): @@ -13,10 +13,10 @@ def __init__(self, schedule_id: int) -> None: class ScheduleService(ServiceBase): @staticmethod - def relations() -> Relations: + def relations(): return [ - task_models.Schedule.agent, - task_models.Schedule.workspace, + selectinload(task_models.Schedule.agent), + selectinload(task_models.Schedule.workspace), ] def get_all_schedules_query(self): @@ -43,7 +43,7 @@ async def get_schedule_by_id(self, id: int) -> task_models.Schedule: schedule = await self._db_session.get( task_models.Schedule, id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not schedule: raise ScheduleNotFoundError(id) @@ -88,9 +88,9 @@ def __init__(self, run_record_id: int) -> None: class RunRecordService(ServiceBase): @staticmethod - def relations() -> Relations: + def relations(): return [ - task_models.RunRecord.schedule, + selectinload(task_models.RunRecord.schedule), ] def get_run_records_query(self, schedule_id: int): @@ -104,7 +104,7 @@ async def get_run_record_by_id(self, id: int) -> task_models.RunRecord: run_record = await self._db_session.get( task_models.RunRecord, id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not run_record: raise RunRecordNotFoundError(id) diff --git a/src-server/src/services/skill.py b/src-server/src/services/skill.py index 55ed9c93..89bdddda 100644 --- a/src-server/src/services/skill.py +++ b/src-server/src/services/skill.py @@ -4,7 +4,6 @@ from src.schemas import skill as skill_schemas from .service_base import ServiceBase from .exceptions import NotFoundError, ConflictError, ServiceErrorCode -from .utils import build_load_options, Relations class SkillNotFoundError(NotFoundError): @@ -20,8 +19,10 @@ def __init__(self, name: str) -> None: class SkillService(ServiceBase): @staticmethod - def relations() -> Relations: - return [skill_models.Skill.resources] + def relations(): + return [ + selectinload(skill_models.Skill.resources) + ] def get_skills_query(self): return ( @@ -34,7 +35,7 @@ async def get_all_skills(self) -> list[skill_models.Skill]: stmt = ( select(skill_models.Skill) .order_by(skill_models.Skill.id.asc()) - .options(*build_load_options(self.relations())) + .options(*self.relations()) ) skills = (await self._db_session.scalars(stmt)).all() return list(skills) @@ -43,7 +44,7 @@ async def get_skill_by_id(self, id: int) -> skill_models.Skill: skill = await self._db_session.get( skill_models.Skill, id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not skill: raise SkillNotFoundError(id) @@ -53,7 +54,7 @@ async def get_skill_by_name(self, name: str) -> skill_models.Skill: stmt = ( select(skill_models.Skill) .where(skill_models.Skill.name == name) - .options(*build_load_options(self.relations())) + .options(*self.relations()) ) skill = await self._db_session.scalar(stmt) if not skill: diff --git a/src-server/src/services/tasks/schedule.py b/src-server/src/services/tasks/schedule.py index a8b3e1ca..727c0b0c 100644 --- a/src-server/src/services/tasks/schedule.py +++ b/src-server/src/services/tasks/schedule.py @@ -1,11 +1,11 @@ from sqlalchemy import select +from sqlalchemy.orm import selectinload from src.db.models import tasks as task_models from src.schemas.tasks import schedule as schedule_schemas from src.schemas.tasks import runtime as task_runtime_schemas from .resource import TaskResourceService from ..service_base import ServiceBase from ..exceptions import NotFoundError, ServiceErrorCode -from ..utils import build_load_options, Relations class ScheduleNotFoundError(NotFoundError): @@ -14,10 +14,10 @@ def __init__(self, schedule_id: int) -> None: class ScheduleService(ServiceBase): @staticmethod - def relations() -> Relations: + def relations(): return [ - task_models.Schedule.agent, - task_models.Schedule.workspace, + selectinload(task_models.Schedule.agent), + selectinload(task_models.Schedule.workspace), ] def get_all_schedules_query(self): @@ -44,7 +44,7 @@ async def get_schedule_by_id(self, id: int) -> task_models.Schedule: schedule = await self._db_session.get( task_models.Schedule, id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not schedule: raise ScheduleNotFoundError(id) @@ -89,9 +89,9 @@ def __init__(self, run_record_id: int) -> None: class RunRecordService(ServiceBase): @staticmethod - def relations() -> Relations: + def relations(): return [ - task_models.RunRecord.schedule, + selectinload(task_models.RunRecord.schedule), ] def get_run_records_query(self, schedule_id: int): @@ -105,7 +105,7 @@ async def get_run_record_by_id(self, id: int) -> task_models.RunRecord: run_record = await self._db_session.get( task_models.RunRecord, id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not run_record: raise RunRecordNotFoundError(id) diff --git a/src-server/src/services/tasks/task.py b/src-server/src/services/tasks/task.py index fba96cee..0608ee09 100644 --- a/src-server/src/services/tasks/task.py +++ b/src-server/src/services/tasks/task.py @@ -1,11 +1,9 @@ from sqlalchemy import select +from sqlalchemy.orm import selectinload from src.db.models import tasks as task_models from src.schemas.tasks import task as task_schemas -from src.schemas.tasks import runtime as task_runtime_schemas from ..service_base import ServiceBase from ..exceptions import NotFoundError, ServiceErrorCode -from ..utils import build_load_options, Relations -from .resource import TaskResourceService class TaskNotFoundError(NotFoundError): @@ -15,10 +13,10 @@ def __init__(self, task_id: int) -> None: class TaskService(ServiceBase): @staticmethod - def relations() -> Relations: + def relations(): return [ - task_models.Task.agent, - task_models.Task.workspace, + selectinload(task_models.Task.agent), + selectinload(task_models.Task.workspace), ] def get_tasks_query(self, workspace_id: int): @@ -40,7 +38,7 @@ def get_recent_tasks_query(self): async def get_task_by_id(self, id: int) -> task_models.Task: task = await self._db_session.get( task_models.Task, id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not task: raise TaskNotFoundError(id) @@ -74,6 +72,5 @@ async def update_task(self, id: int, data: task_schemas.TaskUpdate) -> task_mode async def delete_task(self, id: int) -> None: task = await self.get_task_by_id(id) - await TaskResourceService(self._db_session, task_runtime_schemas.TaskType.TASK).delete_task_resources(id) await self._db_session.delete(task) await self._db_session.flush() diff --git a/src-server/src/services/toolset.py b/src-server/src/services/toolset.py index 10ef60e7..48f740f6 100644 --- a/src-server/src/services/toolset.py +++ b/src-server/src/services/toolset.py @@ -1,11 +1,11 @@ from typing import NamedTuple from dais_sdk.types import ToolDef from sqlalchemy import select +from sqlalchemy.orm import selectinload from src.db.models import toolset as toolset_models from src.schemas import toolset as toolset_schemas from .service_base import ServiceBase from .exceptions import NotFoundError, ConflictError, ServiceErrorCode -from .utils import build_load_options, Relations class ToolsetNotFoundError(NotFoundError): @@ -28,9 +28,9 @@ class ToolLike(NamedTuple): auto_approve: bool = False @staticmethod - def relations() -> Relations: + def relations(): return [ - toolset_models.Toolset.tools, + selectinload(toolset_models.Toolset.tools), ] async def get_all_mcp_toolsets(self) -> list[toolset_models.Toolset]: @@ -44,7 +44,7 @@ async def get_all_mcp_toolsets(self) -> list[toolset_models.Toolset]: ] ) ) - .options(*build_load_options(self.relations())) + .options(*self.relations()) ) toolsets = (await self._db_session.scalars(stmt)).all() return list(toolsets) @@ -53,7 +53,7 @@ async def get_all_built_in_toolsets(self) -> list[toolset_models.Toolset]: stmt = ( select(toolset_models.Toolset) .where(toolset_models.Toolset.type == toolset_models.ToolsetType.BUILT_IN) - .options(*build_load_options(self.relations())) + .options(*self.relations()) ) toolsets = (await self._db_session.scalars(stmt)).all() return list(toolsets) @@ -62,7 +62,7 @@ async def get_toolset_by_id(self, id: int) -> toolset_models.Toolset: toolset = await self._db_session.get( toolset_models.Toolset, id, - options=build_load_options(self.relations()), + options=self.relations(), ) if not toolset: raise ToolsetNotFoundError(id) @@ -72,7 +72,7 @@ async def get_toolset_by_internal_key(self, internal_key: str) -> toolset_models stmt = ( select(toolset_models.Toolset) .where(toolset_models.Toolset.internal_key == internal_key) - .options(*build_load_options(self.relations())) + .options(*self.relations()) ) toolset = await self._db_session.scalar(stmt) if not toolset: diff --git a/src-server/src/services/utils.py b/src-server/src/services/utils.py deleted file mode 100644 index 71cc905b..00000000 --- a/src-server/src/services/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -from sqlalchemy.orm import QueryableAttribute, selectinload - -type Relations = list[QueryableAttribute] - -def build_load_options( - relations: list[QueryableAttribute], - root: QueryableAttribute | None = None, -): - if root is not None: - base = selectinload(root) - return [base.selectinload(rel) for rel in relations] - else: - return [selectinload(rel) for rel in relations] diff --git a/src-server/src/services/workspace.py b/src-server/src/services/workspace.py index f055902e..a6cb5ffc 100644 --- a/src-server/src/services/workspace.py +++ b/src-server/src/services/workspace.py @@ -1,6 +1,5 @@ from sqlalchemy import select from sqlalchemy.orm import selectinload -from src.agent.notes.manager import NoteManager from src.db.models import workspace as workspace_models from src.db.models import agent as agent_models from src.db.models import toolset as toolset_models @@ -9,7 +8,6 @@ from .service_base import ServiceBase from .exceptions import NotFoundError, ServiceErrorCode from .agent import AgentService -from .utils import build_load_options, Relations class WorkspaceNotFoundError(NotFoundError): @@ -18,17 +16,20 @@ def __init__(self, workspace_id: int) -> None: class WorkspaceService(ServiceBase): @staticmethod - def relations() -> Relations: + def relations(): return [ - workspace_models.Workspace.usable_tools, - workspace_models.Workspace.usable_agents, - workspace_models.Workspace.usable_skills, + selectinload(workspace_models.Workspace.usable_tools), + selectinload(workspace_models.Workspace.usable_agents) + .selectinload(agent_models.Agent.model), + selectinload(workspace_models.Workspace.usable_skills), + selectinload(workspace_models.Workspace.notes), ] def get_workspaces_query(self): return ( select(workspace_models.Workspace) .order_by(workspace_models.Workspace.id.asc()) + .options(*self.relations()) ) async def _update_relations( @@ -36,20 +37,11 @@ async def _update_relations( workspace: workspace_models.Workspace, data: workspace_schemas.WorkspaceCreate | workspace_schemas.WorkspaceUpdate, ): - if data.notes is not None: - workspace.notes = [ - workspace_models.WorkspaceNote( - relative=note.relative, - content=note.content, - ) - for note in data.notes - ] - if data.usable_agent_ids is not None: stmt = ( select(agent_models.Agent) .where(agent_models.Agent.id.in_(data.usable_agent_ids)) - .options(*build_load_options(AgentService.relations())) + .options(*AgentService.relations()) ) agents = (await self._db_session.scalars(stmt)).all() workspace.usable_agents = list(agents) @@ -68,17 +60,20 @@ async def _update_relations( skills = (await self._db_session.scalars(stmt)).all() workspace.usable_skills = list(skills) + async def get_all_workspaces(self) -> list[workspace_models.Workspace]: + stmt = ( + select(workspace_models.Workspace) + .order_by(workspace_models.Workspace.id.asc()) + .options(*self.relations()) + ) + workspaces = (await self._db_session.scalars(stmt)).all() + return list(workspaces) + async def get_workspace_by_id(self, id: int) -> workspace_models.Workspace: workspace = await self._db_session.get( workspace_models.Workspace, id, - options=[ - selectinload(workspace_models.Workspace.usable_tools), - selectinload(workspace_models.Workspace.usable_agents) - .selectinload(agent_models.Agent.model), - selectinload(workspace_models.Workspace.usable_skills), - selectinload(workspace_models.Workspace.notes), - ], + options=self.relations(), ) if not workspace: raise WorkspaceNotFoundError(id) @@ -88,6 +83,13 @@ async def create_workspace(self, data: workspace_schemas.WorkspaceCreate) -> wor create_data = data.model_dump(exclude={"notes", "usable_agent_ids", "usable_tool_ids", "usable_skill_ids"}) new_workspace = workspace_models.Workspace(**create_data) + new_workspace.notes = [ + workspace_models.WorkspaceNote( + relative=note.relative, + content=note.content, + ) + for note in data.notes + ] await self._update_relations(new_workspace, data) self._db_session.add(new_workspace) @@ -108,8 +110,22 @@ async def update_workspace(self, id: int, data: workspace_schemas.WorkspaceUpdat updated_workspace = await self.get_workspace_by_id(workspace.id) return updated_workspace + async def update_workspace_notes(self, id, data: workspace_schemas.WorkspaceNotesUpdate): + workspace = await self.get_workspace_by_id(id) + workspace.notes = [ + workspace_models.WorkspaceNote( + relative=note.relative, + content=note.content, + ) + for note in data.notes + ] + await self._db_session.flush() + self._db_session.expunge(workspace) + + updated_workspace = await self.get_workspace_by_id(workspace.id) + return updated_workspace + async def delete_workspace(self, id: int) -> None: workspace = await self.get_workspace_by_id(id) await self._db_session.delete(workspace) await self._db_session.flush() - await NoteManager(workspace.id).clear_materialized(force=True) diff --git a/src-server/src/utils/__init__.py b/src-server/src/utils/__init__.py index 5a0a55be..00102afa 100644 --- a/src-server/src/utils/__init__.py +++ b/src-server/src/utils/__init__.py @@ -1,6 +1,7 @@ +from .directory_watcher import DirectoryWatcher, FileChange from .markdown_converter import MarkdownConverter from .get_unique_filename import get_unique_filename from .open_in_file_manager import open_in_file_manager from .to_base64_str import to_base64_str from .parent_watchdog import ParentWatchdog -from .scheduler import Scheduler +from .scheduler import Scheduler \ No newline at end of file diff --git a/src-server/src/utils/directory_watcher.py b/src-server/src/utils/directory_watcher.py new file mode 100644 index 00000000..98a7b4a7 --- /dev/null +++ b/src-server/src/utils/directory_watcher.py @@ -0,0 +1,116 @@ +import asyncio +from typing import Callable, Awaitable +from anyio import Path as AnyioPath +from pathlib import Path as StdPath +from loguru import logger +from watchfiles import awatch, Change as ChangeType + + +type FileChange = tuple[ChangeType, str] +type ChangeHandler = Callable[[list[FileChange]], Awaitable] + +WATCHFILES_SENTINEL = ".watchfiles_sentinel" + +class DirectoryWatcher: + _logger = logger.bind(name="DirectoryWatcher") + + def __init__(self, dir: AnyioPath, on_changes: ChangeHandler): + self._dir = dir + self._on_changes = on_changes + self._change_event: asyncio.Event | None = None + self._stop_event: asyncio.Event | None = None + self._watch_task: asyncio.Task | None = None + + async def start(self) -> None: + if self._watch_task is not None: + await self.stop() + + self._stop_event = asyncio.Event() + self._change_event = asyncio.Event() + ready_event = asyncio.Event() + + self._watch_task = asyncio.create_task( + self._watch_loop(ready_event, self._change_event, self._stop_event)) + + # TODO: use the builtin ready_event of awatch after https://github.com/samuelcolvin/watchfiles/pull/356 is merged + sentinel = self._dir / WATCHFILES_SENTINEL + await sentinel.write_bytes(b"") + await sentinel.unlink(missing_ok=True) + await ready_event.wait() + + async def stop(self) -> None: + if self._stop_event: + await self._drain() + self._stop_event.set() + self._stop_event = None + + if self._watch_task and not self._watch_task.done(): + try: + await asyncio.wait_for(self._watch_task, timeout=2.0) + except asyncio.TimeoutError: + self._logger.warning("awatch did not exit gracefully, cancelling") + self._watch_task.cancel() + try: + await self._watch_task + except asyncio.CancelledError: + # since the watch_task is cancelled programatically, + # we should ignore CancelledError here + pass + finally: + self._watch_task = None + + async def _drain(self) -> None: + """ + Wait until the watcher goes quiet before signalling stop. + This ensures all the changes generated before `stop` to be handled. + """ + INITIAL_TIMEOUT_SEC = 0.8 + EXTRA_TIMEOUT_SEC = 0.4 + + change_event = self._change_event + if change_event is None: + return + + timeout = INITIAL_TIMEOUT_SEC + while True: + change_event.clear() + try: + await asyncio.wait_for(change_event.wait(), timeout=timeout) + timeout = EXTRA_TIMEOUT_SEC + except asyncio.TimeoutError: + # Quiet for the full window — safe to stop. + break + self._change_event = None + + async def _watch_loop(self, + ready_event: asyncio.Event, + change_event: asyncio.Event, + stop_event: asyncio.Event) -> None: + if not await self._dir.exists(): + self._logger.warning(f"Watching directory does not exist: {self._dir}") + ready_event.set() + return + + try: + async for changes in awatch( + StdPath(self._dir), + stop_event=stop_event, + recursive=True, + debounce=100, + ): + filtered: list[FileChange] = [] + for change_type, path in changes: + if WATCHFILES_SENTINEL in path: + ready_event.set() + continue + filtered.append((change_type, path)) + + if filtered: + change_event.set() + await self._on_changes(filtered) + + except asyncio.CancelledError: + self._logger.debug(f"Watching cancelled: {self._dir}") + raise + except Exception: + self._logger.exception(f"Unexpected error in watcher: {self._dir}") diff --git a/src-server/tests/agent/notes/test_manager.py b/src-server/tests/agent/notes/test_manager.py deleted file mode 100644 index 2ed7b014..00000000 --- a/src-server/tests/agent/notes/test_manager.py +++ /dev/null @@ -1,699 +0,0 @@ -"""Tests for the NoteManager class.""" - -import asyncio -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from anyio import Path as AnyioPath -from watchfiles import Change as ChangeType - -from src.agent.notes.manager import NoteManager -from src.db.models import workspace as workspace_models - - -@pytest.fixture(autouse=True) -def reset_note_manager_ref_counts(): - NoteManager._workspace_ref_counts.clear() - NoteManager._workspace_ref_lock = None - - -@pytest.fixture -def note_manager(): - """Return a NoteManager instance for workspace_id=1.""" - return NoteManager(workspace_id=1) - - -@pytest.fixture -def mock_workspace_with_notes(): - """Return a mock workspace with notes.""" - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [ - workspace_models.WorkspaceNote(relative="README.md", content="# Hello"), - workspace_models.WorkspaceNote(relative="docs/guide.md", content="## Guide"), - ] - return workspace - - -class TestGetNotesRootDir: - """Tests for get_notes_root_dir static method.""" - - @pytest.mark.asyncio - async def test_creates_directory_if_not_exists(self, tmp_path: Path, monkeypatch): - """Test that get_notes_root_dir creates the directory if it doesn't exist.""" - notes_root = tmp_path / ".notes" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - - result = await NoteManager.get_notes_root_dir() - - assert await result.exists() - assert result.name == ".notes" - - @pytest.mark.asyncio - async def test_returns_existing_directory(self, tmp_path: Path, monkeypatch): - """Test that get_notes_root_dir returns existing directory without error.""" - notes_root = tmp_path / ".notes" - notes_root.mkdir() - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - - result = await NoteManager.get_notes_root_dir() - - assert result == AnyioPath(notes_root) - - -class TestGetNotesDir: - """Tests for get_notes_dir method.""" - - @pytest.mark.asyncio - async def test_creates_workspace_specific_directory(self, tmp_path: Path, monkeypatch): - """Test that get_notes_dir creates workspace-specific directory.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=42) - - result = await manager.get_notes_dir() - - assert await result.exists() - assert result.name == "42" - assert result.parent.name == ".notes" - - -class TestGetNotesIndex: - """Tests for get_notes_index method.""" - - @pytest.mark.asyncio - async def test_returns_notes_index_content(self, tmp_path: Path, monkeypatch): - """Test that get_notes_index returns NOTES.md content when it exists.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=7) - - notes_dir = await manager.get_notes_dir() - notes_index = notes_dir / "NOTES.md" - await notes_index.write_text("# Workspace notes\n\n- item", "utf-8") - - result = await manager.get_notes_index() - - assert result == "# Workspace notes\n\n- item" - - @pytest.mark.asyncio - async def test_returns_none_when_notes_index_not_exists(self, tmp_path: Path, monkeypatch): - """Test that get_notes_index returns None when NOTES.md does not exist.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=8) - - result = await manager.get_notes_index() - - assert result is None - - @pytest.mark.asyncio - async def test_returns_none_when_reading_notes_index_fails(self, tmp_path: Path, monkeypatch): - """Test that get_notes_index returns None when NOTES.md cannot be read.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=9) - - notes_dir = await manager.get_notes_dir() - notes_index = notes_dir / "NOTES.md" - await notes_index.write_text("# Workspace notes", "utf-8") - - with patch.object(AnyioPath, "read_text", AsyncMock(side_effect=OSError("read failed"))): - result = await manager.get_notes_index() - - assert result is None - - -class TestNotesDirEnv: - """Tests for notes_dir_env property.""" - - def test_returns_env_dict_with_correct_path(self, tmp_path: Path, monkeypatch): - """Test that notes_dir_env returns correct environment variable dict.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - env = NoteManager.get_notes_dir_env(workspace_id=5) - - expected_path = str(tmp_path / ".notes" / "5") - assert env == {NoteManager.NOTES_DIR_ENVNAME: expected_path} - - -class TestMaterialize: - """Tests for materialize method.""" - - @pytest.mark.asyncio - async def test_materializes_notes_to_files( - self, tmp_path: Path, monkeypatch, mock_workspace_with_notes - ): - """Test that materialize writes notes to the filesystem.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = mock_workspace_with_notes - - with patch( - "src.services.workspace.WorkspaceService", return_value=mock_service - ): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_db_context.return_value.__aenter__ = AsyncMock( - return_value=AsyncMock() - ) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - manager = NoteManager(workspace_id=1) - result = await manager.materialize() - - # Check files were created - readme_path = result / "README.md" - guide_path = result / "docs" / "guide.md" - - assert await readme_path.exists() - assert await guide_path.exists() - assert await readme_path.read_text("utf-8") == "# Hello" - assert await guide_path.read_text("utf-8") == "## Guide" - - @pytest.mark.asyncio - async def test_materialize_empty_notes(self, tmp_path: Path, monkeypatch): - """Test that materialize works with empty notes list.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [] - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch( - "src.services.workspace.WorkspaceService", return_value=mock_service - ): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_db_context.return_value.__aenter__ = AsyncMock( - return_value=AsyncMock() - ) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - manager = NoteManager(workspace_id=1) - result = await manager.materialize() - - assert await result.exists() - - -class TestClearMaterialized: - """Tests for clear_materialized method.""" - - @pytest.mark.asyncio - async def test_removes_notes_directory(self, tmp_path: Path, monkeypatch): - """Test that clear_materialized removes the workspace notes directory.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - - workspace = MagicMock() - workspace.id = 3 - workspace.notes = [workspace_models.WorkspaceNote(relative="test.md", content="content")] - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch("src.services.workspace.WorkspaceService", return_value=mock_service): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock()) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - manager = NoteManager(workspace_id=3) - notes_dir = await manager.materialize() - - assert await notes_dir.exists() - - await manager.clear_materialized() - - assert not await notes_dir.exists() - - @pytest.mark.asyncio - async def test_keeps_directory_when_workspace_has_other_active_task(self, tmp_path: Path, monkeypatch): - """Test that clear_materialized keeps notes when another task in same workspace is active.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [workspace_models.WorkspaceNote(relative="README.md", content="# Hello")] - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch("src.services.workspace.WorkspaceService", return_value=mock_service): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock()) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - manager_a = NoteManager(workspace_id=1) - manager_b = NoteManager(workspace_id=1) - - notes_dir = await manager_a.materialize() - await manager_b.materialize() - - assert await notes_dir.exists() - - await manager_a.clear_materialized() - assert await notes_dir.exists() - - await manager_b.clear_materialized() - assert not await notes_dir.exists() - - @pytest.mark.asyncio - async def test_force_clear_always_removes_directory(self, tmp_path: Path, monkeypatch): - """Test that force clear removes notes directory regardless of ref count.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [workspace_models.WorkspaceNote(relative="README.md", content="# Hello")] - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch("src.services.workspace.WorkspaceService", return_value=mock_service): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock()) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - manager_a = NoteManager(workspace_id=1) - manager_b = NoteManager(workspace_id=1) - - notes_dir = await manager_a.materialize() - await manager_b.materialize() - - assert await notes_dir.exists() - - await manager_a.clear_materialized(force=True) - assert not await notes_dir.exists() - - @pytest.mark.asyncio - async def test_handles_nonexistent_directory(self, tmp_path: Path, monkeypatch): - """Test that clear_materialized handles non-existent directory gracefully.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=999) - - # Should not raise - await manager.clear_materialized() - - -class TestStartWatching: - """Tests for start_watching method.""" - - @pytest.mark.asyncio - async def test_creates_watch_task(self, tmp_path: Path, monkeypatch): - """Test that start_watching creates a watch task.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - # Create the notes directory - await manager.get_notes_dir() - - with patch("src.agent.notes.manager.awatch") as mock_awatch: - mock_awatch.return_value.__aiter__ = AsyncMock( - return_value=iter([]) - ) - - await manager.start_watching() - - assert manager._watch_task is not None - assert not manager._watch_task.done() - - # Cleanup - await manager.stop_watching() - - @pytest.mark.asyncio - async def test_restarts_existing_watch(self, tmp_path: Path, monkeypatch): - """Test that start_watching stops existing watch before starting new one.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - # Create the notes directory - await manager.get_notes_dir() - - with patch("src.agent.notes.manager.awatch") as mock_awatch: - mock_awatch.return_value.__aiter__ = AsyncMock( - return_value=iter([]) - ) - - await manager.start_watching() - first_task = manager._watch_task - - await manager.start_watching() - second_task = manager._watch_task - - assert first_task is not second_task - - # Cleanup - await manager.stop_watching() - - -class TestStopWatching: - """Tests for stop_watching method.""" - - @pytest.mark.asyncio - async def test_cancels_watch_task(self, tmp_path: Path, monkeypatch): - """Test that stop_watching cancels the watch task.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - # Create the notes directory - await manager.get_notes_dir() - - with patch("src.agent.notes.manager.awatch") as mock_awatch: - mock_awatch.return_value.__aiter__ = AsyncMock( - return_value=iter([]) - ) - - await manager.start_watching() - assert manager._watch_task is not None - - await manager.stop_watching() - - assert manager._watch_task is None - assert manager._stop_watching_event is None - - @pytest.mark.asyncio - async def test_handles_already_stopped(self, tmp_path: Path, monkeypatch): - """Test that stop_watching handles already stopped state gracefully.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - # Should not raise when nothing is watching - await manager.stop_watching() - - -class TestHandleFileChanges: - """Tests for _handle_file_changes method.""" - - @pytest.mark.asyncio - async def test_handles_added_file(self, tmp_path: Path, monkeypatch): - """Test that added files are processed correctly.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - notes_dir = await manager.get_notes_dir() - - # Create a file to simulate "added" - test_file = notes_dir / "new.md" - await test_file.write_text("New content", "utf-8") - - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [] - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch( - "src.services.workspace.WorkspaceService", return_value=mock_service - ): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_session = AsyncMock() - mock_db_context.return_value.__aenter__ = AsyncMock( - return_value=mock_session - ) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - changes = [(ChangeType.added, AnyioPath(test_file))] - await manager._handle_file_changes(changes) - - # Verify workspace.notes was updated - assert len(workspace.notes) == 1 - assert workspace.notes[0].relative == "new.md" - assert workspace.notes[0].content == "New content" - - @pytest.mark.asyncio - async def test_handles_modified_file(self, tmp_path: Path, monkeypatch): - """Test that modified files are processed correctly.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - notes_dir = await manager.get_notes_dir() - - # Create existing note - existing_note = workspace_models.WorkspaceNote( - relative="existing.md", content="Old content" - ) - - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [existing_note] - - # Update the file - test_file = notes_dir / "existing.md" - await test_file.parent.mkdir(parents=True, exist_ok=True) - await test_file.write_text("Updated content", "utf-8") - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch( - "src.services.workspace.WorkspaceService", return_value=mock_service - ): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_session = AsyncMock() - mock_db_context.return_value.__aenter__ = AsyncMock( - return_value=mock_session - ) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - changes = [(ChangeType.modified, AnyioPath(test_file))] - await manager._handle_file_changes(changes) - - # Verify note was updated - assert len(workspace.notes) == 1 - assert workspace.notes[0].content == "Updated content" - - @pytest.mark.asyncio - async def test_handles_deleted_file(self, tmp_path: Path, monkeypatch): - """Test that deleted files are removed from notes.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - notes_dir = await manager.get_notes_dir() - - # Create existing note - existing_note = workspace_models.WorkspaceNote( - relative="deleted.md", content="Content" - ) - - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [existing_note] - - # File path (doesn't need to exist for deletion test) - test_file = notes_dir / "deleted.md" - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch( - "src.services.workspace.WorkspaceService", return_value=mock_service - ): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_session = AsyncMock() - mock_db_context.return_value.__aenter__ = AsyncMock( - return_value=mock_session - ) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - changes = [(ChangeType.deleted, AnyioPath(test_file))] - await manager._handle_file_changes(changes) - - # Verify note was removed - assert len(workspace.notes) == 0 - - @pytest.mark.asyncio - async def test_skips_non_md_files(self, tmp_path: Path, monkeypatch): - """Test that non-.md files are ignored.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - notes_dir = await manager.get_notes_dir() - - # Create a non-md file - test_file = notes_dir / "script.py" - await test_file.write_text("print('hello')", "utf-8") - - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [] - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch( - "src.services.workspace.WorkspaceService", return_value=mock_service - ): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_session = AsyncMock() - mock_db_context.return_value.__aenter__ = AsyncMock( - return_value=mock_session - ) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - # _watch_files filters non-md files, but _handle_file_changes - # should handle any file type passed to it - changes = [(ChangeType.added, AnyioPath(test_file))] - await manager._handle_file_changes(changes) - - # Verify note was still added (handler doesn't filter) - assert len(workspace.notes) == 1 - - @pytest.mark.asyncio - async def test_handles_empty_changes(self, tmp_path: Path, monkeypatch): - """Test that empty changes are handled gracefully.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - # Should not raise - await manager._handle_file_changes([]) - - @pytest.mark.asyncio - async def test_handles_read_error_gracefully(self, tmp_path: Path, monkeypatch): - """Test that file read errors are handled gracefully.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - notes_dir = await manager.get_notes_dir() - - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [] - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - # Create a path that doesn't exist to trigger read error - nonexistent_file = notes_dir / "nonexistent.md" - - with patch( - "src.services.workspace.WorkspaceService", return_value=mock_service - ): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_session = AsyncMock() - mock_db_context.return_value.__aenter__ = AsyncMock( - return_value=mock_session - ) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - changes = [(ChangeType.added, AnyioPath(nonexistent_file))] - # Should not raise even though file doesn't exist - await manager._handle_file_changes(changes) - - -class TestWatchFiles: - """Tests for _watch_files method.""" - - @pytest.mark.asyncio - async def test_returns_early_if_directory_not_exists(self, tmp_path: Path): - """Test that _watch_files returns early if directory doesn't exist.""" - manager = NoteManager(workspace_id=999) - stop_event = asyncio.Event() - - nonexistent_dir = AnyioPath(tmp_path / "nonexistent") - - # Should complete without error - await manager._watch_files(nonexistent_dir, stop_event) - - @pytest.mark.asyncio - async def test_filters_md_files(self, tmp_path: Path, monkeypatch): - """Test that _watch_files filters to only .md files.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - notes_dir = await manager.get_notes_dir() - stop_event = asyncio.Event() - - # Create files - md_file = notes_dir / "note.md" - py_file = notes_dir / "script.py" - await md_file.write_text("# Note", "utf-8") - await py_file.write_text("print('hello')", "utf-8") - - mock_changes = [ - {(ChangeType.added, str(md_file)), (ChangeType.added, str(py_file))} - ] - - with patch( - "src.agent.notes.manager.awatch" - ) as mock_awatch, patch.object( - manager, "_handle_file_changes" - ) as mock_handle: - - async def mock_awatch_generator(*args, **kwargs): - for change_set in mock_changes: - yield change_set - stop_event.set() - - mock_awatch.return_value.__aiter__ = mock_awatch_generator - - # Run watch briefly then stop - watch_task = asyncio.create_task( - manager._watch_files(notes_dir, stop_event) - ) - - # Give it time to process - await asyncio.sleep(0.1) - stop_event.set() - - try: - await asyncio.wait_for(watch_task, timeout=1.0) - except asyncio.TimeoutError: - watch_task.cancel() - - # Verify _handle_file_changes was called with only md file - mock_handle.assert_called_once() - call_args = mock_handle.call_args[0][0] - assert len(call_args) == 1 - change_type, path = list(call_args)[0] - assert path.name == "note.md" - - @pytest.mark.asyncio - async def test_handles_cancelled_error(self, tmp_path: Path, monkeypatch): - """Test that _watch_files handles CancelledError gracefully.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - notes_dir = await manager.get_notes_dir() - stop_event = asyncio.Event() - - with patch("src.agent.notes.manager.awatch") as mock_awatch: - mock_awatch.side_effect = asyncio.CancelledError() - - # Should not raise - await manager._watch_files(notes_dir, stop_event) - - -class TestIntegration: - """Integration tests for NoteManager.""" - - @pytest.mark.asyncio - async def test_full_lifecycle(self, tmp_path: Path, monkeypatch): - """Test the full lifecycle: materialize, watch, modify, clear.""" - monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path) - manager = NoteManager(workspace_id=1) - - # Setup mock workspace - workspace = MagicMock() - workspace.id = 1 - workspace.notes = [ - workspace_models.WorkspaceNote(relative="test.md", content="Initial") - ] - - mock_service = AsyncMock() - mock_service.get_workspace_by_id.return_value = workspace - - with patch( - "src.services.workspace.WorkspaceService", return_value=mock_service - ): - with patch("src.agent.notes.manager.db_context") as mock_db_context: - mock_session = AsyncMock() - mock_db_context.return_value.__aenter__ = AsyncMock( - return_value=mock_session - ) - mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) - - # Materialize - notes_dir = await manager.materialize() - assert await (notes_dir / "test.md").exists() - - # Clear - await manager.clear_materialized() - assert not await notes_dir.exists() diff --git a/src-server/tests/agent/notes/test_materializer.py b/src-server/tests/agent/notes/test_materializer.py new file mode 100644 index 00000000..9a4c1fc2 --- /dev/null +++ b/src-server/tests/agent/notes/test_materializer.py @@ -0,0 +1,168 @@ +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from anyio import Path as AnyioPath + +from src.agent.notes import NoteMaterializer +from src.db.models import workspace as workspace_models +from src.schemas import workspace as workspace_schemas + + +@pytest.mark.integration +class TestNoteMaterializer: + @pytest.mark.asyncio + async def test_get_notes_root_dir_creates_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + + result = await NoteMaterializer.get_notes_root_dir() + + assert await result.exists() + assert result == AnyioPath(tmp_path / ".notes") + + @pytest.mark.asyncio + async def test_get_notes_dir_creates_workspace_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + + result = await NoteMaterializer.get_notes_dir(42) + + assert await result.exists() + assert result == AnyioPath(tmp_path / ".notes" / "42") + + def test_get_notes_dir_env_returns_workspace_specific_env(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + + env = NoteMaterializer.get_notes_dir_env(5) + + assert env == { + NoteMaterializer.NOTES_DIR_ENVNAME: str(tmp_path / ".notes" / "5") + } + + @pytest.mark.asyncio + async def test_get_notes_index_returns_content(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + + notes_dir = await NoteMaterializer.get_notes_dir(7) + notes_index = notes_dir / "NOTES.md" + await notes_index.write_text("# Workspace notes\n\n- item", "utf-8") + + result = await NoteMaterializer.get_notes_index(7) + + assert result == "# Workspace notes\n\n- item" + + @pytest.mark.asyncio + async def test_get_notes_index_returns_none_when_missing(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + + result = await NoteMaterializer.get_notes_index(8) + + assert result is None + + @pytest.mark.asyncio + async def test_get_notes_index_returns_none_when_read_fails(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + + notes_dir = await NoteMaterializer.get_notes_dir(9) + notes_index = notes_dir / "NOTES.md" + await notes_index.write_text("# Workspace notes", "utf-8") + + with patch.object(AnyioPath, "read_text", AsyncMock(side_effect=OSError("read failed"))): + result = await NoteMaterializer.get_notes_index(9) + + assert result is None + + @pytest.mark.asyncio + async def test_materialize_writes_workspace_notes_to_files(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + workspace = workspace_schemas.WorkspaceRead.model_validate({ + "id": 1, + "name": "Workspace", + "directory": "/tmp/workspace", + "instruction": "", + "notes": [ + {"id": 1, "relative": "README.md", "content": "# Hello"}, + {"id": 2, "relative": "docs/guide.md", "content": "## Guide"}, + ], + "usable_agents": [], + "usable_tools": [], + "usable_skills": [], + }) + + notes_dir = await NoteMaterializer.materialize(workspace) + + readme_path = Path(str(notes_dir)) / "README.md" + guide_path = Path(str(notes_dir)) / "docs" / "guide.md" + assert readme_path.read_text(encoding="utf-8") == "# Hello" + assert guide_path.read_text(encoding="utf-8") == "## Guide" + + @pytest.mark.asyncio + async def test_clear_materialized_removes_workspace_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + notes_dir = await NoteMaterializer.get_notes_dir(3) + note_path = Path(str(notes_dir)) / "test.md" + note_path.write_text("content", encoding="utf-8") + + await NoteMaterializer.clear_materialized(3) + + assert not note_path.exists() + assert not Path(str(notes_dir)).exists() + + @pytest.mark.asyncio + async def test_clear_materialized_handles_nonexistent_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + + await NoteMaterializer.clear_materialized(999) + + assert not (tmp_path / ".notes" / "999").exists() + + @pytest.mark.asyncio + async def test_materialize_all_clears_then_materializes_each_workspace(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + + workspace_a = MagicMock(spec=workspace_models.Workspace) + workspace_a.id = 1 + workspace_b = MagicMock(spec=workspace_models.Workspace) + workspace_b.id = 2 + + mock_service = AsyncMock() + mock_service.get_all_workspaces.return_value = [workspace_a, workspace_b] + + clear_mock = AsyncMock() + materialize_mock = AsyncMock() + workspace_read_mock = MagicMock(side_effect=[ + workspace_schemas.WorkspaceRead.model_validate({ + "id": 1, + "name": "Workspace A", + "directory": "/tmp/workspace-a", + "instruction": "", + "notes": [{"id": 1, "relative": "a.md", "content": "A"}], + "usable_agents": [], + "usable_tools": [], + "usable_skills": [], + }), + workspace_schemas.WorkspaceRead.model_validate({ + "id": 2, + "name": "Workspace B", + "directory": "/tmp/workspace-b", + "instruction": "", + "notes": [{"id": 2, "relative": "b.md", "content": "B"}], + "usable_agents": [], + "usable_tools": [], + "usable_skills": [], + }), + ]) + + with patch("src.services.workspace.WorkspaceService", return_value=mock_service): + with patch("src.agent.notes.materializer.db_context") as mock_db_context: + mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock()) + mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) + with patch.object(NoteMaterializer, "clear_materialized", clear_mock): + with patch.object(NoteMaterializer, "materialize", materialize_mock): + with patch.object(workspace_schemas.WorkspaceRead, "model_validate", workspace_read_mock): + await NoteMaterializer.materialize_all() + + assert clear_mock.await_args_list == [ + ((1,), {}), + ((2,), {}), + ] + assert materialize_mock.await_count == 2 diff --git a/src-server/tests/agent/notes/test_watcher.py b/src-server/tests/agent/notes/test_watcher.py new file mode 100644 index 00000000..108235c8 --- /dev/null +++ b/src-server/tests/agent/notes/test_watcher.py @@ -0,0 +1,198 @@ +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from anyio import Path as AnyioPath +from watchfiles import Change as ChangeType + +from src.agent.notes.watcher import NoteWatcher +from src.agent.notes.workspace_ref_manager import WorkspaceRefManager +from src.db.models import workspace as workspace_models +from src.utils.directory_watcher import DirectoryWatcher + + +@pytest.mark.integration +class TestNoteWatcher: + @pytest.mark.asyncio + async def test_start_creates_directory_watcher(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + watcher = NoteWatcher(workspace_id=1) + directory_watcher = AsyncMock(spec=DirectoryWatcher) + + with patch( + "src.agent.notes.watcher.DirectoryWatcher", + return_value=directory_watcher, + ) as directory_watcher_cls: + await watcher._start() + + assert watcher._watcher is directory_watcher + assert WorkspaceRefManager.is_workspace_in_use(1) + directory_watcher_cls.assert_called_once() + directory_watcher.start.assert_awaited_once() + + await watcher._stop() + directory_watcher.stop.assert_awaited_once() + assert watcher._watcher is None + assert not WorkspaceRefManager.is_workspace_in_use(1) + + @pytest.mark.asyncio + async def test_start_restarts_existing_directory_watcher(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + watcher = NoteWatcher(workspace_id=1) + first_directory_watcher = AsyncMock(spec=DirectoryWatcher) + second_directory_watcher = AsyncMock(spec=DirectoryWatcher) + + with patch( + "src.agent.notes.watcher.DirectoryWatcher", + side_effect=[first_directory_watcher, second_directory_watcher], + ): + await watcher._start() + await watcher._start() + + first_directory_watcher.start.assert_awaited_once() + first_directory_watcher.stop.assert_awaited_once() + second_directory_watcher.start.assert_awaited_once() + assert watcher._watcher is second_directory_watcher + assert WorkspaceRefManager.is_workspace_in_use(1) + + await watcher._stop() + second_directory_watcher.stop.assert_awaited_once() + assert not WorkspaceRefManager.is_workspace_in_use(1) + + @pytest.mark.asyncio + async def test_stop_handles_already_stopped_state(self): + watcher = NoteWatcher(workspace_id=1) + + await watcher._stop() + + assert watcher._watcher is None + assert not WorkspaceRefManager.is_workspace_in_use(1) + + @pytest.mark.asyncio + async def test_start_rolls_back_workspace_ref_when_directory_watcher_start_fails( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ): + monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path) + watcher = NoteWatcher(workspace_id=1) + directory_watcher = AsyncMock(spec=DirectoryWatcher) + directory_watcher.start.side_effect = RuntimeError("boom") + + with patch("src.agent.notes.watcher.DirectoryWatcher", return_value=directory_watcher): + with pytest.raises(RuntimeError, match="boom"): + await watcher._start() + + assert not WorkspaceRefManager.is_workspace_in_use(1) + + @pytest.mark.asyncio + async def test_handle_file_changes_filters_and_forwards_markdown_only(self, tmp_path: Path): + watcher = NoteWatcher(workspace_id=1) + notes_dir = AnyioPath(tmp_path / "notes") + await notes_dir.mkdir(parents=True, exist_ok=True) + + md_file = notes_dir / "note.md" + txt_file = notes_dir / "note.txt" + subdir = notes_dir / "dir" + await md_file.write_text("# Note", "utf-8") + await txt_file.write_text("plain text", "utf-8") + await subdir.mkdir() + + with patch( + "src.agent.notes.watcher.NoteMaterializer.get_notes_dir", + AsyncMock(return_value=notes_dir), + ): + with patch.object(watcher, "_handle_note_changes", AsyncMock()) as handle_mock: + await watcher._handle_file_changes( + [ + (ChangeType.added, str(md_file)), + (ChangeType.added, str(txt_file)), + (ChangeType.added, str(subdir)), + ] + ) + + handle_mock.assert_awaited_once() + forwarded_base, forwarded_changes = handle_mock.await_args.args + assert forwarded_base == notes_dir + assert len(forwarded_changes) == 1 + change_type, path = forwarded_changes[0] + assert change_type == ChangeType.added + assert path == md_file + + @pytest.mark.asyncio + async def test_handle_file_changes_ignores_empty_changes(self, tmp_path: Path): + watcher = NoteWatcher(workspace_id=1) + notes_dir = AnyioPath(tmp_path / "notes") + await notes_dir.mkdir(parents=True, exist_ok=True) + + with patch( + "src.agent.notes.watcher.NoteMaterializer.get_notes_dir", + AsyncMock(return_value=notes_dir), + ): + with patch.object(watcher, "_handle_note_changes", AsyncMock()) as handle_mock: + await watcher._handle_file_changes([]) + + handle_mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_handle_note_changes_adds_updates_and_deletes_notes(self, tmp_path: Path): + watcher = NoteWatcher(workspace_id=1) + base = AnyioPath(tmp_path / "notes") + await base.mkdir(parents=True, exist_ok=True) + + added_path = base / "new.md" + updated_path = base / "existing.md" + deleted_path = base / "deleted.md" + await added_path.write_text("New content", "utf-8") + await updated_path.write_text("Updated content", "utf-8") + + workspace = MagicMock() + workspace.id = 1 + workspace.notes = [ + workspace_models.WorkspaceNote(relative="existing.md", content="Old content"), + workspace_models.WorkspaceNote(relative="deleted.md", content="Delete me"), + ] + + mock_service = AsyncMock() + mock_service.get_workspace_by_id.return_value = workspace + + with patch("src.services.workspace.WorkspaceService", return_value=mock_service): + with patch("src.agent.notes.watcher.db_context") as mock_db_context: + mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock()) + mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) + await watcher._handle_note_changes( + base, + [ + (ChangeType.added, added_path), + (ChangeType.modified, updated_path), + (ChangeType.deleted, deleted_path), + ], + ) + + notes_by_relative = {note.relative: note.content for note in workspace.notes} + assert notes_by_relative == { + "existing.md": "Updated content", + "new.md": "New content", + } + + @pytest.mark.asyncio + async def test_handle_note_changes_skips_file_read_errors(self, tmp_path: Path): + watcher = NoteWatcher(workspace_id=1) + base = AnyioPath(tmp_path / "notes") + await base.mkdir(parents=True, exist_ok=True) + + missing_path = base / "missing.md" + workspace = MagicMock() + workspace.id = 1 + workspace.notes = [] + + mock_service = AsyncMock() + mock_service.get_workspace_by_id.return_value = workspace + + with patch("src.services.workspace.WorkspaceService", return_value=mock_service): + with patch("src.agent.notes.watcher.db_context") as mock_db_context: + mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock()) + mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False) + await watcher._handle_note_changes(base, [(ChangeType.added, missing_path)]) + + assert workspace.notes == [] diff --git a/src-server/tests/services/test_task_service.py b/src-server/tests/services/test_task_service.py index 9114df5a..03576bd1 100644 --- a/src-server/tests/services/test_task_service.py +++ b/src-server/tests/services/test_task_service.py @@ -297,7 +297,6 @@ async def test_delete_task_removes_entity_and_task_resources( workspace_factory, agent_factory, tool_factory, - task_resource_data_dir: Path, ): tool = await tool_factory(name="Echo", internal_key="echo") agent = await agent_factory(name="Agent A", usable_tools=[tool]) @@ -320,15 +319,11 @@ async def test_delete_task_removes_entity_and_task_resources( "note.txt", b"resource-bytes", ) - resource_dir = task_resource_data_dir / ".task-resources" / "task" / str(task.id) - resource_path = resource_dir / resource.filename await task_service.delete_task(task.id) await db_session.flush() db_session.expunge_all() - assert not resource_path.exists() - assert not resource_dir.exists() with pytest.raises(TaskNotFoundError, match=f"Task '{task.id}' not found"): await task_service.get_task_by_id(task.id) diff --git a/src-server/tests/services/test_workspace_service.py b/src-server/tests/services/test_workspace_service.py index 3fa6daf8..8c4cb660 100644 --- a/src-server/tests/services/test_workspace_service.py +++ b/src-server/tests/services/test_workspace_service.py @@ -4,7 +4,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from src.agent.notes.manager import NoteManager +from src.agent.notes import NoteMaterializer from src.db.models.markdown_cache import MarkdownCache from src.db.models import tasks as task_models from src.db.models import workspace as workspace_models @@ -137,18 +137,12 @@ async def test_delete_workspace_removes_entity_and_cascade_children( db_session.add(cache) await db_session.flush() - note_manager = NoteManager(workspace.id) - notes_dir = await note_manager.get_notes_dir() - note_path = Path(str(notes_dir)) / "note.md" - note_path.write_text("persisted", encoding="utf-8") - await workspace_service.delete_workspace(workspace.id) await db_session.flush() db_session.expunge_all() with pytest.raises(WorkspaceNotFoundError, match=f"Workspace '{workspace.id}' not found"): await workspace_service.get_workspace_by_id(workspace.id) - assert not note_path.exists() note_in_db = await db_session.scalar( select(workspace_models.WorkspaceNote).where(workspace_models.WorkspaceNote.id == note_id) diff --git a/src-server/tests/tools/conftest.py b/src-server/tests/tools/conftest.py index de47c071..3882a723 100644 --- a/src-server/tests/tools/conftest.py +++ b/src-server/tests/tools/conftest.py @@ -15,4 +15,4 @@ def _init(self, ctx, toolset_ent=None): @pytest.fixture def built_in_toolset_context(temp_workspace): - return BuiltInToolsetContext(1, temp_workspace, ContextUsage.default()) + return BuiltInToolsetContext(1, temp_workspace) diff --git a/src-server/tests/tools/file_system/test_init.py b/src-server/tests/tools/file_system/test_init.py index 50f958a4..c68f7a7e 100644 --- a/src-server/tests/tools/file_system/test_init.py +++ b/src-server/tests/tools/file_system/test_init.py @@ -29,7 +29,7 @@ async def test_read_file_resolves_tilde_cwd(self): marker = home / "__dais_test_marker__.txt" marker.write_text("tilde-ok", encoding="utf-8") try: - tool = FileSystemToolset(BuiltInToolsetContext(1, "~", ContextUsage.default())) + tool = FileSystemToolset(BuiltInToolsetContext(1, "~")) result = await tool.read_file("__dais_test_marker__.txt") assert "tilde-ok" in result finally: