Neural Graph Partial Differential Equations
This tutorial is adapted from the paper LEARNING CONTINUOUS-TIME PDES FROM SPARSE DATA WITH GRAPH NEURAL NETWORKS. We will use VMHConv
to learn the dynamics of the convection-diffusion equation defined as
\[\frac{\partial u(x, y, t)}{\partial t}=0.25 \nabla^{2} u(x, y, t)-\mathbf{v} \cdot \nabla u(x, y, t).\]
Specifically, we will learn the operator from the initial condition to the solution on the given temporal and spatial domain.
Load the packages
using DataDeps, MLUtils, Fetch
using NeuralGraphPDE, Lux, Optimisers, Random
using CUDA, JLD2
using SciMLSensitivity, DifferentialEquations
using Zygote
using Flux.Losses: mse
import Lux: initialparameters, initialstates
using NNlib
using DiffEqFlux: NeuralODE
WARNING: method definition for #DeterministicCNF#57 at /home/runner/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:50 declares type variable RE but does not use it.
WARNING: Method definition (::Type{DiffEqFlux.DeterministicCNF{M, P, RE, D, T, A, K} where K where A where T where D where RE where P where M})(Any, Any, Any, Any, Any, Any, Any) in module DiffEqFlux at /home/runner/.julia/packages/DiffEqFlux/2IJEZ/src/ffjord.jl:41 overwritten at deprecated.jl:103.
** incremental compilation may be fatally broken for this module **
Load data
function register_convdiff()
return register(DataDep("Convection_Diffusion_Equation",
"""
Convection-Diffusion equation dataset from
[Learning continuous-time PDEs from sparse data with graph neural networks](https://github.com/yakovlev31/graphpdes_experiments)
""",
"https://drive.google.com/file/d/1oyatNeLizoO5co2ZVXIwZmWjJ046E9j6/view?usp=sharing";
fetch_method=gdownload))
end
register_convdiff()
function get_data()
data = load(joinpath(datadep"Convection_Diffusion_Equation", "convdiff_n3000.jld2"))
train_data = (data["gs_train"], data["u_train"])
test_data = (data["gs_test"], data["u_test"])
return train_data, test_data, data["dt_train"], data["dt_test"], data["tspan_train"],
data["tspan_test"]
end
train_data, test_data, dt_train, dt_test, tspan_train, tspan_test = get_data()
((GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472) … GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472)], [0.27177718 0.23937638 … 0.42235315 0.41736898; 0.21528548 0.1956899 … 0.43383837 0.42533794; … ; 0.34221563 0.3904251 … 0.5707478 0.55885684; 0.39733648 0.44065434 … 0.5460154 0.5373213;;; 0.6088588 0.6091988 … 0.6238272 0.6237031; 0.6620063 0.65413135 … 0.62142694 0.6214922; … ; 0.68301886 0.62650585 … 0.5867061 0.61041445; 0.15491875 0.18275763 … 0.50021714 0.49130455;;; 0.53923243 0.57713723 … 0.61369747 0.60968894; 0.4529313 0.50201744 … 0.61193085 0.6085183; … ; 0.86270213 0.86968255 … 0.34141946 0.33516222; 0.8416236 0.81410044 … 0.6529443 0.6491246;;; … ;;; 0.53005123 0.5142571 … 0.53653616 0.52344817; 0.5106739 0.49997246 … 0.5831849 0.5668304; … ; 0.5212612 0.5202565 … 0.41233268 0.40637177; 0.54890347 0.5561268 … 0.4212825 0.4196272;;; 0.7143447 0.68895495 … 0.4071307 0.40816742; 0.6922809 0.66664165 … 0.4212077 0.4225582; … ; 0.55738837 0.5547057 … 0.32667074 0.33436233; 0.56979585 0.5690957 … 0.34755823 0.34747365;;; 0.570878 0.5683722 … 0.51076573 0.50354195; 0.5547831 0.55492985 … 0.48554698 0.47597125; … ; 0.47069466 0.51299846 … 0.37722296 0.36119726; 0.4099427 0.44195706 … 0.405135 0.38460922]), (GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472) … GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472), GNNGraph(2912, 17472)], [0.5149223 0.55170023 … 0.6248634 0.62268436; 0.43925944 0.47944194 … 0.6214465 0.62069637; … ; 0.56612545 0.57916313 … 0.6281285 0.6321008; 0.4334923 0.4407123 … 0.58638793 0.58987784;;; 0.22460508 0.26920357 … 0.50755596 0.5097602; 0.19417556 0.23524378 … 0.50171965 0.50386363; … ; 0.553888 0.5182864 … 0.36891004 0.36974698; 0.25890443 0.26049393 … 0.44759774 0.4444001;;; 0.6038077 0.5779638 … 0.525803 0.5221338; 0.6054157 0.580127 … 0.52578473 0.52208436; … ; 0.83374864 0.83018935 … 0.4135459 0.41875854; 0.25839168 0.25704578 … 0.3929463 0.39229473;;; … ;;; 0.5009132 0.491767 … 0.4384263 0.43659583; 0.47495937 0.47529647 … 0.43956596 0.43832278; … ; 0.6088848 0.5462766 … 0.38454705 0.38363633; 0.23869595 0.27031878 … 0.4214499 0.41845664;;; 0.41262475 0.43334606 … 0.50877845 0.5069962; 0.2918109 0.31498253 … 0.5289032 0.52331084; … ; 0.49011603 0.47694343 … 0.46114618 0.4575798; 0.5308936 0.5296045 … 0.4887325 0.48860204;;; 0.042765517 0.07461907 … 0.4805961 0.47712252; 0.06063889 0.05464858 … 0.48565334 0.48300737; … ; 0.28119192 0.2442868 … 0.49623835 0.49621361; 0.5339831 0.55114985 … 0.36704236 0.3728751]), 0.01f0, 0.01f0, (0.0f0, 0.2f0), (0.0f0, 0.6f0))
The training data contains 24 simulations on the time interval $[0,0.2]$. Simulations are observed on different 2D grids with 3000 points. Neighbors for each node were selected by applying Delaunay triangulation to the measurement positions. Two nodes were considered to be neighbors if they lie on the same edge of at least one triangle.
Utilities function
function diffeqsol_to_array(x::ODESolution{T, N, <:AbstractVector{<:CuArray}}) where {T, N}
return gpu(x)
end
diffeqsol_to_array(x::ODESolution) = Array(x)
diffeqsol_to_array (generic function with 2 methods)
Model
We will use only one message passing layer. The layer will have the following structure:
initialparameters(rng::AbstractRNG, node::NeuralODE) = initialparameters(rng, node.model)
initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model)
act = tanh
nhidden = 60
nout = 40
ϕ = Chain(Dense(4 => nhidden, act), Dense(nhidden => nhidden, act),
Dense(nhidden => nhidden, act), Dense(nhidden => nout))
γ = Chain(Dense(nout + 1 => nhidden, act), Dense(nhidden => nhidden, act),
Dense(nhidden => nhidden, act), Dense(nhidden => 1))
gnn = VMHConv(ϕ, γ)
node = NeuralODE(gnn, tspan_train, Tsit5(); saveat=dt_train, reltol=1e-9, abstol=1e-3)
model = Chain(node, diffeqsol_to_array)
Chain(
layer_1 = NeuralODE(), # 19_961 parameters
layer_2 = WrappedFunction(diffeqsol_to_array),
) # Total: 19_961 parameters,
# plus 0 states, summarysize 784 bytes.
Optimiser
Since we only have 24 samples, we will use the Rprop
optimiser.
opt = Rprop(1.0f-6, (5.0f-1, 1.2f0), (1.0f-8, 10.0f0))
Optimisers.Rprop{Float32}(1.0f-6, (0.5f0, 1.2f0), (1.0f-8, 10.0f0))
Loss function
We will use the mse
loss function.
function loss(x, y, ps, st)
ŷ, st = model(x, ps, st)
l = mse(ŷ, y)
return l
end
loss (generic function with 1 method)
Train the model
The solution data has the shape (space_points , time_points, num_samples)
. We will first permute the last two dimensions, resulting in the shape (space_points , num_samples, time_points)
. Then we flatten the first two dimensions, (1, space_points * num_samples, time_points)
, and use the initial condition as the input to the model. The output of the model will be of size (1, space_points * time_points, num_samples)
.
mydevice = CUDA.functional() ? gpu : cpu
train_loader = DataLoader(train_data; batchsize=24, shuffle=true)
rng = Random.default_rng()
Random.seed!(rng, 0)
function train()
ps, st = Lux.setup(rng, model)
ps = Lux.ComponentArray(ps) |> mydevice
st = st |> mydevice
st_opt = Optimisers.setup(opt, ps)
for i in 1:200
for (g, u) in train_loader
g = g |> mydevice
st = updategraph(st, g)
u = u |> mydevice
u0 = reshape(u[:, 1, :], 1, :)
ut = permutedims(u, (1, 3, 2))
ut = reshape(ut, 1, g.num_nodes, :)
l, back = pullback(p -> loss(u0, ut, p, st), ps)
((i - 1) % 10 == 0) && @info "epoch $i | train loss = $l"
gs = back(one(l))[1]
st_opt, ps = Optimisers.update(st_opt, ps, gs)
end
end
end
train()
Expected output
[ Info: epoch 10 | train loss = 0.02720912251427 0.53685 0.425613 0.71604
[ Info: epoch 20 | train loss = 0.026874812
[ Info: epoch 30 | train loss = 0.025392009
[ Info: epoch 40 | train loss = 0.023239506
[ Info: epoch 50 | train loss = 0.010599495
[ Info: epoch 60 | train loss = 0.010421633
[ Info: epoch 70 | train loss = 0.0098072495
[ Info: epoch 80 | train loss = 0.008936066
[ Info: epoch 90 | train loss = 0.0063929264
[ Info: epoch 100 | train loss = 0.004207685
[ Info: epoch 110 | train loss = 0.0026181203
[ Info: epoch 120 | train loss = 0.0023022622
[ Info: epoch 130 | train loss = 0.0019534715
[ Info: epoch 140 | train loss = 0.0017379699
[ Info: epoch 150 | train loss = 0.0015728847
[ Info: epoch 160 | train loss = 0.0013444767
[ Info: epoch 170 | train loss = 0.0012353633
[ Info: epoch 180 | train loss = 0.0011409305
[ Info: epoch 190 | train loss = 0.0010424983
[ Info: epoch 200 | train loss = 0.0009809926