Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 110 additions & 7 deletions src/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

import json5
try:
import json5
except ImportError:
json5 = None # type: ignore
from jsonschema import ValidationError, validate

from actions import load_action
Expand Down Expand Up @@ -88,6 +91,45 @@ def validate_config_schema(raw_config: dict) -> None:
raise


def _validate_required_field(
config_dict: dict,
field_name: str,
context: str = "",
field_type: str = "field",
) -> Any:
"""
Validate that a required field exists in a configuration dictionary.

Parameters
----------
config_dict : dict
The configuration dictionary to check.
field_name : str
The name of the required field.
context : str
Additional context for error message (e.g., mode name, rule index).
field_type : str
Type of field for error message (e.g., 'field', 'property').

Returns
-------
Any
The field value if it exists.

Raises
------
ValueError
If the required field is missing with detailed error context.
"""
if field_name not in config_dict:
context_msg = f" in {context}" if context else ""
raise ValueError(
f"Required {field_type} '{field_name}' is missing{context_msg}. "
f"Please ensure this field is defined in your configuration."
)
return config_dict[field_name]


@dataclass
class RuntimeConfig:
"""
Expand Down Expand Up @@ -543,6 +585,8 @@ def load_mode_config(

with open(config_path, "r") as f:
try:
if json5 is None:
raise ImportError("json5 is required to load configuration files")
raw_config = json5.load(f)
except Exception as e:
raise ValueError(
Expand All @@ -564,10 +608,15 @@ def load_mode_config(

load_unitree(g_ut_eth)

# Validate and extract default_mode early to provide clear error messages
default_mode = _validate_required_field(
raw_config, "default_mode", context="global configuration", field_type="field"
)

mode_system_config = ModeSystemConfig(
version=config_version,
name=raw_config.get("name", "mode_system"),
default_mode=raw_config["default_mode"],
default_mode=default_mode,
config_name=config_name,
allow_manual_switching=raw_config.get("allow_manual_switching", True),
mode_memory_enabled=raw_config.get("mode_memory_enabled", True),
Expand All @@ -585,12 +634,20 @@ def load_mode_config(
)

for mode_name, mode_data in raw_config.get("modes", {}).items():
# Validate required mode fields early
system_prompt_base = _validate_required_field(
mode_data,
"system_prompt_base",
context=f"mode '{mode_name}'",
field_type="field",
)

mode_config = ModeConfig(
version=mode_data.get("version", "1.0.1"),
name=mode_name,
display_name=mode_data.get("display_name", mode_name),
description=mode_data.get("description", ""),
system_prompt_base=mode_data["system_prompt_base"],
system_prompt_base=system_prompt_base,
hertz=mode_data.get("hertz", 1.0),
lifecycle_hooks=parse_lifecycle_hooks(
mode_data.get("lifecycle_hooks", []), api_key=g_api_key
Expand All @@ -610,11 +667,31 @@ def load_mode_config(

mode_system_config.modes[mode_name] = mode_config

for rule_data in raw_config.get("transition_rules", []):
for rule_idx, rule_data in enumerate(raw_config.get("transition_rules", [])):
# Validate required transition rule fields early
from_mode = _validate_required_field(
rule_data,
"from_mode",
context=f"transition rule at index {rule_idx}",
field_type="field",
)
to_mode = _validate_required_field(
rule_data,
"to_mode",
context=f"transition rule at index {rule_idx}",
field_type="field",
)
transition_type_str = _validate_required_field(
rule_data,
"transition_type",
context=f"transition rule at index {rule_idx}",
field_type="field",
)

rule = TransitionRule(
from_mode=rule_data["from_mode"],
to_mode=rule_data["to_mode"],
transition_type=TransitionType(rule_data["transition_type"]),
from_mode=from_mode,
to_mode=to_mode,
transition_type=TransitionType(transition_type_str),
trigger_keywords=rule_data.get("trigger_keywords", []),
priority=rule_data.get("priority", 1),
cooldown_seconds=rule_data.get("cooldown_seconds", 0.0),
Expand All @@ -623,6 +700,32 @@ def load_mode_config(
)
mode_system_config.transition_rules.append(rule)

# Validate that default_mode exists in the loaded modes
if default_mode not in mode_system_config.modes:
available_modes = ", ".join(mode_system_config.modes.keys())
raise ValueError(
f"Default mode '{default_mode}' not found in available modes. "
f"Available modes: {available_modes if available_modes else 'none'}. "
f"Please ensure the default_mode matches one of the defined modes in your configuration."
)

# Validate that modes referenced in transition rules exist
for rule_idx, rule in enumerate(mode_system_config.transition_rules):
if rule.from_mode not in mode_system_config.modes:
available_modes = ", ".join(mode_system_config.modes.keys())
raise ValueError(
f"Transition rule at index {rule_idx} references unknown 'from_mode' '{rule.from_mode}'. "
f"Available modes: {available_modes}. "
f"Please ensure all transition rules reference valid mode names."
)
if rule.to_mode not in mode_system_config.modes:
available_modes = ", ".join(mode_system_config.modes.keys())
raise ValueError(
f"Transition rule at index {rule_idx} references unknown 'to_mode' '{rule.to_mode}'. "
f"Available modes: {available_modes}. "
f"Please ensure all transition rules reference valid mode names."
)

return mode_system_config


Expand Down
17 changes: 13 additions & 4 deletions src/runtime/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional

import json5
import zenoh
try:
import json5
except ImportError:
json5 = None # type: ignore

try:
import zenoh
except ImportError:
zenoh = None # type: ignore

from runtime.config import (
LifecycleHookType,
Expand Down Expand Up @@ -146,6 +153,8 @@ def _create_runtime_config_file(self):

temp_file = runtime_config_path + ".tmp"
with open(temp_file, "w") as f:
if json5 is None:
raise ImportError("json5 is required to write configuration files")
json5.dump(runtime_config, f, indent=2)

os.rename(temp_file, runtime_config_path)
Expand Down Expand Up @@ -715,7 +724,7 @@ async def process_tick(

return None

def _zenoh_mode_status_request(self, data: zenoh.Sample):
def _zenoh_mode_status_request(self, data: "zenoh.Sample"):
"""
Process incoming mode status requests via Zenoh.

Expand Down Expand Up @@ -764,7 +773,7 @@ def _zenoh_mode_status_request(self, data: zenoh.Sample):
mode_status_response.serialize()
)

def _zenoh_context_update(self, data: zenoh.Sample):
def _zenoh_context_update(self, data: "zenoh.Sample"):
"""
Process incoming context update messages via Zenoh.

Expand Down
15 changes: 12 additions & 3 deletions src/zenoh_msgs/session.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

try:
import zenoh
except ImportError:
zenoh = None # type: ignore

import zenoh
if TYPE_CHECKING:
from zenoh import Config, Session

logging.basicConfig(level=logging.INFO)


def create_zenoh_config(network_discovery: bool = True) -> zenoh.Config:
def create_zenoh_config(network_discovery: bool = True) -> "zenoh.Config":
"""
Create a Zenoh configuration for a client connecting to a local server.

Expand All @@ -27,7 +36,7 @@ def create_zenoh_config(network_discovery: bool = True) -> zenoh.Config:
return config


def open_zenoh_session() -> zenoh.Session:
def open_zenoh_session() -> "zenoh.Session":
"""
Open a Zenoh session with a local connection first, then fall back to network discovery.

Expand Down
Loading
Loading