From 20fc852a2d7ec3f2d424b4c40de90b7b7cfc4bc5 Mon Sep 17 00:00:00 2001 From: Jun Luan Date: Wed, 21 Jan 2026 15:22:45 -0800 Subject: [PATCH] Add support for conv1d case in quantized_conv2d_nhwc op Differential Revision: D90901384 --- .../generic/operators/op_quantized_conv2d.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/backends/cadence/generic/operators/op_quantized_conv2d.cpp b/backends/cadence/generic/operators/op_quantized_conv2d.cpp index ca701957866..a2484294a03 100644 --- a/backends/cadence/generic/operators/op_quantized_conv2d.cpp +++ b/backends/cadence/generic/operators/op_quantized_conv2d.cpp @@ -406,19 +406,21 @@ void quantized_conv2d_nhwc( float output_scale, int32_t output_zero_point, Tensor& out) { + + bool conv1d = input.dim() == 3; // input = [n, h, w, c] const int n = input.size(0); - const int h = input.size(1); - const int w = input.size(2); - const int c = input.size(3); + const int h = conv1d ? 1 : input.size(1); + const int w = conv1d ? input.size(1) : input.size(2); + const int c = conv1d ? input.size(2) : input.size(3); // weight = [oc, wh, ww, wc] const int oc = weight.size(0); - const int wh = weight.size(1); - const int ww = weight.size(2); - const int wc = weight.size(3); + const int wh = conv1d ? 1 : weight.size(1); + const int ww = conv1d ? weight.size(1) : weight.size(2); + const int wc = conv1d ? weight.size(2) : weight.size(3); // output = [n, oh, ow, oc] - const int oh = out.size(1); - const int ow = out.size(2); + const int oh = conv1d ? 1 : out.size(1); + const int ow = conv1d ? out.size(1) : out.size(2); // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) if (out.scalar_type() == ScalarType::Short &&