load_checkpoint func for vocoder models

pull/10/head
root 2021-01-20 02:12:29 +00:00
parent ea39715305
commit ca3743539a
4 changed files with 38 additions and 0 deletions

View File

@ -95,3 +95,11 @@ class MelganGenerator(nn.Module):
nn.utils.remove_weight_norm(layer)
except ValueError:
layer.remove_weight_norm()
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()

View File

@ -39,6 +39,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
self.upsample_factors = upsample_factors
self.upsample_scale = np.prod(upsample_factors)
self.inference_padding = inference_padding
self.use_weight_norm = use_weight_norm
# check the number of layers and stacks
assert num_res_blocks % stacks == 0
@ -156,3 +157,12 @@ class ParallelWaveganGenerator(torch.nn.Module):
def receptive_field_size(self):
return self._get_receptive_field_size(self.layers, self.stacks,
self.kernel_size)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training
if self.use_weight_norm:
self.remove_weight_norm()

View File

@ -175,3 +175,22 @@ class Wavegrad(nn.Module):
self.x_conv = weight_norm(self.x_conv)
self.out_conv = weight_norm(self.out_conv)
self.y_conv = weight_norm(self.y_conv)
def load_checkpoint(self, config, checkpoint_path, eval=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training
if self.use_weight_norm:
self.remove_weight_norm()
betas = np.linspace(config['test_noise_schedule']['min_val'],
config['test_noise_schedule']['max_val'],
config['test_noise_schedule']['num_steps'])
self.compute_noise_level(betas)
else:
betas = np.linspace(config['train_noise_schedule']['min_val'],
config['train_noise_schedule']['max_val'],
config['train_noise_schedule']['num_steps'])
self.compute_noise_level(betas)

View File

@ -1,4 +1,5 @@
import re
import torch
import importlib
import numpy as np
from matplotlib import pyplot as plt