Toggle navigation
Toggle navigation
This project
Loading...
Sign in
Hyunji
/
A-Performance-Evaluation-of-CNN-for-Brain-Age-Prediction-Using-Structural-MRI-Data
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
Hyunji
2021-12-20 04:29:52 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
908ef1601e3dd51b46a8b5bc1808ff7530234d04
908ef160
1 parent
eedd76b3
verify mri lstm
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
127 additions
and
0 deletions
2DCNN/tests/verify_mri_lstm.py
2DCNN/tests/verify_mri_lstm.py
0 → 100644
View file @
908ef16
import
torch
from
torch
import
nn
"""
Code to test LSTM implementation with Lam et.al.
Our implementation use vectorization and should be faster... but need to be verified.
"""
def
encoder_blk
(
in_channels
,
out_channels
):
return
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
3
,
padding
=
1
,
stride
=
1
),
nn
.
InstanceNorm2d
(
out_channels
),
nn
.
MaxPool2d
(
2
,
stride
=
2
),
nn
.
ReLU
()
)
class
MRI_LSTM
(
nn
.
Module
):
def
__init__
(
self
,
lstm_feat_dim
,
lstm_latent_dim
,
*
args
,
**
kwargs
):
super
(
MRI_LSTM
,
self
)
.
__init__
()
self
.
input_dim
=
(
1
,
109
,
91
)
self
.
feat_embed_dim
=
lstm_feat_dim
self
.
latent_dim
=
lstm_latent_dim
# Build Encoder
encoder_blocks
=
[
encoder_blk
(
1
,
32
),
encoder_blk
(
32
,
64
),
encoder_blk
(
64
,
128
),
encoder_blk
(
128
,
256
),
encoder_blk
(
256
,
256
)
]
self
.
encoder
=
nn
.
Sequential
(
*
encoder_blocks
)
# Post processing
self
.
post_proc
=
nn
.
Sequential
(
nn
.
Conv2d
(
256
,
64
,
1
,
stride
=
1
),
nn
.
InstanceNorm2d
(
64
),
nn
.
ReLU
(),
nn
.
AvgPool2d
([
3
,
2
]),
nn
.
Dropout
(
p
=
0.5
),
nn
.
Conv2d
(
64
,
self
.
feat_embed_dim
,
1
)
)
# Connect w/ LSTM
self
.
n_layers
=
1
self
.
lstm
=
nn
.
LSTM
(
self
.
feat_embed_dim
,
self
.
latent_dim
,
self
.
n_layers
,
batch_first
=
True
)
# Build regressor
self
.
lstm_post
=
nn
.
Linear
(
self
.
latent_dim
,
64
)
self
.
regressor
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Linear
(
64
,
1
))
self
.
init_weights
()
def
init_weights
(
self
):
for
k
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
"fan_out"
,
nonlinearity
=
"relu"
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
Linear
)
and
"regressor"
in
k
:
m
.
bias
.
data
.
fill_
(
62.68
)
elif
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
init_hidden
(
self
,
x
):
h_0
=
torch
.
zeros
(
self
.
n_layers
,
x
.
size
(
0
),
self
.
latent_dim
,
device
=
x
.
device
)
c_0
=
torch
.
zeros
(
self
.
n_layers
,
x
.
size
(
0
),
self
.
latent_dim
,
device
=
x
.
device
)
h_0
.
requires_grad
=
True
c_0
.
requires_grad
=
True
return
h_0
,
c_0
def
encode_old
(
self
,
x
,
):
B
,
C
,
H
,
W
,
D
=
x
.
size
()
h_t
,
c_t
=
self
.
init_hidden
(
x
)
for
i
in
range
(
H
):
out
=
self
.
encoder
(
x
[:,
:,
i
,
:,
:])
out
=
self
.
post_proc
(
out
)
out
=
out
.
view
(
B
,
1
,
self
.
feat_embed_dim
)
h_t
=
h_t
.
view
(
1
,
B
,
self
.
latent_dim
)
c_t
=
c_t
.
view
(
1
,
B
,
self
.
latent_dim
)
h_t
,
(
_
,
c_t
)
=
self
.
lstm
(
out
,
(
h_t
,
c_t
))
encoding
=
h_t
.
view
(
B
,
self
.
latent_dim
)
return
encoding
def
encode_new
(
self
,
x
):
h_0
,
c_0
=
self
.
init_hidden
(
x
)
B
,
C
,
H
,
W
,
D
=
x
.
size
()
# convert to 2D images, apply encoder and then reshape for lstm
new_input
=
torch
.
cat
([
x
[:,
:,
i
,
:,
:]
for
i
in
range
(
H
)],
dim
=
0
)
encoding
=
self
.
encoder
(
new_input
)
encoding
=
self
.
post_proc
(
encoding
)
# (BxH) X C_out X W_out X D_out
encoding
=
torch
.
stack
(
torch
.
split
(
encoding
,
B
,
dim
=
0
),
dim
=
2
)
# B X C_out X H X W_out X D_out
encoding
=
encoding
.
squeeze
(
4
)
.
squeeze
(
3
)
# lstm take batch x seq_len x dim
encoding
=
encoding
.
permute
(
0
,
2
,
1
)
_
,
(
encoding
,
_
)
=
self
.
lstm
(
encoding
)
# output is 1 X batch x hidden
encoding
=
encoding
.
squeeze
(
0
)
# pass it to lstm and get encoding
return
encoding
def
forward
(
self
,
x
):
embedding_old
=
self
.
encode_old
(
x
)
embedding_new
=
self
.
encode_new
(
x
)
return
embedding_new
,
embedding_old
if
__name__
==
"__main__"
:
B
=
4
new_model
=
MRI_LSTM
(
lstm_feat_dim
=
2
,
lstm_latent_dim
=
128
)
new_model
.
eval
()
inp
=
torch
.
rand
(
4
,
1
,
91
,
109
,
91
)
output
=
new_model
(
inp
)
print
(
torch
.
allclose
(
output
[
0
],
output
[
1
]))
# breakpoint()
Please
register
or
login
to post a comment