Syncronize DDP processes

pull/725/head
Eren Gölge 2021-08-13 10:40:50 +00:00
parent ecf5f17dca
commit 7c0d564965
3 changed files with 28 additions and 6 deletions

View File

@ -84,7 +84,10 @@ class TrainingArgs(Coqpit):
config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
use_ddp: bool= field(default=False, metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."}) use_ddp: bool = field(
default=False,
metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
)
class Trainer: class Trainer:
@ -362,7 +365,9 @@ class Trainer:
) -> DataLoader: ) -> DataLoader:
if num_gpus > 1: if num_gpus > 1:
if hasattr(model.module, "get_data_loader"): if hasattr(model.module, "get_data_loader"):
loader = model.module.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank) loader = model.module.get_data_loader(
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
)
else: else:
if hasattr(model, "get_data_loader"): if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
@ -797,6 +802,7 @@ class Trainer:
loader_time = time.time() - loader_start_time loader_time = time.time() - loader_start_time
self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
outputs, _ = self.eval_step(batch, cur_step) outputs, _ = self.eval_step(batch, cur_step)
loader_start_time = time.time()
# plot epoch stats, artifacts and figures # plot epoch stats, artifacts and figures
if self.args.rank == 0: if self.args.rank == 0:
figures, audios = None, None figures, audios = None, None
@ -839,7 +845,7 @@ class Trainer:
self.total_steps_done = self.restore_step self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs): for epoch in range(0, self.config.epochs):
if self.num_gpus: if self.num_gpus > 1:
# let all processes sync up before starting with a new epoch of training # let all processes sync up before starting with a new epoch of training
dist.barrier() dist.barrier()
self.callbacks.on_epoch_start() self.callbacks.on_epoch_start()
@ -868,6 +874,9 @@ class Trainer:
self.callbacks.on_keyboard_interrupt() self.callbacks.on_keyboard_interrupt()
# if the output folder is empty remove the run. # if the output folder is empty remove the run.
remove_experiment_folder(self.output_path) remove_experiment_folder(self.output_path)
# clear the DDP processes
if self.num_gpus > 1:
dist.destroy_process_group()
# finish the wandb run and sync data # finish the wandb run and sync data
self.dashboard_logger.finish() self.dashboard_logger.finish()
# stop without error signal # stop without error signal

View File

@ -2,6 +2,7 @@ import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
import torch.distributed as dist
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -164,7 +165,14 @@ class BaseTTS(BaseModel):
} }
def get_data_loader( def get_data_loader(
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int=None self,
config: Coqpit,
ap: AudioProcessor,
is_eval: bool,
data_items: List,
verbose: bool,
num_gpus: int,
rank: int = None,
) -> "DataLoader": ) -> "DataLoader":
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
@ -228,6 +236,10 @@ class BaseTTS(BaseModel):
else: else:
self.train_data_items = dataset.items self.train_data_items = dataset.items
# halt DDP processes for the main process to finish computing the phoneme cache
if num_gpus > 1:
dist.barrier()
dataset.sort_items() dataset.sort_items()
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None

View File

@ -1,5 +1,5 @@
import importlib import importlib
from typing import Dict, List from typing import Dict, List, Tuple
import torch import torch
@ -9,7 +9,8 @@ from TTS.utils.training import NoamLR
def is_apex_available(): def is_apex_available():
return importlib.util.find_spec("apex") is not None return importlib.util.find_spec("apex") is not None
def setup_torch_training_env(cudnn_enable:bool, cudnn_benchmark:bool, use_ddp:bool=False) -> Tuple[bool, int]:
def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]:
"""Setup PyTorch environment for training. """Setup PyTorch environment for training.
Args: Args: