Hyunji

torch utils

1 +""" things torch should have but it doesn't"""
2 +import logging
3 +
4 +import torch
5 +import torch.nn as nn
6 +from torch.autograd import Function
7 +
8 +logger = logging.getLogger()
9 +EPSILON = 1e-8
10 +
11 +
12 +# reset seed
13 +def reset_seed():
14 + while True:
15 + try:
16 + torch.seed()
17 + except RuntimeError as _:
18 + logger.error("Error generating seed")
19 + else:
20 + break
21 +
22 +
23 +class Reshape(nn.Module):
24 + """
25 + Reshape module that reshapes any input to (batch_size, ...shape)
26 + by default it does flattening but you can pass any shape.
27 + """
28 +
29 + def __init__(self, shape=(-1,)):
30 + super().__init__()
31 + self.shape = shape
32 +
33 + def forward(self, x):
34 + batch_size = x.shape[0]
35 + return x.view((batch_size,) + self.shape)
36 +
37 + def extra_repr(self):
38 + return f"shape={self.shape}"
39 +
40 +
41 +class Offset(torch.nn.Module):
42 + def __init__(self, offset, net):
43 + super().__init__()
44 + self.offset = nn.Parameter(offset, requires_grad=False)
45 + self.net = net
46 +
47 + def forward(self, *args):
48 + batch_size = args[0].shape[0]
49 + return self.offset.expand((batch_size, -1, -1, -1)) + 1e-8 # + self.net(*args)
50 +
51 +
52 +def batch_eye(N, D, device="cpu"):
53 + x = torch.eye(D, device=device)
54 + x = x.unsqueeze(0)
55 + x = x.repeat(N, 1, 1)
56 + return x
57 +
58 +
59 +def batch_eye_like(tensor):
60 + assert len(tensor.shape) == 3 and tensor.shape[1] == tensor.shape[2]
61 + N = tensor.shape[0]
62 + D = tensor.shape[1]
63 + return batch_eye(N, D, device=tensor.device)
64 +
65 +
66 +class _RevGrad(Function):
67 + @staticmethod
68 + def forward(ctx, input_):
69 + ctx.save_for_backward(input_)
70 + output = input_
71 + return output
72 +
73 + @staticmethod
74 + def backward(ctx, grad_output):
75 + grad_input = None
76 + if ctx.needs_input_grad[0]:
77 + grad_input = -grad_output
78 + return grad_input
79 +
80 +
81 +revgrad = _RevGrad.apply
82 +
83 +
84 +class RevGrad(nn.Module):
85 + def __init__(self, *args, **kwargs):
86 + """
87 + A gradient reversal layer.
88 + This layer has no parameters, and simply reverses the gradient
89 + in the backward pass.
90 + """
91 + super().__init__(*args, **kwargs)
92 +
93 + def forward(self, input_):
94 + return revgrad(input_)
95 +
96 +
97 +def infer_shape(net, input_shape):
98 + x = torch.rand((2,) + input_shape)
99 + return net(x).shape[1:]