Neural Graph Ordinary Differential Equations
This tutorial is adapted from SciMLSensitivity, GraphNeuralNetworks, and Lux.
Load the packages
using NeuralGraphPDE, DifferentialEquations
using Lux, NNlib, Optimisers, Zygote, Random
using GraphNeuralNetworks: mldataset2gnngraph
using ComponentArrays, OneHotArrays
using SciMLSensitivity
using Statistics: mean
using MLDatasets: Cora
using CUDA
CUDA.allowscalar(false)
device = CUDA.functional() ? gpu : cpu
cpu (generic function with 2 methods)
Load data
dataset = Cora();
classes = dataset.metadata["classes"]
g = device(mldataset2gnngraph(dataset))
X = g.ndata.features
y = onehotbatch(g.ndata.targets, classes) # a dense matrix is not the optimal
(; train_mask, val_mask, test_mask) = g.ndata
ytrain = y[:, train_mask]
7×140 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 1 ⋅ … ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ 1 1 1 1 1 1 1
⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
1 ⋅ ⋅ ⋅ 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ … 1 1 1 ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
Model and data configuration
nin = size(X, 1)
nhidden = 16
nout = length(classes)
epochs = 40
40
Define Neural ODE
struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
solver::So
sensealg::Se
tspan::T
kwargs::K
end
function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(),
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
tspan=(0.0f0, 1.0f0), kwargs...)
return NeuralODE(model, solver, sensealg, tspan, kwargs)
end
function (n::NeuralODE)(x, ps, st)
function dudt(u, p, t)
u_, st = n.model(u, p, st)
return u_
end
prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps)
return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st
end
function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:CuArray}}) where {T, N}
return dropdims(gpu(x); dims=3)
end
diffeqsol_to_array(x::ODESolution) = dropdims(Array(x); dims=3)
diffeqsol_to_array (generic function with 2 methods)
Create and initialize the Neural Graph ODE layer
function create_model()
node_chain = Chain(GCNConv(nhidden => nhidden, relu), GCNConv(nhidden => nhidden, relu))
node = NeuralODE(node_chain; save_everystep=false, reltol=1e-3, abstol=1e-3,
save_start=false)
model = Chain(GCNConv(nin => nhidden, relu), node, diffeqsol_to_array,
Dense(nhidden, nout))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps) |> device
st = updategraph(st, g) |> device
return model, ps, st
end
create_model (generic function with 1 method)
Define the loss function
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ); dims=1))
function loss(x, y, mask, model, ps, st)
ŷ, st = model(x, ps, st)
return logitcrossentropy(ŷ[:, mask], y), st
end
function eval_loss_accuracy(X, y, mask, model, ps, st)
ŷ, _ = model(X, ps, st)
l = logitcrossentropy(ŷ[:, mask], y[:, mask])
acc = mean(onecold(ŷ[:, mask]) .== onecold(y[:, mask]))
return (loss=round(l; digits=4), acc=round(acc * 100; digits=2))
end
eval_loss_accuracy (generic function with 1 method)
Train the model
function train()
model, ps, st = create_model()
# Optimizer
opt = Optimisers.Adam(0.01f0)
st_opt = Optimisers.setup(opt, ps)
# Training Loop
for epoch in 1:epochs
(l, st), back = pullback(p -> loss(X, ytrain, train_mask, model, p, st), ps)
gs = back((one(l), nothing))[1]
st_opt, ps = Optimisers.update(st_opt, ps, gs)
@show eval_loss_accuracy(X, y, val_mask, model, ps, st)
end
end
train()
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.8719f0, acc = 38.0)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.7944f0, acc = 39.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.6933f0, acc = 41.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.5726f0, acc = 48.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.447f0, acc = 57.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.3408f0, acc = 62.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.2522f0, acc = 68.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.1682f0, acc = 70.0)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.0832f0, acc = 72.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.0095f0, acc = 74.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9572f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9304f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9238f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9119f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.8986f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9082f0, acc = 75.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9463f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9797f0, acc = 75.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 0.9941f0, acc = 75.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.0208f0, acc = 75.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.0752f0, acc = 75.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.1417f0, acc = 75.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.1981f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.2332f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.2646f0, acc = 75.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.311f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.3783f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.4583f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.5366f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.5999f0, acc = 74.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.6421f0, acc = 73.8)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.6685f0, acc = 74.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.6907f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.7205f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.7648f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.8221f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.8847f0, acc = 74.4)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.9453f0, acc = 74.6)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 1.9962f0, acc = 74.2)
eval_loss_accuracy(X, y, val_mask, model, ps, st) = (loss = 2.0326f0, acc = 74.2)