Formating and printing more about the model

pull/10/head
Eren G 2018-08-08 18:45:02 +02:00
parent d5febfb187
commit ecd31af125
5 changed files with 25 additions and 20 deletions

View File

@ -3,7 +3,7 @@
"audio_processor": "audio",
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 22050,
"sample_rate": 22000,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
@ -21,7 +21,7 @@
"eval_batch_size":-1,
"r": 5,
"griffin_lim_iters": 60,
"griffin_lim_iters": 50,
"power": 1.5,
"num_loader_workers": 8,

View File

@ -3,7 +3,7 @@
"audio_processor": "audio",
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 22050,
"sample_rate": 22000,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
@ -21,7 +21,7 @@
"eval_batch_size":-1,
"r": 5,
"griffin_lim_iters": 60,
"griffin_lim_iters": 50,
"power": 1.5,
"num_loader_workers": 8,

View File

@ -62,8 +62,8 @@ class TacotronTrainTest(unittest.TestCase):
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
# ignore pre-higway layer since it works conditional
if count not in [148, 59]:
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1
# if count not in [145, 59]:
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1

View File

@ -37,7 +37,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_step_time = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
batch_n_iter = len(data_loader.dataset) / c.batch_size
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -321,13 +321,14 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
# test sentences
ap.griffin_lim_iters = 60
for idx, test_sentence in enumerate(test_sentences):
try:
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
use_cuda, c.text_cleaner)
try:
wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate)
wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate)
except:
print(" !! Error as creating Test Sentence -", idx)
pass
align_img = alignments[0].data.cpu().numpy()
linear_spec = plot_spectrogram(linear_spec, ap)

View File

@ -23,6 +23,7 @@ class AudioProcessor(object):
max_mel_freq,
griffin_lim_iters=None):
print(" > Setting up Audio Processor...")
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
@ -36,11 +37,12 @@ class AudioProcessor(object):
self.max_mel_freq = max_mel_freq
self.griffin_lim_iters = griffin_lim_iters
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
if preemphasis == 0:
print(" | > Preemphasis is deactive.")
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(
path, wav.astype(np.int16), self.sample_rate)
librosa.output.write_wav(path, wav.astype(np.int16), self.sample_rate)
def _linear_to_mel(self, spectrogram):
global _mel_basis
@ -64,6 +66,10 @@ class AudioProcessor(object):
n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
hop_length = 256
win_length = 1024
print(" | > fft size: {}, hop length: {}, win length: {}".format(
n_fft, hop_length, win_length))
return n_fft, hop_length, win_length
def _amp_to_db(self, x):
@ -123,13 +129,11 @@ class AudioProcessor(object):
return self._normalize(S)
def _stft(self, y):
n_fft, hop_length, win_length = self._stft_parameters()
return librosa.stft(
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
y=y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
def _istft(self, y):
_, hop_length, win_length = self._stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec)