Showing
1 changed file
with
127 additions
and
0 deletions
2DCNN/tests/verify_mri_lstm.py
0 → 100644
1 | +import torch | ||
2 | +from torch import nn | ||
3 | + | ||
4 | +""" | ||
5 | +Code to test LSTM implementation with Lam et.al. | ||
6 | +Our implementation use vectorization and should be faster... but need to be verified. | ||
7 | +""" | ||
8 | + | ||
9 | + | ||
10 | +def encoder_blk(in_channels, out_channels): | ||
11 | + return nn.Sequential( | ||
12 | + nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1), | ||
13 | + nn.InstanceNorm2d(out_channels), | ||
14 | + nn.MaxPool2d(2, stride=2), | ||
15 | + nn.ReLU() | ||
16 | + ) | ||
17 | + | ||
18 | + | ||
19 | +class MRI_LSTM(nn.Module): | ||
20 | + | ||
21 | + def __init__(self, lstm_feat_dim, lstm_latent_dim, *args, **kwargs): | ||
22 | + super(MRI_LSTM, self).__init__() | ||
23 | + | ||
24 | + self.input_dim = (1, 109, 91) | ||
25 | + | ||
26 | + self.feat_embed_dim = lstm_feat_dim | ||
27 | + self.latent_dim = lstm_latent_dim | ||
28 | + | ||
29 | + # Build Encoder | ||
30 | + encoder_blocks = [ | ||
31 | + encoder_blk(1, 32), | ||
32 | + encoder_blk(32, 64), | ||
33 | + encoder_blk(64, 128), | ||
34 | + encoder_blk(128, 256), | ||
35 | + encoder_blk(256, 256) | ||
36 | + ] | ||
37 | + self.encoder = nn.Sequential(*encoder_blocks) | ||
38 | + | ||
39 | + # Post processing | ||
40 | + self.post_proc = nn.Sequential( | ||
41 | + nn.Conv2d(256, 64, 1, stride=1), | ||
42 | + nn.InstanceNorm2d(64), | ||
43 | + nn.ReLU(), | ||
44 | + nn.AvgPool2d([3, 2]), | ||
45 | + nn.Dropout(p=0.5), | ||
46 | + nn.Conv2d(64, self.feat_embed_dim, 1) | ||
47 | + ) | ||
48 | + | ||
49 | + # Connect w/ LSTM | ||
50 | + self.n_layers = 1 | ||
51 | + self.lstm = nn.LSTM(self.feat_embed_dim, self.latent_dim, self.n_layers, batch_first=True) | ||
52 | + | ||
53 | + # Build regressor | ||
54 | + self.lstm_post = nn.Linear(self.latent_dim, 64) | ||
55 | + self.regressor = nn.Sequential(nn.ReLU(), nn.Linear(64, 1)) | ||
56 | + | ||
57 | + self.init_weights() | ||
58 | + | ||
59 | + def init_weights(self): | ||
60 | + for k, m in self.named_modules(): | ||
61 | + if isinstance(m, nn.Conv2d): | ||
62 | + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||
63 | + if m.bias is not None: | ||
64 | + nn.init.constant_(m.bias, 0) | ||
65 | + elif isinstance(m, nn.Linear) and "regressor" in k: | ||
66 | + m.bias.data.fill_(62.68) | ||
67 | + elif isinstance(m, nn.Linear): | ||
68 | + nn.init.normal_(m.weight, 0, 0.01) | ||
69 | + nn.init.constant_(m.bias, 0) | ||
70 | + | ||
71 | + def init_hidden(self, x): | ||
72 | + h_0 = torch.zeros(self.n_layers, x.size(0), self.latent_dim, device=x.device) | ||
73 | + c_0 = torch.zeros(self.n_layers, x.size(0), self.latent_dim, device=x.device) | ||
74 | + h_0.requires_grad = True | ||
75 | + c_0.requires_grad = True | ||
76 | + return h_0, c_0 | ||
77 | + | ||
78 | + def encode_old(self, x, ): | ||
79 | + | ||
80 | + B, C, H, W, D = x.size() | ||
81 | + h_t, c_t = self.init_hidden(x) | ||
82 | + for i in range(H): | ||
83 | + out = self.encoder(x[:, :, i, :, :]) | ||
84 | + out = self.post_proc(out) | ||
85 | + out = out.view(B, 1, self.feat_embed_dim) | ||
86 | + h_t = h_t.view(1, B, self.latent_dim) | ||
87 | + c_t = c_t.view(1, B, self.latent_dim) | ||
88 | + h_t, (_, c_t) = self.lstm(out, (h_t, c_t)) | ||
89 | + encoding = h_t.view(B, self.latent_dim) | ||
90 | + return encoding | ||
91 | + | ||
92 | + def encode_new(self, x): | ||
93 | + | ||
94 | + h_0, c_0 = self.init_hidden(x) | ||
95 | + B, C, H, W, D = x.size() | ||
96 | + # convert to 2D images, apply encoder and then reshape for lstm | ||
97 | + new_input = torch.cat([x[:, :, i, :, :] for i in range(H)], dim=0) | ||
98 | + encoding = self.encoder(new_input) | ||
99 | + encoding = self.post_proc(encoding) | ||
100 | + # (BxH) X C_out X W_out X D_out | ||
101 | + encoding = torch.stack(torch.split(encoding, B, dim=0), dim=2) | ||
102 | + # B X C_out X H X W_out X D_out | ||
103 | + encoding = encoding.squeeze(4).squeeze(3) | ||
104 | + # lstm take batch x seq_len x dim | ||
105 | + encoding = encoding.permute(0, 2, 1) | ||
106 | + | ||
107 | + _, (encoding, _) = self.lstm(encoding) | ||
108 | + # output is 1 X batch x hidden | ||
109 | + encoding = encoding.squeeze(0) | ||
110 | + # pass it to lstm and get encoding | ||
111 | + return encoding | ||
112 | + | ||
113 | + def forward(self, x): | ||
114 | + embedding_old = self.encode_old(x) | ||
115 | + embedding_new = self.encode_new(x) | ||
116 | + | ||
117 | + return embedding_new, embedding_old | ||
118 | + | ||
119 | + | ||
120 | +if __name__ == "__main__": | ||
121 | + B = 4 | ||
122 | + new_model = MRI_LSTM(lstm_feat_dim=2, lstm_latent_dim=128) | ||
123 | + new_model.eval() | ||
124 | + inp = torch.rand(4, 1, 91, 109, 91) | ||
125 | + output = new_model(inp) | ||
126 | + print(torch.allclose(output[0], output[1])) | ||
127 | + # breakpoint() |
-
Please register or login to post a comment