mirror of https://github.com/coqui-ai/TTS.git
refactor(tacotron): remove duplicate function
parent
0f69d31f70
commit
fa844e0fb7
|
@ -3,6 +3,8 @@ from torch import nn
|
|||
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.tacotron.common_layers import calculate_post_conv_height
|
||||
|
||||
|
||||
class CapacitronVAE(nn.Module):
|
||||
"""Effective Use of Variational Embedding Capacity for prosody transfer.
|
||||
|
@ -97,7 +99,7 @@ class ReferenceEncoder(nn.Module):
|
|||
self.training = False
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
|
||||
|
||||
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
|
||||
post_conv_height = calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
|
||||
self.recurrence = nn.LSTM(
|
||||
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
|
||||
)
|
||||
|
@ -155,13 +157,6 @@ class ReferenceEncoder(nn.Module):
|
|||
|
||||
return last_output.to(inputs.device) # [B, 128]
|
||||
|
||||
@staticmethod
|
||||
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for _ in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
return height
|
||||
|
||||
|
||||
class TextSummary(nn.Module):
|
||||
def __init__(self, embedding_dim, encoder_output_dim):
|
||||
|
|
|
@ -3,6 +3,13 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def calculate_post_conv_height(height: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int:
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for _ in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
return height
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""Linear layer with a specific initialization.
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.tacotron.common_layers import calculate_post_conv_height
|
||||
|
||||
|
||||
class GST(nn.Module):
|
||||
"""Global Style Token Module for factorizing prosody in speech.
|
||||
|
@ -44,7 +46,7 @@ class ReferenceEncoder(nn.Module):
|
|||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
|
||||
|
||||
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers)
|
||||
post_conv_height = calculate_post_conv_height(num_mel, 3, 2, 1, num_layers)
|
||||
self.recurrence = nn.GRU(
|
||||
input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True
|
||||
)
|
||||
|
@ -71,13 +73,6 @@ class ReferenceEncoder(nn.Module):
|
|||
|
||||
return out.squeeze(0)
|
||||
|
||||
@staticmethod
|
||||
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for _ in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
return height
|
||||
|
||||
|
||||
class StyleTokenLayer(nn.Module):
|
||||
"""NN Module attending to style tokens based on prosody encodings."""
|
||||
|
|
Loading…
Reference in New Issue