refactor(tacotron): remove duplicate function

pull/4115/head^2
Enno Hermann 2024-11-22 21:35:26 +01:00
parent 0f69d31f70
commit fa844e0fb7
3 changed files with 13 additions and 16 deletions

View File

@ -3,6 +3,8 @@ from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
from torch.nn import functional as F
from TTS.tts.layers.tacotron.common_layers import calculate_post_conv_height
class CapacitronVAE(nn.Module):
"""Effective Use of Variational Embedding Capacity for prosody transfer.
@ -97,7 +99,7 @@ class ReferenceEncoder(nn.Module):
self.training = False
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
post_conv_height = calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
self.recurrence = nn.LSTM(
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
)
@ -155,13 +157,6 @@ class ReferenceEncoder(nn.Module):
return last_output.to(inputs.device) # [B, 128]
@staticmethod
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
for _ in range(n_convs):
height = (height - kernel_size + 2 * pad) // stride + 1
return height
class TextSummary(nn.Module):
def __init__(self, embedding_dim, encoder_output_dim):

View File

@ -3,6 +3,13 @@ from torch import nn
from torch.nn import functional as F
def calculate_post_conv_height(height: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int:
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
for _ in range(n_convs):
height = (height - kernel_size + 2 * pad) // stride + 1
return height
class Linear(nn.Module):
"""Linear layer with a specific initialization.

View File

@ -2,6 +2,8 @@ import torch
import torch.nn.functional as F
from torch import nn
from TTS.tts.layers.tacotron.common_layers import calculate_post_conv_height
class GST(nn.Module):
"""Global Style Token Module for factorizing prosody in speech.
@ -44,7 +46,7 @@ class ReferenceEncoder(nn.Module):
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers)
post_conv_height = calculate_post_conv_height(num_mel, 3, 2, 1, num_layers)
self.recurrence = nn.GRU(
input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True
)
@ -71,13 +73,6 @@ class ReferenceEncoder(nn.Module):
return out.squeeze(0)
@staticmethod
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
for _ in range(n_convs):
height = (height - kernel_size + 2 * pad) // stride + 1
return height
class StyleTokenLayer(nn.Module):
"""NN Module attending to style tokens based on prosody encodings."""