TTS/layers/losses.py

56 lines
2.2 KiB
Python
Raw Normal View History

2018-03-25 02:22:45 +00:00
from torch import nn
2019-07-19 06:46:23 +00:00
from torch.nn import functional
2018-07-13 12:50:55 +00:00
from utils.generic_utils import sequence_mask
2018-03-22 21:06:54 +00:00
2018-07-13 12:50:55 +00:00
class L1LossMasked(nn.Module):
2019-07-19 06:46:23 +00:00
def forward(self, x, target, length):
2018-07-13 12:50:55 +00:00
"""
Args:
2019-07-19 06:46:23 +00:00
x: A Variable containing a FloatTensor of size
2018-07-13 12:50:55 +00:00
(batch, max_len, dim) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len, dim) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
# mask: (batch, max_len, 1)
2019-03-23 16:33:47 +00:00
target.requires_grad = False
2018-08-02 14:34:17 +00:00
mask = sequence_mask(
2018-08-13 13:02:17 +00:00
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
2019-07-19 06:46:23 +00:00
mask = mask.expand_as(x)
2018-08-13 13:02:17 +00:00
loss = functional.l1_loss(
2019-07-19 06:46:23 +00:00
x * mask, target * mask, reduction="sum")
2018-08-13 13:02:17 +00:00
loss = loss / mask.sum()
2018-07-13 12:50:55 +00:00
return loss
class MSELossMasked(nn.Module):
2019-07-19 06:46:23 +00:00
def forward(self, x, target, length):
2018-03-25 02:22:45 +00:00
"""
Args:
2019-07-19 06:46:23 +00:00
x: A Variable containing a FloatTensor of size
2018-03-29 01:20:56 +00:00
(batch, max_len, dim) which contains the
2018-03-25 02:22:45 +00:00
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
2018-03-29 01:20:56 +00:00
(batch, max_len, dim) which contains the index of the true
2018-03-25 02:22:45 +00:00
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
"""
# mask: (batch, max_len, 1)
2019-03-23 16:33:47 +00:00
target.requires_grad = False
2018-08-02 14:34:17 +00:00
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
2019-07-19 06:46:23 +00:00
mask = mask.expand_as(x)
loss = functional.mse_loss(
2019-07-19 06:46:23 +00:00
x * mask, target * mask, reduction="sum")
loss = loss / mask.sum()
2018-04-03 10:24:57 +00:00
return loss