Skip to content

Unroll GQA and remove neg ops#460

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
mhs4670go:qqq2
Feb 5, 2026
Merged

Unroll GQA and remove neg ops#460
mhs4670go merged 1 commit intoSamsung:mainfrom
mhs4670go:qqq2

Conversation

@mhs4670go
Copy link
Contributor

This commit rewrites the wrapper for below.

  • Remove 5-rank Broadcast ops by unrolling GQA
  • Fuse neg ops into position embeddings.

The changes makes exported graph hardware-friendly.

TICO-DCO-1.0-Signed-off-by: seongwoo mhs4670go@naver.com

@mhs4670go mhs4670go added the DRAFT label Feb 3, 2026
@mhs4670go
Copy link
Contributor Author

mhs4670go commented Feb 3, 2026

one-optimize issue

circle2circle got errors when running it with O1 option.

circle2circle: /home/seongwoo/ONE/compiler/luci/lang/src/Nodes/CircleConst.cpp:48: typename loco::DataTypeImpl<DT>::Type& luci::CircleConst::at(uint32_t) [with loco::DataType DT = loco::DataType::U4; typename loco::DataTypeImpl<DT>::Type = unsigned char; uint32_t = unsigned int]: Assertion `n < size<DT>()' failed.
[1]    2252871 abort (core dumped)  ~/ONE/build/compiler/circle2circle/circle2circle decoder_layer.q.circle

RMSNorm issue

This is not related with this PR but current main branch got errors. Seems that this was because of CircleRMSNorm.

circle2circle: ERROR: Optimized graph is invalid

When I tested this issue with TinyLlamaWithFusedRMSNorm test model, the error didn't happen.

@stamalakhov
Copy link
Contributor

stamalakhov commented Feb 3, 2026

@mhs4670go
After rebasing your draft onto current main, and patching infer_dtype by:

 if (
        weight_val.within(0, 15)
        and zp_val.within(0, 15)
        and dtype == torch.uint8
        and weight.numel() > 1 #this prevents scalar multiplication to be quantized to `uint4`, while another input is quantized to `uint8`
    ):
        return "uint4"

i managed to run .../circle2circle ...decoder_layer.q.circle ...decoder_layer.q.opt.circle at least.
so you may test after rebasing to current main and setting default_dtype to 16 bits, .../circle2circle ...decoder_layer.q.circle ...decoder_layer.q.opt.circle runs successfully.

@mhs4670go
Copy link
Contributor Author

mhs4670go commented Feb 3, 2026

After the patch, the error in the main branch has been resolved. But, current branch still got an error.

circle2circle: /home/seongwoo/ONE/compiler/luci/lang/src/Nodes/CircleConst.cpp:47: typename loco::DataTypeImpl<DT>::Type& luci::CircleConst::at(uint32_t) [with loco::DataType DT = loco::DataType::FLOAT32; typename loco::DataTypeImpl<DT>::Type = float; uint32_t = unsigned int]: Assertion `dtype() == DT' failed.

The error happened in ForwardTransposeOpPass pass in one-optimize. CircleMul node seems to be a problem.

Turns out that ForwardTransposeOpPass doesn't support uint8 dtype properly.

@stamalakhov
Copy link
Contributor

ForwardTransposeOpPass

@mhs4670go
I'm not sure but BroadcastTo seems to may be the problem (its output is float32 while its input is uint8).

@mhs4670go
Copy link
Contributor Author

@stamalakhov Ah, seems that below is a problem.

Turns out that ForwardTransposeOpPass doesn't support uint8 dtype properly.

Some passes hasn't considered quantized inputs. Maybe this is the case. I'll update the code soon.

Comment on lines 269 to 270
attn_weights = torch.cat(attn_weights_parts, dim=1)
attn_out_h = torch.cat(attn_out_parts, dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
I believe attn_weights should be quantized also. Currently they are producing floats. Or is it intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stamalakhov Ah, I missed it. Thanks for the letting me know!

@mhs4670go
Copy link
Contributor Author

I am observing a significant increase in PEIR after introducing the changes. Even though I ran it with float RMSNorm.

python tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.037503
│ PEIR       : 60.925078 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
20.6┤                                            │
    │                          •                 │
    │                                            │
15.7┤                                            │
    │                                            │
    │                                            │
    │                                            │
10.8┤                                            │
    │                             •              │
    │                                            │
 6.0┤                 ••                         │
    │                 •                          │
    │                •                           │
    │                                            │
 1.1┤              •                             │
    │            ••                              │
    │                                            │
-3.8┤                                            │
    │          •                                 │
    │    • • •                                   │
    │          •                                 │
-8.6┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -8.6       -1.3        6.0      13.3      20.6 

@mhs4670go
Copy link
Contributor Author

@stamalakhov FYI, RMSNorm seems to have a similar problem.

After applying QuantRMSNorm, even though I set to int16, it has high PEIR.

@stamalakhov
Copy link
Contributor

@stamalakhov FYI, RMSNorm seems to have a similar problem.

After applying QuantRMSNorm, even though I set to int16, it has high PEIR.

@mhs4670go
Ahh. This is sad. I'll try to investigate it.

L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)

if position_embeddings is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go I believe these ones should be used unconditionally, otherwise LlamaModel will send its own position_embeddings and they will be desynchronized with _rot(). IMHO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go This seems to be the problem. Position embeddings are wrong, they don't have -[]. Moreover right now seq_len is hardcoded to 256, so the whole sequence should be padded to 256.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go This seems to be the problem. Position embeddings are wrong, they don't have -[]. Moreover right now seq_len is hardcoded to 256, so the whole sequence should be padded to 256.

pads = torch.zeros(ids["input_ids"].shape[0], model.config.max_position_embeddings - ids["input_ids"].shape[1], dtype=ids["input_ids"].dtype)
for j in range(model.config.max_position_embeddings - ids["input_ids"].shape[-1]):
     pads[0, j] = tokenizer.pad_token_id 
ids["input_ids"] = torch.cat((ids["input_ids"], pads), dim = 1)

or in some other way, but sequence length should be 256.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go ids should be padded to 256. may be this way:

pads = torch.zeros(ids["input_ids"].shape[0], model.config.max_position_embeddings - ids["input_ids"].shape[1], dtype=ids["input_ids"].dtype)
for j in range(model.config.max_position_embeddings - ids["input_ids"].shape[-1]):
     pads[0, j] = tokenizer.pad_token_id 
ids["input_ids"] = torch.cat((ids["input_ids"], pads), dim = 1)

or in some other way.

@stamalakhov
Copy link
Contributor

stamalakhov commented Feb 4, 2026

@mhs4670go
Please see comments about remaining position_embeddings. Applying all paddings and removing all postion_embeddings from layer inference produces the following PEIR for unsloth/Llama...:

┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.023445
│ PEIR       : 8.307570 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 5.5┤                                            │
    │                                            │
    │                                            │
 3.9┤                                         •  │
    │                                            │
    │                                            │
    │                                            │
 2.3┤                                            │
    │                                            │
    │                         ••                 │
 0.7┤                        •                   │
    │                   •••                      │
    │               • ••••                       │
    │               ••••                         │
-0.9┤              ••                            │
    │           ••                               │
    │                                            │
-2.5┤                                            │
    │                                            │
    │  •                                         │
    │                                            │
-4.1┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -4.1       -1.7        0.7       3.1       5.5 

which seems to be valid.
Otherwise i've copied quant_attn.py and quant_decoder_layer.py to the draft, and run full quantize_full_qmodel_with_gptq.py quntization script for SmolLM, which produced ~40 ppl (original ~17), which was the ppl before copying quant_attn.py and quant_decoder_layer.py. So the whole draft seems to be correct.

@mhs4670go
Copy link
Contributor Author

mhs4670go commented Feb 5, 2026

After some bug fix according to @stamalakhov 's suggestion, it resolved the error. Thanks a lot!

┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.020452
│ PEIR       : 1.983250 %
└──────────────────────────────────────────────────────
     ┌───────────────────────────────────────────┐
 22.5┤                                           │
     │                                        •  │
     │                                           │
 16.7┤                                     •     │
     │                                           │
     │                                           │
     │                                           │
 10.9┤                                           │
     │                                           │
     │                        •                  │
  5.1┤                      ••                   │
     │                     •                     │
     │                                           │
     │               ••                          │
 -0.6┤              ••                           │
     │                                           │
     │         •                                 │
 -6.4┤        •                                  │
     │      •                                    │
     │    •                                      │
     │  •                                        │
-12.2┤                                           │
     └┬──────────┬─────────┬──────────┬─────────┬┘
    -12.2      -3.5       5.1       13.8     22.5

@mhs4670go mhs4670go changed the title [DRAFT] Unroll GQA and remove neg ops Unroll GQA and remove neg ops Feb 5, 2026
@mhs4670go mhs4670go requested a review from stamalakhov February 5, 2026 07:34
@mhs4670go mhs4670go removed the DRAFT label Feb 5, 2026
This commit unrolls GQA and remove neg ops.

TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
Copy link
Contributor

@stamalakhov stamalakhov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you!

@mhs4670go mhs4670go merged commit 84d6f6f into Samsung:main Feb 5, 2026
7 checks passed
@mhs4670go mhs4670go deleted the qqq2 branch February 5, 2026 09:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants