Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interfaces to the pipeline to obtain logits and ppl #1652

Merged
merged 25 commits into from
Jun 25, 2024

Conversation

irexyc
Copy link
Collaborator

@irexyc irexyc commented May 24, 2024

Motivation

Add interfaces to the pipeline to obtain logits and ppl

Use cases (Optional)

from lmdeploy import pipeline
import numpy as np
import torch
from lmdeploy.vl import load_image

pipe = pipeline('/nvme/shared/llava-v1.5-7b/')
im = load_image('tiger.jpeg')

out = pipe.prepare_inputs([('hello', im)])
logits = pipe.get_logits(out['input_ids'], out['input_embeddings'], out['input_embedding_ranges'])

@irexyc irexyc added the WIP label May 24, 2024
@lvhan028
Copy link
Collaborator

May update both the llm pipeline and vlm pipeline user guide by adding an example of calculating logits and ppl respectively.
Regarding the example presented in the note, I suggest using internlm2-7b, xcomposer2-7b as the candidate model.
As for the tokenizer, how about using AutoTokenizer instead?
My concern is users know AutoTokenizer very well and it's not necessary to introduce our tokenizer.

@lvhan028 lvhan028 added the enhancement New feature or request label May 24, 2024
return event_loop


class LogitsMixin:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have strong reasons to make a new class?
Is there any concern if we put the following APIs to AsyncEngine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation is to make AsyncEngine class clean, if it is not necessary, I can move these functions into AsyncEngine

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your concern. But in my opinion, the motivation is not strong.
The "Mixin" pattern is often used to achieve some of the benefits of multiple inheritance while avoiding the complexities and potential issues associated with it
But in our cases, there is no multiple in inheritance

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lzhangzz @AllentDan any comments?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@irexyc If you insist on making logitsMixin, I suggest renaming the file as logits_mixin.py instead of utils.py

Copy link
Collaborator

@grimoire grimoire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

lmdeploy/serve/utils.py Outdated Show resolved Hide resolved
lmdeploy/serve/utils.py Outdated Show resolved Hide resolved
@irexyc
Copy link
Collaborator Author

irexyc commented Jun 6, 2024

There may be some bugs:

  1. 2024-06-06 06:30:36,737 - lmdeploy - ERROR - Engine loop failed with error: CUDA error: an illegal memory access was encountered
from lmdeploy.turbomind import TurboMind
from lmdeploy import pipeline, TurbomindEngineConfig, PytorchEngineConfig
pipe = pipeline('/nvme/shared/vicuna-7b-v1.5/', log_level='INFO', backend_config=PytorchEngineConfig(session_len=33000))

g = pipe.engine.create_instance()
g.decode([[100] * 10000], sequence_end=False)
g.decode([[100] * 10000], sequence_start=False)
  1. nan logits (both pytorch and turbomind backend)
from lmdeploy.turbomind import TurboMind
from lmdeploy import pipeline, TurbomindEngineConfig, PytorchEngineConfig
pipe = pipeline('/nvme/shared/vicuna-7b-v1.5/', log_level='INFO', backend_config=TurbomindEngineConfig(session_len=33000))
# pipe = pipeline('/nvme/shared/vicuna-7b-v1.5/', log_level='INFO', backend_config=PytorchEngineConfig(session_len=33000))
g = pipe.engine.create_instance()
g.decode([list(range(9700))])

tensor([[[-5.2604,  4.2658,  6.0810,  ..., -0.2098, -1.4354, -0.3216],
         [-7.8710,  2.2585,  3.1645,  ..., -2.7675, -5.0839, -3.0246],
         [-2.0298,  9.4127,  2.9417,  ..., -0.4258, -1.0950, -0.9109],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]]],
       device='cuda:0')

@grimoire
Copy link
Collaborator

grimoire commented Jun 6, 2024

The first error is caused by

block_offsets = self.block_offsets[:, :block_end]

[:, :block_end] should be removed.

@RunningLeon
Copy link
Collaborator

@irexyc Could fix conflict with main branch.

Conflicts:
	lmdeploy/pytorch/engine/engine_instance.py
(float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ * max_context_token_num_);
const auto tp = model_->tensor_para_.world_size_;
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ * num_token);
const auto tp = model_->tensor_para_.world_size_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current implmetation, these buffers are only going to be allocated once. num_token is just for this single iteration.

@lvhan028
Copy link
Collaborator

There may be some bugs:

  1. 2024-06-06 06:30:36,737 - lmdeploy - ERROR - Engine loop failed with error: CUDA error: an illegal memory access was encountered
from lmdeploy.turbomind import TurboMind
from lmdeploy import pipeline, TurbomindEngineConfig, PytorchEngineConfig
pipe = pipeline('/nvme/shared/vicuna-7b-v1.5/', log_level='INFO', backend_config=PytorchEngineConfig(session_len=33000))

g = pipe.engine.create_instance()
g.decode([[100] * 10000], sequence_end=False)
g.decode([[100] * 10000], sequence_start=False)
  1. nan logits (both pytorch and turbomind backend)
from lmdeploy.turbomind import TurboMind
from lmdeploy import pipeline, TurbomindEngineConfig, PytorchEngineConfig
pipe = pipeline('/nvme/shared/vicuna-7b-v1.5/', log_level='INFO', backend_config=TurbomindEngineConfig(session_len=33000))
# pipe = pipeline('/nvme/shared/vicuna-7b-v1.5/', log_level='INFO', backend_config=PytorchEngineConfig(session_len=33000))
g = pipe.engine.create_instance()
g.decode([list(range(9700))])

tensor([[[-5.2604,  4.2658,  6.0810,  ..., -0.2098, -1.4354, -0.3216],
         [-7.8710,  2.2585,  3.1645,  ..., -2.7675, -5.0839, -3.0246],
         [-2.0298,  9.4127,  2.9417,  ..., -0.4258, -1.0950, -0.9109],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]]],
       device='cuda:0')

Does this issue still exist?

"""Helper class to calculate logits and ppl."""

def prepare_inputs(self, prompts: Union[PromptType, List[PromptType]]):
if hasattr(self, '_convert_prompts'):
Copy link
Collaborator

@lvhan028 lvhan028 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we always apply the chat template for the VLM models, but we don't do it on LLMs, right?
If this is the case, I suggest we use AutoTokenizer.apply_chat_template as the example in pipeline.md, so that we won't bother by explaining to users whether we apply the chat template.

for prompt in prompts:
out = _get_event_loop().run_until_complete(
self._get_prompt_input(prompt,
do_preprocess=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better not hardcode do_preprocess for LLMs

@@ -289,7 +289,7 @@ def split(self, split_size: int, block_size: int):
if overlap:
block_end += 1

block_offsets = self.block_offsets[:, :block_end]
block_offsets = self.block_offsets
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RunningLeon Is this OK?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Because here is a bug as mentioned in #1652 (comment)

logits = pipe.get_logits(input_ids)

# ppl
ppl = pipe.get_ppl(input_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an interesting result: ppl is different (around 4%) between pytorch and turbomind for this example

Turbomind PyTorch
ppl 5.5916224 5.3524413

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible. They use different cuda kernels

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
import fire

def main(model_path, backend='turbomind'):
    from transformers import AutoTokenizer
        
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    messages = [
        {"role": "user", "content": "Hello, how are you?"},
    ]
    inputs = tokenizer.apply_chat_template(messages, return_tensors='pt', return_dict=True)
    input_ids = inputs["input_ids"][0].tolist()
    if backend == 'turbomind':
        from lmdeploy import pipeline, TurbomindEngineConfig
        pipe = pipeline(model_path, backend_config=TurbomindEngineConfig(session_len=33000))
        ppl = pipe.get_ppl(input_ids)
        print(ppl)
    elif backend == 'pytorch':
        from lmdeploy import pipeline, PytorchEngineConfig
        pipe = pipeline(model_path, backend_config=PytorchEngineConfig(session_len=33000))
        ppl = pipe.get_ppl(input_ids)
        print(ppl)
    elif backend == 'transformers':
        # from transformers.models.llama import LlamaForCausalLM
        # model = LlamaForCausalLM.from_pretrained(
        from transformers import AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(
            model_path, 
            attn_implementation='flash_attention_2', 
            torch_dtype=torch.float16,
            trust_remote_code=True)
        model.to("cuda")
        inputs.to(model.device)
        with torch.no_grad():
            outputs = model(
                **inputs,
                use_cache=False,
                labels=inputs["input_ids"],
            )
            logits = outputs.logits.squeeze(0)
            print(outputs.loss)


if __name__ == "__main__":
    fire.Fire(main)

turbomind: 5.5916224
pytorch: 5.365595
transformers: 5.57444953918457

@lvhan028 lvhan028 mentioned this pull request Jun 25, 2024
@lvhan028 lvhan028 merged commit c59a704 into InternLM:main Jun 25, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants