mirror of https://github.com/coqui-ai/TTS.git
revert logging.info to print statements for trainer
parent
fd6afe5ae5
commit
c7ff175592
|
@ -150,7 +150,7 @@ class TrainerTTS:
|
|||
|
||||
# count model size
|
||||
num_params = count_parameters(self.model)
|
||||
logging.info("\n > Model has {} parameters".format(num_params))
|
||||
print("\n > Model has {} parameters".format(num_params))
|
||||
|
||||
@staticmethod
|
||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
||||
|
@ -186,7 +186,6 @@ class TrainerTTS:
|
|||
out_path: str = "",
|
||||
data_train: List = []) -> SpeakerManager:
|
||||
speaker_manager = SpeakerManager()
|
||||
if config.use_speaker_embedding:
|
||||
if restore_path:
|
||||
speakers_file = os.path.join(os.path.dirname(restore_path),
|
||||
"speaker.json")
|
||||
|
@ -196,16 +195,6 @@ class TrainerTTS:
|
|||
)
|
||||
speakers_file = config.external_speaker_embedding_file
|
||||
|
||||
if config.use_external_speaker_embedding_file:
|
||||
speaker_manager.load_x_vectors_file(speakers_file)
|
||||
else:
|
||||
speaker_manager.load_ids_file(speakers_file)
|
||||
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
|
||||
speaker_manager.load_x_vectors_file(
|
||||
config.external_speaker_embedding_file)
|
||||
else:
|
||||
speaker_manager.parse_speakers_from_items(data_train)
|
||||
file_path = os.path.join(out_path, "speakers.json")
|
||||
speaker_manager.save_ids_file(file_path)
|
||||
return speaker_manager
|
||||
|
||||
|
@ -238,15 +227,15 @@ class TrainerTTS:
|
|||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||
checkpoint = torch.load(restore_path)
|
||||
try:
|
||||
logging.info(" > Restoring Model...")
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
logging.info(" > Restoring Optimizer...")
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scaler" in checkpoint and config.mixed_precision:
|
||||
logging.info(" > Restoring AMP Scaler...")
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except (KeyError, RuntimeError):
|
||||
logging.info(" > Partial model initialization...")
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||
model.load_state_dict(model_dict)
|
||||
|
|
Loading…
Reference in New Issue