NeuralGraphPDE
Documentation for NeuralGraphPDE.
Features
- Layers and graphs are coupled and decoupled at the same time: You can bind a graph to a layer at initialization, but the graph is stored in
st
, not in the layer. They are decoupled in the sense that you can easily update or change the graph by changingst
:
using NeuralGraphPDE, Random, Lux
g = rand_graph(5, 4; bidirected=false)
x = randn(3, g.num_nodes)
# create layer
l = GCNConv(3 => 5; initialgraph=g)
# setup layer
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, l)
# forward pass
y, st = l(x, ps, st) # you don't need to feed a graph explicitly
#change the graph
new_g = rand_graph(5, 7; bidirected=false)
st = updategraph(st, new_g)
y, st = l(x, ps, st)
([-1.2892177734821657 -1.2709222282216477 … -0.30403150986816907 -0.051113133960325796; -0.006378785979397673 -0.46649378478968345 … 0.3184408274996586 -0.6528909625035688; … ; -0.7239448136837706 -0.8694849206614506 … 0.09199019941242034 -0.7866858541325896; -1.114598683308912 -0.8549547591966108 … -0.29924792936262 -0.16143250287349156], (graph = GNNGraph(5, 7),))
- You can omit the keyword argument
initalgraph
at initialization, and then callupdategraph
onst
to put the graph in it. All gnn layers can work smoothly with other layers defined byLux
.
g = rand_graph(5, 4; bidirected=false)
x = randn(3, g.num_nodes)
model = Chain(Dense(3 => 5), GCNConv(5 => 5), GCNConv(5 => 3)) # you don't need to use `g` for initialization
# setup layer
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model) # the default graph is empty
st = updategraph(st, g) # put the graph in st
# forward pass
y, st = model(x, ps, st)
([1.8060158815335092 -1.6607651913254147 … 2.390543114285151 0.9390967263933919; 1.2397189080259086 -0.23349064693339155 … 1.4497739233496794 0.5905613584971816; 0.9476761631383458 0.12809287458085628 … 1.0537667430866087 0.45529803001843594], (layer_1 = NamedTuple(), layer_2 = (graph = GNNGraph(5, 4),), layer_3 = (graph = GNNGraph(5, 4),)))
An unified interface for graph level tasks. As pointed out here, GNNs are difficult to work well with other neural networks when the input graph is changing. This will not be an issue here. You have an unified interface
y, st = model(x, ps, st)
. There are several benefits to doing so:- Each layer can take in different graphs.
- You can modify the graph inside a layer and return it.
- Multigraphs. A layer can take in any number of graphs in
st
.
Trainable node embeddings and nontrainable features are separately stored in
x
andst.graph
.
Limitations
- We assume all graphs have the same structure.
- The input must be a matrix or a named tuple of matrices.