Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于qwen1.5长序列训练的疑问 #761

Open
macheng6 opened this issue Jun 11, 2024 · 4 comments
Open

关于qwen1.5长序列训练的疑问 #761

macheng6 opened this issue Jun 11, 2024 · 4 comments

Comments

@macheng6
Copy link

使用A100-32卡,xtuner训练qwen1.5-32b长序列,有下面几个疑问想请教一下:

  1. qwen1.5貌似不支持use_varlen_attn=True参数,这个功能后面有计划增加吗?
  2. 对于特定模型来说,use_varlen_attn=True与use_varlen_attn=False,是否有明显差别?
  3. 在scheduler设置为按照step更新时,有没有公式可以计算总共需要训练多少个steps?然后好设置begin和end?
  4. 在多机多卡训练时,一条pack到max_length的数据,是怎么根据sequence_parallel_size与总卡数来切分数据的?
@HIT-cwh
Copy link
Collaborator

HIT-cwh commented Jun 12, 2024

  1. qwen1.5 目前支持设置 use_varlen_attn=True
  2. varlen 是 True or False 决定了attn计算的时候是只在每条短数据内算attn,还是在整条max_length长度的长数据内算loss。如果训练数据集数据长度普遍较短,且相互之间没什么关联,模型训练难度会增大,但也可以理解成是一种数据增强。
  3. XTuner 提供的 config 都是按照step更新lr的。例如,这个 qwen1_5_0_5b_chat_full_alpaca_e3.py config 中的学习率就是1. 前 3%的iter是warmup ,后面97%的iters是cos下降
  4. 一条数据会被切分为 sequence_parallel_size 份,因此sp_size = world_size // dp_size

@macheng6
Copy link
Author

macheng6 commented Jun 12, 2024

感谢回答,还有疑问是:

  1. use_varlen_attn=True时,报错如下:
    File "/opt/conda/lib/python3.8/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 748, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
    File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
    File "/nas/macheng.ma/projects/xtuner-main/xtuner/model/modules/dispatch/qwen2.py", line 317, in qwen2_varlen_attn_forward
    attn_output = varlen_flash_attn(attn_output = varlen_flash_attn(

    File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/attention.py", line 65, in sequence_parallel_attn
    attn_output = varlen_flash_attn( File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/attention.py", line 65, in sequence_parallel_attn

File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/attention.py", line 65, in sequence_parallel_attn
pre_process_for_sequence_parallel_attn(pre_process_for_sequence_parallel_attn(

File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/attention.py", line 23, in pre_process_for_sequence_parallel_attn
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/attention.py", line 23, in pre_process_for_sequence_parallel_attn
pre_process_for_sequence_parallel_attn(
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/attention.py", line 23, in pre_process_for_sequence_parallel_attn
query_states = all_to_all(
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 87, in all_to_all
query_states = all_to_all(
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 87, in all_to_all
query_states = all_to_all(
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 87, in all_to_all
return _AllToAll.apply(input, sp_group, scatter_dim, gather_dim)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
return _AllToAll.apply(input, sp_group, scatter_dim, gather_dim)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
return _AllToAll.apply(input, sp_group, scatter_dim, gather_dim)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 42, in forward
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 42, in forward
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 42, in forward
output = _all_to_all(input, ctx.world_size, sp_group, scatter_dim,
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 21, in _all_to_all
output = _all_to_all(input, ctx.world_size, sp_group, scatter_dim,
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 21, in _all_to_all
output = _all_to_all(input, ctx.world_size, sp_group, scatter_dim,
File "/nas/macheng.ma/projects/xtuner-main/xtuner/parallel/sequence/comm.py", line 21, in _all_to_all
dist.all_to_all(output_list, input_list, group=group)
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1436, in wrapper
dist.all_to_all(output_list, input_list, group=group)
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1436, in wrapper
dist.all_to_all(output_list, input_list, group=group)
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1441, in wrapper
return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 3266, in all_to_all
"args": f"{args}, {kwargs}",return func(*args, **kwargs)

File "/opt/conda/lib/python3.8/site-packages/torch/_tensor.py", line 426, in repr
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 3266, in all_to_all
return torch._tensor_str._str(self, tensor_contents=tensor_contents)
File "/opt/conda/lib/python3.8/site-packages/torch/_tensor_str.py", line 636, in _str
work = group.alltoall(output_tensor_list, input_tensor_list, opts)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3
ncclInternalError: Internal check failed.

  1. qwen1_5_0_5b_chat_full_alpaca_e3.py这个文件中,by_epoch=True应该是按照epoch来更新学习率,并不是按照step?

@HIT-cwh
Copy link
Collaborator

HIT-cwh commented Jun 14, 2024

  1. 大概率是nccl的版本问题,请先检查bashrc里有没有自己设置nccl相关的环境变量。然后用 pip install 的方式装最近版本的 torch ,会自动装 nccl 依赖。
  2. lr scheduler里有一个convert_to_iter_based=True的设置,我们会自动把begin, end中的epoch数转换为iter 数,学习率是按照step更新的。
    image

@macheng6
Copy link
Author

好的,感谢回答。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants