-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
127 lines (98 loc) · 3.95 KB
/
model.py
File metadata and controls
127 lines (98 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.pytorch import GraphConv, GINConv
import math
from utils import idx_shuffle, row_normalization
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, activation, layers, dropout):
super(MLP, self).__init__()
# Input layer
encoder = [nn.Linear(in_dim, out_dim), activation, nn.Dropout(dropout)]
# Hidden layers
for i in range(layers - 1):
encoder.append(nn.Linear(out_dim, out_dim))
encoder.append(nn.Dropout(dropout))
encoder.append(activation)
# Sequential model
self.mlp = nn.ModuleList(encoder)
def forward(self, feats):
h = feats
for layer in self.mlp:
h = layer(h)
h = F.normalize(h, p=2, dim=1) #
return h
class GNN(nn.Module):
def __init__(self, g, in_dim, hid_dim, activation, layers, dropout):
super(GNN, self).__init__()
self.g = g
self.gcn = nn.ModuleList()
self.dropout = nn.Dropout(p=dropout)
# # GCN
# # input layer
# self.gcn.append(GraphConv(in_dim, hid_dim, activation=activation))
# # hid layer
# for i in range(layers - 1):
# self.gcn.append(GraphConv(hid_dim, hid_dim, activation=activation))
# GIN
self.gcn.append(GINConv(nn.Linear(in_dim, hid_dim), learn_eps=True))
for _ in range(layers - 1):
self.gcn.append(GINConv(nn.Linear(hid_dim, hid_dim), learn_eps=True))
def forward(self, feats):
h = feats
for layer in self.gcn:
h = layer(self.g, h)
h = self.dropout(h)
return h
class MeanAggregator(nn.Module):
def __init__(self):
super(MeanAggregator, self).__init__()
def extract_H_diff(self, graph, h, cluster_ids, mode = 'local'):
if mode == 'local':
# local diff
with graph.local_scope():
graph.ndata['h'] = h
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
neigh_means = graph.ndata['neigh']
diff = h - neigh_means
return diff
elif mode == 'cluster':
# cluster diff
if cluster_ids is None:
raise ValueError("cluster_ids is required when mode='cluster'")
cluster_ids = torch.tensor(cluster_ids)
unique_clusters = torch.unique(cluster_ids)
cluster_means = []
for c in unique_clusters:
indices = torch.where(cluster_ids == c)[0]
cluster_h = h[indices]
mean_h = torch.mean(cluster_h, dim=0)
cluster_means.append(mean_h)
cluster_means = torch.stack(cluster_means)
expanded_cluster_means = cluster_means[cluster_ids]
diff = h - expanded_cluster_means
return diff
else:
raise ValueError("Invalid mode. Supported modes are 'local' and 'cluster'.")
class Discriminator(nn.Module):
def __init__(self, hid_dim) -> None:
super().__init__()
self.weight = nn.Parameter(torch.Tensor(hid_dim, hid_dim))
self.reset_parameters()
def uniform(self, size, tensor):
bound = 1.0 / math.sqrt(size)
if tensor is not None:
tensor.data.uniform_(-bound, bound)
def reset_parameters(self):
size = self.weight.size(0)
self.uniform(size, self.weight)
def forward(self, features, centers, mode):
assert mode=='local' or 'global', "mode must be local or global"
if mode == 'local':
tmp = torch.matmul(features, self.weight) # tmp = xW^T
res = torch.sum(tmp * centers, dim=1) # res = <tmp, s>
else:
res = torch.matmul(features, torch.matmul(self.weight, centers)) # xW^Tg
return res