mirror of https://github.com/coqui-ai/TTS.git
Make style
parent
a1c431e6a9
commit
c03768bb53
|
@ -107,7 +107,7 @@ class FastSpeechConfig(BaseTTSConfig):
|
|||
base_model: str = "forward_tts"
|
||||
|
||||
# model specific params
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False)
|
||||
model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=False))
|
||||
|
||||
# multi-speaker settings
|
||||
num_speakers: int = 0
|
||||
|
|
|
@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig):
|
|||
base_model: str = "forward_tts"
|
||||
|
||||
# model specific params
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True)
|
||||
model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=True, use_energy=True))
|
||||
|
||||
# multi-speaker settings
|
||||
num_speakers: int = 0
|
||||
|
|
|
@ -103,26 +103,28 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
base_model: str = "forward_tts"
|
||||
|
||||
# set model args as SpeedySpeech
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs(
|
||||
use_pitch=False,
|
||||
encoder_type="residual_conv_bn",
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13,
|
||||
},
|
||||
decoder_type="residual_conv_bn",
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17,
|
||||
},
|
||||
out_channels=80,
|
||||
hidden_channels=128,
|
||||
positional_encoding=True,
|
||||
detach_duration_predictor=True,
|
||||
model_args: ForwardTTSArgs = field(
|
||||
default_factory=lambda: ForwardTTSArgs(
|
||||
use_pitch=False,
|
||||
encoder_type="residual_conv_bn",
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13,
|
||||
},
|
||||
decoder_type="residual_conv_bn",
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17,
|
||||
},
|
||||
out_channels=80,
|
||||
hidden_channels=128,
|
||||
positional_encoding=True,
|
||||
detach_duration_predictor=True,
|
||||
)
|
||||
)
|
||||
|
||||
# multi-speaker settings
|
||||
|
|
|
@ -165,7 +165,7 @@ class BCELossMasked(nn.Module):
|
|||
|
||||
def __init__(self, pos_weight: float = None):
|
||||
super().__init__()
|
||||
self.pos_weight = nn.Parameter(torch.tensor([pos_weight]), requires_grad=False)
|
||||
self.register_buffer("pos_weight", torch.tensor([pos_weight]))
|
||||
|
||||
def forward(self, x, target, length):
|
||||
"""
|
||||
|
@ -191,10 +191,15 @@ class BCELossMasked(nn.Module):
|
|||
mask = sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||
num_items = mask.sum()
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
x.masked_select(mask), target.masked_select(mask), pos_weight=self.pos_weight, reduction="sum"
|
||||
x.masked_select(mask),
|
||||
target.masked_select(mask),
|
||||
pos_weight=self.pos_weight.to(x.device),
|
||||
reduction="sum",
|
||||
)
|
||||
else:
|
||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
x, target, pos_weight=self.pos_weight.to(x.device), reduction="sum"
|
||||
)
|
||||
num_items = torch.numel(x)
|
||||
loss = loss / num_items
|
||||
return loss
|
||||
|
|
|
@ -16,7 +16,7 @@ from TTS.utils.audio import AudioProcessor
|
|||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
config_global = TacotronConfig(num_chars=32, num_speakers=5, out_channels=513, decoder_output_dim=80)
|
||||
|
||||
|
@ -288,7 +288,6 @@ class TacotronCapacitronTrainTest(unittest.TestCase):
|
|||
batch["text_input"].shape[0], batch["stop_targets"].size(1) // config.r, -1
|
||||
)
|
||||
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
model = Tacotron(config).to(device)
|
||||
criterion = model.get_criterion()
|
||||
optimizer = model.get_optimizer()
|
||||
|
|
Loading…
Reference in New Issue