Fitting a nonlinear discontinuous function

This example is taken from here. However, we do not use adaptive activation functions. Instead, we show that using suitable non-parametric activation functions immediately performs better.

Consider the following discontinuous function with discontinuity at $x=0$:

\[u(x)= \begin{cases}0.2 \sin (18 x) & \text { if } x \leq 0 \\ 1+0.3 x \cos (54 x) & \text { otherwise }\end{cases}\]

The domain is $[-1,1]$. The number of training points used is 50.

Import packages

using Lux, Sophon
using NNlib, Optimisers, Plots, Random, StatsBase, Zygote

Dataset

function u(x)
    if x <= 0
        return 0.2 * sin(18 * x)
    else
        return 1 + 0.3 * x * cos(54 * x)
    end
end

function generate_data(n=50)
    x = reshape(collect(range(-1.0f0, 1.0f0, n)), (1, n))
    y = u.(x)
    return (x, y)
end
generate_data (generic function with 2 methods)

Let's visualize the data.

x_train, y_train = generate_data(50)
x_test, y_test = generate_data(200)
Plots.plot(vec(x_test), vec(y_test),label=false)

Naive Neural Networks

First, we demonstrate that naive, fully connected neural nets are not sufficient for fitting this function.

model = FullyConnected((1,50,50,50,50,1), relu)
Chain(
    layer_1 = Dense(1 => 50, relu),     # 100 parameters
    layer_2 = Dense(50 => 50, relu),    # 2_550 parameters
    layer_3 = Dense(50 => 50, relu),    # 2_550 parameters
    layer_4 = Dense(50 => 50, relu),    # 2_550 parameters
    layer_5 = Dense(50 => 1),           # 51 parameters
)         # Total: 7_801 parameters,
          #        plus 0 states.

Train the model

function train(model, x, y)
    ps, st = Lux.setup(Random.default_rng(), model)
    opt = Adam()
    st_opt = Optimisers.setup(opt,ps)
    function loss(model, ps, st, x, y)
        y_pred, _ = model(x, ps, st)
        mes = mean(abs2, y_pred .- y)
        return mes
    end

    for i in 1:2000
        gs = gradient(p->loss(model,p,st,x,y), ps)[1]
        st_opt, ps = Optimisers.update(st_opt, ps, gs)
        if i % 100 == 1 || i == 2000
            println("Epoch $i ||  ", loss(model,ps,st,x,y))
        end
    end
    return ps, st
end
train (generic function with 1 method)

Plot the result

@time ps, st = train(model, x_train, y_train)
y_pred = model(x_test,ps,st)[1]
Plots.plot(vec(x_test), vec(y_pred),label="Prediction",line = (:dot, 4))
Plots.plot!(vec(x_test), vec(y_test),label="Exact",legend=:topleft)
Epoch 1 ||  0.33211493414236365
Epoch 101 ||  0.017558964981011626
Epoch 201 ||  0.016286750225435142
Epoch 301 ||  0.015586845137269774
Epoch 401 ||  0.01453105260796386
Epoch 501 ||  0.013318687689341928
Epoch 601 ||  0.013218828779241455
Epoch 701 ||  0.013177559261395601
Epoch 801 ||  0.013254494040674678
Epoch 901 ||  0.013241565045628123
Epoch 1001 ||  0.013237247265629004
Epoch 1101 ||  0.01327065945666888
Epoch 1201 ||  0.013289406808524218
Epoch 1301 ||  0.01328566200528125
Epoch 1401 ||  0.013353929202729176
Epoch 1501 ||  0.013220101895295624
Epoch 1601 ||  0.013218167911195795
Epoch 1701 ||  0.013348619817749357
Epoch 1801 ||  0.013205242895896507
Epoch 1901 ||  0.013200326024322438
Epoch 2000 ||  0.013197263673170225
  7.528670 seconds (9.20 M allocations: 1.215 GiB, 2.16% gc time, 93.96% compilation time)

Siren

We use four hidden layers with 50 neurons in each.

model = Siren(1,50,50,50,50,1; omega = 30f0)
Chain(
    layer_1 = Dense(1 => 50, sin),      # 100 parameters
    layer_2 = Dense(50 => 50, sin),     # 2_550 parameters
    layer_3 = Dense(50 => 50, sin),     # 2_550 parameters
    layer_4 = Dense(50 => 50, sin),     # 2_550 parameters
    layer_5 = Dense(50 => 1),           # 51 parameters
)         # Total: 7_801 parameters,
          #        plus 0 states.
@time ps, st = train(model, x_train, y_train)
y_pred = model(x_test,ps,st)[1]
Plots.plot(vec(x_test), vec(y_pred),label="Prediction",line = (:dot, 4))
Plots.plot!(vec(x_test), vec(y_test),label="Exact",legend=:topleft)
Epoch 1 ||  1.244547143747077
Epoch 101 ||  0.0020304667366154144
Epoch 201 ||  0.00021802140174867173
Epoch 301 ||  1.554377020172236e-5
Epoch 401 ||  8.653130013132294e-7
Epoch 501 ||  3.1872430149209275e-8
Epoch 601 ||  8.301891632653796e-10
Epoch 701 ||  2.612457416595017e-11
Epoch 801 ||  2.2810923185535154e-12
Epoch 901 ||  5.952888961198464e-13
Epoch 1001 ||  1.58945849778866e-13
Epoch 1101 ||  1.1546265045420297e-13
Epoch 1201 ||  7.396298629086989e-14
Epoch 1301 ||  3.774152011743428e-14
Epoch 1401 ||  4.654341438893344e-14
Epoch 1501 ||  5.011853506021905e-14
Epoch 1601 ||  4.3152554444728475e-14
Epoch 1701 ||  7.214382684986497e-14
Epoch 1801 ||  5.2201648173338806e-14
Epoch 1901 ||  7.705561112155462e-14
Epoch 2000 ||  7.514691037423146e-14
  5.259629 seconds (5.20 M allocations: 1.085 GiB, 1.59% gc time, 89.72% compilation time)

As we can see the model overfits the data, and the high frequencies cannot be optimized away. We need to tunning the hyperparameter omega

model = Siren(1,50,50,50,50,1; omega = 10f0)
Chain(
    layer_1 = Dense(1 => 50, sin),      # 100 parameters
    layer_2 = Dense(50 => 50, sin),     # 2_550 parameters
    layer_3 = Dense(50 => 50, sin),     # 2_550 parameters
    layer_4 = Dense(50 => 50, sin),     # 2_550 parameters
    layer_5 = Dense(50 => 1),           # 51 parameters
)         # Total: 7_801 parameters,
          #        plus 0 states.
@time ps, st = train(model, x_train, y_train)
y_pred = model(x_test,ps,st)[1]
Plots.plot(vec(x_test), vec(y_pred),label="Prediction",line = (:dot, 4))
Plots.plot!(vec(x_test), vec(y_test),label="Exact",legend=:topleft)
Epoch 1 ||  1.5044918360104416
Epoch 101 ||  0.008747278671585977
Epoch 201 ||  0.0054290369565187105
Epoch 301 ||  0.004057629321358789
Epoch 401 ||  0.0028611589052059245
Epoch 501 ||  0.0019241110225008684
Epoch 601 ||  0.0012575806711284767
Epoch 701 ||  0.0007338853031752823
Epoch 801 ||  0.0003498864344587249
Epoch 901 ||  0.00014608042143291494
Epoch 1001 ||  7.367923424203053e-5
Epoch 1101 ||  5.2614493193239925e-5
Epoch 1201 ||  4.458113243774209e-5
Epoch 1301 ||  3.954863152445854e-5
Epoch 1401 ||  3.5486182951160676e-5
Epoch 1501 ||  3.1911478663684325e-5
Epoch 1601 ||  2.865442635128827e-5
Epoch 1701 ||  2.5646842154570724e-5
Epoch 1801 ||  2.2855599675166522e-5
Epoch 1901 ||  2.026878942623e-5
Epoch 2000 ||  1.7902375028683566e-5
  0.496432 seconds (1.26 M allocations: 901.495 MiB, 3.34% gc time)

Gaussian activation function

We can also try using a fully connected net with the gaussian activation function.

model = FullyConnected((1,50,50,50,50,1), gaussian)
Chain(
    layer_1 = Dense(1 => 50, gaussian),  # 100 parameters
    layer_2 = Dense(50 => 50, gaussian),  # 2_550 parameters
    layer_3 = Dense(50 => 50, gaussian),  # 2_550 parameters
    layer_4 = Dense(50 => 50, gaussian),  # 2_550 parameters
    layer_5 = Dense(50 => 1),           # 51 parameters
)         # Total: 7_801 parameters,
          #        plus 0 states.
@time ps, st = train(model, x_train, y_train)
y_pred = model(x_test,ps,st)[1]
Plots.plot(vec(x_test), vec(y_pred),label="Prediction",line = (:dot, 4))
Plots.plot!(vec(x_test), vec(y_test),label="Exact",legend=:topleft)
Epoch 1 ||  0.2834719319724448
Epoch 101 ||  0.005419890202136385
Epoch 201 ||  0.0034669216189830198
Epoch 301 ||  0.00027882821308004114
Epoch 401 ||  4.7382583704128875e-6
Epoch 501 ||  8.016170748585772e-7
Epoch 601 ||  3.889809627308763e-7
Epoch 701 ||  7.244211338137569e-7
Epoch 801 ||  0.00024102173619896546
Epoch 901 ||  1.2921137443294608e-7
Epoch 1001 ||  7.869109384847333e-8
Epoch 1101 ||  4.9112123646171555e-5
Epoch 1201 ||  3.8001445874615374e-8
Epoch 1301 ||  2.488655013958093e-8
Epoch 1401 ||  2.3003831928445148e-6
Epoch 1501 ||  3.2602870214588686e-8
Epoch 1601 ||  1.676143051063272e-8
Epoch 1701 ||  2.1762552546753528e-7
Epoch 1801 ||  3.623238876298654e-8
Epoch 1901 ||  1.4262278759043627e-6
Epoch 2000 ||  5.314626695535025e-7
  4.113636 seconds (4.72 M allocations: 1.058 GiB, 1.69% gc time, 84.83% compilation time)

Quadratic activation function

quadratic is much cheaper to compute compared to the Gaussian activation function.

model = FullyConnected((1,50,50,50,50,1), quadratic)
Chain(
    layer_1 = Dense(1 => 50, quadratic),  # 100 parameters
    layer_2 = Dense(50 => 50, quadratic),  # 2_550 parameters
    layer_3 = Dense(50 => 50, quadratic),  # 2_550 parameters
    layer_4 = Dense(50 => 50, quadratic),  # 2_550 parameters
    layer_5 = Dense(50 => 1),           # 51 parameters
)         # Total: 7_801 parameters,
          #        plus 0 states.
@time ps, st = train(model, x_train, y_train)
y_pred = model(x_test,ps,st)[1]
Plots.plot(vec(x_test), vec(y_pred),label="Prediction",line = (:dot, 4))
Plots.plot!(vec(x_test), vec(y_test),label="Exact",legend=:topleft)
Epoch 1 ||  0.2934415682217071
Epoch 101 ||  0.0065038672964299995
Epoch 201 ||  0.006139738932723895
Epoch 301 ||  0.005817771321898906
Epoch 401 ||  0.00531100405785607
Epoch 501 ||  0.004462746060709031
Epoch 601 ||  0.00267753556381101
Epoch 701 ||  0.0008346869339238527
Epoch 801 ||  6.489995555251278e-5
Epoch 901 ||  4.108444969977914e-6
Epoch 1001 ||  3.6060054272585023e-6
Epoch 1101 ||  1.8404302678300762e-8
Epoch 1201 ||  1.8416048214737336e-7
Epoch 1301 ||  3.0715145117515355e-5
Epoch 1401 ||  5.226529890157809e-7
Epoch 1501 ||  3.007607902280181e-8
Epoch 1601 ||  8.796082780201602e-6
Epoch 1701 ||  2.3935236582396397e-8
Epoch 1801 ||  3.565741733668846e-7
Epoch 1901 ||  6.098279125984983e-7
Epoch 2000 ||  6.254504025767731e-7
  3.931608 seconds (4.48 M allocations: 1.042 GiB, 1.47% gc time, 88.88% compilation time)

Conclusion

"Neural networks suppress high-frequency components" is a misinterpretation of the spectral bias. The accurate way of putting it is that the lower frequencies in the error are optimized first in the optimization process. This can be seen in Siren's example of overfitting data, where you do not have implicit regularization. The high frequency in the network will never go away because it has fitted the data perfectly.

Mainstream attributes the phenomenon that neural networks "suppress" high frequencies to gradient descent. This is not the whole picture. Initialization also plays an important role. Siren mitigates this problem by initializing larger weights in the first layer. In contrast, activation functions such as Gaussian have sufficiently large gradients and sufficiently large support of the second derivative with proper hyperparameters. Please refer to [1], [2] and [3] if you want to dive deeper into this.