add trainer and train_tts

pull/506/head
Eren Gölge 2021-05-25 14:45:27 +02:00
parent 34f8a74e4d
commit 5f07315722
2 changed files with 784 additions and 0 deletions

28
TTS/bin/train_tts.py Normal file
View File

@ -0,0 +1,28 @@
import os
import sys
import traceback
from TTS.utils.arguments import init_training
from TTS.utils.generic_utils import remove_experiment_folder
from TTS.trainer import TrainerTTS
def main():
# try:
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
sys.argv)
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
trainer.fit()
# except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH)
# try:
# sys.exit(0)
# except SystemExit:
# os._exit(0) # pylint: disable=protected-access
# except Exception: # pylint: disable=broad-except
# remove_experiment_folder(OUT_PATH)
# traceback.print_exc()
# sys.exit(1)
if __name__ == "__main__":
main()

756
TTS/trainer.py Normal file
View File

@ -0,0 +1,756 @@
# -*- coding: utf-8 -*-
import os
import sys
import time
import traceback
from random import randrange
import logging
import importlib
import numpy as np
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets import load_meta_data, TTSDataset
from TTS.tts.layers import setup_loss
from TTS.tts.models import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.arguments import init_training
from TTS.tts.utils.visual import plot_spectrogram, plot_alignment
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict, find_module
from TTS.utils.training import setup_torch_training_env, check_update
@dataclass
class TrainingArgs(Coqpit):
continue_path: str = field(
default='',
metadata={
'help':
'Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder.'
})
restore_path: str = field(
default='',
metadata={
'help':
'Path to a model checkpoit. Restore the model with the given checkpoint and start a new training.'
})
best_path: str = field(
default='',
metadata={
'help':
"Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
})
config_path: str = field(
default='', metadata={'help': 'Path to the configuration file.'})
rank: int = field(
default=0, metadata={'help': 'Process rank in distributed training.'})
group_id: str = field(
default='',
metadata={'help': 'Process group id in distributed training.'})
# pylint: disable=import-outside-toplevel, too-many-public-methods
class TrainerTTS:
use_cuda, num_gpus = setup_torch_training_env(True, False)
def __init__(self,
args,
config,
c_logger,
tb_logger,
model=None,
output_path=None):
self.args = args
self.config = config
self.c_logger = c_logger
self.tb_logger = tb_logger
self.output_path = output_path
self.total_steps_done = 0
self.epochs_done = 0
self.restore_step = 0
self.best_loss = float("inf")
self.train_loader = None
self.eval_loader = None
self.output_audio_path = os.path.join(output_path, 'test_audios')
self.keep_avg_train = None
self.keep_avg_eval = None
# model, audio processor, datasets, loss
# init audio processor
self.ap = AudioProcessor(**config.audio.to_dict())
# init character processor
self.model_characters = self.init_character_processor()
# load dataset samples
self.data_train, self.data_eval = load_meta_data(config.datasets)
# default speaker manager
self.speaker_manager = self.init_speaker_manager()
# init TTS model
if model is not None:
self.model = model
else:
self.model = self.init_model()
# setup criterion
self.criterion = self.init_criterion()
# DISTRUBUTED
if self.num_gpus > 1:
init_distributed(args.rank, self.num_gpus, args.group_id,
config.distributed["backend"],
config.distributed["url"])
# scalers for mixed precision training
self.scaler = torch.cuda.amp.GradScaler(
) if config.mixed_precision else None
# setup optimizer
self.optimizer = self.init_optimizer(self.model)
# setup scheduler
self.scheduler = self.init_scheduler(self.config, self.optimizer)
if self.args.restore_path:
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
self.config, args.restore_path, self.model, self.optimizer,
self.scaler)
if self.use_cuda:
self.model.cuda()
self.criterion.cuda()
# DISTRUBUTED
if self.num_gpus > 1:
self.model = DDP_th(self.model, device_ids=[args.rank])
# count model size
num_params = count_parameters(self.model)
logging.info("\n > Model has {} parameters".format(num_params),
flush=True)
def init_model(self):
model = setup_model(
len(self.model_characters),
self.speaker_manager.num_speakers,
self.config,
self.speaker_manager.x_vector_dim
if self.speaker_manager.x_vectors else None,
)
return model
def init_optimizer(self, model):
optimizer_name = self.config.optimizer
optimizer_params = self.config.optimizer_params
if optimizer_name.lower() == "radam":
module = importlib.import_module("TTS.utils.radam")
optimizer = getattr(module, "RAdam")
else:
optimizer = getattr(torch.optim, optimizer_name)
return optimizer(model.parameters(),
lr=self.config.lr,
**optimizer_params)
def init_character_processor(self):
# setup custom characters if set in config file.
# TODO: implement CharacterProcessor
if self.config.characters is not None:
symbols, phonemes = make_symbols(
**self.config.characters.to_dict())
else:
from TTS.tts.utils.text.symbols import symbols, phonemes
model_characters = phonemes if self.config.use_phonemes else symbols
return model_characters
def init_speaker_manager(self, restore_path: str = "", out_path: str = ""):
speaker_manager = SpeakerManager()
if restore_path:
speakers_file = os.path.join(os.path.dirname(restore_path),
"speaker.json")
if not os.path.exists(speakers_file):
logging.info(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
)
speakers_file = self.config.external_speaker_embedding_file
if self.config.use_external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(speakers_file)
else:
self.speaker_manage.load_speaker_mapping(speakers_file)
elif self.config.use_external_speaker_embedding_file and self.config.external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(
self.config.external_speaker_embedding_file)
else:
speaker_manager.parse_speakers_from_items(self.data_train)
file_path = os.path.join(out_path, "speakers.json")
speaker_manager.save_ids_file(file_path)
return speaker_manager
def init_scheduler(self, config, optimizer):
lr_scheduler = config.lr_scheduler
lr_scheduler_params = config.lr_scheduler_params
if lr_scheduler is None:
return None
if lr_scheduler.lower() == "noamlr":
from TTS.utils.training import NoamLR
scheduler = NoamLR
else:
scheduler = getattr(torch.optim, lr_scheduler)
return scheduler(optimizer, **lr_scheduler_params)
def init_criterion(self):
return setup_loss(self.config)
def restore_model(self,
config,
restore_path,
model,
optimizer,
scaler=None):
logging.info(f" > Restoring from {os.path.basename(restore_path)}...")
checkpoint = torch.load(restore_path, map_location="cpu")
try:
logging.info(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
# optimizer restore
logging.info(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and config.mixed_precision:
logging.info(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
except (KeyError, RuntimeError):
logging.info(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group["lr"] = self.config.lr
logging.info(" > Model restored from step %d" % checkpoint["step"],
flush=True)
restore_step = checkpoint["step"]
return model, optimizer, scaler, restore_step
def _setup_loader(self, r, ap, is_eval, data_items, verbose,
speaker_mapping):
if is_eval and not self.config.run_eval:
loader = None
else:
dataset = TTSDataset(
outputs_per_step=r,
text_cleaner=self.config.text_cleaner,
compute_linear_spec= 'tacotron' == self.config.model.lower(),
meta_data=data_items,
ap=ap,
tp=self.config.characters,
add_blank=self.config["add_blank"],
batch_group_size=0 if is_eval else
self.config.batch_group_size * self.config.batch_size,
min_seq_len=self.config.min_seq_len,
max_seq_len=self.config.max_seq_len,
phoneme_cache_path=self.config.phoneme_cache_path,
use_phonemes=self.config.use_phonemes,
phoneme_language=self.config.phoneme_language,
enable_eos_bos=self.config.enable_eos_bos_chars,
use_noise_augment=not is_eval,
verbose=verbose,
speaker_mapping=speaker_mapping
if self.config.use_speaker_embedding
and self.config.use_external_speaker_embedding_file else None,
)
if self.config.use_phonemes and self.config.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
dataset.compute_input_seq(self.config.num_loader_workers)
dataset.sort_items()
sampler = DistributedSampler(
dataset) if self.num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=self.config.eval_batch_size
if is_eval else self.config.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=self.config.num_val_loader_workers
if is_eval else self.config.num_loader_workers,
pin_memory=False,
)
return loader
def setup_train_dataloader(self, r, ap, data_items, verbose,
speaker_mapping):
return self._setup_loader(r, ap, False, data_items, verbose,
speaker_mapping)
def setup_eval_dataloder(self, r, ap, data_items, verbose,
speaker_mapping):
return self._setup_loader(r, ap, True, data_items, verbose,
speaker_mapping)
def format_batch(self, batch):
# setup input batch
text_input = batch[0]
text_lengths = batch[1]
speaker_names = batch[2]
linear_input = batch[3] if self.config.model.lower() in ["tacotron"
] else None
mel_input = batch[4]
mel_lengths = batch[5]
stop_targets = batch[6]
item_idx = batch[7]
speaker_embeddings = batch[8]
attn_mask = batch[9]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
# convert speaker names to ids
if self.config.use_speaker_embedding:
if self.config.use_external_speaker_embedding_file:
speaker_embeddings = batch[8]
speaker_ids = None
else:
speaker_ids = [
self.speaker_manager.speaker_ids[speaker_name]
for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None
else:
speaker_embeddings = None
speaker_ids = None
# compute durations from attention masks
if attn_mask is not None:
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
for idx, am in enumerate(attn_mask):
# compute raw durations
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
dur[c_idxs] = counts
# smooth the durations and set any 0 duration to 1
# by cutting off from the largest duration indeces.
extra_frames = dur.sum() - mel_lengths[idx]
largest_idxs = torch.argsort(-dur)[:extra_frames]
dur[largest_idxs] -= 1
assert (
dur.sum() == mel_lengths[idx]
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, :text_lengths[idx]] = dur
# set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // self.config.r,
-1)
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze(2)
# dispatch batch to GPU
if self.use_cuda:
text_input = text_input.cuda(non_blocking=True)
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(
non_blocking=True) if self.config.model.lower() in [
"tacotron"
] else None
stop_targets = stop_targets.cuda(non_blocking=True)
attn_mask = attn_mask.cuda(
non_blocking=True) if attn_mask is not None else None
durations = durations.cuda(
non_blocking=True) if attn_mask is not None else None
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
if speaker_embeddings is not None:
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
return {
"text_input": text_input,
"text_lengths": text_lengths,
"mel_input": mel_input,
"mel_lengths": mel_lengths,
"linear_input": linear_input,
"stop_targets": stop_targets,
"attn_mask": attn_mask,
"durations": durations,
"speaker_ids": speaker_ids,
"x_vectors": speaker_embeddings,
"max_text_length": max_text_length,
"max_spec_length": max_spec_length,
"item_idx": item_idx
}
def train_step(self, batch, batch_n_steps, step, loader_start_time):
self.on_train_step_start()
step_start_time = time.time()
# format data
batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time
# zero-out optimizer
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self.model.train_step(batch, self.criterion)
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(
f"Detected NaN loss at step {self.total_steps_done}.")
# optimizer step
if self.config.mixed_precision:
# model optimizer step in mixed precision mode
self.scaler.scale(loss_dict["loss"]).backward()
self.scaler.unscale_(self.optimizer)
grad_norm, _ = check_update(self.model,
self.config.grad_clip,
ignore_stopnet=True)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# main model optimizer step
loss_dict["loss"].backward()
grad_norm, _ = check_update(self.model,
self.config.grad_clip,
ignore_stopnet=True)
self.optimizer.step()
step_time = time.time() - step_start_time
# setup lr
if self.config.lr_scheduler:
self.scheduler.step()
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
self.keep_avg_train.update_values(update_train_values)
# print training progress
current_lr = self.optimizer.param_groups[0]["lr"]
if self.total_steps_done % self.config.print_step == 0:
log_dict = {
"max_spec_length": [batch["max_spec_length"],
1], # value, precision
"max_text_length": [batch["max_text_length"], 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
self.c_logger.print_train_step(batch_n_steps, step,
self.total_steps_done, log_dict,
loss_dict,
self.keep_avg_train.avg_values)
if self.args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if self.total_steps_done % self.config.tb_plot_step == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time,
}
iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done,
iter_stats)
if self.total_steps_done % self.config.save_step == 0:
if self.config.checkpoint:
# save model
save_checkpoint(
self.model,
self.optimizer,
self.total_steps_done,
self.epochs_done,
self.config.r,
self.output_path,
model_loss=loss_dict["loss"],
characters=self.model_characters,
scaler=self.scaler.state_dict()
if self.config.mixed_precision else None,
)
# training visualizations
figures, audios = self.model.train_log(self.ap, batch, outputs)
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.tb_logger.tb_train_audios(self.total_steps_done,
{"TrainAudio": audios},
self.ap.sample_rate)
self.total_steps_done += 1
self.on_train_step_end()
return outputs, loss_dict
def train_epoch(self):
self.model.train()
epoch_start_time = time.time()
if self.use_cuda:
batch_num_steps = int(
len(self.train_loader.dataset) /
(self.config.batch_size * self.num_gpus))
else:
batch_num_steps = int(
len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start()
loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader):
_, _ = self.train_step(batch, batch_num_steps, cur_step,
loader_start_time)
epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats
if self.args.rank == 0:
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(self.keep_avg_train.avg_values)
self.tb_logger.tb_train_epoch_stats(self.total_steps_done,
epoch_stats)
if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model,
self.total_steps_done)
def eval_step(self, batch, step):
with torch.no_grad():
step_start_time = time.time()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self.model.eval_step(
batch, self.criterion)
step_time = time.time() - step_start_time
# detach loss values
loss_dict_new = dict()
for key, value in loss_dict.items():
if isinstance(value, (int, float)):
loss_dict_new[key] = value
else:
loss_dict_new[key] = value.item()
loss_dict = loss_dict_new
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
update_eval_values["avg_step_time"] = step_time
self.keep_avg_eval.update_values(update_eval_values)
if self.config.print_eval:
self.c_logger.print_eval_step(step, loss_dict,
self.keep_avg_eval.avg_values)
return outputs, loss_dict
def eval_epoch(self):
self.model.eval()
if self.use_cuda:
batch_num_steps = int(
len(self.train_loader.dataset) /
(self.config.batch_size * self.num_gpus))
else:
batch_num_steps = int(
len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_eval_start()
loader_start_time = time.time()
for cur_step, batch in enumerate(self.eval_loader):
# format data
batch = self.format_batch(batch)
loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({'avg_loader_time': loader_time})
outputs, _ = self.eval_step(batch, cur_step)
# Plot epoch stats and samples from the last batch.
if self.args.rank == 0:
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.tb_logger.tb_eval_audios(self.total_steps_done,
{"EvalAudio": eval_audios},
self.ap.sample_rate)
def test_run(self, ):
logging.info(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
cond_inputs = self._get_cond_inputs()
for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis(
self.model,
sen,
self.config,
self.use_cuda,
self.ap,
speaker_id=cond_inputs['speaker_id'],
x_vector=cond_inputs['x_vector'],
style_wav=cond_inputs['style_wav'],
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
self.ap.save_wav(wav, file_path)
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
model_outputs, self.ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(
alignment, output_fig=False)
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
def _get_cond_inputs(self):
# setup speaker_id
speaker_id = 0 if self.config.use_speaker_embedding else None
# setup x_vector
x_vector = self.speaker_manager.get_x_vectors_by_speaker(
self.speaker_manager.speaker_ids[0]
) if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding else None
# setup style_mel
if self.config.has('gst_style_input'):
style_wav = self.config.gst_style_input
else:
style_wav = None
if style_wav is None and 'use_gst' in self.config and self.config.use_gst:
# inicialize GST with zero dict.
style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(self.config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0
cond_inputs = {'speaker_id': speaker_id, 'style_wav': style_wav, 'x_vector': x_vector}
return cond_inputs
def fit(self):
if self.restore_step != 0 or self.args.best_path:
logging.info(" > Restoring best loss from "
f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path,
map_location="cpu")["model_loss"]
logging.info(
f" > Starting with loaded last best loss {self.best_loss}.")
# define data loaders
self.train_loader = self.setup_train_dataloader(
self.config.r,
self.ap,
self.data_train,
verbose=True,
speaker_mapping=self.speaker_manager.speaker_ids)
self.eval_loader = self.setup_eval_dataloder(
self.config.r,
self.ap,
self.data_train,
verbose=True,
speaker_mapping=self.speaker_manager.speaker_ids
) if self.config.run_eval else None
self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs):
self.on_epoch_start()
self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage(
) if self.config.run_eval else None
self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs)
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
if epoch >= self.config.test_delay_epochs:
self.test_run()
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values
if self.config.run_eval else self.keep_avg_train.avg_values)
self.save_best_model()
self.on_epoch_end()
def save_best_model(self):
self.best_loss = save_best_model(
self.keep_avg_eval['avg_loss']
if self.keep_avg_eval else self.keep_avg_train['avg_loss'],
self.best_loss,
self.model,
self.optimizer,
self.total_steps_done,
self.epochs_done,
self.config.r,
self.output_path,
self.model_characters,
keep_all_best=self.config.keep_all_best,
keep_after=self.config.keep_after,
scaler=self.scaler.state_dict()
if self.config.mixed_precision else None,
)
def on_epoch_start(self):
if hasattr(self.model, 'on_epoch_start'):
self.model.on_epoch_start(self)
if hasattr(self.criterion, "on_epoch_start"):
self.criterion.on_epoch_start(self)
if hasattr(self.optimizer, "on_epoch_start"):
self.optimizer.on_epoch_start(self)
def on_epoch_end(self):
if hasattr(self.model, "on_epoch_start"):
self.model.on_epoch_end(self)
if hasattr(self.criterion, "on_epoch_end"):
self.criterion.on_epoch_end(self)
if hasattr(self.optimizer, "on_epoch_end"):
self.optimizer.on_epoch_end(self)
def on_train_step_start(self):
if hasattr(self.model, "on_epoch_start"):
self.model.on_train_step_start(self)
if hasattr(self.criterion, "on_train_step_start"):
self.criterion.on_train_step_start(self)
if hasattr(self.optimizer, "on_train_step_start"):
self.optimizer.on_train_step_start(self)
def on_train_step_end(self):
if hasattr(self.model, "on_train_step_end"):
self.model.on_train_step_end(self)
if hasattr(self.criterion, "on_train_step_end"):
self.criterion.on_train_step_end(self)
if hasattr(self.optimizer, "on_train_step_end"):
self.optimizer.on_train_step_end(self)