diff --git a/layers/custom_layers.py b/layers/custom_layers.py new file mode 100644 index 00000000..802091e8 --- /dev/null +++ b/layers/custom_layers.py @@ -0,0 +1,26 @@ +# coding: utf-8 +import torch +from torch.autograd import Variable +from torch import nn + + +class StopProjection(nn.Module): + r""" Simple projection layer to predict the "stop token" + + Args: + in_features (int): size of the input vector + out_features (int or list): size of each output vector. aka number + of predicted frames. + """ + + def __init__(self, in_features, out_features): + super(StopProjection, self).__init__() + self.linear = nn.Linear(in_features, out_features) + self.dropout = nn.Dropout(0.5) + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs): + out = self.dropout(inputs) + out = self.linear(out) + out = self.sigmoid(out) + return out \ No newline at end of file