Large Language Model Text Generation Inference
Fix mask passed to flashinfer (#3324)
Custom masks are padded to the shape `[batch_size, max_len, max_len]`. However, flashinfer expects an unpadded mask of the shape `[sum(q_len[i] * k_len[i] for i in range(batch_size)]`. This change unpads the custom mask (currently only used by Gemma 3) to this shape (assuming q_len == k_len, since we only use the custom mask during prefill).
D
Daniël de Kok committed
c6071749db61208dc22f658689f37f4eb803bde6
Parent: 4f067c2
Committed by GitHub <noreply@github.com>
on 9/8/2025, 5:47:03 PM