-
Notifications
You must be signed in to change notification settings - Fork 66
Expand file tree
/
Copy pathtorchscript_utils.py
More file actions
127 lines (107 loc) · 4.65 KB
/
torchscript_utils.py
File metadata and controls
127 lines (107 loc) · 4.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
#
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0 and the following additional limitation. Functionality enabled by the
# files subject to the Elastic License 2.0 may only be used in production when
# invoked by an Elasticsearch process with a license key installed that permits
# use of machine learning features. You may not use this file except in
# compliance with the Elastic License 2.0 and the foregoing additional
# limitation.
#
"""Shared utilities for extracting and inspecting TorchScript operations."""
import json
import os
import sys
from pathlib import Path
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
def load_model_config(config_path: Path) -> dict[str, dict]:
"""Load a model config JSON file and normalise entries.
Each entry is either a plain model-name string or a dict with
``model_id`` (required) and optional ``quantized`` boolean. All
entries are normalised to ``{"model_id": str, "quantized": bool}``.
Keys starting with ``_comment`` are silently skipped.
Raises ``ValueError`` for malformed entries so that config problems
are caught early with an actionable message.
"""
with open(config_path) as f:
raw = json.load(f)
models: dict[str, dict] = {}
for key, value in raw.items():
if key.startswith("_comment"):
continue
if isinstance(value, str):
models[key] = {"model_id": value, "quantized": False}
elif isinstance(value, dict):
if "model_id" not in value:
raise ValueError(
f"Config entry {key!r} is a dict but missing required "
f"'model_id' key: {value!r}")
models[key] = {
"model_id": value["model_id"],
"quantized": value.get("quantized", False),
}
else:
raise ValueError(
f"Config entry {key!r} has unsupported type "
f"{type(value).__name__}: {value!r}. "
f"Expected a model name string or a dict with 'model_id'.")
return models
def collect_graph_ops(graph) -> set[str]:
"""Collect all operation names from a TorchScript graph, including blocks."""
ops = set()
for node in graph.nodes():
ops.add(node.kind())
for block in node.blocks():
ops.update(collect_graph_ops(block))
return ops
def collect_inlined_ops(module) -> set[str]:
"""Clone the forward graph, inline all calls, and return the op set."""
graph = module.forward.graph.copy()
torch._C._jit_pass_inline(graph)
return collect_graph_ops(graph)
def load_and_trace_hf_model(model_name: str, quantize: bool = False):
"""Load a HuggingFace model, tokenize sample input, and trace to TorchScript.
When *quantize* is True the model is dynamically quantized (nn.Linear
layers converted to quantized::linear_dynamic) before tracing. This
mirrors what Eland does when importing models for Elasticsearch.
Returns the traced module, or None if the model could not be loaded or traced.
"""
token = os.environ.get("HF_TOKEN")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
config = AutoConfig.from_pretrained(
model_name, torchscript=True, token=token)
model = AutoModel.from_pretrained(
model_name, config=config, token=token)
model.eval()
except Exception as exc:
print(f" LOAD ERROR: {exc}", file=sys.stderr)
return None
if quantize:
try:
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8)
print(" Applied dynamic quantization (nn.Linear -> qint8)",
file=sys.stderr)
except Exception as exc:
print(f" QUANTIZE ERROR: {exc}", file=sys.stderr)
return None
inputs = tokenizer(
"This is a sample input for graph extraction.",
return_tensors="pt", padding="max_length",
max_length=32, truncation=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
try:
return torch.jit.trace(
model, (input_ids, attention_mask), strict=False)
except Exception as exc:
print(f" TRACE WARNING: {exc}", file=sys.stderr)
print(" Falling back to torch.jit.script...", file=sys.stderr)
try:
return torch.jit.script(model)
except Exception as exc2:
print(f" SCRIPT ERROR: {exc2}", file=sys.stderr)
return None