models.py
658 Bytes
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphConvolution
class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout):
super(GCN, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nhid)
self.fc1 = nn.Linear(nhid, nhid)
self.fc2 = nn.Linear(nhid, nclass)
self.dropout = dropout
def forward(self, x, adj):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = F.relu(self.gc2(x, adj))
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x