From 18172f489db625c6c92c3fd261c4f5b5b286edec Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 24 Apr 2026 12:03:52 -0400 Subject: [PATCH 1/2] Added support inp->DQ graphs --- tools/qdq-translator/CHANGELOG.md | 5 +++++ tools/qdq-translator/qdq_translator.py | 27 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/tools/qdq-translator/CHANGELOG.md b/tools/qdq-translator/CHANGELOG.md index a0d9d04..4b754c2 100644 --- a/tools/qdq-translator/CHANGELOG.md +++ b/tools/qdq-translator/CHANGELOG.md @@ -1,6 +1,11 @@ # NVIDIA QDQ Translator change log Dates are in YYYY-MM-DD format. + +## Unreleased + +- Added support for translating models with a standalone DequantizeLinear connected to an int8/uint8 graph input. + ## v0.2.0 (2023-08-09) - Added "infer_mul_scales" arg for handling Mul op. diff --git a/tools/qdq-translator/qdq_translator.py b/tools/qdq-translator/qdq_translator.py index 0dbe4ab..4963ea4 100644 --- a/tools/qdq-translator/qdq_translator.py +++ b/tools/qdq-translator/qdq_translator.py @@ -288,12 +288,39 @@ def extract_qdq_scales(quantize_node: gs.Node, dequantize_node: gs.Node): assert (quant_scales == dequant_scales).all() return quant_scales + @staticmethod + def _remove_input_dq(graph: gs.Graph, precision_config: Dict[str, float]): + """Strip DequantizeLinear nodes that consume an int8/uint8 graph input directly.""" + gin_names = {i.name for i in graph.inputs} + for dq in list(graph.nodes): + if dq.op != "DequantizeLinear": + continue + x = dq.inputs[0] + if not (isinstance(x, gs.Variable) and x.name in gin_names + and len(x.inputs) == 0 and x.dtype in (np.int8, np.uint8)): + continue + dq_out = dq.outputs[0] + if isinstance(dq.inputs[1], gs.Constant): + dq_scale = dq.inputs[1].values + else: + dq_scale = dq.inputs[1].inputs[0].attrs["value"].values + if len(dq.inputs) > 2 and isinstance(dq.inputs[2], gs.Constant): + assert (dq.inputs[2].values == 0).all() + precision_config[x.name] = float(dq_scale) + # Flip to float dtype so downstream consumers pass ONNX type checks; binding-side int8 is reasserted via TRT flags + precision_config. + x.dtype = dq_out.dtype + for node in graph.nodes: + QATModelParser.node_replace_input(node, dq_out.name, x, None, None) + QATModelParser.graph_replace_output(graph, dq_out.name, x) + dq.outputs.clear() + @staticmethod def extract_precision_config(graph: gs.Graph, calibration_type: str): precision_config = {} # Check for all zero weighted inputs of QuantizeLinear and # Conv nodes and add to this set to skip for the later check zero_check_skip = set() + QATModelParser._remove_input_dq(graph, precision_config) for node in graph.nodes: if node.op != "QuantizeLinear": if node.op in ("Conv", "ConvTranspose", "Gemm"): From c0c1ca36068510ffc64d97962afac1e68776b2f3 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Fri, 24 Apr 2026 13:04:17 -0400 Subject: [PATCH 2/2] Skip onnxoptimizer's fuse_bn_into_conv pass when BN has FP16 params This avoids graph decomposition: https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/fuse_bn_into_conv.h#L40 --- tools/qdq-translator/CHANGELOG.md | 1 + tools/qdq-translator/qdq_translator.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/tools/qdq-translator/CHANGELOG.md b/tools/qdq-translator/CHANGELOG.md index 4b754c2..7126b6d 100644 --- a/tools/qdq-translator/CHANGELOG.md +++ b/tools/qdq-translator/CHANGELOG.md @@ -5,6 +5,7 @@ Dates are in YYYY-MM-DD format. ## Unreleased - Added support for translating models with a standalone DequantizeLinear connected to an int8/uint8 graph input. +- Skip the onnxoptimizer ``fuse_bn_into_conv`` pass when any BatchNormalization parameter is fp16, to avoid graph decomposition (see https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/fuse_bn_into_conv.h#L40). ## v0.2.0 (2023-08-09) diff --git a/tools/qdq-translator/qdq_translator.py b/tools/qdq-translator/qdq_translator.py index 4963ea4..37254f0 100644 --- a/tools/qdq-translator/qdq_translator.py +++ b/tools/qdq-translator/qdq_translator.py @@ -605,6 +605,17 @@ def parse(model_path: str, output_dir: str, post_opt_passes: List[str], model = onnx.shape_inference.infer_shapes(model) graph = gs.import_onnx(model) + if 'fuse_bn_into_conv' in post_opt_passes: + has_bn_fp16 = any( + n.op == 'BatchNormalization' + and any(isinstance(c, gs.Constant) and c.values.dtype == np.float16 for c in n.inputs) + for n in graph.nodes + ) + if has_bn_fp16: + logging.warning( + 'Skipping fuse_bn_into_conv: onnxoptimizer decomposes BN with fp16 params ' + '(see https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/fuse_bn_into_conv.h#L40).') + post_opt_passes = [p for p in post_opt_passes if p != 'fuse_bn_into_conv'] if rename_node_outputs: for node in graph.nodes: for idx, out in enumerate(node.outputs):