# coding: utf-8 # import torch # from torch import nn # class StopProjection(nn.Module): # r""" Simple projection layer to predict the "stop token" # Args: # in_features (int): size of the input vector # out_features (int or list): size of each output vector. aka number # of predicted frames. # """ # def __init__(self, in_features, out_features): # super(StopProjection, self).__init__() # self.linear = nn.Linear(in_features, out_features) # self.dropout = nn.Dropout(0.5) # self.sigmoid = nn.Sigmoid() # def forward(self, inputs): # out = self.dropout(inputs) # out = self.linear(out) # out = self.sigmoid(out) # return out