mirror of https://github.com/coqui-ai/TTS.git
use memory queue if r is smaller than queue size
parent
59a0606e06
commit
72c5062c02
|
@ -395,7 +395,7 @@ class Decoder(nn.Module):
|
||||||
return output, stop_token, self.attention_layer.attention_weights
|
return output, stop_token, self.attention_layer.attention_weights
|
||||||
|
|
||||||
def _update_memory_queue(self, new_memory):
|
def _update_memory_queue(self, new_memory):
|
||||||
if self.memory_size > 0:
|
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
||||||
self.memory_input = torch.cat([
|
self.memory_input = torch.cat([
|
||||||
self.memory_input[:, self.r * self.memory_dim:].clone(),
|
self.memory_input[:, self.r * self.memory_dim:].clone(),
|
||||||
new_memory
|
new_memory
|
||||||
|
|
Loading…
Reference in New Issue