Hyunji

verify mri lstm

1 +import torch
2 +from torch import nn
3 +
4 +"""
5 +Code to test LSTM implementation with Lam et.al.
6 +Our implementation use vectorization and should be faster... but need to be verified.
7 +"""
8 +
9 +
10 +def encoder_blk(in_channels, out_channels):
11 + return nn.Sequential(
12 + nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1),
13 + nn.InstanceNorm2d(out_channels),
14 + nn.MaxPool2d(2, stride=2),
15 + nn.ReLU()
16 + )
17 +
18 +
19 +class MRI_LSTM(nn.Module):
20 +
21 + def __init__(self, lstm_feat_dim, lstm_latent_dim, *args, **kwargs):
22 + super(MRI_LSTM, self).__init__()
23 +
24 + self.input_dim = (1, 109, 91)
25 +
26 + self.feat_embed_dim = lstm_feat_dim
27 + self.latent_dim = lstm_latent_dim
28 +
29 + # Build Encoder
30 + encoder_blocks = [
31 + encoder_blk(1, 32),
32 + encoder_blk(32, 64),
33 + encoder_blk(64, 128),
34 + encoder_blk(128, 256),
35 + encoder_blk(256, 256)
36 + ]
37 + self.encoder = nn.Sequential(*encoder_blocks)
38 +
39 + # Post processing
40 + self.post_proc = nn.Sequential(
41 + nn.Conv2d(256, 64, 1, stride=1),
42 + nn.InstanceNorm2d(64),
43 + nn.ReLU(),
44 + nn.AvgPool2d([3, 2]),
45 + nn.Dropout(p=0.5),
46 + nn.Conv2d(64, self.feat_embed_dim, 1)
47 + )
48 +
49 + # Connect w/ LSTM
50 + self.n_layers = 1
51 + self.lstm = nn.LSTM(self.feat_embed_dim, self.latent_dim, self.n_layers, batch_first=True)
52 +
53 + # Build regressor
54 + self.lstm_post = nn.Linear(self.latent_dim, 64)
55 + self.regressor = nn.Sequential(nn.ReLU(), nn.Linear(64, 1))
56 +
57 + self.init_weights()
58 +
59 + def init_weights(self):
60 + for k, m in self.named_modules():
61 + if isinstance(m, nn.Conv2d):
62 + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
63 + if m.bias is not None:
64 + nn.init.constant_(m.bias, 0)
65 + elif isinstance(m, nn.Linear) and "regressor" in k:
66 + m.bias.data.fill_(62.68)
67 + elif isinstance(m, nn.Linear):
68 + nn.init.normal_(m.weight, 0, 0.01)
69 + nn.init.constant_(m.bias, 0)
70 +
71 + def init_hidden(self, x):
72 + h_0 = torch.zeros(self.n_layers, x.size(0), self.latent_dim, device=x.device)
73 + c_0 = torch.zeros(self.n_layers, x.size(0), self.latent_dim, device=x.device)
74 + h_0.requires_grad = True
75 + c_0.requires_grad = True
76 + return h_0, c_0
77 +
78 + def encode_old(self, x, ):
79 +
80 + B, C, H, W, D = x.size()
81 + h_t, c_t = self.init_hidden(x)
82 + for i in range(H):
83 + out = self.encoder(x[:, :, i, :, :])
84 + out = self.post_proc(out)
85 + out = out.view(B, 1, self.feat_embed_dim)
86 + h_t = h_t.view(1, B, self.latent_dim)
87 + c_t = c_t.view(1, B, self.latent_dim)
88 + h_t, (_, c_t) = self.lstm(out, (h_t, c_t))
89 + encoding = h_t.view(B, self.latent_dim)
90 + return encoding
91 +
92 + def encode_new(self, x):
93 +
94 + h_0, c_0 = self.init_hidden(x)
95 + B, C, H, W, D = x.size()
96 + # convert to 2D images, apply encoder and then reshape for lstm
97 + new_input = torch.cat([x[:, :, i, :, :] for i in range(H)], dim=0)
98 + encoding = self.encoder(new_input)
99 + encoding = self.post_proc(encoding)
100 + # (BxH) X C_out X W_out X D_out
101 + encoding = torch.stack(torch.split(encoding, B, dim=0), dim=2)
102 + # B X C_out X H X W_out X D_out
103 + encoding = encoding.squeeze(4).squeeze(3)
104 + # lstm take batch x seq_len x dim
105 + encoding = encoding.permute(0, 2, 1)
106 +
107 + _, (encoding, _) = self.lstm(encoding)
108 + # output is 1 X batch x hidden
109 + encoding = encoding.squeeze(0)
110 + # pass it to lstm and get encoding
111 + return encoding
112 +
113 + def forward(self, x):
114 + embedding_old = self.encode_old(x)
115 + embedding_new = self.encode_new(x)
116 +
117 + return embedding_new, embedding_old
118 +
119 +
120 +if __name__ == "__main__":
121 + B = 4
122 + new_model = MRI_LSTM(lstm_feat_dim=2, lstm_latent_dim=128)
123 + new_model.eval()
124 + inp = torch.rand(4, 1, 91, 109, 91)
125 + output = new_model(inp)
126 + print(torch.allclose(output[0], output[1]))
127 + # breakpoint()