Graph Neural Networks and implementing in TensorFlow

Krishna Chaitanya
7 min readAug 12, 2023

--

This article guide you through the process of understanding Graph Neural Networks (GNNs) and implementing one using TensorFlow. In the followup article we discuss about different variants of GNNs and their implementations. Here’s a step-by-step plan:

  1. Usage of Graph Neural Networks (GNNs): We’ll start by discussing what GNNs are, how they work, and where they are used.
  2. Understanding Graphs: Before we dive into GNNs, it’s important to understand the basics of graphs, including nodes, edges, adjacency matrices, and graph representations.
  3. Understanding Graph Neural Networks: We’ll also briefly touch on the basics of neural networks, as GNNs are a type of neural network.
  4. Variants of Graph Neural Networks (GNNs)
  5. Implementing a GNN with TensorFlow: Finally, we’ll walk through the process of implementing a simple GNN using TensorFlow.

Usage of Graph Neural Networks (GNNs)

Graph Neural Networks (GNNs) are a type of neural network designed to perform machine learning tasks on graph data structures. They are particularly useful for tasks where data is represented as graphs, such as social networks, molecular structures, and recommendation systems.

GNNs work by propagating information from a node to its neighbors. The nodes in the graph are updated based on the states of their neighbors, and this process is repeated for a number of iterations. The final state of the nodes can then be used to make predictions.

For example, in a social network, a GNN could be used to predict the interests of a user based on the interests of their friends. The GNN would start with some initial representation of each user, and then update each user’s representation based on the representations of their friends. After a few iterations, the final representation of each user would capture not just their own interests, but also the interests of their friends, their friends’ friends, and so on.

Understanding Graphs:

A graph is a mathematical structure that models relationships between objects. It consists of nodes (also called vertices) and edges. Nodes represent objects, and edges represent relationships between those objects.

For example, in a social network, each person could be represented by a node, and each friendship could be represented by an edge connecting two nodes.

There are two main types of graphs:

  1. Undirected Graphs: In an undirected graph, the edges do not have a direction. That is, if there is an edge from node A to node B, there is also an edge from node B to node A. An example of this is a Facebook friendship: if person A is friends with person B, then person B is also friends with person A.
  2. Directed Graphs: In a directed graph, the edges do have a direction. That is, if there is an edge from node A to node B, it does not necessarily mean that there is an edge from node B to node A. An example of this is a Twitter follow: if person A follows person B, it does not mean that person B follows person A.

Graphs can be represented in several ways, but one of the most common ways is through an adjacency matrix. An adjacency matrix is a square matrix where the entry in the i-th row and j-th column is equal to the number of edges between nodes i and j. For an undirected graph, the adjacency matrix is symmetric.

Another common representation is the edge list, where each edge is represented by a pair of nodes.

Understanding these basics of graphs is crucial for understanding how Graph Neural Networks work, as they operate directly on the graph structure.

Understanding Graph Neural Networks

GNNs are a type of neural network designed to perform machine learning tasks on graph data structures. They are particularly useful for tasks where data is represented as graphs, such as social networks, molecular structures, and recommendation systems.

The key idea behind GNNs is to capture the dependencies between the connections in the graph. They do this by aggregating the features of neighboring nodes to generate embeddings for each node. These embeddings can then be used to perform various tasks such as node classification, link prediction, and graph classification.

Here’s a more detailed step-by-step process of how GNNs work:

  1. Node Feature Initialization: Each node in the graph is initialized with a feature vector. This could be a one-hot encoding of the node’s label, some real-valued vector specific to the node, or even a vector of zeros.
  2. Feature Aggregation: Each node aggregates the feature vectors of its neighboring nodes to update its own feature vector. This is typically done using a function that takes in the feature vectors of the node and its neighbors and outputs a new feature vector. The function could be a simple average, a weighted sum, or a more complex function.
  3. Feature Transformation: The aggregated feature vector is then transformed, typically using a linear transformation followed by a non-linear activation function. This is similar to what happens in a traditional neural network layer.
  4. Repeat Steps 2 and 3: Steps 2 and 3 are repeated for a certain number of iterations. With each iteration, the nodes aggregate and transform features from a larger and larger neighborhood.
  5. Readout: After the final iteration, a readout function is used to aggregate the feature vectors of all nodes in the graph to produce a graph-level output.

The beauty of GNNs is that they can handle graphs of varying sizes and shapes, and they can capture the local and global structure of the graph.

Implementing a GNN with TensorFlow

There are several libraries built on top of TensorFlow that provide implementations of various types of GNNs, such as Graph Nets and Spektral. We can use one of these libraries to make the implementation process easier.

First, you’ll need to install the Spektral library. You can do this using pip:

pip install spektral

Once you’ve installed Spektral, you can start by importing the necessary libraries:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout
from spektral.layers import GCNConv, global_sum_pool
from spektral.data import DisjointLoader, Dataset
from spektral.datasets import TUDataset

In this example, we’ll use the TUDataset, which is a collection of benchmark datasets for graph classification.

Next, let’s load the dataset:

dataset = TUDataset('PROTEINS')

This will download the PROTEINS dataset, which is a graph classification dataset of protein structures.

  1. Readout: After the final layer, a readout function is used to aggregate the feature vectors of all nodes in the graph to produce a graph-level output.

Now, let’s see how we can implement a simple GraphSAGE model using the Spektral library in TensorFlow:

import spektral
from spektral.layers import GraphSageConv
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dropout, Dense

# Define the model
class GraphSageModel(Model):
def __init__(self, n_hidden, n_labels):
super().__init__()
self.sage_conv1 = GraphSageConv(n_hidden)
self.sage_conv2 = GraphSageConv(n_labels)
self.dropout = Dropout(0.5)
self.dense = Dense(n_labels, 'softmax')

def call(self, inputs, training=False):
x, a = inputs
x = self.dropout(x, training=training)
x = self.sage_conv1([x, a])
x = self.sage_conv2([x, a])
return self.dense(x)

# Instantiate the model
model = GraphSageModel(n_hidden=64, n_labels=dataset.n_labels)

This model takes as input a graph represented by its node features x, adjacency matrix a, and a batch index i. The model first applies dropout to the node features, then applies two graph convolution layers, pools the node features into a graph-level representation, and finally applies a dense layer to predict the class of each graph.

Next, let’s compile and train our model:

model = GNN(n_hidden=64, n_labels=dataset.n_labels)
model.compile('adam', 'categorical_crossentropy', ['acc'])
loader = DisjointLoader(dataset, batch_size=32, epochs=10)
model.fit(loader.load(), steps_per_epoch=loader.steps_per_epoch)

what does global_sum_pool represent?

In the context of Graph Neural Networks (GNNs), pooling is a technique used to aggregate information from an entire graph into a single vector representation. This is particularly useful for graph-level prediction tasks, where we want to make a prediction for an entire graph (as opposed to individual nodes or edges).

global_sum_pool is one such pooling operation provided by the Spektral library. As the name suggests, it simply sums the feature vectors of all nodes in a graph to produce a single vector. This operation is invariant to the order of the nodes in the graph, which is an important property for many graph-based tasks.

It’s worth noting that sum pooling is a very simple pooling operation, and there are many other more complex pooling operations that can be used in GNNs, such as mean pooling, max pooling, and more sophisticated methods like Graph Attention Pooling and Graph Isomorphism Pooling. The choice of pooling operation can have a significant impact on the performance of the GNN, and the best choice often depends on the specific task and data.

What does i denote x = self.pool(x, i)?

The i in the global_sum_pool(x, i) function call represents the batch index for each node.

When you’re working with graph data in a batch setting (i.e., multiple graphs in a single batch), you need a way to indicate which nodes belong to which graphs. This is because unlike images or text data, graphs in a batch can have different sizes (i.e., different numbers of nodes and edges), so you can’t simply stack them in a single tensor.

The batch index i is a vector that assigns each node to a specific graph in the batch. For example, if you have two graphs in a batch, the first with 3 nodes and the second with 2 nodes, your batch index i would be [0, 0, 0, 1, 1]. This indicates that the first three nodes belong to the first graph and the last two nodes belong to the second graph.

In the followup article we discuss about different variants of GNNs and their implementations.

--

--

Krishna Chaitanya
Krishna Chaitanya

Written by Krishna Chaitanya

Krishna is a researcher specializing in deep learning, control theory, and machine learning to optimize data-driven models for off-road mobile robots.

No responses yet