diff --git a/tools/qdq-translator/CHANGELOG.md b/tools/qdq-translator/CHANGELOG.md index a0d9d04..7126b6d 100644 --- a/tools/qdq-translator/CHANGELOG.md +++ b/tools/qdq-translator/CHANGELOG.md @@ -1,6 +1,12 @@ # NVIDIA QDQ Translator change log Dates are in YYYY-MM-DD format. + +## Unreleased + +- Added support for translating models with a standalone DequantizeLinear connected to an int8/uint8 graph input. +- Skip the onnxoptimizer ``fuse_bn_into_conv`` pass when any BatchNormalization parameter is fp16, to avoid graph decomposition (see https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/fuse_bn_into_conv.h#L40). + ## v0.2.0 (2023-08-09) - Added "infer_mul_scales" arg for handling Mul op. diff --git a/tools/qdq-translator/qdq_translator.py b/tools/qdq-translator/qdq_translator.py index 0dbe4ab..37254f0 100644 --- a/tools/qdq-translator/qdq_translator.py +++ b/tools/qdq-translator/qdq_translator.py @@ -288,12 +288,39 @@ def extract_qdq_scales(quantize_node: gs.Node, dequantize_node: gs.Node): assert (quant_scales == dequant_scales).all() return quant_scales + @staticmethod + def _remove_input_dq(graph: gs.Graph, precision_config: Dict[str, float]): + """Strip DequantizeLinear nodes that consume an int8/uint8 graph input directly.""" + gin_names = {i.name for i in graph.inputs} + for dq in list(graph.nodes): + if dq.op != "DequantizeLinear": + continue + x = dq.inputs[0] + if not (isinstance(x, gs.Variable) and x.name in gin_names + and len(x.inputs) == 0 and x.dtype in (np.int8, np.uint8)): + continue + dq_out = dq.outputs[0] + if isinstance(dq.inputs[1], gs.Constant): + dq_scale = dq.inputs[1].values + else: + dq_scale = dq.inputs[1].inputs[0].attrs["value"].values + if len(dq.inputs) > 2 and isinstance(dq.inputs[2], gs.Constant): + assert (dq.inputs[2].values == 0).all() + precision_config[x.name] = float(dq_scale) + # Flip to float dtype so downstream consumers pass ONNX type checks; binding-side int8 is reasserted via TRT flags + precision_config. + x.dtype = dq_out.dtype + for node in graph.nodes: + QATModelParser.node_replace_input(node, dq_out.name, x, None, None) + QATModelParser.graph_replace_output(graph, dq_out.name, x) + dq.outputs.clear() + @staticmethod def extract_precision_config(graph: gs.Graph, calibration_type: str): precision_config = {} # Check for all zero weighted inputs of QuantizeLinear and # Conv nodes and add to this set to skip for the later check zero_check_skip = set() + QATModelParser._remove_input_dq(graph, precision_config) for node in graph.nodes: if node.op != "QuantizeLinear": if node.op in ("Conv", "ConvTranspose", "Gemm"): @@ -578,6 +605,17 @@ def parse(model_path: str, output_dir: str, post_opt_passes: List[str], model = onnx.shape_inference.infer_shapes(model) graph = gs.import_onnx(model) + if 'fuse_bn_into_conv' in post_opt_passes: + has_bn_fp16 = any( + n.op == 'BatchNormalization' + and any(isinstance(c, gs.Constant) and c.values.dtype == np.float16 for c in n.inputs) + for n in graph.nodes + ) + if has_bn_fp16: + logging.warning( + 'Skipping fuse_bn_into_conv: onnxoptimizer decomposes BN with fp16 params ' + '(see https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/fuse_bn_into_conv.h#L40).') + post_opt_passes = [p for p in post_opt_passes if p != 'fuse_bn_into_conv'] if rename_node_outputs: for node in graph.nodes: for idx, out in enumerate(node.outputs):