bug fixes

pull/10/head
erogol 2020-05-01 14:34:14 +02:00
parent 736f169cc9
commit de2918c85b
1 changed files with 4 additions and 3 deletions

View File

@ -82,8 +82,8 @@ class Encoder(nn.Module):
o = x o = x
for layer in self.convolutions: for layer in self.convolutions:
o = layer(o) o = layer(o)
o = x.transpose(1, 2) o = o.transpose(1, 2)
self.lstm.flatten_parameters() # self.lstm.flatten_parameters()
o, _ = self.lstm(o) o, _ = self.lstm(o)
return o return o
@ -140,7 +140,8 @@ class Decoder(nn.Module):
attn_K=attn_K) attn_K=attn_K)
self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim, self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim,
self.decoder_rnn_dim, 1) self.decoder_rnn_dim,
bias=False)
self.linear_projection = Linear(self.decoder_rnn_dim + input_dim, self.linear_projection = Linear(self.decoder_rnn_dim + input_dim,
self.frame_dim * self.r_init) self.frame_dim * self.r_init)