2018-04-03 10:24:57 +00:00
|
|
|
import torch
|
2018-03-22 21:35:02 +00:00
|
|
|
from torch.nn import functional
|
2018-03-25 02:22:45 +00:00
|
|
|
from torch import nn
|
2018-07-13 12:50:55 +00:00
|
|
|
from utils.generic_utils import sequence_mask
|
2018-03-22 21:06:54 +00:00
|
|
|
|
|
|
|
|
2018-07-13 12:50:55 +00:00
|
|
|
class L1LossMasked(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(L1LossMasked, self).__init__()
|
2018-03-22 21:06:54 +00:00
|
|
|
|
2018-07-13 12:50:55 +00:00
|
|
|
def forward(self, input, target, length):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
input: A Variable containing a FloatTensor of size
|
|
|
|
(batch, max_len, dim) which contains the
|
|
|
|
unnormalized probability for each class.
|
|
|
|
target: A Variable containing a LongTensor of size
|
|
|
|
(batch, max_len, dim) which contains the index of the true
|
|
|
|
class for each corresponding step.
|
|
|
|
length: A Variable containing a LongTensor of size (batch,)
|
|
|
|
which contains the length of each data in a batch.
|
|
|
|
Returns:
|
|
|
|
loss: An average loss value masked by the length.
|
|
|
|
"""
|
|
|
|
input = input.contiguous()
|
|
|
|
target = target.contiguous()
|
|
|
|
|
|
|
|
# logits_flat: (batch * max_len, dim)
|
|
|
|
input = input.view(-1, input.shape[-1])
|
|
|
|
# target_flat: (batch * max_len, dim)
|
|
|
|
target_flat = target.view(-1, target.shape[-1])
|
|
|
|
# losses_flat: (batch * max_len, dim)
|
2018-08-02 14:34:17 +00:00
|
|
|
losses_flat = functional.l1_loss(
|
|
|
|
input, target_flat, size_average=False, reduce=False)
|
2018-07-13 12:50:55 +00:00
|
|
|
# losses: (batch, max_len, dim)
|
|
|
|
losses = losses_flat.view(*target.size())
|
|
|
|
|
|
|
|
# mask: (batch, max_len, 1)
|
2018-08-02 14:34:17 +00:00
|
|
|
mask = sequence_mask(
|
|
|
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
2018-07-13 12:50:55 +00:00
|
|
|
losses = losses * mask.float()
|
|
|
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
class MSELossMasked(nn.Module):
|
2018-03-25 02:22:45 +00:00
|
|
|
def __init__(self):
|
2018-07-13 12:50:55 +00:00
|
|
|
super(MSELossMasked, self).__init__()
|
2018-04-03 10:24:57 +00:00
|
|
|
|
2018-03-25 02:22:45 +00:00
|
|
|
def forward(self, input, target, length):
|
|
|
|
"""
|
|
|
|
Args:
|
2018-03-29 01:20:56 +00:00
|
|
|
input: A Variable containing a FloatTensor of size
|
|
|
|
(batch, max_len, dim) which contains the
|
2018-03-25 02:22:45 +00:00
|
|
|
unnormalized probability for each class.
|
|
|
|
target: A Variable containing a LongTensor of size
|
2018-03-29 01:20:56 +00:00
|
|
|
(batch, max_len, dim) which contains the index of the true
|
2018-03-25 02:22:45 +00:00
|
|
|
class for each corresponding step.
|
|
|
|
length: A Variable containing a LongTensor of size (batch,)
|
|
|
|
which contains the length of each data in a batch.
|
|
|
|
Returns:
|
|
|
|
loss: An average loss value masked by the length.
|
|
|
|
"""
|
|
|
|
input = input.contiguous()
|
|
|
|
target = target.contiguous()
|
2018-03-22 21:06:54 +00:00
|
|
|
|
2018-03-25 02:22:45 +00:00
|
|
|
# logits_flat: (batch * max_len, dim)
|
2018-03-29 01:20:56 +00:00
|
|
|
input = input.view(-1, input.shape[-1])
|
2018-03-25 02:22:45 +00:00
|
|
|
# target_flat: (batch * max_len, dim)
|
2018-03-29 01:20:56 +00:00
|
|
|
target_flat = target.view(-1, target.shape[-1])
|
2018-03-25 02:22:45 +00:00
|
|
|
# losses_flat: (batch * max_len, dim)
|
2018-08-02 14:34:17 +00:00
|
|
|
losses_flat = functional.mse_loss(
|
|
|
|
input, target_flat, size_average=False, reduce=False)
|
2018-03-25 02:22:45 +00:00
|
|
|
# losses: (batch, max_len, dim)
|
|
|
|
losses = losses_flat.view(*target.size())
|
2018-07-13 12:50:55 +00:00
|
|
|
|
2018-03-25 02:22:45 +00:00
|
|
|
# mask: (batch, max_len, 1)
|
2018-08-02 14:34:17 +00:00
|
|
|
mask = sequence_mask(
|
|
|
|
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
2018-03-25 02:22:45 +00:00
|
|
|
losses = losses * mask.float()
|
|
|
|
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
2018-04-03 10:24:57 +00:00
|
|
|
return loss
|