Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ def _create_server_process(
"""
import os

from zndraw import config as config_module

# Reset config singleton for test isolation
config_module._config = None

port = _get_free_port()
storage_path = tmp_path / "zndraw-data"
redis_url = "redis://localhost:6379"
Expand Down Expand Up @@ -295,6 +300,9 @@ def _create_server_process(

remove_server_info(port)

# Reset config singleton again for next test
config_module._config = None


@pytest.fixture
def server(tmp_path) -> t.Generator[str, None, None]:
Expand Down Expand Up @@ -431,3 +439,55 @@ def test_my_feature(joined_room):
assert response.status_code == 200, f"Failed to join room {room}"

return server, room


@pytest.fixture
def server_provider(tmp_path, redis_client, get_free_port):
"""Server with restart capability. Use only when testing restart scenarios.

For normal tests, use `server` or `server_admin_mode` fixtures instead.

This fixture provides a ServerProvider instance that allows:
- Starting/stopping the server
- Restarting the server mid-test
- Checking server status

Example
-------
def test_extension_persistence(server_provider):
vis = ZnDraw(url=server_provider.url, room="test")
vis.register_extension(MyExt, public=True)

server_provider.restart()

# Verify extension re-registers after restart
schema = requests.get(f"{server_provider.url}/api/schema/modifiers").json()
assert "MyExt" in schema
"""
from server_provider import ServerProvider

from zndraw import config as config_module
from zndraw.server_manager import remove_server_info

# Reset config singleton for test isolation
config_module._config = None

port = get_free_port()
storage_path = tmp_path / "zndraw-provider"
storage_path.mkdir()
redis_url = "redis://localhost:6379"

provider = ServerProvider(
port=port,
storage_path=storage_path,
redis_url=redis_url,
)
provider.start()

try:
yield provider
finally:
provider.stop() # Graceful shutdown first
redis_client.flushall()
remove_server_info(port)
config_module._config = None
208 changes: 208 additions & 0 deletions tests/server_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""Server lifecycle management for restart scenarios in tests.

This module provides a ServerProvider class that allows starting, stopping,
and restarting ZnDraw servers mid-test. Use this ONLY when you need restart
capability. For normal tests, use the `server` or `server_admin_mode` fixtures.
"""

import os
import socket
import subprocess
import time
from dataclasses import dataclass, field
from pathlib import Path


@dataclass
class ServerProvider:
"""Manages a ZnDraw server with restart capability.

Use this ONLY when you need to restart the server mid-test.
For normal tests, use the `server` or `server_admin_mode` fixtures.

Parameters
----------
port
Server port number.
storage_path
Path to LMDB storage directory.
redis_url
Redis connection URL.
admin_user
Admin username (enables deployment mode if set with password).
admin_password
Admin password (enables deployment mode if set with username).

Example
-------
def test_extension_persistence(server_provider):
server = server_provider
vis = ZnDraw(url=server.url, room="test")
vis.register_extension(MyExt, public=True)

server.restart()

# Verify extension re-registers after restart
schema = requests.get(f"{server.url}/api/schema/modifiers").json()
assert "MyExt" in schema
"""

port: int
storage_path: Path
redis_url: str
admin_user: str | None = None
admin_password: str | None = None

# Internal state (not part of dataclass comparison)
_process: subprocess.Popen | None = field(default=None, repr=False, compare=False)

@property
def url(self) -> str:
"""Server URL."""
return f"http://127.0.0.1:{self.port}"

@property
def is_running(self) -> bool:
"""Check if server process is running and responsive."""
if self._process is None or self._process.poll() is not None:
return False
try:
import requests

resp = requests.get(f"{self.url}/health", timeout=1)
return resp.status_code == 200
except Exception:
return False

def start(self) -> str:
"""Start the server.

Returns
-------
str
Server URL.

Raises
------
RuntimeError
If server is already running or fails to start.
"""
if self._process is not None and self._process.poll() is None:
raise RuntimeError("Server already running")

env = os.environ.copy()
if self.admin_user and self.admin_password:
env["ZNDRAW_ADMIN_USERNAME"] = self.admin_user
env["ZNDRAW_ADMIN_PASSWORD"] = self.admin_password
else:
env.pop("ZNDRAW_ADMIN_USERNAME", None)
env.pop("ZNDRAW_ADMIN_PASSWORD", None)

self._process = subprocess.Popen(
[
"zndraw",
"--port",
str(self.port),
"--no-celery",
"--storage-path",
str(self.storage_path),
"--redis-url",
self.redis_url,
"--no-browser",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
)

if not _wait_for_server(self.port):
self.kill()
raise RuntimeError(f"Server failed to start on port {self.port}")

return self.url

def stop(self) -> None:
"""Graceful shutdown (SIGTERM)."""
if self._process is None:
return
self._process.terminate()
try:
self._process.wait(timeout=10)
except subprocess.TimeoutExpired:
self._process.kill()
self._process.wait()
self._process = None

def kill(self) -> None:
"""Force kill (SIGKILL)."""
if self._process is None:
return
self._process.kill()
self._process.wait()
self._process = None

def restart(self) -> str:
"""Stop and start server.

Returns
-------
str
Server URL.
"""
self.stop()
return self.start()

def flush_redis(self) -> None:
"""Flush all data from Redis.

Clears all keys in the Redis database, including:
- Public extensions
- Room data
- Session data
"""
import redis

client = redis.Redis.from_url(self.redis_url, decode_responses=True)
client.flushall()

def fresh_restart(self) -> str:
"""Stop server, flush Redis, and start fresh.

Use this when you need a completely clean slate with no
persisted data (extensions, rooms, etc.).

Returns
-------
str
Server URL.
"""
self.stop()
self.flush_redis()
return self.start()


def _wait_for_server(port: int, timeout: float = 30.0) -> bool:
"""Wait for server to become ready.

Parameters
----------
port
Port to check.
timeout
Maximum wait time in seconds.

Returns
-------
bool
True if server is ready, False if timeout.
"""
start = time.time()
while time.time() - start < timeout:
try:
with socket.socket() as sock:
sock.settimeout(0.1)
sock.connect(("127.0.0.1", port))
return True
except (ConnectionRefusedError, OSError):
time.sleep(0.1)
return False
99 changes: 99 additions & 0 deletions tests/test_extension_persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Tests for extension persistence across server restarts.

These tests verify that the ZnDraw client auto-reconnects and re-registers
extensions after a server restart.
"""

import time

import requests

from zndraw import ZnDraw
from zndraw.extensions import Category, Extension


class PersistenceTestExtension(Extension):
"""A test extension for verifying re-registration after server restart."""

category = Category.MODIFIER

def run(self, vis: ZnDraw, **kwargs):
"""Dummy run method."""
pass


def test_extension_reregisters_after_restart(server_provider):
"""Test that client auto-reconnects and re-registers extensions after restart.

1. Register extension (client stays connected)
2. Restart server (Redis persists - session remains valid)
3. Client should auto-reconnect and re-register extension
"""
vis = ZnDraw(url=server_provider.url, room="test-room", user="worker")
vis.register_extension(PersistenceTestExtension, public=True)
time.sleep(0.5)

# Verify registered
response = requests.get(
f"{server_provider.url}/api/rooms/test-room/schema/modifiers"
)
assert "PersistenceTestExtension" in [ext["name"] for ext in response.json()]

# Normal restart (Redis persists, session valid) - client should auto-reconnect
server_provider.restart()

# Wait for client to auto-reconnect
for _ in range(20): # Up to 10 seconds
time.sleep(0.5)
if vis.socket.connected:
break
else:
raise AssertionError("Client did not reconnect within 10 seconds")

time.sleep(1.0)

# Extension should be available (persisted in Redis + re-registered by client)
response = requests.get(
f"{server_provider.url}/api/rooms/test-room/schema/modifiers"
)
assert response.status_code == 200
extension_names = [ext["name"] for ext in response.json()]
assert "PersistenceTestExtension" in extension_names, (
f"Extension not available after restart. Available: {extension_names}"
)


def test_fresh_restart_invalidates_session(server_provider):
"""Test that fresh_restart invalidates client session (Redis is flushed).

When Redis is flushed, the client's session is invalidated and it cannot
reconnect with its old credentials. This is expected behavior.
"""
vis = ZnDraw(url=server_provider.url, room="test-room", user="worker")
vis.register_extension(PersistenceTestExtension, public=True)
time.sleep(0.5)

# Verify registered
response = requests.get(
f"{server_provider.url}/api/rooms/test-room/schema/modifiers"
)
assert "PersistenceTestExtension" in [ext["name"] for ext in response.json()]

# Fresh restart flushes Redis - session becomes invalid
server_provider.fresh_restart()

# Client cannot reconnect (session invalidated)
time.sleep(2.0)
assert vis.socket.connected is False, (
"Client should NOT reconnect after fresh restart"
)

# Extension is gone (Redis was flushed)
response = requests.get(
f"{server_provider.url}/api/rooms/test-room/schema/modifiers"
)
assert response.status_code == 200
extension_names = [ext["name"] for ext in response.json()]
assert "PersistenceTestExtension" not in extension_names, (
"Extension should be cleared after fresh restart"
)
Loading
Loading