mirror of https://github.com/coqui-ai/TTS.git
load_checkpoint func for vocoder models
parent
ea39715305
commit
ca3743539a
|
@ -95,3 +95,11 @@ class MelganGenerator(nn.Module):
|
|||
nn.utils.remove_weight_norm(layer)
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
|
|
|
@ -39,6 +39,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
self.upsample_factors = upsample_factors
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.inference_padding = inference_padding
|
||||
self.use_weight_norm = use_weight_norm
|
||||
|
||||
# check the number of layers and stacks
|
||||
assert num_res_blocks % stacks == 0
|
||||
|
@ -156,3 +157,12 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
def receptive_field_size(self):
|
||||
return self._get_receptive_field_size(self.layers, self.stacks,
|
||||
self.kernel_size)
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
|
|
|
@ -175,3 +175,22 @@ class Wavegrad(nn.Module):
|
|||
self.x_conv = weight_norm(self.x_conv)
|
||||
self.out_conv = weight_norm(self.out_conv)
|
||||
self.y_conv = weight_norm(self.y_conv)
|
||||
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
if self.use_weight_norm:
|
||||
self.remove_weight_norm()
|
||||
betas = np.linspace(config['test_noise_schedule']['min_val'],
|
||||
config['test_noise_schedule']['max_val'],
|
||||
config['test_noise_schedule']['num_steps'])
|
||||
self.compute_noise_level(betas)
|
||||
else:
|
||||
betas = np.linspace(config['train_noise_schedule']['min_val'],
|
||||
config['train_noise_schedule']['max_val'],
|
||||
config['train_noise_schedule']['num_steps'])
|
||||
self.compute_noise_level(betas)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
import torch
|
||||
import importlib
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
|
Loading…
Reference in New Issue