mirror of https://github.com/coqui-ai/TTS.git
change the bitwise for masking and small fixes
parent
c637aa04a2
commit
72ad58d893
|
@ -103,7 +103,6 @@ class MyDataset(Dataset):
|
||||||
if self.enable_eos_bos:
|
if self.enable_eos_bos:
|
||||||
phonemes = pad_with_eos_bos(phonemes)
|
phonemes = pad_with_eos_bos(phonemes)
|
||||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||||
|
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
def load_data(self, idx):
|
def load_data(self, idx):
|
||||||
|
|
|
@ -234,7 +234,7 @@ class Attention(nn.Module):
|
||||||
query, processed_inputs)
|
query, processed_inputs)
|
||||||
# apply masking
|
# apply masking
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
attention.data.masked_fill_(torch.bitwise_not(mask), self._mask_value)
|
attention.data.masked_fill_(~mask, self._mask_value)
|
||||||
# apply windowing - only in eval mode
|
# apply windowing - only in eval mode
|
||||||
if not self.training and self.windowing:
|
if not self.training and self.windowing:
|
||||||
attention = self.apply_windowing(attention, inputs)
|
attention = self.apply_windowing(attention, inputs)
|
||||||
|
|
|
@ -315,7 +315,7 @@ class Decoder(nn.Module):
|
||||||
# learn init values instead of zero init.
|
# learn init values instead of zero init.
|
||||||
self.stopnet = StopNet(256 + memory_dim * self.r_init)
|
self.stopnet = StopNet(256 + memory_dim * self.r_init)
|
||||||
|
|
||||||
def _set_r(self, new_r):
|
def set_r(self, new_r):
|
||||||
self.r = new_r
|
self.r = new_r
|
||||||
|
|
||||||
def _reshape_memory(self, memory):
|
def _reshape_memory(self, memory):
|
||||||
|
|
Loading…
Reference in New Issue