Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions backends/cadence/generic/operators/op_quantized_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,19 +406,21 @@ void quantized_conv2d_nhwc(
float output_scale,
int32_t output_zero_point,
Tensor& out) {

bool conv1d = input.dim() == 3;
// input = [n, h, w, c]
const int n = input.size(0);
const int h = input.size(1);
const int w = input.size(2);
const int c = input.size(3);
const int h = conv1d ? 1 : input.size(1);
const int w = conv1d ? input.size(1) : input.size(2);
const int c = conv1d ? input.size(2) : input.size(3);
// weight = [oc, wh, ww, wc]
const int oc = weight.size(0);
const int wh = weight.size(1);
const int ww = weight.size(2);
const int wc = weight.size(3);
const int wh = conv1d ? 1 : weight.size(1);
const int ww = conv1d ? weight.size(1) : weight.size(2);
const int wc = conv1d ? weight.size(2) : weight.size(3);
// output = [n, oh, ow, oc]
const int oh = out.size(1);
const int ow = out.size(2);
const int oh = conv1d ? 1 : out.size(1);
const int ow = conv1d ? out.size(1) : out.size(2);

// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
if (out.scalar_type() == ScalarType::Short &&
Expand Down
Loading