diff --git a/layers/losses.py b/layers/losses.py index 9e467ef8..f2b6fcb0 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -44,10 +44,11 @@ class L1LossMasked(nn.Module): # target_flat: (batch * max_len, dim) target_flat = target.view(-1, target.shape[-1]) # losses_flat: (batch * max_len, dim) - losses_flat = functional.l1_loss(input, target, size_average=False, + losses_flat = functional.l1_loss(input, target_flat, size_average=False, reduce=False) # losses: (batch, max_len, dim) losses = losses_flat.view(*target.size()) + # mask: (batch, max_len, 1) mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)