Add rename folder and tags arguments
Everything is now referred to as tags rather than db since it's a set of tags not a database. It also switches the positional argument to refer to the regular structured data folder, with a separate tags-folder option to override where to load the file ids from the tags frompull/10/head
parent
19bf49a4ae
commit
45c72a80c2
|
@ -78,7 +78,7 @@ def test_pocketsphinx(listener: PocketsphinxListener, data_files) -> Stats:
|
|||
|
||||
def main():
|
||||
args = TrainData.parse_args(create_parser(usage))
|
||||
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
data_files = data.train_files if args.use_train else data.test_files
|
||||
listener = PocketsphinxListener(
|
||||
args.key_phrase, args.dict_file, args.hmm_folder, args.threshold
|
||||
|
|
|
@ -66,7 +66,7 @@ def main():
|
|||
):
|
||||
parser.error('Must pass all or no Pocketsphinx arguments')
|
||||
|
||||
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
data_files = data.train_files if args.use_train else data.test_files
|
||||
print('Data:', data)
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ def main():
|
|||
|
||||
inject_params(args.model)
|
||||
|
||||
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
train, test = data.load(args.use_train, not args.use_train)
|
||||
inputs, targets = train if args.use_train else test
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ def main():
|
|||
inject_params(args.model)
|
||||
save_params(args.model)
|
||||
|
||||
data = TrainData.from_both(args.db_file, args.db_folder, args.data_dir)
|
||||
data = TrainData.from_both(args.tags_file, args.tags_folder, args.folder)
|
||||
print('Data:', data)
|
||||
(inputs, outputs), test_data = data.load(True, not args.no_validation)
|
||||
|
||||
|
|
|
@ -58,8 +58,8 @@ usage = '''
|
|||
Disable accuracy and validation calculation
|
||||
to improve speed during training
|
||||
|
||||
:-r --random-data-dir str data/random
|
||||
Directories with properly encoded wav files of
|
||||
:-r --random-data-folder str data/random
|
||||
Folder with properly encoded wav files of
|
||||
random audio that should not cause an activation
|
||||
|
||||
...
|
||||
|
@ -95,8 +95,8 @@ class IncrementalTrainer:
|
|||
from keras.callbacks import ModelCheckpoint
|
||||
self.checkpoint = ModelCheckpoint(args.model, monitor=args.metric_monitor,
|
||||
save_best_only=args.save_best)
|
||||
data = TrainData.from_db(args.db_file, args.db_folder)
|
||||
self.db_data = data.load(True, not args.no_validation)
|
||||
data = TrainData.from_tags(args.tags_file, args.tags_folder)
|
||||
self.tags_data = data.load(True, not args.no_validation)
|
||||
|
||||
if not isfile(args.model):
|
||||
create_model(args.model, args.no_validation, args.extra_metrics).save(args.model)
|
||||
|
@ -104,11 +104,11 @@ class IncrementalTrainer:
|
|||
|
||||
def retrain(self):
|
||||
"""Train for a session, pulling in any new data from the filesystem"""
|
||||
folder = TrainData.from_folder(self.args.data_dir)
|
||||
folder = TrainData.from_folder(self.args.folder)
|
||||
train_data, test_data = folder.load(True, not self.args.no_validation)
|
||||
|
||||
train_data = TrainData.merge(train_data, self.db_data[0])
|
||||
test_data = TrainData.merge(test_data, self.db_data[1])
|
||||
train_data = TrainData.merge(train_data, self.tags_data[0])
|
||||
test_data = TrainData.merge(test_data, self.tags_data[1])
|
||||
print()
|
||||
try:
|
||||
self.listener.runner.model.fit(*train_data, self.args.batch_size, self.args.epochs,
|
||||
|
@ -132,7 +132,7 @@ class IncrementalTrainer:
|
|||
if conf > 0.5:
|
||||
samples_since_train += 1
|
||||
name = splitext(basename(fn))[0] + '-' + str(i) + '.wav'
|
||||
name = join(self.args.data_dir, 'test' if save_test else '', 'not-wake-word',
|
||||
name = join(self.args.folder, 'test' if save_test else '', 'not-wake-word',
|
||||
'generated', name)
|
||||
save_audio(name, audio_buffer)
|
||||
print()
|
||||
|
@ -147,7 +147,7 @@ class IncrementalTrainer:
|
|||
Begin reading through audio files, saving false
|
||||
activations and retraining when necessary
|
||||
"""
|
||||
for fn in glob_all(self.args.random_data_dir, '*.wav'):
|
||||
for fn in glob_all(self.args.random_data_folder, '*.wav'):
|
||||
if fn in self.trained_fns:
|
||||
print('Skipping ' + fn + '...')
|
||||
continue
|
||||
|
@ -164,8 +164,8 @@ def main():
|
|||
args = TrainData.parse_args(create_parser(usage))
|
||||
|
||||
for i in (
|
||||
join(args.data_dir, 'not-wake-word', 'generated'),
|
||||
join(args.data_dir, 'test', 'not-wake-word', 'generated')
|
||||
join(args.folder, 'not-wake-word', 'generated'),
|
||||
join(args.folder, 'test', 'not-wake-word', 'generated')
|
||||
):
|
||||
makedirs(i, exist_ok=True)
|
||||
|
||||
|
|
|
@ -19,13 +19,14 @@ from os.path import join, isfile, dirname
|
|||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
from prettyparse import add_to_parser
|
||||
|
||||
from precise.util import find_wavs
|
||||
from precise.vectorization import load_vector, vectorize_inhibit, vectorize
|
||||
|
||||
|
||||
class TrainData:
|
||||
"""Class to handle loading of wave data from categorized folders and SQLite dbs"""
|
||||
"""Class to handle loading of wave data from categorized folders and tagged text files"""
|
||||
|
||||
def __init__(self, train_files: Tuple[List[str], List[str]],
|
||||
test_files: Tuple[List[str], List[str]]):
|
||||
|
@ -49,40 +50,40 @@ class TrainData:
|
|||
return cls(find_wavs(folder), find_wavs(join(folder, 'test')))
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, db_file: str, db_folder: str) -> 'TrainData':
|
||||
def from_tags(cls, tags_file: str, tags_folder: str) -> 'TrainData':
|
||||
"""
|
||||
Load a set of data from a text database in the following format:
|
||||
Load a set of data from a text file with tags in the following format:
|
||||
<file_id> (tab) <tag>
|
||||
<file_id> (tab) <tag>
|
||||
|
||||
file_id: identifier of file such that the following
|
||||
file exists: {db_folder}/{data_id}.wav
|
||||
file exists: {tags_folder}/{data_id}.wav
|
||||
tag: "wake-word" or "not-wake-word"
|
||||
"""
|
||||
if not db_file:
|
||||
if not tags_file:
|
||||
return cls(([], []), ([], []))
|
||||
if not isfile(db_file):
|
||||
raise RuntimeError('Database file does not exist: ' + db_file)
|
||||
if not isfile(tags_file):
|
||||
raise RuntimeError('Database file does not exist: ' + tags_file)
|
||||
|
||||
train_groups = {}
|
||||
train_group_file = join(db_file.replace('.txt', '') + '.groups.json')
|
||||
train_group_file = join(tags_file.replace('.txt', '') + '.groups.json')
|
||||
if isfile(train_group_file):
|
||||
with open(train_group_file) as f:
|
||||
train_groups = json.load(f)
|
||||
|
||||
db_files = {
|
||||
tags_files = {
|
||||
'wake-word': [],
|
||||
'not-wake-word': []
|
||||
}
|
||||
with open(db_file) as f:
|
||||
with open(tags_file) as f:
|
||||
for line in f.read().split('\n'):
|
||||
if not line:
|
||||
continue
|
||||
file, tag = line.split('\t')
|
||||
db_files[tag.strip()].append(join(db_folder, file.strip() + '.wav'))
|
||||
tags_files[tag.strip()].append(join(tags_folder, file.strip() + '.wav'))
|
||||
|
||||
train_files, test_files = ([], []), ([], [])
|
||||
for label, rows in enumerate([db_files['wake-word'], db_files['not-wake-word']]):
|
||||
for label, rows in enumerate([tags_files['wake-word'], tags_files['not-wake-word']]):
|
||||
for fn in rows:
|
||||
if not isfile(fn):
|
||||
print('Missing file:', fn)
|
||||
|
@ -103,9 +104,9 @@ class TrainData:
|
|||
return cls(train_files, test_files)
|
||||
|
||||
@classmethod
|
||||
def from_both(cls, db_file: str, db_folder: str, data_dir: str) -> 'TrainData':
|
||||
def from_both(cls, tags_file: str, tags_folder: str, folder: str) -> 'TrainData':
|
||||
"""Load data from both a database and a structured folder"""
|
||||
return cls.from_db(db_file, db_folder) + cls.from_folder(data_dir)
|
||||
return cls.from_tags(tags_file, tags_folder) + cls.from_folder(folder)
|
||||
|
||||
def load(self, train=True, test=True) -> tuple:
|
||||
"""
|
||||
|
@ -140,14 +141,23 @@ class TrainData:
|
|||
@staticmethod
|
||||
def parse_args(parser: ArgumentParser) -> Any:
|
||||
"""Return parsed args from parser, adding options for train data inputs"""
|
||||
parser.add_argument('db_folder', help='Folder to load database references from')
|
||||
parser.add_argument(
|
||||
'-db', '--db-file', default='', help='Text database to load from where '
|
||||
'each line is <file_id>\t(wake-word|not-wake-word) and {db_folder}/<file_id>.wav exists..')
|
||||
parser.add_argument('-d', '--data-dir', default='{db_folder}',
|
||||
help='Load files from a different directory')
|
||||
extra_usage = '''
|
||||
:folder str
|
||||
Folder to wav files from
|
||||
|
||||
:-tf --tags-folder str {folder}
|
||||
Specify a different folder to load file ids
|
||||
in tags file from
|
||||
|
||||
:-tg --tags-file str -
|
||||
Text file to load tags from where each line is
|
||||
<file_id> TAB (wake-word|not-wake-word) and
|
||||
{folder}/<file_id>.wav exists
|
||||
|
||||
'''
|
||||
add_to_parser(parser, extra_usage)
|
||||
args = parser.parse_args()
|
||||
args.data_dir = args.data_dir.format(db_folder=args.db_folder)
|
||||
args.tags_folder = args.tags_folder.format(folder=args.folder)
|
||||
return args
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
Loading…
Reference in New Issue