You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@OhadRubin I'm not familiar with multi_queries_paged_attention but flash_attention custom implementation from torch_xla has default block size set. You can try messing with those settings to use more TPU memory?
I'm running Llama3 70B with vllm on a TPU-v4-16, when using the flash attention kernel i'm able to go up to 16k, but using multi_queries_paged_attention with sequence length 256, it seems that the page table is taking too much smem.
@vanbasten23 @WoosukKwon any idea how to address this (i'm familiar with pallas programming)?
maybe something along the lines of this? https://github.com/vllm-project/vllm/blob/02222a0256f60319f5bcd56d1d036a943d6334f8/vllm/attention/backends/pallas.py#L260
The text was updated successfully, but these errors were encountered: