【损失函数】Keras Loss Function( 二 )


uPIT (-level PIT)
uPIT 相当于在上述所有组合的情况中,找一种最优的输出 。
实现代码:
class PitWrapper(nn.Module):"""Permutation Invariant Wrapper to allow Permutation Invariant Training(PIT) with existing losses.Permutation invariance is calculated over the sources/classes axis which isassumed to be the rightmost dimension: predictions and targets tensors areassumed to have shape [batch, ..., channels, sources].Arguments---------base_loss : functionBase loss function, e.g. torch.nn.MSELoss. It is assumed that it takestwo arguments:predictions and targets and no reduction is performed.(if a pytorch loss is used, the user must specify reduction="none").Returns---------pit_loss : torch.nn.ModuleTorch module supporting forward method for PIT.Example------->>> pit_mse = PitWrapper(nn.MSELoss(reduction="none"))>>> targets = torch.rand((2, 32, 4))>>> p = (3, 0, 2, 1)>>> predictions = targets[..., p]>>> loss, opt_p = pit_mse(predictions, targets)>>> losstensor([0., 0.])"""def __init__(self, base_loss):super(PitWrapper, self).__init__()self.base_loss = base_lossdef _fast_pit(self, loss_mat):"""Arguments----------loss_mat : torch.TensorTensor of shape [sources, source] containing loss values for eachpossible permutation of predictions.Returns-------loss : torch.TensorPermutation invariant loss for the current batch, tensor of shape [1]assigned_perm : tupleIndexes for optimal permutation of the input over sources whichminimizes the loss."""loss = Noneassigned_perm = Nonefor p in permutations(range(loss_mat.shape[0])):c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()# return loss_mat[range(loss_mat.shape[0]), p][0], p############################################################ IMPORTANT ###########################################if loss is None or loss > c_loss:loss = c_lossassigned_perm = p#########################################################return loss, assigned_permdef _opt_perm_loss(self, pred, target):"""Arguments---------pred : torch.TensorNetwork prediction for the current example, tensor ofshape [..., sources].target : torch.TensorTarget for the current example, tensor of shape [..., sources].Returns-------loss : torch.TensorPermutation invariant loss for the current example, tensor of shape [1]assigned_perm : tupleIndexes for optimal permutation of the input over sources whichminimizes the loss."""n_sources = pred.size(-1)pred = pred.unsqueeze(-2).repeat(*[1 for x in range(len(pred.shape) - 1)], n_sources, 1)target = target.unsqueeze(-1).repeat(1, *[1 for x in range(len(target.shape) - 1)], n_sources)loss_mat = self.base_loss(pred, target)assert (len(loss_mat.shape) >= 2), "Base loss should not perform any reduction operation"mean_over = [x for x in range(len(loss_mat.shape))]loss_mat = loss_mat.mean(dim=mean_over[:-2])return self._fast_pit(loss_mat)def reorder_tensor(self, tensor, p):"""Arguments---------tensor : torch.TensorTensor to reorder given the optimal permutation, of shape[batch, ..., sources].p : list of tuplesList of optimal permutations, e.g. for batch=2 and n_sources=3[(0, 1, 2), (0, 2, 1].Returns-------reordered : torch.TensorReordered tensor given permutation p."""reordered = torch.zeros_like(tensor, device=tensor.device)for b in range(tensor.shape[0]):reordered[b] = tensor[b][..., p[b]].clone()return reordereddef forward(self, preds, targets):"""Arguments---------preds : torch.TensorNetwork predictions tensor, of shape[batch, channels, ..., sources].targets : torch.TensorTarget tensor, of shape [batch, channels, ..., sources].Returns-------loss : torch.TensorPermutation invariant loss for current examples, tensor ofshape [batch]perms : listList of indexes for optimal permutation of the inputs oversources.e.g., [(0, 1, 2), (2, 1, 0)] for three sources and 2 examplesper batch."""losses = []perms = []for pred, label in zip(preds, targets):loss, p = self._opt_perm_loss(pred, label)perms.append(p)losses.append(loss)loss = torch.stack(losses)return loss, perms
其中,类是中一个枚举所有的类: