Syncronize DDP processes

pull/725/head
Eren Gölge 2021-08-13 10:40:50 +00:00
parent ecf5f17dca
commit 7c0d564965
3 changed files with 28 additions and 6 deletions

View File

@ -84,7 +84,10 @@ class TrainingArgs(Coqpit):
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
use_ddp: bool= field(default=False, metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."})
use_ddp: bool = field(
default=False,
metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
)
class Trainer:
@ -362,7 +365,9 @@ class Trainer:
) -> DataLoader:
if num_gpus > 1:
if hasattr(model.module, "get_data_loader"):
loader = model.module.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank)
loader = model.module.get_data_loader(
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
)
else:
if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
@ -797,6 +802,7 @@ class Trainer:
loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
outputs, _ = self.eval_step(batch, cur_step)
loader_start_time = time.time()
# plot epoch stats, artifacts and figures
if self.args.rank == 0:
figures, audios = None, None
@ -839,7 +845,7 @@ class Trainer:
self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs):
if self.num_gpus:
if self.num_gpus > 1:
# let all processes sync up before starting with a new epoch of training
dist.barrier()
self.callbacks.on_epoch_start()
@ -868,6 +874,9 @@ class Trainer:
self.callbacks.on_keyboard_interrupt()
# if the output folder is empty remove the run.
remove_experiment_folder(self.output_path)
# clear the DDP processes
if self.num_gpus > 1:
dist.destroy_process_group()
# finish the wandb run and sync data
self.dashboard_logger.finish()
# stop without error signal

View File

@ -2,6 +2,7 @@ import os
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
@ -164,7 +165,14 @@ class BaseTTS(BaseModel):
}
def get_data_loader(
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int=None
self,
config: Coqpit,
ap: AudioProcessor,
is_eval: bool,
data_items: List,
verbose: bool,
num_gpus: int,
rank: int = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@ -228,6 +236,10 @@ class BaseTTS(BaseModel):
else:
self.train_data_items = dataset.items
# halt DDP processes for the main process to finish computing the phoneme cache
if num_gpus > 1:
dist.barrier()
dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None

View File

@ -1,5 +1,5 @@
import importlib
from typing import Dict, List
from typing import Dict, List, Tuple
import torch
@ -9,7 +9,8 @@ from TTS.utils.training import NoamLR
def is_apex_available():
return importlib.util.find_spec("apex") is not None
def setup_torch_training_env(cudnn_enable:bool, cudnn_benchmark:bool, use_ddp:bool=False) -> Tuple[bool, int]:
def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]:
"""Setup PyTorch environment for training.
Args: