Neural Graph Ordinary Differential Equations

This tutorial is adapted from SciMLSensitivity, GraphNeuralNetworks, and Lux.

Load the packages

using NeuralGraphPDE, DifferentialEquations
using Lux, NNlib, Optimisers, Zygote, Random
using GraphNeuralNetworks: mldataset2gnngraph
using ComponentArrays, OneHotArrays
using SciMLSensitivity
using Statistics: mean
using MLDatasets: Cora
using CUDA
CUDA.allowscalar(false)
device = CUDA.functional() ? gpu : cpu
cpu (generic function with 2 methods)

Load data

dataset = Cora();
classes = dataset.metadata["classes"]
g = device(mldataset2gnngraph(dataset))
X = g.ndata.features
y = onehotbatch(g.ndata.targets, classes) # a dense matrix is not the optimal
(; train_mask, val_mask, test_mask) = g.ndata
ytrain = y[:, train_mask]
7×140 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  ⋅  ⋅  1  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  1  ⋅  …  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  1  ⋅  1  1  1  1  1  1  1
 ⋅  ⋅  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅  1  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 1  ⋅  ⋅  ⋅  1  ⋅  ⋅  1  1  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  1  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  1     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  …  1  1  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅

Model and data configuration

nin = size(X, 1)
nhidden = 16
nout = length(classes)
epochs = 40
40

Define Neural ODE

struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <:
       Lux.AbstractExplicitContainerLayer{(:model,)}
    model::M
    solver::So
    sensealg::Se
    tspan::T
    kwargs::K
end

function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(),
                   sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
                   tspan=(0.0f0, 1.0f0), kwargs...)
    return NeuralODE(model, solver, sensealg, tspan, kwargs)
end

function (n::NeuralODE)(x, ps, st)
    function dudt(u, p, t)
        u_, st = n.model(u, p, st)
        return u_
    end
    prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
    return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st
end

function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:CuArray}}) where {T, N}
    return dropdims(gpu(x); dims=3)
end
diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims=3)
diffeqsol_to_array (generic function with 2 methods)

Create and initialize the Neural Graph ODE layer

function create_model()
    node_chain = Chain(GCNConv(nhidden => nhidden, relu), GCNConv(nhidden => nhidden, relu))

    node = NeuralODE(node_chain; save_everystep=false, reltol=1e-3, abstol=1e-3,
                     save_start=false)

    model = Chain(GCNConv(nin => nhidden, relu), node, diffeqsol_to_array,
                  Dense(nhidden, nout))

    rng = Random.default_rng()
    Random.seed!(rng, 0)

    ps, st = Lux.setup(rng, model)
    ps = ComponentArray(ps) |> device
    st = updategraph(st, g) |> device

    return model, ps, st
end
create_model (generic function with 1 method)

Define the loss function

logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ); dims=1))

function loss(x, y, mask, model, ps, st)
    ŷ, st = model(x, ps, st)
    return logitcrossentropy(ŷ[:, mask], y), st
end

function eval_loss_accuracy(X, y, mask, model, ps, st)
    ŷ, _ = model(X, ps, st)
    l = logitcrossentropy(ŷ[:, mask], y[:, mask])
    acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask]))
    return (loss=round(l; digits=4), acc=round(acc * 100; digits=2))
end
eval_loss_accuracy (generic function with 1 method)

Train the model

function train()
    model, ps, st = create_model()

    # Optimizer
    opt = Optimisers.Adam(0.01f0)
    st_opt = Optimisers.setup(opt, ps)

    # Training Loop
    for epoch in 1:epochs
        (l, st), back = pullback(p -> loss(X, ytrain, train_mask, model, p, st), ps)
        gs = back((one(l), nothing))[1]
        st_opt, ps = Optimisers.update(st_opt, ps, gs)
        @show eval_loss_accuracy(X, y, val_mask, model, ps, st)
    end
end

train()
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.8719f0, acc = 38.0)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.7944f0, acc = 39.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.6933f0, acc = 41.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.5726f0, acc = 48.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.447f0, acc = 57.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.3408f0, acc = 62.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.2522f0, acc = 68.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.1682f0, acc = 70.0)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.0832f0, acc = 72.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.0095f0, acc = 74.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9572f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9304f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9238f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9119f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.8986f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9082f0, acc = 75.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9463f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9797f0, acc = 75.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9941f0, acc = 75.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.0208f0, acc = 75.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.0752f0, acc = 75.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.1417f0, acc = 75.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.1981f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.2332f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.2646f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.311f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.3783f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.4583f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.5366f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.5999f0, acc = 74.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.6421f0, acc = 73.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.6685f0, acc = 74.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.6907f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.7205f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.7648f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.8221f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.8847f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.9453f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.9962f0, acc = 74.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 2.0326f0, acc = 74.2)