diff --git a/.gitignore b/.gitignore index 755d3df..b383ded 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ hola-py/benchmark_results/ CLAUDE.md .claude/ paper/ +/issues/ # OS .DS_Store diff --git a/Cargo.lock b/Cargo.lock index a1745f5..952749b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -112,7 +112,6 @@ dependencies = [ "tower", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -131,7 +130,6 @@ dependencies = [ "sync_wrapper", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -841,15 +839,6 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" -[[package]] -name = "lock_api" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" -dependencies = [ - "scopeguard", -] - [[package]] name = "log" version = "0.4.29" @@ -1034,29 +1023,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "parking_lot" -version = "0.12.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-link", -] - [[package]] name = "paste" version = "1.0.15" @@ -1287,15 +1253,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" -[[package]] -name = "redox_syscall" -version = "0.5.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" -dependencies = [ - "bitflags", -] - [[package]] name = "reqwest" version = "0.12.28" @@ -1423,12 +1380,6 @@ dependencies = [ "bytemuck", ] -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - [[package]] name = "serde" version = "1.0.228" @@ -1514,16 +1465,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signal-hook-registry" -version = "1.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" -dependencies = [ - "errno", - "libc", -] - [[package]] name = "simba" version = "0.9.1" @@ -1687,9 +1628,7 @@ dependencies = [ "bytes", "libc", "mio", - "parking_lot", "pin-project-lite", - "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.61.2", @@ -1754,7 +1693,6 @@ dependencies = [ "tokio", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -1803,7 +1741,6 @@ version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ - "log", "pin-project-lite", "tracing-core", ] diff --git a/dashboard/app.js b/dashboard/app.js index 9a0f343..b0b4a8e 100644 --- a/dashboard/app.js +++ b/dashboard/app.js @@ -33,6 +33,26 @@ const S = { // ============================================================================ // Connection // ============================================================================ +function apiToken() { + const urlToken = new URLSearchParams(window.location.search).get('token'); + if (urlToken) { + localStorage.setItem('hola_api_token', urlToken); + return urlToken; + } + return localStorage.getItem('hola_api_token') || ''; +} + +function apiFetch(url, options = {}) { + const headers = new Headers(options.headers || {}); + const token = apiToken(); + if (token) headers.set('Authorization', `Bearer ${token}`); + return fetch(url, { ...options, headers }); +} + +function clearElement(el) { + el.replaceChildren(); +} + async function connectToServer() { const url = document.getElementById('server-url').value.trim().replace(/\/+$/, '') || 'http://localhost:8000'; @@ -74,9 +94,9 @@ function startSSE() { S.sse.onmessage = async (e) => { const event = JSON.parse(e.data); if (event.type === 'TrialCompleted') { - // Re-fetch trials - const resp = await fetch(`${S.serverUrl}/api/trials?sorted_by=index&include_infeasible=true`); - S.trials = await resp.json(); + const trial = event.trial || await fetchCompletedTrial(event.trial_id); + if (!trial) return; + upsertTrial(trial); discoverMetrics(); S.lastTrialTime = Date.now(); if (S.previewActive) previewObjectives(); else renderAll(); @@ -84,6 +104,18 @@ function startSSE() { }; } +async function fetchCompletedTrial(trialId) { + const resp = await fetch(`${S.serverUrl}/api/trial/${trialId}?include_infeasible=true`); + if (!resp.ok) return null; + return resp.json(); +} + +function upsertTrial(trial) { + const idx = S.trials.findIndex(t => t.trial_id === trial.trial_id); + if (idx >= 0) S.trials[idx] = trial; + else S.trials.push(trial); +} + function loadCheckpointFile(event) { const file = event.target.files[0]; if (!file) return; @@ -253,7 +285,7 @@ function renderConvergence() { const h = 280; // Clear any previous uPlot DOM content so it doesn't accumulate - container.innerHTML = ''; + clearElement(container); const xs = S.trials.map((_, i) => i); // Convert NaN to null so uPlot treats them as gaps instead of broken values @@ -339,11 +371,11 @@ function renderParetoDropdowns() { const allFields = [...S.metricNames]; const prevX = xSel.value; const prevY = ySel.value; - xSel.innerHTML = ''; - ySel.innerHTML = ''; + clearElement(xSel); + clearElement(ySel); for (const f of allFields) { - xSel.innerHTML += ``; - ySel.innerHTML += ``; + xSel.add(new Option(f, f, false, f === prevX)); + ySel.add(new Option(f, f, false, f === prevY)); } if (!prevX && allFields.length >= 2) { xSel.value = allFields[0]; @@ -517,13 +549,22 @@ function attachParetoTooltip() { } if (nearest && nearestDist <= 10) { - tooltip.innerHTML = - `Trial ${nearest.trial.trial_id}
` + - `${nearest.xField}: ${fmtCell(nearest.xVal)}
` + - `${nearest.yField}: ${fmtCell(nearest.yVal)}
` + - (nearest.onFront - ? 'Pareto front' - : 'Dominated'); + clearElement(tooltip); + + const title = document.createElement('strong'); + title.textContent = `Trial ${fmtCell(nearest.trial.trial_id)}`; + + const xLine = document.createElement('div'); + xLine.textContent = `${nearest.xField}: ${fmtCell(nearest.xVal)}`; + + const yLine = document.createElement('div'); + yLine.textContent = `${nearest.yField}: ${fmtCell(nearest.yVal)}`; + + const status = document.createElement('span'); + status.textContent = nearest.onFront ? 'Pareto front' : 'Dominated'; + status.style.color = nearest.onFront ? 'var(--accent-bright)' : 'var(--text-2)'; + + tooltip.append(title, xLine, yLine, status); // Position tooltip near cursor but keep it inside the container const container = canvas.parentElement; const cw = container.clientWidth; @@ -682,10 +723,14 @@ function renderTable() { // Build columns const cols = ['trial_id', 'rank', ...S.paramNames, ...S.metricNames]; - thead.innerHTML = cols.map(c => { - const cls = S.sortCol === c ? (S.sortAsc ? 'sorted-asc' : 'sorted-desc') : ''; - return `${c}`; - }).join(''); + clearElement(thead); + for (const c of cols) { + const th = document.createElement('th'); + th.textContent = c; + if (S.sortCol === c) th.className = S.sortAsc ? 'sorted-asc' : 'sorted-desc'; + th.addEventListener('click', () => sortTable(c)); + thead.appendChild(th); + } // Sort trials let sorted = [...S.trials]; @@ -697,12 +742,18 @@ function renderTable() { }); } - tbody.innerHTML = sorted.map(t => { + clearElement(tbody); + for (const t of sorted) { const isBest = t.trial_id === (S.bestIdx >= 0 ? S.trials[S.bestIdx].trial_id : -1); - return `` + - cols.map(c => `${fmtCell(getCellValue(t, c))}`).join('') + - ''; - }).join(''); + const tr = document.createElement('tr'); + if (isBest) tr.className = 'best-row'; + for (const c of cols) { + const td = document.createElement('td'); + td.textContent = fmtCell(getCellValue(t, c)); + tr.appendChild(td); + } + tbody.appendChild(tr); + } } function getCellValue(trial, col) { @@ -733,33 +784,84 @@ function sortTable(col) { // ============================================================================ function renderObjectives() { const container = document.getElementById('objectives-list'); + clearElement(container); if (S.objectives.length === 0) { - container.innerHTML = '
No objectives configured
'; + const empty = document.createElement('div'); + empty.style.color = 'var(--text-2)'; + empty.style.fontSize = '0.82rem'; + empty.style.padding = '8px 0'; + empty.textContent = 'No objectives configured'; + container.appendChild(empty); return; } - container.innerHTML = S.objectives.map((obj, i) => ` -
- ${obj.field} - ${obj.obj_type || obj.type || 'minimize'} - - - - -
- `).join(''); + S.objectives.forEach((obj, i) => { + const row = document.createElement('div'); + row.className = 'objective-row'; + + const field = document.createElement('span'); + field.className = 'obj-field'; + field.textContent = obj.field ?? ''; + + const type = document.createElement('span'); + type.className = 'obj-type'; + type.textContent = obj.obj_type || obj.type || 'minimize'; + + const priorityLabel = makeObjectiveLabel('Priority'); + const priority = document.createElement('input'); + priority.type = 'range'; + priority.min = '0'; + priority.max = '5'; + priority.step = '0.1'; + priority.value = obj.priority ?? 1; + const priorityValue = document.createElement('span'); + priorityValue.className = 'obj-priority-value'; + priorityValue.textContent = priority.value; + priority.addEventListener('input', () => { + S.objectives[i].priority = parseFloat(priority.value); + priorityValue.textContent = priority.value; + }); + priorityLabel.append(priority, priorityValue); + + const targetLabel = makeObjectiveLabel('Target'); + const target = document.createElement('input'); + target.type = 'number'; + target.step = 'any'; + target.value = obj.target ?? ''; + target.addEventListener('change', () => { + S.objectives[i].target = target.value ? parseFloat(target.value) : null; + }); + targetLabel.appendChild(target); + + const limitLabel = makeObjectiveLabel('Limit'); + const limit = document.createElement('input'); + limit.type = 'number'; + limit.step = 'any'; + limit.value = obj.limit ?? ''; + limit.addEventListener('change', () => { + S.objectives[i].limit = limit.value ? parseFloat(limit.value) : null; + }); + limitLabel.appendChild(limit); + + const groupLabel = makeObjectiveLabel('Group'); + const group = document.createElement('input'); + group.type = 'text'; + group.className = 'obj-group-input'; + group.value = obj.group ?? ''; + group.addEventListener('change', () => { + S.objectives[i].group = group.value || null; + }); + groupLabel.appendChild(group); + + row.append(field, type, priorityLabel, targetLabel, limitLabel, groupLabel); + container.appendChild(row); + }); +} + +function makeObjectiveLabel(text) { + const label = document.createElement('label'); + label.className = 'objective-label'; + label.append(document.createTextNode(text)); + return label; } // Client-side TLP rescalarization for preview mode. @@ -813,7 +915,7 @@ async function applyObjectives() { if (S.mode !== 'live') return; if (!confirm('This will update the server objectives and rescalarize all trials. The server will use these objectives for future sampling. Continue?')) return; try { - const resp = await fetch(`${S.serverUrl}/api/objectives`, { + const resp = await apiFetch(`${S.serverUrl}/api/objectives`, { method: 'PATCH', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ objectives: S.objectives }), @@ -837,13 +939,16 @@ async function applyObjectives() { async function saveCheckpoint() { if (S.mode !== 'live') return; try { - const resp = await fetch(`${S.serverUrl}/api/checkpoint/save`, { + const resp = await apiFetch(`${S.serverUrl}/api/checkpoint/save`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ description: `Dashboard save at ${new Date().toISOString()}` }), }); const data = await resp.json(); - if (resp.ok) alert(`Checkpoint saved: ${data.path} (${data.trials_saved} trials)`); + if (resp.ok) { + const kind = data.checkpoint_type ? `${data.checkpoint_type} checkpoint` : 'checkpoint'; + alert(`Saved ${kind}: ${data.path} (${data.trials_saved} trials)`); + } else alert('Save failed: ' + (data.error || 'unknown')); } catch (e) { alert('Save failed: ' + e.message); diff --git a/dashboard/styles.css b/dashboard/styles.css index edbc259..6badb94 100644 --- a/dashboard/styles.css +++ b/dashboard/styles.css @@ -309,6 +309,26 @@ tr.best-row { background: var(--bg-3); border-radius: 4px; } +.objective-label { + font-size: 0.75rem; + color: var(--text-2); +} +.obj-priority-value { + font-family: var(--mono); + color: var(--text-1); + min-width: 30px; + display: inline-block; +} +.obj-group-input { + width: 80px; + padding: 4px 8px; + background: var(--bg-2); + border: 1px solid var(--border); + border-radius: 4px; + color: var(--text-0); + font-family: var(--mono); + font-size: 0.78rem; +} input[type="range"] { -webkit-appearance: none; width: 120px; diff --git a/dashboard/xss-smoke-checkpoint.json b/dashboard/xss-smoke-checkpoint.json new file mode 100644 index 0000000..52a8a80 --- /dev/null +++ b/dashboard/xss-smoke-checkpoint.json @@ -0,0 +1,22 @@ +{ + "leaderboard": { + "trials": [ + { + "trial_id": 0, + "candidate": { + "": "" + }, + "raw_metrics": { + "": "", + "loss": 0.5 + }, + "score_vector": { + "loss": 0.5 + }, + "rank": 0, + "pareto_front": 0, + "timestamp": 0 + } + ] + } +} diff --git a/docs/cli-guide.md b/docs/cli-guide.md index a8f9d0a..7f85768 100644 --- a/docs/cli-guide.md +++ b/docs/cli-guide.md @@ -172,7 +172,7 @@ strategy: type: gmm # "gmm" (default), "sobol", or "random" refit_interval: 20 # how often GMM refits (used by "gmm") seed: 42 # optional seed for reproducible runs - exploration_budget: 50 # number of Sobol trials before switching to GMM + exploration_budget: 50 # number of Sobol asks before switching to GMM elite_fraction: 0.25 # fraction of top trials used for GMM fitting (default: 0.25) ``` @@ -181,7 +181,7 @@ strategy: | `type` | `"gmm"` | Strategy type: `"gmm"`, `"sobol"`, or `"random"` | | `refit_interval` | `20` | How often the GMM refits (only used by `"gmm"`) | | `seed` | none | Seed for reproducible runs. When omitted, Sobol uses 42, others use random seeds. | -| `exploration_budget` | none | Number of Sobol exploration trials before switching to GMM exploitation. When omitted, we use a formula based on `total_budget`. | +| `exploration_budget` | none | Number of issued Sobol exploration suggestions before switching to GMM exploitation. Pending asks count against this budget. When omitted, we use a formula based on `total_budget`. | | `elite_fraction` | `0.25` | Fraction of top trials used for GMM refitting. Must be in (0.0, 1.0]. | ### Checkpoint Configuration @@ -203,11 +203,16 @@ hola serve config.yaml --port 8000 | Flag | Default | Description | |------|---------|-------------| | `config` | required | Path to the YAML configuration file | +| `--host` | `127.0.0.1` | Host/interface to bind. Use `0.0.0.0` explicitly for network access | | `--port` | `8000` | Port to listen on | | `--dashboard` | none | Path to a dashboard directory to serve at `/` (e.g. `--dashboard ./dashboard`) | +| `--auth-token` | none | Bearer token required for write-capable API endpoints | +| `--checkpoint-dir` | checkpoint config directory or config file directory | Directory where dashboard/API checkpoint saves are allowed | +| `--cors-origin` | none | Allowed browser CORS origin. Repeat for multiple origins | -The server starts listening on `0.0.0.0:` and exposes -the [REST API](rest-api.md). +The server starts listening on `127.0.0.1:` by default and exposes +the [REST API](rest-api.md). Binding a non-local host requires `--auth-token` +or the `HOLA_API_TOKEN` environment variable. ## Running Workers @@ -220,6 +225,7 @@ hola worker --server http://localhost:8000 --exec "python train.py" | `--server` | required | URL of the HOLA server | | `--exec` | required | Shell command to execute for each trial | | `--mode` | `callback` | Worker mode: `"callback"` or `"exec"` | +| `--token` | none | Bearer token for servers started with `--auth-token` | ### Callback mode (default) @@ -385,13 +391,13 @@ others. **Machine A (server):** ```bash -hola serve config.yaml --port 8000 +hola serve config.yaml --host 0.0.0.0 --port 8000 --auth-token "$HOLA_API_TOKEN" ``` **Machines B, C, D (workers):** ```bash -hola worker --server http://machine-a:8000 --exec "python train.py" +hola worker --server http://machine-a:8000 --token "$HOLA_API_TOKEN" --exec "python train.py" ``` Each worker independently polls the server for trials. The @@ -447,7 +453,9 @@ checkpoint: ``` This saves a checkpoint every 50 completed trials, keeping -the 5 most recent. +the 5 most recent. Automatic checkpoints are leaderboard +checkpoints: they preserve completed trials and can be used to +warm-start a restarted server. ### Manual checkpointing @@ -460,6 +468,8 @@ curl -X POST http://localhost:8000/api/checkpoint/save \ ``` Or from the dashboard's Checkpoints panel. +Manual REST and dashboard saves write full checkpoints with +completed trials, strategy state, and study configuration. ### Resuming from a checkpoint @@ -473,8 +483,13 @@ checkpoint: load_from: ./checkpoints/checkpoint_000100.json ``` -On startup, the server loads the leaderboard (trial history) -from the specified checkpoint file. We then refit the strategy -from the loaded data, so optimization resumes with full -knowledge of previous trials. This is useful for continuing a -study after a server restart or crash. +On startup, the server loads the specified checkpoint file. Full +checkpoints restore both the leaderboard and search strategy state. +Legacy leaderboard-only checkpoints are still accepted as a +warm-start path; they restore completed trials but do not restore +strategy state. + +Checkpoint loads intentionally clear any pending or cancelled +in-flight trials from the engine state. Restored studies resume from +the completed trial history, and the next `ask` receives a fresh ID +after the restored completed trials. diff --git a/docs/concepts.md b/docs/concepts.md index e7e4191..233252c 100644 --- a/docs/concepts.md +++ b/docs/concepts.md @@ -115,6 +115,10 @@ The lifecycle follows three phases. 3. **Exploit.** New samples are drawn from the updated GMM, focusing on promising regions. +The exploration budget counts issued suggestions from `ask`, +including suggestions that are still pending. GMM refits are based +on completed trials in the leaderboard. + This strategy works well for larger budgets (50+ trials) where you want to transition from exploration to exploitation. The more trials you run, the more the GMM focuses on the best regions. @@ -233,13 +237,14 @@ Each trial record contains the following fields. ## Persistence -We support atomic JSON checkpoints that capture the leaderboard -state. +We support atomic JSON checkpoints for both warm starts and exact +resumes. - **Leaderboard checkpoint.** All completed trials with params, scores, metrics, and timestamps. -- **Strategy state.** The current state of the search strategy - (e.g., Sobol sequence position, GMM parameters). +- **Full checkpoint.** A leaderboard checkpoint plus the current + search strategy state (e.g., Sobol sequence position, GMM + parameters) and study configuration. Checkpoints enable the following. @@ -247,5 +252,11 @@ Checkpoints enable the following. - Offline analysis in the dashboard - Carrying over a leaderboard to a new engine (warm-start) +Loading a checkpoint replaces the in-memory pending and cancelled +sets. Pending or cancelled in-flight trials from the pre-load engine +state do not survive the load; the restored engine resumes from +completed trials and issues fresh trial IDs after the restored +leaderboard. + We write checkpoint files atomically (first to a temp file, then rename) to prevent corruption. diff --git a/docs/python-guide.md b/docs/python-guide.md index 0db7c94..c351290 100644 --- a/docs/python-guide.md +++ b/docs/python-guide.md @@ -423,7 +423,8 @@ Study(strategy=Gmm(refit_interval=10, elite_fraction=0.1), ...) Gaussian Mixture Model strategy. Uses Sobol exploration followed by GMM exploitation. Refits a GMM to the top `elite_fraction` (default 25%) of trials every `refit_interval` (default 20) -completed trials. Uses the +completed trials. The exploration budget counts issued `ask` +suggestions, including pending trials. Uses the [HOLA algorithm](concepts.md#gmm-strategy). - Best for larger budgets (50+ trials) where exploration can @@ -444,7 +445,7 @@ Study(strategy=Gmm(refit_interval=10, elite_fraction=0.1), ...) |-----------|------|---------|-------------| | `refit_interval` | `int` or `None` | 20 | How often the GMM is refit, in completed trials | | `elite_fraction` | `float` or `None` | 0.25 | Fraction of top trials used for refitting. Must be in (0, 1]. | -| `exploration_budget` | `int` or `None` | auto | Number of Sobol exploration trials before GMM exploitation begins. When omitted, computed automatically from the number of dimensions. | +| `exploration_budget` | `int` or `None` | auto | Number of issued Sobol exploration suggestions before GMM exploitation begins. Pending asks count against this budget. When omitted, computed automatically from the number of dimensions. | ### Sobol diff --git a/docs/rest-api.md b/docs/rest-api.md index ace3bcc..9336a9a 100644 --- a/docs/rest-api.md +++ b/docs/rest-api.md @@ -17,6 +17,22 @@ http://localhost:8000 The port is configurable via `hola serve --port `. +## Authentication + +By default, a local HOLA server does not require authentication. +When the server is started with `--auth-token `, all +write-capable endpoints require this header: + +```http +Authorization: Bearer +``` + +This applies to `POST /api/ask`, `POST /api/tell`, +`POST /api/cancel`, `PATCH /api/objectives`, and +`POST /api/checkpoint/save`. Read-only endpoints remain +available without a token. The CLI requires an auth token when +binding the server to a non-local host. + ## Error Format All error responses return a JSON object with an `error` @@ -89,7 +105,25 @@ Report the result of a trial. ```json { "status": "ok", - "trial_count": 1 + "trial_count": 1, + "trial": { + "trial_id": 0, + "params": { + "learning_rate": 0.00316, + "num_layers": 5, + "optimizer": "adam", + "momentum": 0.85 + }, + "score_vector": {"loss": 0.42}, + "scores": {"loss": 0.42}, + "metrics": { + "loss": 0.42, + "latency": 120.5 + }, + "rank": 0, + "pareto_front": 0, + "completed_at": 1736935800 + } } ``` @@ -97,9 +131,13 @@ Report the result of a trial. |-------|------|-------------| | `status` | string | `"ok"` on success | | `trial_count` | integer | Total number of completed trials after this tell | +| `trial` | object | Newly completed trial created by this `tell` | -**Error (400).** Returned if the trial ID is unknown or has -already been told. +The returned `trial.trial_id` matches the `trial_id` in the +request. + +**Error (400).** Returned if the trial ID is unknown, cancelled, +or has already been told. ```json {"error": "Trial 0 has already been completed"} @@ -123,7 +161,7 @@ Get the top k trials found so far. | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `k` | integer | 1 | Number of top trials to return | +| `k` | integer | required | Number of top trials to return | **Response (200)** @@ -243,11 +281,8 @@ Each element in the array has the following fields. | `pareto_front` | integer | 0-indexed Pareto front index | | `completed_at` | integer | Unix timestamp in seconds | -**Error (400).** Returned for single-group (scalar) studies. - -```json -"pareto_front() is only available for multi-objective studies (objectives with distinct groups)" -``` +For scalar studies, this endpoint returns an empty array because +there are no Pareto fronts to report. **Example** @@ -271,39 +306,32 @@ Get all completed trials. **Response (200)** ```json -{ - "trials": [ - { - "trial_id": 0, - "params": {"learning_rate": 0.01, "num_layers": 3}, - "score_vector": {"loss": 0.85}, - "scores": {"loss": 0.85}, - "metrics": {"loss": 0.85, "latency": 50.2}, - "rank": 1, - "pareto_front": 0, - "completed_at": 1736935800 - }, - { - "trial_id": 1, - "params": {"learning_rate": 0.001, "num_layers": 7}, - "score_vector": {"loss": 0.42}, - "scores": {"loss": 0.42}, - "metrics": {"loss": 0.42, "latency": 120.5}, - "rank": 0, - "pareto_front": 0, - "completed_at": 1736935805 - } - ], - "total": 2 -} +[ + { + "trial_id": 0, + "params": {"learning_rate": 0.01, "num_layers": 3}, + "score_vector": {"loss": 0.85}, + "scores": {"loss": 0.85}, + "metrics": {"loss": 0.85, "latency": 50.2}, + "rank": 1, + "pareto_front": 0, + "completed_at": 1736935800 + }, + { + "trial_id": 1, + "params": {"learning_rate": 0.001, "num_layers": 7}, + "score_vector": {"loss": 0.42}, + "scores": {"loss": 0.42}, + "metrics": {"loss": 0.42, "latency": 120.5}, + "rank": 0, + "pareto_front": 0, + "completed_at": 1736935805 + } +] ``` -| Field | Type | Description | -|-------|------|-------------| -| `trials` | array | All completed trials | -| `total` | integer | Total number of completed trials | - -Each trial in the array has the following fields. +The response is an array of completed trials. Each trial has the +following fields. | Field | Type | Description | |-------|------|-------------| @@ -324,6 +352,53 @@ curl http://localhost:8000/api/trials --- +### GET /api/trial/{trial_id} + +Get one completed trial by public trial ID. + +**Path parameters** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `trial_id` | integer | Public trial ID returned by `ask` and `tell` | + +**Query parameters** + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `include_infeasible` | boolean | true | Whether to return infeasible completed trials | + +**Response (200)** + +```json +{ + "trial_id": 17, + "params": {"learning_rate": 0.001, "num_layers": 4}, + "score_vector": {"loss": 0.42}, + "scores": {"loss": 0.42}, + "metrics": {"loss": 0.42, "latency": 95.3}, + "rank": 0, + "pareto_front": 0, + "completed_at": 1736935800 +} +``` + +**Error (404).** Returned if no completed trial exists for the +given ID, or if the trial is infeasible and `include_infeasible` +is false. + +```json +{"error": "Trial 17 not found"} +``` + +**Example** + +```bash +curl http://localhost:8000/api/trial/17 +``` + +--- + ### GET /api/objectives Get the current objective configuration. @@ -458,20 +533,20 @@ curl http://localhost:8000/api/space ### POST /api/checkpoint/save -Save the current state as a JSON checkpoint file. +Save the current server state as a full JSON checkpoint file. **Request** ```json { - "path": "/tmp/checkpoint.json", + "path": "checkpoint.json", "description": "After 100 trials" } ``` | Field | Type | Required | Description | |-------|------|----------|-------------| -| `path` | string | no | File path for the checkpoint (default: `"checkpoint.json"`) | +| `path` | string | no | Relative path under the configured checkpoint directory (default: `"checkpoint.json"`) | | `description` | string | no | Optional description stored in the checkpoint metadata | **Response (200)** @@ -479,11 +554,21 @@ Save the current state as a JSON checkpoint file. ```json { "status": "ok", - "path": "/tmp/checkpoint.json", + "checkpoint_type": "full", + "path": "./checkpoint.json", "trials_saved": 100 } ``` +The saved file includes completed trials, strategy state, and study +configuration. The returned `path` is the resolved server-side path. + +**Error (400)** + +```json +{"error": "Checkpoint path must be relative to the configured checkpoint directory"} +``` + **Error (500)** ```json @@ -568,7 +653,21 @@ The dashboard uses this endpoint for live monitoring. **TrialCompleted.** Emitted after each successful `tell`. ```json -{"type": "TrialCompleted", "trial_id": 42, "score": 0.42} +{ + "type": "TrialCompleted", + "trial_id": 42, + "score": 0.42, + "trial": { + "trial_id": 42, + "params": {"learning_rate": 0.001}, + "score_vector": {"loss": 0.42}, + "scores": {"loss": 0.42}, + "metrics": {"loss": 0.42}, + "rank": 0, + "pareto_front": 0, + "completed_at": 1736935800 + } +} ``` **RefitOccurred.** Emitted when the GMM strategy is refit. @@ -603,7 +702,7 @@ TRIAL_ID=$(echo "$TRIAL" | jq '.trial_id') curl -s -X POST http://localhost:8000/api/tell \ -H "Content-Type: application/json" \ -d "{\"trial_id\": $TRIAL_ID, \"metrics\": {\"loss\": 0.42, \"latency\": 120}}" -# {"status":"ok","trial_count":1} +# {"status":"ok","trial_count":1,"trial":{"trial_id":0,...}} # Check the top trial curl -s http://localhost:8000/api/top_k?k=1 diff --git a/hola-cli/Cargo.toml b/hola-cli/Cargo.toml index a4feec8..4cdb441 100644 --- a/hola-cli/Cargo.toml +++ b/hola-cli/Cargo.toml @@ -14,5 +14,6 @@ clap = { version = "4", features = ["derive"] } serde_yaml = "0.9" serde_json = "1" serde = { version = "1", features = ["derive"] } -tokio = { version = "1", features = ["full"] } +# CLI needs the Tokio main macro, multi-threaded runtime for reqwest, and retry sleeps. +tokio = { version = "1", default-features = false, features = ["macros", "rt-multi-thread", "time"] } reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } diff --git a/hola-cli/src/main.rs b/hola-cli/src/main.rs index da6b239..c930701 100644 --- a/hola-cli/src/main.rs +++ b/hola-cli/src/main.rs @@ -13,6 +13,8 @@ use clap::{Parser, Subcommand}; use hola::hola_engine::{HolaEngine, StudyConfig}; +use hola::server::ServerOptions; +use std::net::IpAddr; use std::path::PathBuf; #[derive(Parser)] @@ -32,12 +34,24 @@ enum Commands { Serve { /// Path to the study YAML config file. config: PathBuf, + /// Host/interface to bind. Defaults to localhost; use 0.0.0.0 explicitly for network access. + #[arg(long, default_value = "127.0.0.1")] + host: String, /// Port to listen on. #[arg(long, default_value = "8000")] port: u16, /// Serve the dashboard UI from this directory. #[arg(long)] dashboard: Option, + /// Bearer token required for write-capable API endpoints. + #[arg(long)] + auth_token: Option, + /// Directory where dashboard/API checkpoint saves are allowed. + #[arg(long)] + checkpoint_dir: Option, + /// Allowed CORS origin. May be provided multiple times. + #[arg(long = "cors-origin")] + cors_origins: Vec, }, /// Run a worker that polls the server for trials. /// @@ -60,6 +74,9 @@ enum Commands { /// Worker mode: "callback" (default) or "exec". #[arg(long, default_value = "callback")] mode: String, + /// Bearer token for servers started with --auth-token. + #[arg(long)] + token: Option, }, } @@ -69,6 +86,29 @@ fn load_config(path: &PathBuf) -> Result Ok(config) } +fn is_local_host(host: &str) -> bool { + if host.eq_ignore_ascii_case("localhost") { + return true; + } + host.parse::().is_ok_and(|ip| ip.is_loopback()) +} + +fn configured_token(cli_token: Option) -> Option { + cli_token + .or_else(|| std::env::var("HOLA_API_TOKEN").ok()) + .filter(|token| !token.is_empty()) +} + +fn with_bearer_auth( + request: reqwest::RequestBuilder, + token: Option<&str>, +) -> reqwest::RequestBuilder { + match token { + Some(token) => request.bearer_auth(token), + None => request, + } +} + #[tokio::main] async fn main() -> Result<(), Box> { let cli = Cli::parse(); @@ -76,38 +116,85 @@ async fn main() -> Result<(), Box> { match cli.command { Commands::Serve { config, + host, port, dashboard, + auth_token, + checkpoint_dir, + cors_origins, } => { let study_config = load_config(&config)?; let load_from = study_config .checkpoint .as_ref() .and_then(|c| c.load_from.clone()); + let config_checkpoint_dir = study_config + .checkpoint + .as_ref() + .map(|checkpoint| PathBuf::from(&checkpoint.directory)); let engine = HolaEngine::from_config(study_config) .map_err(|e| format!("Failed to create engine: {e}"))?; if let Some(path) = load_from { - engine - .load_leaderboard_checkpoint(&path) + let checkpoint_kind = engine + .load_checkpoint_with_fallback(&path) .await .map_err(|e| format!("Failed to load checkpoint '{path}': {e}"))?; + eprintln!("Loaded {} checkpoint from {path}", checkpoint_kind.as_str()); + } + + let auth_token = configured_token(auth_token); + if !is_local_host(&host) && auth_token.is_none() { + return Err( + "--auth-token or HOLA_API_TOKEN is required when --host is not localhost" + .into(), + ); } - hola::server::serve(engine, port, dashboard.as_deref()).await?; + let mut options = ServerOptions::new(port); + options.host = host; + options.dashboard_dir = dashboard; + options.auth_token = auth_token; + options.checkpoint_dir = checkpoint_dir + .or(config_checkpoint_dir) + .or_else(|| config.parent().map(|path| path.to_path_buf())) + .unwrap_or_else(|| PathBuf::from(".")); + options.cors_allowed_origins = cors_origins; + + hola::server::serve_with_options(engine, options).await?; } - Commands::Worker { server, exec, mode } => { + Commands::Worker { + server, + exec, + mode, + token, + } => { let exec_mode = mode == "exec"; + let token = configured_token(token); eprintln!("Worker connecting to {server} ({mode} mode)..."); eprintln!("Will execute: {exec}"); let client = reqwest::Client::new(); loop { - let resp = client.post(format!("{server}/api/ask")).send().await; + let resp = + with_bearer_auth(client.post(format!("{server}/api/ask")), token.as_deref()) + .send() + .await; match resp { Ok(r) => { + if !r.status().is_success() { + let status = r.status(); + let body = r + .text() + .await + .unwrap_or_else(|_| "unknown error".to_string()); + eprintln!("Server returned {status}: {body}. Retrying in 5s..."); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + continue; + } + let trial: serde_json::Value = r.json().await?; let params = trial.get("params").cloned().unwrap_or_default(); let trial_id = trial.get("trial_id").and_then(|v| v.as_u64()).unwrap_or(0); @@ -125,8 +212,8 @@ async fn main() -> Result<(), Box> { let metrics: serde_json::Value = serde_json::from_str(stdout.trim()) .unwrap_or_else(|_| serde_json::json!({"error": "parse_failed"})); - let tell_resp = client - .post(format!("{server}/api/tell")) + let tell_resp = client.post(format!("{server}/api/tell")); + let tell_resp = with_bearer_auth(tell_resp, token.as_deref()) .json(&serde_json::json!({ "trial_id": trial_id, "metrics": metrics, @@ -145,13 +232,17 @@ async fn main() -> Result<(), Box> { } else { // Callback mode (default): script calls // POST /api/tell itself via HOLA_SERVER. - let status = std::process::Command::new("sh") + let mut command = std::process::Command::new("sh"); + command .arg("-c") .arg(&exec) .env("HOLA_SERVER", &server) .env("HOLA_TRIAL_ID", trial_id.to_string()) - .env("HOLA_PARAMS", params.to_string()) - .status()?; + .env("HOLA_PARAMS", params.to_string()); + if let Some(token) = &token { + command.env("HOLA_API_TOKEN", token); + } + let status = command.status()?; if status.success() { eprintln!("Trial {trial_id}: script exited successfully"); @@ -160,11 +251,13 @@ async fn main() -> Result<(), Box> { "Trial {trial_id}: script failed (exit {}), canceling", status.code().unwrap_or(-1) ); - let _ = client - .post(format!("{server}/api/cancel")) - .json(&serde_json::json!({"trial_id": trial_id})) - .send() - .await; + let _ = with_bearer_auth( + client.post(format!("{server}/api/cancel")), + token.as_deref(), + ) + .json(&serde_json::json!({"trial_id": trial_id})) + .send() + .await; } } } diff --git a/hola-py/Cargo.toml b/hola-py/Cargo.toml index ae94580..368c7fb 100644 --- a/hola-py/Cargo.toml +++ b/hola-py/Cargo.toml @@ -12,5 +12,6 @@ crate-type = ["cdylib"] hola_engine = { package = "hola", version = "1.0.1-rc5", path = "../hola", features = ["server"] } pyo3 = { version = "0.28", features = ["extension-module", "abi3-py310"] } serde_json = "1" -tokio = { version = "1", features = ["full"] } +# Python bindings construct Tokio runtimes for local engine calls and remote HTTP. +tokio = { version = "1", default-features = false, features = ["rt-multi-thread"] } reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } diff --git a/hola-py/hola_opt/__init__.pyi b/hola-py/hola_opt/__init__.pyi index dd5c2ae..08d69a7 100644 --- a/hola-py/hola_opt/__init__.pyi +++ b/hola-py/hola_opt/__init__.pyi @@ -191,7 +191,7 @@ class Study: max_trials: int | None = None, ) -> None: ... @staticmethod - def connect(url: str) -> Study: + def connect(url: str, token: str | None = None) -> Study: """Connect to an existing HOLA server.""" ... @staticmethod diff --git a/hola-py/src/lib.rs b/hola-py/src/lib.rs index 8cfe3f5..1c6ba0b 100644 --- a/hola-py/src/lib.rs +++ b/hola-py/src/lib.rs @@ -232,10 +232,10 @@ impl Gmm { exploration_budget: Option, ) -> PyResult { if let Some(ef) = elite_fraction - && (ef <= 0.0 || ef > 1.0) + && (!ef.is_finite() || ef <= 0.0 || ef > 1.0) { return Err(PyValueError::new_err( - "elite_fraction must be between 0.0 (exclusive) and 1.0 (inclusive)", + "elite_fraction must be finite and between 0.0 (exclusive) and 1.0 (inclusive)", )); } if let Some(ri) = refit_interval @@ -426,6 +426,7 @@ enum StudyInner { }, Remote { url: String, + token: Option, client: reqwest::Client, runtime: tokio::runtime::Runtime, }, @@ -478,6 +479,16 @@ fn extract_objectives(objectives: &Bound<'_, PyList>) -> PyResult, +) -> reqwest::RequestBuilder { + match token.as_deref() { + Some(token) => request.bearer_auth(token), + None => request, + } +} + #[pymethods] impl Study { #[new] @@ -568,13 +579,15 @@ impl Study { /// Connect to an existing HOLA server. #[staticmethod] - fn connect(url: &str) -> PyResult { + #[pyo3(signature = (url, token=None))] + fn connect(url: &str, token: Option) -> PyResult { let runtime = tokio::runtime::Runtime::new() .map_err(|e| PyValueError::new_err(format!("Failed to create runtime: {e}")))?; let client = reqwest::Client::new(); Ok(Self { inner: StudyInner::Remote { url: url.trim_end_matches('/').to_string(), + token, client, runtime, }, @@ -621,19 +634,24 @@ impl Study { } StudyInner::Remote { url, + token, client, runtime, } => { let resp: serde_json::Value = runtime .block_on(async { - client - .post(format!("{url}/api/ask")) + let resp = with_remote_auth(client.post(format!("{url}/api/ask")), token) .send() .await - .map_err(|e| format!("HTTP error: {e}"))? - .json() - .await - .map_err(|e| format!("JSON error: {e}")) + .map_err(|e| format!("HTTP error: {e}"))?; + if !resp.status().is_success() { + let body = resp + .text() + .await + .unwrap_or_else(|_| "unknown error".to_string()); + return Err(format!("Server error: {body}")); + } + resp.json().await.map_err(|e| format!("JSON error: {e}")) }) .map_err(PyValueError::new_err)?; @@ -670,15 +688,13 @@ impl Study { } StudyInner::Remote { url, + token, client, runtime, } => { - // Remote tell returns lightweight response, so we tell then - // fetch the trial's details via top_k. - runtime + let trial_json: serde_json::Value = runtime .block_on(async { - let resp = client - .post(format!("{url}/api/tell")) + let resp = with_remote_auth(client.post(format!("{url}/api/tell")), token) .json(&serde_json::json!({ "trial_id": trial_id, "metrics": raw, @@ -694,14 +710,27 @@ impl Study { .unwrap_or_else(|_| "unknown error".to_string()); return Err(format!("Server error: {body}")); } - Ok(()) - }) - .map_err(PyValueError::new_err)?; + let tell_body: serde_json::Value = + resp.json().await.map_err(|e| format!("JSON error: {e}"))?; + if let Some(trial) = tell_body.get("trial") { + return Ok(trial.clone()); + } - // Fetch the trial details from the server - let trials_resp: Vec = runtime - .block_on(async { - client + let trial_resp = client + .get(format!( + "{url}/api/trial/{trial_id}?include_infeasible=true" + )) + .send() + .await + .map_err(|e| format!("HTTP error: {e}"))?; + if trial_resp.status().is_success() { + return trial_resp + .json() + .await + .map_err(|e| format!("JSON error: {e}")); + } + + let trials_resp: Vec = client .get(format!( "{url}/api/trials?sorted_by=index&include_infeasible=true" )) @@ -710,22 +739,16 @@ impl Study { .map_err(|e| format!("HTTP error: {e}"))? .json() .await - .map_err(|e| format!("JSON error: {e}")) + .map_err(|e| format!("JSON error: {e}"))?; + trials_resp + .into_iter() + .find(|t| t.get("trial_id").and_then(|v| v.as_u64()) == Some(trial_id)) + .ok_or_else(|| format!("Trial {trial_id} not found in server response")) }) .map_err(PyValueError::new_err)?; - // Find the trial we just told - let trial_json = trials_resp - .iter() - .find(|t| t.get("trial_id").and_then(|v| v.as_u64()) == Some(trial_id)) - .ok_or_else(|| { - PyValueError::new_err(format!( - "Trial {trial_id} not found in server response" - )) - })?; - let ct: hola_engine::hola_engine::CompletedTrial = - serde_json::from_value(trial_json.clone()).map_err(|e| { + serde_json::from_value(trial_json).map_err(|e| { PyValueError::new_err(format!("Deserialization error: {e}")) })?; rust_to_py_completed(py, &ct) @@ -741,12 +764,12 @@ impl Study { .map_err(PyValueError::new_err), StudyInner::Remote { url, + token, client, runtime, } => runtime .block_on(async { - let resp = client - .post(format!("{url}/api/cancel")) + let resp = with_remote_auth(client.post(format!("{url}/api/cancel")), token) .json(&serde_json::json!({ "trial_id": trial_id })) .send() .await @@ -777,6 +800,7 @@ impl Study { url, client, runtime, + .. } => { let resp: Vec = runtime .block_on(async { @@ -814,6 +838,7 @@ impl Study { url, client, runtime, + .. } => { let resp: Vec = runtime .block_on(async { @@ -851,6 +876,7 @@ impl Study { url, client, runtime, + .. } => { let resp: Vec = runtime .block_on(async { @@ -879,6 +905,7 @@ impl Study { url, client, runtime, + .. } => { let resp: serde_json::Value = runtime .block_on(async { @@ -906,22 +933,22 @@ impl Study { fn update_objectives(&self, objectives: &Bound<'_, PyList>) -> PyResult<()> { let obj_configs = extract_objectives(objectives)?; match &self.inner { - StudyInner::Local { engine, runtime } => { - runtime.block_on(engine.update_objectives(obj_configs)); - Ok(()) - } + StudyInner::Local { engine, runtime } => runtime + .block_on(engine.update_objectives(obj_configs)) + .map_err(PyValueError::new_err), StudyInner::Remote { url, + token, client, runtime, } => runtime .block_on(async { - let resp = client - .patch(format!("{url}/api/objectives")) - .json(&serde_json::json!({ "objectives": obj_configs })) - .send() - .await - .map_err(|e| format!("HTTP error: {e}"))?; + let resp = + with_remote_auth(client.patch(format!("{url}/api/objectives")), token) + .json(&serde_json::json!({ "objectives": obj_configs })) + .send() + .await + .map_err(|e| format!("HTTP error: {e}"))?; if !resp.status().is_success() { let body = resp .text() diff --git a/hola-py/tests/test_cli.py b/hola-py/tests/test_cli.py index 220bf85..0f762d7 100644 --- a/hola-py/tests/test_cli.py +++ b/hola-py/tests/test_cli.py @@ -57,6 +57,7 @@ def test_serve_starts_responds(cli_binary, tmp_path): port = _find_free_port() config_path = write_yaml_config(tmp_path) url = f"http://localhost:{port}" + stderr = "" proc = subprocess.Popen( [cli_binary, "serve", str(config_path), "--port", str(port)], @@ -69,7 +70,9 @@ def test_serve_starts_responds(cli_binary, tmp_path): assert status == 200 finally: proc.terminate() - proc.wait(timeout=5) + _, stderr = proc.communicate(timeout=5) + stderr = stderr.decode("utf-8", errors="replace") + assert f"127.0.0.1:{port}" in stderr def test_serve_bad_config_exits(cli_binary): @@ -95,6 +98,36 @@ def test_serve_invalid_yaml_exits(cli_binary, tmp_path): assert result.returncode != 0 +def test_serve_invalid_strategy_config_exits_with_helpful_error(cli_binary, tmp_path): + config_path = write_yaml_config(tmp_path, strategy={"type": "soboll"}) + + result = subprocess.run( + [cli_binary, "serve", str(config_path)], + capture_output=True, + text=True, + timeout=10, + ) + assert result.returncode != 0 + assert "Unknown strategy type 'soboll'" in result.stderr + + +def test_serve_invalid_scale_config_exits_with_parameter_name(cli_binary, tmp_path): + config_path = write_yaml_config( + tmp_path, + space={"lr": {"type": "real", "min": 0.001, "max": 0.1, "scale": "log2"}}, + ) + + result = subprocess.run( + [cli_binary, "serve", str(config_path)], + capture_output=True, + text=True, + timeout=10, + ) + assert result.returncode != 0 + assert "Parameter 'lr'" in result.stderr + assert "unknown real scale" in result.stderr + + def test_serve_custom_port(cli_binary, tmp_path): port = _find_free_port() config_path = write_yaml_config(tmp_path) @@ -114,6 +147,20 @@ def test_serve_custom_port(cli_binary, tmp_path): proc.wait(timeout=5) +def test_serve_nonlocal_host_requires_token(cli_binary, tmp_path): + port = _find_free_port() + config_path = write_yaml_config(tmp_path) + + result = subprocess.run( + [cli_binary, "serve", str(config_path), "--host", "0.0.0.0", "--port", str(port)], + capture_output=True, + text=True, + timeout=10, + ) + assert result.returncode != 0 + assert "--auth-token" in result.stderr + + # ========================================================================== # Worker Tests # ========================================================================== diff --git a/hola-py/tests/test_dashboard_security.py b/hola-py/tests/test_dashboard_security.py new file mode 100644 index 0000000..d8469d5 --- /dev/null +++ b/hola-py/tests/test_dashboard_security.py @@ -0,0 +1,37 @@ +# Copyright 2026 BlackRock, Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[2] + + +def test_dashboard_app_avoids_html_sinks_for_untrusted_data(): + app_js = (ROOT / "dashboard" / "app.js").read_text() + + forbidden = [ + "innerHTML", + "insertAdjacentHTML", + "outerHTML", + "document.write", + ] + for token in forbidden: + assert token not in app_js + + +def test_dashboard_xss_smoke_fixture_contains_html_like_values(): + fixture = json.loads((ROOT / "dashboard" / "xss-smoke-checkpoint.json").read_text()) + trial = fixture["leaderboard"]["trials"][0] + + joined = json.dumps(trial) + assert "" in joined + assert "" in joined diff --git a/hola-py/tests/test_hola.py b/hola-py/tests/test_hola.py index 9d90325..dd47a32 100644 --- a/hola-py/tests/test_hola.py +++ b/hola-py/tests/test_hola.py @@ -331,7 +331,58 @@ def test_study_categorical_mixed_space(): # ========================================================================== -# 8. Study.run() Convenience Method +# 8. Checkpoint Persistence +# ========================================================================== + + +def test_study_save_load_resume_uses_fresh_trial_id(tmp_path): + from hola_opt import Minimize, Real, Space, Study + + study = Study(space=Space(x=Real(0.0, 1.0)), objectives=[Minimize("loss")]) + for expected_id, loss in enumerate([0.5, 0.3]): + trial = study.ask() + assert trial.trial_id == expected_id + completed = study.tell(trial.trial_id, {"loss": loss}) + assert completed.trial_id == expected_id + + path = tmp_path / "study.json" + study.save(str(path)) + + restored = Study.load(str(path)) + trial = restored.ask() + assert trial.trial_id == 2 + completed = restored.tell(trial.trial_id, {"loss": 0.1}) + assert completed.trial_id == 2 + assert completed.params == trial.params + assert [trial.trial_id for trial in restored.trials()] == [0, 1, 2] + + +def test_study_save_load_resume_uses_fresh_vector_trial_id(tmp_path): + from hola_opt import Minimize, Real, Space, Study + + study = Study( + space=Space(x=Real(0.0, 1.0)), + objectives=[Minimize("f1", priority=1.0), Minimize("f2", priority=2.0)], + ) + for expected_id, metrics in enumerate([{"f1": 1.0, "f2": 3.0}, {"f1": 2.0, "f2": 1.0}]): + trial = study.ask() + assert trial.trial_id == expected_id + completed = study.tell(trial.trial_id, metrics) + assert completed.trial_id == expected_id + + path = tmp_path / "vector-study.json" + study.save(str(path)) + + restored = Study.load(str(path)) + trial = restored.ask() + assert trial.trial_id == 2 + completed = restored.tell(trial.trial_id, {"f1": 0.5, "f2": 2.5}) + assert completed.trial_id == 2 + assert [trial.trial_id for trial in restored.trials()] == [0, 1, 2] + + +# ========================================================================== +# 9. Study.run() Convenience Method # ========================================================================== diff --git a/hola-py/tests/test_server.py b/hola-py/tests/test_server.py index 2fb54a9..0bd53a8 100644 --- a/hola-py/tests/test_server.py +++ b/hola-py/tests/test_server.py @@ -19,13 +19,15 @@ updates, and Study.connect() ask/tell/top_k/connection-error. """ +import json import os -import tempfile +import socket +import subprocess import pytest -from conftest import http_json +from conftest import _wait_for_server, http_json -from hola_opt import Study +from hola_opt import Minimize, Real, Space, Study # ========================================================================== # REST API Endpoint Tests @@ -61,6 +63,8 @@ def test_ask_tell_best_flow(self): assert status == 200 assert resp["status"] == "ok" assert resp["trial_count"] == 1 + assert resp["trial"]["trial_id"] == trial_id + assert isinstance(resp["trial"]["score_vector"], dict) # Top K (replacement for /api/best) status, best = http_json(f"{self.url}/api/top_k?k=1") @@ -70,6 +74,11 @@ def test_ask_tell_best_flow(self): assert best[0]["trial_id"] == trial_id assert isinstance(best[0]["score_vector"], dict) + status, completed = http_json(f"{self.url}/api/trial/{trial_id}?include_infeasible=true") + assert status == 200 + assert completed["trial_id"] == trial_id + assert completed["metrics"]["loss"] == 0.42 + def test_trials_empty(self): status, body = http_json(f"{self.url}/api/trials?sorted_by=index&include_infeasible=true") assert status == 200 @@ -141,16 +150,136 @@ def test_checkpoint_save(self): body={"trial_id": trial["trial_id"], "metrics": {"loss": 0.5}}, ) - with tempfile.TemporaryDirectory() as td: - ckpt_path = os.path.join(td, "test_checkpoint.json") - status, body = http_json( - f"{self.url}/api/checkpoint/save", + status, body = http_json( + f"{self.url}/api/checkpoint/save", + method="POST", + body={"path": "test_checkpoint.json"}, + ) + assert status == 200 + assert body["status"] == "ok" + assert body["checkpoint_type"] == "full" + assert os.path.exists(body["path"]) + restored = Study.load(body["path"]) + assert restored.trial_count() == 1 + + def test_checkpoint_save_rejects_absolute_path(self): + status, body = http_json( + f"{self.url}/api/checkpoint/save", + method="POST", + body={"path": "/tmp/hola_escape.json"}, + ) + assert status == 400 + assert "relative" in body["error"] + + +def _free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _write_sobol_server_config(tmp_path, *, load_from=None): + load_from_line = "" + if load_from is not None: + load_from_line = f" load_from: {json.dumps(str(load_from))}\n" + config_path = tmp_path / ("loaded.yaml" if load_from else "study.yaml") + config_path.write_text( + "space:\n" + " x:\n" + " type: real\n" + " min: 0.0\n" + " max: 1.0\n" + "objectives:\n" + " - field: loss\n" + " type: minimize\n" + " priority: 1.0\n" + "strategy:\n" + " type: sobol\n" + " seed: 123\n" + "checkpoint:\n" + f" directory: {json.dumps(str(tmp_path))}\n" + " interval: 50\n" + " max_checkpoints: 5\n" + f"{load_from_line}", + encoding="utf-8", + ) + return config_path + + +def _start_server(cli_binary, config_path, port): + url = f"http://localhost:{port}" + proc = subprocess.Popen( + [cli_binary, "serve", str(config_path), "--port", str(port)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if not _wait_for_server(url): + proc.kill() + stderr = proc.stderr.read().decode() if proc.stderr else "" + pytest.fail(f"Server failed to start within timeout. stderr: {stderr}") + return proc, url + + +def _stop_server(proc): + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + + +def test_cli_load_from_rest_full_checkpoint_preserves_sobol_sequence( + cli_binary, free_port, tmp_path +): + baseline = Study( + space=Space(x=Real(0.0, 1.0)), + objectives=[Minimize("loss")], + strategy="sobol", + seed=123, + ) + for _ in range(3): + baseline.ask() + expected = baseline.ask() + + config_path = _write_sobol_server_config(tmp_path) + proc, url = _start_server(cli_binary, config_path, free_port) + try: + for _ in range(3): + status, trial = http_json(f"{url}/api/ask", method="POST") + assert status == 200 + status, _ = http_json( + f"{url}/api/tell", method="POST", - body={"path": ckpt_path}, + body={ + "trial_id": trial["trial_id"], + "metrics": {"loss": trial["params"]["x"]}, + }, ) assert status == 200 - assert body["status"] == "ok" - assert os.path.exists(ckpt_path) + + status, body = http_json( + f"{url}/api/checkpoint/save", + method="POST", + body={"path": "server-full.json", "description": "server full"}, + ) + assert status == 200 + assert body["checkpoint_type"] == "full" + checkpoint_path = body["path"] + finally: + _stop_server(proc) + + restored = Study.load(checkpoint_path) + assert restored.ask().params == expected.params + + loaded_config_path = _write_sobol_server_config(tmp_path, load_from=checkpoint_path) + proc, loaded_url = _start_server(cli_binary, loaded_config_path, _free_port()) + try: + status, trial = http_json(f"{loaded_url}/api/ask", method="POST") + assert status == 200 + assert trial["trial_id"] == 3 + assert trial["params"] == expected.params + finally: + _stop_server(proc) class TestRestMultiParam: @@ -210,6 +339,19 @@ def test_update_objectives(self, running_server): assert status == 200 assert body["status"] == "ok" + @pytest.mark.server_config( + objectives=[{"field": "loss", "type": "minimize", "priority": 1.0}], + ) + def test_update_objectives_rejects_invalid_type(self, running_server): + url = running_server + status, body = http_json( + f"{url}/api/objectives", + method="PATCH", + body={"objectives": [{"field": "accuracy", "type": "larger", "priority": 1.0}]}, + ) + assert status == 400 + assert "Objective 'accuracy'" in body["error"] + # ========================================================================== # Study.connect() Live Integration Tests @@ -227,7 +369,9 @@ def test_study_connect_ask_tell_best(self): assert t.trial_id == 0 assert "x" in t.params - remote.tell(t.trial_id, {"loss": 0.42}) + completed = remote.tell(t.trial_id, {"loss": 0.42}) + assert completed.trial_id == t.trial_id + assert isinstance(completed.score_vector, dict) top = remote.top_k(1) assert len(top) == 1 diff --git a/hola-py/tests/test_study_advanced.py b/hola-py/tests/test_study_advanced.py index 2a26f8e..ec20fe8 100644 --- a/hola-py/tests/test_study_advanced.py +++ b/hola-py/tests/test_study_advanced.py @@ -13,7 +13,7 @@ Advanced Study tests covering error paths, scalarization, convergence, concurrency, and best-tracking. -Tests double-tell errors, bad strategy fallback, maximize/minimize +Tests double-tell errors, config validation, maximize/minimize scalarization, TLP feasibility, multi-objective priorities, Sobol properties, GMM refit, end-to-end convergence, concurrent ask/tell, and monotonic best-tracking. @@ -24,7 +24,7 @@ import pytest -from hola_opt import Categorical, Integer, Maximize, Minimize, Real, Space, Study +from hola_opt import Categorical, Gmm, Integer, Maximize, Minimize, Real, Space, Study # ========================================================================== # Error Paths @@ -39,12 +39,35 @@ def test_double_tell_raises(simple_space): study.tell(t.trial_id, {"loss": 0.3}) -def test_bad_strategy_string_defaults_gracefully(simple_space): - # Unknown strategy names do not raise; the engine falls back to a default. - study = Study(space=simple_space, objectives=[Minimize("loss")], strategy="nonexistent") - t = study.ask() - study.tell(t.trial_id, {"loss": 0.5}) - assert study.trial_count() == 1 +def test_bad_strategy_string_raises(simple_space): + with pytest.raises(ValueError, match="Unknown strategy type.*nonexistent"): + Study(space=simple_space, objectives=[Minimize("loss")], strategy="nonexistent") + + +def test_invalid_space_config_raises_with_parameter_name(): + with pytest.raises(ValueError, match="scale must be"): + Real(0.0, 1.0, scale="log2") + + with pytest.raises(ValueError, match="Parameter 'x'.*min must be less than max"): + Study(space=Space(x=Real(1.0, 0.0)), objectives=[Minimize("loss")]) + + with pytest.raises(ValueError, match="Parameter 'layers'.*integer min"): + Study(space=Space(layers=Integer(5, 1)), objectives=[Minimize("loss")]) + + with pytest.raises(ValueError, match="Parameter 'opt'.*choices must not be empty"): + Study(space=Space(opt=Categorical([])), objectives=[Minimize("loss")]) + + with pytest.raises(ValueError, match="Parameter 'x'.*bounds must be finite"): + Study(space=Space(x=Real(float("nan"), 1.0)), objectives=[Minimize("loss")]) + + +def test_invalid_objective_config_raises_with_field_name(simple_space): + with pytest.raises(ValueError, match="Objective 'loss'.*priority"): + Study(space=simple_space, objectives=[Minimize("loss", priority=-1.0)]) + + study = Study(space=simple_space, objectives=[Minimize("loss")]) + with pytest.raises(ValueError, match="Objective 'loss'.*priority"): + study.update_objectives([Minimize("loss", priority=float("inf"))]) def test_run_zero_trials(simple_space): @@ -170,6 +193,47 @@ def test_multi_objective_with_priorities(): assert all(math.isfinite(v) for v in obs.values()) +def test_update_objectives_migrates_scalar_to_vector_leaderboard(): + study = Study( + space=Space(x=Real(0.0, 1.0)), + objectives=[Minimize("f1")], + ) + for metrics in [ + {"f1": 1.0, "f2": 5.0}, + {"f1": 5.0, "f2": 1.0}, + {"f1": 3.0, "f2": 3.0}, + {"f1": 4.0, "f2": 4.0}, + ]: + trial = study.ask() + study.tell(trial.trial_id, metrics) + + assert study.pareto_front() == [] + + study.update_objectives([Minimize("f1"), Minimize("f2")]) + front_ids = sorted(trial.trial_id for trial in study.pareto_front()) + assert front_ids == [0, 1, 2] + + +def test_update_objectives_migrates_vector_to_scalar_leaderboard(): + study = Study( + space=Space(x=Real(0.0, 1.0)), + objectives=[Minimize("f1"), Minimize("f2")], + ) + for metrics in [ + {"f1": 10.0, "f2": 0.0}, + {"f1": 1.0, "f2": 10.0}, + {"f1": 5.0, "f2": 5.0}, + ]: + trial = study.ask() + study.tell(trial.trial_id, metrics) + + assert study.pareto_front() != [] + + study.update_objectives([Minimize("f1")]) + assert study.pareto_front() == [] + assert study.top_k(1)[0].trial_id == 1 + + # ========================================================================== # Sobol Properties # ========================================================================== @@ -216,6 +280,65 @@ def test_gmm_survives_50_trials(): assert best is not None +def _assert_same_params(actual, expected): + assert actual.keys() == expected.keys() + for key in actual: + assert actual[key] == pytest.approx(expected[key]) + + +def test_gmm_counts_pending_asks_against_exploration_budget(): + space = Space(x=Real(0.0, 1.0)) + study = Study( + space=space, + objectives=[Minimize("loss")], + strategy=Gmm(exploration_budget=2), + seed=17, + ) + sobol = Study(space=space, objectives=[Minimize("loss")], strategy="sobol", seed=17) + gmm = Study( + space=space, + objectives=[Minimize("loss")], + strategy=Gmm(exploration_budget=0), + seed=17, + ) + + trials = [study.ask() for _ in range(4)] + + _assert_same_params(trials[0].params, sobol.ask().params) + _assert_same_params(trials[1].params, sobol.ask().params) + _assert_same_params(trials[2].params, gmm.ask().params) + _assert_same_params(trials[3].params, gmm.ask().params) + assert study.trial_count() == 0 + + +def test_gmm_save_load_preserves_pending_ask_accounting(tmp_path): + space = Space(x=Real(0.0, 1.0)) + study = Study( + space=space, + objectives=[Minimize("loss")], + strategy=Gmm(exploration_budget=2), + seed=17, + ) + for _ in range(3): + study.ask() + + path = tmp_path / "auto-pending.json" + study.save(str(path)) + restored = Study.load(str(path)) + + gmm = Study( + space=space, + objectives=[Minimize("loss")], + strategy=Gmm(exploration_budget=0), + seed=17, + ) + gmm.ask() + expected = gmm.ask() + resumed = restored.ask() + + _assert_same_params(resumed.params, expected.params) + + @pytest.mark.slow def test_gmm_vs_random_on_sphere(): def sphere(params): diff --git a/hola/Cargo.toml b/hola/Cargo.toml index a3fa603..3c1c7f3 100644 --- a/hola/Cargo.toml +++ b/hola/Cargo.toml @@ -6,18 +6,19 @@ license = "Apache-2.0" [features] default = [] -server = ["axum", "tokio-stream", "tower-http"] +server = ["axum/http1", "axum/json", "axum/query", "axum/tokio", "tokio/net", "tokio-stream", "tower-http"] [dependencies] opt_engine = { version = "1.0.1-rc5", path = "../opt_engine" } serde = { version = "1", features = ["derive"] } serde_json = "1" -tokio = { version = "1", features = ["full"] } +# Core engine needs async locks and blocking refit tasks; server adds net support. +tokio = { version = "1", default-features = false, features = ["macros", "rt-multi-thread", "sync"] } rand = "0.9.3" chrono = "0.4" -axum = { version = "0.8", optional = true } -tokio-stream = { version = "0.1", features = ["sync"], optional = true } -tower-http = { version = "0.6.8", features = ["cors", "fs"], optional = true } +axum = { version = "0.8", default-features = false, optional = true } +tokio-stream = { version = "0.1", default-features = false, features = ["sync"], optional = true } +tower-http = { version = "0.6.8", default-features = false, features = ["cors", "fs"], optional = true } [dev-dependencies] opt_engine = { version = "1.0.1-rc5", path = "../opt_engine" } diff --git a/hola/src/hola_engine.rs b/hola/src/hola_engine.rs index 8e951fa..cc976b1 100644 --- a/hola/src/hola_engine.rs +++ b/hola/src/hola_engine.rs @@ -31,6 +31,7 @@ use opt_engine::traits::{RefitConfig, SampleSpace, StandardizedSpace, Strategy}; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashSet}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::sync::RwLock; // ============================================================================= @@ -391,12 +392,13 @@ enum DynStrategyInner { /// The default exploration budget follows the formula from the paper: /// `min(floor(S / 5), 50 + 2n)`, where `S` is the intended number of /// simulations and `n` is the dimensionality. -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Debug)] pub struct AutoStrategy { sobol: SobolStrategy, gmm: GmmStrategy, exploration_budget: usize, trial_count: usize, + issued_count: AtomicUsize, } impl AutoStrategy { @@ -429,10 +431,67 @@ impl AutoStrategy { gmm: GmmStrategy::uniform_prior(gmm_seed, dim, 0.1), exploration_budget, trial_count: 0, + issued_count: AtomicUsize::new(0), + } + } +} + +impl Clone for AutoStrategy { + fn clone(&self) -> Self { + Self { + sobol: self.sobol.clone(), + gmm: self.gmm.clone(), + exploration_budget: self.exploration_budget, + trial_count: self.trial_count, + issued_count: AtomicUsize::new(self.issued_count.load(Ordering::Relaxed)), } } } +impl Serialize for AutoStrategy { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + + let mut state = serializer.serialize_struct("AutoStrategy", 5)?; + state.serialize_field("sobol", &self.sobol)?; + state.serialize_field("gmm", &self.gmm)?; + state.serialize_field("exploration_budget", &self.exploration_budget)?; + state.serialize_field("trial_count", &self.trial_count)?; + state.serialize_field("issued_count", &self.issued_count.load(Ordering::Relaxed))?; + state.end() + } +} + +impl<'de> Deserialize<'de> for AutoStrategy { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct AutoStrategySerde { + sobol: SobolStrategy, + gmm: GmmStrategy, + exploration_budget: usize, + trial_count: usize, + #[serde(default)] + issued_count: Option, + } + + let state = AutoStrategySerde::deserialize(deserializer)?; + let issued_count = state.issued_count.unwrap_or(state.trial_count); + Ok(Self { + sobol: state.sobol, + gmm: state.gmm, + exploration_budget: state.exploration_budget, + trial_count: state.trial_count, + issued_count: AtomicUsize::new(issued_count.max(state.trial_count)), + }) + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct DynStrategy { inner: DynStrategyInner, @@ -448,7 +507,8 @@ impl Strategy for DynStrategy { DynStrategyInner::Sobol(s) => s.suggest(space), DynStrategyInner::Gmm(s) => s.suggest(space), DynStrategyInner::Auto(s) => { - if s.trial_count < s.exploration_budget { + let issued = s.issued_count.fetch_add(1, Ordering::Relaxed); + if issued < s.exploration_budget { s.sobol.suggest(space) } else { s.gmm.suggest(space) @@ -464,6 +524,7 @@ impl Strategy for DynStrategy { DynStrategyInner::Gmm(s) => s.update(candidate, observation), DynStrategyInner::Auto(s) => { s.trial_count += 1; + s.issued_count.fetch_max(s.trial_count, Ordering::Relaxed); s.sobol.update(candidate, observation); s.gmm.update(candidate, observation); } @@ -482,7 +543,7 @@ impl opt_engine::traits::RefittableStrategy for DynStrategy { } // ============================================================================= -// Configuration types for constructing DynEngine from YAML/JSON +// Configuration types for constructing HolaEngine from YAML/JSON // ============================================================================= #[derive(Clone, Debug, Serialize, Deserialize)] @@ -588,8 +649,150 @@ fn default_checkpoint_interval() -> usize { 50 } +fn validate_study_config(config: &StudyConfig) -> Result<(), String> { + validate_space_config(&config.space)?; + validate_objectives(&config.objectives)?; + if let Some(strategy) = &config.strategy { + validate_strategy_config(strategy)?; + } + if let Some(checkpoint) = &config.checkpoint + && checkpoint.interval == 0 + { + return Err("checkpoint.interval must be at least 1".to_string()); + } + Ok(()) +} + +fn validate_space_config(space: &BTreeMap) -> Result<(), String> { + if space.is_empty() { + return Err("At least one parameter is required".to_string()); + } + + for (name, param) in space { + if name.trim().is_empty() { + return Err("Parameter names must not be empty".to_string()); + } + match param { + ParamConfig::Real { min, max, scale } => { + if !min.is_finite() || !max.is_finite() { + return Err(format!( + "Parameter '{name}': real bounds must be finite, got min={min}, max={max}", + )); + } + if min >= max { + return Err(format!( + "Parameter '{name}': min must be less than max, got min={min}, max={max}", + )); + } + match scale.as_str() { + "linear" => {} + "log" | "ln" | "log10" => { + if *min <= 0.0 || *max <= 0.0 { + return Err(format!( + "Parameter '{name}': {scale} scale requires min > 0 and max > 0, got min={min}, max={max}", + )); + } + } + other => { + return Err(format!( + "Parameter '{name}': unknown real scale '{other}'. Expected one of: linear, log, ln, log10", + )); + } + } + } + ParamConfig::Integer { min, max } => { + if min > max { + return Err(format!( + "Parameter '{name}': integer min must be <= max, got min={min}, max={max}", + )); + } + } + ParamConfig::Categorical { choices } => { + if choices.is_empty() { + return Err(format!( + "Parameter '{name}': categorical choices must not be empty", + )); + } + } + } + } + + Ok(()) +} + +fn validate_objectives(objectives: &[ObjectiveConfig]) -> Result<(), String> { + if objectives.is_empty() { + return Err("At least one objective is required. \ + Example: objectives: [{ field: \"loss\", type: \"minimize\" }]" + .to_string()); + } + + for obj in objectives { + if obj.field.trim().is_empty() { + return Err("Objective field names must not be empty".to_string()); + } + match obj.obj_type.as_str() { + "minimize" | "maximize" => {} + other => { + return Err(format!( + "Objective '{}': unknown objective type '{}'. Expected 'minimize' or 'maximize'", + obj.field, other + )); + } + } + if !obj.priority.is_finite() || obj.priority < 0.0 { + return Err(format!( + "Objective '{}': priority must be finite and non-negative, got {}", + obj.field, obj.priority + )); + } + if let Some(target) = obj.target + && !target.is_finite() + { + return Err(format!( + "Objective '{}': target must be finite, got {}", + obj.field, target + )); + } + if let Some(limit) = obj.limit + && !limit.is_finite() + { + return Err(format!( + "Objective '{}': limit must be finite, got {}", + obj.field, limit + )); + } + } + + Ok(()) +} + +fn validate_strategy_config(strategy: &StrategyConfig) -> Result<(), String> { + match strategy.strategy_type.as_str() { + "random" | "sobol" | "gmm" | "auto" => {} + other => { + return Err(format!( + "Unknown strategy type '{other}'. Expected one of: random, sobol, gmm, auto", + )); + } + } + + if strategy.refit_interval == 0 { + return Err("strategy.refit_interval must be at least 1".to_string()); + } + if let Some(elite_fraction) = strategy.elite_fraction + && (!elite_fraction.is_finite() || elite_fraction <= 0.0 || elite_fraction > 1.0) + { + return Err(format!( + "strategy.elite_fraction must be finite and in (0, 1], got {elite_fraction}", + )); + } + + Ok(()) +} + // ============================================================================= -// DynEngine: the top-level Ask/Tell interface +// HolaEngine: the top-level Ask/Tell interface // ============================================================================= /// A trial returned by `ask()`. @@ -629,6 +832,24 @@ pub struct CompletedTrial { pub completed_at: u64, } +/// Kind of checkpoint loaded by [`HolaEngine::load_checkpoint_with_fallback`]. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum CheckpointLoadKind { + /// Full checkpoint with leaderboard and strategy state. + Full, + /// Legacy leaderboard-only checkpoint with trial history but no strategy state. + Leaderboard, +} + +impl CheckpointLoadKind { + pub fn as_str(self) -> &'static str { + match self { + CheckpointLoadKind::Full => "full", + CheckpointLoadKind::Leaderboard => "leaderboard", + } + } +} + // ============================================================================= // DynLeaderboard: scalar or vector leaderboard dispatch // ============================================================================= @@ -641,8 +862,17 @@ enum DynLeaderboard { } impl DynLeaderboard { + fn for_objectives(objectives: &[ObjectiveConfig]) -> Self { + if count_priority_groups(objectives) > 1 { + DynLeaderboard::Vector(Leaderboard::new()) + } else { + DynLeaderboard::Scalar(Leaderboard::new()) + } + } + fn push_with_raw( &mut self, + trial_id: u64, candidate: serde_json::Value, raw_metrics: serde_json::Value, objectives: &[ObjectiveConfig], @@ -650,18 +880,39 @@ impl DynLeaderboard { match self { DynLeaderboard::Scalar(lb) => { let score = scalarize_raw(&raw_metrics, objectives); - let id = lb.push_with_raw(candidate, score, raw_metrics); + let id = lb.push_with_raw_trial_id(candidate, score, raw_metrics, trial_id); (id, score) } DynLeaderboard::Vector(lb) => { let obs = vectorize_raw(&raw_metrics, objectives); let score = scalarize_observation(&obs, objectives); - let id = lb.push_with_raw(candidate, obs, raw_metrics); + let id = lb.push_with_raw_trial_id(candidate, obs, raw_metrics, trial_id); (id, score) } } } + fn contains_trial_id(&self, trial_id: u64) -> bool { + match self { + DynLeaderboard::Scalar(lb) => lb.get(trial_id).is_some(), + DynLeaderboard::Vector(lb) => lb.get(trial_id).is_some(), + } + } + + fn next_trial_id(&self) -> u64 { + match self { + DynLeaderboard::Scalar(lb) => lb.next_trial_id(), + DynLeaderboard::Vector(lb) => lb.next_trial_id(), + } + } + + fn normalize_next_trial_id(&mut self) -> u64 { + match self { + DynLeaderboard::Scalar(lb) => lb.normalize_next_trial_id(), + DynLeaderboard::Vector(lb) => lb.normalize_next_trial_id(), + } + } + fn len(&self) -> usize { match self { DynLeaderboard::Scalar(lb) => lb.len(), @@ -705,21 +956,90 @@ impl DynLeaderboard { } } + fn migrate_for_objectives(&mut self, objectives: &[ObjectiveConfig]) { + let should_be_vector = count_priority_groups(objectives) > 1; + match (&mut *self, should_be_vector) { + (DynLeaderboard::Scalar(_), false) | (DynLeaderboard::Vector(_), true) => { + self.rescalarize(objectives); + return; + } + _ => {} + } + + let migrated = match self { + DynLeaderboard::Scalar(lb) => { + let mut migrated = Leaderboard::new(); + for trial in lb.iter() { + let raw_metrics = trial.raw_metrics.clone(); + let raw = raw_metrics.as_ref().unwrap_or(&serde_json::Value::Null); + migrated.push_existing_trial(Trial { + candidate: trial.candidate.clone(), + observation: vectorize_raw(raw, objectives), + raw_metrics, + trial_id: trial.trial_id, + timestamp: trial.timestamp, + }); + } + DynLeaderboard::Vector(migrated) + } + DynLeaderboard::Vector(lb) => { + let mut migrated = Leaderboard::new(); + for trial in lb.iter() { + let raw_metrics = trial.raw_metrics.clone(); + let raw = raw_metrics.as_ref().unwrap_or(&serde_json::Value::Null); + migrated.push_existing_trial(Trial { + candidate: trial.candidate.clone(), + observation: scalarize_raw(raw, objectives), + raw_metrics, + trial_id: trial.trial_id, + timestamp: trial.timestamp, + }); + } + DynLeaderboard::Scalar(migrated) + } + }; + *self = migrated; + } + /// Get a single completed trial by ID, computing its rank and Pareto front. fn get_completed( &self, trial_id: u64, + include_infeasible: bool, objectives: &[ObjectiveConfig], ) -> Option { - // Build the full ranked list, then find the requested trial. - let all = self.completed_trials("rank", false, objectives); - all.into_iter() - .find(|ct| ct.trial_id == trial_id) - .or_else(|| { - // Trial might be infeasible — try again with infeasible included - let all_with_inf = self.completed_trials("rank", true, objectives); - all_with_inf.into_iter().find(|ct| ct.trial_id == trial_id) - }) + match self { + DynLeaderboard::Scalar(lb) => { + let trial = lb.get(trial_id)?.clone(); + if !include_infeasible + && !Leaderboard::::trial_is_feasible(&trial) + { + return None; + } + + let rank = lb + .iter() + .filter(|other| { + include_infeasible + || Leaderboard::::trial_is_feasible(other) + }) + .filter(|other| { + other + .observation + .partial_cmp(&trial.observation) + .unwrap_or(std::cmp::Ordering::Equal) + .then_with(|| other.trial_id.cmp(&trial.trial_id)) + == std::cmp::Ordering::Less + }) + .count(); + + Some(build_completed_scalar(trial, rank, objectives)) + } + DynLeaderboard::Vector(_) => { + let all = self.completed_trials("rank", include_infeasible, objectives); + all.into_iter().find(|ct| ct.trial_id == trial_id) + } + } } /// Return all trials as CompletedTrial with ranking and scoring. @@ -972,14 +1292,18 @@ struct HolaEngineState { cancelled: HashSet, } +impl HolaEngineState { + fn reset_transient_trial_state_after_load(&mut self) { + self.next_pending_id = self.leaderboard.normalize_next_trial_id(); + self.pending.clear(); + self.cancelled.clear(); + } +} + impl HolaEngine { /// Build a HolaEngine from a StudyConfig (parsed from YAML/JSON). pub fn from_config(config: StudyConfig) -> Result { - if config.objectives.is_empty() { - return Err("At least one objective is required. \ - Example: objectives: [{ field: \"loss\", obj_type: \"minimize\" }]" - .to_string()); - } + validate_study_config(&config)?; let mut space = DynSpace::new(); for (name, param) in &config.space { @@ -1051,7 +1375,7 @@ impl HolaEngine { None, ), // "gmm" (default): Sobol exploration followed by GMM exploitation - _ => { + "gmm" | "auto" => { let exploration_budget = strategy_cfg .and_then(|s| s.exploration_budget) .unwrap_or_else(|| { @@ -1073,6 +1397,7 @@ impl HolaEngine { )), ) } + _ => unreachable!("strategy type was validated before construction"), }; let auto_checkpoint = config.checkpoint.as_ref().map(|c| { @@ -1081,11 +1406,7 @@ impl HolaEngine { ac }); - let leaderboard = if count_priority_groups(&config.objectives) > 1 { - DynLeaderboard::Vector(Leaderboard::new()) - } else { - DynLeaderboard::Scalar(Leaderboard::new()) - }; + let leaderboard = DynLeaderboard::for_objectives(&config.objectives); Ok(Self { space, @@ -1119,8 +1440,18 @@ impl HolaEngine { } } let params = state.strategy.suggest(&self.space); - let id = state.next_pending_id; - state.next_pending_id += 1; + let mut id = state.next_pending_id.max(state.leaderboard.next_trial_id()); + while state.pending.contains_key(&id) + || state.cancelled.contains(&id) + || state.leaderboard.contains_trial_id(id) + { + id = id + .checked_add(1) + .ok_or_else(|| "Exhausted trial ID space".to_string())?; + } + state.next_pending_id = id + .checked_add(1) + .ok_or_else(|| "Exhausted trial ID space".to_string())?; state.pending.insert(id, params.clone()); Ok(DynTrial { trial_id: id, @@ -1142,24 +1473,34 @@ impl HolaEngine { return Err(format!("Trial {trial_id} has been cancelled")); } + if state.leaderboard.contains_trial_id(trial_id) { + return Err(format!("Trial {trial_id} has already been completed")); + } + let candidate = state .pending .remove(&trial_id) .ok_or_else(|| format!("Unknown trial_id: {trial_id}"))?; - let (_trial_id, score) = + let (stored_trial_id, score) = state .leaderboard - .push_with_raw(candidate.clone(), raw_metrics, &objectives); + .push_with_raw(trial_id, candidate.clone(), raw_metrics, &objectives); + if stored_trial_id != trial_id { + return Err(format!( + "Internal trial ID mismatch: pending trial {trial_id} was stored as {stored_trial_id}" + )); + } state.strategy.update(&candidate, score); let n_trials = state.leaderboard.len(); - let completed = state - .leaderboard - .get_completed(trial_id, &objectives) - .ok_or_else(|| format!("Failed to build CompletedTrial for {trial_id}"))?; + let leaderboard_snapshot = state.leaderboard.clone(); drop(state); + let completed = leaderboard_snapshot + .get_completed(stored_trial_id, true, &objectives) + .ok_or_else(|| format!("Failed to build CompletedTrial for {stored_trial_id}"))?; + // Auto-refit if configured if let Some(ref config) = self.refit_config && config.should_refit(n_trials) @@ -1248,6 +1589,17 @@ impl HolaEngine { ) } + /// Get a single completed trial by ID with scoring and ranking. + pub async fn completed_trial( + &self, + trial_id: u64, + include_infeasible: bool, + ) -> Option { + let objectives = self.objectives.read().await.clone(); + let leaderboard_snapshot = self.state.read().await.leaderboard.clone(); + leaderboard_snapshot.get_completed(trial_id, include_infeasible, &objectives) + } + /// Get all trials with scoring and ranking. pub async fn trials(&self, sorted_by: &str, include_infeasible: bool) -> Vec { let objectives = self.objectives.read().await.clone(); @@ -1315,13 +1667,15 @@ impl HolaEngine { /// updated scalarization. If a refittable strategy (e.g., GMM) is configured, /// a refit is triggered immediately so the sampling distribution reflects /// the new objective weights. - pub async fn update_objectives(&self, objectives: Vec) { + pub async fn update_objectives(&self, objectives: Vec) -> Result<(), String> { + validate_objectives(&objectives)?; + // Persist the new objectives *self.objectives.write().await = objectives.clone(); // Re-scalarize all historical trials with the new objectives let n_trials = { let mut state = self.state.write().await; - state.leaderboard.rescalarize(&objectives); + state.leaderboard.migrate_for_objectives(&objectives); state.leaderboard.len() }; @@ -1346,6 +1700,7 @@ impl HolaEngine { self.state.write().await.strategy = fitted; } } + Ok(()) } // ========================================================================= @@ -1366,6 +1721,38 @@ impl HolaEngine { self.load_full_checkpoint(path).await } + /// Load a checkpoint, preferring full checkpoints and falling back to + /// legacy leaderboard-only files. + /// + /// This is used by CLI config `checkpoint.load_from`, which historically + /// accepted leaderboard-only checkpoints. Full checkpoints preserve search + /// strategy state; leaderboard-only checkpoints preserve completed trials. + pub async fn load_checkpoint_with_fallback( + &self, + path: impl AsRef, + ) -> std::io::Result { + let path = path.as_ref(); + let raw: serde_json::Value = { + let file = std::fs::File::open(path)?; + let reader = std::io::BufReader::new(file); + serde_json::from_reader(reader) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))? + }; + + let has_strategy_state = raw + .get("checkpoint") + .unwrap_or(&raw) + .get("strategy_state") + .is_some(); + if has_strategy_state { + self.load_full_checkpoint(path).await?; + Ok(CheckpointLoadKind::Full) + } else { + self.load_leaderboard_checkpoint(path).await?; + Ok(CheckpointLoadKind::Leaderboard) + } + } + // ========================================================================= // Persistence (internal) // ========================================================================= @@ -1410,7 +1797,9 @@ impl HolaEngine { eprintln!("[hola] Loaded leaderboard checkpoint with {n} trials"); DynLeaderboard::Scalar(cp.leaderboard) }; - self.state.write().await.leaderboard = leaderboard; + let mut state = self.state.write().await; + state.leaderboard = leaderboard; + state.reset_transient_trial_state_after_load(); Ok(()) } @@ -1503,6 +1892,7 @@ impl HolaEngine { let mut state = self.state.write().await; state.leaderboard = DynLeaderboard::Vector(cp.leaderboard); state.strategy = cp.strategy_state; + state.reset_transient_trial_state_after_load(); eprintln!("[hola] Loaded full checkpoint with {n_loaded} trials"); } else { let cp: opt_engine::persistence::Checkpoint = @@ -1512,6 +1902,7 @@ impl HolaEngine { let mut state = self.state.write().await; state.leaderboard = DynLeaderboard::Scalar(cp.leaderboard); state.strategy = cp.strategy_state; + state.reset_transient_trial_state_after_load(); eprintln!("[hola] Loaded full checkpoint with {n_loaded} trials"); } Ok(()) @@ -1645,7 +2036,7 @@ fn objective_score(val: f64, obj_type: &str, target: Option, limit: Option< (Some(t), Some(l)) => opt_engine::objectives::tlp_score(val, t, l), _ => opt_engine::objectives::directed_value(val, true), }, - _ => val, + _ => f64::INFINITY, } } diff --git a/hola/src/server.rs b/hola/src/server.rs index 0bfc4c5..8271225 100644 --- a/hola/src/server.rs +++ b/hola/src/server.rs @@ -22,19 +22,23 @@ //! - `POST /api/cancel` - Cancel a pending trial //! - `GET /api/top_k` - Get top-k trials by rank //! - `GET /api/pareto_front` - Get Pareto front trials +//! - `GET /api/trial/{trial_id}` - Get one completed trial with scoring/ranking //! - `GET /api/trials` - Get all trials with scoring/ranking //! - `GET /api/trial_count` - Get number of completed trials //! - `PATCH /api/objectives` - Update objectives mid-run //! - `GET /api/objectives` - Get current objectives //! - `GET /api/space` - Get parameter space metadata -//! - `POST /api/checkpoint/save` - Save a checkpoint (internal) +//! - `POST /api/checkpoint/save` - Save a full checkpoint //! - `GET /api/events` - SSE stream of engine events -use crate::hola_engine::{HolaEngine, ObjectiveConfig}; +use crate::hola_engine::{CompletedTrial, HolaEngine, ObjectiveConfig}; use axum::{ Router, - extract::{Query, State}, - http::StatusCode, + extract::{Path as AxumPath, Query, State}, + http::{ + HeaderMap, HeaderValue, Method, StatusCode, + header::{AUTHORIZATION, CONTENT_TYPE}, + }, response::{ Json, sse::{Event, Sse}, @@ -43,12 +47,12 @@ use axum::{ }; use serde::{Deserialize, Serialize}; use std::convert::Infallible; -use std::path::Path; +use std::path::{Component, Path, PathBuf}; use std::sync::Arc; use tokio::sync::broadcast; use tokio_stream::StreamExt; use tokio_stream::wrappers::BroadcastStream; -use tower_http::cors::CorsLayer; +use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::services::ServeDir; // ============================================================================= @@ -59,13 +63,44 @@ use tower_http::services::ServeDir; #[derive(Clone, Debug, Serialize)] #[serde(tag = "type")] pub enum EngineEvent { - TrialCompleted { trial_id: u64, score: f64 }, - RefitOccurred { n_trials: usize }, + TrialCompleted { + trial_id: u64, + score: f64, + trial: CompletedTrial, + }, + RefitOccurred { + n_trials: usize, + }, } pub struct ServerState { pub engine: HolaEngine, pub events_tx: broadcast::Sender, + auth_token: Option, + checkpoint_dir: PathBuf, +} + +#[derive(Clone, Debug)] +pub struct ServerOptions { + pub host: String, + pub port: u16, + pub dashboard_dir: Option, + pub auth_token: Option, + pub checkpoint_dir: PathBuf, + pub cors_allowed_origins: Vec, +} + +impl ServerOptions { + pub fn new(port: u16) -> Self { + Self { + host: "127.0.0.1".to_string(), + port, + dashboard_dir: None, + auth_token: None, + checkpoint_dir: PathBuf::from("."), + cors_allowed_origins: Vec::new(), + } + } } // ============================================================================= @@ -106,6 +141,12 @@ struct TrialsQuery { include_infeasible: Option, } +#[derive(Deserialize)] +struct TrialQuery { + #[serde(default)] + include_infeasible: Option, +} + #[derive(Deserialize)] struct SaveCheckpointRequest { #[serde(default = "default_checkpoint_path")] @@ -132,9 +173,71 @@ struct UpdateObjectivesRequest { // Handlers // ============================================================================= +fn unauthorized() -> (StatusCode, Json) { + ( + StatusCode::UNAUTHORIZED, + Json(ErrorResponse { + error: "Missing or invalid bearer token".to_string(), + }), + ) +} + +fn authorize_mutation( + state: &ServerState, + headers: &HeaderMap, +) -> Result<(), (StatusCode, Json)> { + let Some(token) = &state.auth_token else { + return Ok(()); + }; + + let expected = format!("Bearer {token}"); + match headers + .get(AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + { + Some(actual) if actual == expected => Ok(()), + _ => Err(unauthorized()), + } +} + +fn invalid_checkpoint_path(message: impl Into) -> (StatusCode, Json) { + ( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: message.into(), + }), + ) +} + +fn resolve_checkpoint_path( + state: &ServerState, + requested: &str, +) -> Result)> { + let path = Path::new(requested); + if path.as_os_str().is_empty() { + return Err(invalid_checkpoint_path("Checkpoint path must not be empty")); + } + if path.is_absolute() + || path.components().any(|component| { + matches!( + component, + Component::ParentDir | Component::RootDir | Component::Prefix(_) + ) + }) + { + return Err(invalid_checkpoint_path( + "Checkpoint path must be relative to the configured checkpoint directory", + )); + } + + Ok(state.checkpoint_dir.join(path)) +} + async fn handle_ask( State(state): State>, + headers: HeaderMap, ) -> Result, (StatusCode, Json)> { + authorize_mutation(&state, &headers)?; match state.engine.ask().await { Ok(trial) => Ok(Json(serde_json::to_value(&trial).unwrap())), Err(e) => Err((StatusCode::BAD_REQUEST, Json(ErrorResponse { error: e }))), @@ -143,8 +246,10 @@ async fn handle_ask( async fn handle_tell( State(state): State>, + headers: HeaderMap, Json(req): Json, ) -> Result, (StatusCode, Json)> { + authorize_mutation(&state, &headers)?; match state.engine.tell(req.trial_id, req.metrics).await { Ok(completed) => { let n = state.engine.trial_count().await; @@ -160,11 +265,13 @@ async fn handle_tell( let _ = state.events_tx.send(EngineEvent::TrialCompleted { trial_id: req.trial_id, score, + trial: completed.clone(), }); Ok(Json(serde_json::json!({ "status": "ok", "trial_count": n, + "trial": completed, }))) } Err(e) => Err((StatusCode::BAD_REQUEST, Json(ErrorResponse { error: e }))), @@ -173,8 +280,10 @@ async fn handle_tell( async fn handle_cancel( State(state): State>, + headers: HeaderMap, Json(req): Json, ) -> Result, (StatusCode, Json)> { + authorize_mutation(&state, &headers)?; match state.engine.cancel(req.trial_id).await { Ok(()) => Ok(Json(serde_json::json!({ "status": "ok" }))), Err(e) => Err((StatusCode::BAD_REQUEST, Json(ErrorResponse { error: e }))), @@ -210,6 +319,27 @@ async fn handle_trials( Json(serde_json::to_value(&trials).unwrap_or_default()) } +async fn handle_trial( + State(state): State>, + AxumPath(trial_id): AxumPath, + Query(q): Query, +) -> Result, (StatusCode, Json)> { + let include_infeasible = q.include_infeasible.unwrap_or(true); + match state + .engine + .completed_trial(trial_id, include_infeasible) + .await + { + Some(trial) => Ok(Json(serde_json::to_value(&trial).unwrap_or_default())), + None => Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: format!("Trial {trial_id} not found"), + }), + )), + } +} + async fn handle_trial_count(State(state): State>) -> Json { let count = state.engine.trial_count().await; Json(serde_json::json!({ "trial_count": count })) @@ -217,14 +347,20 @@ async fn handle_trial_count(State(state): State>) -> Json>, + headers: HeaderMap, Json(req): Json, -) -> Json { - state.engine.update_objectives(req.objectives).await; - let n = state.engine.trial_count().await; - Json(serde_json::json!({ - "status": "ok", - "rescalarized_trials": n, - })) +) -> Result, (StatusCode, Json)> { + authorize_mutation(&state, &headers)?; + match state.engine.update_objectives(req.objectives).await { + Ok(()) => { + let n = state.engine.trial_count().await; + Ok(Json(serde_json::json!({ + "status": "ok", + "rescalarized_trials": n, + }))) + } + Err(e) => Err((StatusCode::BAD_REQUEST, Json(ErrorResponse { error: e }))), + } } async fn handle_get_objectives(State(state): State>) -> Json { @@ -256,18 +392,33 @@ async fn handle_space(State(state): State>) -> Json>, + headers: HeaderMap, Json(req): Json, ) -> Result, (StatusCode, Json)> { + authorize_mutation(&state, &headers)?; + let path = resolve_checkpoint_path(&state, &req.path)?; + if let Some(parent) = path.parent() + && let Err(e) = std::fs::create_dir_all(parent) + { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: e.to_string(), + }), + )); + } + match state .engine - .save_leaderboard_checkpoint_to(&req.path, req.description.as_deref()) + .save_full_checkpoint(&path, req.description.as_deref()) .await { Ok(()) => { let n = state.engine.trial_count().await; Ok(Json(serde_json::json!({ "status": "ok", - "path": req.path, + "checkpoint_type": "full", + "path": path.to_string_lossy(), "trials_saved": n, }))) } @@ -298,12 +449,42 @@ async fn handle_events( // Router & Server // ============================================================================= +fn build_cors(origins: &[String]) -> CorsLayer { + let mut cors = CorsLayer::new() + .allow_methods([Method::GET, Method::POST, Method::PATCH]) + .allow_headers([CONTENT_TYPE, AUTHORIZATION]); + + if !origins.is_empty() { + let parsed: Vec = origins + .iter() + .map(|origin| { + origin + .parse() + .expect("CORS origins must be valid HTTP header values") + }) + .collect(); + cors = cors.allow_origin(AllowOrigin::list(parsed)); + } + + cors +} + /// Create the Axum router for the engine server. pub fn create_router(engine: HolaEngine) -> Router { + create_router_with_options(engine, ServerOptions::new(8000)) +} + +/// Create the Axum router for the engine server with explicit server options. +pub fn create_router_with_options(engine: HolaEngine, options: ServerOptions) -> Router { let (events_tx, _) = broadcast::channel(256); - let state = Arc::new(ServerState { engine, events_tx }); + let state = Arc::new(ServerState { + engine, + events_tx, + auth_token: options.auth_token, + checkpoint_dir: options.checkpoint_dir, + }); - let cors = CorsLayer::permissive(); + let cors = build_cors(&options.cors_allowed_origins); Router::new() .route("/api/ask", post(handle_ask)) @@ -311,6 +492,7 @@ pub fn create_router(engine: HolaEngine) -> Router { .route("/api/cancel", post(handle_cancel)) .route("/api/top_k", get(handle_top_k)) .route("/api/pareto_front", get(handle_pareto_front)) + .route("/api/trial/{trial_id}", get(handle_trial)) .route("/api/trials", get(handle_trials)) .route("/api/trial_count", get(handle_trial_count)) .route( @@ -329,7 +511,22 @@ pub fn create_router(engine: HolaEngine) -> Router { /// API routes under `/api/*` take priority; all other paths fall through to /// serve static files from `dashboard_dir`. pub fn create_router_with_dashboard(engine: HolaEngine, dashboard_dir: &Path) -> Router { - create_router(engine).fallback_service(ServeDir::new(dashboard_dir)) + let mut options = ServerOptions::new(8000); + options.dashboard_dir = Some(dashboard_dir.to_path_buf()); + create_router_with_dashboard_and_options(engine, options) +} + +/// Create the Axum router with the dashboard and explicit server options. +pub fn create_router_with_dashboard_and_options( + engine: HolaEngine, + options: ServerOptions, +) -> Router { + let dashboard_dir = options.dashboard_dir.clone(); + let router = create_router_with_options(engine, options); + match dashboard_dir { + Some(dir) => router.fallback_service(ServeDir::new(dir)), + None => router, + } } /// Start the server on the given port. Blocks until the server is shut down. @@ -340,18 +537,31 @@ pub async fn serve( port: u16, dashboard_dir: Option<&Path>, ) -> Result<(), Box> { - let router = match dashboard_dir { - Some(dir) => create_router_with_dashboard(engine, dir), - None => create_router(engine), + let mut options = ServerOptions::new(port); + options.dashboard_dir = dashboard_dir.map(Path::to_path_buf); + serve_with_options(engine, options).await +} + +/// Start the server with explicit host, auth, CORS, and checkpoint options. +pub async fn serve_with_options( + engine: HolaEngine, + options: ServerOptions, +) -> Result<(), Box> { + let router = match options.dashboard_dir.as_deref() { + Some(_) => create_router_with_dashboard_and_options(engine, options.clone()), + None => create_router_with_options(engine, options.clone()), }; - let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}")).await?; - if let Some(dir) = dashboard_dir { + let listener = + tokio::net::TcpListener::bind(format!("{}:{}", options.host, options.port)).await?; + if let Some(dir) = &options.dashboard_dir { eprintln!( - "HOLA server listening on port {port} (dashboard: {})", + "HOLA server listening on {}:{} (dashboard: {})", + options.host, + options.port, dir.display() ); } else { - eprintln!("HOLA server listening on port {port}"); + eprintln!("HOLA server listening on {}:{}", options.host, options.port); } axum::serve(listener, router).await?; Ok(()) diff --git a/hola/tests/integration/hola_engine.rs b/hola/tests/integration/hola_engine.rs index b0a7edb..58cdd8d 100644 --- a/hola/tests/integration/hola_engine.rs +++ b/hola/tests/integration/hola_engine.rs @@ -14,7 +14,9 @@ //! Exercises config parsing, ask/tell flows, strategy types, scalarization, //! objectives, checkpoints, refit, and all parameter types. -use hola::hola_engine::{HolaEngine, ObjectiveConfig, ParamConfig, StrategyConfig, StudyConfig}; +use hola::hola_engine::{ + CheckpointLoadKind, HolaEngine, ObjectiveConfig, ParamConfig, StrategyConfig, StudyConfig, +}; use opt_engine::traits::SampleSpace; use serde_json::json; use std::collections::BTreeMap; @@ -88,6 +90,155 @@ async fn test_dyn_engine_config_with_checkpoint() { let _t = engine.ask().await.unwrap(); } +fn valid_config_for_validation() -> StudyConfig { + StudyConfig { + space: BTreeMap::from([( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: 1.0, + scale: "linear".to_string(), + }, + )]), + objectives: vec![ObjectiveConfig { + field: "loss".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }], + strategy: Some(StrategyConfig { + strategy_type: "gmm".to_string(), + refit_interval: 20, + total_budget: None, + exploration_budget: None, + seed: None, + elite_fraction: None, + }), + checkpoint: None, + max_trials: None, + } +} + +fn assert_config_error(config: StudyConfig, expected: &[&str]) { + let err = match HolaEngine::from_config(config) { + Ok(_) => panic!("expected config validation to fail"), + Err(err) => err, + }; + for needle in expected { + assert!( + err.contains(needle), + "expected error {err:?} to contain {needle:?}" + ); + } +} + +#[test] +fn test_dyn_engine_config_validation_rejects_invalid_scale() { + let mut config = valid_config_for_validation(); + config.space.insert( + "lr".to_string(), + ParamConfig::Real { + min: 1.0e-4, + max: 1.0e-1, + scale: "log2".to_string(), + }, + ); + assert_config_error(config, &["Parameter 'lr'", "unknown real scale", "log2"]); +} + +#[test] +fn test_dyn_engine_config_validation_rejects_invalid_strategy() { + let mut config = valid_config_for_validation(); + config.strategy.as_mut().unwrap().strategy_type = "soboll".to_string(); + assert_config_error(config, &["Unknown strategy type", "soboll"]); +} + +#[test] +fn test_dyn_engine_config_validation_rejects_invalid_objective_type() { + let mut config = valid_config_for_validation(); + config.objectives[0].obj_type = "minimise".to_string(); + assert_config_error( + config, + &["Objective 'loss'", "unknown objective type", "minimise"], + ); +} + +#[test] +fn test_dyn_engine_config_validation_rejects_invalid_space_shapes() { + let mut real = valid_config_for_validation(); + real.space.insert( + "x".to_string(), + ParamConfig::Real { + min: 1.0, + max: 1.0, + scale: "linear".to_string(), + }, + ); + assert_config_error(real, &["Parameter 'x'", "min must be less than max"]); + + let mut integer = valid_config_for_validation(); + integer.space.insert( + "layers".to_string(), + ParamConfig::Integer { min: 10, max: 1 }, + ); + assert_config_error( + integer, + &["Parameter 'layers'", "integer min must be <= max"], + ); + + let mut categorical = valid_config_for_validation(); + categorical.space.insert( + "optimizer".to_string(), + ParamConfig::Categorical { choices: vec![] }, + ); + assert_config_error( + categorical, + &["Parameter 'optimizer'", "choices must not be empty"], + ); +} + +#[test] +fn test_dyn_engine_config_validation_rejects_non_finite_real_bounds() { + let mut nan = valid_config_for_validation(); + nan.space.insert( + "x".to_string(), + ParamConfig::Real { + min: f64::NAN, + max: 1.0, + scale: "linear".to_string(), + }, + ); + assert_config_error(nan, &["Parameter 'x'", "bounds must be finite"]); + + let mut inf = valid_config_for_validation(); + inf.space.insert( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: f64::INFINITY, + scale: "linear".to_string(), + }, + ); + assert_config_error(inf, &["Parameter 'x'", "bounds must be finite"]); +} + +#[test] +fn test_dyn_engine_config_validation_rejects_invalid_refit_and_priority() { + let mut refit = valid_config_for_validation(); + refit.strategy.as_mut().unwrap().refit_interval = 0; + assert_config_error(refit, &["strategy.refit_interval", "at least 1"]); + + let mut priority = valid_config_for_validation(); + priority.objectives[0].priority = -1.0; + assert_config_error(priority, &["Objective 'loss'", "priority"]); + + let mut elite_fraction = valid_config_for_validation(); + elite_fraction.strategy.as_mut().unwrap().elite_fraction = Some(f64::NAN); + assert_config_error(elite_fraction, &["strategy.elite_fraction", "finite"]); +} + // ========================================================================== // Ask/Tell flow // ========================================================================== @@ -202,6 +353,57 @@ async fn test_dyn_engine_double_tell_error() { assert!(engine.tell(t.trial_id, json!({"loss": 0.3})).await.is_err()); } +#[tokio::test] +async fn test_dyn_engine_out_of_order_tell_preserves_public_trial_ids() { + let config = StudyConfig { + space: BTreeMap::from([( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: 1.0, + scale: "linear".to_string(), + }, + )]), + objectives: vec![ObjectiveConfig { + field: "loss".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }], + strategy: None, + checkpoint: None, + max_trials: None, + }; + + let engine = HolaEngine::from_config(config).unwrap(); + let t0 = engine.ask().await.unwrap(); + let t1 = engine.ask().await.unwrap(); + + let completed_1 = engine + .tell(t1.trial_id, json!({"loss": 0.2})) + .await + .unwrap(); + assert_eq!(completed_1.trial_id, t1.trial_id); + assert_eq!(completed_1.params, t1.params); + + let completed_0 = engine + .tell(t0.trial_id, json!({"loss": 0.8})) + .await + .unwrap(); + assert_eq!(completed_0.trial_id, t0.trial_id); + assert_eq!(completed_0.params, t0.params); + + let ids: Vec = engine + .trials("index", true) + .await + .into_iter() + .map(|trial| trial.trial_id) + .collect(); + assert_eq!(ids, vec![0, 1]); +} + // ========================================================================== // All parameter types // ========================================================================== @@ -650,11 +852,33 @@ async fn test_dyn_engine_update_objectives() { priority: 1.0, group: None, }]) - .await; + .await + .unwrap(); assert!(!engine.top_k(1, false).await.is_empty()); } +#[tokio::test] +async fn test_dyn_engine_update_objectives_rejects_invalid_config() { + let engine = HolaEngine::from_config(valid_config_for_validation()).unwrap(); + let before = engine.objectives().await; + + let err = engine + .update_objectives(vec![ObjectiveConfig { + field: "accuracy".to_string(), + obj_type: "larger_is_better".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }]) + .await + .unwrap_err(); + + assert!(err.contains("Objective 'accuracy'")); + assert_eq!(engine.objectives().await[0].field, before[0].field); +} + #[tokio::test] async fn test_dyn_engine_objectives_accessor() { let config = StudyConfig { @@ -746,12 +970,153 @@ async fn test_dyn_engine_update_objectives_rescalarizes() { priority: 1.0, group: None, }]) - .await; + .await + .unwrap(); let best_after = engine.top_k(1, false).await.into_iter().next().unwrap(); assert_ne!(best_before.trial_id, best_after.trial_id); } +#[tokio::test] +async fn test_dyn_engine_update_objectives_migrates_scalar_to_vector() { + let config = StudyConfig { + space: BTreeMap::from([( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: 1.0, + scale: "linear".to_string(), + }, + )]), + objectives: vec![ObjectiveConfig { + field: "f1".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }], + strategy: None, + checkpoint: None, + max_trials: None, + }; + + let engine = HolaEngine::from_config(config).unwrap(); + for metrics in [ + json!({"f1": 1.0, "f2": 5.0}), + json!({"f1": 5.0, "f2": 1.0}), + json!({"f1": 3.0, "f2": 3.0}), + json!({"f1": 4.0, "f2": 4.0}), + ] { + let trial = engine.ask().await.unwrap(); + engine.tell(trial.trial_id, metrics).await.unwrap(); + } + assert!(engine.pareto_front(0, false).await.is_empty()); + + engine + .update_objectives(vec![ + ObjectiveConfig { + field: "f1".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }, + ObjectiveConfig { + field: "f2".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }, + ]) + .await + .unwrap(); + + let mut front_ids: Vec = engine + .pareto_front(0, false) + .await + .into_iter() + .map(|trial| trial.trial_id) + .collect(); + front_ids.sort_unstable(); + assert_eq!(front_ids, vec![0, 1, 2]); + + let migrated = engine + .trials("index", true) + .await + .into_iter() + .find(|trial| trial.trial_id == 0) + .unwrap(); + assert!(migrated.score_vector.get("f1").is_some()); + assert!(migrated.score_vector.get("f2").is_some()); +} + +#[tokio::test] +async fn test_dyn_engine_update_objectives_migrates_vector_to_scalar() { + let config = StudyConfig { + space: BTreeMap::from([( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: 1.0, + scale: "linear".to_string(), + }, + )]), + objectives: vec![ + ObjectiveConfig { + field: "f1".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }, + ObjectiveConfig { + field: "f2".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }, + ], + strategy: None, + checkpoint: None, + max_trials: None, + }; + + let engine = HolaEngine::from_config(config).unwrap(); + for metrics in [ + json!({"f1": 10.0, "f2": 0.0}), + json!({"f1": 1.0, "f2": 10.0}), + json!({"f1": 5.0, "f2": 5.0}), + ] { + let trial = engine.ask().await.unwrap(); + engine.tell(trial.trial_id, metrics).await.unwrap(); + } + assert!(!engine.pareto_front(0, false).await.is_empty()); + + engine + .update_objectives(vec![ObjectiveConfig { + field: "f1".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }]) + .await + .unwrap(); + + assert!(engine.pareto_front(0, false).await.is_empty()); + let best = engine.top_k(1, false).await.into_iter().next().unwrap(); + assert_eq!(best.trial_id, 1); + assert_eq!(best.score_vector.as_object().unwrap().len(), 1); +} + // ========================================================================== // Rescalarize // ========================================================================== @@ -942,7 +1307,8 @@ async fn test_update_objectives_triggers_refit() { priority: 1.0, group: None, }]) - .await; + .await + .unwrap(); let best_after = engine.top_k(1, false).await.into_iter().next().unwrap(); assert!(best_after.rank < best_before.rank || best_after.trial_id != best_before.trial_id); @@ -955,6 +1321,64 @@ async fn test_update_objectives_triggers_refit() { // Checkpoints // ========================================================================== +fn scalar_checkpoint_config(max_trials: Option) -> StudyConfig { + StudyConfig { + space: BTreeMap::from([( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: 1.0, + scale: "linear".to_string(), + }, + )]), + objectives: vec![ObjectiveConfig { + field: "loss".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }], + strategy: None, + checkpoint: None, + max_trials, + } +} + +fn vector_checkpoint_config() -> StudyConfig { + StudyConfig { + space: BTreeMap::from([( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: 1.0, + scale: "linear".to_string(), + }, + )]), + objectives: vec![ + ObjectiveConfig { + field: "f1".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }, + ObjectiveConfig { + field: "f2".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 2.0, + group: None, + }, + ], + strategy: None, + checkpoint: None, + max_trials: None, + } +} + #[tokio::test] async fn test_dyn_engine_leaderboard_checkpoint() { let config = StudyConfig { @@ -1003,6 +1427,59 @@ async fn test_dyn_engine_leaderboard_checkpoint() { assert_eq!(engine2.trial_count().await, 2); } +#[tokio::test] +async fn test_dyn_engine_leaderboard_checkpoint_resume_uses_fresh_trial_id() { + let config = scalar_checkpoint_config(Some(3)); + let engine = HolaEngine::from_config(config.clone()).unwrap(); + + for (expected_id, loss) in [0.5, 0.3].into_iter().enumerate() { + let trial = engine.ask().await.unwrap(); + assert_eq!(trial.trial_id, expected_id as u64); + let completed = engine + .tell(trial.trial_id, json!({"loss": loss})) + .await + .unwrap(); + assert_eq!(completed.trial_id, expected_id as u64); + } + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("lb.json"); + engine + .save_leaderboard_checkpoint_to(&path, Some("2 trials")) + .await + .unwrap(); + + let restored = HolaEngine::from_config(config).unwrap(); + let stale_pending = restored.ask().await.unwrap(); + let stale_cancelled = restored.ask().await.unwrap(); + restored.cancel(stale_cancelled.trial_id).await.unwrap(); + + restored.load_leaderboard_checkpoint(&path).await.unwrap(); + assert!( + restored + .tell(stale_pending.trial_id, json!({"loss": 0.0})) + .await + .is_err(), + "pending trials from the pre-load engine state must not survive checkpoint load" + ); + + let trial = restored.ask().await.unwrap(); + assert_eq!(trial.trial_id, 2); + let completed = restored + .tell(trial.trial_id, json!({"loss": 0.1})) + .await + .unwrap(); + assert_eq!(completed.trial_id, 2); + + let ids: Vec = restored + .trials("index", true) + .await + .into_iter() + .map(|trial| trial.trial_id) + .collect(); + assert_eq!(ids, vec![0, 1, 2]); +} + #[tokio::test] async fn test_dyn_engine_full_checkpoint() { let config = StudyConfig { @@ -1043,10 +1520,175 @@ async fn test_dyn_engine_full_checkpoint() { assert_eq!(engine2.trial_count().await, 1); } +#[tokio::test] +async fn test_dyn_engine_full_checkpoint_resume_returns_new_completed_trial() { + let config = scalar_checkpoint_config(None); + let engine = HolaEngine::from_config(config.clone()).unwrap(); + + for loss in [0.5, 0.3] { + let trial = engine.ask().await.unwrap(); + engine + .tell(trial.trial_id, json!({"loss": loss})) + .await + .unwrap(); + } + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("full.json"); + engine + .save_full_checkpoint(&path, Some("2 trials")) + .await + .unwrap(); + + let restored = HolaEngine::from_config(config).unwrap(); + restored.load_full_checkpoint(&path).await.unwrap(); + + let trial = restored.ask().await.unwrap(); + assert_eq!(trial.trial_id, 2); + let completed = restored + .tell(trial.trial_id, json!({"loss": 0.1})) + .await + .unwrap(); + assert_eq!(completed.trial_id, 2); + assert_eq!(completed.params, trial.params); + + let ids: Vec = restored + .trials("index", true) + .await + .into_iter() + .map(|trial| trial.trial_id) + .collect(); + assert_eq!(ids, vec![0, 1, 2]); +} + +#[tokio::test] +async fn test_dyn_engine_full_checkpoint_resume_preserves_vector_trial_ids() { + let config = vector_checkpoint_config(); + let engine = HolaEngine::from_config(config.clone()).unwrap(); + + for metrics in [json!({"f1": 1.0, "f2": 3.0}), json!({"f1": 2.0, "f2": 1.0})] { + let trial = engine.ask().await.unwrap(); + let completed = engine.tell(trial.trial_id, metrics).await.unwrap(); + assert_eq!(completed.trial_id, trial.trial_id); + } + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("vector-full.json"); + engine + .save_full_checkpoint(&path, Some("vector checkpoint")) + .await + .unwrap(); + + let restored = HolaEngine::from_config(config).unwrap(); + restored.load_full_checkpoint(&path).await.unwrap(); + + let trial = restored.ask().await.unwrap(); + assert_eq!(trial.trial_id, 2); + let completed = restored + .tell(trial.trial_id, json!({"f1": 0.5, "f2": 2.5})) + .await + .unwrap(); + assert_eq!(completed.trial_id, 2); + + let ids: Vec = restored + .trials("index", true) + .await + .into_iter() + .map(|trial| trial.trial_id) + .collect(); + assert_eq!(ids, vec![0, 1, 2]); +} + +#[tokio::test] +async fn test_dyn_engine_checkpoint_load_with_fallback_supports_full_and_leaderboard() { + let config = scalar_checkpoint_config(None); + let engine = HolaEngine::from_config(config.clone()).unwrap(); + let trial = engine.ask().await.unwrap(); + engine + .tell(trial.trial_id, json!({"loss": 0.5})) + .await + .unwrap(); + + let dir = tempfile::tempdir().unwrap(); + let full_path = dir.path().join("full.json"); + engine + .save_full_checkpoint(&full_path, Some("full")) + .await + .unwrap(); + + let restored_full = HolaEngine::from_config(config.clone()).unwrap(); + let kind = restored_full + .load_checkpoint_with_fallback(&full_path) + .await + .unwrap(); + assert_eq!(kind, CheckpointLoadKind::Full); + assert_eq!(restored_full.trial_count().await, 1); + assert_eq!(restored_full.ask().await.unwrap().trial_id, 1); + + let leaderboard_path = dir.path().join("leaderboard.json"); + engine + .save_leaderboard_checkpoint_to(&leaderboard_path, Some("leaderboard")) + .await + .unwrap(); + + let restored_leaderboard = HolaEngine::from_config(config).unwrap(); + let kind = restored_leaderboard + .load_checkpoint_with_fallback(&leaderboard_path) + .await + .unwrap(); + assert_eq!(kind, CheckpointLoadKind::Leaderboard); + assert_eq!(restored_leaderboard.trial_count().await, 1); + assert_eq!(restored_leaderboard.ask().await.unwrap().trial_id, 1); +} + // ========================================================================== // Auto strategy (Sobol -> GMM switching) // ========================================================================== +fn auto_strategy_test_config(exploration_budget: usize, seed: u64) -> StudyConfig { + StudyConfig { + space: BTreeMap::from([( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: 1.0, + scale: "linear".to_string(), + }, + )]), + objectives: vec![ObjectiveConfig { + field: "loss".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }], + strategy: Some(StrategyConfig { + strategy_type: "auto".to_string(), + refit_interval: 5, + total_budget: None, + exploration_budget: Some(exploration_budget), + seed: Some(seed), + elite_fraction: None, + }), + checkpoint: None, + max_trials: None, + } +} + +fn sobol_strategy_test_config(seed: u64) -> StudyConfig { + let mut config = auto_strategy_test_config(0, seed); + config.strategy = Some(StrategyConfig { + strategy_type: "sobol".to_string(), + refit_interval: 20, + total_budget: None, + exploration_budget: None, + seed: Some(seed), + elite_fraction: None, + }); + config +} + #[tokio::test] async fn test_auto_strategy_default() { // With no strategy config, should use "auto" and work correctly @@ -1157,6 +1799,50 @@ fn test_auto_strategy_default_exploration_budget() { assert_eq!(AutoStrategy::default_exploration_budget(5, 1), 1); // min(1, 52) = 1 } +#[tokio::test] +async fn test_auto_strategy_counts_pending_asks_against_exploration_budget() { + let auto = HolaEngine::from_config(auto_strategy_test_config(2, 17)).unwrap(); + let sobol = HolaEngine::from_config(sobol_strategy_test_config(17)).unwrap(); + let gmm = HolaEngine::from_config(auto_strategy_test_config(0, 17)).unwrap(); + + let auto_trials = [ + auto.ask().await.unwrap(), + auto.ask().await.unwrap(), + auto.ask().await.unwrap(), + auto.ask().await.unwrap(), + ]; + + assert_eq!(auto_trials[0].params, sobol.ask().await.unwrap().params); + assert_eq!(auto_trials[1].params, sobol.ask().await.unwrap().params); + assert_eq!(auto_trials[2].params, gmm.ask().await.unwrap().params); + assert_eq!(auto_trials[3].params, gmm.ask().await.unwrap().params); + assert_eq!(auto.trial_count().await, 0); +} + +#[tokio::test] +async fn test_auto_strategy_full_checkpoint_preserves_pending_ask_accounting() { + let engine = HolaEngine::from_config(auto_strategy_test_config(2, 17)).unwrap(); + for _ in 0..3 { + engine.ask().await.unwrap(); + } + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("auto-full.json"); + engine + .save_full_checkpoint(&path, Some("auto pending asks")) + .await + .unwrap(); + + let restored = HolaEngine::load_from_checkpoint(&path).await.unwrap(); + + let gmm = HolaEngine::from_config(auto_strategy_test_config(0, 17)).unwrap(); + gmm.ask().await.unwrap(); + let expected = gmm.ask().await.unwrap(); + let resumed = restored.ask().await.unwrap(); + + assert_eq!(resumed.params, expected.params); +} + // ========================================================================== // Seed determinism tests // ========================================================================== diff --git a/hola/tests/integration/main.rs b/hola/tests/integration/main.rs index 0e4ea8e..a448a2a 100644 --- a/hola/tests/integration/main.rs +++ b/hola/tests/integration/main.rs @@ -9,7 +9,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Integration tests for HOLA: DynEngine, REST API, and end-to-end optimization. +//! Integration tests for HOLA: HolaEngine, REST API, and end-to-end optimization. mod end_to_end; mod hola_engine; diff --git a/hola/tests/integration/server.rs b/hola/tests/integration/server.rs index e9a1b48..d23edc4 100644 --- a/hola/tests/integration/server.rs +++ b/hola/tests/integration/server.rs @@ -14,8 +14,8 @@ //! Exercises ask/tell/top_k/trials endpoints, space/objectives info, //! checkpoints, error handling, cancel, and objective rescalarization. -use hola::hola_engine::{HolaEngine, ObjectiveConfig, ParamConfig, StudyConfig}; -use hola::server::create_router; +use hola::hola_engine::{HolaEngine, ObjectiveConfig, ParamConfig, StrategyConfig, StudyConfig}; +use hola::server::{ServerOptions, create_router, create_router_with_options}; use http_body_util::BodyExt; use serde_json::json; use std::collections::BTreeMap; @@ -49,6 +49,37 @@ fn minimal_config() -> StudyConfig { } } +fn sobol_config(seed: u64) -> StudyConfig { + StudyConfig { + space: BTreeMap::from([( + "x".to_string(), + ParamConfig::Real { + min: 0.0, + max: 1.0, + scale: "linear".to_string(), + }, + )]), + objectives: vec![ObjectiveConfig { + field: "loss".to_string(), + obj_type: "minimize".to_string(), + target: None, + limit: None, + priority: 1.0, + group: None, + }], + strategy: Some(StrategyConfig { + strategy_type: "sobol".to_string(), + refit_interval: 20, + total_budget: None, + exploration_budget: None, + seed: Some(seed), + elite_fraction: None, + }), + checkpoint: None, + max_trials: None, + } +} + fn multi_param_config() -> StudyConfig { StudyConfig { space: BTreeMap::from([ @@ -90,8 +121,21 @@ async fn json_request( method: &str, uri: &str, body: Option, +) -> (u16, serde_json::Value) { + json_request_with_headers(app, method, uri, body, &[]).await +} + +async fn json_request_with_headers( + app: axum::Router, + method: &str, + uri: &str, + body: Option, + headers: &[(&str, &str)], ) -> (u16, serde_json::Value) { let mut builder = hyper::Request::builder().method(method).uri(uri); + for (name, value) in headers { + builder = builder.header(*name, *value); + } let body = if let Some(b) = body { builder = builder.header("content-type", "application/json"); axum::body::Body::from(serde_json::to_vec(&b).unwrap()) @@ -106,6 +150,22 @@ async fn json_request( (status, json) } +async fn options_request( + app: axum::Router, + uri: &str, + origin: &str, + requested_method: &str, +) -> hyper::Response { + let req = hyper::Request::builder() + .method("OPTIONS") + .uri(uri) + .header("origin", origin) + .header("access-control-request-method", requested_method) + .body(axum::body::Body::empty()) + .unwrap(); + app.oneshot(req).await.unwrap() +} + // ========================================================================== // Core flow: ask -> tell -> top_k -> trials // ========================================================================== @@ -144,6 +204,8 @@ async fn test_server_ask_tell_top_k_flow() { assert_eq!(status, 200); assert_eq!(result["status"], "ok"); assert_eq!(result["trial_count"], 1); + assert_eq!(result["trial"]["trial_id"], trial_id); + assert!(result["trial"]["score_vector"].is_object()); // Top-k let (status, top) = json_request( @@ -162,6 +224,18 @@ async fn test_server_ask_tell_top_k_flow() { assert!(top_arr[0]["scores"].is_object()); assert!(top_arr[0]["rank"].is_u64()); + // Single-trial lookup + let (status, single) = json_request( + app.clone(), + "GET", + &format!("/api/trial/{trial_id}?include_infeasible=true"), + None, + ) + .await; + assert_eq!(status, 200); + assert_eq!(single["trial_id"], trial_id); + assert_eq!(single["metrics"]["loss"], 0.42); + // Trials let (status, trials) = json_request( app, @@ -196,6 +270,111 @@ async fn test_server_trial_count() { assert_eq!(body["trial_count"], 0); } +// ========================================================================== +// Security options: auth and CORS +// ========================================================================== + +#[tokio::test] +async fn test_server_auth_rejects_missing_and_invalid_bearer() { + let engine = HolaEngine::from_config(minimal_config()).unwrap(); + let mut options = ServerOptions::new(8000); + options.auth_token = Some("secret".to_string()); + let app = create_router_with_options(engine, options); + + let (status, body) = json_request(app.clone(), "POST", "/api/ask", None).await; + assert_eq!(status, 401); + assert!(body["error"].as_str().unwrap().contains("bearer token")); + + for (method, uri, body) in [ + ( + "POST", + "/api/tell", + Some(json!({"trial_id": 0, "metrics": {"loss": 0.5}})), + ), + ("POST", "/api/cancel", Some(json!({"trial_id": 0}))), + ( + "PATCH", + "/api/objectives", + Some(json!({"objectives": [{"field": "loss", "type": "minimize"}]})), + ), + ( + "POST", + "/api/checkpoint/save", + Some(json!({"path": "checkpoint.json"})), + ), + ] { + let (status, body) = json_request(app.clone(), method, uri, body).await; + assert_eq!(status, 401, "{method} {uri}"); + assert!(body["error"].as_str().unwrap().contains("bearer token")); + } + + let (status, body) = json_request_with_headers( + app, + "POST", + "/api/ask", + None, + &[("authorization", "Bearer wrong")], + ) + .await; + assert_eq!(status, 401); + assert!(body["error"].as_str().unwrap().contains("bearer token")); +} + +#[tokio::test] +async fn test_server_auth_accepts_valid_bearer_for_mutations() { + let engine = HolaEngine::from_config(minimal_config()).unwrap(); + let mut options = ServerOptions::new(8000); + options.auth_token = Some("secret".to_string()); + let app = create_router_with_options(engine, options); + + let (status, trial) = json_request_with_headers( + app.clone(), + "POST", + "/api/ask", + None, + &[("authorization", "Bearer secret")], + ) + .await; + assert_eq!(status, 200); + + let tell = json!({"trial_id": trial["trial_id"], "metrics": {"loss": 0.5}}); + let (status, body) = json_request_with_headers( + app, + "POST", + "/api/tell", + Some(tell), + &[("authorization", "Bearer secret")], + ) + .await; + assert_eq!(status, 200); + assert_eq!(body["status"], "ok"); +} + +#[tokio::test] +async fn test_server_cors_allows_configured_origin_only() { + let engine = HolaEngine::from_config(minimal_config()).unwrap(); + let mut options = ServerOptions::new(8000); + options.cors_allowed_origins = vec!["http://allowed.example".to_string()]; + let app = create_router_with_options(engine, options); + + let allowed = options_request(app.clone(), "/api/ask", "http://allowed.example", "POST").await; + assert_eq!( + allowed + .headers() + .get("access-control-allow-origin") + .unwrap(), + "http://allowed.example" + ); + + let disallowed = options_request(app, "/api/ask", "http://disallowed.example", "POST").await; + assert!( + disallowed + .headers() + .get("access-control-allow-origin") + .is_none() + ); +} + // ========================================================================== // Error handling // ========================================================================== @@ -310,6 +489,22 @@ async fn test_server_update_objectives() { assert_eq!(result["status"], "ok"); } +#[tokio::test] +async fn test_server_update_objectives_rejects_invalid_type() { + let engine = HolaEngine::from_config(minimal_config()).unwrap(); + let app = create_router(engine); + + let patch = json!({"objectives": [{"field": "accuracy", "type": "larger", "priority": 1.0}]}); + let (status, result) = json_request(app, "PATCH", "/api/objectives", Some(patch)).await; + assert_eq!(status, 400); + assert!( + result["error"] + .as_str() + .unwrap() + .contains("Objective 'accuracy'") + ); +} + #[tokio::test] async fn test_server_update_objectives_rescalarizes() { let engine = HolaEngine::from_config(minimal_config()).unwrap(); @@ -366,19 +561,96 @@ async fn test_server_ask_sequential_ids() { #[tokio::test] async fn test_server_checkpoint_save_endpoint() { let engine = HolaEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); + let dir = tempfile::tempdir().unwrap(); + let mut options = ServerOptions::new(8000); + options.checkpoint_dir = dir.path().to_path_buf(); + let app = create_router_with_options(engine, options); let (_, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; let tell = json!({"trial_id": trial["trial_id"], "metrics": {"loss": 0.5}}); json_request(app.clone(), "POST", "/api/tell", Some(tell)).await; - let dir = tempfile::tempdir().unwrap(); let path = dir.path().join("ckpt.json"); - let save_req = json!({"path": path.to_string_lossy(), "description": "server test"}); + let save_req = json!({"path": "ckpt.json", "description": "server test"}); - let (status, _) = json_request(app, "POST", "/api/checkpoint/save", Some(save_req)).await; + let (status, body) = json_request(app, "POST", "/api/checkpoint/save", Some(save_req)).await; assert_eq!(status, 200); + assert_eq!(body["checkpoint_type"], "full"); + assert_eq!(body["path"].as_str().unwrap(), path.to_string_lossy()); assert!(path.exists()); + + let saved: serde_json::Value = + serde_json::from_reader(std::fs::File::open(&path).unwrap()).unwrap(); + assert!(saved.get("config").is_some()); + assert!(saved["checkpoint"].get("strategy_state").is_some()); +} + +#[tokio::test] +async fn test_server_checkpoint_save_preserves_sobol_sequence() { + let config = sobol_config(123); + let baseline = HolaEngine::from_config(config.clone()).unwrap(); + let engine = HolaEngine::from_config(config).unwrap(); + let dir = tempfile::tempdir().unwrap(); + let mut options = ServerOptions::new(8000); + options.checkpoint_dir = dir.path().to_path_buf(); + let app = create_router_with_options(engine, options); + + for _ in 0..3 { + let baseline_trial = baseline.ask().await.unwrap(); + baseline + .tell( + baseline_trial.trial_id, + json!({"loss": baseline_trial.params["x"]}), + ) + .await + .unwrap(); + + let (status, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; + assert_eq!(status, 200); + assert_eq!(trial["params"], baseline_trial.params); + let tell = + json!({"trial_id": trial["trial_id"], "metrics": {"loss": trial["params"]["x"]}}); + let (status, _) = json_request(app.clone(), "POST", "/api/tell", Some(tell)).await; + assert_eq!(status, 200); + } + let expected_next = baseline.ask().await.unwrap(); + + let path = dir.path().join("full.json"); + let save_req = json!({"path": "full.json", "description": "server full"}); + let (status, body) = json_request(app, "POST", "/api/checkpoint/save", Some(save_req)).await; + assert_eq!(status, 200); + assert_eq!(body["checkpoint_type"], "full"); + + let restored = HolaEngine::load_from_checkpoint(&path).await.unwrap(); + let restored_next = restored.ask().await.unwrap(); + assert_eq!(restored_next.trial_id, expected_next.trial_id); + assert_eq!(restored_next.params, expected_next.params); +} + +#[tokio::test] +async fn test_server_checkpoint_save_rejects_unconfined_paths() { + let engine = HolaEngine::from_config(minimal_config()).unwrap(); + let dir = tempfile::tempdir().unwrap(); + let mut options = ServerOptions::new(8000); + options.checkpoint_dir = dir.path().to_path_buf(); + let app = create_router_with_options(engine, options); + + let absolute_req = json!({"path": dir.path().join("escape.json").to_string_lossy()}); + let (status, body) = json_request( + app.clone(), + "POST", + "/api/checkpoint/save", + Some(absolute_req), + ) + .await; + assert_eq!(status, 400); + assert!(body["error"].as_str().unwrap().contains("relative")); + + let traversal_req = json!({"path": "../escape.json"}); + let (status, body) = + json_request(app, "POST", "/api/checkpoint/save", Some(traversal_req)).await; + assert_eq!(status, 400); + assert!(body["error"].as_str().unwrap().contains("relative")); } // ========================================================================== diff --git a/issues/08-autostrategy-pending-trial-accounting.md b/issues/08-autostrategy-pending-trial-accounting.md new file mode 100644 index 0000000..fea3343 --- /dev/null +++ b/issues/08-autostrategy-pending-trial-accounting.md @@ -0,0 +1,94 @@ +# AutoStrategy Exploration Accounting Ignores Pending Trials + +## Status + +Resolved in branch `audit/known-issues` with the issue 08 fix. + +## Tags + +- severity: medium +- type: algorithm +- area: strategy +- area: distributed +- area: ask-tell +- user-impact: sampling-quality + +## Summary + +`AutoStrategy` switches from Sobol exploration to GMM exploitation based on +completed trial count. In distributed usage, many `ask()` calls can be issued +before any `tell()` completes. Because `trial_count` increments only during +`update()`, a batch of pending trials can exceed the exploration budget while +still being sampled from Sobol. + +## Evidence + +- `suggest()` checks `s.trial_count < s.exploration_budget`. + - `hola/src/hola_engine.rs:451` +- `trial_count` increments only on `update()`, which occurs during `tell()`. + - `hola/src/hola_engine.rs:460` + - `hola/src/hola_engine.rs:466` + +## Impact + +- Distributed workers can oversample the exploration phase. +- The intended exploration/exploitation split depends on worker concurrency and + result latency. +- GMM may start later than intended in high-parallelism runs. + +## Proposed Fix + +Track issued suggestions separately from completed updates. + +Recommended design options: + +1. Add an issued counter to `AutoStrategy` and increment it in `suggest()`. +2. Or let `HolaEngine::ask()` decide exploration mode based on + completed + pending counts. + +Preferred path: + +- Add `issued_count` to `AutoStrategy`. +- Serialize it in checkpoints. +- Use `issued_count` for the Sobol/GMM switch. +- Keep `completed_count` or existing `trial_count` for refit/update logic. + +## Resolution + +- Added an `issued_count` counter to `AutoStrategy` and increment it from + `suggest()`, so the Sobol/GMM boundary is based on issued trials rather than + completed `tell()` calls. +- Kept the existing `trial_count` as completed-update accounting for refit and + strategy update behavior. +- Serialized `issued_count` in full checkpoints, with backward-compatible + deserialization that defaults older checkpoints to at least the completed + trial count. +- Added focused Rust integration coverage for pending `ask()` batches crossing + the exploration boundary and for full-checkpoint save/load after pending + asks. +- Added Python coverage for `Study.save()` + `Study.load()` continuing the + AutoStrategy issued-trial budget after pending asks. + +## Acceptance Criteria + +- Issuing N concurrent `ask()` calls crosses the exploration boundary exactly + once based on issued trials. +- Checkpoint save/load preserves the issued counter. +- Existing sequential behavior remains unchanged. + +## Suggested Tests + +- Configure `exploration_budget = 2`, call `ask()` 4 times before any `tell()`, + assert first 2 come from Sobol path and later suggestions use GMM path. +- Save after pending asks, load, continue asking, assert budget continuity. + +## Verification + +- `cargo test -p hola --test integration auto_strategy --all-features` +- `cargo test -p hola --test integration checkpoint --all-features` +- `cd hola-py && uv run maturin develop` +- `cd hola-py && uv run pytest -q tests/test_study_advanced.py -k 'gmm_counts_pending_asks or gmm_save_load_preserves_pending_ask_accounting'` +- `cargo test --workspace --all-features` +- `cd hola-py && uv run pytest -q` + +All passed after the fix. diff --git a/issues/09-ranking-refit-scalability.md b/issues/09-ranking-refit-scalability.md new file mode 100644 index 0000000..1bc12a8 --- /dev/null +++ b/issues/09-ranking-refit-scalability.md @@ -0,0 +1,103 @@ +# Ranking And Refit Paths Scale Poorly Under Locks + +## Status + +Resolved in branch `audit/known-issues` with the issue 09 fix. + +## Tags + +- severity: medium +- type: performance +- area: leaderboard +- area: concurrency +- area: dashboard +- area: multi-objective +- regression-risk: medium + +## Summary + +Several hot paths rebuild ranked views by cloning and sorting the entire +leaderboard while engine locks are held or immediately after mutating state. +Multi-objective ranking is documented as `O(M * N^2)`. This is acceptable for +small studies, but will become a bottleneck for large dashboards, frequent +SSE-driven refreshes, and distributed ask/tell workloads. + +## Evidence + +- `tell()` computes a completed view immediately after pushing a trial. + - `hola/src/hola_engine.rs:1157` +- `get_completed()` builds full ranked lists and searches them. + - `hola/src/hola_engine.rs:714` +- Multi-objective non-dominated sort is documented as `O(M * N^2)`. + - `opt_engine/src/leaderboard.rs:653` +- Dashboard refetches all trials on each SSE trial completion. + - `dashboard/app.js:77` + +## Impact + +- `tell()` latency grows with leaderboard size. +- Dashboard-connected runs can repeatedly trigger full ranking and JSON + serialization. +- Multi-objective studies can become slow at hundreds or thousands of trials. +- Long read/write lock holds can block concurrent asks/tells. + +## Proposed Fix + +Separate mutation, ranking, and presentation concerns. + +Recommended design: + +- In `tell()`, return the just completed trial with lightweight rank info or + compute rank outside the write lock using a cloned snapshot. +- Add cached ranking snapshots invalidated on leaderboard mutation. +- Add paginated or incremental `/api/trials` responses. +- Add `/api/trial/{id}` so remote Python `tell()` does not fetch every trial. +- Consider specialized Pareto/ranking data structures for large + multi-objective studies. + +## Resolution + +- Changed `HolaEngine::tell()` to snapshot the leaderboard and release the + write lock before constructing the ranked `CompletedTrial`. +- Added scalar single-trial completion lookup that computes exact rank with a + linear scan instead of rebuilding the full ranked list. +- Added `HolaEngine::completed_trial()` and `GET /api/trial/{trial_id}` for + single completed-trial retrieval. +- Included the completed trial in `POST /api/tell` responses and SSE + `TrialCompleted` events. +- Updated remote Python `Study.tell()` to consume the returned completed trial, + falling back to `/api/trial/{id}` and only then to the legacy full-trials + fetch for older servers. +- Updated the dashboard SSE path to upsert a single completed trial instead of + refetching all trials after each completion. +- Optimized scalar `top_k`/`bottom_k` and vector `top_k_scalarized` to select + only the requested prefix before sorting it, avoiding full sorts for refit + and small top-k requests. +- Added ignored Rust scalability probes for representative scalar and vector + leaderboard ranking sizes. + +## Acceptance Criteria + +- `tell()` does not rebuild the entire ranked list while holding the write lock. +- Dashboard can update from a single completed-trial event or paginated fetch. +- Benchmarks exist for scalar and vector leaderboard ranking at representative + sizes, for example 1k, 10k, and 50k scalar trials. + +## Suggested Tests + +- Add criterion or lightweight benchmark tests for `top_k`, `trials`, and + `pareto_front`. +- Add a concurrency test with many asks/tells and dashboard reads to catch lock + contention regressions. + +## Verification + +- `cargo test -p opt_engine leaderboard::tests::test_top_k --all-features` +- `cargo test -p opt_engine --test integration leaderboard --all-features` +- `cargo test -p hola --test integration server --features server` +- `cd hola-py && uv run maturin develop` +- `cd hola-py && uv run pytest -q tests/test_server.py -k 'ask_tell_best_flow or study_connect_ask_tell_best'` +- `cargo test --workspace --all-features` +- `cd hola-py && uv run pytest -q` + +All passed after the fix. diff --git a/issues/10-dependency-feature-minimization.md b/issues/10-dependency-feature-minimization.md new file mode 100644 index 0000000..331d102 --- /dev/null +++ b/issues/10-dependency-feature-minimization.md @@ -0,0 +1,98 @@ +# Dependency Feature Sets Are Broader Than Needed + +## Status + +Resolved in branch `audit/known-issues` with the issue 10 fix. + +## Tags + +- severity: low +- type: maintenance +- area: dependencies +- area: build +- area: security-surface +- regression-risk: low + +## Summary + +The Rust crates use broad dependency feature sets, notably `tokio` with +`features = ["full"]`. This increases compile time, dependency graph size, and +security/update surface beyond what the code appears to need. + +## Evidence + +- `opt_engine` depends on Tokio with all features. + - `opt_engine/Cargo.toml:10` +- `hola` depends on Tokio with all features. + - `hola/Cargo.toml:15` +- `hola-cli` depends on Tokio with all features. + - `hola-cli/Cargo.toml:17` + +## Impact + +- Slower clean builds and CI. +- More transitive code included in the dependency graph. +- Larger audit and vulnerability review surface. + +## Proposed Fix + +Minimize features per crate. + +Likely needed features: + +- `opt_engine`: `sync`, `rt`, and possibly `rt-multi-thread` only if tests or + `spawn_blocking` require it. +- `hola`: `sync`, `rt`, `rt-multi-thread`, `net`, `macros` depending on server + and tests. +- `hola-cli`: `macros`, `rt-multi-thread`, `time`, `process` if async process + APIs are used, otherwise standard process does not need Tokio process. + +Validate by reducing features incrementally and running the full Rust and +Python test suites. + +## Resolution + +- Replaced `tokio = { features = ["full"] }` in `opt_engine`, `hola`, + `hola-cli`, and `hola-py` with explicit minimal feature sets. +- `opt_engine` now enables only `macros`, `rt-multi-thread`, and `sync` for + async locks, blocking refit tasks, and tests/examples. +- `hola` now enables only `macros`, `rt-multi-thread`, and `sync` by default; + the `server` feature adds `tokio/net` and the minimal Axum features needed + for JSON, query extractors, HTTP/1, and Tokio serving. +- `hola-cli` now enables only `macros`, `rt-multi-thread`, and `time` for the + Tokio main runtime, async HTTP, and retry sleeps. +- `hola-py` now enables only `rt-multi-thread` for constructing runtimes from + Python bindings. +- Disabled unnecessary default features on optional server dependencies where + the code uses narrower feature sets: `axum`, `tokio-stream`, and + `tower-http`. +- Updated `Cargo.lock`; the resolved graph no longer includes Tokio + full-feature-only dependencies such as `parking_lot` and + `signal-hook-registry`. + +## Acceptance Criteria + +- Replace `tokio = { features = ["full"] }` with the minimal feature set in + each crate. +- `cargo check --workspace --all-features` passes. +- `cargo test --workspace --all-features` passes. +- Python test suite still passes, especially server and CLI integration tests. +- Document why each enabled feature is needed if the feature list is not + obvious. + +## Suggested Tests + +- Run `cargo tree -e features` before and after to confirm feature reduction. +- Run the full existing verification suite. + +## Verification + +- `cargo tree -e features -i tokio --workspace --all-features` +- `cargo check --workspace --all-features` +- `cargo check --workspace` +- `cargo test --workspace --all-features` +- `cargo check --workspace --all-features --all-targets` +- `cd hola-py && uv run maturin develop` +- `cd hola-py && uv run pytest -q` + +All passed after the fix. diff --git a/opt_engine/Cargo.toml b/opt_engine/Cargo.toml index 41da2bc..ef70210 100644 --- a/opt_engine/Cargo.toml +++ b/opt_engine/Cargo.toml @@ -7,7 +7,8 @@ license = "Apache-2.0" [dependencies] serde = { version = "1", features = ["derive"] } serde_json = "1" -tokio = { version = "1", features = ["full"] } +# Engine internals need async locks and blocking refit tasks; tests/examples use Tokio macros. +tokio = { version = "1", default-features = false, features = ["macros", "rt-multi-thread", "sync"] } rand = "0.9.3" sobol_burley = "0.5" nalgebra = { version = "0.34", features = ["serde-serialize"] } diff --git a/opt_engine/src/leaderboard.rs b/opt_engine/src/leaderboard.rs index d0eaed1..eb45c6c 100644 --- a/opt_engine/src/leaderboard.rs +++ b/opt_engine/src/leaderboard.rs @@ -24,6 +24,7 @@ use chrono::Utc; use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; use std::collections::BTreeMap; // ============================================================================= @@ -140,6 +141,18 @@ impl Leaderboard { trial_id } + /// Append a trial using an externally assigned trial ID. + /// + /// This is useful when trial IDs are issued before completion. The internal + /// auto-assignment counter is advanced so future generated IDs do not reuse + /// the provided ID. + pub fn push_with_trial_id(&mut self, candidate: D, observation: Obs, trial_id: u64) -> u64 { + self.next_id = self.next_id.max(trial_id.saturating_add(1)); + self.trials + .push(Trial::new(candidate, observation, trial_id)); + trial_id + } + /// Append a trial with raw metrics preserved for lazy re-scalarization. pub fn push_with_raw( &mut self, @@ -158,6 +171,55 @@ impl Leaderboard { trial_id } + /// Append a trial with an externally assigned trial ID and raw metrics + /// preserved for lazy re-scalarization. + pub fn push_with_raw_trial_id( + &mut self, + candidate: D, + observation: Obs, + raw_metrics: serde_json::Value, + trial_id: u64, + ) -> u64 { + self.next_id = self.next_id.max(trial_id.saturating_add(1)); + self.trials.push(Trial::with_raw_metrics( + candidate, + observation, + raw_metrics, + trial_id, + )); + trial_id + } + + /// Append an existing trial record, preserving its ID and timestamp. + /// + /// This supports rebuilding a leaderboard with a different observation type + /// while preserving public trial identity and completion metadata. + pub fn push_existing_trial(&mut self, trial: Trial) -> u64 { + let trial_id = trial.trial_id; + self.next_id = self.next_id.max(trial_id.saturating_add(1)); + self.trials.push(trial); + trial_id + } + + /// Return the next ID that can be assigned without reusing a stored trial ID. + pub fn next_trial_id(&self) -> u64 { + let next_from_trials = self + .trials + .iter() + .map(|trial| trial.trial_id.saturating_add(1)) + .max() + .unwrap_or(0); + self.next_id.max(next_from_trials) + } + + /// Repair the internal next-ID counter after deserializing or manually + /// modifying a leaderboard. + pub fn normalize_next_trial_id(&mut self) -> u64 { + let next_id = self.next_trial_id(); + self.next_id = next_id; + next_id + } + pub fn len(&self) -> usize { self.trials.len() } @@ -243,6 +305,20 @@ impl Leaderboard { is_feasible_scalar(trial.observation) } + fn compare_best(a: &Trial, b: &Trial) -> Ordering { + a.observation + .partial_cmp(&b.observation) + .unwrap_or(Ordering::Equal) + .then_with(|| a.trial_id.cmp(&b.trial_id)) + } + + fn compare_worst(a: &Trial, b: &Trial) -> Ordering { + b.observation + .partial_cmp(&a.observation) + .unwrap_or(Ordering::Equal) + .then_with(|| a.trial_id.cmp(&b.trial_id)) + } + /// Return only feasible trials (those with finite observations). pub fn feasible_trials(&self) -> Vec> { self.trials @@ -276,14 +352,14 @@ impl Leaderboard { .filter(|t| Self::trial_is_feasible(t)) .collect(); - feasible.sort_by(|a, b| { - a.observation - .partial_cmp(&b.observation) - .unwrap_or(std::cmp::Ordering::Equal) - .then_with(|| a.trial_id.cmp(&b.trial_id)) - }); + let limit = k.min(feasible.len()); + if limit < feasible.len() { + feasible.select_nth_unstable_by(limit, |a, b| Self::compare_best(a, b)); + feasible.truncate(limit); + } + feasible.sort_by(|a, b| Self::compare_best(a, b)); - feasible.into_iter().take(k).cloned().collect() + feasible.into_iter().cloned().collect() } /// Return the k trials with the lowest observations, including infeasible. @@ -292,21 +368,15 @@ impl Leaderboard { return Vec::new(); } - let mut indices: Vec = (0..self.trials.len()).collect(); - indices.sort_by(|&a, &b| { - let obs_a = self.trials[a].observation; - let obs_b = self.trials[b].observation; - obs_a - .partial_cmp(&obs_b) - .unwrap_or(std::cmp::Ordering::Equal) - .then_with(|| self.trials[a].trial_id.cmp(&self.trials[b].trial_id)) - }); + let mut trials: Vec<&Trial> = self.trials.iter().collect(); + let limit = k.min(trials.len()); + if limit < trials.len() { + trials.select_nth_unstable_by(limit, |a, b| Self::compare_best(a, b)); + trials.truncate(limit); + } + trials.sort_by(|a, b| Self::compare_best(a, b)); - indices - .into_iter() - .take(k) - .map(|i| self.trials[i].clone()) - .collect() + trials.into_iter().cloned().collect() } /// Return feasible trials sorted by observation (ascending). @@ -345,14 +415,14 @@ impl Leaderboard { .filter(|t| Self::trial_is_feasible(t)) .collect(); - feasible.sort_by(|a, b| { - b.observation - .partial_cmp(&a.observation) - .unwrap_or(std::cmp::Ordering::Equal) - .then_with(|| a.trial_id.cmp(&b.trial_id)) - }); + let limit = k.min(feasible.len()); + if limit < feasible.len() { + feasible.select_nth_unstable_by(limit, |a, b| Self::compare_worst(a, b)); + feasible.truncate(limit); + } + feasible.sort_by(|a, b| Self::compare_worst(a, b)); - feasible.into_iter().take(k).cloned().collect() + feasible.into_iter().cloned().collect() } /// Return the k worst trials, including infeasible. @@ -361,21 +431,15 @@ impl Leaderboard { return Vec::new(); } - let mut indices: Vec = (0..self.trials.len()).collect(); - indices.sort_by(|&a, &b| { - let obs_a = self.trials[a].observation; - let obs_b = self.trials[b].observation; - obs_b - .partial_cmp(&obs_a) - .unwrap_or(std::cmp::Ordering::Equal) - .then_with(|| self.trials[a].trial_id.cmp(&self.trials[b].trial_id)) - }); + let mut trials: Vec<&Trial> = self.trials.iter().collect(); + let limit = k.min(trials.len()); + if limit < trials.len() { + trials.select_nth_unstable_by(limit, |a, b| Self::compare_worst(a, b)); + trials.truncate(limit); + } + trials.sort_by(|a, b| Self::compare_worst(a, b)); - indices - .into_iter() - .take(k) - .map(|i| self.trials[i].clone()) - .collect() + trials.into_iter().cloned().collect() } /// Compute the quantile threshold for feasible observations. @@ -419,6 +483,8 @@ impl Leaderboard { // Multi-Objective Ranking (BTreeMap observations) // ============================================================================= +type ScoredMultiTrial<'a, D> = (&'a Trial>, f64); + impl Leaderboard> { /// Check if a trial is feasible (all objective values are finite). /// @@ -577,27 +643,37 @@ impl Leaderboard> { where F: Fn(&BTreeMap) -> f64, { - let feasible = self.feasible_trials(); - if feasible.is_empty() || k == 0 { + if self.trials.is_empty() || k == 0 { return Vec::new(); } - let mut indexed: Vec<(usize, f64)> = feasible + let mut scored: Vec> = self + .trials .iter() - .enumerate() - .map(|(i, t)| (i, scalarizer(&t.observation))) + .filter(|t| Self::trial_is_feasible(t)) + .map(|t| (t, scalarizer(&t.observation))) .collect(); - indexed.sort_by(|a, b| { + if scored.is_empty() { + return Vec::new(); + } + + let compare = |a: &ScoredMultiTrial<'_, D>, b: &ScoredMultiTrial<'_, D>| { a.1.partial_cmp(&b.1) - .unwrap_or(std::cmp::Ordering::Equal) - .then_with(|| feasible[a.0].trial_id.cmp(&feasible[b.0].trial_id)) - }); + .unwrap_or(Ordering::Equal) + .then_with(|| a.0.trial_id.cmp(&b.0.trial_id)) + }; - indexed + let limit = k.min(scored.len()); + if limit < scored.len() { + scored.select_nth_unstable_by(limit, compare); + scored.truncate(limit); + } + scored.sort_by(compare); + + scored .into_iter() - .take(k) - .map(|(i, _)| feasible[i].clone()) + .map(|(trial, _)| (*trial).clone()) .collect() } @@ -1037,6 +1113,44 @@ mod tests { assert_eq!(id3, 2); } + #[test] + fn test_externally_assigned_trial_ids_advance_next_id() { + let mut lb: Leaderboard = Leaderboard::new(); + + assert_eq!(lb.push_with_trial_id(0.1, 0.1, 7), 7); + assert_eq!(lb.push(0.2, 0.2), 8); + assert_eq!(lb.next_trial_id(), 9); + } + + #[test] + fn test_normalize_next_trial_id_repairs_stale_counter() { + let mut lb: Leaderboard = Leaderboard::new(); + lb.push(0.1, 0.1); + lb.push(0.2, 0.2); + lb.next_id = 0; + + assert_eq!(lb.next_trial_id(), 2); + assert_eq!(lb.normalize_next_trial_id(), 2); + assert_eq!(lb.push(0.3, 0.3), 2); + } + + #[test] + fn test_push_existing_trial_preserves_timestamp_and_advances_next_id() { + let mut lb: Leaderboard = Leaderboard::new(); + let trial = Trial { + candidate: 0.1, + observation: 0.2, + raw_metrics: Some(serde_json::json!({"loss": 0.2})), + trial_id: 4, + timestamp: 123, + }; + + assert_eq!(lb.push_existing_trial(trial), 4); + let stored = lb.get(4).unwrap(); + assert_eq!(stored.timestamp, 123); + assert_eq!(lb.push(0.3, 0.3), 5); + } + #[test] fn test_top_k_scalar() { let mut lb: Leaderboard<&str, f64> = Leaderboard::new(); diff --git a/opt_engine/src/lib.rs b/opt_engine/src/lib.rs index 91bab24..7b896df 100644 --- a/opt_engine/src/lib.rs +++ b/opt_engine/src/lib.rs @@ -25,8 +25,8 @@ //! - **Transformers** convert raw worker output into typed observations. //! - **[`Engine`]** orchestrates the loop with full compile-time type checking. //! -//! For the type-erased HOLA frontend (`DynEngine`, REST server), see the `hola` -//! crate which builds on top of `opt_engine`. +//! For the type-erased HOLA frontend (`HolaEngine`, Python bindings, CLI, and +//! REST server), see the `hola` crate which builds on top of `opt_engine`. //! //! # Quick start //! diff --git a/opt_engine/src/server.rs b/opt_engine/src/server.rs deleted file mode 100644 index d1d4a09..0000000 --- a/opt_engine/src/server.rs +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright 2026 BlackRock, Inc. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Axum HTTP server for the DynEngine. -//! -//! Provides REST endpoints for distributed Ask/Tell optimization, -//! Server-Sent Events for real-time dashboard integration, and -//! dashboard API endpoints for space/objectives/checkpoint management. -//! -//! # Endpoints -//! -//! - `POST /api/ask` - Request the next trial -//! - `POST /api/tell` - Report trial results -//! - `GET /api/leaderboard` - Get all completed trials -//! - `GET /api/best` - Get the best trial -//! - `GET /api/pareto_front` - Get Pareto front (multi-objective only) -//! - `PATCH /api/objectives` - Update objectives mid-run -//! - `GET /api/objectives` - Get current objectives -//! - `GET /api/space` - Get parameter space metadata -//! - `POST /api/checkpoint/save` - Save a checkpoint -//! - `GET /api/mode` - Get the server mode ("live") -//! - `GET /api/events` - SSE stream of engine events - -use crate::dyn_engine::{DynEngine, DynTrial, ObjectiveConfig}; -use axum::{ - Router, - extract::State, - http::StatusCode, - response::{ - Json, - sse::{Event, Sse}, - }, - routing::{get, patch, post}, -}; -use serde::{Deserialize, Serialize}; -use std::convert::Infallible; -use std::path::Path; -use std::sync::Arc; -use tokio::sync::broadcast; -use tokio_stream::StreamExt; -use tokio_stream::wrappers::BroadcastStream; -use tower_http::cors::CorsLayer; -use tower_http::services::ServeDir; - -// ============================================================================= -// Shared state -// ============================================================================= - -/// Events emitted by the engine for SSE consumers. -#[derive(Clone, Debug, Serialize)] -#[serde(tag = "type")] -pub enum EngineEvent { - TrialCompleted { trial_id: u64, score: f64 }, - RefitOccurred { n_trials: usize }, -} - -pub struct ServerState { - pub engine: DynEngine, - pub events_tx: broadcast::Sender, -} - -// ============================================================================= -// Request/Response types -// ============================================================================= - -#[derive(Deserialize)] -struct TellRequest { - trial_id: u64, - metrics: serde_json::Value, -} - -#[derive(Deserialize)] -struct SaveCheckpointRequest { - #[serde(default = "default_checkpoint_path")] - path: String, - #[serde(default)] - description: Option, -} - -fn default_checkpoint_path() -> String { - "checkpoint.json".to_string() -} - -#[derive(Serialize)] -struct ErrorResponse { - error: String, -} - -// ============================================================================= -// Handlers -// ============================================================================= - -async fn handle_ask(State(state): State>) -> Json { - let trial = state.engine.ask().await; - Json(trial) -} - -async fn handle_tell( - State(state): State>, - Json(req): Json, -) -> Result, (StatusCode, Json)> { - match state.engine.tell(req.trial_id, req.metrics).await { - Ok(()) => { - let n = state.engine.trial_count().await; - let best = state.engine.best().await; - let score = best - .map(|b| b.observation.as_f64().unwrap_or(f64::INFINITY)) - .unwrap_or(f64::INFINITY); - - let _ = state.events_tx.send(EngineEvent::TrialCompleted { - trial_id: req.trial_id, - score, - }); - - Ok(Json(serde_json::json!({ - "status": "ok", - "trial_count": n, - }))) - } - Err(e) => Err((StatusCode::BAD_REQUEST, Json(ErrorResponse { error: e }))), - } -} - -async fn handle_leaderboard(State(state): State>) -> Json { - let trials = state.engine.trials_as_json().await; - let total = trials.len(); - Json(serde_json::json!({ - "trials": trials, - "total": total, - })) -} - -async fn handle_best(State(state): State>) -> Json { - match state.engine.best().await { - Some(best) => Json(serde_json::json!({ - "trial_id": best.trial_id, - "candidate": best.candidate, - "observation": best.observation, - "raw_metrics": best.raw_metrics, - })), - None => Json(serde_json::json!({ "trial_id": null })), - } -} - -#[derive(Deserialize)] -struct UpdateObjectivesRequest { - objectives: Vec, -} - -async fn handle_pareto_front( - State(state): State>, -) -> Result, (StatusCode, Json)> { - match state.engine.pareto_front().await { - Ok(front) => { - let trials: Vec = front - .into_iter() - .map(|t| { - let mut objectives = serde_json::Map::new(); - if let Some(ref raw) = t.raw_metrics { - for key in t.observation.keys() { - if let Some(val) = raw.get(key).and_then(|v| v.as_f64()) { - objectives.insert( - key.clone(), - serde_json::Value::from(val), - ); - } - } - } - serde_json::json!({ - "trial_id": t.trial_id, - "candidate": t.candidate, - "objectives": objectives, - }) - }) - .collect(); - Ok(Json(serde_json::Value::Array(trials))) - } - Err(e) => Err((StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": e})))), - } -} - -async fn handle_update_objectives( - State(state): State>, - Json(req): Json, -) -> Json { - state.engine.update_objectives(req.objectives).await; - let n = state.engine.trial_count().await; - Json(serde_json::json!({ - "status": "ok", - "rescalarized_trials": n, - })) -} - -async fn handle_get_objectives(State(state): State>) -> Json { - let objectives = state.engine.objectives().await; - Json(serde_json::json!({ "objectives": objectives })) -} - -async fn handle_space(State(state): State>) -> Json { - let params: Vec = state - .engine - .space_config() - .into_iter() - .map(|(name, info)| { - serde_json::json!({ - "name": name, - "type": info.param_type, - "min": info.min, - "max": info.max, - "scale": info.scale, - }) - }) - .collect(); - Json(serde_json::json!({ "params": params })) -} - -async fn handle_checkpoint_save( - State(state): State>, - Json(req): Json, -) -> Result, (StatusCode, Json)> { - match state - .engine - .save_leaderboard_checkpoint_to(&req.path, req.description.as_deref()) - .await - { - Ok(()) => { - let n = state.engine.trial_count().await; - Ok(Json(serde_json::json!({ - "status": "ok", - "path": req.path, - "trials_saved": n, - }))) - } - Err(e) => Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: e.to_string(), - }), - )), - } -} - -async fn handle_mode() -> Json { - Json(serde_json::json!({ "mode": "live" })) -} - -async fn handle_events( - State(state): State>, -) -> Sse>> { - let rx = state.events_tx.subscribe(); - let stream = BroadcastStream::new(rx).filter_map(|result| match result { - Ok(event) => { - let data = serde_json::to_string(&event).unwrap_or_default(); - Some(Ok(Event::default().data(data))) - } - Err(_) => None, - }); - Sse::new(stream) -} - -// ============================================================================= -// Router & Server -// ============================================================================= - -/// Create the Axum router for the engine server. -pub fn create_router(engine: DynEngine) -> Router { - let (events_tx, _) = broadcast::channel(256); - let state = Arc::new(ServerState { engine, events_tx }); - - let cors = CorsLayer::permissive(); - - Router::new() - .route("/api/ask", post(handle_ask)) - .route("/api/tell", post(handle_tell)) - .route("/api/leaderboard", get(handle_leaderboard)) - .route("/api/best", get(handle_best)) - .route("/api/pareto_front", get(handle_pareto_front)) - .route( - "/api/objectives", - patch(handle_update_objectives).get(handle_get_objectives), - ) - .route("/api/space", get(handle_space)) - .route("/api/checkpoint/save", post(handle_checkpoint_save)) - .route("/api/mode", get(handle_mode)) - .route("/api/events", get(handle_events)) - .layer(cors) - .with_state(state) -} - -/// Create the Axum router with the dashboard served from a local directory. -/// -/// API routes under `/api/*` take priority; all other paths fall through to -/// serve static files from `dashboard_dir`. -pub fn create_router_with_dashboard(engine: DynEngine, dashboard_dir: &Path) -> Router { - create_router(engine).fallback_service(ServeDir::new(dashboard_dir)) -} - -/// Start the server on the given port. Blocks until the server is shut down. -/// -/// If `dashboard_dir` is provided, the dashboard UI is served at `/`. -pub async fn serve( - engine: DynEngine, - port: u16, - dashboard_dir: Option<&Path>, -) -> Result<(), Box> { - let router = match dashboard_dir { - Some(dir) => create_router_with_dashboard(engine, dir), - None => create_router(engine), - }; - let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}")).await?; - if let Some(dir) = dashboard_dir { - eprintln!( - "HOLA server listening on port {port} (dashboard: {})", - dir.display() - ); - } else { - eprintln!("HOLA server listening on port {port}"); - } - axum::serve(listener, router).await?; - Ok(()) -} diff --git a/opt_engine/tests/integration/dyn_engine.rs b/opt_engine/tests/integration/dyn_engine.rs deleted file mode 100644 index 0eef300..0000000 --- a/opt_engine/tests/integration/dyn_engine.rs +++ /dev/null @@ -1,1318 +0,0 @@ -// Copyright 2026 BlackRock, Inc. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Integration tests for DynEngine (type-erased layer). -//! -//! Exercises config parsing, ask/tell flows, strategy types, scalarization, -//! objectives, checkpoints, refit, and all parameter types. - -use opt_engine::dyn_engine::{ - DynEngine, ObjectiveConfig, ParamConfig, StrategyConfig, StudyConfig, -}; -use opt_engine::traits::SampleSpace; -use serde_json::json; -use std::collections::BTreeMap; - -// ========================================================================== -// Config parsing -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_config_parsing() { - let yaml_config = r#" - space: - learning_rate: - type: real - min: 0.0001 - max: 0.1 - scale: log10 - num_layers: - type: integer - min: 1 - max: 10 - objectives: - - field: loss - type: minimize - priority: 1.0 - - field: latency - type: minimize - target: 100 - limit: 500 - priority: 0.5 - strategy: - type: sobol - refit_interval: 20 - "#; - - let config: StudyConfig = serde_yaml::from_str(yaml_config).unwrap(); - assert_eq!(config.space.len(), 2); - assert_eq!(config.objectives.len(), 2); - assert!(config.strategy.is_some()); - assert!(config.checkpoint.is_none()); - - let engine = DynEngine::from_config(config).unwrap(); - assert_eq!(engine.trial_count().await, 0); -} - -#[tokio::test] -async fn test_dyn_engine_config_with_checkpoint() { - let yaml = r#" - space: - x: - type: real - min: 0.0 - max: 1.0 - objectives: - - field: loss - type: minimize - checkpoint: - directory: "/tmp/hola_test_ckpts" - interval: 10 - max_checkpoints: 3 - "#; - - let config: StudyConfig = serde_yaml::from_str(yaml).unwrap(); - assert!(config.checkpoint.is_some()); - let ckpt = config.checkpoint.as_ref().unwrap(); - assert_eq!(ckpt.directory, "/tmp/hola_test_ckpts"); - assert_eq!(ckpt.interval, 10); - assert_eq!(ckpt.max_checkpoints, Some(3)); - - let engine = DynEngine::from_config(config).unwrap(); - let _t = engine.ask().await; -} - -// ========================================================================== -// Ask/Tell flow -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_ask_tell_flow() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - assert_eq!(engine.trial_count().await, 0); - assert!(engine.best().await.is_none()); - - let t0 = engine.ask().await; - assert_eq!(t0.trial_id, 0); - let t1 = engine.ask().await; - assert_eq!(t1.trial_id, 1); - - engine.tell(t0.trial_id, json!({"loss": 0.8})).await.unwrap(); - assert_eq!(engine.trial_count().await, 1); - - engine.tell(t1.trial_id, json!({"loss": 0.2})).await.unwrap(); - assert_eq!(engine.trial_count().await, 2); - - let best = engine.best().await.unwrap(); - assert_eq!(best.trial_id, 1); - assert_eq!(engine.trial_count().await, 2); -} - -#[tokio::test] -async fn test_dyn_engine_unknown_trial_error() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let result = engine.tell(999, json!({"loss": 0.5})).await; - assert!(result.is_err()); - assert!(result.unwrap_err().contains("999")); -} - -#[tokio::test] -async fn test_dyn_engine_double_tell_error() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let t = engine.ask().await; - engine.tell(t.trial_id, json!({"loss": 0.5})).await.unwrap(); - assert!(engine.tell(t.trial_id, json!({"loss": 0.3})).await.is_err()); -} - -// ========================================================================== -// All parameter types -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_all_param_types() { - let config = StudyConfig { - space: BTreeMap::from([ - ( - "lr".to_string(), - ParamConfig::Real { - min: 1e-4, - max: 0.1, - scale: "log10".to_string(), - }, - ), - ( - "layers".to_string(), - ParamConfig::Integer { min: 1, max: 10 }, - ), - ( - "optimizer".to_string(), - ParamConfig::Categorical { - choices: vec!["adam".into(), "sgd".into(), "rmsprop".into()], - }, - ), - ( - "dropout".to_string(), - ParamConfig::Real { - min: 0.0, - max: 0.5, - scale: "linear".to_string(), - }, - ), - ]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - for _ in 0..10 { - let t = engine.ask().await; - assert!(engine.space().contains(&t.params)); - engine.tell(t.trial_id, json!({"loss": 0.5})).await.unwrap(); - } - assert_eq!(engine.trial_count().await, 10); -} - -#[tokio::test] -async fn test_dyn_engine_categorical_params() { - let config = StudyConfig { - space: BTreeMap::from([( - "optimizer".to_string(), - ParamConfig::Categorical { - choices: vec!["adam".into(), "sgd".into(), "rmsprop".into()], - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let valid_choices: Vec = vec!["adam".into(), "sgd".into(), "rmsprop".into()]; - for _ in 0..20 { - let t = engine.ask().await; - let opt = t.params.get("optimizer").unwrap().as_str().unwrap(); - assert!(valid_choices.contains(&opt.to_string())); - } -} - -#[tokio::test] -async fn test_dyn_engine_ask_returns_valid_params() { - let config = StudyConfig { - space: BTreeMap::from([ - ( - "lr".to_string(), - ParamConfig::Real { - min: 0.001, - max: 1.0, - scale: "log10".to_string(), - }, - ), - ( - "batch".to_string(), - ParamConfig::Integer { min: 16, max: 256 }, - ), - ]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - for _ in 0..20 { - let t = engine.ask().await; - assert!(engine.space().contains(&t.params)); - } -} - -// ========================================================================== -// Param info -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_param_info() { - let config = StudyConfig { - space: BTreeMap::from([ - ( - "lr".to_string(), - ParamConfig::Real { - min: 1e-4, - max: 0.1, - scale: "log10".to_string(), - }, - ), - ( - "layers".to_string(), - ParamConfig::Integer { min: 1, max: 10 }, - ), - ( - "opt".to_string(), - ParamConfig::Categorical { - choices: vec!["adam".into(), "sgd".into()], - }, - ), - ]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let info = engine.space_config(); - assert_eq!(info.len(), 3); - - let info_map: BTreeMap = info.into_iter().collect(); - assert_eq!(info_map["lr"].param_type, "real"); - assert_eq!(info_map["lr"].scale, "log10"); - assert_eq!(info_map["layers"].param_type, "integer"); - assert_eq!(info_map["opt"].param_type, "categorical"); - assert_eq!(info_map["opt"].choices.as_ref().unwrap().len(), 2); -} - -// ========================================================================== -// Strategy types -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_strategy_types() { - // Random - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: Some(StrategyConfig { - strategy_type: "random".to_string(), - refit_interval: 20, - total_budget: None, - exploration_budget: None, - seed: None, - elite_fraction: None, - }), - checkpoint: None, - }; - let engine = DynEngine::from_config(config).unwrap(); - assert!(engine.space().contains(&engine.ask().await.params)); - - // GMM - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: Some(StrategyConfig { - strategy_type: "gmm".to_string(), - refit_interval: 20, - total_budget: None, - exploration_budget: None, - seed: None, - elite_fraction: None, - }), - checkpoint: None, - }; - let engine = DynEngine::from_config(config).unwrap(); - assert!(engine.space().contains(&engine.ask().await.params)); - - // Default (Sobol) - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - let engine = DynEngine::from_config(config).unwrap(); - assert!(engine.space().contains(&engine.ask().await.params)); -} - -// ========================================================================== -// Scalarization -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_scalarize_missing_field_infinity() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let t = engine.ask().await; - engine.tell(t.trial_id, json!({"accuracy": 0.9})).await.unwrap(); - - // Missing field → infeasible trial → no feasible best - assert!(engine.best().await.is_none()); -} - -#[tokio::test] -async fn test_dyn_engine_scalarize_maximize() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "accuracy".to_string(), - obj_type: "maximize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let t1 = engine.ask().await; - engine.tell(t1.trial_id, json!({"accuracy": 0.9})).await.unwrap(); - let t2 = engine.ask().await; - engine.tell(t2.trial_id, json!({"accuracy": 0.5})).await.unwrap(); - - let best = engine.best().await.unwrap(); - assert!( - best.observation.as_f64().unwrap() < 0.0, - "Maximized field should be negated" - ); -} - -// ========================================================================== -// TLP objectives -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_tlp_objectives() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: Some(0.0), - limit: Some(1.0), - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let t1 = engine.ask().await; - engine.tell(t1.trial_id, json!({"loss": 0.5})).await.unwrap(); - let t2 = engine.ask().await; - engine.tell(t2.trial_id, json!({"loss": 2.0})).await.unwrap(); - - // Two trials told, but one is infeasible (loss >= limit=1.5) - assert_eq!(engine.trial_count().await, 2); - assert!(engine.best().await.unwrap().observation.as_f64().unwrap().is_finite()); -} - -// ========================================================================== -// Objectives -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_update_objectives() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let t = engine.ask().await; - engine - .tell(t.trial_id, json!({"loss": 0.5, "accuracy": 0.9})) - .await - .unwrap(); - - engine - .update_objectives(vec![ObjectiveConfig { - field: "accuracy".to_string(), - obj_type: "maximize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }]) - .await; - - assert!(engine.best().await.is_some()); -} - -#[tokio::test] -async fn test_dyn_engine_objectives_accessor() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ - ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }, - ObjectiveConfig { - field: "acc".to_string(), - obj_type: "maximize".to_string(), - target: None, - limit: None, - priority: 0.5, - group: None, - }, - ], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let objs = engine.objectives().await; - assert_eq!(objs.len(), 2); - assert_eq!(objs[0].field, "loss"); - assert_eq!(objs[1].field, "acc"); -} - -#[tokio::test] -async fn test_dyn_engine_update_objectives_rescalarizes() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - - let metrics = vec![ - json!({"loss": 0.1, "accuracy": 0.3}), - json!({"loss": 0.5, "accuracy": 0.9}), - json!({"loss": 0.3, "accuracy": 0.5}), - json!({"loss": 0.8, "accuracy": 0.95}), - json!({"loss": 0.2, "accuracy": 0.4}), - ]; - for m in metrics { - let t = engine.ask().await; - engine.tell(t.trial_id, m).await.unwrap(); - } - - let best_before = engine.best().await.unwrap(); - assert_eq!(best_before.trial_id, 0); - - engine - .update_objectives(vec![ObjectiveConfig { - field: "accuracy".to_string(), - obj_type: "maximize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }]) - .await; - - let best_after = engine.best().await.unwrap(); - assert_ne!(best_before.trial_id, best_after.trial_id); -} - -// ========================================================================== -// Rescalarize -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_rescalarize() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let t = engine.ask().await; - engine - .tell(t.trial_id, json!({"loss": 0.5, "acc": 0.9})) - .await - .unwrap(); - - engine.rescalarize().await; - assert_eq!(engine.trial_count().await, 1); -} - -// ========================================================================== -// GMM with refit -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_gmm_with_refit() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: Some(StrategyConfig { - strategy_type: "gmm".to_string(), - refit_interval: 5, - total_budget: None, - exploration_budget: None, - seed: None, - elite_fraction: None, - }), - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - for i in 0..30 { - let t = engine.ask().await; - engine - .tell(t.trial_id, json!({"loss": (i as f64) * 0.03})) - .await - .unwrap(); - } - - assert_eq!(engine.trial_count().await, 30); - assert!(engine.best().await.is_some()); -} - -#[tokio::test] -async fn test_refit_excludes_infeasible_scalar() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: Some(0.0), - limit: Some(1.0), - priority: 1.0, - group: None, - }], - strategy: Some(StrategyConfig { - strategy_type: "gmm".to_string(), - refit_interval: 1, - total_budget: None, - exploration_budget: None, - seed: None, - elite_fraction: None, - }), - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - - for i in 0..25 { - let t = engine.ask().await; - let loss_val = if i % 5 == 4 { 2.0 } else { (i as f64) * 0.03 }; - engine.tell(t.trial_id, json!({"loss": loss_val})).await.unwrap(); - } - - // 25 trials: 5 infeasible (loss=2.0), 20 feasible - assert_eq!(engine.trial_count().await, 25); - assert!(engine.best().await.is_some()); - - let t = engine.ask().await; - assert!(engine.space().contains(&t.params)); -} - -#[tokio::test] -async fn test_update_objectives_triggers_refit() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: Some(StrategyConfig { - strategy_type: "gmm".to_string(), - refit_interval: 10, - total_budget: None, - exploration_budget: None, - seed: None, - elite_fraction: None, - }), - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - - for i in 0..30 { - let t = engine.ask().await; - let x = (i as f64) / 29.0; - engine - .tell(t.trial_id, json!({"loss": x, "accuracy": x})) - .await - .unwrap(); - } - - let best_before = engine.best().await.unwrap(); - - engine - .update_objectives(vec![ObjectiveConfig { - field: "accuracy".to_string(), - obj_type: "maximize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }]) - .await; - - let best_after = engine.best().await.unwrap(); - assert!( - best_after.observation.as_f64().unwrap() - < best_before.observation.as_f64().unwrap() - ); - - let t = engine.ask().await; - assert!(engine.space().contains(&t.params)); -} - -// ========================================================================== -// Checkpoints -// ========================================================================== - -#[tokio::test] -async fn test_dyn_engine_leaderboard_checkpoint() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config.clone()).unwrap(); - let t0 = engine.ask().await; - engine.tell(t0.trial_id, json!({"loss": 0.5})).await.unwrap(); - let t1 = engine.ask().await; - engine.tell(t1.trial_id, json!({"loss": 0.3})).await.unwrap(); - - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("lb.json"); - engine - .save_leaderboard_checkpoint_to(&path, Some("2 trials")) - .await - .unwrap(); - - let engine2 = DynEngine::from_config(config).unwrap(); - engine2.load_leaderboard_checkpoint(&path).await.unwrap(); - assert_eq!(engine2.trial_count().await, 2); -} - -#[tokio::test] -async fn test_dyn_engine_full_checkpoint() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config.clone()).unwrap(); - let t = engine.ask().await; - engine.tell(t.trial_id, json!({"loss": 0.5})).await.unwrap(); - - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("full.json"); - engine - .save_full_checkpoint(&path, Some("full checkpoint test")) - .await - .unwrap(); - - let engine2 = DynEngine::from_config(config).unwrap(); - engine2.load_full_checkpoint(&path).await.unwrap(); - assert_eq!(engine2.trial_count().await, 1); -} - -// ========================================================================== -// Auto strategy (Sobol -> GMM switching) -// ========================================================================== - -#[tokio::test] -async fn test_auto_strategy_default() { - // With no strategy config, should use "auto" and work correctly - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, // should default to "auto" - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - - // Run enough trials to cross the exploration threshold and trigger refit - for i in 0..60 { - let t = engine.ask().await; - assert!(engine.space().contains(&t.params)); - let loss = (i as f64) / 59.0; - engine.tell(t.trial_id, json!({"loss": loss})).await.unwrap(); - } - - assert_eq!(engine.trial_count().await, 60); - let best = engine.best().await; - assert!(best.is_some()); -} - -#[tokio::test] -async fn test_auto_strategy_with_explicit_exploration_budget() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: Some(StrategyConfig { - strategy_type: "auto".to_string(), - refit_interval: 5, - total_budget: None, - exploration_budget: Some(10), // switch to GMM after 10 trials - seed: None, - elite_fraction: None, - }), - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - - for i in 0..30 { - let t = engine.ask().await; - assert!(engine.space().contains(&t.params)); - let loss = (i as f64) / 29.0; - engine.tell(t.trial_id, json!({"loss": loss})).await.unwrap(); - } - - assert_eq!(engine.trial_count().await, 30); -} - -#[test] -fn test_auto_strategy_default_exploration_budget() { - use opt_engine::dyn_engine::AutoStrategy; - - // min(40, 56) = 40 -> round down to 32 - assert_eq!(AutoStrategy::default_exploration_budget(200, 3), 32); - - // min(20, 52) = 20 -> round down to 16 - assert_eq!(AutoStrategy::default_exploration_budget(100, 1), 16); - - // min(200, 60) = 60 -> round down to 32 - assert_eq!(AutoStrategy::default_exploration_budget(1000, 5), 32); - - // min(10, 70) = 10 -> round down to 8 - assert_eq!(AutoStrategy::default_exploration_budget(50, 10), 8); - - // Edge cases - assert_eq!(AutoStrategy::default_exploration_budget(10, 1), 2); // min(2, 52) = 2 - assert_eq!(AutoStrategy::default_exploration_budget(5, 1), 1); // min(1, 52) = 1 -} - -// ========================================================================== -// Seed determinism tests -// ========================================================================== - -#[tokio::test] -async fn test_seed_determinism_sobol() { - let make_engine = |seed| { - DynEngine::from_config(StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: Some(StrategyConfig { - strategy_type: "sobol".to_string(), - refit_interval: 20, - total_budget: None, - exploration_budget: None, - seed: Some(seed), - elite_fraction: None, - }), - checkpoint: None, - }) - .unwrap() - }; - - let e1 = make_engine(123); - let e2 = make_engine(123); - - for _ in 0..10 { - let t1 = e1.ask().await; - let t2 = e2.ask().await; - assert_eq!(t1.params, t2.params, "Same seed should produce same candidates"); - } -} - -#[tokio::test] -async fn test_seed_determinism_random() { - let make_engine = |seed| { - DynEngine::from_config(StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: Some(StrategyConfig { - strategy_type: "random".to_string(), - refit_interval: 20, - total_budget: None, - exploration_budget: None, - seed: Some(seed), - elite_fraction: None, - }), - checkpoint: None, - }) - .unwrap() - }; - - let e1 = make_engine(42); - let e2 = make_engine(42); - - for _ in 0..10 { - let t1 = e1.ask().await; - let t2 = e2.ask().await; - assert_eq!(t1.params, t2.params, "Same seed should produce same candidates"); - } -} - -// ========================================================================== -// Pareto front tests -// ========================================================================== - -#[tokio::test] -async fn test_pareto_front_multi_objective() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ - ObjectiveConfig { - field: "f1".to_string(), - obj_type: "minimize".to_string(), - target: Some(0.0), - limit: Some(10.0), - priority: 1.0, - group: None, - }, - ObjectiveConfig { - field: "f2".to_string(), - obj_type: "minimize".to_string(), - target: Some(0.0), - limit: Some(10.0), - priority: 2.0, - group: None, - }, - ], - strategy: Some(StrategyConfig { - strategy_type: "random".to_string(), - refit_interval: 20, - total_budget: None, - exploration_budget: None, - seed: Some(0), - elite_fraction: None, - }), - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - - // Tell trials with known Pareto structure: - // (1,5) and (5,1) are non-dominated; (3,3) dominated by neither but (4,4) dominated by (3,3) - let trials_data = vec![ - json!({"f1": 1.0, "f2": 5.0}), // Pareto-optimal - json!({"f1": 5.0, "f2": 1.0}), // Pareto-optimal - json!({"f1": 3.0, "f2": 3.0}), // Pareto-optimal - json!({"f1": 4.0, "f2": 4.0}), // Dominated by (3,3) - ]; - - for data in trials_data { - let t = engine.ask().await; - engine.tell(t.trial_id, data).await.unwrap(); - } - - let front = engine.pareto_front().await.unwrap(); - assert_eq!(front.len(), 3, "Should have 3 non-dominated trials"); -} - -#[tokio::test] -async fn test_pareto_front_scalar_study_errors() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - assert!(engine.pareto_front().await.is_err()); -} - -#[tokio::test] -async fn test_pareto_front_empty() { - let config = StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ - ObjectiveConfig { - field: "f1".to_string(), - obj_type: "minimize".to_string(), - target: Some(0.0), - limit: Some(10.0), - priority: 1.0, - group: None, - }, - ObjectiveConfig { - field: "f2".to_string(), - obj_type: "minimize".to_string(), - target: Some(0.0), - limit: Some(10.0), - priority: 2.0, - group: None, - }, - ], - strategy: None, - checkpoint: None, - }; - - let engine = DynEngine::from_config(config).unwrap(); - let front = engine.pareto_front().await.unwrap(); - assert!(front.is_empty()); -} diff --git a/opt_engine/tests/integration/leaderboard_scalability.rs b/opt_engine/tests/integration/leaderboard_scalability.rs new file mode 100644 index 0000000..f283fac --- /dev/null +++ b/opt_engine/tests/integration/leaderboard_scalability.rs @@ -0,0 +1,91 @@ +// Copyright 2026 BlackRock, Inc. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Ignored scalability probes for leaderboard ranking hot paths. +//! +//! Run with: +//! `RUN_EXPENSIVE_BENCHMARKS=1 cargo test -p opt_engine --test integration leaderboard_scalability -- --ignored --nocapture` + +use opt_engine::leaderboard::Leaderboard; +use std::collections::BTreeMap; +use std::hint::black_box; +use std::time::Instant; + +const REPRESENTATIVE_SIZES: &[usize] = &[1_000, 10_000, 50_000]; + +fn benchmark_sizes() -> Vec { + if std::env::var_os("RUN_EXPENSIVE_BENCHMARKS").is_some() { + REPRESENTATIVE_SIZES.to_vec() + } else { + vec![1_000] + } +} + +fn scalar_leaderboard(n: usize) -> Leaderboard { + let mut lb = Leaderboard::with_capacity(n); + for i in 0..n { + let score = ((i.wrapping_mul(37) % n.max(1)) as f64) / n.max(1) as f64; + lb.push(i, score); + } + lb +} + +fn vector_leaderboard(n: usize) -> Leaderboard> { + let mut lb = Leaderboard::with_capacity(n); + for i in 0..n { + let mut obs = BTreeMap::new(); + obs.insert("loss".to_string(), (i % 997) as f64 / 997.0); + obs.insert("latency".to_string(), ((n - i) % 991) as f64 / 991.0); + lb.push(i, obs); + } + lb +} + +fn time_it(label: &str, f: impl FnOnce() -> T) -> T { + let start = Instant::now(); + let result = f(); + eprintln!("{label}: {:?}", start.elapsed()); + result +} + +#[test] +#[ignore = "performance probe; run explicitly with --ignored --nocapture"] +fn leaderboard_scalability_scalar_ranking() { + for n in benchmark_sizes() { + let lb = scalar_leaderboard(n); + let top = time_it(&format!("scalar n={n} top_k(100)"), || { + black_box(lb.top_k(100)) + }); + assert_eq!(top.len(), 100.min(n)); + + let all = time_it(&format!("scalar n={n} sorted_all()"), || { + black_box(lb.sorted_all()) + }); + assert_eq!(all.len(), n); + } +} + +#[test] +#[ignore = "performance probe; run explicitly with --ignored --nocapture"] +fn leaderboard_scalability_vector_ranking() { + for n in benchmark_sizes() { + let lb = vector_leaderboard(n); + let top = time_it(&format!("vector n={n} top_k_scalarized(100)"), || { + black_box(lb.top_k_scalarized(100, |obs| obs.values().sum())) + }); + assert_eq!(top.len(), 100.min(n)); + + let front = time_it(&format!("vector n={n} pareto_front()"), || { + black_box(lb.pareto_front()) + }); + assert!(!front.is_empty()); + } +} diff --git a/opt_engine/tests/integration/main.rs b/opt_engine/tests/integration/main.rs index e098538..9498735 100644 --- a/opt_engine/tests/integration/main.rs +++ b/opt_engine/tests/integration/main.rs @@ -15,3 +15,33 @@ mod end_to_end; mod engine; +mod leaderboard_scalability; + +#[test] +fn integration_test_files_are_referenced_by_harness() { + let integration_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("integration"); + let allowed = [ + "end_to_end.rs", + "engine.rs", + "leaderboard_scalability.rs", + "main.rs", + ]; + + let mut unreferenced = Vec::new(); + for entry in std::fs::read_dir(&integration_dir).unwrap() { + let path = entry.unwrap().path(); + if path.extension().is_some_and(|ext| ext == "rs") { + let name = path.file_name().unwrap().to_string_lossy().into_owned(); + if !allowed.contains(&name.as_str()) { + unreferenced.push(name); + } + } + } + + assert!( + unreferenced.is_empty(), + "opt_engine integration test files must be referenced by tests/integration/main.rs: {unreferenced:?}" + ); +} diff --git a/opt_engine/tests/integration/server.rs b/opt_engine/tests/integration/server.rs deleted file mode 100644 index 72508ae..0000000 --- a/opt_engine/tests/integration/server.rs +++ /dev/null @@ -1,442 +0,0 @@ -// Copyright 2026 BlackRock, Inc. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Integration tests for the REST API (in-process HTTP). -//! -//! Exercises ask/tell/best/leaderboard endpoints, space/mode/objectives info, -//! checkpoints, error handling, and objective rescalarization. - -use http_body_util::BodyExt; -use opt_engine::dyn_engine::{DynEngine, ObjectiveConfig, ParamConfig, StudyConfig}; -use opt_engine::server::create_router; -use serde_json::json; -use std::collections::BTreeMap; -use tower::ServiceExt; - -// ========================================================================== -// Helpers -// ========================================================================== - -fn minimal_config() -> StudyConfig { - StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - } -} - -fn multi_param_config() -> StudyConfig { - StudyConfig { - space: BTreeMap::from([ - ( - "lr".to_string(), - ParamConfig::Real { - min: 0.001, - max: 1.0, - scale: "log10".to_string(), - }, - ), - ( - "layers".to_string(), - ParamConfig::Integer { min: 1, max: 10 }, - ), - ( - "opt".to_string(), - ParamConfig::Categorical { - choices: vec!["adam".into(), "sgd".into()], - }, - ), - ]), - objectives: vec![ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: None, - limit: None, - priority: 1.0, - group: None, - }], - strategy: None, - checkpoint: None, - } -} - -async fn json_request( - app: axum::Router, - method: &str, - uri: &str, - body: Option, -) -> (u16, serde_json::Value) { - let mut builder = hyper::Request::builder().method(method).uri(uri); - let body = if let Some(b) = body { - builder = builder.header("content-type", "application/json"); - axum::body::Body::from(serde_json::to_vec(&b).unwrap()) - } else { - axum::body::Body::empty() - }; - let req = builder.body(body).unwrap(); - let resp = app.oneshot(req).await.unwrap(); - let status = resp.status().as_u16(); - let bytes = resp.into_body().collect().await.unwrap().to_bytes(); - let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); - (status, json) -} - -// ========================================================================== -// Core flow: ask → tell → best → leaderboard -// ========================================================================== - -#[tokio::test] -async fn test_server_ask_endpoint() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let req = hyper::Request::builder() - .method("POST") - .uri("/api/ask") - .body(axum::body::Body::empty()) - .unwrap(); - let resp = app.oneshot(req).await.unwrap(); - assert_eq!(resp.status(), 200); - - let body = resp.into_body().collect().await.unwrap().to_bytes(); - let trial: serde_json::Value = serde_json::from_slice(&body).unwrap(); - assert_eq!(trial["trial_id"], 0); - assert!(trial["params"]["x"].is_number()); -} - -#[tokio::test] -async fn test_server_ask_tell_best_flow() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - // Ask - let (_, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; - let trial_id = trial["trial_id"].as_u64().unwrap(); - - // Tell - let tell = json!({"trial_id": trial_id, "metrics": {"loss": 0.42}}); - let (status, result) = json_request(app.clone(), "POST", "/api/tell", Some(tell)).await; - assert_eq!(status, 200); - assert_eq!(result["status"], "ok"); - assert_eq!(result["trial_count"], 1); - - // Best - let (status, best) = json_request(app.clone(), "GET", "/api/best", None).await; - assert_eq!(status, 200); - assert_eq!(best["trial_id"], 0); - assert!(best["observation"].is_number()); - - // Leaderboard - let (status, lb) = json_request(app, "GET", "/api/leaderboard", None).await; - assert_eq!(status, 200); - assert_eq!(lb["total"], 1); - assert_eq!(lb["trials"].as_array().unwrap().len(), 1); -} - -#[tokio::test] -async fn test_server_best_empty() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (status, best) = json_request(app, "GET", "/api/best", None).await; - assert_eq!(status, 200); - assert!(best["trial_id"].is_null()); -} - -#[tokio::test] -async fn test_server_leaderboard_empty() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (status, body) = json_request(app, "GET", "/api/leaderboard", None).await; - assert_eq!(status, 200); - assert_eq!(body["total"], 0); -} - -// ========================================================================== -// Error handling -// ========================================================================== - -#[tokio::test] -async fn test_server_tell_unknown_trial() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let tell = json!({"trial_id": 999, "metrics": {"loss": 0.5}}); - let (status, err) = json_request(app, "POST", "/api/tell", Some(tell)).await; - assert_eq!(status, 400); - assert!(err["error"].as_str().unwrap().contains("999")); -} - -#[tokio::test] -async fn test_server_double_tell_returns_400() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (_, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; - let tell = json!({"trial_id": trial["trial_id"], "metrics": {"loss": 0.5}}); - - let (status, _) = json_request(app.clone(), "POST", "/api/tell", Some(tell.clone())).await; - assert_eq!(status, 200); - - let (status, body) = json_request(app, "POST", "/api/tell", Some(tell)).await; - assert_eq!(status, 400); - assert!( - body["error"] - .as_str() - .unwrap() - .contains(&trial["trial_id"].to_string()) - ); -} - -// ========================================================================== -// Info endpoints: space, mode, objectives -// ========================================================================== - -#[tokio::test] -async fn test_server_space_endpoint() { - let engine = DynEngine::from_config(multi_param_config()).unwrap(); - let app = create_router(engine); - - let (status, body) = json_request(app, "GET", "/api/space", None).await; - assert_eq!(status, 200); - - let params = body["params"].as_array().unwrap(); - assert_eq!(params.len(), 3); - let names: Vec<&str> = params.iter().map(|p| p["name"].as_str().unwrap()).collect(); - assert!(names.contains(&"lr")); - assert!(names.contains(&"layers")); - assert!(names.contains(&"opt")); -} - -#[tokio::test] -async fn test_server_space_with_all_param_types() { - let engine = DynEngine::from_config(multi_param_config()).unwrap(); - let app = create_router(engine); - - let (_, body) = json_request(app, "GET", "/api/space", None).await; - let params = body["params"].as_array().unwrap(); - - let find = - |name: &str| -> &serde_json::Value { params.iter().find(|p| p["name"] == name).unwrap() }; - assert_eq!(find("lr")["type"], "real"); - assert_eq!(find("lr")["scale"], "log10"); - assert_eq!(find("layers")["type"], "integer"); - assert_eq!(find("opt")["type"], "categorical"); -} - -#[tokio::test] -async fn test_server_mode_endpoint() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (status, body) = json_request(app, "GET", "/api/mode", None).await; - assert_eq!(status, 200); - assert_eq!(body["mode"], "live"); -} - -#[tokio::test] -async fn test_server_get_objectives_endpoint() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (status, body) = json_request(app, "GET", "/api/objectives", None).await; - assert_eq!(status, 200); - assert_eq!(body["objectives"].as_array().unwrap().len(), 1); - assert_eq!(body["objectives"][0]["field"], "loss"); -} - -// ========================================================================== -// Objectives: update and rescalarize -// ========================================================================== - -#[tokio::test] -async fn test_server_update_objectives() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let patch = json!({"objectives": [{"field": "accuracy", "type": "maximize", "priority": 1.0}]}); - let (status, result) = json_request(app, "PATCH", "/api/objectives", Some(patch)).await; - assert_eq!(status, 200); - assert_eq!(result["status"], "ok"); -} - -#[tokio::test] -async fn test_server_update_objectives_rescalarizes() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (_, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; - let tell = json!({"trial_id": trial["trial_id"], "metrics": {"loss": 0.5, "accuracy": 0.9}}); - json_request(app.clone(), "POST", "/api/tell", Some(tell)).await; - - let (_, best_before) = json_request(app.clone(), "GET", "/api/best", None).await; - assert!(best_before["observation"].as_f64().unwrap() > 0.0); - - let patch = json!({"objectives": [{"field": "accuracy", "type": "maximize", "priority": 1.0}]}); - json_request(app.clone(), "PATCH", "/api/objectives", Some(patch)).await; - - let (_, best_after) = json_request(app, "GET", "/api/best", None).await; - assert!(best_after["observation"].as_f64().unwrap() < 0.0); -} - -// ========================================================================== -// Sequential asks + monotonic IDs -// ========================================================================== - -#[tokio::test] -async fn test_server_ask_sequential_ids() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let mut prev_id: Option = None; - for _ in 0..5 { - let (status, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; - assert_eq!(status, 200); - let id = trial["trial_id"].as_u64().unwrap(); - if let Some(prev) = prev_id { - assert!(id > prev); - } - prev_id = Some(id); - } -} - -// ========================================================================== -// Checkpoint save -// ========================================================================== - -#[tokio::test] -async fn test_server_checkpoint_save_endpoint() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (_, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; - let tell = json!({"trial_id": trial["trial_id"], "metrics": {"loss": 0.5}}); - json_request(app.clone(), "POST", "/api/tell", Some(tell)).await; - - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("ckpt.json"); - let save_req = json!({"path": path.to_string_lossy(), "description": "server test"}); - - let (status, _) = json_request(app, "POST", "/api/checkpoint/save", Some(save_req)).await; - assert_eq!(status, 200); - assert!(path.exists()); -} - -// ========================================================================== -// Pareto front -// ========================================================================== - -fn multi_objective_config() -> StudyConfig { - StudyConfig { - space: BTreeMap::from([( - "x".to_string(), - ParamConfig::Real { - min: 0.0, - max: 1.0, - scale: "linear".to_string(), - }, - )]), - objectives: vec![ - ObjectiveConfig { - field: "loss".to_string(), - obj_type: "minimize".to_string(), - target: Some(0.0), - limit: Some(5.0), - priority: 1.0, - group: None, - }, - ObjectiveConfig { - field: "latency".to_string(), - obj_type: "minimize".to_string(), - target: Some(0.0), - limit: Some(100.0), - priority: 2.0, - group: None, - }, - ], - strategy: None, - checkpoint: None, - } -} - -#[tokio::test] -async fn test_server_pareto_front_multi_objective() { - let engine = DynEngine::from_config(multi_objective_config()).unwrap(); - let app = create_router(engine); - - // Complete a few trials - for i in 0..3 { - let (_, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; - let tell = json!({ - "trial_id": trial["trial_id"], - "metrics": {"loss": (i as f64) * 0.5, "latency": 50.0 - (i as f64) * 10.0} - }); - json_request(app.clone(), "POST", "/api/tell", Some(tell)).await; - } - - let (status, body) = json_request(app, "GET", "/api/pareto_front", None).await; - assert_eq!(status, 200); - assert!(body.is_array()); - let front = body.as_array().unwrap(); - assert!(!front.is_empty()); - // Each trial in the front should have trial_id, candidate, objectives - for trial in front { - assert!(trial["trial_id"].is_u64()); - assert!(trial["candidate"].is_object()); - assert!(trial["objectives"].is_object()); - assert!(trial["objectives"]["loss"].is_f64()); - assert!(trial["objectives"]["latency"].is_f64()); - } -} - -#[tokio::test] -async fn test_server_pareto_front_scalar_returns_400() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (status, _) = json_request(app, "GET", "/api/pareto_front", None).await; - assert_eq!(status, 400); -} - -// ========================================================================== -// Multiple metric fields -// ========================================================================== - -#[tokio::test] -async fn test_server_tell_with_multiple_fields() { - let engine = DynEngine::from_config(minimal_config()).unwrap(); - let app = create_router(engine); - - let (_, trial) = json_request(app.clone(), "POST", "/api/ask", None).await; - let tell = json!({ - "trial_id": trial["trial_id"], - "metrics": {"loss": 0.3, "accuracy": 0.9, "latency": 50.0} - }); - let (status, body) = json_request(app, "POST", "/api/tell", Some(tell)).await; - assert_eq!(status, 200); - assert_eq!(body["status"], "ok"); -}