1D Convection Equation

Consider the following 1D-convection equation with periodic boundary conditions.

\[\begin{aligned} &\frac{\partial u}{\partial t}+c \frac{\partial u}{\partial x}=0, x \in[0,1], t \in[0,1] \\ &u(x, 0)=sin(2\pi x) \\ \end{aligned}\]

First we define the PDE.

using ModelingToolkit, Sophon, IntervalSets, CairoMakie
using Optimization, OptimizationOptimJL, Zygote

@parameters x, t
@variables u(..)
Dₜ = Differential(t)
Dₓ = Differential(x)

c = 6
eq = Dₜ(u(x,t)) + c * Dₓ(u(x,t)) ~ 0
u_analytic(x,t) = sinpi(2*(x-c*t))

domains = [x ∈ 0..1, t ∈ 0..1]

bcs = [u(x,0) ~ u_analytic(x,0)]

@named convection = PDESystem(eq, bcs, domains, [x,t], [u(x,t)])

\[ \begin{align} \frac{\mathrm{d}}{\mathrm{d}t} u\left( x, t \right) + 6 \frac{\mathrm{d}}{\mathrm{d}x} u\left( x, t \right) =& 0 \end{align} \]

Imposing periodic boundary conditions

We will use BACON to impose the boundary conditions. To this end, we simply set period to be one.

chain = BACON(2, 1, 8, 1; hidden_dims = 32, num_layers=4)
MultiplicativeFilterNet(
    filters = BranchLayer(
        filter_1 = DiscreteFourierFeature(2 => 32),  # 32 parameters, plus 64
        filter_2 = DiscreteFourierFeature(2 => 32),  # 32 parameters, plus 64
        filter_3 = DiscreteFourierFeature(2 => 32),  # 32 parameters, plus 64
        filter_4 = DiscreteFourierFeature(2 => 32),  # 32 parameters, plus 64
    ),
    linear_layers = PairwiseFusion(
        Base.Broadcast.BroadcastFunction(*)
        layer_1 = Dense(32 => 32),      # 1_056 parameters
        layer_2 = Dense(32 => 32),      # 1_056 parameters
        layer_3 = Dense(32 => 32),      # 1_056 parameters
    ),
    output_layer = Dense(32 => 1),      # 33 parameters
)         # Total: 3_329 parameters,
          #        plus 256 states.
Note

For demonstration purposes, the model is also periodic in time

sampler = QuasiRandomSampler(500, 100) # data points
strategy = NonAdaptiveTraining(1 , 500) # weights
pinn = PINN(chain)

prob = Sophon.discretize(convection, pinn, sampler, strategy)

@showprogress res = Optimization.solve(prob, BFGS(); maxiters = 1000)
retcode: Failure
u: ComponentVector{Float64}(filters = (filter_1 = (bias = [-1.010714923999912; -0.3888416767219043; … ; 0.5654644320695815; 0.2473021992091934;;]), filter_2 = (bias = [0.8242169169427611; -0.0585136392229339; … ; 0.25366503926878187; -0.7580452240166924;;]), filter_3 = (bias = [-0.7789688235724138; 0.5729476482476413; … ; 0.20652518704371556; 0.8972185678546548;;]), filter_4 = (bias = [-0.16825893508890838; -0.649257344859488; … ; 0.7100324181753614; -1.336598483300285;;])), linear_layers = (layer_1 = (weight = [0.14078942930918178 -0.38445971230871645 … -0.26574407152648716 0.34427137682438164; -0.01965093105451913 0.03675046422685971 … 0.23234683852851779 -0.03611911335721922; … ; -0.17962502060236965 0.07378767536336585 … -0.23654489559601463 -0.03991236595878216; -0.11772241482918668 -0.09205608683033016 … -0.5777847853863197 0.18080554247175035], bias = [0.06660598219919296; -0.07364273939318124; … ; 0.009053046745646186; 0.07958025100582719;;]), layer_2 = (weight = [-0.1971063906255166 0.2504977258782329 … 0.49734956105122746 -0.1607622162431413; 0.0797897291608867 -0.04338676712699259 … -0.14354636673623042 0.30672997282864667; … ; -0.13878729420149452 -0.1281461435321874 … -0.0713007497903937 -0.1731449776094634; -0.2873583329728899 0.39980531039000083 … -0.45120480532462914 0.06897097641278123], bias = [-0.05487219692978632; 0.12626064794707473; … ; -0.03958924287778301; 0.009449755122876428;;]), layer_3 = (weight = [-0.23765344754909348 -0.039824490561383874 … -0.6154167086633562 -0.044711846184721; 0.2390938822563206 -0.23198340684823418 … 0.3240346530248625 -0.016667690219904604; … ; -0.4029304486458164 0.37490939652029176 … 0.38825890882164743 0.29789393852744894; 0.8091428627948413 0.2920871646558002 … 0.18581825333851876 -0.040327186352224334], bias = [-0.0012145927523829945; -0.029944140815509997; … ; 0.016003958598256332; -0.09814374509581479;;])), output_layer = (weight = [0.13674240429800133 -0.018921117517511556 … -0.001095372655588073 0.1415145207922058], bias = [0.04482818063003649;;]))

Let's visualize the result.

phi = pinn.phi

xs, ts= [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
u_pred = [sum(phi([x,t],res.u)) for x in xs, t in ts]
u_real = u_analytic.(xs,ts')
fig, ax, hm = heatmap(ts, xs, u_pred', axis=(xlabel="t", ylabel="x", title="c = $c"))
ax2, hm2 = heatmap(fig[1,end+1], ts,xs, abs.(u_pred' .- u_real'), axis = (xlabel="t", ylabel="x", title="Absolute error"))
Colorbar(fig[:, end+1], hm2)
display(fig)

We can verify that our model is indeed, periodic.

xs, ts= [infimum(d.domain):0.01:supremum(d.domain)*2 for d in domains]
u_pred = [sum(phi([x,t],res.u)) for x in xs, t in ts]
fig, ax, hm = heatmap(ts, xs, u_pred', axis=(xlabel="t", ylabel="x", title="c = $c"))
display(fig)