Update gan

pull/1324/head
Eren Gölge 2022-02-20 11:37:34 +01:00
parent 20a677c623
commit fc3b6d2861
1 changed files with 5 additions and 5 deletions

View File

@ -80,8 +80,8 @@ class GAN(BaseVocoder):
Returns: Returns:
Tuple[Dict, Dict]: model outputs and the computed loss values. Tuple[Dict, Dict]: model outputs and the computed loss values.
""" """
outputs = None outputs = {}
loss_dict = None loss_dict = {}
x = batch["input"] x = batch["input"]
y = batch["waveform"] y = batch["waveform"]
@ -311,7 +311,7 @@ class GAN(BaseVocoder):
config: Coqpit, config: Coqpit,
assets: Dict, assets: Dict,
is_eval: True, is_eval: True,
data_items: List, samples: List,
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
rank: int = None, # pylint: disable=unused-argument rank: int = None, # pylint: disable=unused-argument
@ -322,7 +322,7 @@ class GAN(BaseVocoder):
config (Coqpit): Model config. config (Coqpit): Model config.
ap (AudioProcessor): Audio processor. ap (AudioProcessor): Audio processor.
is_eval (True): Set the dataloader for evaluation if true. is_eval (True): Set the dataloader for evaluation if true.
data_items (List): Data samples. samples (List): Data samples.
verbose (bool): Log information if true. verbose (bool): Log information if true.
num_gpus (int): Number of GPUs in use. num_gpus (int): Number of GPUs in use.
rank (int): Rank of the current GPU. Defaults to None. rank (int): Rank of the current GPU. Defaults to None.
@ -332,7 +332,7 @@ class GAN(BaseVocoder):
""" """
dataset = GANDataset( dataset = GANDataset(
ap=self.ap, ap=self.ap,
items=data_items, items=samples,
seq_len=config.seq_len, seq_len=config.seq_len,
hop_len=self.ap.hop_length, hop_len=self.ap.hop_length,
pad_short=config.pad_short, pad_short=config.pad_short,