Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,25 @@ module = [
ignore_missing_imports = true

[tool.pytest.ini_options]
testpaths = ["tests"]
python_paths = ["src"]
addopts = "-v --tb=short"
testpaths = ["tests/unit", "tests/integration", "tests/e2e"]
pythonpath = ["src"]
addopts = "-v --tb=short -ra --strict-markers -W default"
minversion = "3.8"
filterwarnings = [
"ignore::DeprecationWarning",
"ignore::PendingDeprecationWarning",
]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"integration: marks tests as integration tests",
"unit: Fast unit tests (no external dependencies)",
"integration: Integration tests (may be slower, test multiple components)",
"e2e: End-to-end tests (require full environment, Docker, may be very slow)",
"slow: Slow tests (can be skipped with -m \"not slow\")",
"gpu: Tests that require GPU hardware",
"amd: Tests specific to AMD GPUs",
"nvidia: Tests specific to NVIDIA GPUs",
"cpu: Tests for CPU-only execution",
"requires_docker: Tests that require Docker daemon",
"requires_models: Tests that require model fixtures",
]

[tool.coverage.run]
Expand Down
85 changes: 0 additions & 85 deletions pytest.ini

This file was deleted.

6 changes: 6 additions & 0 deletions src/madengine/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,12 @@ def get_gpu_renderD_nodes(self) -> typing.Optional[typing.List[int]]:
print(f"Warning: Failed to parse unique_id from line '{item}': {e}")
continue

if kfd_renderDs is None:
raise RuntimeError(
"KFD topology not accessible and required for ROCm < 6.4.1 GPU mapping. "
"Check permissions on /sys/devices/virtual/kfd/kfd/topology/nodes"
)

if len(kfd_unique_ids) != len(kfd_renderDs):
raise RuntimeError(
f"Mismatch between unique_ids count ({len(kfd_unique_ids)}) "
Expand Down
9 changes: 5 additions & 4 deletions src/madengine/core/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ def __init__(
# add mounts
if mounts is not None:
for mount in mounts:
command += "-v " + mount + ":" + mount + " "
quoted_mount = shlex.quote(mount)
command += "-v " + quoted_mount + ":" + quoted_mount + " "

# add current working directory
command += "-v " + cwd + ":/myworkspace/ "
command += "-v " + shlex.quote(cwd) + ":/myworkspace/ "

# add envVars
_env_key_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
Expand All @@ -105,8 +106,8 @@ def __init__(
command += "-e " + evar + "=" + shlex.quote(str(envVars[evar])) + " "

command += "--workdir /myworkspace/ "
command += "--name " + container_name + " "
command += image + " "
command += "--name " + shlex.quote(container_name) + " "
command += shlex.quote(image) + " "

# Use 'cat' to keep container alive (blocks waiting for stdin)
# Works reliably across all deployment types (local, k8s, slurm)
Expand Down
2 changes: 1 addition & 1 deletion src/madengine/deployment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _monitor_until_complete(self, deployment_id: str) -> DeploymentResult:
while True:
status = self.monitor(deployment_id)

if status.status in [DeploymentStatus.SUCCESS, DeploymentStatus.FAILED, DeploymentStatus.UNKNOWN]:
if status.status in [DeploymentStatus.SUCCESS, DeploymentStatus.FAILED, DeploymentStatus.UNKNOWN, DeploymentStatus.CANCELLED]:
return status

# Still running, wait and check again
Expand Down
2 changes: 1 addition & 1 deletion src/madengine/deployment/presets/k8s/defaults.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"ttl_seconds_after_finished": null,
"allow_privileged_profiling": null,
"nfs_storage_class": "nfs-banff",
"local_path_storage_class": "local-path",
"storage_class": "nfs-banff",
"data_storage_class": "nfs-banff",
"recreate_shared_data_pvc": false,
"secrets": {
Expand Down
10 changes: 5 additions & 5 deletions src/madengine/execution/container_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,16 +557,16 @@ def pull_image(
print(f"🔄 Using fresh pull policy for SLURM compute node (prevents cached layer corruption)")
# Remove any existing cached image to force fresh pull
try:
self.console.sh(f"docker rmi -f {registry_image} 2>/dev/null || true")
self.console.sh(f"docker rmi -f {shlex.quote(registry_image)} 2>/dev/null || true")
print(f"✓ Removed cached image layers")
except Exception:
pass # It's okay if image doesn't exist

try:
self.console.sh(f"docker pull {registry_image}")
self.console.sh(f"docker pull {shlex.quote(registry_image)}")

if local_name:
self.console.sh(f"docker tag {registry_image} {local_name}")
self.console.sh(f"docker tag {shlex.quote(registry_image)} {shlex.quote(local_name)}")
print(f"🏷️ Tagged as: {local_name}")
self.rich_console.print(f"[bold green]✅ Successfully pulled and tagged image[/bold green]")
self.rich_console.print(f"[dim]{'='*80}[/dim]")
Expand Down Expand Up @@ -688,7 +688,7 @@ def get_mount_arg(self, mount_datapaths: typing.List) -> str:
for mount_datapath in mount_datapaths:
if mount_datapath:
mount_args += (
f"-v {mount_datapath['path']}:{mount_datapath['home']}"
f"-v {shlex.quote(mount_datapath['path'])}:{shlex.quote(mount_datapath['home'])}"
)
if (
"readwrite" in mount_datapath
Expand All @@ -702,7 +702,7 @@ def get_mount_arg(self, mount_datapaths: typing.List) -> str:
if "docker_mounts" in self.context.ctx:
for mount_arg in self.context.ctx["docker_mounts"].keys():
mount_args += (
f"-v {self.context.ctx['docker_mounts'][mount_arg]}:{mount_arg} "
f"-v {shlex.quote(self.context.ctx['docker_mounts'][mount_arg])}:{shlex.quote(mount_arg)} "
)

return mount_args
Expand Down
14 changes: 7 additions & 7 deletions src/madengine/execution/docker_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def build_image(

build_command = (
f"docker build {use_cache_str} --network=host "
f"-t {docker_image} --pull -f {dockerfile} "
f"{build_args} {docker_context}"
f"-t {shlex.quote(docker_image)} --pull -f {shlex.quote(dockerfile)} "
f"{build_args} {shlex.quote(docker_context)}"
)
Comment thread
coketaste marked this conversation as resolved.

# Execute build with log redirection
Expand All @@ -207,7 +207,7 @@ def build_image(
base_docker = self.context.ctx["docker_build_arg"]["BASE_DOCKER"]
else:
base_docker = self.console.sh(
f"grep '^ARG BASE_DOCKER=' {dockerfile} | sed -E 's/ARG BASE_DOCKER=//g'"
f"grep '^ARG BASE_DOCKER=' {shlex.quote(dockerfile)} | sed -E 's/ARG BASE_DOCKER=//g'"
)

print(f"BASE DOCKER is {base_docker}")
Expand All @@ -216,7 +216,7 @@ def build_image(
docker_sha = ""
try:
docker_sha = self.console.sh(
f'docker manifest inspect {base_docker} | grep digest | head -n 1 | cut -d \\" -f 4'
f'docker manifest inspect {shlex.quote(base_docker)} | grep digest | head -n 1 | cut -d \\" -f 4'
)
print(f"BASE DOCKER SHA is {docker_sha}")
except Exception as e:
Expand Down Expand Up @@ -297,15 +297,15 @@ def push_image(
# Tag the image if different from local name
if registry_image != docker_image:
print(f"Tagging image: docker tag {docker_image} {registry_image}")
tag_command = f"docker tag {docker_image} {registry_image}"
tag_command = f"docker tag {shlex.quote(docker_image)} {shlex.quote(registry_image)}"
self.console.sh(tag_command)
else:
print(
f"No tag needed, docker_image and registry_image are the same: {docker_image}"
)

# Push the image
push_command = f"docker push {registry_image}"
push_command = f"docker push {shlex.quote(registry_image)}"
self.rich_console.print(f"\n[bold blue]🚀 Starting docker push to registry...[/bold blue]")
print(f"📤 Registry: {registry}")
print(f"🏷️ Image: {registry_image}")
Expand Down Expand Up @@ -559,7 +559,7 @@ def _get_dockerfiles_for_model(self, model_info: typing.Dict) -> typing.List[str
for cur_docker_file in all_dockerfiles:
# Get context of dockerfile
dockerfiles[cur_docker_file] = self.console.sh(
f"head -n5 {cur_docker_file} | grep '# CONTEXT ' | sed 's/# CONTEXT //g'"
f"head -n5 {shlex.quote(cur_docker_file)} | grep '# CONTEXT ' | sed 's/# CONTEXT //g'"
)

# Filter dockerfiles based on context
Expand Down
5 changes: 3 additions & 2 deletions src/madengine/orchestration/run_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import json
import os
import shlex
import subprocess
from pathlib import Path
from typing import Dict, Optional
Expand Down Expand Up @@ -391,12 +392,12 @@ def _create_manifest_from_local_image(

# Validate that the image exists locally or can be pulled
try:
self.console.sh(f"docker image inspect {image_name} > /dev/null 2>&1")
self.console.sh(f"docker image inspect {shlex.quote(image_name)} > /dev/null 2>&1")
self.rich_console.print(f"[green]✓ Image {image_name} found locally[/green]")
except (subprocess.CalledProcessError, RuntimeError) as e:
self.rich_console.print(f"[yellow]⚠️ Image {image_name} not found locally, attempting to pull...[/yellow]")
try:
self.console.sh(f"docker pull {image_name}")
self.console.sh(f"docker pull {shlex.quote(image_name)}")
self.rich_console.print(f"[green]✓ Successfully pulled {image_name}[/green]")
except Exception as e:
raise RuntimeError(
Expand Down
36 changes: 1 addition & 35 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,9 @@

import json
import os
import sys
import tempfile
from pathlib import Path

_SRC = Path(__file__).resolve().parents[1] / "src"
if _SRC.is_dir() and str(_SRC) not in sys.path:
sys.path.insert(0, str(_SRC))
from unittest.mock import MagicMock, patch

import pytest


Expand Down Expand Up @@ -361,35 +356,6 @@ def integration_test_env():
yield env_vars


# ============================================================================
# Pytest Configuration
# ============================================================================

def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line(
"markers", "integration: marks tests as integration tests (may be slow)"
)
config.addinivalue_line(
"markers", "unit: marks tests as fast unit tests"
)
config.addinivalue_line(
"markers", "gpu: marks tests that require GPU hardware"
)
config.addinivalue_line(
"markers", "amd: marks tests specific to AMD GPUs"
)
config.addinivalue_line(
"markers", "nvidia: marks tests specific to NVIDIA GPUs"
)
config.addinivalue_line(
"markers", "cpu: marks tests for CPU-only execution"
)
config.addinivalue_line(
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
)


# ============================================================================
# Utility Functions for Tests
# ============================================================================
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ def test_handle_generic_error(self):
class TestGlobalErrorHandler:
"""Test global error handler functionality."""

def setup_method(self):
set_error_handler(None)

def teardown_method(self):
set_error_handler(None)

def test_set_and_get_error_handler(self):
"""Test setting and getting global error handler."""
mock_console = Mock(spec=Console)
Expand Down
Loading