mirror of https://github.com/coqui-ai/TTS.git
fix duration predictor in AlignTTS
parent
c2d29e5cd4
commit
07269e639b
|
@ -2,8 +2,9 @@ import torch
|
|||
import math
|
||||
from torch import nn
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder, PositionalEncoding
|
||||
from TTS.tts.layers.align_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path
|
||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||
|
@ -11,6 +12,7 @@ from TTS.tts.layers.align_tts.mdn import MDNBlock
|
|||
|
||||
|
||||
|
||||
|
||||
class AlignTTS(nn.Module):
|
||||
"""Speedy Speech model with Monotonic Alignment Search
|
||||
https://arxiv.org/abs/2008.03802
|
||||
|
@ -75,11 +77,9 @@ class AlignTTS(nn.Module):
|
|||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
|
||||
decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels +
|
||||
c_in_channels)
|
||||
self.duration_predictor = DurationPredictor(num_chars, hidden_channels, hidden_channels_ffn=1024, num_heads=2)
|
||||
|
||||
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
# self.wn_spec_encoder = WNSpecEncoder(out_channels, hidden_channels, c_in_channels=c_in_channels)
|
||||
self.mdn_block = MDNBlock(hidden_channels, 2*out_channels)
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
|
|
Loading…
Reference in New Issue