-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_plot.py
More file actions
125 lines (94 loc) · 3.01 KB
/
example_plot.py
File metadata and controls
125 lines (94 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 29 15:13:12 2023
@author: SomgBird
"""
import torch
from pathlib import Path
import networkx as nx
import torch_geometric
import matplotlib.pyplot as plt
from models import ChebNetStacked
from data_generator import GNNDataset
path = Path('K:/Dev/PyTorch projects/data/rusautocon2023_data/w3-10_n1.0_s30/')
torch.set_printoptions(sci_mode=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = GNNDataset(root=path/'data', transform=None)
in_channels=1
hidden_channels=128
out_channels=2
K=6
cheb_depth=4
linear_depth=1
linear_channels=512
model = ChebNetStacked(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
K=K,
cheb_depth=cheb_depth,
linear_depth=linear_depth,
linear_channels=linear_channels,
normalization=None,
bias=True,
droprate=0.3).to(device)
model.load_state_dict(torch.load(path/'models/double_chebfc_hc128_hl4_K6_lc512_ll1.pth'))
model.eval()
data = dataset[450].to(device)
y = model(
data.x,
data.edge_index,
data.edge_weight,
None,
4)
pred = y.argmax(dim=1)
data.output = y
data.pred = pred
G = torch_geometric.utils.to_networkx(data, node_attrs=['x', 'output', 'pred'], edge_attrs=['edge_weight'], to_undirected=True)
pos_nodes=nx.spring_layout(G, seed=131232)
pos_labels = {}
for k, v in pos_nodes.items():
pos_labels[k] = (v[0]+0.01, v[1]+0.1)
x_values, y_values = zip(*pos_labels.values())
x_max = max(x_values)
x_min = min(x_values)
x_margin = (x_max - x_min) * 0.4
x_min = x_min - x_margin/7
x_max = x_max + x_margin
y_max = max(y_values)
y_min = min(y_values)
y_margin = (y_max - y_min) * 0.1
y_min = y_min - y_margin
y_max = y_max + y_margin
def draw(G, pos_nodes, node_labels, node_colors, node_size):
fig, ax = plt.subplots(figsize=(7,7))
nx.draw_networkx(G, pos_nodes, with_labels=False, node_size=node_size, ax=ax, node_color=node_colors, cmap=plt.cm.cool)
nx.draw_networkx_labels(G, pos_labels, labels=node_labels, horizontalalignment='left', ax=ax, font_size=20)
plt.box(False)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.show()
# Input data
node_labels = nx.get_node_attributes(G, 'x')
node_colors = []
for key in node_labels.keys():
node_colors.append(node_labels[key])
if node_labels[key] == -1.0:
node_labels[key] = "-1"
else:
node_labels[key] = f"{node_labels[key]:.4f}"
draw(G, pos_nodes, node_labels, node_colors, 300)
# Output data
node_labels = nx.get_node_attributes(G, 'output')
node_colors = []
for key in node_labels.keys():
node_colors.append(node_labels[key][0])
node_labels[key] = f"[{node_labels[key][0]:.4f},\n {node_labels[key][1]:.4f}]"
draw(G, pos_nodes, node_labels, node_colors, 300)
# Postprocessed
node_labels = nx.get_node_attributes(G, 'pred')
node_colors = []
for key in node_labels.keys():
node_colors.append(1/(1 + node_labels[key]))
node_labels[key] = f"{node_labels[key]:.0f}"
draw(G, pos_nodes, node_labels, node_colors, 300)