Schrödinger equation

The nonlinear Shrödinger equation is given by

\[\mathrm{i} \partial_t \psi=-\frac{1}{2} \sigma \partial_{x x} \psi-\beta|\psi|^2 \psi\]

Let $\sigma=\beta=1, \psi=u+v i$, the equation can be transformed into a system of partial differential equations

using ModelingToolkit, IntervalSets, Sophon, CairoMakie
using Optimization, OptimizationOptimJL, Zygote

@parameters x,t
@variables u(..), v(..)
Dₜ = Differential(t)
Dₓ² = Differential(x)^2

eqs=[Dₜ(u(x,t)) ~ -Dₓ²(v(x,t))/2 - (abs2(v(x,t)) + abs2(u(x,t))) * v(x,t),
     Dₜ(v(x,t)) ~  Dₓ²(u(x,t))/2 + (abs2(v(x,t)) + abs2(u(x,t))) * u(x,t)]

bcs = [u(x, 0.0) ~ 2sech(x),
       v(x, 0.0) ~ 0.0,
       u(-5.0, t) ~ u(5.0, t),
       v(-5.0, t) ~ v(5.0, t)]

domains = [x ∈ Interval(-5.0, 5.0),
           t ∈ Interval(0.0, π/2)]

@named pde_system = PDESystem(eqs, bcs, domains, [x,t], [u(x,t),v(x,t)])

\[ \begin{align} \frac{\mathrm{d}}{\mathrm{d}t} u\left( x, t \right) =& - \frac{1}{2} \frac{\mathrm{d}}{\mathrm{d}x} \frac{\mathrm{d}}{\mathrm{d}x} v\left( x, t \right) - v\left( x, t \right) \left( \left|u\left( x, t \right)\right|^{2} + \left|v\left( x, t \right)\right|^{2} \right) \\ \frac{\mathrm{d}}{\mathrm{d}t} v\left( x, t \right) =& \frac{1}{2} \frac{\mathrm{d}}{\mathrm{d}x} \frac{\mathrm{d}}{\mathrm{d}x} u\left( x, t \right) + u\left( x, t \right) \left( \left|u\left( x, t \right)\right|^{2} + \left|v\left( x, t \right)\right|^{2} \right) \end{align} \]

pinn = PINN(u = Siren(2,1; hidden_dims=16,num_layers=4, omega = 1.0),
            v = Siren(2,1; hidden_dims=16,num_layers=4, omega = 1.0))

sampler = QuasiRandomSampler(500, (200,200,20,20))
strategy = NonAdaptiveTraining(1,(10,10,1,1))

prob = Sophon.discretize(pde_system, pinn, sampler, strategy)
OptimizationProblem. In-place: true
u0: ComponentVector{Float64}(u = (layer_1 = (weight = [0.12437975406646729 0.44595426321029663; 0.1801026463508606 -0.04881483316421509; … ; -0.05987071990966797 0.03644704818725586; 0.004447221755981445 -0.1341557502746582], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = [-0.005493796896189451 -0.3063620626926422 … 0.0657891258597374 0.22709763050079346; -0.21412229537963867 -0.5509554147720337 … -0.5415192246437073 0.012914660386741161; … ; -0.25463923811912537 -0.21366985142230988 … -0.2098701000213623 0.5991225242614746; -0.254373162984848 0.27955687046051025 … 0.4434380531311035 0.2497818022966385], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = [-0.07657640427350998 -0.36847707629203796 … 0.56789231300354 -0.03934645652770996; 0.21456432342529297 0.4294339418411255 … -0.49449408054351807 0.5622052550315857; … ; 0.08706934750080109 0.05164017528295517 … -0.36894717812538147 0.38920629024505615; -0.4709374010562897 -0.029357435181736946 … -0.09887104481458664 -0.30283233523368835], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = [0.6122332811355591 0.19558975100517273 … -0.20422548055648804 -0.589148998260498; -0.20462055504322052 0.36808279156684875 … -0.5213020443916321 0.5447372794151306; … ; 0.5705767273902893 -0.28328827023506165 … 0.38541704416275024 -0.16633224487304688; -0.6027133464813232 -0.09377802163362503 … -0.05540385842323303 -0.2788461446762085], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_5 = (weight = [0.5084771513938904 -0.21528089046478271 … -0.27638763189315796 0.45121902227401733], bias = [0.0;;])), v = (layer_1 = (weight = [0.4386160373687744 0.20605236291885376; 0.09716188907623291 -0.12628185749053955; … ; -0.3785223960876465 -0.24050813913345337; 0.4256352186203003 -0.4405507445335388], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = [0.5219868421554565 -0.055293262004852295 … -0.4120663106441498 0.06788496673107147; -0.33393746614456177 -0.257290780544281 … 0.2149340659379959 0.05264429375529289; … ; -0.1605721414089203 -0.6067885756492615 … -0.25411051511764526 -0.5320649147033691; -0.6023417115211487 0.04440677538514137 … 0.10300155729055405 0.333592027425766], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = [-0.1032603457570076 0.36594825983047485 … 0.566887378692627 -0.23729221522808075; 0.48476433753967285 -0.4082109332084656 … 0.1959226280450821 -0.013809792697429657; … ; -0.021105678752064705 -0.22593370079994202 … 0.48392754793167114 0.017326444387435913; 0.1668839156627655 -0.5292595028877258 … 0.049039456993341446 -0.43499240279197693], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = [-0.011006866581737995 0.24997955560684204 … 0.0528075248003006 -0.4433881938457489; -0.501366138458252 -0.5152841806411743 … 0.3501299321651459 -0.2047107219696045; … ; 0.24082361161708832 -0.14314401149749756 … 0.49718210101127625 -0.40485042333602905; 0.283348947763443 -0.5262483954429626 … -0.38046470284461975 0.3434981405735016], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_5 = (weight = [-0.1860182136297226 -0.2244003266096115 … -0.0605173222720623 0.08741829544305801], bias = [0.0;;])))

Now we train the neural nets and resample data while training.

function train(pde_system, prob, sampler, strategy, resample_period = 500, n=10)
     bfgs = BFGS()
     res = Optimization.solve(prob, bfgs; maxiters=2000)

     for i in 1:n
         data = Sophon.sample(pde_system, sampler)
         prob = remake(prob; u0=res.u, p=data)
         @showprogress res = Optimization.solve(prob, bfgs; maxiters=resample_period)
     end
     return res
end

res = train(pde_system, prob, sampler, strategy)
retcode: Failure
u: ComponentVector{Float64}(u = (layer_1 = (weight = [0.12132043472969824 0.13706589653821677; 0.331558875658897 -0.2214592097318921; … ; -0.1832470565261606 -0.2683121772035772; -0.06384592892422732 -0.25353549603323705], bias = [0.03457990902318536; -0.2970302073157644; … ; 0.10202417418055722; 0.08990128429765297;;]), layer_2 = (weight = [0.03227584711887043 -0.7134805007697756 … 0.25508337065258274 0.0918824339231881; 0.0014963677796634674 -0.4870663069853615 … -0.5407014505150527 -0.16101650766666947; … ; -0.1236361719417575 -0.48594967704735303 … -0.11153342006575803 0.6257046269304084; -0.21604197690032692 0.37324947690763965 … 0.46787286125934446 0.3803253530393611], bias = [-0.028709156319010175; -0.12217875297239775; … ; 0.13325571272040806; 0.08786627572409132;;]), layer_3 = (weight = [-0.21866496809870836 -0.518635538356842 … 0.7180741835027895 -0.07941685559869616; 0.18211811231741304 0.4743518080632508 … -0.48113967913678723 0.2387737159942114; … ; 0.3238627923632339 -0.07968777654477957 … -0.2877455918723949 0.6307444698959466; -0.08198880336255476 0.01338038093097148 … -0.16753173413800465 -0.1522782737142534], bias = [0.0930817492761415; -0.2004021944561516; … ; 0.0439520176898969; -0.06977567045265581;;]), layer_4 = (weight = [0.6678219019725093 -0.21402212233652582 … -0.31778904942208364 -0.6486034117183597; -0.37889264819178214 0.3574089509048657 … -0.43766805720301827 0.7389820607553469; … ; 0.8000333029453861 -0.22683318540890932 … 0.5265812597851898 -0.3761687287092942; -0.7751769783370748 -0.14480041436952257 … -0.19803438100585818 0.11001374836350837], bias = [0.15500091773436533; -0.2628036993486933; … ; -0.06963953571356227; -0.284992240506222;;]), layer_5 = (weight = [0.24656523790733714 -1.3082501610922546 … -0.8226059747575226 0.6581781456212589], bias = [0.24270937243913138;;])), v = (layer_1 = (weight = [0.9124125901722901 0.36013251271287966; 0.3492437769709516 0.0648722781970533; … ; -0.3264115005223823 -0.4961072713033787; 1.1031092900274804 -0.7848257202837018], bias = [-0.09943663221453625; 0.096469824035597; … ; 0.19015654519102945; 0.4623849653370221;;]), layer_2 = (weight = [0.5816141932164753 0.2961703914921628 … -0.7755940655433085 0.21743506876814955; -0.42800827687055415 -0.4916097875698985 … 0.34126316951722613 0.24852533834046445; … ; -0.22951925303661008 -0.7257520274361305 … -0.30986092426562484 -0.6857298285723963; -0.4758938464505916 0.04374666080066617 … 0.23403671606776374 0.4768505926281345], bias = [-0.5372954487581348; -0.07792116676071285; … ; -0.2542190542300221; 0.13968932253125724;;]), layer_3 = (weight = [-0.1037606892361396 0.39288270176381607 … 0.3918585973650151 -0.2383789334389362; 0.5909426961117098 -0.6590281537284639 … 0.030801689565764616 -0.061253268326188916; … ; 0.15050621020140095 -0.43589097247544584 … 0.6993558115021667 -0.18727106880137984; 0.18158964404627467 -0.747391961755251 … 0.4710640221502363 -0.15604286085666746], bias = [-0.1823215251780214; -0.17896562509012234; … ; -0.15970349747786974; -0.07287239608833126;;]), layer_4 = (weight = [0.14156087628715588 -0.018308540542586054 … 0.0036916749335062426 -0.6895742711544348; -0.879098515240261 -0.5510485622347722 … 0.5777035456093971 -0.22141891457488125; … ; -0.3176150225603898 0.2040106926002749 … 0.874916212875832 -0.6640311308802417; 0.22059917079809951 -0.02815242387101304 … -0.30911724226801357 0.2720420779235135], bias = [-0.20502257918318018; 0.09085936850773539; … ; 0.22306877263544167; 0.23435707205205722;;]), layer_5 = (weight = [-0.32666730418081075 -0.29002784158088635 … -0.7173706858265606 0.2599343768868765], bias = [0.14980081231242026;;])))
phi = pinn.phi
ps = res.u

xs, ts= [infimum(d.domain):0.01:supremum(d.domain) for d in pde_system.domain]

u = [sum(phi.u(([x,t]), ps.u)) for x in xs, t in ts]
v = [sum(phi.v(([x,t]), ps.v)) for x in xs, t in ts]
ψ = @. sqrt(u^2+ v^2)

axis = (xlabel="t", ylabel="x", title="u")
fig, ax1, hm1 = heatmap(ts, xs, u', axis=axis)
ax2, hm2= heatmap(fig[1, end+1], ts, xs, v', axis= merge(axis, (; title="v")))
display(fig)

axis = (xlabel="t", ylabel="x", title="ψ")
fig, ax1, hm1 = heatmap(ts, xs, ψ', axis=axis, colormap=:jet)
Colorbar(fig[:, end+1], hm1)
display(fig)

Customize Sampling

Bascially any sampling method is supportted. For example we can sample data according to the predicted solution.

using StatsBase

data = vec([[x, t] for x in xs, t in ts])
wv = vec(ψ)
new_data = wsample(data, wv, 500)
new_data = reduce(hcat, new_data)
fig, ax = scatter(new_data[2,:], new_data[1,:])

prob.p[1] = new_data
prob.p[2] = new_data
prob = remake(prob; u0 = res.u)
# res = Optimization.solve(prob, bfgs; maxiters=1000)
OptimizationProblem. In-place: true
u0: ComponentVector{Float64}(u = (layer_1 = (weight = [0.12132043472969824 0.13706589653821677; 0.331558875658897 -0.2214592097318921; … ; -0.1832470565261606 -0.2683121772035772; -0.06384592892422732 -0.25353549603323705], bias = [0.03457990902318536; -0.2970302073157644; … ; 0.10202417418055722; 0.08990128429765297;;]), layer_2 = (weight = [0.03227584711887043 -0.7134805007697756 … 0.25508337065258274 0.0918824339231881; 0.0014963677796634674 -0.4870663069853615 … -0.5407014505150527 -0.16101650766666947; … ; -0.1236361719417575 -0.48594967704735303 … -0.11153342006575803 0.6257046269304084; -0.21604197690032692 0.37324947690763965 … 0.46787286125934446 0.3803253530393611], bias = [-0.028709156319010175; -0.12217875297239775; … ; 0.13325571272040806; 0.08786627572409132;;]), layer_3 = (weight = [-0.21866496809870836 -0.518635538356842 … 0.7180741835027895 -0.07941685559869616; 0.18211811231741304 0.4743518080632508 … -0.48113967913678723 0.2387737159942114; … ; 0.3238627923632339 -0.07968777654477957 … -0.2877455918723949 0.6307444698959466; -0.08198880336255476 0.01338038093097148 … -0.16753173413800465 -0.1522782737142534], bias = [0.0930817492761415; -0.2004021944561516; … ; 0.0439520176898969; -0.06977567045265581;;]), layer_4 = (weight = [0.6678219019725093 -0.21402212233652582 … -0.31778904942208364 -0.6486034117183597; -0.37889264819178214 0.3574089509048657 … -0.43766805720301827 0.7389820607553469; … ; 0.8000333029453861 -0.22683318540890932 … 0.5265812597851898 -0.3761687287092942; -0.7751769783370748 -0.14480041436952257 … -0.19803438100585818 0.11001374836350837], bias = [0.15500091773436533; -0.2628036993486933; … ; -0.06963953571356227; -0.284992240506222;;]), layer_5 = (weight = [0.24656523790733714 -1.3082501610922546 … -0.8226059747575226 0.6581781456212589], bias = [0.24270937243913138;;])), v = (layer_1 = (weight = [0.9124125901722901 0.36013251271287966; 0.3492437769709516 0.0648722781970533; … ; -0.3264115005223823 -0.4961072713033787; 1.1031092900274804 -0.7848257202837018], bias = [-0.09943663221453625; 0.096469824035597; … ; 0.19015654519102945; 0.4623849653370221;;]), layer_2 = (weight = [0.5816141932164753 0.2961703914921628 … -0.7755940655433085 0.21743506876814955; -0.42800827687055415 -0.4916097875698985 … 0.34126316951722613 0.24852533834046445; … ; -0.22951925303661008 -0.7257520274361305 … -0.30986092426562484 -0.6857298285723963; -0.4758938464505916 0.04374666080066617 … 0.23403671606776374 0.4768505926281345], bias = [-0.5372954487581348; -0.07792116676071285; … ; -0.2542190542300221; 0.13968932253125724;;]), layer_3 = (weight = [-0.1037606892361396 0.39288270176381607 … 0.3918585973650151 -0.2383789334389362; 0.5909426961117098 -0.6590281537284639 … 0.030801689565764616 -0.061253268326188916; … ; 0.15050621020140095 -0.43589097247544584 … 0.6993558115021667 -0.18727106880137984; 0.18158964404627467 -0.747391961755251 … 0.4710640221502363 -0.15604286085666746], bias = [-0.1823215251780214; -0.17896562509012234; … ; -0.15970349747786974; -0.07287239608833126;;]), layer_4 = (weight = [0.14156087628715588 -0.018308540542586054 … 0.0036916749335062426 -0.6895742711544348; -0.879098515240261 -0.5510485622347722 … 0.5777035456093971 -0.22141891457488125; … ; -0.3176150225603898 0.2040106926002749 … 0.874916212875832 -0.6640311308802417; 0.22059917079809951 -0.02815242387101304 … -0.30911724226801357 0.2720420779235135], bias = [-0.20502257918318018; 0.09085936850773539; … ; 0.22306877263544167; 0.23435707205205722;;]), layer_5 = (weight = [-0.32666730418081075 -0.29002784158088635 … -0.7173706858265606 0.2599343768868765], bias = [0.14980081231242026;;])))