Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] SparseAttention op #21110

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft

[CPU] SparseAttention op #21110

wants to merge 16 commits into from

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jun 20, 2024

Description

Add SparseAttention cpu implementation. It depends on CPU Flash Attention in #20805.

This work is still in progress:

  • Refactoring GQAAttentionBase
  • Add SparseAttention implementation
  • Add test cases
  • Test performance

Motivation and Context

@tianleiwu tianleiwu requested a review from a team as a code owner June 20, 2024 03:43
@tianleiwu tianleiwu marked this pull request as draft June 20, 2024 03:43
Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PREfast found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@@ -0,0 +1,321 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@@ -0,0 +1,210 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
sequence_length = sequence_lengths[i % len(sequence_lengths)]
num_heads = heads[i % len(heads)]
head_size = head_sizes[i % len(head_sizes)]
format = formats[i % len(formats)]

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'format' is unnecessary as it is
redefined
before this value is used.
def get_test_cases(provider: str, has_past_kv:bool, comprehensive: bool, debug=False):
if provider == "CUDAExecutionProvider" and not has_cuda_support():
return
yield

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant