Merge branch 'dev'

pull/2435/head v0.12.0
Eren Gölge 2023-03-17 13:31:08 +01:00
commit 12f3365185
5 changed files with 29 additions and 16 deletions

View File

@ -1 +1 @@
0.11.1
0.12.0

View File

@ -179,6 +179,7 @@ class NeuralhmmTTS(BaseTTS):
Args:
aux_inputs (Dict): Dictionary containing the auxiliary inputs.
"""
default_input_dict = default_input_dict.copy()
default_input_dict.update(
{
"sampling_temp": self.sampling_temp,
@ -187,8 +188,8 @@ class NeuralhmmTTS(BaseTTS):
}
)
if aux_input:
return format_aux_input(aux_input, default_input_dict)
return None
return format_aux_input(default_input_dict, aux_input)
return default_input_dict
@torch.no_grad()
def inference(
@ -319,7 +320,7 @@ class NeuralhmmTTS(BaseTTS):
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)}
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
)
figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3))

View File

@ -192,6 +192,7 @@ class Overflow(BaseTTS):
Args:
aux_inputs (Dict): Dictionary containing the auxiliary inputs.
"""
default_input_dict = default_input_dict.copy()
default_input_dict.update(
{
"sampling_temp": self.sampling_temp,
@ -200,8 +201,8 @@ class Overflow(BaseTTS):
}
)
if aux_input:
return format_aux_input(aux_input, default_input_dict)
return None
return format_aux_input(default_input_dict, aux_input)
return default_input_dict
@torch.no_grad()
def inference(
@ -335,7 +336,7 @@ class Overflow(BaseTTS):
# sample one item from the batch -1 will give the smalles item
print(" | > Synthesising audio from the model...")
inference_output = self.inference(
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lenghts": batch["text_lengths"][-1].unsqueeze(0)}
batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)}
)
figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3))

View File

@ -1628,13 +1628,23 @@ class Vits(BaseTTS):
pin_memory=False,
)
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
if num_gpus > 1:
loader = DataLoader(
dataset,
sampler=sampler,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def get_optimizer(self) -> List:

View File

@ -167,9 +167,10 @@ def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
Returns:
Dict: arguments with formatted auxilary inputs.
"""
kwargs = kwargs.copy()
for name in def_args:
if name not in kwargs:
kwargs[def_args[name]] = None
if name not in kwargs or kwargs[name] is None:
kwargs[name] = def_args[name]
return kwargs