ICASSP 2023论文模型开源|语音分离Mossformer( 三 )


▏如何训练自有的语音分离模型?
第一步:训练您的模型
环境准备

ICASSP 2023论文模型开源|语音分离Mossformer

文章插图
网站官方提供的环境已经安装好了所有依赖,能够直接开始训练 。如果您要在自己的设备上训练,可以参考模型主页上的环境准备步骤操作 。环境准备完成后建议运行模型主页上推理示例代码验证模型可以正常工作 。
数据准备
魔搭社区上开放的模型使用约30小时2人混合语音作为训练数据 。混合语音是基于WSJ0数据集生成的,由于WSJ0的问题无法在这里分享 。我们在上提供了基于数据集生成的混合音频,以便您快速开始训练 。其中训练集包含约42小时语音,共13900条,大小约7G 。请访问官网页面了解数据集详情,链接在文章末尾 。
模型训练
以下列出的为训练示例代码,其中可以替换成您需要的路径 。
数据训练一遍为一个epoch,默认共训练120个epoch,需要约10天 。
import osfrom datasets import load_datasetfrom modelscope.metainfo import Trainersfrom modelscope.msdatasets import MsDatasetfrom modelscope.preprocessors.audio import AudioBrainPreprocessorfrom modelscope.trainers import build_trainerfrom modelscope.utils.audio.audio_utils import to_segmentwork_dir = './train_dir'if not os.path.exists(work_dir):os.makedirs(work_dir)train_dataset = MsDataset.load('Libri2Mix_8k', split='train').to_torch_dataset(preprocessors=[AudioBrainPreprocessor(takes='mix_wav:FILE', provides='mix_sig'),AudioBrainPreprocessor(takes='s1_wav:FILE', provides='s1_sig'),AudioBrainPreprocessor(takes='s2_wav:FILE', provides='s2_sig')],to_tensor=False)eval_dataset = MsDataset.load('Libri2Mix_8k', split='validation').to_torch_dataset(preprocessors=[AudioBrainPreprocessor(takes='mix_wav:FILE', provides='mix_sig'),AudioBrainPreprocessor(takes='s1_wav:FILE', provides='s1_sig'),AudioBrainPreprocessor(takes='s2_wav:FILE', provides='s2_sig')],to_tensor=False)kwargs = dict(model='damo/speech_mossformer_separation_temporal_8k',train_dataset=train_dataset,eval_dataset=eval_dataset,work_dir=work_dir)trainer = build_trainer(Trainers.speech_separation, default_args=kwargs)trainer.train()
第二步:评估你的模型
以下列出的为模型评估代码,其中必须是您训练时指定的路径 。程序会搜索路径下的最佳模型并自动加载 。
import osfrom datasets import load_datasetfrom modelscope.metainfo import Trainersfrom modelscope.msdatasets import MsDatasetfrom modelscope.preprocessors.audio import AudioBrainPreprocessorfrom modelscope.trainers import build_trainerfrom modelscope.utils.audio.audio_utils import to_segmentwork_dir = './train_dir'if not os.path.exists(work_dir):os.makedirs(work_dir)train_dataset = Noneeval_dataset = MsDataset.load('Libri2Mix_8k', split='test').to_torch_dataset(preprocessors=[AudioBrainPreprocessor(takes='mix_wav:FILE', provides='mix_sig'),AudioBrainPreprocessor(takes='s1_wav:FILE', provides='s1_sig'),AudioBrainPreprocessor(takes='s2_wav:FILE', provides='s2_sig')],to_tensor=False)