mirror of https://github.com/coqui-ai/TTS.git
implement the Speaker Encoder H/ASP
parent
85ccad7e0a
commit
3fcc748b2e
|
@ -133,3 +133,6 @@ TTS/tts/layers/glow_tts/monotonic_align/core.c
|
|||
.vscode-upload.json
|
||||
temp_build/*
|
||||
recipes/*
|
||||
|
||||
# nohup logs
|
||||
*.out
|
|
@ -12,8 +12,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from TTS.speaker_encoder.dataset import MyDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxLoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model
|
||||
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model, save_checkpoint, setup_model
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -66,21 +65,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
return loader, dataset.get_num_speakers()
|
||||
|
||||
|
||||
def train(model, optimizer, scheduler, ap, global_step):
|
||||
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
def train(model, optimizer, scheduler, criterion, data_loader, ap, global_step):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
|
@ -154,7 +139,7 @@ def train(model, optimizer, scheduler, ap, global_step):
|
|||
)
|
||||
|
||||
# save best model
|
||||
best_loss = save_best_model(model, optimizer, avg_loss, best_loss, OUT_PATH, global_step)
|
||||
best_loss = save_best_model(model, optimizer, criterion, avg_loss, best_loss, OUT_PATH, global_step)
|
||||
|
||||
end_time = time.time()
|
||||
return avg_loss, global_step
|
||||
|
@ -166,14 +151,24 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global meta_data_eval
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
model = SpeakerEncoder(
|
||||
input_dim=c.model["input_dim"],
|
||||
proj_dim=c.model["proj_dim"],
|
||||
lstm_dim=c.model["lstm_dim"],
|
||||
num_lstm_layers=c.model["num_lstm_layers"],
|
||||
)
|
||||
model = setup_model(c)
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
||||
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
try:
|
||||
|
@ -183,14 +178,19 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if c.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
except KeyError:
|
||||
|
||||
if 'criterion' in checkpoint:
|
||||
criterion.load_state_dict(checkpoint["criterion"])
|
||||
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint, c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = c.lr
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
|
@ -204,11 +204,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
global_step = args.restore_step
|
||||
_, global_step = train(model, optimizer, scheduler, ap, global_step)
|
||||
# save_checkpoint(model, optimizer, criterion, 0.9, '../', global_step, 1)
|
||||
_, global_step = train(model, optimizer, scheduler, criterion, data_loader, ap, global_step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
{
|
||||
"model_name": "lstm",
|
||||
"run_name": "mueller91",
|
||||
"run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ",
|
||||
"audio":{
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
{
|
||||
"model_name": "resnet",
|
||||
"run_name": "speaker_encoder",
|
||||
"run_description": "train speaker encoder with VCTK",
|
||||
"audio":{
|
||||
|
@ -41,7 +42,7 @@
|
|||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||
|
||||
// Speakers config
|
||||
"num_speakers_in_batch": 2, // Batch size for training.
|
||||
"num_speakers_in_batch": 128, // Batch size for training.
|
||||
"num_utters_per_speaker": 2, //
|
||||
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
|
||||
|
@ -91,10 +92,7 @@
|
|||
},
|
||||
"model": {
|
||||
"input_dim": 80,
|
||||
"proj_dim": 512,
|
||||
"lstm_dim": 768,
|
||||
"num_lstm_layers": 3,
|
||||
"use_lstm_with_projection": true
|
||||
"proj_dim": 512
|
||||
},
|
||||
"storage": {
|
||||
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
|
||||
{
|
||||
"model_name": "resnet",
|
||||
"run_name": "speaker_encoder",
|
||||
"run_description": "train speaker encoder with VCTK",
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
"num_mels": 64, // size of the mel spec frame.
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
"stft_pad_mode": "reflect",
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 20.0,
|
||||
"do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
"reinit_layers": [],
|
||||
|
||||
"loss": "softmaxproto", // "ge2e" to use Generalized End-to-End loss, "angleproto" to use Angular Prototypical loss and "softmaxproto" to use Softmax with Angular Prototypical loss
|
||||
"grad_clip": 3.0, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||
|
||||
// Speakers config
|
||||
"num_speakers_in_batch": 256, // Batch size for training.
|
||||
"num_utters_per_speaker": 2, //
|
||||
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
|
||||
"voice_len": 2, // number of seconds for each training instance
|
||||
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
"print_step": 20, // Number of steps to log traning on console.
|
||||
"output_path": "../../../checkpoints/speaker_encoder/continue-training-voxceleb-trainer/", // DATASET-RELATED: output path for all training outputs.
|
||||
|
||||
"audio_augmentation": {
|
||||
"p": 0.75, // propability of apply this method, 0 is disable rir and additive noise augmentation
|
||||
"rir":{
|
||||
"rir_path": "/workspace/store/ecasanova/ComParE/RIRS_NOISES/simulated_rirs/",
|
||||
"conv_mode": "full"
|
||||
},
|
||||
"additive":{
|
||||
"sounds_path": "/workspace/store/ecasanova/ComParE/musan/",
|
||||
// list of each of the directories in your data augmentation, if a directory is in "sounds_path" but is not listed here it will be ignored
|
||||
"speech":{
|
||||
"min_snr_in_db": 13,
|
||||
"max_snr_in_db": 20,
|
||||
"min_num_noises": 3,
|
||||
"max_num_noises": 7
|
||||
},
|
||||
"noise":{
|
||||
"min_snr_in_db": 0,
|
||||
"max_snr_in_db": 15,
|
||||
"min_num_noises": 1,
|
||||
"max_num_noises": 1
|
||||
},
|
||||
"music":{
|
||||
"min_snr_in_db": 5,
|
||||
"max_snr_in_db": 15,
|
||||
"min_num_noises": 1,
|
||||
"max_num_noises": 1
|
||||
}
|
||||
},
|
||||
//add a gaussian noise to the data in order to increase robustness
|
||||
"gaussian":{ // as the insertion of Gaussian noise is quick to be calculated, we added it after loading the wav file, this way, even audios that were reused with the cache can receive this noise
|
||||
"p": 0.5, // propability of apply this method, 0 is disable
|
||||
"min_amplitude": 0.0,
|
||||
"max_amplitude": 1e-5
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"input_dim": 64,
|
||||
"proj_dim": 512
|
||||
},
|
||||
"storage": {
|
||||
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 1 // the size of the in-memory storage with respect to a single batch
|
||||
},
|
||||
"datasets":
|
||||
[
|
||||
{
|
||||
"name": "voxceleb2",
|
||||
"path": "/workspace/scratch/ecasanova/datasets/VoxCeleb/vox2_dev_aac/",
|
||||
"meta_file_train": null,
|
||||
"meta_file_val": null
|
||||
}
|
||||
]
|
||||
}
|
|
@ -29,7 +29,7 @@ class LSTMWithoutProjection(nn.Module):
|
|||
return self.relu(self.linear(hidden[-1]))
|
||||
|
||||
|
||||
class SpeakerEncoder(nn.Module):
|
||||
class LSTMSpeakerEncoder(nn.Module):
|
||||
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
|
||||
super().__init__()
|
||||
self.use_lstm_with_projection = use_lstm_with_projection
|
|
@ -0,0 +1,157 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
class SEBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
|
||||
super(SEBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se = SELayer(planes, reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.relu(out)
|
||||
out = self.bn1(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
"""Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
|
||||
Adapted from: https://github.com/clovaai/voxceleb_trainer
|
||||
"""
|
||||
def __init__(self, input_dim=64, proj_dim=512, layers=[3, 4, 6, 3], num_filters=[32, 64, 128, 256], encoder_type='ASP', log_input=False):
|
||||
super(ResNetSpeakerEncoder, self).__init__()
|
||||
|
||||
self.encoder_type = encoder_type
|
||||
self.input_dim = input_dim
|
||||
self.log_input = log_input
|
||||
self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=3, stride=1, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.bn1 = nn.BatchNorm2d(num_filters[0])
|
||||
|
||||
self.inplanes = num_filters[0]
|
||||
self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
|
||||
self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
|
||||
self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
|
||||
self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
|
||||
|
||||
self.instancenorm = nn.InstanceNorm1d(input_dim)
|
||||
|
||||
outmap_size = int(self.input_dim/8)
|
||||
|
||||
self.attention = nn.Sequential(
|
||||
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
|
||||
nn.Softmax(dim=2),
|
||||
)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
out_dim = num_filters[3] * outmap_size
|
||||
elif self.encoder_type == "ASP":
|
||||
out_dim = num_filters[3] * outmap_size * 2
|
||||
else:
|
||||
raise ValueError('Undefined encoder')
|
||||
|
||||
self.fc = nn.Linear(out_dim, proj_dim)
|
||||
|
||||
self._init_layers()
|
||||
|
||||
def _init_layers(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def create_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def new_parameter(self, *size):
|
||||
out = nn.Parameter(torch.FloatTensor(*size))
|
||||
nn.init.xavier_normal_(out)
|
||||
return out
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, 2)
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if self.log_input: x = (x+1e-6).log()
|
||||
x = self.instancenorm(x).unsqueeze(1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.bn1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = x.reshape(x.size()[0],-1,x.size()[-1])
|
||||
|
||||
w = self.attention(x)
|
||||
|
||||
if self.encoder_type == "SAP":
|
||||
x = torch.sum(x * w, dim=2)
|
||||
elif self.encoder_type == "ASP":
|
||||
mu = torch.sum(x * w, dim=2)
|
||||
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu ** 2 ).clamp(min=1e-5) )
|
||||
x = torch.cat((mu, sg),1)
|
||||
|
||||
x = x.view(x.size()[0], -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
|
@ -8,7 +8,8 @@ import glob
|
|||
import random
|
||||
|
||||
from scipy import signal
|
||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.utils.generic_utils import check_argument
|
||||
|
||||
class AugmentWAV(object):
|
||||
|
@ -146,11 +147,14 @@ def to_camel(text):
|
|||
|
||||
|
||||
def setup_model(c):
|
||||
model = SpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"])
|
||||
if c.model_name.lower() == 'lstm':
|
||||
model = LSTMSpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"])
|
||||
elif c.model_name.lower() == 'resnet':
|
||||
model = ResNetSpeakerEncoder(input_dim=c.model["input_dim"], proj_dim=c.model["proj_dim"])
|
||||
return model
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch):
|
||||
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
|
||||
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
|
@ -159,6 +163,7 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch)
|
|||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"criterion": criterion.state_dict(),
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"loss": model_loss,
|
||||
|
@ -167,12 +172,13 @@ def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch)
|
|||
torch.save(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"criterion": criterion.state_dict(),
|
||||
"step": current_step,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
|
@ -234,11 +240,13 @@ def check_config_speaker_encoder(c):
|
|||
|
||||
# model parameters
|
||||
check_argument("model", c, restricted=True, val_type=dict)
|
||||
check_argument("model_name", c, restricted=True, val_type=str)
|
||||
check_argument("input_dim", c["model"], restricted=True, val_type=int)
|
||||
check_argument("proj_dim", c["model"], restricted=True, val_type=int)
|
||||
check_argument("lstm_dim", c["model"], restricted=True, val_type=int)
|
||||
check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int)
|
||||
check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool)
|
||||
if c.model_name.lower() == 'lstm':
|
||||
check_argument("proj_dim", c["model"], restricted=True, val_type=int)
|
||||
check_argument("lstm_dim", c["model"], restricted=True, val_type=int)
|
||||
check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int)
|
||||
check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool)
|
||||
|
||||
# in-memory storage parameters
|
||||
check_argument("storage", c, restricted=True, val_type=dict)
|
||||
|
|
Loading…
Reference in New Issue