diff --git a/src-frontend/src/api/workspace.ts b/src-frontend/src/api/workspace.ts
index d570d138..9d7bd7e3 100644
--- a/src-frontend/src/api/workspace.ts
+++ b/src-frontend/src/api/workspace.ts
@@ -7,6 +7,7 @@ export {
useGetWorkspaceSuspense,
useGetWorkspacesSuspenseInfinite,
useUpdateWorkspace,
+ useUpdateWorkspaceNotes,
openWorkspace,
} from "./generated/endpoints/workspace/workspace";
diff --git a/src-frontend/src/features/Tabs/WorkspacePanel/WorkspaceNotesEditForm.tsx b/src-frontend/src/features/Tabs/WorkspacePanel/WorkspaceNotesEditForm.tsx
index 58fb7309..282ada46 100644
--- a/src-frontend/src/features/Tabs/WorkspacePanel/WorkspaceNotesEditForm.tsx
+++ b/src-frontend/src/features/Tabs/WorkspacePanel/WorkspaceNotesEditForm.tsx
@@ -3,7 +3,7 @@ import { useTranslation } from "react-i18next";
import { toast } from "sonner";
import { TABS_WORKSPACE_NAMESPACE } from "@/i18n/resources";
import type { WorkspaceRead } from "@/api/generated/schemas";
-import { invalidateWorkspaceQueries, useUpdateWorkspace } from "@/api/workspace";
+import { invalidateWorkspaceQueries, useUpdateWorkspaceNotes } from "@/api/workspace";
import { FormShell, FormShellFooter } from "@/components/custom/form/FormShell";
import { Button } from "@/components/ui/button";
import { useWorkspaceStore } from "@/stores/workspace-store";
@@ -29,7 +29,7 @@ export function WorkspaceNotesEditForm({ workspace, onConfirm }: WorkspaceNotesE
[workspace.notes]
);
- const updateMutation = useUpdateWorkspace({
+ const updateNotesMutation = useUpdateWorkspaceNotes({
mutation: {
async onSuccess(updatedWorkspace: { id: number; name: string }) {
await invalidateWorkspaceQueries(updatedWorkspace.id);
@@ -47,7 +47,7 @@ export function WorkspaceNotesEditForm({ workspace, onConfirm }: WorkspaceNotesE
const handleSubmit = (data: WorkspaceNotesEditFormValues) => {
const payload = notesEditFormValuesToPayload(data);
- updateMutation.mutate({ workspaceId: workspace.id, data: payload });
+ updateNotesMutation.mutate({ workspaceId: workspace.id, data: payload });
};
return (
@@ -59,8 +59,8 @@ export function WorkspaceNotesEditForm({ workspace, onConfirm }: WorkspaceNotesE
-
diff --git a/src-frontend/src/features/Tabs/WorkspacePanel/form-types.ts b/src-frontend/src/features/Tabs/WorkspacePanel/form-types.ts
index aaf1f810..a1afab48 100644
--- a/src-frontend/src/features/Tabs/WorkspacePanel/form-types.ts
+++ b/src-frontend/src/features/Tabs/WorkspacePanel/form-types.ts
@@ -1,5 +1,6 @@
import type {
WorkspaceCreate,
+ WorkspaceNotesUpdate,
WorkspaceRead,
WorkspaceUpdate,
} from "@/api/generated/schemas";
@@ -52,23 +53,13 @@ export function createFormValuesToPayload(
export function editFormValuesToPayload(
values: WorkspaceEditFormValues,
): WorkspaceUpdate {
- return {
- ...values,
- notes: null, // explicitly exclude `notes`
- };
+ return values;
}
export function notesEditFormValuesToPayload(
values: WorkspaceNotesEditFormValues,
-): WorkspaceUpdate {
+): WorkspaceNotesUpdate {
return {
- name: null,
- directory: null,
- instruction: null,
- usable_agent_ids: null,
- usable_tool_ids: null,
- usable_skill_ids: null,
notes: arboristDataToResources(values.notes),
};
}
-
diff --git a/src-frontend/src/i18n/locales/en/error.json b/src-frontend/src/i18n/locales/en/error.json
index 948bd408..0ff1222a 100644
--- a/src-frontend/src/i18n/locales/en/error.json
+++ b/src-frontend/src/i18n/locales/en/error.json
@@ -24,5 +24,6 @@
"TOOLSET_INTERNAL_KEY_ALREADY_EXISTS": "This toolset already exists.",
"TOOLSET_NOT_FOUND": "The toolset was not found.",
"UNEXPECTED_ERROR": "Something went wrong. Please try again.",
- "WORKSPACE_NOT_FOUND": "The workspace was not found."
+ "WORKSPACE_NOT_FOUND": "The workspace was not found.",
+ "WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK": "Workspace notes cannot be updated while a task in this workspace is running. Please try again after the task finishes."
}
diff --git a/src-frontend/src/i18n/locales/zh_CN/error.json b/src-frontend/src/i18n/locales/zh_CN/error.json
index 014d3b6a..cf3b8d0c 100644
--- a/src-frontend/src/i18n/locales/zh_CN/error.json
+++ b/src-frontend/src/i18n/locales/zh_CN/error.json
@@ -24,5 +24,6 @@
"TOOLSET_INTERNAL_KEY_ALREADY_EXISTS": "该工具集已存在。",
"TOOLSET_NOT_FOUND": "未找到该工具集。",
"UNEXPECTED_ERROR": "发生错误,请稍后重试。",
- "WORKSPACE_NOT_FOUND": "未找到该工作区。"
+ "WORKSPACE_NOT_FOUND": "未找到该工作区。",
+ "WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK": "当前工作区下有任务正在运行,暂时无法更新工作区笔记,请在任务结束后重试。"
}
diff --git a/src-server/src/agent/context/__init__.py b/src-server/src/agent/context/__init__.py
index 961fb44d..6f4fca98 100644
--- a/src-server/src/agent/context/__init__.py
+++ b/src-server/src/agent/context/__init__.py
@@ -1,11 +1,3 @@
-from .aliases import BuiltInToolAliases
-from .models import ToolRuntimeContext, AgentContextResource, AgentContextPersistence
-
-__all__ = [
- "AgentContext",
- "ToolRuntimeContext",
-]
-
import platform
import xml.etree.ElementTree as ET
from collections import namedtuple
@@ -14,7 +6,6 @@
from loguru import logger
from dais_sdk.tool import Toolset
from dais_sdk.types import Message, ToolDef
-from src.agent.notes import NoteManager
from src.db import db_context
from src.schemas import (
agent as agent_schemas,
@@ -29,7 +20,10 @@
from src.services.provider import ProviderService
from src.settings import use_app_setting_manager
from .persistence import create_agent_context_persistence
-from ..tool import use_mcp_toolset_manager, BuiltinToolsetManager, McpToolsetManager, BuiltInToolset
+from .aliases import BuiltInToolAliases
+from .models import AgentContextResource, AgentContextPersistence
+from ..notes import NoteMaterializer
+from ..tool import use_mcp_toolset_manager, BuiltinToolsetManager, McpToolsetManager
from ..prompts import (
BASE_INSTRUCTION,
FAILED_TO_LOAD_NOTES_INDEX,
@@ -49,7 +43,7 @@ def __init__(self,
*,
messages: list[Message],
resource: AgentContextResource,
- tool_context: ToolRuntimeContext,
+ usage: ContextUsage,
persistence: AgentContextPersistence,
builtin_toolset_manager: BuiltinToolsetManager,
mcp_toolset_manager: McpToolsetManager):
@@ -58,7 +52,7 @@ def __init__(self,
self._resource = resource
self._messages = messages
- self._tool_context = tool_context
+ self._usage = usage
self._persistence = persistence
self._builtin_toolset_manager = builtin_toolset_manager
self._mcp_toolset_manager = mcp_toolset_manager
@@ -82,12 +76,7 @@ async def create(cls, task: task_runtime_schemas.TaskRuntimeContext) -> Self:
usage = ContextUsage(**asdict(usage))
messages = task.messages
- note_manager = NoteManager(task.workspace_id)
- await note_manager.materialize()
- await note_manager.start_watching()
-
- tool_context = ToolRuntimeContext(usage=usage, note_manager=note_manager)
- builtin_toolset_manager = await BuiltinToolsetManager.create(workspace.id, workspace.directory, tool_context)
+ builtin_toolset_manager = await BuiltinToolsetManager.create(workspace.id, workspace.directory)
mcp_toolset_manager = use_mcp_toolset_manager()
persistence = create_agent_context_persistence(task)
@@ -101,7 +90,7 @@ async def create(cls, task: task_runtime_schemas.TaskRuntimeContext) -> Self:
model=provider_schemas.LlmModelRead.model_validate(model),
skills=[skill_schemas.SkillBrief.model_validate(skill) for skill in skills],
),
- tool_context=tool_context,
+ usage=usage,
persistence=persistence,
builtin_toolset_manager=builtin_toolset_manager,
mcp_toolset_manager=mcp_toolset_manager)
@@ -131,7 +120,7 @@ def _resolve_instructions(self) -> ResolvedInstructions:
@property
def usage(self) -> ContextUsage:
- return self._tool_context.usage
+ return self._usage
@property
def toolsets(self) -> list[Toolset]:
@@ -171,7 +160,7 @@ async def compose_system_instruction(self) -> str:
available_skills = AgentContext._format_skills([
skill for skill in self._resource.skills if skill.is_enabled])
resolved_instructions = self._resolve_instructions()
- notes_index = await self._tool_context.note_manager.get_notes_index()
+ notes_index = await NoteMaterializer.get_notes_index(self._resource.workspace.id)
resolved_notes_index = notes_index if notes_index is not None else FAILED_TO_LOAD_NOTES_INDEX
return resolved_instructions.base.format(
@@ -194,14 +183,8 @@ def find_tool(self, tool_name: str) -> ToolDef | None:
return None
async def persist(self) -> task_runtime_schemas.TaskRuntimeContext:
- try:
- await self._tool_context.note_manager.stop_watching()
- await self._tool_context.note_manager.clear_materialized()
- except:
- self._logger.exception("Failed to execute NoteManager cleanup.")
-
return await self._persistence.persist(
self.task_id,
self._messages,
- self._tool_context.usage,
+ self._usage,
)
diff --git a/src-server/src/agent/context/models.py b/src-server/src/agent/context/models.py
index cb175fad..4c81e5f7 100644
--- a/src-server/src/agent/context/models.py
+++ b/src-server/src/agent/context/models.py
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Protocol
from dais_sdk.types import Message
-from src.agent.notes import NoteManager
from src.db.models import tasks as task_models
from src.schemas import (
agent as agent_schemas,
@@ -10,14 +9,8 @@
workspace as workspace_schemas,
)
from src.schemas.tasks import runtime as task_runtime_schemas
-from src.agent.types import ContextUsage
-@dataclass(frozen=True)
-class ToolRuntimeContext:
- usage: ContextUsage
- note_manager: NoteManager
-
@dataclass(frozen=True)
class AgentContextResource:
workspace: workspace_schemas.WorkspaceRead
diff --git a/src-server/src/agent/notes/__init__.py b/src-server/src/agent/notes/__init__.py
index fd0cb64c..854c4dfe 100644
--- a/src-server/src/agent/notes/__init__.py
+++ b/src-server/src/agent/notes/__init__.py
@@ -1 +1,3 @@
-from .manager import NoteManager
+from .materializer import NoteMaterializer
+from .watcher import NoteWatcher
+from .workspace_ref_manager import WorkspaceRefManager
diff --git a/src-server/src/agent/notes/manager.py b/src-server/src/agent/notes/manager.py
deleted file mode 100644
index 0f181eb8..00000000
--- a/src-server/src/agent/notes/manager.py
+++ /dev/null
@@ -1,212 +0,0 @@
-import asyncio
-import shutil
-from typing import Any, Callable
-from anyio import Path as AnyioPath
-from pathlib import Path as StdPath
-from loguru import logger
-from watchfiles import awatch, Change as ChangeType
-from src.common import DATA_DIR
-from src.db import db_context
-
-
-type FileChange = tuple[ChangeType, AnyioPath]
-
-class NoteManager:
- NOTES_DIR_ENVNAME = "DAIS_NOTES_DIR"
- _logger = logger.bind(name="NoteManager")
- _workspace_ref_counts: dict[int, int] = {}
- _workspace_ref_lock: asyncio.Lock | None = None
-
- def __init__(self, workspace_id: int):
- self._workspace_id = workspace_id
- self._stop_watching_event = None
- self._watch_task = None
-
- @classmethod
- def _get_lock(cls) -> asyncio.Lock:
- if cls._workspace_ref_lock is None:
- cls._workspace_ref_lock = asyncio.Lock()
- return cls._workspace_ref_lock
-
- @classmethod
- async def _increment_workspace_ref(cls, workspace_id: int) -> int:
- async with cls._get_lock():
- current = cls._workspace_ref_counts.get(workspace_id, 0)
- ref_count = current + 1
- cls._workspace_ref_counts[workspace_id] = ref_count
- return ref_count
-
- @classmethod
- async def _decrement_workspace_ref(cls, workspace_id: int, *, force: bool = False) -> int:
- async with cls._get_lock():
- if force:
- cls._workspace_ref_counts.pop(workspace_id, None)
- return 0
- current = cls._workspace_ref_counts.get(workspace_id, 0)
- if current <= 1:
- cls._workspace_ref_counts.pop(workspace_id, None)
- return 0
- ref_count = current - 1
- cls._workspace_ref_counts[workspace_id] = ref_count
- return ref_count
-
- @staticmethod
- async def get_notes_root_dir() -> AnyioPath:
- notes_root_dir = AnyioPath(DATA_DIR, ".notes")
- await notes_root_dir.mkdir(parents=True, exist_ok=True)
- return notes_root_dir
-
- @staticmethod
- def get_notes_dir_env(workspace_id: int) -> dict[str, str]:
- return {NoteManager.NOTES_DIR_ENVNAME: str(DATA_DIR / ".notes" / str(workspace_id))}
-
- async def get_notes_dir(self) -> AnyioPath:
- notes_root_dir = await self.get_notes_root_dir()
- notes_dir = notes_root_dir / str(self._workspace_id)
- await notes_dir.mkdir(parents=True, exist_ok=True)
- return notes_dir
-
- async def get_notes_index(self) -> str | None:
- notes_dir = await self.get_notes_dir()
- index_file = notes_dir / "NOTES.md"
- if not await index_file.exists(): return None
- try:
- return await index_file.read_text(encoding="utf-8")
- except:
- self._logger.exception(f"Failed to read root NOTES.md for workspace {self._workspace_id}.")
- return None
-
- async def materialize(self) -> AnyioPath:
- """Materialize workspace notes into `$DAIS_NOTES_DIR` for the current task context.
-
- Calling this method increments the workspace-level materialization reference count.
- The caller must later call `clear_materialized()` exactly once to release this
- reference; otherwise temporary notes state may be retained unexpectedly.
- """
- from src.services.workspace import WorkspaceService
-
- async with db_context() as db_session:
- workspace = await WorkspaceService(db_session).get_workspace_by_id(self._workspace_id)
- workspace_notes = workspace.notes
-
- notes_dir = await self.get_notes_dir()
- for note in workspace_notes:
- note_path = notes_dir / note.relative
- await note_path.parent.mkdir(parents=True, exist_ok=True)
- await note_path.write_text(note.content, "utf-8")
-
- await self._increment_workspace_ref(self._workspace_id)
- return notes_dir
-
- async def clear_materialized(self, *, force: bool = False):
- """Release one materialization reference and clean notes when eligible.
-
- This method is the required counterpart of `materialize()`: after each successful
- `materialize()` call, the caller must invoke `clear_materialized()` once.
- When `force=True`, cleanup ignores reference counts (used by workspace deletion).
- """
- notes_root_dir = await self.get_notes_root_dir()
- notes_dir = notes_root_dir / str(self._workspace_id)
-
- ref_count = await self._decrement_workspace_ref(self._workspace_id, force=force)
- should_delete = ref_count == 0
-
- if not should_delete: return
- if not await notes_dir.exists(): return
- await asyncio.to_thread(shutil.rmtree, notes_dir)
-
- async def start_watching(self):
- if self._watch_task is not None:
- await self.stop_watching()
- self._stop_watching_event = asyncio.Event()
- notes_dir = await self.get_notes_dir()
- self._watch_task = asyncio.create_task(
- self._watch_files(notes_dir, self._stop_watching_event))
-
- async def stop_watching(self) -> None:
- if self._stop_watching_event:
- self._stop_watching_event.set()
- self._stop_watching_event = None
-
- if self._watch_task and not self._watch_task.done():
- self._watch_task.cancel()
- try:
- await self._watch_task
- except asyncio.CancelledError:
- pass
- finally:
- self._watch_task = None
-
- async def _handle_file_changes(self, changes: list[FileChange]):
- from src.services.workspace import WorkspaceService
- from src.db.models import workspace as workspace_models
-
- if len(changes) == 0: return
- added_notes: list[tuple[AnyioPath, str]] = [] # (relative, content)
- updated_notes: list[tuple[AnyioPath, str]] = [] # (relative, content)
- deleted_notes: list[AnyioPath] = []
-
- for change_type, file_path in changes:
- try:
- match change_type:
- case ChangeType.added:
- content = await AnyioPath(file_path).read_text("utf-8")
- added_notes.append((file_path, content))
- case ChangeType.modified:
- content = await AnyioPath(file_path).read_text("utf-8")
- updated_notes.append((file_path, content))
- case ChangeType.deleted:
- deleted_notes.append(file_path)
- except Exception:
- # read file failed, skip
- pass
-
- notes_dir = await self.get_notes_dir()
- normalized_path: Callable[[AnyioPath], str] = lambda path: path.relative_to(notes_dir).as_posix()
-
- async with db_context() as db_session:
- workspace = await WorkspaceService(db_session).get_workspace_by_id(self._workspace_id)
- existing_notes: dict[str, workspace_models.WorkspaceNote] = {
- note.relative: note for note in workspace.notes
- }
-
- for path in deleted_notes:
- existing_notes.pop(normalized_path(path))
- for path, content in added_notes:
- relative = normalized_path(path)
- existing_notes[relative] = workspace_models.WorkspaceNote(
- relative=relative,
- content=content,
- )
- for path, content in updated_notes:
- relative = normalized_path(path)
- existing_note = existing_notes.get(relative, workspace_models.WorkspaceNote(
- relative=relative,
- content=content,
- ))
- existing_note.content = content
-
- workspace.notes = list(existing_notes.values())
-
- async def _watch_files(self, notes_dir: AnyioPath, stop_event: asyncio.Event) -> None:
- if not await notes_dir.exists():
- return
-
- try:
- async for changes in awatch(
- StdPath(notes_dir),
- stop_event=stop_event,
- debounce=500,
- recursive=True,
- ):
- markdown_changes: list[Any] = []
- for change_type, path in changes:
- path = AnyioPath(path)
- if await path.is_symlink() or await path.is_dir(): continue
- if path.suffix.lower() != ".md": continue
- markdown_changes.append((change_type, path))
- await self._handle_file_changes(markdown_changes)
- except asyncio.CancelledError:
- self._logger.debug(f"Notes watch cancelled for workspace {self._workspace_id}")
- except Exception:
- self._logger.exception(f"Error watching notes for workspace {self._workspace_id}")
diff --git a/src-server/src/agent/notes/materializer.py b/src-server/src/agent/notes/materializer.py
new file mode 100644
index 00000000..bb1e7e49
--- /dev/null
+++ b/src-server/src/agent/notes/materializer.py
@@ -0,0 +1,79 @@
+import asyncio
+import shutil
+from anyio import Path as AnyioPath
+from loguru import logger
+from src.common import DATA_DIR
+from src.db import db_context, workspace_models
+from src.schemas import workspace as workspace_schemas
+
+
+class NoteMaterializer:
+ NOTES_DIR_ENVNAME = "DAIS_NOTES_DIR"
+ _logger = logger.bind(name="NoteMaterializer")
+
+ @staticmethod
+ async def get_notes_root_dir() -> AnyioPath:
+ notes_root_dir = AnyioPath(DATA_DIR, ".notes")
+ await notes_root_dir.mkdir(parents=True, exist_ok=True)
+ return notes_root_dir
+
+ @classmethod
+ def get_notes_dir_env(cls, workspace_id: int) -> dict[str, str]:
+ return {cls.NOTES_DIR_ENVNAME: str(DATA_DIR / ".notes" / str(workspace_id))}
+
+ @classmethod
+ async def get_notes_dir(cls, workspace_id: int) -> AnyioPath:
+ notes_root_dir = await cls.get_notes_root_dir()
+ notes_dir = notes_root_dir / str(workspace_id)
+ await notes_dir.mkdir(parents=True, exist_ok=True)
+ return notes_dir
+
+ @classmethod
+ async def get_notes_index(cls, workspace_id: int) -> str | None:
+ notes_dir = await cls.get_notes_dir(workspace_id)
+ index_file = notes_dir / "NOTES.md"
+ if not await index_file.exists(): return None
+ try:
+ return await index_file.read_text(encoding="utf-8")
+ except:
+ cls._logger.exception(f"Failed to read root NOTES.md for workspace {workspace_id}.")
+ return None
+
+ @classmethod
+ async def materialize(cls, workspace: workspace_schemas.WorkspaceRead) -> AnyioPath:
+ notes_dir = await cls.get_notes_dir(workspace.id)
+ for note in workspace.notes:
+ note_path = notes_dir / note.relative
+ await note_path.parent.mkdir(parents=True, exist_ok=True)
+ await note_path.write_text(note.content, "utf-8")
+
+ return notes_dir
+
+ @classmethod
+ async def materialize_all(cls):
+ from src.services.workspace import WorkspaceService
+
+ async with db_context() as db_session:
+ workspaces = await WorkspaceService(db_session).get_all_workspaces()
+
+ sem = asyncio.Semaphore(16)
+ async def sem_materialize(workspace: workspace_models.Workspace):
+ async with sem:
+ workspace_read = workspace_schemas.WorkspaceRead.model_validate(workspace)
+ await cls.clear_materialized(workspace_read.id)
+ await cls.materialize(workspace_read)
+
+ tasks = [sem_materialize(workspace) for workspace in workspaces]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ for result in results:
+ if isinstance(result, BaseException):
+ cls._logger.opt(exception=result).warning("Failed to materialize workspaces")
+
+ @classmethod
+ async def clear_materialized(cls, workspace_id: int):
+ cls._logger.debug(f"Clearing notes for workspace {workspace_id}")
+ notes_root_dir = await cls.get_notes_root_dir()
+ notes_dir = notes_root_dir / str(workspace_id)
+
+ if not await notes_dir.exists(): return
+ await asyncio.to_thread(shutil.rmtree, notes_dir)
diff --git a/src-server/src/agent/notes/watcher.py b/src-server/src/agent/notes/watcher.py
new file mode 100644
index 00000000..213fdb78
--- /dev/null
+++ b/src-server/src/agent/notes/watcher.py
@@ -0,0 +1,111 @@
+from typing import Callable
+from anyio import Path as AnyioPath
+from watchfiles import Change as ChangeType
+from src.db import db_context
+from src.utils import DirectoryWatcher, FileChange
+from .materializer import NoteMaterializer
+from .workspace_ref_manager import WorkspaceRefManager
+
+
+type NoteChange = tuple[ChangeType, AnyioPath]
+
+class NoteWatcher:
+ def __init__(self, workspace_id: int) -> None:
+ self._workspace_id = workspace_id
+ self._ref_acquired = False
+ self._watcher: DirectoryWatcher | None = None
+
+ async def _start(self):
+ if self._watcher is not None:
+ await self._stop()
+
+ WorkspaceRefManager.increase_workspace_ref(self._workspace_id)
+ self._ref_acquired = True
+ try:
+ notes_dir = await NoteMaterializer.get_notes_dir(self._workspace_id)
+ self._watcher = DirectoryWatcher(notes_dir, on_changes=self._handle_file_changes)
+ await self._watcher.start()
+ except BaseException:
+ WorkspaceRefManager.decrease_workspace_ref(self._workspace_id)
+ self._ref_acquired = False
+ raise
+
+ async def _stop(self) -> None:
+ if self._watcher:
+ await self._watcher.stop()
+ self._watcher = None
+ if self._ref_acquired:
+ WorkspaceRefManager.decrease_workspace_ref(self._workspace_id)
+
+ async def _handle_file_changes(self, raw_changes: list[FileChange]) -> None:
+ notes_dir = await NoteMaterializer.get_notes_dir(self._workspace_id)
+
+ markdown_changes: list[NoteChange] = []
+ for change_type, path_str in raw_changes:
+ path = AnyioPath(path_str)
+ if await path.is_symlink() or await path.is_dir():
+ continue
+ if path.suffix.lower() != ".md":
+ continue
+ markdown_changes.append((change_type, path))
+
+ if markdown_changes:
+ await self._handle_note_changes(notes_dir, markdown_changes)
+
+ async def _handle_note_changes(self, base: AnyioPath, changes: list[NoteChange]):
+ from src.services.workspace import WorkspaceService
+ from src.db.models import workspace as workspace_models
+
+ if len(changes) == 0: return
+
+ added_notes: list[tuple[AnyioPath, str]] = [] # (relative, content)
+ updated_notes: list[tuple[AnyioPath, str]] = [] # (relative, content)
+ deleted_notes: list[AnyioPath] = []
+
+ for change_type, file_path in changes:
+ try:
+ match change_type:
+ case ChangeType.added:
+ content = await AnyioPath(file_path).read_text("utf-8")
+ added_notes.append((file_path, content))
+ case ChangeType.modified:
+ content = await AnyioPath(file_path).read_text("utf-8")
+ updated_notes.append((file_path, content))
+ case ChangeType.deleted:
+ deleted_notes.append(file_path)
+ except Exception:
+ # read file failed, skip
+ pass
+
+ normalized_path: Callable[[AnyioPath], str] = lambda path: path.relative_to(base).as_posix()
+
+ async with db_context() as db_session:
+ workspace = await WorkspaceService(db_session).get_workspace_by_id(self._workspace_id)
+ existing_notes: dict[str, workspace_models.WorkspaceNote] = {
+ note.relative: note for note in workspace.notes
+ }
+
+ for path in deleted_notes:
+ existing_notes.pop(normalized_path(path), None)
+ for path, content in added_notes:
+ relative = normalized_path(path)
+ existing_notes[relative] = workspace_models.WorkspaceNote(
+ relative=relative,
+ content=content,
+ )
+ for path, content in updated_notes:
+ relative = normalized_path(path)
+ existing_note = existing_notes.get(relative, workspace_models.WorkspaceNote(
+ relative=relative,
+ content=content,
+ ))
+ existing_note.content = content
+
+ workspace.notes = list(existing_notes.values())
+
+ async def __aenter__(self):
+ await self._start()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self._stop()
diff --git a/src-server/src/agent/notes/workspace_ref_manager.py b/src-server/src/agent/notes/workspace_ref_manager.py
new file mode 100644
index 00000000..960720bb
--- /dev/null
+++ b/src-server/src/agent/notes/workspace_ref_manager.py
@@ -0,0 +1,18 @@
+from collections import Counter
+
+
+class WorkspaceRefManager:
+ _workspace_refs: Counter[int] = Counter()
+
+ @classmethod
+ def increase_workspace_ref(cls, workspace_id: int):
+ cls._workspace_refs[workspace_id] += 1
+
+ @classmethod
+ def decrease_workspace_ref(cls, workspace_id: int):
+ current_count = cls._workspace_refs[workspace_id]
+ cls._workspace_refs[workspace_id] = max(current_count - 1, 0)
+
+ @classmethod
+ def is_workspace_in_use(cls, workspace_id: int) -> bool:
+ return cls._workspace_refs[workspace_id] > 0
diff --git a/src-server/src/agent/skills/materializer.py b/src-server/src/agent/skills/materializer.py
index 844e9002..7482461d 100644
--- a/src-server/src/agent/skills/materializer.py
+++ b/src-server/src/agent/skills/materializer.py
@@ -32,7 +32,7 @@ async def get_skill_dir(skill: skill_schemas.SkillRead) -> Path:
return skill_dir
@classmethod
- async def materialize_skill(cls, skill: skill_schemas.SkillRead) -> Path:
+ async def materialize(cls, skill: skill_schemas.SkillRead) -> Path:
"""
Materialize a skill to a temporary directory and return the directory absolute path.
"""
@@ -57,7 +57,7 @@ async def materialize_skill(cls, skill: skill_schemas.SkillRead) -> Path:
return skill_dir
@classmethod
- async def materialize_skills(cls):
+ async def materialize_all(cls):
async with db_context() as db_session:
skills = await SkillService(db_session).get_all_skills()
@@ -65,17 +65,18 @@ async def materialize_skills(cls):
async def sem_materialize(skill: skill_models.Skill):
async with sem:
skill_read = skill_schemas.SkillRead.model_validate(skill)
- await cls.clear_materialized(skill_read)
- await cls.materialize_skill(skill_read)
+ await cls.clear_materialized(skill_read.id)
+ await cls.materialize(skill_read)
+
tasks = [sem_materialize(skill) for skill in skills]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, BaseException):
- cls._logger.opt(exception=result).warning("Failed to materialize skill")
+ cls._logger.opt(exception=result).warning("Failed to materialize skills")
@classmethod
- async def clear_materialized(cls, skill: skill_schemas.SkillRead):
+ async def clear_materialized(cls, skill_id: int):
skills_dir = await cls.get_skills_dir()
- skill_dir = skills_dir / str(skill.id)
+ skill_dir = skills_dir / str(skill_id)
if not await skill_dir.exists(): return
await asyncio.to_thread(shutil.rmtree, skill_dir)
diff --git a/src-server/src/agent/task/__init__.py b/src-server/src/agent/task/__init__.py
index 7984aaac..e986d1a9 100644
--- a/src-server/src/agent/task/__init__.py
+++ b/src-server/src/agent/task/__init__.py
@@ -1,10 +1,12 @@
import asyncio
from loguru import logger
from dais_sdk.types import ToolMessage, AssistantMessage
+from src.schemas import workspace as workspace_schemas
from src.schemas.tasks import runtime as task_runtime_schemas
from .message_manager import MessageManager, MessageNotFoundError
from .tool_call_manager import ToolCallManager
from .llm_request_manager import LlmRequestManager
+from ..notes import NoteWatcher
from ..context import AgentContext
from ..types import (
AgentGenerator, StopReason,
@@ -37,72 +39,84 @@ def messages(self) -> MessageManager:
def tool_calls(self) -> ToolCallManager:
return self._tool_call_manager
+ @property
+ def workspace(self) -> workspace_schemas.WorkspaceRead:
+ return self._ctx._resource.workspace
+
async def run(self) -> AgentGenerator:
_exited_by_generator_close = False
retries = 0
max_retries = 3
- try:
- while self._is_running:
- last_chunk: MessageEndEvent | TaskInterruptedEvent | ErrorEvent | None = None
- try:
- llm_stream = self._llm_request_manager.create_llm_call()
- async for chunk in llm_stream:
- if isinstance(chunk, self._llm_request_manager.FINISHING_CHUNK_TYPE):
- last_chunk = chunk
- continue
- yield chunk
- except asyncio.CancelledError:
- # Task cancelled by user
- break
-
- match last_chunk:
- case MessageEndEvent() as message_end_chunk:
- retries = 0
- yield message_end_chunk
- case ErrorEvent(error=error, retryable=retryable) as error_chunk:
- self._logger.warning(f"LLM provider error: {error}")
- if not retryable or retries >= max_retries:
- yield error_chunk
- break
- retries += 1
- continue # retry
- case TaskInterruptedEvent() as interrupted_chunk:
- yield interrupted_chunk
- break
- case _ as chunk:
- self._logger.warning(f"Unexpected message event: {chunk}")
+ async with NoteWatcher(self.workspace.id):
+ try:
+ tail_tool_calls = list(self.messages.tail_tool_messages_iter())
+ dispatch_stream, _ = self.tool_calls.dispatch(tail_tool_calls)
+ async for event in dispatch_stream:
+ yield event
+ has_pending_tool_calls = len(self.tool_calls.collect_pendings()) > 0
+ if has_pending_tool_calls: return
+
+ while self._is_running:
+ last_chunk: MessageEndEvent | TaskInterruptedEvent | ErrorEvent | None = None
+ try:
+ llm_stream = self._llm_request_manager.create_llm_call()
+ async for chunk in llm_stream:
+ if isinstance(chunk, self._llm_request_manager.FINISHING_CHUNK_TYPE):
+ last_chunk = chunk
+ continue
+ yield chunk
+ except asyncio.CancelledError:
+ # Task cancelled by user
break
- assistant_message = last_chunk.message
- if (assistant_message.content == None and
- assistant_message.reasoning_content == None and
- (assistant_message.tool_calls is None or len(assistant_message.tool_calls) == 0)):
- # empty message, retry
- continue
+ match last_chunk:
+ case MessageEndEvent() as message_end_chunk:
+ retries = 0
+ yield message_end_chunk
+ case ErrorEvent(error=error, retryable=retryable) as error_chunk:
+ self._logger.warning(f"LLM provider error: {error}")
+ if not retryable or retries >= max_retries:
+ yield error_chunk
+ break
+ retries += 1
+ continue # retry
+ case TaskInterruptedEvent() as interrupted_chunk:
+ yield interrupted_chunk
+ break
+ case _ as chunk:
+ self._logger.warning(f"Unexpected message event: {chunk}")
+ break
- self._ctx.messages.append(assistant_message)
- tool_call_messages = self._extract_tool_call(assistant_message)
- if tool_call_messages is None:
- break
+ assistant_message = last_chunk.message
+ if (assistant_message.content == None and
+ assistant_message.reasoning_content == None and
+ (assistant_message.tool_calls is None or len(assistant_message.tool_calls) == 0)):
+ # empty message, retry
+ continue
- for message in tool_call_messages:
- self._ctx.messages.append(message)
- yield ToolCallEndEvent(message=message)
+ self._ctx.messages.append(assistant_message)
+ tool_call_messages = self._extract_tool_call(assistant_message)
+ if tool_call_messages is None:
+ break
- dispatch_stream, dispatch_result =\
- self._tool_call_manager.dispatch(tool_call_messages)
- async for event in dispatch_stream:
- yield event
- if (dispatch_result.has_finished_task or
- dispatch_result.has_blocked_tool_calls):
- self._is_running = False
- break
- except GeneratorExit:
- _exited_by_generator_close = True
- finally:
- if not _exited_by_generator_close:
- yield TaskDoneEvent()
+ for message in tool_call_messages:
+ self._ctx.messages.append(message)
+ yield ToolCallEndEvent(message=message)
+
+ dispatch_stream, dispatch_result =\
+ self._tool_call_manager.dispatch(tool_call_messages)
+ async for event in dispatch_stream:
+ yield event
+ if (dispatch_result.has_finished_task or
+ dispatch_result.has_blocked_tool_calls):
+ self._is_running = False
+ break
+ except GeneratorExit:
+ _exited_by_generator_close = True
+ finally:
+ if not _exited_by_generator_close:
+ yield TaskDoneEvent()
async def run_until_done(self) -> StopReason:
async for event in self.run():
diff --git a/src-server/src/agent/tool/builtin_tools/file_system.py b/src-server/src/agent/tool/builtin_tools/file_system.py
index 8093d694..7d386dfd 100644
--- a/src-server/src/agent/tool/builtin_tools/file_system.py
+++ b/src-server/src/agent/tool/builtin_tools/file_system.py
@@ -16,8 +16,6 @@
from src.settings import use_app_setting_manager
from src.utils import MarkdownConverter
from ..toolset_wrapper import built_in_tool, BuiltInToolset, BuiltInToolsetContext, BuiltInToolDefaults
-from ...notes import NoteManager
-from ...skills import SkillMaterializer
from ...prompts import create_one_turn_llm, SemanticFileAnalysis, SemanticFileAnalysisInput
@@ -46,10 +44,13 @@ class FileSystemToolset(BuiltInToolset):
def __init__(self,
ctx: BuiltInToolsetContext,
toolset_ent: toolset_models.Toolset | None = None):
+ from ...notes import NoteMaterializer
+ from ...skills import SkillMaterializer
+
super().__init__(ctx, toolset_ent)
self._path_expander = PathExpander({
**SkillMaterializer.get_skill_dir_env(),
- **NoteManager.get_notes_dir_env(ctx.workspace_id),
+ **NoteMaterializer.get_notes_dir_env(ctx.workspace_id),
})
self._markdown_converter = MarkdownConverter()
diff --git a/src-server/src/agent/tool/builtin_tools/os_interactions.py b/src-server/src/agent/tool/builtin_tools/os_interactions.py
index 443bf579..000ae5c1 100644
--- a/src-server/src/agent/tool/builtin_tools/os_interactions.py
+++ b/src-server/src/agent/tool/builtin_tools/os_interactions.py
@@ -5,23 +5,24 @@
from itertools import islice
from dais_shell import AgentShell, CommandStep
from dais_shell.iostream_reader import IOStreamBuffer
-from src.agent.skills import SkillMaterializer
from src.db.models import toolset as toolset_models
from src.binaries import UV_PATH, NODE_PATH
from ..toolset_wrapper import built_in_tool, BuiltInToolset, BuiltInToolsetContext
-from ...notes import NoteManager
class OsInteractionsToolset(BuiltInToolset):
def __init__(self,
ctx: BuiltInToolsetContext,
toolset_ent: toolset_models.Toolset | None = None):
+ from ...notes import NoteMaterializer
+ from ...skills import SkillMaterializer
+
super().__init__(ctx, toolset_ent)
self._shell = AgentShell(
extra_paths=[str(NODE_PATH.parent), str(UV_PATH.parent)],
extra_env={
**SkillMaterializer.get_skill_dir_env(),
- **NoteManager.get_notes_dir_env(ctx.workspace_id),
+ **NoteMaterializer.get_notes_dir_env(ctx.workspace_id),
}
)
diff --git a/src-server/src/agent/tool/toolset_manager/builtin_toolset_manager.py b/src-server/src/agent/tool/toolset_manager/builtin_toolset_manager.py
index a30c31b2..0446d94c 100644
--- a/src-server/src/agent/tool/toolset_manager/builtin_toolset_manager.py
+++ b/src-server/src/agent/tool/toolset_manager/builtin_toolset_manager.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Callable, Sequence, override
+from typing import Sequence, override
from dais_sdk.tool import Toolset
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import db_context
@@ -7,27 +7,16 @@
from ..toolset_wrapper import BuiltInToolset, BuiltInToolsetContext
from ..builtin_tools import BUILT_IN_TOOLSETS
-if TYPE_CHECKING:
- from ...context import ToolRuntimeContext
-
class BuiltinToolsetManager(ToolsetManager):
- def __init__(self, workspace_id: int, cwd: str, runtime_getter: Callable[[], ToolRuntimeContext]):
- self._ctx = BuiltInToolsetContext(workspace_id, cwd, runtime_getter)
+ def __init__(self, workspace_id: int, cwd: str):
+ self._ctx = BuiltInToolsetContext(workspace_id, cwd)
self._toolset_map: dict[str, toolset_models.Toolset] | None = None
self._toolsets: list[BuiltInToolset] | None = None
@classmethod
def default(cls):
- """Create a manager for static built-in tool metadata export only.
-
- The returned manager is intended for metadata-only operations, such as exporting
- built-in tool definitions before a real agent runtime context exists. Any access
- to runtime-backed context fields from this instance is considered invalid.
- """
- def runtime_getter():
- raise ValueError("ToolRuntimeContext is unavailable in BuiltinToolsetManager.default()")
- return cls(1, "~", runtime_getter)
+ return cls(1, "~")
@staticmethod
async def sync_toolsets(db_session: AsyncSession):
@@ -60,7 +49,7 @@ async def initialize(self):
self._toolsets.append(toolset_t(self._ctx, toolset_ent))
@classmethod
- async def create(cls, workspace_id: int, cwd: str, tool_runtime_context: ToolRuntimeContext) -> BuiltinToolsetManager:
- manager = cls(workspace_id, cwd, lambda: tool_runtime_context)
+ async def create(cls, workspace_id: int, cwd: str) -> BuiltinToolsetManager:
+ manager = cls(workspace_id, cwd)
await manager.initialize()
return manager
diff --git a/src-server/src/agent/tool/toolset_wrapper/built_in_toolset.py b/src-server/src/agent/tool/toolset_wrapper/built_in_toolset.py
index 58470a7a..03c7d1d9 100644
--- a/src-server/src/agent/tool/toolset_wrapper/built_in_toolset.py
+++ b/src-server/src/agent/tool/toolset_wrapper/built_in_toolset.py
@@ -1,15 +1,12 @@
from dataclasses import replace
from pathlib import Path
-from typing import Callable, Self, cast, override, TYPE_CHECKING, TypedDict
+from typing import Self, cast, override, TYPE_CHECKING, TypedDict
from dais_sdk.tool import PythonToolset, python_tool
from dais_sdk.types import ToolDef
from sqlalchemy.ext.asyncio import AsyncSession
from ..types import ToolMetadata
-from ...notes import NoteManager
-from ...types import ContextUsage
if TYPE_CHECKING:
- from ...context import ToolRuntimeContext
from ....db.models import toolset as toolset_models
@@ -22,28 +19,19 @@ class BuiltInToolDefaults(TypedDict, total=False):
needs_user_interaction: bool
class BuiltInToolsetContext:
- def __init__(self, workspace_id: int, cwd: str | Path, runtime_getter: Callable[[], ToolRuntimeContext]):
+ def __init__(self, workspace_id: int, cwd: str | Path):
self.workspace_id = workspace_id
self.cwd = Path(cwd).expanduser().resolve()
- self._runtime_getter = runtime_getter
-
- @property
- def usage(self) -> ContextUsage: return self._runtime_getter().usage
-
- @property
- def note_manager(self) -> NoteManager: return self._runtime_getter().note_manager
@classmethod
def default(cls) -> Self:
"""Create a context for static tool metadata export without runtime state.
The returned context is only intended for code paths that inspect built-in tool
- definitions, such as tool metadata synchronization. Runtime-only properties like
- usage and note_manager are intentionally unavailable on this instance.
+ definitions, such as tool metadata synchronization. Runtime-only properties are
+ intentionally unavailable on this instance.
"""
- def runtime_getter():
- raise ValueError("ToolRuntimeContext is unavailable in BuiltInToolsetContext.default()")
- return cls(1, Path.cwd(), runtime_getter)
+ return cls(1, Path.cwd())
class BuiltInToolset(PythonToolset):
def __init__(self,
diff --git a/src-server/src/api/exceptions.py b/src-server/src/api/exceptions.py
index 87c66778..0a51279a 100644
--- a/src-server/src/api/exceptions.py
+++ b/src-server/src/api/exceptions.py
@@ -14,6 +14,8 @@ class ApiErrorCode(StrEnum):
TASK_RESOURCE_NOT_FOUND = "TASK_RESOURCE_NOT_FOUND"
TASK_RESOURCE_SHOULD_HAVE_FILENAME_AND_CONTENTTYPE = "TASK_RESOURCE_SHOULD_HAVE_FILENAME_AND_CONTENTTYPE"
+ WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK = "WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK"
+
INVALID_SKILL_ARCHIVE = "INVALID_SKILL_ARCHIVE"
SUMMARIZE_TITLE_FAILED = "SUMMARIZE_TITLE_FAILED"
diff --git a/src-server/src/api/lifespan.py b/src-server/src/api/lifespan.py
index 6f3491d4..d417f954 100644
--- a/src-server/src/api/lifespan.py
+++ b/src-server/src/api/lifespan.py
@@ -5,6 +5,7 @@
from typing import Coroutine, TypedDict
from fastapi import FastAPI
from src.agent.skills import SkillMaterializer
+from src.agent.notes import NoteMaterializer
from src.agent.task.schedule_runner import init_schedule_runner
from src.agent.tool import BuiltinToolsetManager, McpToolsetManager, use_mcp_toolset_manager
from src.db import engine as database_engine, db_context
@@ -47,7 +48,8 @@ def __init__(self):
self.background_task_manager = BackgroundTaskManager()
async def _init_resources(self):
- self.background_task_manager.add_task(SkillMaterializer.materialize_skills())
+ self.background_task_manager.add_task(SkillMaterializer.materialize_all())
+ self.background_task_manager.add_task(NoteMaterializer.materialize_all())
self.background_task_manager.add_task(self.schedule_runner.load_schedules())
self.background_task_manager.add_task(self._clear_unused_cache())
diff --git a/src-server/src/api/routes/skill.py b/src-server/src/api/routes/skill.py
index 06178c0c..2847a3f1 100644
--- a/src-server/src/api/routes/skill.py
+++ b/src-server/src/api/routes/skill.py
@@ -42,8 +42,8 @@ def process_archive(file_obj: IO[bytes]) -> skill_schemas.SkillCreate:
def start_materializing_background_task(background_tasks: BackgroundTasks, skill_ent: skill_models.Skill):
skill_data = skill_schemas.SkillRead.model_validate(skill_ent)
async def clear_and_rematerialize(skill: skill_schemas.SkillRead):
- await SkillMaterializer.clear_materialized(skill)
- await SkillMaterializer.materialize_skill(skill)
+ await SkillMaterializer.clear_materialized(skill.id)
+ await SkillMaterializer.materialize(skill)
background_tasks.add_task(clear_and_rematerialize, skill_data)
@skills_router.get("/", response_model=Page[skill_schemas.SkillBrief])
@@ -99,5 +99,4 @@ async def delete_skill(
background_tasks: BackgroundTasks,
):
deleted_skill = await SkillService(db_session).delete_skill(skill_id)
- skill_data = skill_schemas.SkillRead.model_validate(deleted_skill)
- background_tasks.add_task(SkillMaterializer.clear_materialized, skill_data)
+ background_tasks.add_task(SkillMaterializer.clear_materialized, deleted_skill.id)
diff --git a/src-server/src/api/routes/tasks/schedule.py b/src-server/src/api/routes/tasks/schedule.py
index 2b8408c3..f4318473 100644
--- a/src-server/src/api/routes/tasks/schedule.py
+++ b/src-server/src/api/routes/tasks/schedule.py
@@ -4,7 +4,9 @@
from src.agent.task.schedule_runner import use_schedule_runner
from src.db import DbSessionDep
from src.schemas.tasks import schedule as schedule_schemas
+from src.services.tasks import TaskResourceService
from src.services.schedule import RunRecordService, ScheduleService
+from src.schemas.tasks import runtime as task_runtime_schemas
schedule_manage_router = APIRouter(tags=["schedule"])
@@ -70,3 +72,4 @@ async def get_run_record(run_record_id: int, db_session: DbSessionDep):
@schedule_manage_router.delete("/records/{run_record_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_run_record(run_record_id: int, db_session: DbSessionDep):
await RunRecordService(db_session).delete_run_record(run_record_id)
+ await TaskResourceService(db_session, task_runtime_schemas.TaskType.SCHEDULE).delete_task_resources(run_record_id)
diff --git a/src-server/src/api/routes/tasks/stream.py b/src-server/src/api/routes/tasks/stream.py
index 3ba7b37b..995a66d0 100644
--- a/src-server/src/api/routes/tasks/stream.py
+++ b/src-server/src/api/routes/tasks/stream.py
@@ -70,23 +70,6 @@ async def continue_task(
body: ContinueTaskBody,
request: Request,
):
- """
- Directly continue a existing task
- """
task = await create_agent_task(task_type, task_id, body.agent_id)
-
- # ensure all approved tool calls are executed before continuing
- try:
- tail_tool_calls = list(task.messages.tail_tool_messages_iter())
- dispatch_stream, _ = task.tool_calls.dispatch(tail_tool_calls)
- async for event in dispatch_stream:
- yield event
- finally:
- await asyncio.shield(task.persist())
-
- if len(task.tool_calls.collect_pendings()) > 0:
- # prevent starting agent loop when there are still unresolved tool calls
- yield TaskDoneEvent()
- return
async for event in stream_connector(task, request):
yield event
diff --git a/src-server/src/api/routes/tasks/task.py b/src-server/src/api/routes/tasks/task.py
index a1f2bcc3..ed672af4 100644
--- a/src-server/src/api/routes/tasks/task.py
+++ b/src-server/src/api/routes/tasks/task.py
@@ -8,8 +8,11 @@
from src.db import DbSessionDep
from src.db.models import agent as agent_models
from src.db.models import tasks as task_models
-from src.services.tasks import TaskService
-from src.schemas.tasks import task as task_schemas
+from src.services.tasks import TaskResourceService, TaskService
+from src.schemas.tasks import (
+ task as task_schemas,
+ runtime as task_runtime_schemas,
+)
from ...exceptions import ApiError, ApiErrorCode
@@ -90,3 +93,4 @@ async def summarize_task_title(task_id: int, db_session: DbSessionDep):
@task_manage_router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_task(task_id: int, db_session: DbSessionDep):
await TaskService(db_session).delete_task(task_id)
+ await TaskResourceService(db_session, task_runtime_schemas.TaskType.TASK).delete_task_resources(task_id)
diff --git a/src-server/src/api/routes/workspace.py b/src-server/src/api/routes/workspace.py
index 8c4b008e..f35c8854 100644
--- a/src-server/src/api/routes/workspace.py
+++ b/src-server/src/api/routes/workspace.py
@@ -2,9 +2,11 @@
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import apaginate
from src.db import DbSessionDep
+from src.agent.notes import NoteMaterializer, WorkspaceRefManager
from src.services.workspace import WorkspaceService
from src.schemas import workspace as workspace_schemas
from src.utils.open_in_file_manager import open_in_file_manager
+from ..exceptions import ApiError, ApiErrorCode
workspaces_router = APIRouter(tags=["workspace"])
@@ -23,7 +25,10 @@ async def create_workspace(
db_session: DbSessionDep,
body: workspace_schemas.WorkspaceCreate,
):
- return await WorkspaceService(db_session).create_workspace(body)
+ created_workspace = await WorkspaceService(db_session).create_workspace(body)
+ workspace = workspace_schemas.WorkspaceRead.model_validate(created_workspace)
+ await NoteMaterializer.materialize(workspace)
+ return workspace
@workspaces_router.put("/{workspace_id}", response_model=workspace_schemas.WorkspaceRead)
async def update_workspace(
@@ -33,9 +38,24 @@ async def update_workspace(
):
return await WorkspaceService(db_session).update_workspace(workspace_id, body)
+@workspaces_router.put("/{workspace_id}/notes", response_model=workspace_schemas.WorkspaceRead)
+async def update_workspace_notes(
+ workspace_id: int,
+ body: workspace_schemas.WorkspaceNotesUpdate,
+ db_session: DbSessionDep,
+):
+ if WorkspaceRefManager.is_workspace_in_use(workspace_id):
+ raise ApiError(status.HTTP_409_CONFLICT, ApiErrorCode.WORKSPACE_NOTES_LOCKED_BY_RUNNING_TASK)
+ updated_workspace = await WorkspaceService(db_session).update_workspace_notes(workspace_id, body)
+ workspace = workspace_schemas.WorkspaceRead.model_validate(updated_workspace)
+ await NoteMaterializer.clear_materialized(workspace.id)
+ await NoteMaterializer.materialize(workspace)
+ return workspace
+
@workspaces_router.delete("/{workspace_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_workspace(workspace_id: int, db_session: DbSessionDep):
await WorkspaceService(db_session).delete_workspace(workspace_id)
+ await NoteMaterializer.clear_materialized(workspace_id)
# --- --- --- --- --- ---
diff --git a/src-server/src/logger.py b/src-server/src/logger.py
index 53a30407..1e261918 100644
--- a/src-server/src/logger.py
+++ b/src-server/src/logger.py
@@ -39,6 +39,7 @@ def setup_logging(log_level: int):
NOISY_LIBS = (
"aiosqlite",
+ "apscheduler",
"binaryornot",
"httpcore",
"httpx",
diff --git a/src-server/src/schemas/workspace.py b/src-server/src/schemas/workspace.py
index 7774b5b0..c0a706dd 100644
--- a/src-server/src/schemas/workspace.py
+++ b/src-server/src/schemas/workspace.py
@@ -37,7 +37,9 @@ class WorkspaceUpdate(DTOBase):
name: str | None
directory: str | None
instruction: str | None
- notes: list[WorkspaceNoteBase] | None
usable_agent_ids: list[int] | None
usable_tool_ids: list[int] | None
usable_skill_ids: list[int] | None = None
+
+class WorkspaceNotesUpdate(DTOBase):
+ notes: list[WorkspaceNoteBase]
diff --git a/src-server/src/services/agent.py b/src-server/src/services/agent.py
index d03eef8e..8184ea4a 100644
--- a/src-server/src/services/agent.py
+++ b/src-server/src/services/agent.py
@@ -5,7 +5,6 @@
from src.schemas import agent as agent_schemas
from .service_base import ServiceBase
from .exceptions import NotFoundError, ServiceErrorCode
-from .utils import build_load_options, Relations
class AgentNotFoundError(NotFoundError):
@@ -14,10 +13,10 @@ def __init__(self, agent_id: int) -> None:
class AgentService(ServiceBase):
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- agent_models.Agent.model,
- agent_models.Agent.usable_tools,
+ selectinload(agent_models.Agent.model),
+ selectinload(agent_models.Agent.usable_tools),
]
def get_agents_query(self):
@@ -30,7 +29,7 @@ def get_agents_query(self):
async def get_agent_by_id(self, id: int) -> agent_models.Agent:
agent = await self._db_session.get(
agent_models.Agent, id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not agent:
raise AgentNotFoundError(id)
diff --git a/src-server/src/services/provider.py b/src-server/src/services/provider.py
index f500056d..746047e8 100644
--- a/src-server/src/services/provider.py
+++ b/src-server/src/services/provider.py
@@ -1,10 +1,10 @@
from loguru import logger
from sqlalchemy import select
+from sqlalchemy.orm import selectinload
from src.db.models import provider as provider_models
from src.schemas import provider as provider_schemas
from .service_base import ServiceBase
from .exceptions import NotFoundError, ServiceErrorCode
-from .utils import build_load_options, Relations
_logger = logger.bind(name="ProviderService")
@@ -15,15 +15,15 @@ def __init__(self, provider_id: int) -> None:
class ProviderService(ServiceBase):
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- provider_models.Provider.models,
+ selectinload(provider_models.Provider.models),
]
def get_providers_query(self):
return (
select(provider_models.Provider)
- .options(*build_load_options(self.relations()))
+ .options(*self.relations())
.order_by(provider_models.Provider.id.asc())
)
@@ -36,7 +36,7 @@ async def get_provider_by_id(self, provider_id: int) -> provider_models.Provider
provider = await self._db_session.get(
provider_models.Provider,
provider_id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not provider:
raise ProviderNotFoundError(provider_id)
diff --git a/src-server/src/services/schedule.py b/src-server/src/services/schedule.py
index 6aaaab6b..15af94f5 100644
--- a/src-server/src/services/schedule.py
+++ b/src-server/src/services/schedule.py
@@ -1,10 +1,10 @@
from dais_sdk.types import UserMessage
from sqlalchemy import select
+from sqlalchemy.orm import selectinload
from src.db.models import tasks as task_models
from src.schemas.tasks import schedule as schedule_schemas
from .service_base import ServiceBase
from .exceptions import NotFoundError, ServiceErrorCode
-from .utils import build_load_options, Relations
class ScheduleNotFoundError(NotFoundError):
@@ -13,10 +13,10 @@ def __init__(self, schedule_id: int) -> None:
class ScheduleService(ServiceBase):
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- task_models.Schedule.agent,
- task_models.Schedule.workspace,
+ selectinload(task_models.Schedule.agent),
+ selectinload(task_models.Schedule.workspace),
]
def get_all_schedules_query(self):
@@ -43,7 +43,7 @@ async def get_schedule_by_id(self, id: int) -> task_models.Schedule:
schedule = await self._db_session.get(
task_models.Schedule,
id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not schedule:
raise ScheduleNotFoundError(id)
@@ -88,9 +88,9 @@ def __init__(self, run_record_id: int) -> None:
class RunRecordService(ServiceBase):
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- task_models.RunRecord.schedule,
+ selectinload(task_models.RunRecord.schedule),
]
def get_run_records_query(self, schedule_id: int):
@@ -104,7 +104,7 @@ async def get_run_record_by_id(self, id: int) -> task_models.RunRecord:
run_record = await self._db_session.get(
task_models.RunRecord,
id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not run_record:
raise RunRecordNotFoundError(id)
diff --git a/src-server/src/services/skill.py b/src-server/src/services/skill.py
index 55ed9c93..89bdddda 100644
--- a/src-server/src/services/skill.py
+++ b/src-server/src/services/skill.py
@@ -4,7 +4,6 @@
from src.schemas import skill as skill_schemas
from .service_base import ServiceBase
from .exceptions import NotFoundError, ConflictError, ServiceErrorCode
-from .utils import build_load_options, Relations
class SkillNotFoundError(NotFoundError):
@@ -20,8 +19,10 @@ def __init__(self, name: str) -> None:
class SkillService(ServiceBase):
@staticmethod
- def relations() -> Relations:
- return [skill_models.Skill.resources]
+ def relations():
+ return [
+ selectinload(skill_models.Skill.resources)
+ ]
def get_skills_query(self):
return (
@@ -34,7 +35,7 @@ async def get_all_skills(self) -> list[skill_models.Skill]:
stmt = (
select(skill_models.Skill)
.order_by(skill_models.Skill.id.asc())
- .options(*build_load_options(self.relations()))
+ .options(*self.relations())
)
skills = (await self._db_session.scalars(stmt)).all()
return list(skills)
@@ -43,7 +44,7 @@ async def get_skill_by_id(self, id: int) -> skill_models.Skill:
skill = await self._db_session.get(
skill_models.Skill,
id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not skill:
raise SkillNotFoundError(id)
@@ -53,7 +54,7 @@ async def get_skill_by_name(self, name: str) -> skill_models.Skill:
stmt = (
select(skill_models.Skill)
.where(skill_models.Skill.name == name)
- .options(*build_load_options(self.relations()))
+ .options(*self.relations())
)
skill = await self._db_session.scalar(stmt)
if not skill:
diff --git a/src-server/src/services/tasks/schedule.py b/src-server/src/services/tasks/schedule.py
index a8b3e1ca..727c0b0c 100644
--- a/src-server/src/services/tasks/schedule.py
+++ b/src-server/src/services/tasks/schedule.py
@@ -1,11 +1,11 @@
from sqlalchemy import select
+from sqlalchemy.orm import selectinload
from src.db.models import tasks as task_models
from src.schemas.tasks import schedule as schedule_schemas
from src.schemas.tasks import runtime as task_runtime_schemas
from .resource import TaskResourceService
from ..service_base import ServiceBase
from ..exceptions import NotFoundError, ServiceErrorCode
-from ..utils import build_load_options, Relations
class ScheduleNotFoundError(NotFoundError):
@@ -14,10 +14,10 @@ def __init__(self, schedule_id: int) -> None:
class ScheduleService(ServiceBase):
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- task_models.Schedule.agent,
- task_models.Schedule.workspace,
+ selectinload(task_models.Schedule.agent),
+ selectinload(task_models.Schedule.workspace),
]
def get_all_schedules_query(self):
@@ -44,7 +44,7 @@ async def get_schedule_by_id(self, id: int) -> task_models.Schedule:
schedule = await self._db_session.get(
task_models.Schedule,
id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not schedule:
raise ScheduleNotFoundError(id)
@@ -89,9 +89,9 @@ def __init__(self, run_record_id: int) -> None:
class RunRecordService(ServiceBase):
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- task_models.RunRecord.schedule,
+ selectinload(task_models.RunRecord.schedule),
]
def get_run_records_query(self, schedule_id: int):
@@ -105,7 +105,7 @@ async def get_run_record_by_id(self, id: int) -> task_models.RunRecord:
run_record = await self._db_session.get(
task_models.RunRecord,
id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not run_record:
raise RunRecordNotFoundError(id)
diff --git a/src-server/src/services/tasks/task.py b/src-server/src/services/tasks/task.py
index fba96cee..0608ee09 100644
--- a/src-server/src/services/tasks/task.py
+++ b/src-server/src/services/tasks/task.py
@@ -1,11 +1,9 @@
from sqlalchemy import select
+from sqlalchemy.orm import selectinload
from src.db.models import tasks as task_models
from src.schemas.tasks import task as task_schemas
-from src.schemas.tasks import runtime as task_runtime_schemas
from ..service_base import ServiceBase
from ..exceptions import NotFoundError, ServiceErrorCode
-from ..utils import build_load_options, Relations
-from .resource import TaskResourceService
class TaskNotFoundError(NotFoundError):
@@ -15,10 +13,10 @@ def __init__(self, task_id: int) -> None:
class TaskService(ServiceBase):
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- task_models.Task.agent,
- task_models.Task.workspace,
+ selectinload(task_models.Task.agent),
+ selectinload(task_models.Task.workspace),
]
def get_tasks_query(self, workspace_id: int):
@@ -40,7 +38,7 @@ def get_recent_tasks_query(self):
async def get_task_by_id(self, id: int) -> task_models.Task:
task = await self._db_session.get(
task_models.Task, id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not task:
raise TaskNotFoundError(id)
@@ -74,6 +72,5 @@ async def update_task(self, id: int, data: task_schemas.TaskUpdate) -> task_mode
async def delete_task(self, id: int) -> None:
task = await self.get_task_by_id(id)
- await TaskResourceService(self._db_session, task_runtime_schemas.TaskType.TASK).delete_task_resources(id)
await self._db_session.delete(task)
await self._db_session.flush()
diff --git a/src-server/src/services/toolset.py b/src-server/src/services/toolset.py
index 10ef60e7..48f740f6 100644
--- a/src-server/src/services/toolset.py
+++ b/src-server/src/services/toolset.py
@@ -1,11 +1,11 @@
from typing import NamedTuple
from dais_sdk.types import ToolDef
from sqlalchemy import select
+from sqlalchemy.orm import selectinload
from src.db.models import toolset as toolset_models
from src.schemas import toolset as toolset_schemas
from .service_base import ServiceBase
from .exceptions import NotFoundError, ConflictError, ServiceErrorCode
-from .utils import build_load_options, Relations
class ToolsetNotFoundError(NotFoundError):
@@ -28,9 +28,9 @@ class ToolLike(NamedTuple):
auto_approve: bool = False
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- toolset_models.Toolset.tools,
+ selectinload(toolset_models.Toolset.tools),
]
async def get_all_mcp_toolsets(self) -> list[toolset_models.Toolset]:
@@ -44,7 +44,7 @@ async def get_all_mcp_toolsets(self) -> list[toolset_models.Toolset]:
]
)
)
- .options(*build_load_options(self.relations()))
+ .options(*self.relations())
)
toolsets = (await self._db_session.scalars(stmt)).all()
return list(toolsets)
@@ -53,7 +53,7 @@ async def get_all_built_in_toolsets(self) -> list[toolset_models.Toolset]:
stmt = (
select(toolset_models.Toolset)
.where(toolset_models.Toolset.type == toolset_models.ToolsetType.BUILT_IN)
- .options(*build_load_options(self.relations()))
+ .options(*self.relations())
)
toolsets = (await self._db_session.scalars(stmt)).all()
return list(toolsets)
@@ -62,7 +62,7 @@ async def get_toolset_by_id(self, id: int) -> toolset_models.Toolset:
toolset = await self._db_session.get(
toolset_models.Toolset,
id,
- options=build_load_options(self.relations()),
+ options=self.relations(),
)
if not toolset:
raise ToolsetNotFoundError(id)
@@ -72,7 +72,7 @@ async def get_toolset_by_internal_key(self, internal_key: str) -> toolset_models
stmt = (
select(toolset_models.Toolset)
.where(toolset_models.Toolset.internal_key == internal_key)
- .options(*build_load_options(self.relations()))
+ .options(*self.relations())
)
toolset = await self._db_session.scalar(stmt)
if not toolset:
diff --git a/src-server/src/services/utils.py b/src-server/src/services/utils.py
deleted file mode 100644
index 71cc905b..00000000
--- a/src-server/src/services/utils.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from sqlalchemy.orm import QueryableAttribute, selectinload
-
-type Relations = list[QueryableAttribute]
-
-def build_load_options(
- relations: list[QueryableAttribute],
- root: QueryableAttribute | None = None,
-):
- if root is not None:
- base = selectinload(root)
- return [base.selectinload(rel) for rel in relations]
- else:
- return [selectinload(rel) for rel in relations]
diff --git a/src-server/src/services/workspace.py b/src-server/src/services/workspace.py
index f055902e..a6cb5ffc 100644
--- a/src-server/src/services/workspace.py
+++ b/src-server/src/services/workspace.py
@@ -1,6 +1,5 @@
from sqlalchemy import select
from sqlalchemy.orm import selectinload
-from src.agent.notes.manager import NoteManager
from src.db.models import workspace as workspace_models
from src.db.models import agent as agent_models
from src.db.models import toolset as toolset_models
@@ -9,7 +8,6 @@
from .service_base import ServiceBase
from .exceptions import NotFoundError, ServiceErrorCode
from .agent import AgentService
-from .utils import build_load_options, Relations
class WorkspaceNotFoundError(NotFoundError):
@@ -18,17 +16,20 @@ def __init__(self, workspace_id: int) -> None:
class WorkspaceService(ServiceBase):
@staticmethod
- def relations() -> Relations:
+ def relations():
return [
- workspace_models.Workspace.usable_tools,
- workspace_models.Workspace.usable_agents,
- workspace_models.Workspace.usable_skills,
+ selectinload(workspace_models.Workspace.usable_tools),
+ selectinload(workspace_models.Workspace.usable_agents)
+ .selectinload(agent_models.Agent.model),
+ selectinload(workspace_models.Workspace.usable_skills),
+ selectinload(workspace_models.Workspace.notes),
]
def get_workspaces_query(self):
return (
select(workspace_models.Workspace)
.order_by(workspace_models.Workspace.id.asc())
+ .options(*self.relations())
)
async def _update_relations(
@@ -36,20 +37,11 @@ async def _update_relations(
workspace: workspace_models.Workspace,
data: workspace_schemas.WorkspaceCreate | workspace_schemas.WorkspaceUpdate,
):
- if data.notes is not None:
- workspace.notes = [
- workspace_models.WorkspaceNote(
- relative=note.relative,
- content=note.content,
- )
- for note in data.notes
- ]
-
if data.usable_agent_ids is not None:
stmt = (
select(agent_models.Agent)
.where(agent_models.Agent.id.in_(data.usable_agent_ids))
- .options(*build_load_options(AgentService.relations()))
+ .options(*AgentService.relations())
)
agents = (await self._db_session.scalars(stmt)).all()
workspace.usable_agents = list(agents)
@@ -68,17 +60,20 @@ async def _update_relations(
skills = (await self._db_session.scalars(stmt)).all()
workspace.usable_skills = list(skills)
+ async def get_all_workspaces(self) -> list[workspace_models.Workspace]:
+ stmt = (
+ select(workspace_models.Workspace)
+ .order_by(workspace_models.Workspace.id.asc())
+ .options(*self.relations())
+ )
+ workspaces = (await self._db_session.scalars(stmt)).all()
+ return list(workspaces)
+
async def get_workspace_by_id(self, id: int) -> workspace_models.Workspace:
workspace = await self._db_session.get(
workspace_models.Workspace,
id,
- options=[
- selectinload(workspace_models.Workspace.usable_tools),
- selectinload(workspace_models.Workspace.usable_agents)
- .selectinload(agent_models.Agent.model),
- selectinload(workspace_models.Workspace.usable_skills),
- selectinload(workspace_models.Workspace.notes),
- ],
+ options=self.relations(),
)
if not workspace:
raise WorkspaceNotFoundError(id)
@@ -88,6 +83,13 @@ async def create_workspace(self, data: workspace_schemas.WorkspaceCreate) -> wor
create_data = data.model_dump(exclude={"notes", "usable_agent_ids", "usable_tool_ids", "usable_skill_ids"})
new_workspace = workspace_models.Workspace(**create_data)
+ new_workspace.notes = [
+ workspace_models.WorkspaceNote(
+ relative=note.relative,
+ content=note.content,
+ )
+ for note in data.notes
+ ]
await self._update_relations(new_workspace, data)
self._db_session.add(new_workspace)
@@ -108,8 +110,22 @@ async def update_workspace(self, id: int, data: workspace_schemas.WorkspaceUpdat
updated_workspace = await self.get_workspace_by_id(workspace.id)
return updated_workspace
+ async def update_workspace_notes(self, id, data: workspace_schemas.WorkspaceNotesUpdate):
+ workspace = await self.get_workspace_by_id(id)
+ workspace.notes = [
+ workspace_models.WorkspaceNote(
+ relative=note.relative,
+ content=note.content,
+ )
+ for note in data.notes
+ ]
+ await self._db_session.flush()
+ self._db_session.expunge(workspace)
+
+ updated_workspace = await self.get_workspace_by_id(workspace.id)
+ return updated_workspace
+
async def delete_workspace(self, id: int) -> None:
workspace = await self.get_workspace_by_id(id)
await self._db_session.delete(workspace)
await self._db_session.flush()
- await NoteManager(workspace.id).clear_materialized(force=True)
diff --git a/src-server/src/utils/__init__.py b/src-server/src/utils/__init__.py
index 5a0a55be..00102afa 100644
--- a/src-server/src/utils/__init__.py
+++ b/src-server/src/utils/__init__.py
@@ -1,6 +1,7 @@
+from .directory_watcher import DirectoryWatcher, FileChange
from .markdown_converter import MarkdownConverter
from .get_unique_filename import get_unique_filename
from .open_in_file_manager import open_in_file_manager
from .to_base64_str import to_base64_str
from .parent_watchdog import ParentWatchdog
-from .scheduler import Scheduler
+from .scheduler import Scheduler
\ No newline at end of file
diff --git a/src-server/src/utils/directory_watcher.py b/src-server/src/utils/directory_watcher.py
new file mode 100644
index 00000000..98a7b4a7
--- /dev/null
+++ b/src-server/src/utils/directory_watcher.py
@@ -0,0 +1,116 @@
+import asyncio
+from typing import Callable, Awaitable
+from anyio import Path as AnyioPath
+from pathlib import Path as StdPath
+from loguru import logger
+from watchfiles import awatch, Change as ChangeType
+
+
+type FileChange = tuple[ChangeType, str]
+type ChangeHandler = Callable[[list[FileChange]], Awaitable]
+
+WATCHFILES_SENTINEL = ".watchfiles_sentinel"
+
+class DirectoryWatcher:
+ _logger = logger.bind(name="DirectoryWatcher")
+
+ def __init__(self, dir: AnyioPath, on_changes: ChangeHandler):
+ self._dir = dir
+ self._on_changes = on_changes
+ self._change_event: asyncio.Event | None = None
+ self._stop_event: asyncio.Event | None = None
+ self._watch_task: asyncio.Task | None = None
+
+ async def start(self) -> None:
+ if self._watch_task is not None:
+ await self.stop()
+
+ self._stop_event = asyncio.Event()
+ self._change_event = asyncio.Event()
+ ready_event = asyncio.Event()
+
+ self._watch_task = asyncio.create_task(
+ self._watch_loop(ready_event, self._change_event, self._stop_event))
+
+ # TODO: use the builtin ready_event of awatch after https://github.com/samuelcolvin/watchfiles/pull/356 is merged
+ sentinel = self._dir / WATCHFILES_SENTINEL
+ await sentinel.write_bytes(b"")
+ await sentinel.unlink(missing_ok=True)
+ await ready_event.wait()
+
+ async def stop(self) -> None:
+ if self._stop_event:
+ await self._drain()
+ self._stop_event.set()
+ self._stop_event = None
+
+ if self._watch_task and not self._watch_task.done():
+ try:
+ await asyncio.wait_for(self._watch_task, timeout=2.0)
+ except asyncio.TimeoutError:
+ self._logger.warning("awatch did not exit gracefully, cancelling")
+ self._watch_task.cancel()
+ try:
+ await self._watch_task
+ except asyncio.CancelledError:
+ # since the watch_task is cancelled programatically,
+ # we should ignore CancelledError here
+ pass
+ finally:
+ self._watch_task = None
+
+ async def _drain(self) -> None:
+ """
+ Wait until the watcher goes quiet before signalling stop.
+ This ensures all the changes generated before `stop` to be handled.
+ """
+ INITIAL_TIMEOUT_SEC = 0.8
+ EXTRA_TIMEOUT_SEC = 0.4
+
+ change_event = self._change_event
+ if change_event is None:
+ return
+
+ timeout = INITIAL_TIMEOUT_SEC
+ while True:
+ change_event.clear()
+ try:
+ await asyncio.wait_for(change_event.wait(), timeout=timeout)
+ timeout = EXTRA_TIMEOUT_SEC
+ except asyncio.TimeoutError:
+ # Quiet for the full window — safe to stop.
+ break
+ self._change_event = None
+
+ async def _watch_loop(self,
+ ready_event: asyncio.Event,
+ change_event: asyncio.Event,
+ stop_event: asyncio.Event) -> None:
+ if not await self._dir.exists():
+ self._logger.warning(f"Watching directory does not exist: {self._dir}")
+ ready_event.set()
+ return
+
+ try:
+ async for changes in awatch(
+ StdPath(self._dir),
+ stop_event=stop_event,
+ recursive=True,
+ debounce=100,
+ ):
+ filtered: list[FileChange] = []
+ for change_type, path in changes:
+ if WATCHFILES_SENTINEL in path:
+ ready_event.set()
+ continue
+ filtered.append((change_type, path))
+
+ if filtered:
+ change_event.set()
+ await self._on_changes(filtered)
+
+ except asyncio.CancelledError:
+ self._logger.debug(f"Watching cancelled: {self._dir}")
+ raise
+ except Exception:
+ self._logger.exception(f"Unexpected error in watcher: {self._dir}")
diff --git a/src-server/tests/agent/notes/test_manager.py b/src-server/tests/agent/notes/test_manager.py
deleted file mode 100644
index 2ed7b014..00000000
--- a/src-server/tests/agent/notes/test_manager.py
+++ /dev/null
@@ -1,699 +0,0 @@
-"""Tests for the NoteManager class."""
-
-import asyncio
-from pathlib import Path
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-from anyio import Path as AnyioPath
-from watchfiles import Change as ChangeType
-
-from src.agent.notes.manager import NoteManager
-from src.db.models import workspace as workspace_models
-
-
-@pytest.fixture(autouse=True)
-def reset_note_manager_ref_counts():
- NoteManager._workspace_ref_counts.clear()
- NoteManager._workspace_ref_lock = None
-
-
-@pytest.fixture
-def note_manager():
- """Return a NoteManager instance for workspace_id=1."""
- return NoteManager(workspace_id=1)
-
-
-@pytest.fixture
-def mock_workspace_with_notes():
- """Return a mock workspace with notes."""
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = [
- workspace_models.WorkspaceNote(relative="README.md", content="# Hello"),
- workspace_models.WorkspaceNote(relative="docs/guide.md", content="## Guide"),
- ]
- return workspace
-
-
-class TestGetNotesRootDir:
- """Tests for get_notes_root_dir static method."""
-
- @pytest.mark.asyncio
- async def test_creates_directory_if_not_exists(self, tmp_path: Path, monkeypatch):
- """Test that get_notes_root_dir creates the directory if it doesn't exist."""
- notes_root = tmp_path / ".notes"
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
-
- result = await NoteManager.get_notes_root_dir()
-
- assert await result.exists()
- assert result.name == ".notes"
-
- @pytest.mark.asyncio
- async def test_returns_existing_directory(self, tmp_path: Path, monkeypatch):
- """Test that get_notes_root_dir returns existing directory without error."""
- notes_root = tmp_path / ".notes"
- notes_root.mkdir()
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
-
- result = await NoteManager.get_notes_root_dir()
-
- assert result == AnyioPath(notes_root)
-
-
-class TestGetNotesDir:
- """Tests for get_notes_dir method."""
-
- @pytest.mark.asyncio
- async def test_creates_workspace_specific_directory(self, tmp_path: Path, monkeypatch):
- """Test that get_notes_dir creates workspace-specific directory."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=42)
-
- result = await manager.get_notes_dir()
-
- assert await result.exists()
- assert result.name == "42"
- assert result.parent.name == ".notes"
-
-
-class TestGetNotesIndex:
- """Tests for get_notes_index method."""
-
- @pytest.mark.asyncio
- async def test_returns_notes_index_content(self, tmp_path: Path, monkeypatch):
- """Test that get_notes_index returns NOTES.md content when it exists."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=7)
-
- notes_dir = await manager.get_notes_dir()
- notes_index = notes_dir / "NOTES.md"
- await notes_index.write_text("# Workspace notes\n\n- item", "utf-8")
-
- result = await manager.get_notes_index()
-
- assert result == "# Workspace notes\n\n- item"
-
- @pytest.mark.asyncio
- async def test_returns_none_when_notes_index_not_exists(self, tmp_path: Path, monkeypatch):
- """Test that get_notes_index returns None when NOTES.md does not exist."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=8)
-
- result = await manager.get_notes_index()
-
- assert result is None
-
- @pytest.mark.asyncio
- async def test_returns_none_when_reading_notes_index_fails(self, tmp_path: Path, monkeypatch):
- """Test that get_notes_index returns None when NOTES.md cannot be read."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=9)
-
- notes_dir = await manager.get_notes_dir()
- notes_index = notes_dir / "NOTES.md"
- await notes_index.write_text("# Workspace notes", "utf-8")
-
- with patch.object(AnyioPath, "read_text", AsyncMock(side_effect=OSError("read failed"))):
- result = await manager.get_notes_index()
-
- assert result is None
-
-
-class TestNotesDirEnv:
- """Tests for notes_dir_env property."""
-
- def test_returns_env_dict_with_correct_path(self, tmp_path: Path, monkeypatch):
- """Test that notes_dir_env returns correct environment variable dict."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- env = NoteManager.get_notes_dir_env(workspace_id=5)
-
- expected_path = str(tmp_path / ".notes" / "5")
- assert env == {NoteManager.NOTES_DIR_ENVNAME: expected_path}
-
-
-class TestMaterialize:
- """Tests for materialize method."""
-
- @pytest.mark.asyncio
- async def test_materializes_notes_to_files(
- self, tmp_path: Path, monkeypatch, mock_workspace_with_notes
- ):
- """Test that materialize writes notes to the filesystem."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = mock_workspace_with_notes
-
- with patch(
- "src.services.workspace.WorkspaceService", return_value=mock_service
- ):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_db_context.return_value.__aenter__ = AsyncMock(
- return_value=AsyncMock()
- )
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- manager = NoteManager(workspace_id=1)
- result = await manager.materialize()
-
- # Check files were created
- readme_path = result / "README.md"
- guide_path = result / "docs" / "guide.md"
-
- assert await readme_path.exists()
- assert await guide_path.exists()
- assert await readme_path.read_text("utf-8") == "# Hello"
- assert await guide_path.read_text("utf-8") == "## Guide"
-
- @pytest.mark.asyncio
- async def test_materialize_empty_notes(self, tmp_path: Path, monkeypatch):
- """Test that materialize works with empty notes list."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
-
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = []
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch(
- "src.services.workspace.WorkspaceService", return_value=mock_service
- ):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_db_context.return_value.__aenter__ = AsyncMock(
- return_value=AsyncMock()
- )
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- manager = NoteManager(workspace_id=1)
- result = await manager.materialize()
-
- assert await result.exists()
-
-
-class TestClearMaterialized:
- """Tests for clear_materialized method."""
-
- @pytest.mark.asyncio
- async def test_removes_notes_directory(self, tmp_path: Path, monkeypatch):
- """Test that clear_materialized removes the workspace notes directory."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
-
- workspace = MagicMock()
- workspace.id = 3
- workspace.notes = [workspace_models.WorkspaceNote(relative="test.md", content="content")]
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch("src.services.workspace.WorkspaceService", return_value=mock_service):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock())
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- manager = NoteManager(workspace_id=3)
- notes_dir = await manager.materialize()
-
- assert await notes_dir.exists()
-
- await manager.clear_materialized()
-
- assert not await notes_dir.exists()
-
- @pytest.mark.asyncio
- async def test_keeps_directory_when_workspace_has_other_active_task(self, tmp_path: Path, monkeypatch):
- """Test that clear_materialized keeps notes when another task in same workspace is active."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
-
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = [workspace_models.WorkspaceNote(relative="README.md", content="# Hello")]
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch("src.services.workspace.WorkspaceService", return_value=mock_service):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock())
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- manager_a = NoteManager(workspace_id=1)
- manager_b = NoteManager(workspace_id=1)
-
- notes_dir = await manager_a.materialize()
- await manager_b.materialize()
-
- assert await notes_dir.exists()
-
- await manager_a.clear_materialized()
- assert await notes_dir.exists()
-
- await manager_b.clear_materialized()
- assert not await notes_dir.exists()
-
- @pytest.mark.asyncio
- async def test_force_clear_always_removes_directory(self, tmp_path: Path, monkeypatch):
- """Test that force clear removes notes directory regardless of ref count."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
-
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = [workspace_models.WorkspaceNote(relative="README.md", content="# Hello")]
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch("src.services.workspace.WorkspaceService", return_value=mock_service):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock())
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- manager_a = NoteManager(workspace_id=1)
- manager_b = NoteManager(workspace_id=1)
-
- notes_dir = await manager_a.materialize()
- await manager_b.materialize()
-
- assert await notes_dir.exists()
-
- await manager_a.clear_materialized(force=True)
- assert not await notes_dir.exists()
-
- @pytest.mark.asyncio
- async def test_handles_nonexistent_directory(self, tmp_path: Path, monkeypatch):
- """Test that clear_materialized handles non-existent directory gracefully."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=999)
-
- # Should not raise
- await manager.clear_materialized()
-
-
-class TestStartWatching:
- """Tests for start_watching method."""
-
- @pytest.mark.asyncio
- async def test_creates_watch_task(self, tmp_path: Path, monkeypatch):
- """Test that start_watching creates a watch task."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- # Create the notes directory
- await manager.get_notes_dir()
-
- with patch("src.agent.notes.manager.awatch") as mock_awatch:
- mock_awatch.return_value.__aiter__ = AsyncMock(
- return_value=iter([])
- )
-
- await manager.start_watching()
-
- assert manager._watch_task is not None
- assert not manager._watch_task.done()
-
- # Cleanup
- await manager.stop_watching()
-
- @pytest.mark.asyncio
- async def test_restarts_existing_watch(self, tmp_path: Path, monkeypatch):
- """Test that start_watching stops existing watch before starting new one."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- # Create the notes directory
- await manager.get_notes_dir()
-
- with patch("src.agent.notes.manager.awatch") as mock_awatch:
- mock_awatch.return_value.__aiter__ = AsyncMock(
- return_value=iter([])
- )
-
- await manager.start_watching()
- first_task = manager._watch_task
-
- await manager.start_watching()
- second_task = manager._watch_task
-
- assert first_task is not second_task
-
- # Cleanup
- await manager.stop_watching()
-
-
-class TestStopWatching:
- """Tests for stop_watching method."""
-
- @pytest.mark.asyncio
- async def test_cancels_watch_task(self, tmp_path: Path, monkeypatch):
- """Test that stop_watching cancels the watch task."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- # Create the notes directory
- await manager.get_notes_dir()
-
- with patch("src.agent.notes.manager.awatch") as mock_awatch:
- mock_awatch.return_value.__aiter__ = AsyncMock(
- return_value=iter([])
- )
-
- await manager.start_watching()
- assert manager._watch_task is not None
-
- await manager.stop_watching()
-
- assert manager._watch_task is None
- assert manager._stop_watching_event is None
-
- @pytest.mark.asyncio
- async def test_handles_already_stopped(self, tmp_path: Path, monkeypatch):
- """Test that stop_watching handles already stopped state gracefully."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- # Should not raise when nothing is watching
- await manager.stop_watching()
-
-
-class TestHandleFileChanges:
- """Tests for _handle_file_changes method."""
-
- @pytest.mark.asyncio
- async def test_handles_added_file(self, tmp_path: Path, monkeypatch):
- """Test that added files are processed correctly."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- notes_dir = await manager.get_notes_dir()
-
- # Create a file to simulate "added"
- test_file = notes_dir / "new.md"
- await test_file.write_text("New content", "utf-8")
-
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = []
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch(
- "src.services.workspace.WorkspaceService", return_value=mock_service
- ):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_session = AsyncMock()
- mock_db_context.return_value.__aenter__ = AsyncMock(
- return_value=mock_session
- )
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- changes = [(ChangeType.added, AnyioPath(test_file))]
- await manager._handle_file_changes(changes)
-
- # Verify workspace.notes was updated
- assert len(workspace.notes) == 1
- assert workspace.notes[0].relative == "new.md"
- assert workspace.notes[0].content == "New content"
-
- @pytest.mark.asyncio
- async def test_handles_modified_file(self, tmp_path: Path, monkeypatch):
- """Test that modified files are processed correctly."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- notes_dir = await manager.get_notes_dir()
-
- # Create existing note
- existing_note = workspace_models.WorkspaceNote(
- relative="existing.md", content="Old content"
- )
-
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = [existing_note]
-
- # Update the file
- test_file = notes_dir / "existing.md"
- await test_file.parent.mkdir(parents=True, exist_ok=True)
- await test_file.write_text("Updated content", "utf-8")
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch(
- "src.services.workspace.WorkspaceService", return_value=mock_service
- ):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_session = AsyncMock()
- mock_db_context.return_value.__aenter__ = AsyncMock(
- return_value=mock_session
- )
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- changes = [(ChangeType.modified, AnyioPath(test_file))]
- await manager._handle_file_changes(changes)
-
- # Verify note was updated
- assert len(workspace.notes) == 1
- assert workspace.notes[0].content == "Updated content"
-
- @pytest.mark.asyncio
- async def test_handles_deleted_file(self, tmp_path: Path, monkeypatch):
- """Test that deleted files are removed from notes."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- notes_dir = await manager.get_notes_dir()
-
- # Create existing note
- existing_note = workspace_models.WorkspaceNote(
- relative="deleted.md", content="Content"
- )
-
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = [existing_note]
-
- # File path (doesn't need to exist for deletion test)
- test_file = notes_dir / "deleted.md"
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch(
- "src.services.workspace.WorkspaceService", return_value=mock_service
- ):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_session = AsyncMock()
- mock_db_context.return_value.__aenter__ = AsyncMock(
- return_value=mock_session
- )
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- changes = [(ChangeType.deleted, AnyioPath(test_file))]
- await manager._handle_file_changes(changes)
-
- # Verify note was removed
- assert len(workspace.notes) == 0
-
- @pytest.mark.asyncio
- async def test_skips_non_md_files(self, tmp_path: Path, monkeypatch):
- """Test that non-.md files are ignored."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- notes_dir = await manager.get_notes_dir()
-
- # Create a non-md file
- test_file = notes_dir / "script.py"
- await test_file.write_text("print('hello')", "utf-8")
-
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = []
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch(
- "src.services.workspace.WorkspaceService", return_value=mock_service
- ):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_session = AsyncMock()
- mock_db_context.return_value.__aenter__ = AsyncMock(
- return_value=mock_session
- )
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- # _watch_files filters non-md files, but _handle_file_changes
- # should handle any file type passed to it
- changes = [(ChangeType.added, AnyioPath(test_file))]
- await manager._handle_file_changes(changes)
-
- # Verify note was still added (handler doesn't filter)
- assert len(workspace.notes) == 1
-
- @pytest.mark.asyncio
- async def test_handles_empty_changes(self, tmp_path: Path, monkeypatch):
- """Test that empty changes are handled gracefully."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- # Should not raise
- await manager._handle_file_changes([])
-
- @pytest.mark.asyncio
- async def test_handles_read_error_gracefully(self, tmp_path: Path, monkeypatch):
- """Test that file read errors are handled gracefully."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- notes_dir = await manager.get_notes_dir()
-
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = []
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- # Create a path that doesn't exist to trigger read error
- nonexistent_file = notes_dir / "nonexistent.md"
-
- with patch(
- "src.services.workspace.WorkspaceService", return_value=mock_service
- ):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_session = AsyncMock()
- mock_db_context.return_value.__aenter__ = AsyncMock(
- return_value=mock_session
- )
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- changes = [(ChangeType.added, AnyioPath(nonexistent_file))]
- # Should not raise even though file doesn't exist
- await manager._handle_file_changes(changes)
-
-
-class TestWatchFiles:
- """Tests for _watch_files method."""
-
- @pytest.mark.asyncio
- async def test_returns_early_if_directory_not_exists(self, tmp_path: Path):
- """Test that _watch_files returns early if directory doesn't exist."""
- manager = NoteManager(workspace_id=999)
- stop_event = asyncio.Event()
-
- nonexistent_dir = AnyioPath(tmp_path / "nonexistent")
-
- # Should complete without error
- await manager._watch_files(nonexistent_dir, stop_event)
-
- @pytest.mark.asyncio
- async def test_filters_md_files(self, tmp_path: Path, monkeypatch):
- """Test that _watch_files filters to only .md files."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- notes_dir = await manager.get_notes_dir()
- stop_event = asyncio.Event()
-
- # Create files
- md_file = notes_dir / "note.md"
- py_file = notes_dir / "script.py"
- await md_file.write_text("# Note", "utf-8")
- await py_file.write_text("print('hello')", "utf-8")
-
- mock_changes = [
- {(ChangeType.added, str(md_file)), (ChangeType.added, str(py_file))}
- ]
-
- with patch(
- "src.agent.notes.manager.awatch"
- ) as mock_awatch, patch.object(
- manager, "_handle_file_changes"
- ) as mock_handle:
-
- async def mock_awatch_generator(*args, **kwargs):
- for change_set in mock_changes:
- yield change_set
- stop_event.set()
-
- mock_awatch.return_value.__aiter__ = mock_awatch_generator
-
- # Run watch briefly then stop
- watch_task = asyncio.create_task(
- manager._watch_files(notes_dir, stop_event)
- )
-
- # Give it time to process
- await asyncio.sleep(0.1)
- stop_event.set()
-
- try:
- await asyncio.wait_for(watch_task, timeout=1.0)
- except asyncio.TimeoutError:
- watch_task.cancel()
-
- # Verify _handle_file_changes was called with only md file
- mock_handle.assert_called_once()
- call_args = mock_handle.call_args[0][0]
- assert len(call_args) == 1
- change_type, path = list(call_args)[0]
- assert path.name == "note.md"
-
- @pytest.mark.asyncio
- async def test_handles_cancelled_error(self, tmp_path: Path, monkeypatch):
- """Test that _watch_files handles CancelledError gracefully."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- notes_dir = await manager.get_notes_dir()
- stop_event = asyncio.Event()
-
- with patch("src.agent.notes.manager.awatch") as mock_awatch:
- mock_awatch.side_effect = asyncio.CancelledError()
-
- # Should not raise
- await manager._watch_files(notes_dir, stop_event)
-
-
-class TestIntegration:
- """Integration tests for NoteManager."""
-
- @pytest.mark.asyncio
- async def test_full_lifecycle(self, tmp_path: Path, monkeypatch):
- """Test the full lifecycle: materialize, watch, modify, clear."""
- monkeypatch.setattr("src.agent.notes.manager.DATA_DIR", tmp_path)
- manager = NoteManager(workspace_id=1)
-
- # Setup mock workspace
- workspace = MagicMock()
- workspace.id = 1
- workspace.notes = [
- workspace_models.WorkspaceNote(relative="test.md", content="Initial")
- ]
-
- mock_service = AsyncMock()
- mock_service.get_workspace_by_id.return_value = workspace
-
- with patch(
- "src.services.workspace.WorkspaceService", return_value=mock_service
- ):
- with patch("src.agent.notes.manager.db_context") as mock_db_context:
- mock_session = AsyncMock()
- mock_db_context.return_value.__aenter__ = AsyncMock(
- return_value=mock_session
- )
- mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
-
- # Materialize
- notes_dir = await manager.materialize()
- assert await (notes_dir / "test.md").exists()
-
- # Clear
- await manager.clear_materialized()
- assert not await notes_dir.exists()
diff --git a/src-server/tests/agent/notes/test_materializer.py b/src-server/tests/agent/notes/test_materializer.py
new file mode 100644
index 00000000..9a4c1fc2
--- /dev/null
+++ b/src-server/tests/agent/notes/test_materializer.py
@@ -0,0 +1,168 @@
+from pathlib import Path
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from anyio import Path as AnyioPath
+
+from src.agent.notes import NoteMaterializer
+from src.db.models import workspace as workspace_models
+from src.schemas import workspace as workspace_schemas
+
+
+@pytest.mark.integration
+class TestNoteMaterializer:
+ @pytest.mark.asyncio
+ async def test_get_notes_root_dir_creates_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+
+ result = await NoteMaterializer.get_notes_root_dir()
+
+ assert await result.exists()
+ assert result == AnyioPath(tmp_path / ".notes")
+
+ @pytest.mark.asyncio
+ async def test_get_notes_dir_creates_workspace_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+
+ result = await NoteMaterializer.get_notes_dir(42)
+
+ assert await result.exists()
+ assert result == AnyioPath(tmp_path / ".notes" / "42")
+
+ def test_get_notes_dir_env_returns_workspace_specific_env(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+
+ env = NoteMaterializer.get_notes_dir_env(5)
+
+ assert env == {
+ NoteMaterializer.NOTES_DIR_ENVNAME: str(tmp_path / ".notes" / "5")
+ }
+
+ @pytest.mark.asyncio
+ async def test_get_notes_index_returns_content(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+
+ notes_dir = await NoteMaterializer.get_notes_dir(7)
+ notes_index = notes_dir / "NOTES.md"
+ await notes_index.write_text("# Workspace notes\n\n- item", "utf-8")
+
+ result = await NoteMaterializer.get_notes_index(7)
+
+ assert result == "# Workspace notes\n\n- item"
+
+ @pytest.mark.asyncio
+ async def test_get_notes_index_returns_none_when_missing(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+
+ result = await NoteMaterializer.get_notes_index(8)
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_get_notes_index_returns_none_when_read_fails(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+
+ notes_dir = await NoteMaterializer.get_notes_dir(9)
+ notes_index = notes_dir / "NOTES.md"
+ await notes_index.write_text("# Workspace notes", "utf-8")
+
+ with patch.object(AnyioPath, "read_text", AsyncMock(side_effect=OSError("read failed"))):
+ result = await NoteMaterializer.get_notes_index(9)
+
+ assert result is None
+
+ @pytest.mark.asyncio
+ async def test_materialize_writes_workspace_notes_to_files(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+ workspace = workspace_schemas.WorkspaceRead.model_validate({
+ "id": 1,
+ "name": "Workspace",
+ "directory": "/tmp/workspace",
+ "instruction": "",
+ "notes": [
+ {"id": 1, "relative": "README.md", "content": "# Hello"},
+ {"id": 2, "relative": "docs/guide.md", "content": "## Guide"},
+ ],
+ "usable_agents": [],
+ "usable_tools": [],
+ "usable_skills": [],
+ })
+
+ notes_dir = await NoteMaterializer.materialize(workspace)
+
+ readme_path = Path(str(notes_dir)) / "README.md"
+ guide_path = Path(str(notes_dir)) / "docs" / "guide.md"
+ assert readme_path.read_text(encoding="utf-8") == "# Hello"
+ assert guide_path.read_text(encoding="utf-8") == "## Guide"
+
+ @pytest.mark.asyncio
+ async def test_clear_materialized_removes_workspace_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+ notes_dir = await NoteMaterializer.get_notes_dir(3)
+ note_path = Path(str(notes_dir)) / "test.md"
+ note_path.write_text("content", encoding="utf-8")
+
+ await NoteMaterializer.clear_materialized(3)
+
+ assert not note_path.exists()
+ assert not Path(str(notes_dir)).exists()
+
+ @pytest.mark.asyncio
+ async def test_clear_materialized_handles_nonexistent_directory(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+
+ await NoteMaterializer.clear_materialized(999)
+
+ assert not (tmp_path / ".notes" / "999").exists()
+
+ @pytest.mark.asyncio
+ async def test_materialize_all_clears_then_materializes_each_workspace(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+
+ workspace_a = MagicMock(spec=workspace_models.Workspace)
+ workspace_a.id = 1
+ workspace_b = MagicMock(spec=workspace_models.Workspace)
+ workspace_b.id = 2
+
+ mock_service = AsyncMock()
+ mock_service.get_all_workspaces.return_value = [workspace_a, workspace_b]
+
+ clear_mock = AsyncMock()
+ materialize_mock = AsyncMock()
+ workspace_read_mock = MagicMock(side_effect=[
+ workspace_schemas.WorkspaceRead.model_validate({
+ "id": 1,
+ "name": "Workspace A",
+ "directory": "/tmp/workspace-a",
+ "instruction": "",
+ "notes": [{"id": 1, "relative": "a.md", "content": "A"}],
+ "usable_agents": [],
+ "usable_tools": [],
+ "usable_skills": [],
+ }),
+ workspace_schemas.WorkspaceRead.model_validate({
+ "id": 2,
+ "name": "Workspace B",
+ "directory": "/tmp/workspace-b",
+ "instruction": "",
+ "notes": [{"id": 2, "relative": "b.md", "content": "B"}],
+ "usable_agents": [],
+ "usable_tools": [],
+ "usable_skills": [],
+ }),
+ ])
+
+ with patch("src.services.workspace.WorkspaceService", return_value=mock_service):
+ with patch("src.agent.notes.materializer.db_context") as mock_db_context:
+ mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock())
+ mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
+ with patch.object(NoteMaterializer, "clear_materialized", clear_mock):
+ with patch.object(NoteMaterializer, "materialize", materialize_mock):
+ with patch.object(workspace_schemas.WorkspaceRead, "model_validate", workspace_read_mock):
+ await NoteMaterializer.materialize_all()
+
+ assert clear_mock.await_args_list == [
+ ((1,), {}),
+ ((2,), {}),
+ ]
+ assert materialize_mock.await_count == 2
diff --git a/src-server/tests/agent/notes/test_watcher.py b/src-server/tests/agent/notes/test_watcher.py
new file mode 100644
index 00000000..108235c8
--- /dev/null
+++ b/src-server/tests/agent/notes/test_watcher.py
@@ -0,0 +1,198 @@
+from pathlib import Path
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from anyio import Path as AnyioPath
+from watchfiles import Change as ChangeType
+
+from src.agent.notes.watcher import NoteWatcher
+from src.agent.notes.workspace_ref_manager import WorkspaceRefManager
+from src.db.models import workspace as workspace_models
+from src.utils.directory_watcher import DirectoryWatcher
+
+
+@pytest.mark.integration
+class TestNoteWatcher:
+ @pytest.mark.asyncio
+ async def test_start_creates_directory_watcher(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+ watcher = NoteWatcher(workspace_id=1)
+ directory_watcher = AsyncMock(spec=DirectoryWatcher)
+
+ with patch(
+ "src.agent.notes.watcher.DirectoryWatcher",
+ return_value=directory_watcher,
+ ) as directory_watcher_cls:
+ await watcher._start()
+
+ assert watcher._watcher is directory_watcher
+ assert WorkspaceRefManager.is_workspace_in_use(1)
+ directory_watcher_cls.assert_called_once()
+ directory_watcher.start.assert_awaited_once()
+
+ await watcher._stop()
+ directory_watcher.stop.assert_awaited_once()
+ assert watcher._watcher is None
+ assert not WorkspaceRefManager.is_workspace_in_use(1)
+
+ @pytest.mark.asyncio
+ async def test_start_restarts_existing_directory_watcher(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+ watcher = NoteWatcher(workspace_id=1)
+ first_directory_watcher = AsyncMock(spec=DirectoryWatcher)
+ second_directory_watcher = AsyncMock(spec=DirectoryWatcher)
+
+ with patch(
+ "src.agent.notes.watcher.DirectoryWatcher",
+ side_effect=[first_directory_watcher, second_directory_watcher],
+ ):
+ await watcher._start()
+ await watcher._start()
+
+ first_directory_watcher.start.assert_awaited_once()
+ first_directory_watcher.stop.assert_awaited_once()
+ second_directory_watcher.start.assert_awaited_once()
+ assert watcher._watcher is second_directory_watcher
+ assert WorkspaceRefManager.is_workspace_in_use(1)
+
+ await watcher._stop()
+ second_directory_watcher.stop.assert_awaited_once()
+ assert not WorkspaceRefManager.is_workspace_in_use(1)
+
+ @pytest.mark.asyncio
+ async def test_stop_handles_already_stopped_state(self):
+ watcher = NoteWatcher(workspace_id=1)
+
+ await watcher._stop()
+
+ assert watcher._watcher is None
+ assert not WorkspaceRefManager.is_workspace_in_use(1)
+
+ @pytest.mark.asyncio
+ async def test_start_rolls_back_workspace_ref_when_directory_watcher_start_fails(
+ self,
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+ ):
+ monkeypatch.setattr("src.agent.notes.materializer.DATA_DIR", tmp_path)
+ watcher = NoteWatcher(workspace_id=1)
+ directory_watcher = AsyncMock(spec=DirectoryWatcher)
+ directory_watcher.start.side_effect = RuntimeError("boom")
+
+ with patch("src.agent.notes.watcher.DirectoryWatcher", return_value=directory_watcher):
+ with pytest.raises(RuntimeError, match="boom"):
+ await watcher._start()
+
+ assert not WorkspaceRefManager.is_workspace_in_use(1)
+
+ @pytest.mark.asyncio
+ async def test_handle_file_changes_filters_and_forwards_markdown_only(self, tmp_path: Path):
+ watcher = NoteWatcher(workspace_id=1)
+ notes_dir = AnyioPath(tmp_path / "notes")
+ await notes_dir.mkdir(parents=True, exist_ok=True)
+
+ md_file = notes_dir / "note.md"
+ txt_file = notes_dir / "note.txt"
+ subdir = notes_dir / "dir"
+ await md_file.write_text("# Note", "utf-8")
+ await txt_file.write_text("plain text", "utf-8")
+ await subdir.mkdir()
+
+ with patch(
+ "src.agent.notes.watcher.NoteMaterializer.get_notes_dir",
+ AsyncMock(return_value=notes_dir),
+ ):
+ with patch.object(watcher, "_handle_note_changes", AsyncMock()) as handle_mock:
+ await watcher._handle_file_changes(
+ [
+ (ChangeType.added, str(md_file)),
+ (ChangeType.added, str(txt_file)),
+ (ChangeType.added, str(subdir)),
+ ]
+ )
+
+ handle_mock.assert_awaited_once()
+ forwarded_base, forwarded_changes = handle_mock.await_args.args
+ assert forwarded_base == notes_dir
+ assert len(forwarded_changes) == 1
+ change_type, path = forwarded_changes[0]
+ assert change_type == ChangeType.added
+ assert path == md_file
+
+ @pytest.mark.asyncio
+ async def test_handle_file_changes_ignores_empty_changes(self, tmp_path: Path):
+ watcher = NoteWatcher(workspace_id=1)
+ notes_dir = AnyioPath(tmp_path / "notes")
+ await notes_dir.mkdir(parents=True, exist_ok=True)
+
+ with patch(
+ "src.agent.notes.watcher.NoteMaterializer.get_notes_dir",
+ AsyncMock(return_value=notes_dir),
+ ):
+ with patch.object(watcher, "_handle_note_changes", AsyncMock()) as handle_mock:
+ await watcher._handle_file_changes([])
+
+ handle_mock.assert_not_awaited()
+
+ @pytest.mark.asyncio
+ async def test_handle_note_changes_adds_updates_and_deletes_notes(self, tmp_path: Path):
+ watcher = NoteWatcher(workspace_id=1)
+ base = AnyioPath(tmp_path / "notes")
+ await base.mkdir(parents=True, exist_ok=True)
+
+ added_path = base / "new.md"
+ updated_path = base / "existing.md"
+ deleted_path = base / "deleted.md"
+ await added_path.write_text("New content", "utf-8")
+ await updated_path.write_text("Updated content", "utf-8")
+
+ workspace = MagicMock()
+ workspace.id = 1
+ workspace.notes = [
+ workspace_models.WorkspaceNote(relative="existing.md", content="Old content"),
+ workspace_models.WorkspaceNote(relative="deleted.md", content="Delete me"),
+ ]
+
+ mock_service = AsyncMock()
+ mock_service.get_workspace_by_id.return_value = workspace
+
+ with patch("src.services.workspace.WorkspaceService", return_value=mock_service):
+ with patch("src.agent.notes.watcher.db_context") as mock_db_context:
+ mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock())
+ mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
+ await watcher._handle_note_changes(
+ base,
+ [
+ (ChangeType.added, added_path),
+ (ChangeType.modified, updated_path),
+ (ChangeType.deleted, deleted_path),
+ ],
+ )
+
+ notes_by_relative = {note.relative: note.content for note in workspace.notes}
+ assert notes_by_relative == {
+ "existing.md": "Updated content",
+ "new.md": "New content",
+ }
+
+ @pytest.mark.asyncio
+ async def test_handle_note_changes_skips_file_read_errors(self, tmp_path: Path):
+ watcher = NoteWatcher(workspace_id=1)
+ base = AnyioPath(tmp_path / "notes")
+ await base.mkdir(parents=True, exist_ok=True)
+
+ missing_path = base / "missing.md"
+ workspace = MagicMock()
+ workspace.id = 1
+ workspace.notes = []
+
+ mock_service = AsyncMock()
+ mock_service.get_workspace_by_id.return_value = workspace
+
+ with patch("src.services.workspace.WorkspaceService", return_value=mock_service):
+ with patch("src.agent.notes.watcher.db_context") as mock_db_context:
+ mock_db_context.return_value.__aenter__ = AsyncMock(return_value=AsyncMock())
+ mock_db_context.return_value.__aexit__ = AsyncMock(return_value=False)
+ await watcher._handle_note_changes(base, [(ChangeType.added, missing_path)])
+
+ assert workspace.notes == []
diff --git a/src-server/tests/services/test_task_service.py b/src-server/tests/services/test_task_service.py
index 9114df5a..03576bd1 100644
--- a/src-server/tests/services/test_task_service.py
+++ b/src-server/tests/services/test_task_service.py
@@ -297,7 +297,6 @@ async def test_delete_task_removes_entity_and_task_resources(
workspace_factory,
agent_factory,
tool_factory,
- task_resource_data_dir: Path,
):
tool = await tool_factory(name="Echo", internal_key="echo")
agent = await agent_factory(name="Agent A", usable_tools=[tool])
@@ -320,15 +319,11 @@ async def test_delete_task_removes_entity_and_task_resources(
"note.txt",
b"resource-bytes",
)
- resource_dir = task_resource_data_dir / ".task-resources" / "task" / str(task.id)
- resource_path = resource_dir / resource.filename
await task_service.delete_task(task.id)
await db_session.flush()
db_session.expunge_all()
- assert not resource_path.exists()
- assert not resource_dir.exists()
with pytest.raises(TaskNotFoundError, match=f"Task '{task.id}' not found"):
await task_service.get_task_by_id(task.id)
diff --git a/src-server/tests/services/test_workspace_service.py b/src-server/tests/services/test_workspace_service.py
index 3fa6daf8..8c4cb660 100644
--- a/src-server/tests/services/test_workspace_service.py
+++ b/src-server/tests/services/test_workspace_service.py
@@ -4,7 +4,7 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from src.agent.notes.manager import NoteManager
+from src.agent.notes import NoteMaterializer
from src.db.models.markdown_cache import MarkdownCache
from src.db.models import tasks as task_models
from src.db.models import workspace as workspace_models
@@ -137,18 +137,12 @@ async def test_delete_workspace_removes_entity_and_cascade_children(
db_session.add(cache)
await db_session.flush()
- note_manager = NoteManager(workspace.id)
- notes_dir = await note_manager.get_notes_dir()
- note_path = Path(str(notes_dir)) / "note.md"
- note_path.write_text("persisted", encoding="utf-8")
-
await workspace_service.delete_workspace(workspace.id)
await db_session.flush()
db_session.expunge_all()
with pytest.raises(WorkspaceNotFoundError, match=f"Workspace '{workspace.id}' not found"):
await workspace_service.get_workspace_by_id(workspace.id)
- assert not note_path.exists()
note_in_db = await db_session.scalar(
select(workspace_models.WorkspaceNote).where(workspace_models.WorkspaceNote.id == note_id)
diff --git a/src-server/tests/tools/conftest.py b/src-server/tests/tools/conftest.py
index de47c071..3882a723 100644
--- a/src-server/tests/tools/conftest.py
+++ b/src-server/tests/tools/conftest.py
@@ -15,4 +15,4 @@ def _init(self, ctx, toolset_ent=None):
@pytest.fixture
def built_in_toolset_context(temp_workspace):
- return BuiltInToolsetContext(1, temp_workspace, ContextUsage.default())
+ return BuiltInToolsetContext(1, temp_workspace)
diff --git a/src-server/tests/tools/file_system/test_init.py b/src-server/tests/tools/file_system/test_init.py
index 50f958a4..c68f7a7e 100644
--- a/src-server/tests/tools/file_system/test_init.py
+++ b/src-server/tests/tools/file_system/test_init.py
@@ -29,7 +29,7 @@ async def test_read_file_resolves_tilde_cwd(self):
marker = home / "__dais_test_marker__.txt"
marker.write_text("tilde-ok", encoding="utf-8")
try:
- tool = FileSystemToolset(BuiltInToolsetContext(1, "~", ContextUsage.default()))
+ tool = FileSystemToolset(BuiltInToolsetContext(1, "~"))
result = await tool.read_file("__dais_test_marker__.txt")
assert "tilde-ok" in result
finally: