Neural Graph Partial Differential Equations

This tutorial is adapted from the paper LEARNING CONTINUOUS-TIME PDES FROM SPARSE DATA WITH GRAPH NEURAL NETWORKS. We will use VMHConv to learn the dynamics of the convection-diffusion equation defined as

\[\frac{\partial u(x, y, t)}{\partial t}=0.25 \nabla^{2} u(x, y, t)-\mathbf{v} \cdot \nabla u(x, y, t).\]

Specifically, we will learn the operator from the initial condition to the solution on the given temporal and spatial domain.

Load the packages

using DataDeps, MLUtils, Fetch
using NeuralGraphPDE, Lux, Optimisers, Random
using CUDA, JLD2
using SciMLSensitivity, DifferentialEquations
using Zygote
using Flux.Losses: mse
import Lux: initialparameters, initialstates
using NNlib
using DiffEqFlux: NeuralODE
WARNING: method definition for #DeterministicCNF#57 at /home/runner/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:50 declares type variable RE but does not use it.
WARNING: Method definition (::Type{DiffEqFlux.DeterministicCNF{M, P, RE, D, T, A, K} where K where A where T where D where RE where P where M})(Any, Any, Any, Any, Any, Any, Any) in module DiffEqFlux at /home/runner/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:41 overwritten at deprecated.jl:103.
  ** incremental compilation may be fatally broken for this module **

Load data

function register_convdiff()
    return register(DataDep("Convection_Diffusion_Equation",
                            """
                            Convection-Diffusion equation dataset from
                            [Learning continuous-time PDEs from sparse data with graph neural networks](https://github.com/yakovlev31/graphpdes_experiments)
                            """,
                            "https://drive.google.com/file/d/1oyatNeLizoO5co2ZVXIwZmWjJ046E9j6/view?usp=sharing";
                            fetch_method=gdownload))
end

register_convdiff()

function get_data()
    data = load(joinpath(datadep"Convection_Diffusion_Equation", "convdiff_n3000.jld2"))

    train_data = (data["gs_train"], data["u_train"])
    test_data = (data["gs_test"], data["u_test"])
    return train_data, test_data, data["dt_train"], data["dt_test"], data["tspan_train"],
           data["tspan_test"]
end

train_data, test_data, dt_train, dt_test, tspan_train, tspan_test = get_data()
((GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472)  …  GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472)], [0.27177718 0.23937638 … 0.42235315 0.41736898; 0.21528548 0.1956899 … 0.43383837 0.42533794; … ; 0.34221563 0.3904251 … 0.5707478 0.55885684; 0.39733648 0.44065434 … 0.5460154 0.5373213;;; 0.6088588 0.6091988 … 0.6238272 0.6237031; 0.6620063 0.65413135 … 0.62142694 0.6214922; … ; 0.68301886 0.62650585 … 0.5867061 0.61041445; 0.15491875 0.18275763 … 0.50021714 0.49130455;;; 0.53923243 0.57713723 … 0.61369747 0.60968894; 0.4529313 0.50201744 … 0.61193085 0.6085183; … ; 0.86270213 0.86968255 … 0.34141946 0.33516222; 0.8416236 0.81410044 … 0.6529443 0.6491246;;; … ;;; 0.53005123 0.5142571 … 0.53653616 0.52344817; 0.5106739 0.49997246 … 0.5831849 0.5668304; … ; 0.5212612 0.5202565 … 0.41233268 0.40637177; 0.54890347 0.5561268 … 0.4212825 0.4196272;;; 0.7143447 0.68895495 … 0.4071307 0.40816742; 0.6922809 0.66664165 … 0.4212077 0.4225582; … ; 0.55738837 0.5547057 … 0.32667074 0.33436233; 0.56979585 0.5690957 … 0.34755823 0.34747365;;; 0.570878 0.5683722 … 0.51076573 0.50354195; 0.5547831 0.55492985 … 0.48554698 0.47597125; … ; 0.47069466 0.51299846 … 0.37722296 0.36119726; 0.4099427 0.44195706 … 0.405135 0.38460922]), (GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472)  …  GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472)], [0.5149223 0.55170023 … 0.6248634 0.62268436; 0.43925944 0.47944194 … 0.6214465 0.62069637; … ; 0.56612545 0.57916313 … 0.6281285 0.6321008; 0.4334923 0.4407123 … 0.58638793 0.58987784;;; 0.22460508 0.26920357 … 0.50755596 0.5097602; 0.19417556 0.23524378 … 0.50171965 0.50386363; … ; 0.553888 0.5182864 … 0.36891004 0.36974698; 0.25890443 0.26049393 … 0.44759774 0.4444001;;; 0.6038077 0.5779638 … 0.525803 0.5221338; 0.6054157 0.580127 … 0.52578473 0.52208436; … ; 0.83374864 0.83018935 … 0.4135459 0.41875854; 0.25839168 0.25704578 … 0.3929463 0.39229473;;; … ;;; 0.5009132 0.491767 … 0.4384263 0.43659583; 0.47495937 0.47529647 … 0.43956596 0.43832278; … ; 0.6088848 0.5462766 … 0.38454705 0.38363633; 0.23869595 0.27031878 … 0.4214499 0.41845664;;; 0.41262475 0.43334606 … 0.50877845 0.5069962; 0.2918109 0.31498253 … 0.5289032 0.52331084; … ; 0.49011603 0.47694343 … 0.46114618 0.4575798; 0.5308936 0.5296045 … 0.4887325 0.48860204;;; 0.042765517 0.07461907 … 0.4805961 0.47712252; 0.06063889 0.05464858 … 0.48565334 0.48300737; … ; 0.28119192 0.2442868 … 0.49623835 0.49621361; 0.5339831 0.55114985 … 0.36704236 0.3728751]), 0.01f0, 0.01f0, (0.0f0, 0.2f0), (0.0f0, 0.6f0))

The training data contains 24 simulations on the time interval $[0,0.2]$. Simulations are observed on different 2D grids with 3000 points. Neighbors for each node were selected by applying Delaunay triangulation to the measurement positions. Two nodes were considered to be neighbors if they lie on the same edge of at least one triangle.

Utilities function

function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:CuArray}}) where {T, N}
    return gpu(x)
end

diffeqsol_to_array(x::ODESolution) = Array(x)
diffeqsol_to_array (generic function with 2 methods)

Model

We will use only one message passing layer. The layer will have the following structure:

initialparameters(rng::AbstractRNG, node::NeuralODE) = initialparameters(rng, node.model)
initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model)

act = tanh
nhidden = 60
nout = 40

ϕ = Chain(Dense(4 => nhidden, act), Dense(nhidden => nhidden, act),
          Dense(nhidden => nhidden, act), Dense(nhidden => nout))

γ = Chain(Dense(nout + 1 => nhidden, act), Dense(nhidden => nhidden, act),
          Dense(nhidden => nhidden, act), Dense(nhidden => 1))

gnn = VMHConv(ϕ, γ)

node = NeuralODE(gnn, tspan_train, Tsit5(); saveat=dt_train, reltol=1e-9, abstol=1e-3)

model = Chain(node, diffeqsol_to_array)
Chain(
    layer_1 = NeuralODE(),              # 19_961 parameters
    layer_2 = WrappedFunction(diffeqsol_to_array),
)         # Total: 19_961 parameters,
          #        plus 0 states, summarysize 784 bytes.

Optimiser

Since we only have 24 samples, we will use the Rprop optimiser.

opt = Rprop(1.0f-6, (5.0f-1, 1.2f0), (1.0f-8, 10.0f0))
Optimisers.Rprop{Float32}(1.0f-6, (0.5f0, 1.2f0), (1.0f-8, 10.0f0))

Loss function

We will use the mse loss function.

function loss(x, y, ps, st)
    ŷ, st = model(x, ps, st)
    l = mse(ŷ, y)
    return l
end
loss (generic function with 1 method)

Train the model

The solution data has the shape (space_points , time_points, num_samples). We will first permute the last two dimensions, resulting in the shape (space_points , num_samples, time_points). Then we flatten the first two dimensions, (1, space_points * num_samples, time_points), and use the initial condition as the input to the model. The output of the model will be of size (1, space_points * time_points, num_samples).

mydevice = CUDA.functional() ? gpu : cpu
train_loader = DataLoader(train_data; batchsize=24, shuffle=true)

rng = Random.default_rng()
Random.seed!(rng, 0)

function train()
    ps, st = Lux.setup(rng, model)
    ps = Lux.ComponentArray(ps) |> mydevice
    st = st |> mydevice
    st_opt = Optimisers.setup(opt, ps)

    for i in 1:200
        for (g, u) in train_loader
            g = g |> mydevice
            st = updategraph(st, g)
            u = u |> mydevice
            u0 = reshape(u[:, 1, :], 1, :)
            ut = permutedims(u, (1, 3, 2))
            ut = reshape(ut, 1, g.num_nodes, :)

            l, back = pullback(p -> loss(u0, ut, p, st), ps)
            ((i - 1) % 10 == 0) && @info "epoch $i | train loss = $l"
            gs = back(one(l))[1]
            st_opt, ps = Optimisers.update(st_opt, ps, gs)
        end
    end
end

train()

Expected output

[ Info: epoch 10 | train loss = 0.02720912251427  0.53685   0.425613  0.71604
[ Info: epoch 20 | train loss = 0.026874812
[ Info: epoch 30 | train loss = 0.025392009
[ Info: epoch 40 | train loss = 0.023239506
[ Info: epoch 50 | train loss = 0.010599495
[ Info: epoch 60 | train loss = 0.010421633
[ Info: epoch 70 | train loss = 0.0098072495
[ Info: epoch 80 | train loss = 0.008936066
[ Info: epoch 90 | train loss = 0.0063929264
[ Info: epoch 100 | train loss = 0.004207685
[ Info: epoch 110 | train loss = 0.0026181203
[ Info: epoch 120 | train loss = 0.0023022622
[ Info: epoch 130 | train loss = 0.0019534715
[ Info: epoch 140 | train loss = 0.0017379699
[ Info: epoch 150 | train loss = 0.0015728847
[ Info: epoch 160 | train loss = 0.0013444767
[ Info: epoch 170 | train loss = 0.0012353633
[ Info: epoch 180 | train loss = 0.0011409305
[ Info: epoch 190 | train loss = 0.0010424983
[ Info: epoch 200 | train loss = 0.0009809926