Skip to content

Commit 48a1e66

Browse files
authored
[ML] Reapply: Run allowlist validation in PyTorch edge pipeline (#3007)
The Linux build/test Docker images don't include Python 3 (it's only used during image builds to compile PyTorch, then dropped in the multi-stage final image). Move the validation to a dedicated pipeline step using a python:3 agent image, triggered only for run_pytorch_tests builds.
1 parent 4387502 commit 48a1e66

5 files changed

Lines changed: 104 additions & 51 deletions

File tree

.buildkite/pipeline.json.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ def main():
8484
".buildkite/pipelines/check_build_regression.yml.sh",
8585
soft_fail=True))
8686

87+
# Validate the PyTorch allowlist against HuggingFace models when
88+
# triggered from the PyTorch edge pipeline. Runs in a python:3
89+
# container since the build/test images don't include Python.
90+
if config.run_pytorch_tests:
91+
pipeline_steps.append(pipeline_steps.generate_step("Upload PyTorch allowlist validation",
92+
".buildkite/pipelines/validate_pytorch_allowlist.yml.sh",
93+
soft_fail=True))
94+
8795
pipeline["env"] = env
8896
pipeline["steps"] = pipeline_steps
8997
print(json.dumps(pipeline, indent=2))
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/bin/bash
2+
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
# or more contributor license agreements. Licensed under the Elastic License
4+
# 2.0 and the following additional limitation. Functionality enabled by the
5+
# files subject to the Elastic License 2.0 may only be used in production when
6+
# invoked by an Elasticsearch process with a license key installed that permits
7+
# use of machine learning features. You may not use this file except in
8+
# compliance with the Elastic License 2.0 and the foregoing additional
9+
# limitation.
10+
11+
cat <<'EOL'
12+
steps:
13+
- label: "Validate PyTorch allowlist :torch:"
14+
key: "validate_pytorch_allowlist"
15+
timeout_in_minutes: 60
16+
command:
17+
- "if [ ! -f dev-tools/extract_model_ops/validate_allowlist.py ]; then echo 'validate_allowlist.py not found, skipping'; exit 0; fi"
18+
- "pip install -r dev-tools/extract_model_ops/requirements.txt"
19+
- "python3 dev-tools/extract_model_ops/validate_allowlist.py --config dev-tools/extract_model_ops/validation_models.json --pt-dir dev-tools/extract_model_ops/es_it_models --verbose"
20+
EOL
21+
22+
# Depend on the build steps so validation doesn't start before the
23+
# pipeline is fully generated.
24+
if [ -n "${ML_BUILD_STEP_KEYS:-}" ]; then
25+
echo ' depends_on:'
26+
IFS=',' read -ra STEP_KEYS <<< "$ML_BUILD_STEP_KEYS"
27+
for key in "${STEP_KEYS[@]}"; do
28+
echo " - \"${key}\""
29+
done
30+
fi
31+
32+
cat <<'EOL'
33+
allow_dependency_failure: true
34+
agents:
35+
image: "python:3.12"
36+
memory: "32G"
37+
ephemeralStorage: "30G"
38+
notify:
39+
- github_commit_status:
40+
context: "Validate PyTorch allowlist"
41+
EOL

.buildkite/scripts/steps/run_tests.sh

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -105,28 +105,6 @@ else
105105
-P cmake/run-all-tests-parallel.cmake || TEST_OUTCOME=$?
106106
fi
107107

108-
# --- PyTorch allowlist validation ---
109-
# When triggered from the PyTorch edge pipeline, run the Python-based
110-
# allowlist validation which traces live HuggingFace models with the
111-
# new PyTorch version and verifies every op is in ALLOWED_OPERATIONS.
112-
VALIDATION_OUTCOME=0
113-
if [[ "${GITHUB_PR_COMMENT_VAR_ACTION:-}" == "run_pytorch_tests" ]] && [ -f cmake/run-validation.cmake ]; then
114-
echo "--- Validating PyTorch allowlist against HuggingFace models"
115-
cmake \
116-
-DSOURCE_DIR="$(pwd)" \
117-
-DVALIDATE_CONFIG="$(pwd)/dev-tools/extract_model_ops/validation_models.json" \
118-
-DVALIDATE_PT_DIR="$(pwd)/dev-tools/extract_model_ops/es_it_models" \
119-
-DVALIDATE_VERBOSE=TRUE \
120-
-DOPTIONAL=TRUE \
121-
-P cmake/run-validation.cmake || VALIDATION_OUTCOME=$?
122-
123-
if [[ $VALIDATION_OUTCOME -ne 0 ]]; then
124-
echo "^^^ +++"
125-
echo "Allowlist validation failed — the new PyTorch version may introduce ops not in ALLOWED_OPERATIONS."
126-
echo "See dev-tools/extract_model_ops/README.md for how to update the allowlist."
127-
fi
128-
fi
129-
130108
# Upload test results
131109
echo "--- Uploading test results"
132110
TEST_RESULTS_ARCHIVE=${OS}-${HARDWARE_ARCH}-unit_test_results.tgz
@@ -139,6 +117,4 @@ else
139117
echo "No test results archive created"
140118
fi
141119

142-
if [[ $TEST_OUTCOME -ne 0 || $VALIDATION_OUTCOME -ne 0 ]]; then
143-
exit 1
144-
fi
120+
exit $TEST_OUTCOME

dev-tools/extract_model_ops/torchscript_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,18 @@ def load_and_trace_hf_model(model_name: str, quantize: bool = False,
145145
attention_mask = inputs["attention_mask"]
146146

147147
try:
148-
return torch.jit.trace(
148+
traced = torch.jit.trace(
149149
model, (input_ids, attention_mask), strict=False)
150150
except Exception as exc:
151151
print(f" TRACE WARNING: {exc}", file=sys.stderr)
152152
print(" Falling back to torch.jit.script...", file=sys.stderr)
153153
try:
154-
return torch.jit.script(model)
154+
traced = torch.jit.script(model)
155155
except Exception as exc2:
156156
print(f" SCRIPT ERROR: {exc2}", file=sys.stderr)
157157
return None
158+
159+
# Free the original HF model to reduce peak memory when validating
160+
# many models sequentially.
161+
del model, tokenizer, inputs
162+
return traced

dev-tools/extract_model_ops/validate_allowlist.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"""
3030

3131
import argparse
32+
import gc
3233
import re
3334
import sys
3435
from pathlib import Path
@@ -104,30 +105,44 @@ def validate_model(model_name: str,
104105
allowed: set[str],
105106
forbidden: set[str],
106107
verbose: bool,
107-
quantize: bool = False) -> bool:
108-
"""Validate one HuggingFace model. Returns True if all ops pass."""
108+
quantize: bool = False,
109+
auto_class: str | None = None,
110+
config_overrides: dict | None = None) -> str:
111+
"""Validate one HuggingFace model.
112+
113+
Returns "pass", "fail" (op validation failed), or "skip" (could not
114+
load/trace — e.g. private model without HF_TOKEN).
115+
"""
109116
label = f"{model_name} (quantized)" if quantize else model_name
110117
print(f" {label}...", file=sys.stderr)
111-
traced = load_and_trace_hf_model(model_name, quantize=quantize)
118+
traced = load_and_trace_hf_model(model_name, quantize=quantize,
119+
auto_class=auto_class,
120+
config_overrides=config_overrides)
112121
if traced is None:
113-
print(f" FAILED (could not load/trace)", file=sys.stderr)
114-
return False
122+
print(f" SKIPPED (could not load/trace)", file=sys.stderr)
123+
return "skip"
115124
ops = collect_inlined_ops(traced)
116-
return check_ops(ops, allowed, forbidden, verbose)
125+
result = "pass" if check_ops(ops, allowed, forbidden, verbose) else "fail"
126+
del traced
127+
gc.collect()
128+
return result
117129

118130

119131
def validate_pt_file(name: str,
120132
pt_path: str,
121133
allowed: set[str],
122134
forbidden: set[str],
123-
verbose: bool) -> bool:
124-
"""Validate a local TorchScript .pt file. Returns True if all ops pass."""
135+
verbose: bool) -> str:
136+
"""Validate a local TorchScript .pt file.
137+
138+
Returns "pass", "fail", or "skip".
139+
"""
125140
print(f" {name} ({pt_path})...", file=sys.stderr)
126141
ops = load_pt_and_collect_ops(pt_path)
127142
if ops is None:
128-
print(f" FAILED (could not load)", file=sys.stderr)
129-
return False
130-
return check_ops(ops, allowed, forbidden, verbose)
143+
print(f" SKIPPED (could not load)", file=sys.stderr)
144+
return "skip"
145+
return "pass" if check_ops(ops, allowed, forbidden, verbose) else "fail"
131146

132147

133148
def main():
@@ -151,7 +166,7 @@ def main():
151166
print(f"Parsed {len(allowed)} allowed ops and {len(forbidden)} "
152167
f"forbidden ops from {SUPPORTED_OPS_CC.name}", file=sys.stderr)
153168

154-
results: dict[str, bool] = {}
169+
results: dict[str, str] = {}
155170

156171
models = load_model_config(args.config)
157172

@@ -161,7 +176,9 @@ def main():
161176
for arch, spec in models.items():
162177
results[arch] = validate_model(
163178
spec["model_id"], allowed, forbidden, args.verbose,
164-
quantize=spec["quantized"])
179+
quantize=spec["quantized"],
180+
auto_class=spec.get("auto_class"),
181+
config_overrides=spec.get("config_overrides"))
165182

166183
if args.pt_dir and args.pt_dir.is_dir():
167184
pt_files = sorted(args.pt_dir.glob("*.pt"))
@@ -175,26 +192,32 @@ def main():
175192

176193
print(file=sys.stderr)
177194
print("=" * 60, file=sys.stderr)
178-
all_pass = all(results.values())
179-
for key, passed in results.items():
180-
status = "PASS" if passed else "FAIL"
195+
for key, status in results.items():
196+
display = status.upper()
181197
if key.startswith("pt:"):
182-
print(f" {key}: {status}", file=sys.stderr)
198+
print(f" {key}: {display}", file=sys.stderr)
183199
else:
184200
spec = models[key]
185201
label = spec["model_id"]
186202
if spec["quantized"]:
187203
label += " (quantized)"
188-
print(f" {key} ({label}): {status}", file=sys.stderr)
204+
print(f" {key} ({label}): {display}", file=sys.stderr)
205+
206+
failed = [a for a, s in results.items() if s == "fail"]
207+
skipped = [a for a, s in results.items() if s == "skip"]
208+
passed = [a for a, s in results.items() if s == "pass"]
189209

190210
print("=" * 60, file=sys.stderr)
191-
if all_pass:
192-
print("All models PASS - no false positives.", file=sys.stderr)
193-
else:
194-
failed = [a for a, p in results.items() if not p]
195-
print(f"FAILED models: {', '.join(failed)}", file=sys.stderr)
211+
print(f"{len(passed)} passed, {len(failed)} failed, "
212+
f"{len(skipped)} skipped", file=sys.stderr)
213+
214+
if skipped:
215+
print(f"Skipped (could not load/trace — may need HF_TOKEN "
216+
f"for private models): {', '.join(skipped)}", file=sys.stderr)
217+
if failed:
218+
print(f"FAILED (op validation): {', '.join(failed)}", file=sys.stderr)
196219

197-
sys.exit(0 if all_pass else 1)
220+
sys.exit(0 if not failed else 1)
198221

199222

200223
if __name__ == "__main__":

0 commit comments

Comments
 (0)