pass sequence mask to the same device as the input

pull/10/head
erogol 2020-05-18 11:35:19 +02:00
parent 8805370645
commit 523fa5dfd2
1 changed files with 1 additions and 1 deletions

View File

@ -99,7 +99,7 @@ def sequence_mask(sequence_length, max_len=None):
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_range_expand = seq_range_expand.to(sequence_length.device)
seq_length_expand = (
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
# B x T_max