Add a series of test cases using the new api
The primary purpose is to ensure all the scripts run end to end rather than completely verifying its functionalitypull/102/head
parent
a37958f3f2
commit
6196bdca26
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from precise.scripts.train import TrainScript
|
||||
from test.scripts.test_train import DummyTrainFolder
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def train_folder():
|
||||
folder = DummyTrainFolder(10)
|
||||
try:
|
||||
yield folder
|
||||
finally:
|
||||
folder.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def train_script(train_folder):
|
||||
return TrainScript.create(model=train_folder.model, folder=train_folder.root, epochs=1)
|
|
@ -0,0 +1,55 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import atexit
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from os import makedirs
|
||||
from os.path import isdir, join
|
||||
from shutil import rmtree
|
||||
from tempfile import mkdtemp
|
||||
|
||||
from precise.params import pr
|
||||
from precise.util import save_audio
|
||||
|
||||
|
||||
class DummyAudioFolder:
|
||||
def __init__(self, count=10):
|
||||
self.count = count
|
||||
self.root = mkdtemp()
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def rand(self, min, max):
|
||||
return min + (max - min) * np.random.random() * pr.buffer_t
|
||||
|
||||
def generate_samples(self, folder, name, value, duration):
|
||||
for i in range(self.count):
|
||||
save_audio(join(folder, name.format(i)), np.array([value] * int(duration * pr.sample_rate)))
|
||||
|
||||
def subdir(self, *parts):
|
||||
folder = self.path(*parts)
|
||||
if not isdir(folder):
|
||||
makedirs(folder)
|
||||
return folder
|
||||
|
||||
def path(self, *path):
|
||||
return join(self.root, *path)
|
||||
|
||||
def count_files(self, folder):
|
||||
return sum([len(files) for r, d, files in os.walk(folder)])
|
||||
|
||||
def cleanup(self):
|
||||
if isdir(self.root):
|
||||
rmtree(self.root)
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from precise.scripts.add_noise import AddNoiseScript
|
||||
|
||||
from test.scripts.dummy_audio_folder import DummyAudioFolder
|
||||
|
||||
|
||||
class DummyNoiseFolder(DummyAudioFolder):
|
||||
def __init__(self, count=10):
|
||||
super().__init__(count)
|
||||
self.source = self.subdir('source')
|
||||
self.noise = self.subdir('noise')
|
||||
self.output = self.subdir('output')
|
||||
|
||||
self.generate_samples(self.subdir('source', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2))
|
||||
self.generate_samples(self.subdir('source', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2))
|
||||
self.generate_samples(self.noise, 'noise-{}.wav', 0.5, self.rand(10, 20))
|
||||
|
||||
|
||||
class TestAddNoise:
|
||||
def get_base_data(self, count):
|
||||
folders = DummyNoiseFolder(count)
|
||||
base_args = dict(
|
||||
folder=folders.source, noise_folder=folders.noise,
|
||||
output_folder=folders.output
|
||||
)
|
||||
return folders, base_args
|
||||
|
||||
def test_run_basic(self):
|
||||
folders, base_args = self.get_base_data(10)
|
||||
script = AddNoiseScript.create(inflation_factor=1, **base_args)
|
||||
script.run()
|
||||
assert folders.count_files(folders.output) == 20
|
||||
|
||||
def test_run_basic_2(self):
|
||||
folders, base_args = self.get_base_data(10)
|
||||
script = AddNoiseScript.create(inflation_factor=2, **base_args)
|
||||
script.run()
|
||||
assert folders.count_files(folders.output) == 40
|
|
@ -0,0 +1,43 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
from os.path import isfile
|
||||
|
||||
from precise.scripts.calc_threshold import CalcThresholdScript
|
||||
from precise.scripts.eval import EvalScript
|
||||
from precise.scripts.graph import GraphScript
|
||||
|
||||
|
||||
def read_content(filename):
|
||||
with open(filename) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def test_combined(train_folder, train_script):
|
||||
train_script.run()
|
||||
params_file = train_folder.model + '.params'
|
||||
assert isfile(train_folder.model)
|
||||
assert isfile(params_file)
|
||||
|
||||
EvalScript.create(folder=train_folder.root, models=[train_folder.model]).run()
|
||||
|
||||
out_file = train_folder.path('outputs.npz')
|
||||
graph_script = GraphScript.create(folder=train_folder.root, models=[train_folder.model], output_file=out_file)
|
||||
graph_script.run()
|
||||
assert isfile(out_file)
|
||||
|
||||
params_before = read_content(params_file)
|
||||
CalcThresholdScript.create(folder=train_folder.root, model=train_folder.model, input_file=out_file).run()
|
||||
assert params_before != read_content(params_file)
|
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from os.path import isfile
|
||||
|
||||
from precise.scripts.convert import ConvertScript
|
||||
|
||||
|
||||
def test_convert(train_folder, train_script):
|
||||
train_script.run()
|
||||
|
||||
ConvertScript.create(model=train_folder.model, out=train_folder.model + '.pb').run()
|
||||
assert isfile(train_folder.model + '.pb')
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import glob
|
||||
import re
|
||||
from os.path import join
|
||||
|
||||
from precise.scripts.engine import EngineScript
|
||||
from runner.precise_runner import ReadWriteStream
|
||||
|
||||
|
||||
class FakeStdin:
|
||||
def __init__(self, data: bytes):
|
||||
self.buffer = ReadWriteStream(data)
|
||||
|
||||
def isatty(self):
|
||||
return False
|
||||
|
||||
|
||||
class FakeStdout:
|
||||
def __init__(self):
|
||||
self.buffer = ReadWriteStream()
|
||||
|
||||
|
||||
def test_engine(train_folder, train_script):
|
||||
train_script.run()
|
||||
with open(glob.glob(join(train_folder.root, 'wake-word', '*.wav'))[0], 'rb') as f:
|
||||
data = f.read()
|
||||
try:
|
||||
sys.stdin = FakeStdin(data)
|
||||
sys.stdout = FakeStdout()
|
||||
EngineScript.create(model_name=train_folder.model).run()
|
||||
assert re.match(rb'[01]\.[0-9]+', sys.stdout.buffer.buffer)
|
||||
finally:
|
||||
sys.stdin = sys.__stdin__
|
||||
sys.stdout = sys.__stdout__
|
|
@ -0,0 +1,37 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 Mycroft AI Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from os.path import isfile
|
||||
|
||||
from precise.params import pr
|
||||
from precise.scripts.train import TrainScript
|
||||
from test.scripts.dummy_audio_folder import DummyAudioFolder
|
||||
|
||||
|
||||
class DummyTrainFolder(DummyAudioFolder):
|
||||
def __init__(self, count=10):
|
||||
super().__init__(count)
|
||||
self.generate_samples(self.subdir('wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2 * pr.buffer_t))
|
||||
self.generate_samples(self.subdir('not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
|
||||
self.generate_samples(self.subdir('test', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2 * pr.buffer_t))
|
||||
self.generate_samples(self.subdir('test', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
|
||||
self.model = self.path('model.net')
|
||||
|
||||
|
||||
class TestTrain:
|
||||
def test_run_basic(self):
|
||||
folders = DummyTrainFolder(10)
|
||||
script = TrainScript.create(model=folders.model, folder=folders.root)
|
||||
script.run()
|
||||
assert isfile(folders.model)
|
Loading…
Reference in New Issue