model.py 709 Bytes
import torch.nn as nn
import torch.nn.functional as F
import torch
import const

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.f1 = nn.Sequential(
            nn.Conv2d(1, 2, 3),
            nn.ReLU(True),
        )
        self.f2 = nn.Sequential(
          nn.Conv2d(2, 4, 3),
          nn.ReLU(True),
        )
        self.f3 = nn.Sequential(
          nn.Conv2d(4, 8, 3),
          nn.ReLU(True),
        )
        self.f4 = nn.Sequential(
          nn.Linear(8 * 23 * 23, 2),
        )

    def forward(self, x):
        x = self.f1(x)
        x = self.f2(x)
        x = self.f3(x)
        x = torch.flatten(x, 1)
        x = self.f4(x)
        return x