DGL_图的创建、保存、加载
import dglimport torch as thfrom dgl.data.utils import save_graphsg1 = dgl.DGLGraph()g1.add_nodes(3)g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])g1.ndata["x"] = th.ones(3, 5)# 3个节点的embeddingg1.
·
import dgl
import torch as th
from dgl.data.utils import save_graphs
g1 = dgl.DGLGraph()
g1.add_nodes(3)
g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
g1.ndata["x"] = th.ones(3, 5) # 3个节点的embedding
g1.edata['y'] = th.zeros(6, 5) # 6条边的embedding
# 补充:添加边的方式
# g1.add_edges(th.tensor([3, 4, 5]), 1) # three edges: 3->1, 4->1, 5->1
# g1.add_edges(4, [7, 8, 9]) # three edges: 4->7, 4->8, 4->9
# g1.add_edges([1, 2, 3], [3, 4, 5]) # three edges: 1->3, 2->4, 3->5
g2 = dgl.DGLGraph()
g2.add_nodes(3)
g2.add_edges([0, 1, 2], [1, 2, 1])
g2.edata["e"] = th.ones(3, 4)
graph_labels = {"graph_sizes": th.tensor([3, 3])}
save_graphs("data/try1.bin", [g1, g2], graph_labels)
from dgl.data.utils import load_graphs
from dgl.data.utils import load_labels
# glist, label_dict = load_graphs("data/small.bin") # glist will be [g1, g2]
glist, label_dict = load_graphs("data/try1.bin", [0]) # glist will be [g1]
graph_sizes = load_labels("data/try1.bin")
print(glist)
# [DGLGraph(num_nodes=3, num_edges=6,
# ndata_schemes={'x': Scheme(shape=(5,), dtype=torch.float32)}
# edata_schemes={'y': Scheme(shape=(5,), dtype=torch.float32)})]
print(label_dict)
# {'graph_sizes': tensor([3, 3])}
print(graph_sizes)
# {'graph_sizes': tensor([3, 3])}更多推荐


所有评论(0)