1D Convection Equation
Consider the following 1D-convection equation with periodic boundary conditions.
\[\begin{aligned} &\frac{\partial u}{\partial t}+c \frac{\partial u}{\partial x}=0, x \in[0,1], t \in[0,1] \\ &u(x, 0)=sin(2\pi x) \\ \end{aligned}\]
First we define the PDE.
using ModelingToolkit, Sophon, IntervalSets, CairoMakie
using Optimization, OptimizationOptimJL, Zygote
@parameters x, t
@variables u(..)
Dₜ = Differential(t)
Dₓ = Differential(x)
c = 6
eq = Dₜ(u(x,t)) + c * Dₓ(u(x,t)) ~ 0
u_analytic(x,t) = sinpi(2*(x-c*t))
domains = [x ∈ 0..1, t ∈ 0..1]
bcs = [u(x,0) ~ u_analytic(x,0)]
@named convection = PDESystem(eq, bcs, domains, [x,t], [u(x,t)])
\[ \begin{align} \frac{\mathrm{d}}{\mathrm{d}t} u\left( x, t \right) + 6 \frac{\mathrm{d}}{\mathrm{d}x} u\left( x, t \right) =& 0 \end{align} \]
Imposing periodic boundary conditions
We will use BACON
to impose the boundary conditions. To this end, we simply set period
to be one.
chain = BACON(2, 1, 8, 1; hidden_dims = 32, num_layers=4)
MultiplicativeFilterNet(
filters = BranchLayer(
filter_1 = DiscreteFourierFeature(2 => 32), # 32 parameters, plus 64
filter_2 = DiscreteFourierFeature(2 => 32), # 32 parameters, plus 64
filter_3 = DiscreteFourierFeature(2 => 32), # 32 parameters, plus 64
filter_4 = DiscreteFourierFeature(2 => 32), # 32 parameters, plus 64
),
linear_layers = PairwiseFusion(
Base.Broadcast.BroadcastFunction(*)
layer_1 = Dense(32 => 32), # 1_056 parameters
layer_2 = Dense(32 => 32), # 1_056 parameters
layer_3 = Dense(32 => 32), # 1_056 parameters
),
output_layer = Dense(32 => 1), # 33 parameters
) # Total: 3_329 parameters,
# plus 256 states.
For demonstration purposes, the model is also periodic in time
sampler = QuasiRandomSampler(500, 100) # data points
strategy = NonAdaptiveTraining(1 , 500) # weights
pinn = PINN(chain)
prob = Sophon.discretize(convection, pinn, sampler, strategy)
@showprogress res = Optimization.solve(prob, BFGS(); maxiters = 1000)
u: ComponentVector{Float64}(filters = (filter_1 = (bias = [0.8511742115747291; -0.9655383362032854; … ; 0.592469997595307; 0.646341853308529;;]), filter_2 = (bias = [-0.24584995537673848; -0.47062894738197114; … ; 0.1277210696445003; -0.3416910111815472;;]), filter_3 = (bias = [-0.1861668732798523; 0.26962472217247563; … ; -0.34320420085364406; -0.514279005264565;;]), filter_4 = (bias = [-0.06555181730795076; 0.19450502864192357; … ; -0.8382534133249133; 0.05317308265552186;;])), linear_layers = (layer_1 = (weight = [0.03487097217568845 -0.18830691630184593 … 0.03912231844695687 -0.14463622602933007; 0.14264489205528386 0.2873548277638305 … 0.5205417663460344 0.08751299358266598; … ; 0.27712527737439174 0.2918094773307714 … 0.10129721697110528 0.3195233838565558; 0.3097249541428214 0.3067647952354178 … -0.024939045356847708 -0.044673191601400826], bias = [0.05354153349740081; -0.12938866638596483; … ; 0.042883304698707714; 0.05304785810636943;;]), layer_2 = (weight = [0.10588797042265151 0.2126542725589382 … 0.36424745926758884 0.04771779217803841; -0.28387759733828555 -0.03307647302885385 … -0.4271846180446414 -0.3886137942806132; … ; -0.3598508196405401 -0.35529074449818 … 0.3314769340001686 0.3498325169999393; 0.3331773749952995 0.022450727650030422 … -0.09549727824845429 0.31612706546606584], bias = [0.08424162268611006; 0.03974763869715846; … ; -0.01951158435710198; -0.09194834419409212;;]), layer_3 = (weight = [-0.563320024725325 0.0685212981160522 … -0.08601974597241602 0.12115536451477492; -0.10826122662576301 -0.05526531536653446 … -0.3709548261830361 -0.05834719588339079; … ; -0.1918503967129522 -0.05906955346859778 … 0.06290981946166131 0.059640117149913636; 0.04206141138328464 -0.3576964010474074 … 0.45639013350057256 -0.3464067625070039], bias = [-0.11347165853157569; 0.09392419463481437; … ; 0.07309306726141365; -0.027756972251293366;;])), output_layer = (weight = [0.20377837234872362 -0.12143621648644912 … -0.2303126695481443 -0.0005234735718424548], bias = [-0.19941491782498982;;]))
Let's visualize the result.
phi = pinn.phi
xs, ts= [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
u_pred = [sum(phi([x,t],res.u)) for x in xs, t in ts]
u_real = u_analytic.(xs,ts')
fig, ax, hm = heatmap(ts, xs, u_pred', axis=(xlabel="t", ylabel="x", title="c = $c"))
ax2, hm2 = heatmap(fig[1,end+1], ts,xs, abs.(u_pred' .- u_real'), axis = (xlabel="t", ylabel="x", title="Absolute error"))
Colorbar(fig[:, end+1], hm2)
display(fig)
We can verify that our model is indeed, periodic.
xs, ts= [infimum(d.domain):0.01:supremum(d.domain)*2 for d in domains]
u_pred = [sum(phi([x,t],res.u)) for x in xs, t in ts]
fig, ax, hm = heatmap(ts, xs, u_pred', axis=(xlabel="t", ylabel="x", title="c = $c"))
display(fig)