revert logging.info to print statements for trainer

pull/602/head
Eren Gölge 2021-05-27 11:38:46 +02:00
parent fd6afe5ae5
commit c7ff175592
1 changed files with 5 additions and 16 deletions

View File

@ -150,7 +150,7 @@ class TrainerTTS:
# count model size
num_params = count_parameters(self.model)
logging.info("\n > Model has {} parameters".format(num_params))
print("\n > Model has {} parameters".format(num_params))
@staticmethod
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
@ -186,7 +186,6 @@ class TrainerTTS:
out_path: str = "",
data_train: List = []) -> SpeakerManager:
speaker_manager = SpeakerManager()
if config.use_speaker_embedding:
if restore_path:
speakers_file = os.path.join(os.path.dirname(restore_path),
"speaker.json")
@ -196,16 +195,6 @@ class TrainerTTS:
)
speakers_file = config.external_speaker_embedding_file
if config.use_external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(speakers_file)
else:
speaker_manager.load_ids_file(speakers_file)
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(
config.external_speaker_embedding_file)
else:
speaker_manager.parse_speakers_from_items(data_train)
file_path = os.path.join(out_path, "speakers.json")
speaker_manager.save_ids_file(file_path)
return speaker_manager
@ -238,15 +227,15 @@ class TrainerTTS:
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path)
try:
logging.info(" > Restoring Model...")
print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
logging.info(" > Restoring Optimizer...")
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and config.mixed_precision:
logging.info(" > Restoring AMP Scaler...")
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
except (KeyError, RuntimeError):
logging.info(" > Partial model initialization...")
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
model.load_state_dict(model_dict)