attention_module.py 4.77 KB
"""
    기존의 Self Attention을 경량화한 CONVOLUTION BLOCK ATTENTION MODULE
    @FUNCTION se_block : Squeeze and Excitation Block
    @FUNCTION cbam_block : Convolution Block Attetntion Module
    @FUNCTION channel_attention : Channel Attention
    @FUNCITON Spatial_attention : Spation_attention
"""
import numpy as np
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import GlobalAvgPool2D, GlobalMaxPool2D
from tensorflow.keras.layers import Reshape, Dense, Permute, Lambda
from tensorflow.keras.layers import Add, Activation
from tensorflow.keras import backend as K
from keras.activations import sigmoid
from tensorflow.keras import layers

"""
    Squeeze-and-Excitation(SE) Block
    @brief : 채널간의 관계를 재종정 시켜줌
    @param input_feature : tensor
"""
def se_block(input_feature, ratio=8):
    
    se_feature = GlobalAvgPool2D()(input_feature)
    channel = input_feature._shape[-1]
    
    se_feature = Reshape((1, 1, channel))(se_feature)
    se_feature = Dense(channel // ratio,
                       activation='relu',
                       kernel_initializer='he_normal',
                       use_bias=True,
                       bias_initializer='zeros')(se_feature)
    
    se_feature = Dense(channel,
                       activation='sigmoid',
                       kernel_initializer='he_normal',
                       use_bias=True,
                       bias_initializer='zeros')(se_feature)
    
    se_feature = layers.multiply([input_feature, se_feature])
    
    return se_feature

"""
    CBAM_BLOCK
    @brief : Convolution Block Attention Module
    @param cbam_feature : input tensor
    @param ratio(int) : channel reduce ratio
    @return cbam_feature : dynamic feature selection
"""
def cbam_block(cbam_feature, ratio=8):
    
    cbam_feature = channel_attention(cbam_feature, ratio)
    cbam_feature = spatial_attention(cbam_feature)
    
    return cbam_feature

"""
    Channel Attention
    @brief : Channel Attention, average pool과 max pool을 사용(파라미터 양을 줄일 수 있음)
            두 가지 pooled feature는 같은 의미를 공유하는 값이기 때문에 하나의 공유된 MLP를 사용
    @param input_feature = input_tensor
    @return cbam_feature
"""
def channel_attention(input_feature, ratio=8):
    
    # 채널을 먼저 적용
    channel = input_feature._shape[-1]
    
    shared_layer_one = Dense(channel//ratio,
                             activation='relu',
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')
    
    shared_layer_two = Dense(channel,
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')
    
    # average pool과 max pool 두 가지를 결합하여 사용
    avg_pool = GlobalAvgPool2D()(input_feature)
    avg_pool = Reshape((1, 1, channel))(avg_pool)
    avg_pool = shared_layer_one(avg_pool)
    avg_pool = shared_layer_two(avg_pool)

    max_pool = GlobalMaxPool2D()(input_feature)
    max_pool = Reshape((1, 1, channel))(max_pool)
    max_pool = shared_layer_one(max_pool)
    max_pool = shared_layer_two(max_pool)
    
    cbam_feature = Add()([avg_pool, max_pool])
    # 가장 중요한 feature를 찾는 것이 목적이 아니기 때문에 mutually exclusive한
    # softmax 대신 sigmoid를 사용
    cbam_feature = Activation('sigmoid')(cbam_feature)
    cbam_feature = layers.multiply([avg_pool, max_pool])
    return cbam_feature

"""
    Spatial Attention
    @brief : 2차원의 spatial attention, single convolution을 사용하여 특징이 보이는 
            channel을 만듬, 정보가 어디에 있는지 중점을 둠
    @param ipnut_feature : input_tensor(Channel-refined feature)
"""
def spatial_attention(input_feature, kernel_size=7):
    
    cbam_feature = input_feature
    
    avg_pool = Lambda(lambda x : K.mean(x, axis=3, keepdims=True))(cbam_feature)
    max_pool = Lambda(lambda x : K.max(x, axis=3, keepdims=True))(cbam_feature)
    concat = layers.concatenate([avg_pool, max_pool])
    cbam_feature = Conv2D(filters=1,
                          kernel_size=kernel_size,
                          strides=1,
                          padding='same',
                          activation='sigmoid',
                          kernel_initializer='he_normal',
                          use_bias=False)(concat)
    
    cbam_feature = Conv2D(filters=1,
                          kernel_size=kernel_size,
                          strides=1,
                          padding='same',
                          activation='sigmoid',
                          kernel_initializer='he_normal')(concat)
    
    return layers.multiply([input_feature, cbam_feature])