Skip to content

Conversation

@CC-Yeh
Copy link
Contributor

@CC-Yeh CC-Yeh commented Jan 20, 2026

Proposed changes

Add Metal quantized SDPA vector kernels based on #1515

With M4, L=32768, H = 32, D = 128, Lq=1 :

Precision SDPA (ms) Quant SDPA (ms) Ops-Based (ms) Quant Ops-Based (ms)
mxfp4 98.63080 15.32626 43.72120 24.71464
mxfp8 97.37316 18.71779 42.89932 46.47875

TODO:

What improve performance:

  • Removed thread storage k, v to reduce register pressure (was waiting on synchronization).
  • Fused computation with dequantization
  • Tuned reading size ('uint16_t'/'uin32_t') for quantized k/v
  • Manual unroll better than clang loop optimizer

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni
Copy link
Member

awni commented Jan 21, 2026

The numbers seem quite good.. a little too good to be true 😅

What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark?

@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 21, 2026

The numbers seem quite good.. a little too good to be true 😅

Totally agree, must be missing something 🤔

What's the difference between SDPA and Attention in the benchmark? Also what's the query sequence length used for the benchmark?

Attention is a simple reference implementation built from matmul + softmax + matmul (Maybe too naive?).
SDPA uses mx.fast.scaled_dot_product_attention, which hits the sdpa_vector_2pass kernels when Lq ≤ 8 (this case).

The query sequence length here is 1 (q.shape = (1, 32, 1, 128)), so this benchmark is measuring the single-token decode case, where one new token attends to a long KV cache (L = 32768).

@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 21, 2026

@awni
Fixed some bugs in dequantizing 8bit and benchmark(unneccessary dequantization steps).
Now the numbers make more sense 😃

@awni
Copy link
Member

awni commented Jan 21, 2026

So if I’m understanding correctly the fused implementation is slower in the quantized case than the unfused ops-based one?

@CC-Yeh
Copy link
Contributor Author

CC-Yeh commented Jan 21, 2026

Fused SDPA is faster: MXFP4 15.33 ms vs 24.71 ms, and MXFP8 26.09 ms vs 46.48 ms to decode a single query.

@awni
Copy link
Member

awni commented Jan 21, 2026

Very nice!!

Comment on lines +875 to +878
if (qmode == QuantizationMode::Nvfp4) {
throw std::invalid_argument(
"[quantized_scaled_dot_product_attention] Mode 'nvfp4' is not supported for fast attention.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not nvfp4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s on the way! I just wanted to make sure the PR structure was okay first.

Comment on lines +871 to +874
if (qmode == QuantizationMode::Affine) {
throw std::invalid_argument(
"[quantized_scaled_dot_product_attention] Only fp quantization modes are supported.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not affine?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants