InstructGPT高效实践——【DeepSpeed( 四 )


定义好自定义后,还需要对其进行“注册”,具体可见下述代码块 。
# applications/DeepSpeed-Chat/training/utils/data/data_utils.pydef get_raw_dataset(dataset_name, output_path, seed, local_rank):if "Dahoas/rm-static" in dataset_name:return raw_datasets.DahoasRmstaticDataset(output_path, seed,local_rank, dataset_name)elif "Dahoas/full-hh-rlhf" in dataset_name:return raw_datasets.DahoasFullhhrlhfDataset(output_path, seed,local_rank, dataset_name)···"""将自定义的PromptRawDataset在此处进行注册届时在传参“--data_path”中赋值“custom”即可读取到相应的数据集"""elif "custom" in dataset_name:return raw_datasets.CustomDataset(output_path, seed,local_rank, dataset_name)else:raise RuntimeError(f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py.")
至此完成自定义数据集的设置 。理论上来说,只要实例函数能完全按照注释要求对原始数据进行处理,那么后续的数据流基本也无需再进行任何额外修改也能顺畅运行了 。
0.2.3.2 阶段数据集处理过程
UML时序图(10-12)
这部分处理得到的数据形式,基本接近于数据传入阶段模型前的最终形式,因此通过理解这部分的数据处理过程,可以直接了解到模型所需要的输入形式 。
# applications/DeepSpeed-Chat/training/utils/data/data_utils.pydef create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer,end_of_conversation_token, max_seq_len):"""将根据不同的阶段(train_phase)对数据集进行处理,主要是调用原先在PromptRawDataset类中定义的实例函数来实现 。"""prompt_dataset = []chosen_dataset = []reject_dataset = []if train_phase == 1:# 因为phase1只需要用到chosen数据,所以只取chosen进行处理for i, tmp_data in enumerate(current_dataset):# 获取chosen_sentence,即是将prompt和chosen拼接起来形成完整对话# 具体样例可参照“数据格式基本概念”中的样例chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data)if chosen_sentence is not None:# 在对话末尾加入对话终止符chosen_sentence += end_of_conversation_token# 使用tokenizer处理chosen_sentence,采取截断truncationchosen_token = tokenizer(chosen_sentence,max_length=max_seq_len,padding="max_length",truncation=True,return_tensors="pt")# 去掉batch维度chosen_token["input_ids"] = chosen_token["input_ids"].squeeze(0)chosen_token["attention_mask"] = chosen_token["attention_mask"].squeeze(0)# 存储tokenize结果至列表chosen_datasetchosen_dataset.append(chosen_token)elif train_phase == 2:# phase2需要用到chosen_sentence和reject_sentence# 所以需要对两者都进行处理for i, tmp_data in enumerate(current_dataset):# 获取chosen_sentence,即是将prompt和chosen拼接起来形成完整对话# 具体样例可参照“数据格式基本概念”中的样例chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data)# the accept response# 获取reject_sentence,即是将prompt和rejeced拼接起来形成完整对话# 具体样例可参照“数据格式基本概念”中的样例reject_sentence = raw_dataset.get_prompt_and_rejected(tmp_data)if chosen_sentence is not None and reject_sentence is not None:# 在对话末尾加入对话终止符chosen_sentence += end_of_conversation_token# the accept responsereject_sentence += end_of_conversation_token# 使用tokenizer处理,采取截断truncationchosen_token = tokenizer(chosen_sentence,max_length=max_seq_len,padding="max_length",truncation=True,return_tensors="pt")reject_token = tokenizer(reject_sentence,max_length=max_seq_len,padding="max_length",truncation=True,return_tensors="pt")chosen_token["input_ids"] = chosen_token["input_ids"]chosen_token["attention_mask"] = chosen_token["attention_mask"]# 存储tokenize结果至列表chosen_datasetchosen_dataset.append(chosen_token)reject_token["input_ids"] = reject_token["input_ids"]reject_token["attention_mask"] = reject_token["attention_mask"]# 存储tokenize结果至列表reject_datasetreject_dataset.append(reject_token)elif train_phase == 3:# phase3用到prompt,prompt将被用来生成经验数据for i, tmp_data in enumerate(current_dataset):# 直接获取prompt# 具体样例可参照“数据格式基本概念”中的样例prompt = raw_dataset.get_prompt(tmp_data)if prompt is not None:prompt_token = tokenizer(prompt, return_tensors="pt")prompt_token["input_ids"] = prompt_token["input_ids"]prompt_token["attention_mask"] = prompt_token["attention_mask"]for key_word in ["input_ids", "attention_mask"]:# 获取当前文本token的实际长度length = prompt_token[key_word].size()[-1]# phase3此处的max_seq_len其实是max_prompt_len,默认只有256if length > max_seq_len:# 如果当前文本token长度比max_prompt_len还长# 那么就截断文本前面的部分,保留后面max_prompt_len长度的部分文本# 然后将token进行flip(翻转/倒序),之后在data_collator中再将其flip回来y = prompt_token[key_word].squeeze(0)[length -(max_seq_len -1):].flip(0)else:# 将token进行flip(翻转/倒序),之后在data_collator中再将其flip回来y = prompt_token[key_word].squeeze(0).flip(0)prompt_token[key_word] = yprompt_dataset.append(prompt_token)# 返回PromptDataset实例,该实例相当于torch中的Dataset,可供DataLoader调用return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset,tokenizer.pad_token_id, train_phase)