use memory queue if r is smaller than queue size

pull/10/head
Eren Golge 2019-07-15 15:39:16 +02:00
parent 59a0606e06
commit 72c5062c02
1 changed files with 1 additions and 1 deletions

View File

@ -395,7 +395,7 @@ class Decoder(nn.Module):
return output, stop_token, self.attention_layer.attention_weights
def _update_memory_queue(self, new_memory):
if self.memory_size > 0:
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(),
new_memory