Showing
1 changed file
with
191 additions
and
0 deletions
2DCNN/src/arch/brain_age_slice_set.py
0 → 100644
1 | +"""code for attention models""" | ||
2 | + | ||
3 | +import math | ||
4 | + | ||
5 | +import torch | ||
6 | +from box import Box | ||
7 | +from torch import nn | ||
8 | + | ||
9 | + | ||
10 | +class MeanPool(nn.Module): | ||
11 | + def forward(self, X): | ||
12 | + return X.mean(dim=1, keepdim=True), None | ||
13 | + | ||
14 | + | ||
15 | +class MaxPool(nn.Module): | ||
16 | + def forward(self, X): | ||
17 | + return X.max(dim=1, keepdim=True)[0], None | ||
18 | + | ||
19 | + | ||
20 | +class PooledAttention(nn.Module): | ||
21 | + def __init__(self, input_dim, dim_v, dim_k, num_heads, ln=False): | ||
22 | + super(PooledAttention, self).__init__() | ||
23 | + self.S = nn.Parameter(torch.zeros(1, dim_k)) | ||
24 | + nn.init.xavier_uniform_(self.S) | ||
25 | + | ||
26 | + # transform to get key and value vector | ||
27 | + self.fc_k = nn.Linear(input_dim, dim_k) | ||
28 | + self.fc_v = nn.Linear(input_dim, dim_v) | ||
29 | + | ||
30 | + self.dim_v = dim_v | ||
31 | + self.dim_k = dim_k | ||
32 | + self.num_heads = num_heads | ||
33 | + | ||
34 | + if ln: | ||
35 | + self.ln0 = nn.LayerNorm(dim_v) | ||
36 | + | ||
37 | + def forward(self, X): | ||
38 | + B, C, H = X.shape | ||
39 | + | ||
40 | + Q = self.S.repeat(X.size(0), 1, 1) | ||
41 | + | ||
42 | + K = self.fc_k(X.reshape(-1, H)).reshape(B, C, self.dim_k) | ||
43 | + V = self.fc_v(X.reshape(-1, H)).reshape(B, C, self.dim_v) | ||
44 | + dim_split = self.dim_v // self.num_heads | ||
45 | + Q_ = torch.cat(Q.split(dim_split, 2), 0) | ||
46 | + K_ = torch.cat(K.split(dim_split, 2), 0) | ||
47 | + V_ = torch.cat(V.split(dim_split, 2), 0) | ||
48 | + A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(dim_split), 2) | ||
49 | + O = torch.cat(A.bmm(V_).split(B, 0), 2) | ||
50 | + O = O if getattr(self, 'ln0', None) is None else self.ln0(O) | ||
51 | + return O, A | ||
52 | + | ||
53 | + def get_attention(self, X): | ||
54 | + B, C, H = X.shape | ||
55 | + | ||
56 | + Q = self.S.repeat(X.size(0), 1, 1) | ||
57 | + | ||
58 | + K = self.fc_k(X.reshape(-1, H)).reshape(B, C, self.dim_k) | ||
59 | + V = self.fc_v(X.reshape(-1, H)).reshape(B, C, self.dim_v) | ||
60 | + dim_split = self.dim_v // self.num_heads | ||
61 | + Q_ = torch.cat(Q.split(dim_split, 2), 0) | ||
62 | + K_ = torch.cat(K.split(dim_split, 2), 0) | ||
63 | + V_ = torch.cat(V.split(dim_split, 2), 0) | ||
64 | + A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(dim_split), 2) | ||
65 | + return A | ||
66 | + | ||
67 | + | ||
68 | +def encoder_blk(in_channels, out_channels): | ||
69 | + return nn.Sequential( | ||
70 | + nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1), | ||
71 | + nn.InstanceNorm2d(out_channels), | ||
72 | + nn.MaxPool2d(2, stride=2), | ||
73 | + nn.ReLU() | ||
74 | + ) | ||
75 | + | ||
76 | + | ||
77 | +class MRI_ATTN(nn.Module): | ||
78 | + | ||
79 | + def __init__(self, attn_num_heads, attn_dim, attn_drop=False, agg_fn="attention", slice_dim=1, | ||
80 | + *args, **kwargs): | ||
81 | + super(MRI_ATTN, self).__init__() | ||
82 | + | ||
83 | + self.input_dim = [(1, 109, 91), (91, 1, 91), (91, 109, 1)][slice_dim - 1] | ||
84 | + | ||
85 | + self.num_heads = attn_num_heads | ||
86 | + self.attn_dim = attn_dim | ||
87 | + | ||
88 | + # Build Encoder | ||
89 | + encoder_blocks = [ | ||
90 | + encoder_blk(1, 32), | ||
91 | + encoder_blk(32, 64), | ||
92 | + encoder_blk(64, 128), | ||
93 | + encoder_blk(128, 256), | ||
94 | + encoder_blk(256, 256) | ||
95 | + ] | ||
96 | + self.encoder = nn.Sequential(*encoder_blocks) | ||
97 | + | ||
98 | + if slice_dim == 1: | ||
99 | + avg = nn.AvgPool2d([3, 2]) | ||
100 | + elif slice_dim == 2: | ||
101 | + avg = nn.AvgPool2d([2, 2]) | ||
102 | + elif slice_dim == 3: | ||
103 | + avg = nn.AvgPool2d([2, 3]) | ||
104 | + else: | ||
105 | + raise Exception("Invalid slice dim") | ||
106 | + self.slice_dim = slice_dim | ||
107 | + | ||
108 | + # Post processing | ||
109 | + self.post_proc = nn.Sequential( | ||
110 | + nn.Conv2d(256, 64, 1, stride=1), | ||
111 | + nn.InstanceNorm2d(64), | ||
112 | + nn.ReLU(), | ||
113 | + avg, | ||
114 | + nn.Dropout(p=0.5) if attn_drop else nn.Identity(), | ||
115 | + nn.Conv2d(64, self.num_heads * self.attn_dim, 1) | ||
116 | + ) | ||
117 | + | ||
118 | + if agg_fn == "attention": | ||
119 | + self.pooled_attention = PooledAttention(input_dim=self.num_heads * self.attn_dim, | ||
120 | + dim_v=self.num_heads * self.attn_dim, | ||
121 | + dim_k=self.num_heads * self.attn_dim, | ||
122 | + num_heads=self.num_heads) | ||
123 | + elif agg_fn == "mean": | ||
124 | + self.pooled_attention = MeanPool() | ||
125 | + elif agg_fn == "max": | ||
126 | + self.pooled_attention = MaxPool() | ||
127 | + else: | ||
128 | + raise Exception("Invalid attention function") | ||
129 | + | ||
130 | + # Build regressor | ||
131 | + self.attn_post = nn.Linear(self.num_heads * self.attn_dim, 64) | ||
132 | + self.regressor = nn.Sequential(nn.ReLU(), nn.Linear(64, 1)) | ||
133 | + self.init_weights() | ||
134 | + | ||
135 | + def init_weights(self): | ||
136 | + for k, m in self.named_modules(): | ||
137 | + if isinstance(m, nn.Conv2d): | ||
138 | + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||
139 | + if m.bias is not None: | ||
140 | + nn.init.constant_(m.bias, 0) | ||
141 | + elif isinstance(m, nn.Linear) and "regressor" in k: | ||
142 | + m.bias.data.fill_(62.68) | ||
143 | + elif isinstance(m, nn.Linear): | ||
144 | + nn.init.normal_(m.weight, 0, 0.01) | ||
145 | + nn.init.constant_(m.bias, 0) | ||
146 | + | ||
147 | + def encode(self, x): | ||
148 | + | ||
149 | + B, C, H, W, D = x.size() | ||
150 | + if self.slice_dim == 1: | ||
151 | + new_input = torch.cat([x[:, :, i, :, :] for i in range(H)], dim=0) | ||
152 | + encoding = self.encoder(new_input) | ||
153 | + encoding = self.post_proc(encoding) | ||
154 | + encoding = torch.cat([i.unsqueeze(2) for i in torch.split(encoding, B, dim=0)], dim=2) | ||
155 | + # note: squeezing is bad because batch dim can be dropped | ||
156 | + encoding = encoding.squeeze(4).squeeze(3) | ||
157 | + elif self.slice_dim == 2: | ||
158 | + new_input = torch.cat([x[:, :, :, i, :] for i in range(W)], dim=0) | ||
159 | + encoding = self.encoder(new_input) | ||
160 | + encoding = self.post_proc(encoding) | ||
161 | + encoding = torch.cat([i.unsqueeze(3) for i in torch.split(encoding, B, dim=0)], dim=3) | ||
162 | + # note: squeezing is bad because batch dim can be dropped | ||
163 | + encoding = encoding.squeeze(4).squeeze(2) | ||
164 | + elif self.slice_dim == 3: | ||
165 | + new_input = torch.cat([x[:, :, :, :, i] for i in range(D)], dim=0) | ||
166 | + encoding = self.encoder(new_input) | ||
167 | + encoding = self.post_proc(encoding) | ||
168 | + encoding = torch.cat([i.unsqueeze(4) for i in torch.split(encoding, B, dim=0)], dim=4) | ||
169 | + # note: squeezing is bad because batch dim can be dropped | ||
170 | + encoding = encoding.squeeze(3).squeeze(2) | ||
171 | + else: | ||
172 | + raise Exception("Invalid slice dim") | ||
173 | + | ||
174 | + # swap dims for input to attention | ||
175 | + encoding = encoding.permute((0, 2, 1)) | ||
176 | + encoding, attention = self.pooled_attention(encoding) | ||
177 | + return encoding.squeeze(1), attention | ||
178 | + | ||
179 | + def forward(self, x): | ||
180 | + embedding, attention = self.encode(x) | ||
181 | + post = self.attn_post(embedding) | ||
182 | + y_pred = self.regressor(post) | ||
183 | + return Box({"y_pred": y_pred, "attention": attention}) | ||
184 | + | ||
185 | + def get_attention(self, x): | ||
186 | + _, attention = self.encode(x) | ||
187 | + return attention | ||
188 | + | ||
189 | + | ||
190 | +def get_arch(*args, **kwargs): | ||
191 | + return {"net": MRI_ATTN(*args, **kwargs)} |
-
Please register or login to post a comment