Skip to content

Commit adb3c18

Browse files
author
Agent-Planner
committed
Fix CodeRabbit PR AutoForgeAI#117 issues and type annotations
- registry.py: Include Ollama models in VALID_MODELS for validation - server/routers/settings.py: Pass limit=100 to get_denied_commands() - server/websocket.py: Add feature_update message emission on completion - parallel_orchestrator.py: Add stdin=DEVNULL and Windows CREATE_NO_WINDOW flags - requirements.txt: Document CVE-2026-24486 python-multipart fix - server/routers/projects.py: Add defense-in-depth filename validation with os.path.basename() - security.py: Simplify regex, add comprehensive type annotations with cast()
1 parent 5d12ec3 commit adb3c18

7 files changed

Lines changed: 114 additions & 63 deletions

File tree

parallel_orchestrator.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -679,14 +679,19 @@ async def _run_initializer(self) -> bool:
679679

680680
print("Running initializer agent...", flush=True)
681681

682-
proc = subprocess.Popen(
683-
cmd,
684-
stdout=subprocess.PIPE,
685-
stderr=subprocess.STDOUT,
686-
text=True,
687-
cwd=str(AUTOCODER_ROOT),
688-
env={**os.environ, "PYTHONUNBUFFERED": "1"},
689-
)
682+
# Add stdin and Windows flags to prevent blocking/popups
683+
popen_kwargs = {
684+
"stdout": subprocess.PIPE,
685+
"stderr": subprocess.STDOUT,
686+
"stdin": subprocess.DEVNULL,
687+
"text": True,
688+
"cwd": str(AUTOCODER_ROOT),
689+
"env": {**os.environ, "PYTHONUNBUFFERED": "1"},
690+
}
691+
if sys.platform == "win32":
692+
popen_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW
693+
694+
proc = subprocess.Popen(cmd, **popen_kwargs)
690695

691696
debug_log.log("INIT", "Initializer subprocess started", pid=proc.pid)
692697

registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
AVAILABLE_MODELS = CLAUDE_MODELS
5050

5151
# List of valid model IDs (derived from AVAILABLE_MODELS)
52-
VALID_MODELS = [m["id"] for m in CLAUDE_MODELS]
52+
# Include both Claude and Ollama models for validation
53+
VALID_MODELS = [m["id"] for m in CLAUDE_MODELS + OLLAMA_MODELS]
5354

5455
# Default model and settings
5556
DEFAULT_MODEL = "claude-opus-4-5-20251101"

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ sqlalchemy~=2.0
55
fastapi>=0.115.3,<0.116
66
uvicorn[standard]~=0.32
77
websockets~=13.0
8+
# CVE-2026-24486: >=0.0.22 strips directory components from uploaded filenames
9+
# python-multipart 0.0.21+ requires Python >=3.10 (current: 3.11)
810
python-multipart>=0.0.22,<0.1
911
psutil~=6.0
1012
aiofiles~=24.0

security.py

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass
1717
from datetime import datetime, timezone
1818
from pathlib import Path
19-
from typing import Optional
19+
from typing import Any, Optional, cast
2020

2121
import yaml
2222

@@ -85,7 +85,7 @@ def redact_string(s: str, max_preview: int = 20) -> str:
8585
)
8686

8787

88-
def get_denied_commands(limit: int = 50) -> list[dict]:
88+
def get_denied_commands(limit: int = 50) -> list[dict[str, Any]]:
8989
"""
9090
Get the most recent denied commands.
9191
@@ -299,7 +299,7 @@ def split_command_segments(command_string: str) -> list[str]:
299299
segments = re.split(r"\s*(?:&&|\|\|)\s*", command_string)
300300

301301
# Further split on semicolons
302-
result = []
302+
result: list[str] = []
303303
for segment in segments:
304304
sub_segments = re.split(r'(?<!["\'])\s*;\s*(?!["\'])', segment)
305305
for sub in sub_segments:
@@ -362,7 +362,7 @@ def extract_commands(command_string: str) -> list[str]:
362362
Returns:
363363
List of command names found in the string
364364
"""
365-
commands = []
365+
commands: list[str] = []
366366

367367
# shlex doesn't treat ; as a separator, so we need to pre-process
368368
# (re is already imported at module level)
@@ -382,7 +382,7 @@ def extract_commands(command_string: str) -> list[str]:
382382
# Malformed command (unclosed quotes, etc.)
383383
# Security: Only use fallback if segment contains no chaining operators
384384
# This prevents allowlist bypass via malformed commands hiding chained operators
385-
if re.search(r'\s*(\|\||&&|\||&)\s*', segment):
385+
if re.search(r'\|\||&&|\||&', segment):
386386
# Segment has operators but shlex failed - refuse to parse for safety
387387
continue
388388

@@ -483,7 +483,7 @@ def validate_pkill_command(
483483
return False, "Empty pkill command"
484484

485485
# Separate flags from arguments
486-
args = []
486+
args: list[str] = []
487487
for token in tokens[1:]:
488488
if not token.startswith("-"):
489489
args.append(token)
@@ -493,14 +493,14 @@ def validate_pkill_command(
493493

494494
# Validate every non-flag argument (pkill accepts multiple patterns on BSD)
495495
# This defensively ensures no disallowed process can be targeted
496-
targets = []
496+
targets: list[str] = []
497497
for arg in args:
498498
# For -f flag (full command line match), take the first word as process name
499499
# e.g., "pkill -f 'node server.js'" -> target is "node server.js", process is "node"
500-
t = arg.split()[0] if " " in arg else arg
500+
t: str = arg.split()[0] if " " in arg else arg
501501
targets.append(t)
502502

503-
disallowed = [t for t in targets if t not in allowed_process_names]
503+
disallowed: list[str] = [t for t in targets if t not in allowed_process_names]
504504
if not disallowed:
505505
return True, ""
506506
return False, f"pkill only allowed for processes: {sorted(allowed_process_names)}"
@@ -524,7 +524,7 @@ def validate_chmod_command(command_string: str) -> tuple[bool, str]:
524524
# Look for the mode argument
525525
# Valid modes: +x, u+x, a+x, etc. (anything ending with +x for execute permission)
526526
mode = None
527-
files = []
527+
files: list[str] = []
528528

529529
for token in tokens[1:]:
530530
if token.startswith("-"):
@@ -645,7 +645,7 @@ def get_org_config_path() -> Path:
645645
return Path.home() / ".autocoder" / "config.yaml"
646646

647647

648-
def load_org_config() -> Optional[dict]:
648+
def load_org_config() -> Optional[dict[str, Any]]:
649649
"""
650650
Load organization-level config from ~/.autocoder/config.yaml.
651651
@@ -676,10 +676,11 @@ def load_org_config() -> Optional[dict]:
676676

677677
# Validate allowed_commands if present
678678
if "allowed_commands" in config:
679-
allowed = config["allowed_commands"]
680-
if not isinstance(allowed, list):
679+
allowed_raw = cast(Any, config["allowed_commands"])
680+
if not isinstance(allowed_raw, list):
681681
logger.warning(f"Org config at {config_path}: 'allowed_commands' must be a list")
682682
return None
683+
allowed = cast(list[dict[str, Any]], allowed_raw)
683684
for i, cmd in enumerate(allowed):
684685
if not isinstance(cmd, dict):
685686
logger.warning(f"Org config at {config_path}: allowed_commands[{i}] must be a dict")
@@ -694,23 +695,25 @@ def load_org_config() -> Optional[dict]:
694695

695696
# Validate blocked_commands if present
696697
if "blocked_commands" in config:
697-
blocked = config["blocked_commands"]
698-
if not isinstance(blocked, list):
698+
blocked_raw = cast(Any, config["blocked_commands"])
699+
if not isinstance(blocked_raw, list):
699700
logger.warning(f"Org config at {config_path}: 'blocked_commands' must be a list")
700701
return None
702+
blocked = cast(list[str], blocked_raw)
701703
for i, cmd in enumerate(blocked):
702704
if not isinstance(cmd, str):
703705
logger.warning(f"Org config at {config_path}: blocked_commands[{i}] must be a string")
704706
return None
705707

706708
# Validate pkill_processes if present
707709
if "pkill_processes" in config:
708-
processes = config["pkill_processes"]
709-
if not isinstance(processes, list):
710+
processes_raw = cast(Any, config["pkill_processes"])
711+
if not isinstance(processes_raw, list):
710712
logger.warning(f"Org config at {config_path}: 'pkill_processes' must be a list")
711713
return None
714+
processes = cast(list[Any], processes_raw)
712715
# Normalize and validate each process name against safe pattern
713-
normalized = []
716+
normalized: list[str] = []
714717
for i, proc in enumerate(processes):
715718
if not isinstance(proc, str):
716719
logger.warning(f"Org config at {config_path}: pkill_processes[{i}] must be a string")
@@ -723,7 +726,7 @@ def load_org_config() -> Optional[dict]:
723726
normalized.append(proc)
724727
config["pkill_processes"] = normalized
725728

726-
return config
729+
return cast(dict[str, Any], config)
727730

728731
except yaml.YAMLError as e:
729732
logger.warning(f"Failed to parse org config at {config_path}: {e}")
@@ -733,7 +736,7 @@ def load_org_config() -> Optional[dict]:
733736
return None
734737

735738

736-
def load_project_commands(project_dir: Path) -> Optional[dict]:
739+
def load_project_commands(project_dir: Path) -> Optional[dict[str, Any]]:
737740
"""
738741
Load allowed commands from project-specific YAML config.
739742
@@ -765,10 +768,11 @@ def load_project_commands(project_dir: Path) -> Optional[dict]:
765768
logger.warning(f"Project config at {config_path} missing required 'version' field")
766769
return None
767770

768-
commands = config.get("commands", [])
769-
if not isinstance(commands, list):
771+
commands_raw = cast(Any, config["commands"] if "commands" in config else [])
772+
if not isinstance(commands_raw, list):
770773
logger.warning(f"Project config at {config_path}: 'commands' must be a list")
771774
return None
775+
commands = cast(list[dict[str, Any]], commands_raw)
772776

773777
# Enforce 100 command limit
774778
if len(commands) > 100:
@@ -790,12 +794,13 @@ def load_project_commands(project_dir: Path) -> Optional[dict]:
790794

791795
# Validate pkill_processes if present
792796
if "pkill_processes" in config:
793-
processes = config["pkill_processes"]
794-
if not isinstance(processes, list):
797+
processes_raw = cast(Any, config["pkill_processes"])
798+
if not isinstance(processes_raw, list):
795799
logger.warning(f"Project config at {config_path}: 'pkill_processes' must be a list")
796800
return None
801+
processes = cast(list[Any], processes_raw)
797802
# Normalize and validate each process name against safe pattern
798-
normalized = []
803+
normalized: list[str] = []
799804
for i, proc in enumerate(processes):
800805
if not isinstance(proc, str):
801806
logger.warning(f"Project config at {config_path}: pkill_processes[{i}] must be a string")
@@ -808,7 +813,7 @@ def load_project_commands(project_dir: Path) -> Optional[dict]:
808813
normalized.append(proc)
809814
config["pkill_processes"] = normalized
810815

811-
return config
816+
return cast(dict[str, Any], config)
812817

813818
except yaml.YAMLError as e:
814819
logger.warning(f"Failed to parse project config at {config_path}: {e}")
@@ -818,7 +823,7 @@ def load_project_commands(project_dir: Path) -> Optional[dict]:
818823
return None
819824

820825

821-
def validate_project_command(cmd_config: dict) -> tuple[bool, str]:
826+
def validate_project_command(cmd_config: dict[str, Any]) -> tuple[bool, str]:
822827
"""
823828
Validate a single command entry from project config.
824829
@@ -828,7 +833,7 @@ def validate_project_command(cmd_config: dict) -> tuple[bool, str]:
828833
Returns:
829834
Tuple of (is_valid, error_message)
830835
"""
831-
if not isinstance(cmd_config, dict):
836+
if not isinstance(cmd_config, dict): # type: ignore[misc]
832837
return False, "Command must be a dict"
833838

834839
if "name" not in cmd_config:
@@ -855,9 +860,10 @@ def validate_project_command(cmd_config: dict) -> tuple[bool, str]:
855860

856861
# Args validation (Phase 1 - just check structure)
857862
if "args" in cmd_config:
858-
args = cmd_config["args"]
859-
if not isinstance(args, list):
863+
args_raw = cmd_config["args"]
864+
if not isinstance(args_raw, list):
860865
return False, "Args must be a list"
866+
args = cast(list[str], args_raw)
861867
for arg in args:
862868
if not isinstance(arg, str):
863869
return False, "Each arg must be a string"
@@ -892,13 +898,13 @@ def get_effective_commands(project_dir: Optional[Path]) -> tuple[set[str], set[s
892898
org_config = load_org_config()
893899
if org_config:
894900
# Add org-level blocked commands (cannot be overridden)
895-
org_blocked = org_config.get("blocked_commands", [])
901+
org_blocked: Any = org_config.get("blocked_commands", [])
896902
blocked |= set(org_blocked)
897903

898904
# Add org-level allowed commands
899905
for cmd_config in org_config.get("allowed_commands", []):
900906
if isinstance(cmd_config, dict) and "name" in cmd_config:
901-
allowed.add(cmd_config["name"])
907+
allowed.add(cast(str, cmd_config["name"]))
902908

903909
# Load project config and apply
904910
if project_dir:
@@ -908,7 +914,10 @@ def get_effective_commands(project_dir: Optional[Path]) -> tuple[set[str], set[s
908914
for cmd_config in project_config.get("commands", []):
909915
valid, error = validate_project_command(cmd_config)
910916
if valid:
911-
allowed.add(cmd_config["name"])
917+
allowed.add(cast(str, cmd_config["name"]))
918+
else:
919+
# Log validation error for debugging
920+
logger.debug(f"Project command validation failed: {error}")
912921

913922
# Remove blocked commands from allowed (blocklist takes precedence)
914923
allowed -= blocked
@@ -928,7 +937,8 @@ def get_project_allowed_commands(project_dir: Optional[Path]) -> set[str]:
928937
Returns:
929938
Set of allowed command names (including patterns)
930939
"""
931-
allowed, blocked = get_effective_commands(project_dir)
940+
allowed, _blocked = get_effective_commands(project_dir)
941+
# _blocked is used in get_effective_commands for precedence logic
932942
return allowed
933943

934944

@@ -953,16 +963,18 @@ def get_effective_pkill_processes(project_dir: Optional[Path]) -> set[str]:
953963
# Add org-level pkill_processes
954964
org_config = load_org_config()
955965
if org_config:
956-
org_processes = org_config.get("pkill_processes", [])
957-
if isinstance(org_processes, list):
966+
org_processes_raw = org_config.get("pkill_processes", [])
967+
if isinstance(org_processes_raw, list):
968+
org_processes = cast(list[Any], org_processes_raw)
958969
processes |= {p for p in org_processes if isinstance(p, str) and p.strip()}
959970

960971
# Add project-level pkill_processes
961972
if project_dir:
962973
project_config = load_project_commands(project_dir)
963974
if project_config:
964-
proj_processes = project_config.get("pkill_processes", [])
965-
if isinstance(proj_processes, list):
975+
proj_processes_raw = project_config.get("pkill_processes", [])
976+
if isinstance(proj_processes_raw, list):
977+
proj_processes = cast(list[Any], proj_processes_raw)
966978
processes |= {p for p in proj_processes if isinstance(p, str) and p.strip()}
967979

968980
return processes
@@ -991,7 +1003,11 @@ def is_command_allowed(command: str, allowed_commands: set[str]) -> bool:
9911003
return False
9921004

9931005

994-
async def bash_security_hook(input_data, tool_use_id=None, context=None):
1006+
async def bash_security_hook(
1007+
input_data: dict[str, Any],
1008+
tool_use_id: Optional[str] = None,
1009+
context: Optional[dict[str, Any]] = None
1010+
) -> dict[str, Any]:
9951011
"""
9961012
Pre-tool-use hook that validates bash commands using an allowlist.
9971013
@@ -1015,15 +1031,16 @@ async def bash_security_hook(input_data, tool_use_id=None, context=None):
10151031
if input_data.get("tool_name") != "Bash":
10161032
return {}
10171033

1018-
command = input_data.get("tool_input", {}).get("command", "")
1034+
command_raw: Any = input_data.get("tool_input", {}).get("command", "")
1035+
command = str(command_raw) if command_raw else ""
10191036
if not command:
10201037
return {}
10211038

10221039
# Get project directory from context early (needed for denied command recording)
10231040
project_dir = None
1024-
if context and isinstance(context, dict):
1025-
project_dir_str = context.get("project_dir")
1026-
if project_dir_str:
1041+
if context and isinstance(context, dict): # type: ignore[misc]
1042+
project_dir_str: Any = context.get("project_dir")
1043+
if project_dir_str and isinstance(project_dir_str, str):
10271044
project_dir = Path(project_dir_str)
10281045

10291046
# SECURITY LAYER 1: Pre-validate for dangerous shell patterns

0 commit comments

Comments
 (0)