【损失函数】Keras Loss Function

文章目录
ps: 前半部分代码均为版本 。最后附上uPIT-SiSNR的版本 。
损失函数,也是模型训练中非常重要的一块 。
常见损失函数:
损失函数示例:
语音分离
【】:uPIT-SiSNR()
【 RNN】:SiSNR
【】:SiSNR
【Conv-】:SiSNR()
音乐分离
【】:l1()
语音降噪
【】:l1 ,()
【】:SiSNR或 ()
【】l1
【】l1
【DCCRN】SiSNR(,)
【】:wSNR
【DF-】:SNR
关于uPIT Si-SNR
的代码:【//nnet/.py】:
SNR (-to-Noise Ratio)
ref:
Si-SNR (Scale-to-Noise Ratio)
也可参见论文中的表述:【 scale- -to-noise ratio andformulti-in noisy 】

【损失函数】Keras Loss Function

文章插图
可看出,SISNR的定义其实不止一种 。
这里以中代码为例 。可参见的主页:
【//nnet/.py】:
def cal_si_snr(source, estimate_source):"""Calculate SI-SNR.Arguments:---------source: [T, B, C],Where B is batch size, T is the length of the sources, C is the number of sourcesthe ordering is made so that this loss is compatible with the class PitWrapper.estimate_source: [T, B, C]The estimated source.Example:--------->>> import numpy as np>>> x = torch.Tensor([[1, 0], [123, 45], [34, 5], [2312, 421]])>>> xhat = x[:, (1, 0)]>>> x = x.unsqueeze(-1).repeat(1, 1, 2)>>> xhat = xhat.unsqueeze(1).repeat(1, 2, 1)>>> si_snr = -cal_si_snr(x, xhat)>>> print(si_snr)tensor([[[ 25.2142, 144.1789],[130.9283,25.2142]]])"""EPS = 1e-8assert source.size() == estimate_source.size()device = estimate_source.device.typesource_lengths = torch.tensor([estimate_source.shape[0]] * estimate_source.shape[1], device=device)mask = get_mask(source, source_lengths)estimate_source *= masknum_samples = (source_lengths.contiguous().reshape(1, -1, 1).float())# [1, B, 1]mean_target = torch.sum(source, dim=0, keepdim=True) / num_samplesmean_estimate = (torch.sum(estimate_source, dim=0, keepdim=True) / num_samples)zero_mean_target = source - mean_targetzero_mean_estimate = estimate_source - mean_estimate# mask padding position along Tzero_mean_target *= maskzero_mean_estimate *= mask# Step 2. SI-SNR with PIT# reshape to use broadcasts_target = zero_mean_target# [T, B, C]s_estimate = zero_mean_estimate# [T, B, C]# s_target = s / ||s||^2dot = torch.sum(s_estimate * s_target, dim=0, keepdim=True)# [1, B, C]s_target_energy = (torch.sum(s_target ** 2, dim=0, keepdim=True) + EPS)# [1, B, C]proj = dot * s_target / s_target_energy# [T, B, C]# e_noise = s' - s_targete_noise = s_estimate - proj# [T, B, C]# SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)si_snr_beforelog = torch.sum(proj ** 2, dim=0) / (torch.sum(e_noise ** 2, dim=0) + EPS)si_snr = 10 * torch.log10(si_snr_beforelog + EPS)# [B, C]return -si_snr.unsqueeze(0)def get_mask(source, source_lengths):"""Arguments---------source : [T, B, C]source_lengths : [B]Returns-------mask : [T, B, 1]Example:--------->>> source = torch.randn(4, 3, 2)>>> source_lengths = torch.Tensor([2, 1, 4]).int()>>> mask = get_mask(source, source_lengths)>>> print(mask)tensor([[[1.],[1.],[1.]],[[1.],[0.],[1.]],[[0.],[0.],[1.]],[[0.],[0.],[1.]]])"""T, B, _ = source.size()mask = source.new_ones((T, B, 1))for i in range(B):mask[source_lengths[i] :, i, :] = 0return mask
值得注意的是,这里的 和均减去了平均值 。同时,为了防止出现除法分母为0的错误,加上了EPS 。
PIT ()
PIT是一种训练的方法,全称为。这种训练方式就可以end-to-end去训练,总体思想很直觉,就是我先随便假设一个对应于输出的的顺序,稍微train几下,得到一个model 。然后,下一次train的时候,我会算两次SI-SDR之类的评价指标,分别是红1,蓝2和蓝1,红2,然后把Loss小的那个作为排序,然后按这个顺序train下去