mirror of https://github.com/coqui-ai/TTS.git
pass sequence mask to the same device as the input
parent
8805370645
commit
523fa5dfd2
|
@ -99,7 +99,7 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
# B x T_max
|
||||
|
|
Loading…
Reference in New Issue