Source code for pgl.sampling.custom

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pgl
import numpy as np
from pgl.graph import Graph

__all__ = []
__all__.append("subgraph")


[docs]def subgraph(graph, nodes, eid=None, edges=None, with_node_feat=True, with_edge_feat=True): """Generate subgraph with nodes and edge ids. This function will generate a :code:`pgl.graph.Subgraph` object and copy all corresponding node and edge features. Nodes and edges will be reindex from 0. Eid and edges can't both be None. WARNING: ALL NODES IN EID MUST BE INCLUDED BY NODES Args: nodes: Node ids which will be included in the subgraph. eid (optional): Edge ids which will be included in the subgraph. edges (optional): Edge(src, dst) list which will be included in the subgraph. with_node_feat: Whether to inherit node features from parent graph. with_edge_feat: Whether to inherit edge features from parent graph. Return: A :code:`pgl.Graph` object. """ assert not graph.is_tensor(), "You must call Graph.numpy() first." if eid is None and edges is None: raise ValueError("Eid and edges can't be None at the same time.") reindex = {} for ind, node in enumerate(nodes): reindex[node] = ind sub_edge_feat = {} if edges is None: edges = graph._edges[eid] else: edges = np.array(edges, dtype="int64") if with_edge_feat: for key, value in graph._edge_feat.items(): if eid is None: raise ValueError("Eid can not be None with edge features.") sub_edge_feat[key] = value[eid] sub_edges = pgl.graph_kernel.map_edges( np.arange( len(edges), dtype="int64"), edges, reindex) sub_node_feat = {} if with_node_feat: for key, value in graph._node_feat.items(): sub_node_feat[key] = value[nodes] g = Graph( edges=sub_edges, num_nodes=len(nodes), node_feat=sub_node_feat, edge_feat=sub_edge_feat) return g