bug fixes

pull/10/head
erogol 2020-05-01 14:34:14 +02:00
parent 736f169cc9
commit de2918c85b
1 changed files with 4 additions and 3 deletions

View File

@ -82,8 +82,8 @@ class Encoder(nn.Module):
o = x
for layer in self.convolutions:
o = layer(o)
o = x.transpose(1, 2)
self.lstm.flatten_parameters()
o = o.transpose(1, 2)
# self.lstm.flatten_parameters()
o, _ = self.lstm(o)
return o
@ -140,7 +140,8 @@ class Decoder(nn.Module):
attn_K=attn_K)
self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim,
self.decoder_rnn_dim, 1)
self.decoder_rnn_dim,
bias=False)
self.linear_projection = Linear(self.decoder_rnn_dim + input_dim,
self.frame_dim * self.r_init)