From 4e4f876bc4f1ba3087c950e2406cd7cc078aa2f5 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 22 Mar 2018 12:34:31 -0700 Subject: [PATCH] Stop token prediction - does train yet --- layers/custom_layers.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 layers/custom_layers.py diff --git a/layers/custom_layers.py b/layers/custom_layers.py new file mode 100644 index 00000000..802091e8 --- /dev/null +++ b/layers/custom_layers.py @@ -0,0 +1,26 @@ +# coding: utf-8 +import torch +from torch.autograd import Variable +from torch import nn + + +class StopProjection(nn.Module): + r""" Simple projection layer to predict the "stop token" + + Args: + in_features (int): size of the input vector + out_features (int or list): size of each output vector. aka number + of predicted frames. + """ + + def __init__(self, in_features, out_features): + super(StopProjection, self).__init__() + self.linear = nn.Linear(in_features, out_features) + self.dropout = nn.Dropout(0.5) + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs): + out = self.dropout(inputs) + out = self.linear(out) + out = self.sigmoid(out) + return out \ No newline at end of file