Skip to content

Add WebGPU support for TopK#27560

Open
xenova wants to merge 8 commits intomicrosoft:mainfrom
xenova:webgpu-topk
Open

Add WebGPU support for TopK#27560
xenova wants to merge 8 commits intomicrosoft:mainfrom
xenova:webgpu-topk

Conversation

@xenova
Copy link
Contributor

@xenova xenova commented Mar 5, 2026

Description

This PR adds WebGPU support for TopK. Needed by several onnx-community models, like https://huggingface.co/onnx-community/Trinity-Nano-Preview-ONNX (which I converted and have been testing recently).

Motivation and Context

While testing onnx-community/Trinity-Nano-Preview-ONNX in Transformers.js, I noticed it was exceptionally slow... achieving only ~8-10 tps on my M4 Max. After a bit of profiling, it became obvious that the bottleneck was the CPU <-> GPU memory copies (MemcpyToHost)... happening before and after each TopK op (needed by the model to correctly select expert... not taken into account by the current QMoE op because Trinity-Nano uses a sigmoid scoring function, while other QMoE models like GPT-OSS use softmax scoring functions, taken into account in the op itself).

The algorithm is basically copied exactly from the QMoE TopK implementation. Maybe we can deduplicate some code here by templating?

Summary of PR:

  • Before: MemcpyToHost (50.1%)
  • After: MemcpyToHost (0.0%)

PROFILING BEFORE:

Top ops by total time:
  MemcpyToHost                    total= 1778.58 ms  count= 702  avg=2.534 ms  (50.1%)
  QMoE                            total= 1312.10 ms  count= 702  avg=1.869 ms  (36.9%)
  Add                             total=   94.52 ms  count=2860  avg=0.033 ms  (2.7%)
  SimplifiedLayerNormalization    total=   86.80 ms  count=4368  avg=0.020 ms  (2.4%)
  MatMulNBits                     total=   63.84 ms  count=5200  avg=0.012 ms  (1.8%)
  GroupQueryAttention             total=   31.45 ms  count= 728  avg=0.043 ms  (0.9%)
  Mul                             total=   29.55 ms  count=2899  avg=0.010 ms  (0.8%)
  Reshape                         total=   21.40 ms  count=3614  avg=0.006 ms  (0.6%)
  Sigmoid                         total=   19.99 ms  count=2158  avg=0.009 ms  (0.6%)
  MatMul                          total=   18.50 ms  count=1339  avg=0.014 ms  (0.5%)
  TopK                            total=   18.08 ms  count= 702  avg=0.026 ms  (0.5%)
  MemcpyFromHost                  total=   17.07 ms  count= 715  avg=0.024 ms  (0.5%)
  Neg                             total=   13.94 ms  count=1404  avg=0.010 ms  (0.4%)
  Expand                          total=   10.18 ms  count= 702  avg=0.014 ms  (0.3%)
  ScatterElements                 total=    8.64 ms  count= 702  avg=0.012 ms  (0.2%)
  GatherElements                  total=    8.33 ms  count= 702  avg=0.012 ms  (0.2%)
  Softplus                        total=    7.42 ms  count= 702  avg=0.011 ms  (0.2%)
  GatherBlockQuantized            total=    4.83 ms  count=  13  avg=0.371 ms  (0.1%)
  Shape                           total=    4.02 ms  count= 715  avg=0.006 ms  (0.1%)
  SkipSimplifiedLayerNormalization  total=    1.74 ms  count=  13  avg=0.134 ms  (0.0%)
  TOTAL                           total= 3551.63 ms

================================================================================
EXECUTION PROVIDER BREAKDOWN
================================================================================

  [CPUExecutionProvider]  total=18.80 ms  nodes=780  (0.5% of all time)
    TopK                            total=   18.08 ms  count= 702  avg=0.026 ms  (96.2%)
    Cast                            total=    0.29 ms  count=  26  avg=0.011 ms  (1.5%)
    ReduceSum                       total=    0.17 ms  count=  13  avg=0.013 ms  (0.9%)
    Sub                             total=    0.12 ms  count=  13  avg=0.009 ms  (0.6%)
    Gather                          total=    0.08 ms  count=  13  avg=0.006 ms  (0.4%)
    Shape                           total=    0.06 ms  count=  13  avg=0.004 ms  (0.3%)

  [WebGpuExecutionProvider]  total=3532.84 ms  nodes=30225  (99.5% of all time)
    MemcpyToHost                    total= 1778.58 ms  count= 702  avg=2.534 ms  (50.3%)
    QMoE                            total= 1312.10 ms  count= 702  avg=1.869 ms  (37.1%)
    Add                             total=   94.52 ms  count=2860  avg=0.033 ms  (2.7%)
    SimplifiedLayerNormalization    total=   86.80 ms  count=4368  avg=0.020 ms  (2.5%)
    MatMulNBits                     total=   63.84 ms  count=5200  avg=0.012 ms  (1.8%)
    GroupQueryAttention             total=   31.45 ms  count= 728  avg=0.043 ms  (0.9%)
    Mul                             total=   29.55 ms  count=2899  avg=0.010 ms  (0.8%)
    Reshape                         total=   21.40 ms  count=3614  avg=0.006 ms  (0.6%)
    Sigmoid                         total=   19.99 ms  count=2158  avg=0.009 ms  (0.6%)
    MatMul                          total=   18.50 ms  count=1339  avg=0.014 ms  (0.5%)
    MemcpyFromHost                  total=   17.07 ms  count= 715  avg=0.024 ms  (0.5%)
    Neg                             total=   13.94 ms  count=1404  avg=0.010 ms  (0.4%)
    Expand                          total=   10.18 ms  count= 702  avg=0.014 ms  (0.3%)
    ScatterElements                 total=    8.64 ms  count= 702  avg=0.012 ms  (0.2%)
    GatherElements                  total=    8.33 ms  count= 702  avg=0.012 ms  (0.2%)

================================================================================
DATA TRANSFER (Memcpy) DETAILS
================================================================================

  Total Memcpy time: 1795.65 ms across 1417 transfers

  Top 30 slowest transfers:
      13.521 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_64_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       7.677 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_94_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       7.392 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_89_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       7.312 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_91_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       7.226 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_93_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       6.818 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_92_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       6.780 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_88_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       6.710 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_90_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       5.550 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_87_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       5.331 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_100_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       5.293 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_101_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       5.273 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_64_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       5.257 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_99_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.943 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_101_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.913 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_89_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.896 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_81_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.893 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_82_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.840 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_93_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.810 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_64_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.746 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_79_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.665 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_84_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.590 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_64_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.571 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_78_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.569 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_64_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.494 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_64_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.472 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_74_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.466 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_103_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.427 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_84_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.407 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_94_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]
       4.402 ms  MemcpyToHost          WebGpuExecutionProvider         Memcpy_token_80_kernel_time
              in:  [{'float': [1, 128, 128]}]
              out: [{'float': [1, 128, 128]}]

================================================================================
CPU FALLBACK NODES (non-Memcpy)
================================================================================

  Total CPU fallback time: 18.80 ms across 60 unique nodes (780 events)

  All CPU fallback nodes:
       0.126 ms  Cast                  /model/attn_mask_reformat/attn_mask_subgraph/Sub/Cast
              in:  [{'int64': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.078 ms  TopK                  /model/layers.20/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.067 ms  TopK                  /model/layers.4/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.065 ms  ReduceSum             /model/attn_mask_reformat/attn_mask_subgraph/ReduceSum
              in:  [{'int64': [1, 128]}, {'int64': [1]}]
              out: [{'int64': [1, 1]}]
       0.049 ms  TopK                  /model/layers.45/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.047 ms  TopK                  /model/layers.53/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.045 ms  TopK                  /model/layers.52/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.044 ms  TopK                  /model/layers.30/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.044 ms  TopK                  /model/layers.38/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.041 ms  TopK                  /model/layers.37/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.041 ms  TopK                  /model/layers.46/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.040 ms  TopK                  /model/layers.2/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.040 ms  TopK                  /model/layers.26/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.039 ms  TopK                  /model/layers.50/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.038 ms  TopK                  /model/layers.35/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.038 ms  TopK                  /model/layers.51/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.038 ms  TopK                  /model/layers.54/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.037 ms  TopK                  /model/layers.36/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.037 ms  TopK                  /model/layers.43/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.036 ms  Sub                   /model/attn_mask_reformat/attn_mask_subgraph/Sub
              in:  [{'int64': [1, 1]}, {'int64': [1]}]
              out: [{'int64': [1, 1]}]
       0.036 ms  TopK                  /model/layers.55/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.036 ms  TopK                  /model/layers.39/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.036 ms  TopK                  /model/layers.49/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.035 ms  TopK                  /model/layers.24/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.035 ms  TopK                  /model/layers.47/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.034 ms  TopK                  /model/layers.33/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.034 ms  TopK                  /model/layers.41/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.034 ms  TopK                  /model/layers.48/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.034 ms  TopK                  /model/layers.9/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.033 ms  TopK                  /model/layers.29/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.033 ms  TopK                  /model/layers.44/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.032 ms  TopK                  /model/layers.25/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.032 ms  TopK                  /model/layers.32/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.032 ms  TopK                  /model/layers.27/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.032 ms  TopK                  /model/layers.31/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.032 ms  TopK                  /model/layers.11/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.031 ms  TopK                  /model/layers.3/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.030 ms  TopK                  /model/layers.34/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.030 ms  TopK                  /model/layers.17/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.030 ms  TopK                  /model/layers.40/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.030 ms  TopK                  /model/layers.42/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.030 ms  TopK                  /model/layers.7/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.029 ms  TopK                  /model/layers.21/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.029 ms  TopK                  /model/layers.16/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.029 ms  TopK                  /model/layers.8/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.029 ms  TopK                  /model/layers.14/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.028 ms  TopK                  /model/layers.13/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.027 ms  TopK                  /model/layers.19/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.027 ms  TopK                  /model/layers.22/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.027 ms  TopK                  /model/layers.5/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.027 ms  TopK                  /model/layers.15/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.027 ms  TopK                  /model/layers.23/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.027 ms  TopK                  /model/layers.6/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.026 ms  TopK                  /model/layers.18/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.026 ms  TopK                  /model/layers.12/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.026 ms  TopK                  /model/layers.28/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.023 ms  TopK                  /model/layers.10/moe/router/TopK
              in:  [{'float': [1, 128, 128]}, {'int64': [1]}]
              out: [{'float': [1, 128, 8]}, {'int64': [1, 128, 8]}]
       0.007 ms  Gather                /model/attn_mask_reformat/attn_mask_subgraph/Gather_1
              in:  [{'int64': [2]}, {'int64': []}]
              out: [{'int64': []}]
       0.006 ms  Cast                  /model/attn_mask_reformat/attn_mask_subgraph/Gather/Cast
              in:  [{'int64': []}]
              out: [{'int32': []}]
       0.006 ms  Shape                 /model/attn_mask_reformat/attn_mask_subgraph/Shape
              in:  [{'int64': [1, 128]}]
              out: [{'int64': [2]}]

================================================================================
TOP 30 SLOWEST INDIVIDUAL NODES
================================================================================
    13.521 ms  WebGpu    MemcpyToHost               Memcpy_token_64_kernel_time
     7.677 ms  WebGpu    MemcpyToHost               Memcpy_token_94_kernel_time
     7.392 ms  WebGpu    MemcpyToHost               Memcpy_token_89_kernel_time
     7.312 ms  WebGpu    MemcpyToHost               Memcpy_token_91_kernel_time
     7.226 ms  WebGpu    MemcpyToHost               Memcpy_token_93_kernel_time
     6.818 ms  WebGpu    MemcpyToHost               Memcpy_token_92_kernel_time
     6.780 ms  WebGpu    MemcpyToHost               Memcpy_token_88_kernel_time
     6.710 ms  WebGpu    MemcpyToHost               Memcpy_token_90_kernel_time
     6.110 ms  WebGpu    QMoE                       /model/layers.2/moe/MoE_Quant_kernel_time
     5.550 ms  WebGpu    MemcpyToHost               Memcpy_token_87_kernel_time
     5.331 ms  WebGpu    MemcpyToHost               Memcpy_token_100_kernel_time
     5.293 ms  WebGpu    MemcpyToHost               Memcpy_token_101_kernel_time
     5.273 ms  WebGpu    MemcpyToHost               Memcpy_token_64_kernel_time
     5.257 ms  WebGpu    MemcpyToHost               Memcpy_token_99_kernel_time
     4.943 ms  WebGpu    MemcpyToHost               Memcpy_token_101_kernel_time
     4.913 ms  WebGpu    MemcpyToHost               Memcpy_token_89_kernel_time
     4.896 ms  WebGpu    MemcpyToHost               Memcpy_token_81_kernel_time
     4.893 ms  WebGpu    MemcpyToHost               Memcpy_token_82_kernel_time
     4.840 ms  WebGpu    MemcpyToHost               Memcpy_token_93_kernel_time
     4.810 ms  WebGpu    MemcpyToHost               Memcpy_token_64_kernel_time
     4.783 ms  WebGpu    QMoE                       /model/layers.45/moe/MoE_Quant_kernel_time
     4.781 ms  WebGpu    QMoE                       /model/layers.46/moe/MoE_Quant_kernel_time
     4.746 ms  WebGpu    MemcpyToHost               Memcpy_token_79_kernel_time
     4.665 ms  WebGpu    MemcpyToHost               Memcpy_token_84_kernel_time
     4.590 ms  WebGpu    MemcpyToHost               Memcpy_token_64_kernel_time
     4.571 ms  WebGpu    MemcpyToHost               Memcpy_token_78_kernel_time
     4.569 ms  WebGpu    MemcpyToHost               Memcpy_token_64_kernel_time
     4.494 ms  WebGpu    MemcpyToHost               Memcpy_token_64_kernel_time
     4.472 ms  WebGpu    MemcpyToHost               Memcpy_token_74_kernel_time
     4.466 ms  WebGpu    MemcpyToHost               Memcpy_token_103_kernel_time

PROFILING AFTER:

Top ops by total time:
  QMoE                            total= 3269.42 ms  count= 702  avg=4.657 ms  (86.1%)
  Add                             total=  119.22 ms  count=2860  avg=0.042 ms  (3.1%)
  SimplifiedLayerNormalization    total=  108.61 ms  count=4368  avg=0.025 ms  (2.9%)
  MatMulNBits                     total=   81.46 ms  count=5200  avg=0.016 ms  (2.1%)
  Mul                             total=   35.81 ms  count=2899  avg=0.012 ms  (0.9%)
  GroupQueryAttention             total=   35.60 ms  count= 728  avg=0.049 ms  (0.9%)
  Reshape                         total=   27.15 ms  count=3614  avg=0.008 ms  (0.7%)
  Sigmoid                         total=   24.87 ms  count=2158  avg=0.012 ms  (0.7%)
  MatMul                          total=   21.87 ms  count=1339  avg=0.016 ms  (0.6%)
  Neg                             total=   16.03 ms  count=1404  avg=0.011 ms  (0.4%)
  TopK                            total=    9.80 ms  count= 702  avg=0.014 ms  (0.3%)
  ScatterElements                 total=    9.58 ms  count= 702  avg=0.014 ms  (0.3%)
  GatherElements                  total=    9.10 ms  count= 702  avg=0.013 ms  (0.2%)
  Expand                          total=    8.95 ms  count= 702  avg=0.013 ms  (0.2%)
  Softplus                        total=    7.86 ms  count= 702  avg=0.011 ms  (0.2%)
  Shape                           total=    4.97 ms  count= 715  avg=0.007 ms  (0.1%)
  GatherBlockQuantized            total=    2.98 ms  count=  13  avg=0.230 ms  (0.1%)
  SkipSimplifiedLayerNormalization  total=    1.52 ms  count=  13  avg=0.117 ms  (0.0%)
  MemcpyFromHost                  total=    1.48 ms  count=  13  avg=0.114 ms  (0.0%)
  Cast                            total=    0.23 ms  count=  26  avg=0.009 ms  (0.0%)
  TOTAL                           total= 3796.87 ms

================================================================================
EXECUTION PROVIDER BREAKDOWN
================================================================================

  [CPUExecutionProvider]  total=0.64 ms  nodes=78  (0.0% of all time)
    Cast                            total=    0.23 ms  count=  26  avg=0.009 ms  (35.6%)
    ReduceSum                       total=    0.15 ms  count=  13  avg=0.012 ms  (24.0%)
    Sub                             total=    0.10 ms  count=  13  avg=0.008 ms  (16.3%)
    Gather                          total=    0.09 ms  count=  13  avg=0.007 ms  (13.7%)
    Shape                           total=    0.07 ms  count=  13  avg=0.005 ms  (10.4%)

  [WebGpuExecutionProvider]  total=3796.23 ms  nodes=29523  (100.0% of all time)
    QMoE                            total= 3269.42 ms  count= 702  avg=4.657 ms  (86.1%)
    Add                             total=  119.22 ms  count=2860  avg=0.042 ms  (3.1%)
    SimplifiedLayerNormalization    total=  108.61 ms  count=4368  avg=0.025 ms  (2.9%)
    MatMulNBits                     total=   81.46 ms  count=5200  avg=0.016 ms  (2.1%)
    Mul                             total=   35.81 ms  count=2899  avg=0.012 ms  (0.9%)
    GroupQueryAttention             total=   35.60 ms  count= 728  avg=0.049 ms  (0.9%)
    Reshape                         total=   27.15 ms  count=3614  avg=0.008 ms  (0.7%)
    Sigmoid                         total=   24.87 ms  count=2158  avg=0.012 ms  (0.7%)
    MatMul                          total=   21.87 ms  count=1339  avg=0.016 ms  (0.6%)
    Neg                             total=   16.03 ms  count=1404  avg=0.011 ms  (0.4%)
    TopK                            total=    9.80 ms  count= 702  avg=0.014 ms  (0.3%)
    ScatterElements                 total=    9.58 ms  count= 702  avg=0.014 ms  (0.3%)
    GatherElements                  total=    9.10 ms  count= 702  avg=0.013 ms  (0.2%)
    Expand                          total=    8.95 ms  count= 702  avg=0.013 ms  (0.2%)
    Softplus                        total=    7.86 ms  count= 702  avg=0.011 ms  (0.2%)

================================================================================
DATA TRANSFER (Memcpy) DETAILS
================================================================================

  Total Memcpy time: 1.48 ms across 13 transfers

  Top 30 slowest transfers:
       0.731 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.072 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.069 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.065 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.064 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.063 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.063 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.061 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.060 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.060 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.059 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.058 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.057 ms  MemcpyFromHost        WebGpuExecutionProvider         Memcpy_kernel_time
              in:  [{'int32': [1, 1]}]
              out: [{'int32': [1, 1]}]

================================================================================
CPU FALLBACK NODES (non-Memcpy)
================================================================================

  Total CPU fallback time: 0.64 ms across 6 unique nodes (78 events)

  All CPU fallback nodes:
       0.050 ms  Cast                  /model/attn_mask_reformat/attn_mask_subgraph/Sub/Cast
              in:  [{'int64': [1, 1]}]
              out: [{'int32': [1, 1]}]
       0.037 ms  ReduceSum             /model/attn_mask_reformat/attn_mask_subgraph/ReduceSum
              in:  [{'int64': [1, 128]}, {'int64': [1]}]
              out: [{'int64': [1, 1]}]
       0.020 ms  Sub                   /model/attn_mask_reformat/attn_mask_subgraph/Sub
              in:  [{'int64': [1, 1]}, {'int64': [1]}]
              out: [{'int64': [1, 1]}]
       0.007 ms  Gather                /model/attn_mask_reformat/attn_mask_subgraph/Gather_1
              in:  [{'int64': [2]}, {'int64': []}]
              out: [{'int64': []}]
       0.007 ms  Cast                  /model/attn_mask_reformat/attn_mask_subgraph/Gather/Cast
              in:  [{'int64': []}]
              out: [{'int32': []}]
       0.006 ms  Shape                 /model/attn_mask_reformat/attn_mask_subgraph/Shape
              in:  [{'int64': [1, 128]}]
              out: [{'int64': [2]}]

================================================================================
TOP 30 SLOWEST INDIVIDUAL NODES
================================================================================
    14.307 ms  WebGpu    QMoE                       /model/layers.2/moe/MoE_Quant_kernel_time
    11.587 ms  WebGpu    QMoE                       /model/layers.25/moe/MoE_Quant_kernel_time
     8.657 ms  WebGpu    QMoE                       /model/layers.16/moe/MoE_Quant_kernel_time
     8.575 ms  WebGpu    QMoE                       /model/layers.55/moe/MoE_Quant_kernel_time
     8.515 ms  WebGpu    QMoE                       /model/layers.4/moe/MoE_Quant_kernel_time
     8.300 ms  WebGpu    QMoE                       /model/layers.2/moe/MoE_Quant_kernel_time
     7.936 ms  WebGpu    QMoE                       /model/layers.51/moe/MoE_Quant_kernel_time
     7.921 ms  WebGpu    QMoE                       /model/layers.50/moe/MoE_Quant_kernel_time
     7.720 ms  WebGpu    QMoE                       /model/layers.52/moe/MoE_Quant_kernel_time
     7.354 ms  WebGpu    QMoE                       /model/layers.54/moe/MoE_Quant_kernel_time
     7.350 ms  WebGpu    QMoE                       /model/layers.2/moe/MoE_Quant_kernel_time
     7.282 ms  WebGpu    QMoE                       /model/layers.48/moe/MoE_Quant_kernel_time
     7.255 ms  WebGpu    QMoE                       /model/layers.53/moe/MoE_Quant_kernel_time
     7.195 ms  WebGpu    QMoE                       /model/layers.2/moe/MoE_Quant_kernel_time
     7.170 ms  WebGpu    QMoE                       /model/layers.49/moe/MoE_Quant_kernel_time
     7.110 ms  WebGpu    QMoE                       /model/layers.2/moe/MoE_Quant_kernel_time
     7.085 ms  WebGpu    QMoE                       /model/layers.47/moe/MoE_Quant_kernel_time
     7.045 ms  WebGpu    QMoE                       /model/layers.44/moe/MoE_Quant_kernel_time
     6.975 ms  WebGpu    QMoE                       /model/layers.52/moe/MoE_Quant_kernel_time
     6.812 ms  WebGpu    QMoE                       /model/layers.46/moe/MoE_Quant_kernel_time
     6.769 ms  WebGpu    QMoE                       /model/layers.55/moe/MoE_Quant_kernel_time
     6.747 ms  WebGpu    QMoE                       /model/layers.2/moe/MoE_Quant_kernel_time
     6.745 ms  WebGpu    QMoE                       /model/layers.45/moe/MoE_Quant_kernel_time
     6.695 ms  WebGpu    QMoE                       /model/layers.50/moe/MoE_Quant_kernel_time
     6.652 ms  WebGpu    QMoE                       /model/layers.49/moe/MoE_Quant_kernel_time
     6.499 ms  WebGpu    QMoE                       /model/layers.53/moe/MoE_Quant_kernel_time
     6.384 ms  WebGpu    QMoE                       /model/layers.46/moe/MoE_Quant_kernel_time
     6.370 ms  WebGpu    QMoE                       /model/layers.40/moe/MoE_Quant_kernel_time
     6.364 ms  WebGpu    QMoE                       /model/layers.2/moe/MoE_Quant_kernel_time
     6.309 ms  WebGpu    QMoE                       /model/layers.55/moe/MoE_Quant_kernel_time

@xenova
Copy link
Contributor Author

xenova commented Mar 5, 2026

cc @guschmue

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Mar 5, 2026
Co-authored-by: Joshua Lochner <admin@xenova.com>
@guschmue
Copy link
Contributor

guschmue commented Mar 5, 2026

this gets stuck on ut, ie TopKOperator.SmallArrayTopKSorted.
The issue is that the wg_size(256) is smaller than cols(400).
Need to do this via for() in the shader I think.

@xenova
Copy link
Contributor Author

xenova commented Mar 5, 2026

Was able to reproduce. Sure, let me update

@xenova
Copy link
Contributor Author

xenova commented Mar 5, 2026

Opus helped cook up a bitonic sort, and I just had to get it working with some edge cases (e.g., with duplicate values).
All my tests pass now, and no infinite hangs.

@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2026

'lintrunner -a' please :)

@xenova
Copy link
Contributor Author

xenova commented Mar 6, 2026

lintrunner installation seems to be broken in my env :/

✗ lintrunner init                     
[2026-03-06T16:44:03Z INFO lintrunner::linter] Initializing linter: 'RUFF'
[2026-03-06T16:44:03Z INFO lintrunner::linter] the init commands are ["python", "-m", "lintrunner_adapters", "run", "pip_init", "--dry-run=0", "--requirement=requirements-lintrunner.txt"]
error:        No such file or directory (os error 2)

I'll just apply the patch manually

@xenova
Copy link
Contributor Author

xenova commented Mar 6, 2026

hmm, weird. suggestion looks identical
image

@guschmue
Copy link
Contributor

guschmue commented Mar 6, 2026

BigArrayBigTopKSorted unit test is failing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants