Use `torch.linalg.qr` for pytorch > `v1.9.0`

pull/506/head
Eren Gölge 2021-06-23 13:45:59 +02:00
parent 0a1962b583
commit d42d1c02ea
1 changed files with 7 additions and 1 deletions

View File

@ -1,3 +1,5 @@
from distutils.version import LooseVersion
import torch
from torch import nn
from torch.nn import functional as F
@ -81,7 +83,11 @@ class InvConvNear(nn.Module):
self.no_jacobian = no_jacobian
self.weight_inv = None
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
if LooseVersion(torch.__version__) < LooseVersion("1.9"):
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
else:
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init)