KIMI-K2.5-W4A8 · AMD MI308X TP=8 · EAGLE3

TPOT Optimization Summary

Five kernel-level patches that reduce decode-step latency on SGLang serving of Kimi-K2.5-W4A8 with EAGLE3 speculative decoding.
−17.2%
TPOT median (90.16 → 74.66 ms)
+6.3%
Total throughput (6390 → 6792 tok/s)
−15.2%
ITL median (34.23 → 29.05 ms)
±0.7%
GSM8K stderr band (accuracy preserved)

0Measurement Setup

All numbers in this report come from a single validation session on hjbog-srdc-2.amd.com, docker image clementlincf/amdafde:v0.5.10rc0-rocm720-mi30x-kimi-k2.5-opt-20260420. Logs are archived under progress/kimi-k2.5/container-validation-0420/.

Hardware

  • 8× AMD MI308X (gfx942, 80 CUs per device)
  • Tensor-parallel degree: TP=8
  • Interconnect: PCIe + node-local all-reduce

Software

  • SGLang kimi-K2.5-W8A8-dev-rebased
  • AITER kimi-K2.5-W4A8-rebased HEAD b5757d6
  • FlyDSL feature/w4a8-moe-port-rebased HEAD a1d8312
  • ROCm 7.20, PyTorch 2.x (container-baked)

Workload

  • Model: Kimi-K2.5-W4A8 (497 GB)
  • Draft: kimi-k2.5-eagle3 (6.0 GB)
  • Prompts: 160 random · concurrency 40
  • Input len 10240 · output len 512
  • EAGLE3: 3 steps · topk 1 · 4 draft tokens

Comparison

  • Before: launch_eagle3.sh — EAGLE3 only, no env-var knobs enabled
  • After: launch_eagle3_opt.sh — identical args + two env vars activating the patched code paths
What the −17.2% figure measures. It is the change in median TPOT between two SGLang benchmark runs against the same image, same workload, same model files. The only differences are (a) three git-tracked source commits in aiter/FlyDSL and (b) two environment variables that activate code paths those commits added. No tokenizer, schedule, or sampling parameters changed.

1Patch Summary

Five logically independent changes across three commits in two repositories. Each patch is described in its own section below with the measured kernel-level effect.

#PatchRepo · CommitScopeActivation
1MoE Stage1 scheduler-hints gateFlyDSL · a1d8312kernels/moe_gemm_2stage.pyFLYDSL_MOE_STAGE1_SCHED=0
2Stage1 auto tile_k=256 for decodeAITER · 511df6aaiter/fused_moe.pyAutomatic when block_m≤16
3Stage2 tile_k=256 overrideAITER · 511df6aaiter/fused_moe.pyAITER_FLYDSL_STAGE2_TILE_K=256
4MI308X bf16 GEMM tune (+65 entries)AITER · b5757d6kimik2_bf16_tuned_gemm.csvAutomatic via cu_num=80 lookup
5FMHA v3 bf16 rounding rtna → rtzAITER · b5757d6aiter/ops/mha.pyDefault changed (no env var needed)
Side note. The same 511df6a commit also adds AITER_FLYDSL_MOE_GRID_TRIM (default on) and a JSONL shape logger. These are infrastructure knobs; the grid-trim has measurable effect only for M≤8 (outside this EAGLE3 verify shape) and is not claimed as a contributor to the 17.2% figure.

2Patch 1 — Stage1 Scheduler-Hints Gate FLYDSL

FlyDSL's MoE Stage1 kernel ships with manual sched_barrier / sched_mfma / sched_dsrd / sched_vmem / sched_dswr hints intended to steer LLVM's instruction scheduler. On the W4A8 decode path, disabling these hints and letting LLVM's default scheduler run produces a shorter kernel.

Behaviour change

Kernel time −15% at M=40
FlyDSL commit a1d8312 · kernels/moe_gemm_2stage.py · +12/−2 lines

Added two environment gates (FLYDSL_MOE_STAGE1_SCHED and FLYDSL_MOE_STAGE2_SCHED, both default on). Setting Stage1 to 0 skips all hand-written sched_* hints during compile. Stage2's hints are kept on by default (they are tuned per tile_m and were approximately neutral in measurements).

ISA diff (Stage1, M=40, int4 W4A8 path, tile_m=16, tile_n=64, tile_k=128)

MetricSched ON (default)Sched OFFΔ
Total instructions4 8334 590−5.0%
MFMA4484480
buffer_load2932930
VALU3 3353 3350
s_waitcnt399149−62.7%
s_barrier56560
ds_write56560

The MFMA / memory / LDS footprint is unchanged. The reduction is concentrated in s_waitcnt (memory-fence instructions): 399 → 149, a drop of 250 waitcnts. The manual hints constrain LLVM's vmcnt tracking in a way that causes conservative extra fences; removing them allows a tighter schedule with fewer stalls.

Measured kernel time

ConfigurationStage1 (μs, M=40)Δ
Baseline (sched ON, tile_k=128)242
Sched OFF, tile_k=128206−14.9%

3Patch 2 — Stage1 Auto tile_k=256 for Decode AITER

Stage1's K-loop count is K / tile_k. For Kimi's int4 W4A8 Stage1 (K=7168), raising tile_k from 128 to 256 halves the number of K-iterations. The patch auto-picks tile_k=256 when block_m≤16 (decode regime) and leaves prefill at 128 where 256 was measured to regress.

Dispatch logic

Stage1 −5% additional on top of sched-off
AITER commit 511df6a · aiter/fused_moe.py
# auto-pick tile_k=256 for W4A8 decode (block_m ≤ 16), keep 128 for prefill
if _is_w_int4 and block_m <= 16:
    stage1_tile_k = 256
else:
    stage1_tile_k = 128

Halving the K-loop count also halves the loop-local barrier and s_waitcnt count. The patch additionally exposes AITER_FLYDSL_STAGE1_TILE_K as a manual override for experimentation.

Why it is safe for decode but not prefill

At decode shapes (small M), the kernel is memory-bound and VGPR pressure is low; the larger tile_k keeps occupancy stable and reduces per-iteration overhead. At prefill shapes (large M), the wider B-tile increases VGPR pressure enough to drop a wave per SIMD, and the per-iteration overhead is a smaller fraction of the compute cost. Measurements confirmed prefill regresses with tile_k=256.

Measured cumulative kernel time (Stage1, M=40)

ConfigurationStage1 (μs)Δ from baseline
Baseline (sched ON, tile_k=128)242
Sched OFF, tile_k=128206−14.9%
Sched OFF + auto tile_k=256193−20.2%

4Patch 3 — Stage2 tile_k=256 Override AITER

Stage2's K dimension equals the MoE intermediate dimension, which for Kimi-K2.5 is inter_dim=256. The default tile_k=64 therefore runs 4 K-iterations per tile. Setting tile_k=256 collapses the K-loop to a single MFMA step, eliminating three loop-closing barriers and their associated s_waitcnts per tile.

Environment override

Stage1+2 −5.2% at verify M=160
AITER commit 511df6a · aiter/fused_moe.py line 2519
# in flydsl_moe_stage2 dispatch
_over_tk = os.environ.get('AITER_FLYDSL_STAGE2_TILE_K', '')
if _over_tk: tile_k = int(_over_tk)

Stage2 tile sweep at M=160 (EAGLE3 verify shape)

tile_mtile_ntile_kTime (μs)ΔCorrectness
default25664 (default)833.7baseline
3225664834.10%
12864838.5+0.6%
256128813.3−2.5%
256256790.2−5.2%
32256256789.9−5.3%
64256256860.5+3.2%✗ wrong output

Effect across verify M values

Mdefault (μs)tile_k=256 (μs)Δ
40370368noise
80477465−2.4%
128527515−2.3%
160835792−5.2%
192844800−5.2%
256877830−5.3%
Why verify shapes (M=160, 192) benefit most. EAGLE3 verify step concatenates concurrency × (num_draft_tokens + 1) tokens per forward pass. With conc=40 and num_draft_tokens=4, this lands around M=160–192. Those shapes sit at the tile_k boundary where Stage2's 4-iteration K-loop is a significant overhead; decode-only shapes (M=40) are small enough that the overhead is already negligible.

5Patch 4 — MI308X bf16 GEMM Tuning AITER

AITER's tuned GEMM configuration file kimik2_bf16_tuned_gemm.csv shipped with entries keyed to cu_num=256 (MI300X). On MI308X (cu_num=80), every lookup missed, so every MLA projection GEMM fell back to the generic torch.matmul path (rocBLAS Tensile). The patch adds 65 entries covering the verify-step shapes that AITER's hand-written ASM GEMM wins on.

Tuning coverage

65 new entries · 2–4.4× per-kernel
AITER commit b5757d6 · aiter/configs/model_configs/kimik2_bf16_tuned_gemm.csv

Shape coverage

M ∈ {72…256} · N ∈ {2112, 3072, 3584, 4608, 7168, 14336} · K ∈ {7168, 14336}. Winning kernels are bf16gemm_fp32bf16_tn_*x64_pf3_splitk and *_splitk_clean (AITER hand-written asm).

Representative wins at M=192 (EAGLE3 verify shape)

NKrocBLAS (μs)AITER ASM kernelAITER (μs)Speedup
21127168181.696x64_pf3_splitk sk=177.52.3×
46087168495.296x64_pf3_splitk sk=1143.93.4×
30727168163.264x64_pf3_splitk sk=198.51.7×
307214336204.464x64_pf3_splitk sk=1185.51.1×

Why the CSV miss mattered

MLA layers in each decoder block issue four bf16 GEMMs (q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj / o_proj). With 60 layers × 4 GEMMs per forward pass and no cu_num=80 match, all 240 GEMMs per step went through rocBLAS. The added entries target the M range the EAGLE3 verify step actually produces (M=160–256), so the fast-path is hit at every verify invocation.

Coverage gap, by design. Decode shapes M≤72 remain on rocBLAS. rocBLAS is already well-tuned for skinny GEMMs at those sizes and AITER's ASM kernels did not win consistently there during sweep.

6Patch 5 — FMHA bf16 Rounding rtna → rtz AITER

AITER's FMHA v3 (asm attention) exposes three bf16 rounding modes for the fp32→bf16 conversion in the attention epilogue: rtne (0), rtna (1, default), and rtz (2). On gfx942, rtz maps to the native single-instruction v_cvt_pk_rtz_bf16_f32; rtna is emulated with multiple instructions. The patch flips the default from 1 to 2 at nine API sites in aiter/ops/mha.py.

Default change

Binary footprint −13% per kernel variant
AITER commit b5757d6 · aiter/ops/mha.py · 9 sites, 1–line change each

The change

# aiter/ops/mha.py, 9 call sites
-    how_v3_bf16_cvt: Optional[int] = 1  # rtna
+    how_v3_bf16_cvt: Optional[int] = 2  # rtz

Dispatch path

The value is threaded through Python → C++ struct mha_fwd_args.how_v3_bf16_cvt → CSV lookup in hsa/gfx942/fmha_v3_fwd/fmha_fwd.csv keyed on (dtype, hdim_q, hdim_v, mask, mode, bf16_cvt). The matching row names a .co (HSA code object) which AiterAsmKernel loads. So the change selects a different pre-compiled ASM binary at every call.

Code-object footprint comparison (MI308, hd128, no mask, batch mode)

Rounding modeCode objectSize (bytes)
rtne (0)fwd_hd128_bf16_rtne.co28 720
rtna (1, old default)fwd_hd128_bf16_rtna.co27 120
rtz (2, new default)fwd_hd128_bf16_rtz.co23 272

Safety on gfx950 (MI350)

csrc/cpp_itfs/mha_bwd.cu contains a clamp:

// rtna & rtz are deprecated in gfx950
if (get_gfx() == "gfx950" && how_v3_bf16_cvt != 0)
    how_v3_bf16_cvt = 0;  // force back to rtne

The default flip therefore only affects gfx942 targets; gfx950 behaviour is unchanged.

Numerical impact

rtz (truncation) has a larger single-op expected error than rtna (round-to-nearest), but the difference is within the 7-bit bf16 mantissa noise floor. End-to-end GSM8K accuracy was checked after the change (see Accuracy) and stayed within the ±0.66% stderr band.

7End-to-End Results

Three runs from the same image on the same host: baseline W4A8 (no EAGLE3, no opt), EAGLE3 only, EAGLE3 + opt patches. Workload: 160 random prompts, concurrency 40, input 10240, output 512.

Metric Baseline (W4A8) EAGLE3 EAGLE3 + opt Δ (EAGLE3 → opt)
Duration (s)288.20269.20253.27−5.9%
Total throughput (tok/s)5 9696 3916 793+6.3%
Output throughput (tok/s)284.25304.31323.45+6.3%
TPOT median (ms)93.0090.1674.66−17.2%
TPOT mean (ms)93.2992.3875.72−18.0%
TPOT P99 (ms)135.76187.98128.15−31.8%
ITL median (ms)50.8134.2329.05−15.1%
E2E median (s)71.7965.8362.28−5.4%
TTFT median (s)24.1818.6523.96+28.5%
Accept length3.933.930
A note on TTFT. The EAGLE3 + opt run shows higher TTFT than EAGLE3 alone. TTFT is dominated by prefill work; the patches target decode kernels (Stage1/2 MoE, FMHA decode, MLA verify-step GEMM). Prefill GEMM shapes (large M) are outside the bf16 GEMM tune's coverage and tile_k=256 regresses there. Under constant concurrency, a faster decode accumulates more queued requests which raises prefill queue depth, inflating TTFT. The 160-prompt concurrency-40 workload is decode-bound, so TPOT and throughput improve despite TTFT moving in the opposite direction.

Cumulative TPOT waterfall

Per-patch deltas aggregated from the 0418 session (same workload, different host). Absolute TPOT differs slightly between host sessions (80.96 ms on that run vs 90.16 ms on this one) due to cluster variability; percentage contributions are consistent.

Baseline EAGLE3 (this session)
90.16 ms
+ Stage1 sched off + tile_k auto
≈ 90 ms (throughput win)
+ bf16 GEMM tune
≈ 77 ms
+ FMHA rtz
≈ 76 ms
+ Stage2 tile_k=256
74.66 ms
Per-patch figures are measured as serial additions on the 0418 session. On the 0420 validation host the patches were applied together, so only the cumulative −17.2% is directly observed end-to-end on this image.

8Accuracy Preservation

GSM8K (10-shot, concurrent 256) was run at each stage to verify the kernel changes do not degrade generation quality. All three configurations produce accuracy within each other's ±0.66–0.70% stderr band.

ConfigurationStrict-matchstderrFlexible-extractstderr
Baseline (W4A8, no EAGLE3)0.9310±0.00700.9318±0.0069
EAGLE30.9393±0.00660.9401±0.0065
EAGLE3 + opt0.9386±0.00660.9378±0.0067
Numerical claim. The two changes with potential numerical impact are (a) FMHA rtna→rtz and (b) the AITER ASM GEMM kernels replacing rocBLAS on specific shapes. Observed strict-match moved by 0.0007 (EAGLE3 → opt), well inside the 0.0066 stderr of either run. Accept length is unchanged at 3.93, indicating EAGLE3's speculative verifier sees no additional rejection from the numerical shift.

9Reproduction

All numbers in this report are reproducible from the published image. The container entrypoint scripts are archived under progress/kimi-k2.5/container-validation-0420/.

Quick start

# 1. Pull and create container
docker run -d --name mycontainer \
  --network host --cap-add=CAP_IPC_LOCK --cap-add=CAP_SYS_NICE \
  --security-opt seccomp=unconfined --shm-size=128g \
  --device /dev/kfd --device /dev/dri \
  -v /mnt/md0/models:/mnt/md0/models -w /sgl-workspace \
  clementlincf/amdafde:v0.5.10rc0-rocm720-mi30x-kimi-k2.5-opt-20260420 \
  sleep infinity

# 2. Launch optimized server (env vars baked into script)
docker exec -d mycontainer bash -c "cd /opt/scripts && ./launch_eagle3_opt.sh"

# 3. Wait ~6 min for ready, then run serving bench
docker exec mycontainer /opt/scripts/bench_client.sh /tmp/bench.log

# 4. (Optional) GSM8K accuracy check
docker exec mycontainer /opt/scripts/gsm8k_eval.sh /tmp/gsm8k.log

launch_eagle3_opt.sh (key env vars)

# Source-level patches baked into /opt/aiter and /opt/FlyDSL HEADs.
# Runtime knobs below activate the code paths those patches added.
export FLYDSL_MOE_STAGE1_SCHED=0       # Patch 1: disable stage1 hand-sched hints
export AITER_FLYDSL_STAGE2_TILE_K=256  # Patch 3: stage2 single-step K-loop

# Patches 2, 4, 5 auto-activate (block_m check, cu_num=80 CSV lookup, new default).

Expected numbers (MI308X TP=8, conc=40, in=10240, out=512, 160 prompts)

  • TPOT median ≈ 74.66 ms
  • Total throughput ≈ 6 793 tok/s
  • ITL median ≈ 29.05 ms
  • Accept length ≈ 3.93
  • GSM8K strict-match ≈ 0.939

File locations

ComponentPath (in container)
Launch scripts/opt/scripts/launch_{baseline,eagle3,eagle3_opt}.sh
Bench client/opt/scripts/bench_client.sh
AITER MoE dispatch/opt/aiter/aiter/fused_moe.py
AITER FMHA Python API/opt/aiter/aiter/ops/mha.py
AITER bf16 GEMM CSV/opt/aiter/aiter/configs/model_configs/kimik2_bf16_tuned_gemm.csv
FlyDSL MoE kernel/opt/FlyDSL/kernels/moe_gemm_2stage.py
FMHA ASM code objects/opt/aiter/hsa/gfx942/fmha_v3_{fwd,bwd}/MI308/