Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

有个bug运行不起来 #8

Open
The-kamisato opened this issue Feb 29, 2024 · 1 comment
Open

有个bug运行不起来 #8

The-kamisato opened this issue Feb 29, 2024 · 1 comment

Comments

@The-kamisato
Copy link

The-kamisato commented Feb 29, 2024

我注意到你们chat.py文件里面219行:
get_func = text_processor.get_func(inputs, **inputs_dic) if hasattr(text_processor, 'get_func') else get_masks_and_position_ids_default
可是如果我一开始没有输入图,image_position < 5,那么inputs_dic不会被赋值text_processor(new_prompt) (205行),就会报错“在变量定义之前使用”
请问这个怎么解决啊,谢谢

@qijimrc
Copy link
Collaborator

qijimrc commented Mar 9, 2024

Hi,感谢你对我们工作的关注和提问。因为我们默认CogCoM是一个针对图像的多模态模型,所以在训练时没有考虑输入不包含图像的情况。然而,你可以通过将chat.py中198~228行的代码替换为如下代码,使得模型兼容兼容不包含图像输入的情况:

        # if image_position < 5: # no image
        #     inputs = text_processor.tokenizer([prompt], return_tensors="pt").to(model.parameters().__next__().device)['input_ids'][0]
        #     # pre_image = 0
        # else:
        new_prompt = prompt[image_position:] if image_position >= 5 else prompt[image_position+1:]
        # new_prompt = prompt[image_position:]
        if not torch_image or hasattr(text_processor, 'no_eoi'):
            new_prompt = new_prompt.replace(text_processor.tokenizer.eoi, '', 1)
        inputs_dic = text_processor(new_prompt)
        for k in inputs_dic:
            if type(inputs_dic[k]) is torch.Tensor and inputs_dic[k].dtype is not torch.int and inputs_dic[k].dtype is not torch.long:
                inputs_dic[k] = inputs_dic[k].to(next(model.parameters()).dtype)
            if type(inputs_dic[k]) is torch.Tensor:
                inputs_dic[k] = inputs_dic[k].to(next(model.parameters()).device)
        inputs = inputs_dic['input_ids'].to(model.parameters().__next__().device)[0]
        # pre_image = inputs_dic['pre_image']
        
        seq = torch.cat(
            [inputs, torch.tensor([-1]*(max_length-len(inputs)), device=inputs.device)], dim=0
        )
        strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[text_processor.tokenizer.eos_token_id],
                                invalid_slices=invalid_slices, repetition_penalty=repetition_penalty)
        get_func = text_processor.get_func(inputs, **inputs_dic) if hasattr(text_processor, 'get_func') else get_masks_and_position_ids_default
        if image_position < 5:
            # inputs = {}
            inputs_dic.pop('input_ids')
            inputs = {**inputs_dic}
        else:
            inputs = {**{'vision_'+k:v for k,v in torch_image.items()}, **{'cross_'+k:v for k,v in cross_image.items()}}
            inputs_dic.pop('input_ids')
            inputs = {**inputs, **inputs_dic}

然而需要注意的是,我们目前的模型版本在多模态训练阶段没有纯文本样本(在上下文窗口中mask掉),所以经过测试发现在纯文本输入的情况下模型的回复效果较差。需要通过结合纯文本的微调来缓解,关于微调可以参考finetune.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants