Add rename folder and tags arguments

Everything is now referred to as tags rather than db since it's a set of tags not a database. It also switches the positional argument to refer to the regular structured data folder, with a separate tags-folder option to override where to load the file ids from the tags from
pull/10/head
Matthew D. Scholefield 2018-04-18 15:59:02 -05:00
parent 19bf49a4ae
commit 45c72a80c2
6 changed files with 46 additions and 36 deletions

View File

@ -78,7 +78,7 @@ def test_pocketsphinx(listener: PocketsphinxListener, data_files) -> Stats:
def main():
args = TrainData.parse_args(create_parser(usage))
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
data_files = data.train_files if args.use_train else data.test_files
listener = PocketsphinxListener(
args.key_phrase, args.dict_file, args.hmm_folder, args.threshold

View File

@ -66,7 +66,7 @@ def main():
):
parser.error('Must pass all or no Pocketsphinx arguments')
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
data_files = data.train_files if args.use_train else data.test_files
print('Data:', data)

View File

@ -96,7 +96,7 @@ def main():
inject_params(args.model)
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
train, test = data.load(args.use_train, not args.use_train)
inputs, targets = train if args.use_train else test

View File

@ -50,7 +50,7 @@ def main():
inject_params(args.model)
save_params(args.model)
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
print('Data:', data)
(inputs, outputs), test_data = data.load(True, not args.no_validation)

View File

@ -58,8 +58,8 @@ usage = '''
Disable accuracy and validation calculation
to improve speed during training
:-r --random-data-dir str data/random
Directories with properly encoded wav files of
:-r --random-data-folder str data/random
Folder with properly encoded wav files of
random audio that should not cause an activation
...
@ -95,8 +95,8 @@ class IncrementalTrainer:
from keras.callbacks import ModelCheckpoint
self.checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
save_best_only=args.save_best)
data = TrainData.from_db(args.db_file, args.db_folder)
self.db_data = data.load(True, not args.no_validation)
data = TrainData.from_tags(args.tags_file, args.tags_folder)
self.tags_data = data.load(True, not args.no_validation)
if not isfile(args.model):
create_model(args.model, args.no_validation, args.extra_metrics).save(args.model)
@ -104,11 +104,11 @@ class IncrementalTrainer:
def retrain(self):
"""Train for a session, pulling in any new data from the filesystem"""
folder = TrainData.from_folder(self.args.data_dir)
folder = TrainData.from_folder(self.args.folder)
train_data, test_data = folder.load(True, not self.args.no_validation)
train_data = TrainData.merge(train_data, self.db_data[0])
test_data = TrainData.merge(test_data, self.db_data[1])
train_data = TrainData.merge(train_data, self.tags_data[0])
test_data = TrainData.merge(test_data, self.tags_data[1])
print()
try:
self.listener.runner.model.fit(*train_data, self.args.batch_size, self.args.epochs,
@ -132,7 +132,7 @@ class IncrementalTrainer:
if conf > 0.5:
samples_since_train += 1
name = splitext(basename(fn))[0] + '-' + str(i) + '.wav'
name = join(self.args.data_dir, 'test' if save_test else '', 'not-wake-word',
name = join(self.args.folder, 'test' if save_test else '', 'not-wake-word',
'generated', name)
save_audio(name, audio_buffer)
print()
@ -147,7 +147,7 @@ class IncrementalTrainer:
Begin reading through audio files, saving false
activations and retraining when necessary
"""
for fn in glob_all(self.args.random_data_dir, '*.wav'):
for fn in glob_all(self.args.random_data_folder, '*.wav'):
if fn in self.trained_fns:
print('Skipping ' + fn + '...')
continue
@ -164,8 +164,8 @@ def main():
args = TrainData.parse_args(create_parser(usage))
for i in (
join(args.data_dir, 'not-wake-word', 'generated'),
join(args.data_dir, 'test', 'not-wake-word', 'generated')
join(args.folder, 'not-wake-word', 'generated'),
join(args.folder, 'test', 'not-wake-word', 'generated')
):
makedirs(i, exist_ok=True)

View File

@ -19,13 +19,14 @@ from os.path import join, isfile, dirname
from typing import *
import numpy as np
from prettyparse import add_to_parser
from precise.util import find_wavs
from precise.vectorization import load_vector, vectorize_inhibit, vectorize
class TrainData:
"""Class to handle loading of wave data from categorized folders and SQLite dbs"""
"""Class to handle loading of wave data from categorized folders and tagged text files"""
def __init__(self, train_files: Tuple[List[str], List[str]],
test_files: Tuple[List[str], List[str]]):
@ -49,40 +50,40 @@ class TrainData:
return cls(find_wavs(folder), find_wavs(join(folder, 'test')))
@classmethod
def from_db(cls, db_file: str, db_folder: str) -> 'TrainData':
def from_tags(cls, tags_file: str, tags_folder: str) -> 'TrainData':
"""
Load a set of data from a text database in the following format:
Load a set of data from a text file with tags in the following format:
<file_id> (tab) <tag>
<file_id> (tab) <tag>
file_id: identifier of file such that the following
file exists: {db_folder}/{data_id}.wav
file exists: {tags_folder}/{data_id}.wav
tag: "wake-word" or "not-wake-word"
"""
if not db_file:
if not tags_file:
return cls(([], []), ([], []))
if not isfile(db_file):
raise RuntimeError('Database file does not exist: ' + db_file)
if not isfile(tags_file):
raise RuntimeError('Database file does not exist: ' + tags_file)
train_groups = {}
train_group_file = join(db_file.replace('.txt', '') + '.groups.json')
train_group_file = join(tags_file.replace('.txt', '') + '.groups.json')
if isfile(train_group_file):
with open(train_group_file) as f:
train_groups = json.load(f)
db_files = {
tags_files = {
'wake-word': [],
'not-wake-word': []
}
with open(db_file) as f:
with open(tags_file) as f:
for line in f.read().split('\n'):
if not line:
continue
file, tag = line.split('\t')
db_files[tag.strip()].append(join(db_folder, file.strip() + '.wav'))
tags_files[tag.strip()].append(join(tags_folder, file.strip() + '.wav'))
train_files, test_files = ([], []), ([], [])
for label, rows in enumerate([db_files['wake-word'], db_files['not-wake-word']]):
for label, rows in enumerate([tags_files['wake-word'], tags_files['not-wake-word']]):
for fn in rows:
if not isfile(fn):
print('Missing file:', fn)
@ -103,9 +104,9 @@ class TrainData:
return cls(train_files, test_files)
@classmethod
def from_both(cls, db_file: str, db_folder: str, data_dir: str) -> 'TrainData':
def from_both(cls, tags_file: str, tags_folder: str, folder: str) -> 'TrainData':
"""Load data from both a database and a structured folder"""
return cls.from_db(db_file, db_folder) + cls.from_folder(data_dir)
return cls.from_tags(tags_file, tags_folder) + cls.from_folder(folder)
def load(self, train=True, test=True) -> tuple:
"""
@ -140,14 +141,23 @@ class TrainData:
@staticmethod
def parse_args(parser: ArgumentParser) -> Any:
"""Return parsed args from parser, adding options for train data inputs"""
parser.add_argument('db_folder', help='Folder to load database references from')
parser.add_argument(
'-db', '--db-file', default='', help='Text database to load from where '
'each line is <file_id>\t(wake-word|not-wake-word) and {db_folder}/<file_id>.wav exists..')
parser.add_argument('-d', '--data-dir', default='{db_folder}',
help='Load files from a different directory')
extra_usage = '''
:folder str
Folder to wav files from
:-tf --tags-folder str {folder}
Specify a different folder to load file ids
in tags file from
:-tg --tags-file str -
Text file to load tags from where each line is
<file_id> TAB (wake-word|not-wake-word) and
{folder}/<file_id>.wav exists
'''
add_to_parser(parser, extra_usage)
args = parser.parse_args()
args.data_dir = args.data_dir.format(db_folder=args.db_folder)
args.tags_folder = args.tags_folder.format(folder=args.folder)
return args
def __repr__(self) -> str: