mirror of https://github.com/coqui-ai/TTS.git
make using different samples for G and D networks optional
parent
67f8248492
commit
57f6bd1afa
|
@ -42,6 +42,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
return_pairs=c.diff_samples_for_G_and_D if 'diff_samples_for_G_and_D' in c else False,
|
||||
is_training=not is_val,
|
||||
return_segments=not is_val,
|
||||
use_noise_augment=c.use_noise_augment,
|
||||
|
@ -62,25 +63,19 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
|
||||
def format_data(data):
|
||||
if isinstance(data[0], list):
|
||||
# setup input data
|
||||
c_G, x_G = data[0]
|
||||
c_D, x_D = data[1]
|
||||
|
||||
# dispatch data to GPU
|
||||
x_G, y_G = data[0]
|
||||
x_D, y_D = data[1]
|
||||
if use_cuda:
|
||||
c_G = c_G.cuda(non_blocking=True)
|
||||
x_G = x_G.cuda(non_blocking=True)
|
||||
c_D = c_D.cuda(non_blocking=True)
|
||||
y_G = y_G.cuda(non_blocking=True)
|
||||
x_D = x_D.cuda(non_blocking=True)
|
||||
|
||||
return c_G, x_G, c_D, x_D
|
||||
|
||||
# return a whole audio segment
|
||||
co, x = data
|
||||
y_D = y_D.cuda(non_blocking=True)
|
||||
return x_G, y_G, x_D, y_D
|
||||
x, y = data
|
||||
if use_cuda:
|
||||
co = co.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
return co, x, None, None
|
||||
y = y.cuda(non_blocking=True)
|
||||
return x, y, None, None
|
||||
|
||||
|
||||
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||
|
@ -143,13 +138,20 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
# we don't need scores for real samples for training G since they are always 1
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
feats_real, y_hat_sub, y_G_sub)
|
||||
loss_G_dict = criterion_G(y_hat=y_hat,
|
||||
y=y_G,
|
||||
scores_fake=scores_fake,
|
||||
feats_fake=feats_fake,
|
||||
feats_real=feats_real,
|
||||
y_hat_sub=y_hat_sub,
|
||||
y_sub=y_G_sub)
|
||||
|
||||
loss_G = loss_G_dict['G_loss']
|
||||
|
||||
# optimizer generator
|
||||
|
@ -174,16 +176,22 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
##############################
|
||||
if global_step >= c.steps_to_start_discriminator:
|
||||
# discriminator pass
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_D)
|
||||
if c.diff_samples_for_G_and_D:
|
||||
# use a different sample than generator
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_D)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
else:
|
||||
# use the same samples as generator
|
||||
c_D = c_G.clone()
|
||||
y_D = y_G.clone()
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat.detach(), c_D)
|
||||
D_out_fake = model_D(y_hat.detach().clone(), c_D)
|
||||
D_out_real = model_D(y_D, c_D)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat.detach())
|
||||
|
@ -191,12 +199,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
# model_D returns scores and features
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
# model D returns only scores
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
|
@ -283,6 +293,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
{'train/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
@ -422,6 +433,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if args.rank == 0:
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"log_func": "np.log",
|
||||
"do_sound_norm": true,
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": false,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
|
@ -89,6 +90,7 @@
|
|||
// "downsample_factors":[4, 4, 4]
|
||||
//},
|
||||
"steps_to_start_discriminator": 0, // steps required to start GAN trainining.1
|
||||
"diff_samples_for_G_and_D": false, // draw a new sample from the dataset for the D pass.
|
||||
|
||||
// GENERATOR
|
||||
"generator_model": "hifigan_generator",
|
||||
|
|
|
@ -20,6 +20,7 @@ class GANDataset(Dataset):
|
|||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
return_pairs=False,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
|
@ -33,6 +34,7 @@ class GANDataset(Dataset):
|
|||
self.hop_len = hop_len
|
||||
self.pad_short = pad_short
|
||||
self.conv_pad = conv_pad
|
||||
self.return_pairs = return_pairs
|
||||
self.is_training = is_training
|
||||
self.return_segments = return_segments
|
||||
self.use_cache = use_cache
|
||||
|
@ -65,11 +67,17 @@ class GANDataset(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
""" Return different items for Generator and Discriminator and
|
||||
cache acoustic features """
|
||||
|
||||
# set the seed differently for each worker
|
||||
random.seed(torch.utils.data.get_worker_info().seed)
|
||||
|
||||
if self.return_segments:
|
||||
idx2 = self.G_to_D_mappings[idx]
|
||||
item1 = self.load_item(idx)
|
||||
item2 = self.load_item(idx2)
|
||||
return item1, item2
|
||||
if self.return_pairs:
|
||||
idx2 = self.G_to_D_mappings[idx]
|
||||
item2 = self.load_item(idx2)
|
||||
return item1, item2
|
||||
return item1
|
||||
item1 = self.load_item(idx)
|
||||
return item1
|
||||
|
||||
|
|
Loading…
Reference in New Issue