diff --git a/layers/tacotron.py b/layers/tacotron.py index 3e82486f..15cf4b28 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -339,11 +339,10 @@ class Decoder(nn.Module): def _reshape_memory(self, memory): B = memory.shape[0] - if memory is not None: - # Grouping multiple frames if necessary - if memory.size(-1) == self.memory_dim: - memory = memory.contiguous() - memory = memory.view(B, memory.size(1) // self.r, -1) + # Grouping multiple frames if necessary + if memory.size(-1) == self.memory_dim: + memory = memory.contiguous() + memory = memory.view(B, memory.size(1) // self.r, -1) # Time first (T_decoder, B, memory_dim) memory = memory.transpose(0, 1) return memory @@ -370,7 +369,8 @@ class Decoder(nn.Module): T = inputs.size(1) # Run greedy decoding if memory is None greedy = not self.training - memory = self._reshape_memory(memory) + if memory is not None: + memory = self._reshape_memory(memory) T_decoder = memory.size(0) # go frame as zeros matrix initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_() @@ -461,4 +461,4 @@ class StopNet(nn.Module): outputs = self.dropout(inputs) outputs = self.linear(outputs) outputs = self.sigmoid(outputs) - return outputs \ No newline at end of file + return outputs diff --git a/train.py b/train.py index 54d6140f..8fe07ded 100644 --- a/train.py +++ b/train.py @@ -401,6 +401,16 @@ def main(args): if args.restore_path: checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + # 1. filter out unnecessary keys + pretrained_dict = { + k: v + for k, v in checkpoint['model'].items() if k in model_dict + } + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + model.load_state_dict(model_dict) if use_cuda: model = model.cuda() criterion.cuda()