Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RLHF code #736

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

[Feature] Add RLHF code #736

wants to merge 13 commits into from

Conversation

hui-zhao-1
Copy link

No description provided.

@hijkzzz
Copy link

hijkzzz commented Jun 3, 2024

Hi, XTuner Team

Could you please add a citation for the source of the Ray+vLLM-based RLHF architecture - OpenRLHF, such as in the README.md file: https://github.com/InternLM/xtuner?tab=readme-ov-file#%EF%B8%8F-acknowledgement.
We noticed that most RLHF-related code, particularly the Ray RLHF architecture in XTuner, are refactored from OpenRLHF. According to the Apache License 2.0 of OpenRLHF, the original copyright statement must be included.

An example:

XTuner's RLHF solution mainly references OpenRLHF, for which we are grateful.

Related MR:
#736
#764

Thank you

),
)

dataset_config = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议把经常会调整的配置放在配置文件靠前的地方,并且配置文件中多加一些注释方便用户理解。可以参考xtuner里的配置。

if sample_data[i].rm_meta != 'default':
cur_rm_data = [{
'role': 'system',
'content': META_PROMPT[sample_data[i].rm_meta]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以在注释里说明下这个是conditional system prompt,可以加一下paper的链接。
以及META_PROMPT这个变量名太抽象了,最好还是按paper里的名字来吧。

# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration."""
import logging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥有地方用了loguru的logger,这里又用了自定义的logger?最好统一一下吧

if not gather or labels is None:
return logp
logpy = torch.gather(logp, -1, labels.unsqueeze(2)).squeeze(-1)
return logpy.cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥这儿要cuda一下?会变到cpu上吗?

if policy_output[key] is None:
continue
batch_size, seq_len = policy_output[
key].shape # assert: only support 2d tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
key].shape # assert: only support 2d tensor
key].shap[:2]

这样是不是兼容性会好点

def padding_policy_outputs(policy_outputs: list[PolicyOutput],
padding_token_map={}):
DEFAULT_PADDING_ID = 0
RIGHT_PADDING = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个是准备做成可配置的吗?现在是写死了的

import time

import torch
from loguru import logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一一下logger用哪个吧

num_actions = action_mask.size(1)
if sft_model is not None:
self.sft_model: BaseModelServer = sft_model
kl_rewards, entropy, kl_distance, policy_logprobs, sft_logprobs = self._get_kl_rewards( # noqa: E501
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kl_rewards, entropy, kl_distance, policy_logprobs, sft_logprobs = self._get_kl_rewards( # noqa: E501
(kl_rewards, entropy, kl_distance, policy_logprobs, sft_logprobs) = self._get_kl_rewards(

这儿加个括号之后yapf就能帮你换行了

s_t = time.time()
value_output = value_model.infer_get(value_output_ref)
raw_values = value_output.logits.squeeze(-1)
logger.info(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看代码里好多这种计时的需求,往utils里面加个计时上下文是不是用起来方便点

class TimeLogger:
    def __init__(self, message: str):
        self.message = message

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        duration = round(time.time() - self.start_time, 2)
        logger.info(f'{self.message} duration: {duration} s')

from loguru import logger


class Timer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

诶,这里有个计时上下文呀,咋没用上

setup.py Outdated
Comment on lines 135 to 137
'rlhf':
parse_requirements('requirements/rlhf.txt'),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'rlhf':
parse_requirements('requirements/rlhf.txt'),
'rlhf':
parse_requirements('requirments/deepspeed.txt') +
parse_requirements('requirements/rlhf.txt'),

Comment on lines 1 to 2
-r requirements/deepspeed.txt
loguru
ray[default,train]==2.9.1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
-r requirements/deepspeed.txt
loguru
ray[default,train]==2.9.1
loguru
ray[default,train]==2.9.1

Comment on lines 6 to 13
num_gpus = 1
if 'parallel' in trainer_config:
parallel = trainer_config['parallel']
data = parallel.get('data', {'size': 1})
tensor = parallel.get('tensor', {'size': 1})
pipeline = parallel.get('pipeline', {'size': 1})
num_gpus = data['size'] * tensor['size'] * pipeline['size']
return num_gpus
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
num_gpus = 1
if 'parallel' in trainer_config:
parallel = trainer_config['parallel']
data = parallel.get('data', {'size': 1})
tensor = parallel.get('tensor', {'size': 1})
pipeline = parallel.get('pipeline', {'size': 1})
num_gpus = data['size'] * tensor['size'] * pipeline['size']
return num_gpus
return get_dp_size(trainer_config) * get_tp_size(trainer_config) * get_pp_size(trainer_config)

Comment on lines 84 to 86
logger.info(
f'{model_name} {model.__class__.__name__}.is_initialized: {model.is_initialized}' # noqa: E501
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这个打印不是非常必要,还容易引起误解

Comment on lines +32 to +56
try:
client_context = ray.init(
address=self.cluster_address,
runtime_env=runtime_env,
ignore_reinit_error=True,
)
logger.info(
f'Connected to a running ray cluster at {self.cluster_address}'
)
self.context_type = 'client'
self.context = client_context

except ConnectionError:
logger.info(
f'Error connecting to {self.cluster_address}, try initializing a new ray cluster.' # noqa: E501
)
ray_context = ray.init(
address=None,
resources=resources,
runtime_env=runtime_env,
ignore_reinit_error=True,
)
node_ip_address = ray_context.address_info['node_ip_address']
logger.info(f'Initialize a ray cluster at {node_ip_address}')
self.context_type = 'server'
self.context = ray_context
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以这是一个自动 fallback 到本地 ray server 的方案吗,感觉不要 fallback 直接抛出会更好

Comment on lines 43 to 92
class FileDataset(IterableDataset):
"""Single json file dataset."""

def __init__(self,
filename,
tokenizer,
sys_meta='default',
rm_meta='default'):
self._filename = filename
self.tokenizer = tokenizer
self.data_list = []
self.sys_meta = sys_meta
self.rm_meta = rm_meta
with open_file(self._filename) as fin:
for lineno, line in enumerate(fin):
data = json.loads(line)
self.data_list.append(data)

def __len__(self):
return len(self.data_list)

def __getitem__(self, index: int):
data = self.data_list[index]
try:
self.tokenizer.apply_chat_template(data, tokenize=True)
return {
'data': data,
'sys_meta': self.sys_meta,
'rm_meta': self.rm_meta
}
except Exception:
print(f'[data tokenize check] skip dirty data: {data}')
return None

def __iter__(self):
with open_file(self._filename) as fin:
for lineno, line in enumerate(fin):
data = json.loads(line)
try:
self.tokenizer.apply_chat_template(data, tokenize=True)
except Exception:
print(f'[data tokenize check] skip dirty data: {data}')
continue
if data is None:
continue
yield {
'data': data,
'sys_meta': self.sys_meta,
'rm_meta': self.rm_meta
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 作为一个 IterableDataset 实现 __len__, __getitem__ 这些 MapDataset 的接口,感觉不是很合理
  2. 作为 IterableDataset,在 __init__ 阶段将整个数据集从磁盘加载到内存,而在 __iter__ 中仍然从磁盘读取,是一种低效的行为
  3. 改个名字&docstring 可能更加合理。读取的文件是 Json line 而非 json,建议改名 JsonlDataset

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的代码也存在同样问题

Comment on lines +36 to +41
def __iter__(self):
while True:
self.rng.shuffle(self.indices)
for i in self.indices:
yield self.data[i]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以考虑在 epoch 结束的时候,输出一下日志,方便 debug

Copy link

@C1rN09 C1rN09 Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要重构。建议

  1. __init__ 改成和 Anthropic/hh-rlhf usage 一致的用法,避免从 path 中解析 data_dir
  2. 删除 save_to_diskload_from_disk
  3. 改一下 docstring

self.epoch_index = 0

def _init_in_data(self):
print(f"========================= Init in data sampler =========================")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件的打印信息要清理一下,去除不必要的

Comment on lines 26 to 27
prompt_datasets: list[str] = None,
pretrain_datasets: list[str] = None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

讨论一下:prompt_datasetspretrain_datasets 分开可能更好?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

讨论:是否保留,或者换种形式。

初衷是 RM 训练和 PPO 训练之间共享。但目前的仓库结构,RM 和 PPO 不在一起。这样没有意义


def __init__(
self,
dataloader: IterableDataset,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataloader 是一个 Dataset 感觉不是非常合理

dataloader (IterableDataset): generate rl data iteratively
reward_function: reward function that computes scalar reward for each episode # noqa: E501
"""
self.dataloader: IterableDataset = iter(dataloader)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果入参 dataloader 实际上是一个 Dataset,这里用 torch.utils.data.DataLoader 更合理,可以享受更多特性

Comment on lines 109 to 119
'role':
'assistant',
'content':
policyout.output_ans_str[i]
}]
else:
cur_rm_data = sample_data[i].message + [{
'role':
'assistant',
'content':
policyout.output_ans_str[i]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式有点问题?

from ..policy_output import logprobs_from_logits


class ActorLoss(torch.nn.Module):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ActorLoss(torch.nn.Module):
class PPOPolicyLoss(torch.nn.Module):

Comment on lines 34 to 53
def forward(self, logits: torch.Tensor, labels: dict[str, Any]):
"""Forward function of ActorLoss.

Args:
logits (Tensor): Forward result of the model. Its shape may be varied. # noqa: E501
For packed forward: (micro_bsz * seqlen, 1), where micro_bsz = 1 # noqa: E501
For non packed forward: (micro_bsz, seqlen, 1)

labels (tuple[dict]): Label values which are split by pipeline
schedule into pieces. The length of the list is micro_bsz. Each
element is a dict, representing labels to a batch.

Note:
The parameter `labels` seems strange because of pj-colossalai's
pipeline schedule mechanism. Labels are delivered to colosslai.Engine # noqa: E501
in List format, so pipeline schedule split it into micro_bsz pieces, # noqa: E501
and deliver them to loss_fn by `*args`.

Returns:
Tensor: Return the final loss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring mismatch

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

讨论:这段代码看来经过了重构,但可能会和 InterEvo 的 megatron-style 训练方式不兼容。需要额外确认

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以及 per_seqper_token 模式是否保留,是否需要额外说明

def policy_learn(self, trajectories, policy_model: BaseModelServer):
if self.policy_minibatch is None:
self.policy_minibatch = len(trajectories.output_ids)
policy_updates = len(trajectories.output_ids) // self.policy_minibatch
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不是整除,会丢数据。最好在这里 assert / logger.warning 一下,或者改一下 minibatch 的逻辑,保证数据都使用到

Comment on lines 78 to 80
assert len(
trajectories.output_ids[begin:end, :]
) == self.policy_minibatch, '[Policy learn] make sure len(policy_batch_inputs) == self.policy_minibatch' # noqa: E501
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个检查似乎是多余的?

Comment on lines 133 to 183
def value_learn_async(self, trajectories, value_model: BaseModelServer):
if self.value_minibatch is None:
self.value_minibatch = len(trajectories.output_ids)
value_updates = len(trajectories.output_ids) // self.value_minibatch
value_loss = []
assert value_updates == 1 and self.policy_learn_time == 1, f'value_updates={value_updates} * self.policy_learn_time={self.policy_learn_time} > 1' # noqa: E501
s_t = time.time()
value_batch_inputs, labels = self._value_learn_prepare(
0, 0, trajectories, value_updates)
v_loss_ref = value_model.train_async(
input_ids=value_batch_inputs['input_ids'],
labels=labels,
attention_mask=value_batch_inputs['attention_mask'],
criterion=self.value_criterion,
micro_batch_size=self.critic_micro_bs,
)
logger.info(
f'[critic train] async duration: {round(time.time() - s_t, 2)} s, {self.value_minibatch} batch' # noqa: E501
)
value_loss.append(v_loss_ref)
return value_loss

def value_learn_get(self, value_loss_ref, value_model: BaseModelServer):
with Timer('value_model.train_get'):
return [
value_model.train_get(ref).item() for ref in value_loss_ref
]

def value_learn(self, trajectories, value_model: BaseModelServer):
if self.value_minibatch is None:
self.value_minibatch = len(trajectories.output_ids)
value_updates = len(trajectories.output_ids) // self.value_minibatch
value_loss = []

for learn_i in range(self.policy_learn_time):
for step_i in range(value_updates):
s_t = time.time()
value_batch_inputs, labels = self._value_learn_prepare(
step_i, learn_i, trajectories, value_updates)
v_loss = value_model.train(
input_ids=value_batch_inputs['input_ids'],
labels=labels,
attention_mask=value_batch_inputs['attention_mask'],
criterion=self.value_criterion,
micro_batch_size=self.critic_micro_bs,
)
logger.info(
f'[critic train] duration: {round(time.time() - s_t, 2)} s, {self.value_minibatch} batch,value loss: {v_loss.item()}' # noqa: E501
)
value_loss.append(v_loss.item())
return value_loss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么 value_learn_async 不支持 value_updates / value_learn_time(代码 typo 注意改正)> 1,而 value_learn 支持?两者应该具有完全相同的功能,只是 sync/async 的区别

clip_reward_max: int = 5,
norm_rewards=True,
reward_scale: bool = False,
fine_grained_rm: bool = False,
Copy link

@C1rN09 C1rN09 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine_grained_rm 看起来还没实现,要不先去掉,等实现后再新写一个 class

from .running_mean_std import RunningStates


class BaseRepeater:
Copy link

@C1rN09 C1rN09 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是作为基类的话,可以考虑更加 Abstract 一些。现在的这个实现可以继承基类之后改个名字,比如 OutcomeKLGAERepeater 之类的(仅供参考)

Comment on lines 47 to 51
policy_model: BaseModelServer,
value_model: BaseModelServer,
sft_model: BaseModelServer = None,
# only used for async reward model.infer_get() in _get_kl_rewards
env=None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xxx_modelenv 放在 __init__ 里面可能更加合理一些,让这个 process 接口函数更加 consistent。我的理解是,并非所有 Repeater 都 exactly 需要这几个 xxx_model

if isinstance(v, torch.Tensor):
if not torch.equal(v, vother):
return False
elif isinstance(v, tuple): # tuple(torch.Tensor)
Copy link

@C1rN09 C1rN09 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

存在 vTuple[torch.Tensor] 的情况吗,似乎 v 只能是 torch.TensorNone
如果存在,下面的 to(self, device) 的实现需要更改

Comment on lines 24 to 29
if len(self.keys()) != len(other.keys()):
return False
for k, v in self.items():
if k not in other:
return False
vother = other[k]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果 other 的 keys 是 self 的超集,也可以通过检查,是否是预期的行为?
如果不是,可以考虑

if self.keys() != other.keys():
    return False

Comment on lines 59 to 60
if model_path == 'internlm/internlm2-chat-1_8b-sft':
return InternLM2ForCausalLM
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 hardcode 有什么特别的原因吗?

Comment on lines 67 to 68
trainer_type = self.trainer_config.get('trainer_type',
'huggingface').lower()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
trainer_type = self.trainer_config.get('trainer_type',
'huggingface').lower()
trainer_type = self.trainer_config.get('trainer_type',
ENGINE_HUGGINGFACE).lower()

def init_tokenizer_and_config(self, model_config):
super().init_tokenizer_and_config(self.model_config)

self.reward_token_id = self.tokenizer.pad_token_id
Copy link

@C1rN09 C1rN09 Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认 self.reward_token_id = self.tokenizer.pad_token_id 很巧妙,但感觉可以加一行注释说明一下原因


# 启动任务,首次启动建议添加 HF_ENDPOINT=https://hf-mirror.com 方便数据集加载
HF_ENDPOINT=https://hf-mirror.com xtuner rlhf -c examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

记得添加read the docs里详细文档的链接

"""Post process sequence: tokenization & truncation."""
message_data = message['data']
new_meaasage_data = []
if self.message_type == 'prompt':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move post-process to dataset class

trajectories['rewards'] = rewards

# pretrain data
if self.pretrain_mes_iter is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move pretrain loss tokenize to pretrain dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants