fix for tensorflow 1.4.0

pull/2/head
Matthew Goldey 2017-10-13 14:56:00 -05:00
parent 268995c08c
commit 7b09826434
1 changed files with 16 additions and 0 deletions

View File

@ -15,6 +15,14 @@ class TacoTestHelper(Helper):
def batch_size(self):
return self._batch_size
@property
def sample_ids_shape(self):
return tf.TensorShape([])
@property
def sample_ids_dtype(self):
return np.int32
def initialize(self, name=None):
return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim))
@ -48,6 +56,14 @@ class TacoTrainingHelper(Helper):
def batch_size(self):
return self._batch_size
@property
def sample_ids_shape(self):
return tf.TensorShape([])
@property
def sample_ids_dtype(self):
return np.int32
def initialize(self, name=None):
return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim))