Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device dispatcher #1775

Merged
merged 20 commits into from
Jun 21, 2024
Merged

Device dispatcher #1775

merged 20 commits into from
Jun 21, 2024

Conversation

grimoire
Copy link
Collaborator

@grimoire grimoire commented Jun 13, 2024

  • Add kernel in pytorch/kernels/<device name>
  • Update StepContext in pytorch/engine/deivces/<device name>
  • Add rewrite module and register in pytorch/models/module_map.py XXX_MODULE_MAP
  • Enable device in PytorchEngineConfig

requirement

@grimoire grimoire marked this pull request as ready for review June 19, 2024 09:18

class CUDADeviceUtils(BaseDeviceUtils):

device = 'cuda'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class name indicates CUDA. Is "device = 'cuda' " necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseDeviceUtils._sub_classes[sub_cls.device] = sub_cls

The class would be registered automatically with this key when being imported.

num_ignore_eos=num_ignore_eos,
output_que=out_que,
)
await self._async_step_background(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"with torch.cuda.stream(self.stream):" is removed. Will it bring side effect?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has been move out to:

with device_manager.context(self.device_context), torch.cuda.stream(

So context just need to be added once.

@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the package "default" used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is designed to place pure pytorch implementation, without any device assumption.
So if the backend does not provide a custom kernel, we can still perform the inference.

@lvhan028
Copy link
Collaborator

@zhulinJulia24
This PR brings in a minor reduction in inference speed.
Could you help verify the quantity?
The candidating models are llama3-7b, mixtral-moe-7x8b, internlm2-20b and llama3-70b

@lvhan028
Copy link
Collaborator

The models' evaluation has to be performed, too. @zhulinJulia24

@zhyncs
Copy link
Contributor

zhyncs commented Jun 20, 2024

The models' evaluation has to be performed, too. @zhulinJulia24

Do we really need to do evaluation? For example, may we verify the basic correctness by comparing the temperature at 0 and the results of transformers, and only do evaluation when it is necessary? Of course, this is just my suggestion. I think if resources are sufficient, running an evaluation task quickly should not be a problem.

@zhulinJulia24
Copy link
Collaborator

zhulinJulia24 commented Jun 21, 2024

https://github.com/InternLM/lmdeploy/actions/runs/9600920652
image

@grimoire @lvhan028
llama3 loss 10+points in precision evaluation, internlm2-chat-20b has 5+ points compared to hf transforms's precision.

all precision are improved compare to https://github.com/zhulinJulia24/lmdeploy/actions/runs/9240064913 which is 0.4.2 version's precision.

@zhulinJulia24
Copy link
Collaborator

zhulinJulia24 commented Jun 21, 2024

@zhulinJulia24 This PR brings in a minor reduction in inference speed. Could you help verify the quantity? The candidating models are llama3-7b, mixtral-moe-7x8b, internlm2-20b and llama3-70b

llama3-70b

lmdeploy serve api_server /mnt/models-new/llm_models/models--meta-llama--Meta-Llama-3-70B-Instruct/snapshots/0cac6d727e4cdf117e1bde11e4c7badd8b963919 --server-port 24555 --tp 4 --backend pytorch

concurrency: 256
elapsed_time: 483.240s

first_token latency(min, max, ave): 2.923s, 371.669s, 29.348s

number of prompt tokens: 447592
number of completion tokens: 404681
token throughput (completion token): 837.433 token/s
token throughput (prompt + completion token): 1763.666 token/s
RPS (request per second): 4.139 req/s
RPM (request per minute): 248.324 req/min

concurrency: 128
elapsed_time: 489.771s

first_token latency(min, max, ave): 0.465s, 17.343s, 3.991s

number of prompt tokens: 447592
number of completion tokens: 404681
token throughput (completion token): 826.265 token/s
token throughput (prompt + completion token): 1740.145 token/s
RPS (request per second): 4.084 req/s
RPM (request per minute): 245.012 req/min

mixtral-moe-7x8b

lmdeploy serve api_server /nvme/qa_test_models/mistralai/Mixtral-8x7B-Instruct-v0.1 --server-port 24555 --tp 2 --backend pytorch

concurrency: 128
elapsed_time: 345.898s

first_token latency(min, max, ave): 1.912s, 20.649s, 3.140s

number of prompt tokens: 491513
number of completion tokens: 474800
token throughput (completion token): 1372.658 token/s
token throughput (prompt + completion token): 2793.634 token/s
RPS (request per second): 5.782 req/s
RPM (request per minute): 346.923 req/min

concurrency: 256
elapsed_time: 324.822s

first_token latency(min, max, ave): 0.248s, 236.570s, 19.633s

number of prompt tokens: 491513
number of completion tokens: 474800
token throughput (completion token): 1461.722 token/s
token throughput (prompt + completion token): 2974.898 token/s
RPS (request per second): 6.157 req/s
RPM (request per minute): 369.433 req/min

internlm2-chat-20b

batch num_prompts RPS RPM FTL(ave)(s) FTL(min)(s) FTL(max)(s) throughput(out tok/s) throughput(total tok/s)
0 128 5000.0 7.328 439.652 2.406 1.746 13.746 1501.208 3206.576
1 256 5000.0 7.530 451.825 17.863 0.299 402.773 1542.775 3295.362

meta-Llama-3-8B-Instruct

batch num_prompts RPS RPM FTL(ave)(s) FTL(min)(s) FTL(max)(s) throughput(out tok/s) throughput(total tok/s)
0 128 5000.0 12.094 725.610 1.512 1.149 6.681 2428.294 5176.386
1 256 5000.0 11.937 716.228 11.475 0.177 96.754 2396.895 5109.453

internlm2-chat-20b and meta-Llama-3-8B-Instruct consistent with baseline of 0.4.2 version.

@lvhan028 lvhan028 changed the title [Draft] Device dispatcher Device dispatcher Jun 21, 2024
@lvhan028
Copy link
Collaborator

Comparing to the previous torch engine, as shown in https://github.com/zhulinJulia24/lmdeploy/actions/runs/9240064913, the evaluation accuracy doesn't degrade, does it?

@lvhan028
Copy link
Collaborator

After inner discussion, this PR didn't cause accuracy degradation comparing to the previous version.
We'll check if there is something wrong with the evaluation config.

@lvhan028 lvhan028 added the enhancement New feature or request label Jun 21, 2024
importlib.import_module(f'{__name__}.{device_type}')
assert device_type in loaded_utils
except ImportError:
logger.debug('Failed to import device utils for '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it fallbacks to cuda, should make it as warning instead of debug?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most kernel won't have a device special implement, warning would be annoy.
And we would fallback to default instead of cuda in the future.

Copy link
Collaborator

@RunningLeon RunningLeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 merged commit 3b39322 into InternLM:main Jun 21, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants