mirror of https://github.com/coqui-ai/TTS.git
Use `torch.linalg.qr` for pytorch > `v1.9.0`
parent
0a1962b583
commit
d42d1c02ea
|
@ -1,3 +1,5 @@
|
|||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
@ -81,7 +83,11 @@ class InvConvNear(nn.Module):
|
|||
self.no_jacobian = no_jacobian
|
||||
self.weight_inv = None
|
||||
|
||||
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||
if LooseVersion(torch.__version__) < LooseVersion("1.9"):
|
||||
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||
else:
|
||||
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
|
||||
|
||||
if torch.det(w_init) < 0:
|
||||
w_init[:, 0] = -1 * w_init[:, 0]
|
||||
self.weight = nn.Parameter(w_init)
|
||||
|
|
Loading…
Reference in New Issue