mirror of https://github.com/coqui-ai/TTS.git
Add a missing class variable to attention class
parent
8ffc85008a
commit
fe99baec5a
|
@ -80,6 +80,7 @@ class AttentionRNN(nn.Module):
|
||||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||||
"""
|
"""
|
||||||
super(AttentionRNN, self).__init__()
|
super(AttentionRNN, self).__init__()
|
||||||
|
self.align_model = align_model
|
||||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||||
# pick bahdanau or location sensitive attention
|
# pick bahdanau or location sensitive attention
|
||||||
if align_model == 'b':
|
if align_model == 'b':
|
||||||
|
|
Loading…
Reference in New Issue