Fix WaveRNN config and test

pull/847/head
Eren Gölge 2021-09-30 16:20:12 +00:00
parent 55d9209221
commit 7edbe04fe0
2 changed files with 7 additions and 7 deletions

View File

@ -17,11 +17,11 @@ class BaseVocoderConfig(BaseTrainingConfig):
Number of instances used for evaluation. Defaults to 10.
data_path (str):
Root path of the training data. All the audio files found recursively from this root path are used for
training. Defaults to MISSING.
training. Defaults to `""`.
feature_path (str):
Root path to the precomputed feature files. Defaults to None.
seq_len (int):
Length of the waveform segments used for training. Defaults to MISSING.
Length of the waveform segments used for training. Defaults to 1000.
pad_short (int):
Extra padding for the waveforms shorter than `seq_len`. Defaults to 0.
conv_path (int):
@ -45,9 +45,9 @@ class BaseVocoderConfig(BaseTrainingConfig):
use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms.
eval_split_size: int = 10 # number of samples used for evaluation.
# dataset
data_path: str = MISSING # root data path. It finds all wav files recursively from there.
data_path: str = "" # root data path. It finds all wav files recursively from there.
feature_path: str = None # if you use precomputed features
seq_len: int = MISSING # signal length used in training.
seq_len: int = 1000 # signal length used in training.
pad_short: int = 0 # additional padding for short wavs
conv_pad: int = 0 # additional padding against convolutions applied to spectrograms
use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM.

View File

@ -12,7 +12,7 @@ def test_wavernn():
config.model_args = WavernnArgs(
rnn_dims=512,
fc_dims=512,
mode=10,
mode="mold",
mulaw=False,
pad=2,
use_aux_net=True,
@ -37,13 +37,13 @@ def test_wavernn():
assert np.all(output.shape == (2, 1280, 30)), output.shape
# mode: gauss
config.model_params.mode = "gauss"
config.model_args.mode = "gauss"
model = Wavernn(config)
output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 2)), output.shape
# mode: quantized
config.model_params.mode = 4
config.model_args.mode = 4
model = Wavernn(config)
output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape