Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

There are precision errors compared with flash_attn_2_cuda.varlen_fwd #335

Open
Amanda-Barbara opened this issue Jun 24, 2024 · 3 comments

Comments

@Amanda-Barbara
Copy link

Amanda-Barbara commented Jun 24, 2024

There are precision errors compared with flash_attn_2_cuda.varlen_fwd when I use flashinfer.single_prefill_with_kv_cache function to run cohere_plus model, below is the code I used:
fi_fwd_out = flashinfer.single_prefill_with_kv_cache(q.contiguous(), k.contiguous(), v.contiguous(), causal=True, sm_scale=softmax_scale, allow_fp16_qk_reduction=False)
fa2_fwd_out = flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
torch.allclose(fi_fwd_out, fa2_fwd_out, rtol=1e-3, atol=1e-3)
It is worth noting that the first half of the layers are same, but second half are different.
can you give an official example code for precision comparison with flash_attn_2_cuda.varlen_fwd? Thanks!

@Amanda-Barbara
Copy link
Author

It seems the problem with the accuracy error is that flashinfer.single_prefill_with_kv_cache doesn't support cu_seqlens_q and cu_seqlens_k, If I want to use flashinfer's prefill function, How to call it like flash_attn_2_cuda.varlen_fwd?

@yzh119
Copy link
Collaborator

yzh119 commented Jun 24, 2024

single_prefill_with_kv_cache is only designed for single request (no batching and variable length).

For batch prefill with variable length, you have to use https://docs.flashinfer.ai/api/python/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper, where query and key/value cache are organized as ragged tensor. (see our layout documentation: https://docs.flashinfer.ai/tutorials/kv_layout.html).

If use you paged kv-cache, you should use https://docs.flashinfer.ai/api/python/prefill.html#flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper

@Amanda-Barbara
Copy link
Author

@yzh119 thanks very much, I will try it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants