Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch deepseek v2 #1621

Merged
merged 32 commits into from
Jun 24, 2024
Merged

Torch deepseek v2 #1621

merged 32 commits into from
Jun 24, 2024

Conversation

grimoire
Copy link
Collaborator

@grimoire grimoire commented May 20, 2024

  • MLA implementation hinted by https://kexue.fm/archives/10091 . kv share same cache blocks.
  • q,k will be loaded by two blocks(triton require block pow2)
  • q_a_proj, kv_a_proj_with_mqa in attention layer, gate in moe layer are not distributed so less nccl op are required with the cost of memorys.
  • Each GPU takes a response for 20 experts (8 A100).
  • block_size=32 would have better performance.
  • large amount of runtime memories are required. big cache_max_entry_count and max_prefill_token_num might leads to oom.

result of deepseek-v2-lite (WIP)

python3 benchmark/profile_throughput.py \
    ShareGPT_V3_unfiltered_cleaned_split.json \
    DeepSeek-V2-Lite-Chat \
    --backend pytorch \
    --cache-max-entry-count 0.8 \
    --num-prompts 3000 \
    --cache-block-seq-len 32 \
    --concurrency 256
--------------------------------------------------
concurrency: 256
elapsed_time: 311.155s

first token latency(s)(min, max, ave): 2.337, 19.513, 4.276
per-token latency(s) percentile(50, 75, 95, 99): [0.056, 0.112, 0.311, 0.408]

number of prompt tokens: 721790
number of completion tokens: 668262
token throughput (completion token): 2147.679 token/s
token throughput (prompt + completion token): 4467.387 token/s
RPS (request per second): 9.641 req/s
RPM (request per minute): 578.489 req/min
--------------------------------------------------

requirements:

@grimoire grimoire marked this pull request as draft May 20, 2024 13:24
@grimoire grimoire linked an issue May 20, 2024 that may be closed by this pull request
@grimoire grimoire marked this pull request as ready for review June 12, 2024 03:47
@grimoire grimoire changed the title [Draft] Torch deepseek v2 Torch deepseek v2 Jun 12, 2024
@zhyncs
Copy link
Contributor

zhyncs commented Jun 19, 2024

Hi @lvhan028 After this PR is ready and merged, will LMDeploy release a new release? Thanks. @grimoire

@zhyncs
Copy link
Contributor

zhyncs commented Jun 19, 2024

result of deepseek-v2-lite

Hi @grimoire Are the current performance benchmark results as expected, and how much of a leading advantage is there compared to vLLM? Thanks.

https://github.com/deepseek-ai/DeepSeek-V2?tab=readme-ov-file#inference-with-vllm-recommended

@grimoire
Copy link
Collaborator Author

@zhyncs the latest profile result (256 concurrency, 3000 prompt, block_size=32 and --cache-max-entry-count has been adjusted to prevent OOM):

  • deepseek v2: 3.627 req/s
  • deepseek v2 lite: 10.404 req/s

Apart from the fact that the default value cannot be used for block_size, the rest is relatively acceptable.

We have not performed benchmarks on vLLM yet, 8 A100 are not always available (T T).

@RunningLeon
Copy link
Collaborator

@grimoire ut TestMBGMV.test_mbgmv failed. may need to fix.

@zhyncs
Copy link
Contributor

zhyncs commented Jun 20, 2024

Hi @grimoire, I used your commit to run the workflow at https://github.com/zhyncs/lmdeploy/actions/runs/9584655537 and obtained the whl https://github.com/zhyncs/dl/releases/tag/0620. And I encountered an error triton-lang/triton#4172. Do you have any ideas? Thanks!

@grimoire
Copy link
Collaborator Author

triton has a prepackaged ptxas, which might be different with your cuda driver version. You can set your own ptxas (/path/to/cuda/bin/ptxas) with environment TRITON_PTXAS_PATH.

@zhyncs
Copy link
Contributor

zhyncs commented Jun 20, 2024

export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas

It works for me. Thanks and cheers. @grimoire

@zhyncs
Copy link
Contributor

zhyncs commented Jun 20, 2024

  • deepseek v2 lite: 10.404 req/s

Hi @grimoire May I ask if this uses a single A100 card or 8 cards? Thanks.

@zhyncs
Copy link
Contributor

zhyncs commented Jun 20, 2024

LMDeploy
https://github.com/zhyncs/dl/releases/tag/0620
https://github.com/zhyncs/lmdeploy/actions/runs/9584655537
grimoire@995e0ed

single A100

# server
python3 -m lmdeploy serve api_server DeepSeek-V2-Lite --backend pytorch --cache-block-seq-len 32

# client
# https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py
python3 benchmark_serving.py --backend lmdeploy --host 127.0.0.1 --port 23333 --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model DeepSeek-V2-Lite --tokenizer DeepSeek-V2-Lite --num-prompts 1000 --request-rate 128

result

# ignore_eos false
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  154.05
Total input tokens:                      236142
Total generated tokens:                  148682
Request throughput (req/s):              6.49
Input token throughput (tok/s):          1532.88
Output token throughput (tok/s):         965.14
---------------Time to First Token----------------
Mean TTFT (ms):                          56583.14
Median TTFT (ms):                        55727.01
P99 TTFT (ms):                           113475.30
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          116.80
Median TPOT (ms):                        90.45
P99 TPOT (ms):                           475.46
---------------Inter-token Latency----------------
Mean ITL (ms):                           77.64
Median ITL (ms):                         58.83
P99 ITL (ms):                            430.49
==================================================

# ignore_eos true
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  181.48
Total input tokens:                      236142
Total generated tokens:                  215605
Request throughput (req/s):              5.51
Input token throughput (tok/s):          1301.17
Output token throughput (tok/s):         1188.01
---------------Time to First Token----------------
Mean TTFT (ms):                          65216.61
Median TTFT (ms):                        65241.68
P99 TTFT (ms):                           135946.04
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          96.65
Median TPOT (ms):                        80.46
P99 TPOT (ms):                           267.16
---------------Inter-token Latency----------------
Mean ITL (ms):                           72.56
Median ITL (ms):                         60.45
P99 ITL (ms):                            372.03
==================================================

@grimoire
Copy link
Collaborator Author

Hi @grimoire May I ask if this uses a single A100 card or 8 cards? Thanks.

It is profiled with single A100. The bottleneck of lite model is on the host side, TP would make it worse.

@zhyncs
Copy link
Contributor

zhyncs commented Jun 20, 2024

  • block_size=32 would have better performance.

May we set the cache-block-seq-len to 32 by default when running DeepSeek V2 for inference? From the benchmark results, there is a significant performance gap.

# python3 benchmark_serving.py --backend lmdeploy --host 127.0.0.1 --port 23333 --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model DeepSeek-V2-Lite --tokenizer DeepSeek-V2-Lite --num-prompts 1000 --request-rate 128
# cache-block-seq-len 32, ignore_eos true
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  181.48
Total input tokens:                      236142
Total generated tokens:                  215605
Request throughput (req/s):              5.51
Input token throughput (tok/s):          1301.17
Output token throughput (tok/s):         1188.01
---------------Time to First Token----------------
Mean TTFT (ms):                          65216.61
Median TTFT (ms):                        65241.68
P99 TTFT (ms):                           135946.04
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          96.65
Median TPOT (ms):                        80.46
P99 TPOT (ms):                           267.16
---------------Inter-token Latency----------------
Mean ITL (ms):                           72.56
Median ITL (ms):                         60.45
P99 ITL (ms):                            372.03
==================================================


# cache-block-seq-len default, ignore_eos true
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  384.18
Total input tokens:                      236142
Total generated tokens:                  215594
Request throughput (req/s):              2.60
Input token throughput (tok/s):          614.67
Output token throughput (tok/s):         561.18
---------------Time to First Token----------------
Mean TTFT (ms):                          155387.50
Median TTFT (ms):                        153036.64
P99 TTFT (ms):                           328194.63
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          196.80
Median TPOT (ms):                        181.42
P99 TPOT (ms):                           515.61
---------------Inter-token Latency----------------
Mean ITL (ms):                           163.64
Median ITL (ms):                         96.56
P99 ITL (ms):                            1542.84
==================================================

Copy link
Collaborator

@RunningLeon RunningLeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

"""adjust block_size."""
# TODO: support kernel with both large head dim and large block size.
if model_config.k_head_dim >= 512 and cache_config.block_size > 32:
cache_config.block_size = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this affect models other than DeepSeek v2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the mha kernel needs enough smem to cache the kv_cache block and query block. Any model with such a large head_dim should be limited.
Among all the models that Pytorch engine has supported, only deepseek v2 with MLA implementation meets the condition.

@zhyncs
Copy link
Contributor

zhyncs commented Jun 21, 2024

LGTM

hold on plz. @grimoire @RunningLeon

python3 -m lmdeploy serve api_server /workdir/DeepSeek-V2-Lite-Chat --backend pytorch

# run multi times get different res when temperature 0
python3 benchmark/profile_restful_api.py 127.0.0.1:23333 /workdir/DeepSeek-V2-Lite-Chat /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model_name /workdir/DeepSeek-V2-Lite-Chat --num_prompts 1 --concurrency 1 --temperature 0

@grimoire
Copy link
Collaborator Author

@zhyncs temperature=0 is a invalid value
https://github.com/huggingface/transformers/blob/730a440734e1fb47c903c17e3231dac18e3e5fd6/src/transformers/generation/logits_process.py#L298

I set to 1 if temperature<=0 in pytorch engine.

if temperature <= 0:

@zhyncs
Copy link
Contributor

zhyncs commented Jun 21, 2024

@grimoire If the temperature is 0, it is not supported. How can I get a deterministic answer?

@zhyncs
Copy link
Contributor

zhyncs commented Jun 21, 2024

TurboMind supports [0,2] for temperature

@grimoire
Copy link
Collaborator Author

Just set topk=1 or given a small enough temperature
https://github.com/huggingface/transformers/blob/730a440734e1fb47c903c17e3231dac18e3e5fd6/src/transformers/generation/logits_process.py#L275-L279

Note that small temperature might still leads to different result if two value in logits are close.

@lvhan028 lvhan028 added the enhancement New feature or request label Jun 21, 2024
@@ -157,11 +179,16 @@ def __forward_hook(module, args, kwargs, output):
target_args = args
target_kwargs = kwargs
target_output = output
raise ExtractorFound()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tool is used to extract input/output of a submodule (for debug), computation after the module is not necessary.

from lmdeploy.pytorch.engine.model_agent import StepContext

if model_config is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert model_config

@lvhan028 lvhan028 merged commit da439df into InternLM:main Jun 24, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature] Support DeepSeek-V2 Model
4 participants