mobilenet.py 793 Bytes
import torch.nn as nn
from torchvision.models import mobilenet as vmn
import torch.utils.model_zoo as model_zoo

class MobileNet(vmn.MobileNetV2):
    'MobileNetV2: Inverted Residuals and Linear Bottlenecks - https://arxiv.org/abs/1801.04381'

    def __init__(self, outputs=[18], url=None):
        self.stride = 128
        self.url = url
        super().__init__()
        self.outputs = outputs
        self.unused_modules = ['features.18', 'classifier']

    def initialize(self):
        if self.url:
            self.load_state_dict(model_zoo.load_url(self.url))

    def forward(self, x):
        outputs = []
        for indx, feat in enumerate(self.features[:-1]):
            x = feat(x)
            if indx in self.outputs:
                outputs.append(x)
        return outputs