这节课我们简单的介绍一下变分自编网络的Pytorch实践。
GCN网络
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphConvolution(Module):
def __init__(self, in_features, out_features, dropout=0., act=F.relu):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.dropout = dropout
self.act = act
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.weight)
def forward(self, input, adj):
input = F.dropout(input, self.dropout, self.training)
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
output = self.act(output)
return output
encoder/decoder
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class GCNModelVAE(nn.Module):
def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2, dropout):
super(GCNModelVAE, self).__init__()
self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu)
self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
self.gc3 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
self.dc = InnerProductDecoder(dropout, act=lambda x: x)
def encode(self, x, adj):
hidden1 = self.gc1(x, adj)
return self.gc2(hidden1, adj), self.gc3(hidden1, adj)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def forward(self, x, adj):
mu, logvar = self.encode(x, adj)
z = self.reparameterize(mu, logvar)
return self.dc(z), mu, logvar
class InnerProductDecoder(nn.Module):
"""Decoder for using inner product for prediction."""
def __init__(self, dropout, act=torch.sigmoid):
super(InnerProductDecoder, self).__init__()
self.dropout = dropout
self.act = act
def forward(self, z):
z = F.dropout(z, self.dropout, training=self.training)
adj = self.act(torch.mm(z, z.t()))
return adj
loss function
1
2
3
4
5
6
7
8
9
10
def loss_function(preds, labels, mu, logvar, n_nodes, norm, pos_weight):
cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 / n_nodes * torch.mean(torch.sum(
1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1))
return cost + KLD
pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
PREVIOUS自编码网络