fix duration predictor in AlignTTS

pull/373/head
Eren Gölge 2021-03-16 17:06:46 +01:00
parent c2d29e5cd4
commit 07269e639b
1 changed files with 5 additions and 5 deletions

View File

@ -2,8 +2,9 @@ import torch
import math
from torch import nn
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.layers.feed_forward.encoder import Encoder, PositionalEncoding
from TTS.tts.layers.align_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.utils.generic_utils import sequence_mask
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
from TTS.tts.layers.align_tts.mdn import MDNBlock
@ -11,6 +12,7 @@ from TTS.tts.layers.align_tts.mdn import MDNBlock
class AlignTTS(nn.Module):
"""Speedy Speech model with Monotonic Alignment Search
https://arxiv.org/abs/2008.03802
@ -75,11 +77,9 @@ class AlignTTS(nn.Module):
self.pos_encoder = PositionalEncoding(hidden_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels +
c_in_channels)
self.duration_predictor = DurationPredictor(num_chars, hidden_channels, hidden_channels_ffn=1024, num_heads=2)
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
# self.wn_spec_encoder = WNSpecEncoder(out_channels, hidden_channels, c_in_channels=c_in_channels)
self.mdn_block = MDNBlock(hidden_channels, 2*out_channels)
if num_speakers > 1 and not external_c: