SIGN IN SIGN UP

Large Language Model Text Generation Inference

0 0 1 Python

Fix mask passed to flashinfer (#3324)

Custom masks are padded to the shape `[batch_size, max_len, max_len]`.
However, flashinfer expects an unpadded mask of the shape
`[sum(q_len[i] * k_len[i] for i in range(batch_size)]`.

This change unpads the custom mask (currently only used by Gemma 3)
to this shape (assuming q_len == k_len, since we only use the custom
mask during prefill).
D
Daniël de Kok committed
c6071749db61208dc22f658689f37f4eb803bde6
Parent: 4f067c2
Committed by GitHub <noreply@github.com> on 9/8/2025, 5:47:03 PM