mirror of https://github.com/MycroftAI/mimic2.git
ability to assign gpu during training, fixed prenet dropout bug
parent
ba27a6d95d
commit
1fc9886e4d
|
@ -8,7 +8,7 @@ def prenet(inputs, is_training, layer_sizes=[256, 128], scope=None):
|
||||||
with tf.variable_scope(scope or 'prenet'):
|
with tf.variable_scope(scope or 'prenet'):
|
||||||
for i, size in enumerate(layer_sizes):
|
for i, size in enumerate(layer_sizes):
|
||||||
dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_%d' % (i+1))
|
dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_%d' % (i+1))
|
||||||
x = tf.layers.dropout(dense, rate=drop_rate, name='dropout_%d' % (i+1))
|
x = tf.layers.dropout(dense, rate=drop_rate, training=True, name='dropout_%d' % (i+1))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
2
train.py
2
train.py
|
@ -139,7 +139,9 @@ def main():
|
||||||
parser.add_argument('--slack_url', help='Slack webhook URL to get periodic reports.')
|
parser.add_argument('--slack_url', help='Slack webhook URL to get periodic reports.')
|
||||||
parser.add_argument('--tf_log_level', type=int, default=1, help='Tensorflow C++ log level.')
|
parser.add_argument('--tf_log_level', type=int, default=1, help='Tensorflow C++ log level.')
|
||||||
parser.add_argument('--git', action='store_true', help='If set, verify that the client is clean.')
|
parser.add_argument('--git', action='store_true', help='If set, verify that the client is clean.')
|
||||||
|
parser.add_argument('--gpu_assignment', default='0', help='Set the gpu the model should run on')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_assignment
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
|
||||||
run_name = args.name or args.model
|
run_name = args.name or args.model
|
||||||
log_dir = os.path.join(args.base_dir, 'logs-%s' % run_name)
|
log_dir = os.path.join(args.base_dir, 'logs-%s' % run_name)
|
||||||
|
|
Loading…
Reference in New Issue