change the bitwise for masking and small fixes

pull/10/head
Eren Golge 2019-08-19 16:24:28 +02:00
parent c637aa04a2
commit 72ad58d893
3 changed files with 2 additions and 3 deletions

View File

@ -103,7 +103,6 @@ class MyDataset(Dataset):
if self.enable_eos_bos: if self.enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes) phonemes = pad_with_eos_bos(phonemes)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes return phonemes
def load_data(self, idx): def load_data(self, idx):

View File

@ -234,7 +234,7 @@ class Attention(nn.Module):
query, processed_inputs) query, processed_inputs)
# apply masking # apply masking
if mask is not None: if mask is not None:
attention.data.masked_fill_(torch.bitwise_not(mask), self._mask_value) attention.data.masked_fill_(~mask, self._mask_value)
# apply windowing - only in eval mode # apply windowing - only in eval mode
if not self.training and self.windowing: if not self.training and self.windowing:
attention = self.apply_windowing(attention, inputs) attention = self.apply_windowing(attention, inputs)

View File

@ -315,7 +315,7 @@ class Decoder(nn.Module):
# learn init values instead of zero init. # learn init values instead of zero init.
self.stopnet = StopNet(256 + memory_dim * self.r_init) self.stopnet = StopNet(256 + memory_dim * self.r_init)
def _set_r(self, new_r): def set_r(self, new_r):
self.r = new_r self.r = new_r
def _reshape_memory(self, memory): def _reshape_memory(self, memory):