Skip to content

Instantly share code, notes, and snippets.

@dkuku
Last active February 7, 2022 07:09
Show Gist options
  • Save dkuku/e8ca7c7e2229b218ca49d705378a3a3e to your computer and use it in GitHub Desktop.
Save dkuku/e8ca7c7e2229b218ca49d705378a3a3e to your computer and use it in GitHub Desktop.

Neural network from scratch in ... Elixir

Pure Elixir initial version test

inputs = [1, 2, 3, 2.5]
weights = [0.2, 0.8, -0.5, 1.0]
bias = 2.0

defmodule NN do
  def neuron(inputs, weights, bias) do
    Enum.zip(inputs, weights)
    |> Enum.map(fn {i, w} -> i * w end)
    |> Enum.sum()
    |> then(&(&1 + bias))
  end
end

NN.neuron(inputs, weights, bias)

First layer using NX

Mix.install([
  # {:nx, path: "/home/kuku/Projects/nx/nx"},
  {:nx, "~> 0.1.0"},
  # {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
  {:kino, "~> 0.4.1"},
  {:vega_lite, "~> 0.1.2"}
])
inputs = Nx.tensor([[1, 2, 3, 2.5], [2.0, 5.0, -1.0, 2.0], [-1.5, 2.7, 3.3, -0.8]])
weights = Nx.tensor([[0.2, 0.8, -0.5, 1.0], [0.5, -0.91, 0.26, -0.5], [-0.26, -0.27, 0.17, 0.87]])
biases = Nx.tensor([2.0, 3.0, 0.5])

Shapes of our tensors:

  • inputs 3 x 4
  • weights 3 x 4 -> transposed 4 x 3
  • dot product 3 x 3
  • biases 3 x 1
  • layer1_outputs 3 x 3
layer1_outputs = Nx.dot(inputs, Nx.transpose(weights)) |> Nx.add(biases)
  • weights2 3 x 3 -> transposed 3 x 3
  • dot product 3 x 3
  • biases2 3 x 1
  • layer2_outputs 3 x 3
weights2 = Nx.tensor([[0.1, -0.14, 0.5], [-0.5, 0.12, -0.33], [-0.44, 0.73, -0.13]])
biases2 = Nx.tensor([-1.0, 2.0, -0.5])
layer2_outputs = Nx.dot(layer1_outputs, Nx.transpose(weights2)) |> Nx.add(biases2)

nnfs spiral data

x_spiral =
  Nx.tensor([
    [0.000000, 0.000000],
    [0.002996, 0.009647],
    [0.012881, 0.015563],
    [0.029975, 0.004448],
    [0.039312, 0.009328],
    [0.000829, 0.050498],
    [0.053484, 0.028506],
    [0.041736, 0.057075],
    [0.055463, 0.058769],
    [0.081604, 0.040066],
    [0.089188, 0.047420],
    [0.107161, -0.029364],
    [0.121183, -0.002648],
    [0.128778, 0.025679],
    [0.141113, -0.009224],
    [0.150579, -0.016813],
    [0.113476, -0.115078],
    [0.171553, -0.007518],
    [0.167187, -0.071459],
    [0.191326, 0.015079],
    [0.136772, 0.148680],
    [0.135606, -0.163114],
    [0.104025, -0.196371],
    [0.215634, -0.086465],
    [-0.098303, -0.221599],
    [0.246031, -0.056899],
    [0.124166, -0.231420],
    [0.132642, -0.238299],
    [-0.123804, -0.254292],
    [-0.146155, -0.253863],
    [0.010245, -0.302857],
    [-0.055691, -0.308139],
    [0.111795, -0.303284],
    [0.238807, -0.232556],
    [-0.040702, -0.341014],
    [-0.160721, -0.314890],
    [-0.325695, -0.161722],
    [-0.347926, -0.136481],
    [-0.185075, -0.336273],
    [-0.237355, -0.314406],
    [-0.147814, -0.376031],
    [-0.118350, -0.396871],
    [-0.104012, -0.411294],
    [-0.356894, 0.247550],
    [-0.385084, -0.221903],
    [-0.421106, -0.171118],
    [-0.357669, -0.296596],
    [-0.432742, 0.195242],
    [-0.379778, -0.301408],
    [-0.490726, 0.064523],
    [-0.502019, -0.055251],
    [-0.415473, 0.304571],
    [-0.504104, 0.147544],
    [-0.534662, 0.027200],
    [-0.407162, 0.362960],
    [-0.272887, 0.483916],
    [-0.316287, 0.468967],
    [-0.210549, 0.535879],
    [-0.395798, 0.431942],
    [-0.288309, 0.521580],
    [-0.321319, 0.513871],
    [-0.182880, 0.588396],
    [-0.259437, 0.569998],
    [-0.448757, 0.451194],
    [0.172541, 0.623014],
    [0.053504, 0.654382],
    [-0.278894, 0.605527],
    [0.444152, 0.510631],
    [0.090303, 0.680907],
    [0.455620, 0.527425],
    [0.645984, 0.287495],
    [0.584617, 0.415401],
    [0.727224, 0.008403],
    [0.336009, 0.656367],
    [0.735656, 0.132396],
    [0.616298, 0.440566],
    [0.628074, 0.441419],
    [0.726366, 0.278084],
    [0.781131, 0.102894],
    [0.788548, -0.122324],
    [0.757484, 0.281448],
    [0.582663, -0.574392],
    [0.653484, -0.508931],
    [0.814699, 0.197873],
    [0.165077, -0.832272],
    [-0.093422, -0.853488],
    [0.128543, -0.859124],
    [0.583979, -0.656686],
    [0.780273, -0.425790],
    [-0.082901, -0.895159],
    [0.463849, -0.781850],
    [-0.339550, -0.854178],
    [0.025749, -0.928936],
    [-0.414920, -0.842794],
    [-0.233395, -0.920363],
    [-0.480520, -0.830617],
    [-0.265596, -0.932615],
    [-0.934678, -0.293908],
    [-0.506961, -0.850230],
    [-0.700585, -0.713569],
    [-0.000000, -0.000000],
    [-0.000024, -0.010101],
    [-0.002859, -0.019999],
    [-0.029651, -0.006253],
    [-0.015462, -0.037328],
    [-0.044768, 0.023380],
    [-0.050146, -0.034037],
    [-0.055761, -0.043476],
    [-0.057789, 0.056484],
    [-0.072163, 0.055290],
    [-0.058895, 0.082064],
    [-0.093631, 0.059824],
    [-0.118445, -0.025750],
    [-0.038581, 0.125518],
    [-0.135727, 0.039701],
    [-0.091694, 0.120620],
    [-0.074141, 0.143607],
    [-0.137313, 0.103112],
    [-0.077633, 0.164411],
    [-0.035506, 0.188606],
    [-0.070780, 0.189215],
    [-0.177902, 0.115526],
    [-0.043066, 0.218009],
    [0.094744, 0.212127],
    [-0.113965, 0.213966],
    [-0.029211, 0.250830],
    [-0.041241, 0.259368],
    [0.241242, 0.127211],
    [0.159244, 0.233738],
    [0.157248, 0.247145],
    [0.023907, 0.302086],
    [0.232032, 0.210267],
    [0.104370, 0.305918],
    [0.235107, 0.236296],
    [0.179523, 0.292777],
    [0.341324, 0.092115],
    [0.355456, 0.076695],
    [0.327276, 0.180471],
    [0.381891, 0.038618],
    [0.289766, 0.266878],
    [0.268962, 0.301510],
    [0.403800, -0.091971],
    [0.416720, -0.079536],
    [0.376112, -0.217241],
    [0.031949, -0.443295],
    [0.289429, -0.350489],
    [0.463997, -0.024562],
    [0.185741, -0.436904],
    [0.484158, -0.025861],
    [0.413590, -0.271880],
    [0.322098, -0.389010],
    [-0.152348, -0.492109],
    [0.386825, -0.355326],
    [0.371773, -0.385212],
    [0.162570, -0.520665],
    [0.258585, -0.491707],
    [-0.277088, -0.493142],
    [0.271146, -0.507914],
    [0.240553, -0.534195],
    [-0.019498, -0.595641],
    [-0.062692, -0.602809],
    [-0.609022, -0.093526],
    [-0.536692, -0.322748],
    [-0.409688, -0.486944],
    [-0.092747, -0.639777],
    [-0.628977, -0.188326],
    [-0.295126, -0.597783],
    [-0.191790, -0.649023],
    [-0.675281, 0.125633],
    [-0.689082, -0.104561],
    [-0.684672, 0.176560],
    [-0.716190, 0.037506],
    [-0.663356, 0.298135],
    [-0.717866, -0.168489],
    [-0.709262, -0.235935],
    [-0.608644, 0.451080],
    [-0.767672, -0.002838],
    [-0.768544, 0.119492],
    [-0.733915, 0.286570],
    [-0.605273, 0.520016],
    [-0.655299, 0.472840],
    [-0.799242, 0.175023],
    [-0.642872, 0.522272],
    [-0.838365, -0.005552],
    [-0.044179, 0.847334],
    [-0.750967, 0.416194],
    [-0.569447, 0.656009],
    [-0.031273, 0.878231],
    [-0.288415, 0.840797],
    [0.710776, 0.550437],
    [-0.358262, 0.835521],
    [0.421108, 0.817057],
    [0.382093, 0.847107],
    [-0.037250, 0.938655],
    [0.747897, 0.584971],
    [0.588822, 0.757702],
    [0.914996, 0.321083],
    [0.958146, 0.204843],
    [0.838563, -0.526035],
    [0.969427, -0.245380],
    [0.000000, 0.000000],
    [0.009143, 0.004294],
    [0.019102, -0.006575],
    [0.029635, -0.006326],
    [0.038554, -0.012085],
    [0.037844, 0.033446],
    [0.059696, -0.010465],
    [0.070468, 0.005812],
    [0.069816, -0.040690],
    [0.082263, -0.038694],
    [0.050711, -0.087358],
    [0.073384, -0.083429],
    [0.045619, -0.112300],
    [0.103771, -0.080466],
    [0.122612, -0.070458],
    [0.095392, -0.117717],
    [0.060471, -0.149877],
    [0.014476, -0.171106],
    [-0.148598, -0.104769],
    [0.017839, -0.191088],
    [0.087675, -0.182003],
    [0.009103, -0.211926],
    [0.000161, -0.222222],
    [-0.125921, -0.195238],
    [0.080229, -0.228764],
    [-0.135363, -0.213181],
    [-0.171858, -0.198588],
    [-0.205293, -0.179541],
    [-0.145042, -0.242806],
    [-0.214011, -0.200017],
    [-0.097370, -0.286961],
    [-0.244015, -0.196235],
    [-0.266583, -0.182791],
    [-0.333263, -0.006862],
    [-0.261409, -0.222738],
    [-0.330371, 0.125866],
    [-0.247638, 0.266283],
    [-0.243964, -0.283128],
    [-0.338751, 0.180498],
    [-0.297228, 0.258542],
    [-0.398896, 0.064268],
    [-0.384470, 0.153934],
    [-0.346745, 0.244437],
    [-0.350296, 0.256801],
    [-0.331783, 0.295721],
    [-0.449742, 0.065907],
    [0.041710, 0.462771],
    [0.073077, 0.469089],
    [-0.308975, 0.373648],
    [-0.392718, 0.301243],
    [0.089572, 0.497044],
    [-0.136970, 0.496609],
    [0.099591, 0.515725],
    [0.032706, 0.534354],
    [0.337681, 0.428360],
    [0.386817, 0.398766],
    [0.090692, 0.558339],
    [-0.038459, 0.574472],
    [-0.038419, 0.584598],
    [0.539512, 0.253170],
    [0.198358, 0.572681],
    [0.431344, 0.439998],
    [0.462868, 0.421850],
    [0.594173, 0.227854],
    [0.230009, 0.604162],
    [0.655945, 0.028546],
    [0.649852, -0.148786],
    [0.672799, -0.073185],
    [0.686836, -0.006667],
    [0.662707, -0.215840],
    [0.596542, -0.379587],
    [0.435551, 0.569764],
    [0.017199, -0.727069],
    [0.478152, -0.561330],
    [0.674677, -0.321760],
    [0.590618, -0.474437],
    [0.263736, -0.720951],
    [0.409586, -0.661194],
    [0.777827, -0.125452],
    [-0.536718, -0.590513],
    [0.198980, -0.783200],
    [-0.331806, -0.747881],
    [0.274292, -0.781547],
    [-0.644941, -0.535667],
    [-0.293880, -0.795966],
    [-0.496032, -0.700801],
    [0.095442, -0.863428],
    [-0.790659, -0.383570],
    [-0.727897, -0.510186],
    [-0.880292, -0.182396],
    [-0.437860, -0.796696],
    [-0.575886, -0.716428],
    [-0.767106, 0.524532],
    [-0.522328, -0.780791],
    [-0.867686, -0.385567],
    [-0.911101, 0.301196],
    [-0.964920, -0.096131],
    [-0.950698, 0.237016],
    [-0.979387, -0.143880],
    [-0.942789, 0.333391]
  ])

y_spiral =
  Nx.tensor(
    for x <- 0..2, _ <- 0..99 do
      x
    end
  )

Prepare for plotting

x_coord = Nx.slice(x_spiral, [0, 0], [300, 1]) |> Nx.reshape({300}) |> Nx.to_flat_list()
y_coord = Nx.slice(x_spiral, [0, 1], [300, 1]) |> Nx.reshape({300}) |> Nx.to_flat_list()
color = Nx.to_flat_list(y_spiral)

Visualize data

alias VegaLite, as: Vl

Vl.new(width: 400, height: 400, title: "nnfs spiral_data")
|> Vl.mark(:circle, opacity: 0.8, size: 40)
|> Vl.data_from_series(x: x_coord, y: y_coord, color: color)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "color", type: :nominal, scale: [scheme: "category10"])

shapes

  • weights 2 x 3
  • biases 1 x 3
  • inputs 300 x 2
  • forward function output [300,2] x [2,3] -> 300 x 3

the output starts to differ here because of usage of random normal function

:rand.seed(:exsss, 0)

defmodule LayerDense do
  def new(n_inputs, n_neurons) do
    weights = Nx.random_normal({n_inputs, n_neurons}) |> Nx.multiply(0.01)
    biases = Nx.broadcast(Nx.tensor(0), {1, n_neurons})
    {weights, biases}
  end

  def forward(inputs, {weights, biases}) do
    Nx.dot(inputs, weights) |> Nx.add(biases)
  end
end

dense1 =
  x_spiral
  |> LayerDense.forward(LayerDense.new(2, 3))

Nx.slice(dense1, [0, 0], [5, 3])

Activation functions

compare activation functions

based on vega docs

Vl.new(width: 460, height: 200)
|> Vl.data(sequence: [start: -5, stop: 5, step: 0.1, as: "x"])
# |> Vl.transform(calculate: "datum.x", as: "x=y")
# |> Vl.transform(calculate: "datum.x>0", as: "step")
|> Vl.transform(calculate: "1/(1 + exp(-1 * datum.x))", as: "sigmoid")
|> Vl.transform(calculate: "if(datum.x>0, datum.x, 0)", as: "ReLU")
|> Vl.transform(fold: ["sigmoid", "ReLU"])
|> Vl.mark(:line)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "value", type: :quantitative)
|> Vl.encode_field(:color, "key", type: :nominal, title: nil)
defmodule ActivationReLU do
  def forward(inputs) do
    Nx.max(0, inputs)
  end
end

activated1 = ActivationReLU.forward(dense1)
Nx.slice(activated1, [0, 0], [5, 3])
defmodule ActivationSoftmax do
  def forward(inputs) do
    max_val = Nx.reduce_max(inputs, axes: [1], keep_axes: true)
    exp_values = Nx.exp(inputs |> Nx.subtract(max_val))
    norm_base = Nx.sum(exp_values, axes: [1], keep_axes: true)
    Nx.divide(exp_values, norm_base)
  end
end

activated2 =
  activated1
  |> LayerDense.forward(LayerDense.new(3, 3))
  |> ActivationSoftmax.forward()

Nx.slice(activated2, [0, 0], [5, 3])

# ActivationSoftmax.forward(Nx.tensor([[4.8, 1.21, 2.385], [8.9, -1.81, 0.2], [1.41, 1.051, 0.026]]))
defmodule LossCategoricalCrossentropy do
  @eps 1.0e-7

  def forward(y_pred, y_true) do
    shape = tuple_size(Nx.shape(y_true))
    y_pred_clipped = Nx.clip(y_pred, @eps, 1 - @eps)

    y_true =
      if shape == 1 do
        labels_to_matrix(y_pred, y_true)
      else
        y_true
      end

    correct_confidences = Nx.sum(Nx.multiply(y_pred_clipped, y_true), axes: [1])
    Nx.multiply(-1, Nx.log(correct_confidences))
  end

  def labels_to_matrix(y_pred, labels) do
    {samples, n_labels} = Nx.shape(y_pred)

    labels
    |> Nx.reshape({samples, 1}, names: [:batch, :output])
    |> Nx.equal(Nx.tensor(Enum.to_list(0..(n_labels - 1))))
  end

  def calculate(output, y) do
    output
    |> forward(y)
    |> Nx.mean()
  end
end

import Nx, only: [sigil_M: 2]
softmax_outputs = ~M(
  0.7 0.1 0.2
  0.1 0.5 0.4
  0.02 0.9 0.08)

class_targets = ~M(
  1 0 0
  0 1 0
  0 1 0)
LossCategoricalCrossentropy.calculate(softmax_outputs, class_targets)

LossCategoricalCrossentropy.forward(softmax_outputs, Nx.tensor([0, 1, 1]))
|> IO.inspect()

loss = LossCategoricalCrossentropy.calculate(activated2, y_spiral)
IO.puts("loss: #{Nx.to_number(loss)}")
predictions = Nx.argmax(activated2, axis: 1)

y =
  if 2 == tuple_size(Nx.shape(y_spiral)) do
    Nx.argmax(y_spiral, axis: 1)
  else
    y_spiral
  end

accuracy = Nx.mean(Nx.equal(y, predictions))
IO.puts("accuracy #{Nx.to_number(accuracy)}")

Optimisation

x = [
  [0.176405, 0.688315],
  [0.040016, 0.365224],
  [0.097874, 0.372952],
  [0.224089, 0.596940],
  [0.186756, 0.382688],
  [-0.097728, 0.694362],
  [0.095009, 0.458638],
  [-0.015136, 0.425255],
  [-0.010322, 0.692294],
  [0.041060, 0.648052],
  [0.014404, 0.686756],
  [0.145427, 0.590604],
  [0.076104, 0.413877],
  [0.012168, 0.691006],
  [0.044386, 0.473200],
  [0.033367, 0.580246],
  [0.149408, 0.594725],
  [-0.020516, 0.484499],
  [0.031307, 0.561408],
  [-0.085410, 0.592221],
  [-0.255299, 0.537643],
  [0.065362, 0.390060],
  [0.086444, 0.529824],
  [-0.074217, 0.632639],
  [0.226975, 0.430543],
  [-0.145437, 0.485037],
  [0.004576, 0.456485],
  [-0.018718, 0.684926],
  [0.153278, 0.567229],
  [0.146936, 0.540746],
  [0.015495, 0.423008],
  [0.037816, 0.553925],
  [-0.088779, 0.432567],
  [-0.198080, 0.503183],
  [-0.034791, 0.436415],
  [0.015635, 0.567643],
  [0.123029, 0.557659],
  [0.120238, 0.479170],
  [-0.038733, 0.539601],
  [-0.030230, 0.390694],
  [-0.104855, 0.350874],
  [-0.142002, 0.543939],
  [-0.170627, 0.516667],
  [0.195078, 0.563503],
  [-0.050965, 0.738315],
  [-0.043807, 0.594448],
  [-0.125280, 0.408718],
  [0.077749, 0.611702],
  [-0.161390, 0.368409],
  [-0.021274, 0.453842],
  [-0.089547, 0.493176],
  [0.038690, 0.671334],
  [-0.051081, 0.425525],
  [-0.118063, 0.417356],
  [-0.002818, 0.490155],
  [0.042833, 0.433652],
  [0.006652, 0.612664],
  [0.030247, 0.392007],
  [-0.063432, 0.385253],
  [-0.036274, 0.456218],
  [-0.067246, 0.450197],
  [-0.035955, 0.692953],
  [-0.081315, 0.594942],
  [-0.172628, 0.508755],
  [0.017743, 0.377456],
  [-0.040178, 0.584436],
  [-0.163020, 0.399978],
  [0.046278, 0.345523],
  [-0.090730, 0.618803],
  [0.005195, 0.531694],
  [0.072909, 0.592086],
  [0.012898, 0.531873],
  [0.113940, 0.585683],
  [-0.123483, 0.434897],
  [0.040234, 0.396576],
  [-0.068481, 0.568159],
  [-0.087080, 0.419659],
  [-0.057885, 0.431045],
  [-0.031155, 0.454447],
  [0.005617, 0.501748],
  [-0.116515, 0.464601],
  [0.090083, 0.362505],
  [0.046566, 0.435638],
  [-0.153624, 0.277660],
  [0.148825, 0.562523],
  [0.189589, 0.339794],
  [0.117878, 0.389562],
  [-0.017992, 0.505216],
  [-0.107075, 0.426044],
  [0.105445, 0.654301],
  [-0.040318, 0.370714],
  [0.122244, 0.526705],
  [0.020827, 0.496072],
  [0.097664, 0.383191],
  [0.035637, 0.552328],
  [0.070657, 0.482845],
  [0.001050, 0.577179],
  [0.178587, 0.582350],
  [0.012691, 0.716324],
  [0.040199, 0.633653],
  [0.296415, 0.369347],
  [0.309395, 0.665813],
  [0.443299, 0.488184],
  [0.398860, 0.431982],
  [0.397346, 0.566638],
  [0.171638, 0.453928],
  [0.330901, 0.366574],
  [0.259530, 0.365328],
  [0.361326, 0.569377],
  [0.323518, 0.484043],
  [0.424351, 0.486630],
  [0.365055, 0.607774],
  [0.411966, 0.387317],
  [0.286691, 0.426932],
  [0.238889, 0.461512],
  [0.292328, 0.509435],
  [0.331631, 0.495783],
  [0.371249, 0.471311],
  [0.559264, 0.493837],
  [0.329108, 0.489269],
  [0.237739, 0.428040],
  [0.298735, 0.418701],
  [0.286974, 0.527452],
  [0.381481, 0.410908],
  [0.179254, 0.384264],
  [0.339660, 0.468771],
  [0.348984, 0.484233],
  [0.356551, 0.725672],
  [0.273602, 0.429530],
  [0.309541, 0.594326],
  [0.190927, 0.574719],
  [0.284001, 0.381106],
  [0.279047, 0.577325],
  [0.374938, 0.381612],
  [0.217715, 0.234083],
  [0.411453, 0.560632],
  [0.482782, 0.324411],
  [0.126335, 0.545093],
  [0.375959, 0.431599],
  [0.401024, 0.665955],
  [0.269590, 0.606851],
  [0.293606, 0.454661],
  [0.320045, 0.431216],
  [0.303554, 0.378592],
  [0.302432, 0.455908],
  [0.165733, 0.471964],
  [0.448566, 0.463531],
  [0.441295, 0.515670],
  [0.251997, 0.557852],
  [0.186691, 0.534965],
  [0.385440, 0.423586],
  [0.275755, 0.356221],
  [0.347529, 0.636453],
  [0.301401, 0.431055],
  [0.402487, 0.434771],
  [0.402808, 0.447881],
  [0.260774, 0.315693],
  [0.194997, 0.452203],
  [0.175039, 0.452034],
  [0.394371, 0.562036],
  [0.214447, 0.569846],
  [0.282652, 0.500377],
  [0.273702, 0.593185],
  [0.328077, 0.533997],
  [0.139705, 0.498432],
  [0.352211, 0.516093],
  [0.385722, 0.480935],
  [0.342176, 0.460515],
  [0.302245, 0.473227],
  [0.343073, 0.387199],
  [0.373238, 0.528044],
  [0.056074, 0.400688],
  [0.528925, 0.584163],
  [0.372343, 0.475054],
  [0.268092, 0.504950],
  [0.294238, 0.549384],
  [0.382708, 0.564331],
  [0.321723, 0.342938],
  [0.130265, 0.479310],
  [0.539783, 0.588018],
  [0.322279, 0.330189],
  [0.435351, 0.538728],
  [0.264128, 0.274444],
  [0.486971, 0.397749],
  [0.361968, 0.503863],
  [0.394218, 0.334328],
  [0.228808, 0.401449],
  [0.454448, 0.352816],
  [0.402315, 0.664814],
  [0.463518, 0.516423],
  [0.270525, 0.556729],
  [0.285231, 0.477732],
  [0.563725, 0.464657],
  [0.227332, 0.338353],
  [0.319738, 0.470816],
  [0.447022, 0.423851],
  [0.343106, 0.585792],
  [0.391629, 0.614110],
  [0.293388, 0.646658],
  [0.370339, 0.585255],
  [0.606801, 0.538273],
  [0.555077, 0.496576],
  [0.743333, 0.609635],
  [0.702296, 0.476578],
  [0.489813, 0.465255],
  [0.702215, 0.441873],
  [0.748119, 0.336737],
  [0.672559, 0.343223],
  [0.648161, 0.382084],
  [0.585902, 0.630143],
  [0.522013, 0.589526],
  [0.746696, 0.637496],
  [0.635755, 0.366779],
  [0.643320, 0.303138],
  [0.839939, 0.433994],
  [0.735117, 0.517582],
  [0.703749, 0.549869],
  [0.680873, 0.604797],
  [0.818666, 0.528428],
  [0.838626, 0.674267],
  [0.759617, 0.477739],
  [0.724889, 0.408692],
  [0.457206, 0.331878],
  [0.679039, 0.411103],
  [0.653656, 0.524212],
  [0.676062, 0.411128],
  [0.760971, 0.593674],
  [0.392699, 0.641233],
  [0.609735, 0.263041],
  [0.693657, 0.586405],
  [0.619982, 0.276040],
  [0.524976, 0.540150],
  [0.753563, 0.622487],
  [0.694354, 0.506486],
  [0.569556, 0.372031],
  [0.698148, 0.441457],
  [0.748825, 0.473835],
  [0.667196, 0.481776],
  [0.746723, 0.479710],
  [0.674493, 0.489012],
  [0.627144, 0.521348],
  [0.550725, 0.379143],
  [0.658074, 0.475798],
  [0.686096, 0.651826],
  [0.754250, 0.461535],
  [0.655156, 0.455616],
  [0.712408, 0.607820],
  [0.570206, 0.244082],
  [0.588404, 0.618138],
  [0.655628, 0.436810],
  [0.561204, 0.516393],
  [0.748691, 0.509632],
  [0.712980, 0.594247],
  [0.694576, 0.473241],
  [0.700557, 0.432197],
  [0.868771, 0.629785],
  [0.619780, 0.263583],
  [0.446523, 0.502033],
  [0.686597, 0.365207],
  [0.661606, 0.423843],
  [0.614915, 0.701126],
  [0.568784, 0.495540],
  [0.622748, 0.519507],
  [0.684801, 0.321844],
  [0.616385, 0.427096],
  [0.907912, 0.519656],
  [0.570616, 0.535476],
  [0.587355, 0.561689],
  [0.437805, 0.500863],
  [0.691815, 0.552700],
  [0.465026, 0.545378],
  [0.612721, 0.317026],
  [0.639100, 0.503701],
  [0.595694, 0.576790],
  [0.840554, 0.558988],
  [0.766106, 0.463614],
  [0.798580, 0.419437],
  [0.578425, 0.388169],
  [0.779526, 0.486895],
  [0.716267, 0.613308],
  [0.743807, 0.304820],
  [0.769611, 0.434011],
  [0.575790, 0.386020],
  [0.624235, 0.578496],
  [0.752926, 0.444569],
  [0.401105, 0.452936],
  [0.817999, 0.478305],
  [0.721980, 0.544539],
  [0.662096, 0.460761],
  [0.688717, 0.195386],
  [0.563673, 0.554331],
  [0.631672, 0.543904],
  [0.776695, 0.478046],
  [0.796469, 0.391596],
  [0.936289, 0.535178],
  [0.659274, 0.537924],
  [0.600811, 0.452997],
  [0.615243, 0.478327],
  [0.564862, 0.406984],
  [0.658881, 0.482141]
]

y =
  for x <- 0..2, _ <- 0..99 do
    x
  end

Vl.new(width: 400, height: 400, title: "nnfs vertical_data")
|> Vl.mark(:circle, opacity: 0.8, size: 40)
|> Vl.data_from_series(
  x: Enum.map(x, fn [x, _] -> x end),
  y: Enum.map(x, fn [_, y] -> y end),
  color: y
)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
|> Vl.encode_field(:color, "color", type: :nominal, scale: [scheme: "category10"])
x = Enum.map(0..5000, fn x -> x / 1000 end)

f = fn x -> 2 * :math.pow(x, 2) end
y_chart = Enum.map(x, &f.(&1))
p2_delta = 0.0001
x1 = 2
x2 = x1 + p2_delta
y1 = f.(x1)
y2 = f.(x2)
approximate_derivative = (y2 - y1) / (x2 - x1)
b = y2 - approximate_derivative * x2
tangent_line = fn x -> approximate_derivative * x + b end
tangent_at_2 = Enum.map(x, &tangent_line.(&1))

Vl.new(width: 200, height: 400, title: "examples")
|> Vl.data_from_series(x: x, y1: y_chart, tangent2: tangent_at_2)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.layers([
  Vl.new()
  |> Vl.mark(:line, color: :blue, size: 3)
  |> Vl.encode_field(:y, "y1", type: :quantitative),
  Vl.new()
  |> Vl.mark(:line, color: :red, size: 1)
  |> Vl.encode_field(:y, "tangent2", type: :quantitative)
])

Backpropagation

import Nx, only: [sigil_V: 2, sigil_M: 2]
x = ~V(1 -2 3)f32
w = ~V(-3 -1 2)f32
b = ~V(1)f32

[Nx.multiply(x, w), b]
|> Nx.concatenate()
|> Nx.sum()
|> Nx.max(~V(0)f32)
|> Nx.greater(~V(0)f32)
z = ~M{
1 2 -3 -4
2 -7 -1 3
-1 2 5 -1}
dvalues = ~M{
  1 2 3 4
  5 6 7 8
  9 10 11 12}
Nx.multiply(dvalues, Nx.less_equal(0, z))
dvalues = ~M(
  1.0 1.0 1.0
  2.0 2.0 2.0
  3.0 3.0 3.0)
inputs = ~M(
  1 2 3 2.5
  2 5 -1 2
  -1.5 2.7 3.3 -0.8)

weights =
  ~M(
  0.2  0.8    -0.5  1
  0.5  -0.91  0.26 -0.5
  -0.26 -0.27 0.27 0.87)
  |> Nx.transpose()

biases = ~M(2 3 0.5)
layer_outputs = Nx.dot(inputs, weights) |> Nx.add(biases)
relu_outputs = Nx.max(0, layer_outputs)
drelu = Nx.multiply(relu_outputs, Nx.less_equal(0, layer_outputs))
dinputs = Nx.dot(drelu, weights |> Nx.transpose())
dweights = Nx.dot(Nx.transpose(inputs), drelu)
dbiases = Nx.sum(drelu, axes: [0], keep_axes: true)
weights = Nx.multiply(-0.001, dweights) |> Nx.add(weights)
biases = Nx.multiply(-0.001, dbiases) |> Nx.add(biases)
# override module with backward function
defmodule LayerDense do
  def new(n_inputs, n_neurons) do
    weights = Nx.random_normal({n_inputs, n_neurons}) |> Nx.multiply(0.01)
    biases = Nx.broadcast(Nx.tensor(0), {1, n_neurons})
    {weights, biases}
  end

  def forward(inputs, {weights, biases}) do
    Nx.dot(inputs, weights) |> Nx.add(biases)
  end

  def backward(dvalues, {old_weights, _}, old_inputs) do
    dweights = Nx.dot(Nx.transpose(old_inputs), dvalues)
    dbiases = Nx.sum(dvalues, axes: [0], keep_axes: true)
    dinputs = Nx.dot(dvalues, old_weights |> Nx.transpose())
    {{dweights, dbiases}, dinputs}
  end
end
defmodule ActivationReLU do
  def forward(inputs) do
    Nx.max(0, inputs)
  end

  def backward(dvalues, old_inputs) do
    _dinputs = Nx.multiply(dvalues, Nx.less_equal(0, old_inputs))
  end
end
softmax_outputs = ~M[0.7 0.1 0.2]

Nx.subtract(
  Nx.size(softmax_outputs[0])
  |> Nx.eye()
  |> Nx.multiply(softmax_outputs),
  softmax_outputs
  |> Nx.transpose()
  |> Nx.dot(softmax_outputs)
)
defmodule ActivationSoftmaxLossCategoricalCrossentropy do
  def forward(inputs, y_true) do
    inputs
    |> ActivationSoftmax.forward()
    |> LossCategoricalCrossentropy.forward(y_true)
  end

  @doc """
  logic is a bit different
  in the book we were substracting -1 from every index in the y_true vector
  if it was a matrix we converted it to a vector
  here we use the matrix with 1 and 0 and substract it from the original matrix
  the effect is the same we just don't mutate the data - the result is the same as
  the examples on page 65
  """
  def backward(dvalues, y_true) do
    {n_samples, _n_labels} = Nx.shape(dvalues)
    shape = tuple_size(Nx.shape(y_true))

    y_true =
      if shape == 1 do
        LossCategoricalCrossentropy.labels_to_matrix(dvalues, y_true)
      else
        y_true
      end

    Nx.subtract(dvalues, y_true)
    |> Nx.divide(n_samples)
  end
end
softmax_outputs = ~M(
  0.7 0.1 0.2
  0.1 0.5 0.4
  0.02 0.9 0.08
)
class_targets_vector = ~V(0 1 1)
class_targets_matrix = ~M(
  1 0 0
  0 1 0
  0 1 0)

softmax_outputs
|> ActivationSoftmaxLossCategoricalCrossentropy.backward(class_targets_vector)
|> IO.inspect(label: :with_vector)

softmax_outputs
|> ActivationSoftmaxLossCategoricalCrossentropy.backward(class_targets_matrix)
|> IO.inspect(label: :with_matrix)

{time, _result} =
  :timer.tc(
    &ActivationSoftmaxLossCategoricalCrossentropy.backward/2,
    [softmax_outputs, class_targets_matrix]
  )

IO.puts("execution took #{time} us")
dense1 = LayerDense.new(2, 3)
dense2 = LayerDense.new(3, 3)

# 300x3
dense1_output = LayerDense.forward(x_spiral, dense1)
# 300x3
relu_output = ActivationReLU.forward(dense1_output)
# 300x3
dense2_output = LayerDense.forward(relu_output, dense2)
# 300x3
softmax_output = ActivationSoftmax.forward(dense2_output)
# 300
loss =
  LossCategoricalCrossentropy.calculate(softmax_output, y_spiral)
  |> IO.inspect()

d_softmax = ActivationSoftmaxLossCategoricalCrossentropy.backward(softmax_output, y_spiral)
{d_dense2, dinputs2} = LayerDense.backward(d_softmax, dense2, relu_output)
d_relu = ActivationReLU.backward(dinputs2, dense1_output)
{d_dense1, dinputs1} = LayerDense.backward(d_relu, dense1, x_spiral)
defmodule OptimizerSGD do
  def update_params(
        {old_weights, old_biases},
        {dweights, dbiases},
        {old_weight_momentums, old_bias_momentums},
        settings \\ %{}
      ) do
    weight_momentums = get_momentums(old_weight_momentums, dweights, settings)
    weights = get_param(weight_momentums, old_weights)
    bias_momentums = get_momentums(old_bias_momentums, dbiases, settings)
    biases = get_param(bias_momentums, old_biases)

    {{weights, biases}, {weight_momentums, bias_momentums}}
  end

  defp get_momentums(old_momentums, dparam, settings) do
    momentum = Map.get(settings, :momentum, 0.1)
    learning_rate = Map.get(settings, :current_learning_rate, 1.0)

    Nx.subtract(
      Nx.multiply(old_momentums, momentum),
      Nx.multiply(dparam, learning_rate)
    )
  end

  defp get_param(old_param, momentums) do
    Nx.add(momentums, old_param)
  end

  def pre_update_params(epoch, settings) do
    current_learning_rate =
      Map.get(settings, :learning_rate, 1.0) * (1.0 / (1 + Map.get(settings, :decay, 0) * epoch))

    Map.put(settings, :current_learning_rate, current_learning_rate)
  end

  def new_momentum({weights, biases}) do
    weights = Nx.broadcast(Nx.tensor(0), weights)
    biases = Nx.broadcast(Nx.tensor(0), biases)
    {weights, biases}
  end
end

IO.inspect(dense1)
OptimizerSGD.update_params(dense1, dense1, OptimizerSGD.new_momentum(dense1))

With SDG + momentum

widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(2, 64)
dense2 = LayerDense.new(64, 3)
# momentum is called cache for easier copy and paste
cache1 = OptimizerSGD.new_momentum(dense1)
cache2 = OptimizerSGD.new_momentum(dense2)

settings = %{
  epochs: 1,
  update_every: 10,
  learning_rate: 1.0,
  decay: 1.0e-3,
  momentum: 0.9
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  cache1: cache1,
  cache2: cache2,
  settings: settings
}

Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
  # forward pass
  dense1_output = LayerDense.forward(x_spiral, acc.dense1)
  relu_output = ActivationReLU.forward(dense1_output)
  dense2_output = LayerDense.forward(relu_output, acc.dense2)
  softmax_output = ActivationSoftmax.forward(dense2_output)
  # backward pass
  d_softmax = ActivationSoftmaxLossCategoricalCrossentropy.backward(softmax_output, y_spiral)
  {d_dense2, dinputs2} = LayerDense.backward(d_softmax, acc.dense2, relu_output)
  d_relu = ActivationReLU.backward(dinputs2, dense1_output)
  {d_dense1, _dinputs1} = LayerDense.backward(d_relu, acc.dense1, x_spiral)
  # settings update
  new_settings = OptimizerSGD.pre_update_params(epoch, acc.settings)

  {new_dense1, new_cache1} =
    OptimizerSGD.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

  {new_dense2, new_cache2} =
    OptimizerSGD.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

  # calculate how good we are
  loss = LossCategoricalCrossentropy.calculate(softmax_output, y_spiral) |> Nx.to_number()
  predictions = Nx.argmax(softmax_output, axis: 1)
  accuracy = Nx.mean(Nx.equal(predictions, y_spiral)) |> Nx.to_number()
  # print chart
  if rem(epoch, acc.settings.update_every) == 0 do
    point = %{epoch: epoch, loss: loss, accuracy: accuracy}
    Kino.VegaLite.push(widget, point, window: 1000)
  end

  %{
    acc
    | dense1: new_dense1,
      dense2: new_dense2,
      cache1: new_cache1,
      cache2: new_cache2,
      settings: new_settings
  }
end)

With Adagrad optimizer

defmodule OptimizerAdagrad do
  def update_params(
        {old_weights, old_biases},
        {dweights, dbiases},
        {old_weight_cache, old_bias_cache},
        settings \\ %{}
      ) do
    weight_updates = get_updates(old_weight_cache, dweights)
    weights = get_param(old_weights, dweights, weight_updates, settings)
    bias_updates = get_updates(old_bias_cache, dbiases)

    biases = get_param(old_biases, dbiases, bias_updates, settings)

    {{weights, biases}, {weight_updates, bias_updates}}
  end

  defp get_updates(old_cache, dparam) do
    Nx.add(old_cache, Nx.power(dparam, 2))
  end

  defp get_param(old_param, dparam, param_updates, settings) do
    learning_rate = Map.get(settings, :current_learning_rate, 1.0)
    epsilon = Map.get(settings, :espilon, 1.0e-7)

    Nx.add(
      old_param,
      Nx.divide(
        Nx.multiply(dparam, -1 * learning_rate),
        Nx.add(Nx.sqrt(param_updates), epsilon)
      )
    )
  end

  def pre_update_params(epoch, settings) do
    current_learning_rate =
      Map.get(settings, :learning_rate, 1.0) * (1.0 / (1 + Map.get(settings, :decay, 0) * epoch))

    Map.put(settings, :current_learning_rate, current_learning_rate)
  end

  def new_cache({weights, biases}) do
    weights = Nx.broadcast(Nx.tensor(0), weights)
    biases = Nx.broadcast(Nx.tensor(0), biases)
    {weights, biases}
  end
end

IO.inspect(dense1)
OptimizerAdagrad.update_params(dense1, dense1, OptimizerAdagrad.new_cache(dense1))
widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(2, 64)
dense2 = LayerDense.new(64, 3)
cache1 = OptimizerAdagrad.new_cache(dense1)
cache2 = OptimizerAdagrad.new_cache(dense2)

settings = %{
  epochs: 1,
  update_every: 10,
  learning_rate: 1.0,
  decay: 1.0e-4,
  epsilon: 1.0e-7
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  cache1: cache1,
  cache2: cache2,
  settings: settings
}

Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
  # forward pass
  dense1_output = LayerDense.forward(x_spiral, acc.dense1)
  relu_output = ActivationReLU.forward(dense1_output)
  dense2_output = LayerDense.forward(relu_output, acc.dense2)
  softmax_output = ActivationSoftmax.forward(dense2_output)
  # backward pass
  d_softmax = ActivationSoftmaxLossCategoricalCrossentropy.backward(softmax_output, y_spiral)
  {d_dense2, dinputs2} = LayerDense.backward(d_softmax, acc.dense2, relu_output)
  d_relu = ActivationReLU.backward(dinputs2, dense1_output)
  {d_dense1, _dinputs1} = LayerDense.backward(d_relu, acc.dense1, x_spiral)
  # settings update
  new_settings = OptimizerAdagrad.pre_update_params(epoch, acc.settings)

  {new_dense1, new_cache1} =
    OptimizerAdagrad.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

  {new_dense2, new_cache2} =
    OptimizerAdagrad.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

  # calculate how good we are
  loss = LossCategoricalCrossentropy.calculate(softmax_output, y_spiral) |> Nx.to_number()
  predictions = Nx.argmax(softmax_output, axis: 1)
  accuracy = Nx.mean(Nx.equal(predictions, y_spiral)) |> Nx.to_number()
  # print chart
  if rem(epoch, acc.settings.update_every) == 0 do
    point = %{epoch: epoch, loss: loss, accuracy: accuracy}
    Kino.VegaLite.push(widget, point, window: 1000)
  end

  %{
    acc
    | dense1: new_dense1,
      dense2: new_dense2,
      cache1: new_cache1,
      cache2: new_cache2,
      settings: new_settings
  }
end)

With RMSProp optimizer

defmodule OptimizerRMSProp do
  def update_params(
        {old_weights, old_biases},
        {dweights, dbiases},
        {old_weight_cache, old_bias_cache},
        settings \\ %{}
      ) do
    weight_updates = get_updates(dweights, old_weight_cache, settings)
    weights = get_param(old_weights, dweights, weight_updates, settings)
    bias_updates = get_updates(dbiases, old_bias_cache, settings)
    biases = get_param(old_biases, dbiases, bias_updates, settings)

    {{weights, biases}, {weight_updates, bias_updates}}
  end

  defp get_updates(dparam, old_param_cache, settings) do
    rho = Map.get(settings, :rho, 0.9)

    Nx.add(
      Nx.multiply(old_param_cache, rho),
      Nx.multiply(Nx.power(dparam, 2), 1 - rho)
    )
  end

  defp get_param(old_param, dparam, updates, settings) do
    learning_rate = Map.get(settings, :current_learning_rate, 1.0)
    epsilon = Map.get(settings, :espilon, 1.0e-7)

    Nx.add(
      old_param,
      Nx.divide(
        Nx.multiply(dparam, -1 * learning_rate),
        Nx.add(Nx.sqrt(updates), epsilon)
      )
    )
  end

  def pre_update_params(epoch, settings) do
    current_learning_rate =
      Map.get(settings, :learning_rate, 1.0) * (1.0 / (1 + Map.get(settings, :decay, 0) * epoch))

    Map.put(settings, :current_learning_rate, current_learning_rate)
  end

  def new_cache({weights, biases}) do
    weights = Nx.broadcast(Nx.tensor(0), weights)
    biases = Nx.broadcast(Nx.tensor(0), biases)
    {weights, biases}
  end
end

IO.inspect(dense1)
OptimizerRMSProp.update_params(dense1, dense1, OptimizerRMSProp.new_cache(dense1))
widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(2, 64)
dense2 = LayerDense.new(64, 3)
cache1 = OptimizerAdagrad.new_cache(dense1)
cache2 = OptimizerAdagrad.new_cache(dense2)

settings = %{
  epochs: 1,
  update_every: 10,
  learning_rate: 0.02,
  decay: 1.0e-5,
  epsilon: 1.0e-7,
  rho: 0.999
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  cache1: cache1,
  cache2: cache2,
  settings: settings
}

Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
  # forward pass
  dense1_output = LayerDense.forward(x_spiral, acc.dense1)
  relu_output = ActivationReLU.forward(dense1_output)
  dense2_output = LayerDense.forward(relu_output, acc.dense2)
  softmax_output = ActivationSoftmax.forward(dense2_output)
  # backward pass
  d_softmax = ActivationSoftmaxLossCategoricalCrossentropy.backward(softmax_output, y_spiral)
  {d_dense2, dinputs2} = LayerDense.backward(d_softmax, acc.dense2, relu_output)
  d_relu = ActivationReLU.backward(dinputs2, dense1_output)
  {d_dense1, _dinputs1} = LayerDense.backward(d_relu, acc.dense1, x_spiral)
  # settings update
  new_settings = OptimizerRMSProp.pre_update_params(epoch, acc.settings)

  {new_dense1, new_cache1} =
    OptimizerRMSProp.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

  {new_dense2, new_cache2} =
    OptimizerRMSProp.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

  # calculate how good we are
  loss = LossCategoricalCrossentropy.calculate(softmax_output, y_spiral) |> Nx.to_number()
  predictions = Nx.argmax(softmax_output, axis: 1)
  accuracy = Nx.mean(Nx.equal(predictions, y_spiral)) |> Nx.to_number()
  # print chart
  if rem(epoch, acc.settings.update_every) == 0 do
    point = %{epoch: epoch, loss: loss, accuracy: accuracy}
    Kino.VegaLite.push(widget, point, window: 1000)
  end

  %{
    acc
    | dense1: new_dense1,
      dense2: new_dense2,
      cache1: new_cache1,
      cache2: new_cache2,
      settings: new_settings
  }
end)

with Adam optimizer

defmodule OptimizerAdam do
  def update_params(
        {old_weights, old_biases},
        {dweights, dbiases},
        {{old_weight_cache, old_bias_cache}, {old_weights_momentum, old_bias_momentum}},
        settings \\ %{}
      ) do
    weight_momentum = get_momentum(dweights, old_weights_momentum, settings)
    weight_momentum_corrected = get_momentum_corrected(weight_momentum, settings)
    weight_updates = get_updates(dweights, old_weight_cache, settings)
    weight_updates_corrected = get_updates_corrected(weight_updates, settings)

    weights =
      get_param(old_weights, weight_momentum_corrected, weight_updates_corrected, settings)

    bias_momentum = get_momentum(dbiases, old_bias_momentum, settings)
    bias_momentum_corrected = get_momentum_corrected(bias_momentum, settings)
    bias_updates = get_updates(dbiases, old_bias_cache, settings)
    bias_updates_corrected = get_updates_corrected(bias_updates, settings)
    biases = get_param(old_biases, bias_momentum_corrected, bias_updates_corrected, settings)

    {{weights, biases},
     {{weight_updates, bias_updates}, {old_weights_momentum, old_bias_momentum}}}
  end

  defp get_momentum(dparam, old_param_momentum, settings) do
    momentum = Map.get(settings, :momentum, 0.9)

    Nx.add(
      Nx.multiply(old_param_momentum, momentum),
      Nx.multiply(dparam, 1 - momentum)
    )
  end

  defp get_momentum_corrected(param_momentum, settings) do
    momentum = Map.get(settings, :momentum, 0.9)
    epoch = Map.get(settings, :epoch, 0)

    Nx.divide(
      param_momentum,
      Nx.subtract(Nx.tensor(1), Nx.power(momentum, epoch + 1))
    )
  end

  defp get_updates(dparam, old_param_cache, settings) do
    rho = Map.get(settings, :rho, 0.999)

    Nx.add(
      Nx.multiply(old_param_cache, rho),
      Nx.multiply(Nx.power(dparam, 2), 1 - rho)
    )
  end

  defp get_updates_corrected(param_updates, settings) do
    rho = Map.get(settings, :rho, 0.999)
    epoch = Map.get(settings, :epoch, 0)

    Nx.divide(
      param_updates,
      Nx.subtract(Nx.tensor(1), Nx.power(rho, epoch + 1))
    )
  end

  defp get_param(old_param, momentum_corrected, updates_corrected, settings) do
    learning_rate = Map.get(settings, :current_learning_rate, 1.0)
    epsilon = Map.get(settings, :espilon, 1.0e-7)

    Nx.add(
      old_param,
      Nx.divide(
        Nx.multiply(momentum_corrected, -1 * learning_rate),
        Nx.add(Nx.sqrt(updates_corrected), epsilon)
      )
    )
  end

  def pre_update_params(epoch, settings) do
    current_learning_rate =
      Map.get(settings, :learning_rate, 1.0) * (1.0 / (1 + Map.get(settings, :decay, 0) * epoch))

    Map.put(settings, :current_learning_rate, current_learning_rate)
  end

  def new_cache({weights, biases}) do
    weights = Nx.broadcast(Nx.tensor(0), weights)
    biases = Nx.broadcast(Nx.tensor(0), biases)
    {weights, biases}
  end

  def new_momentum(params), do: new_cache(params)
end
widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(2, 64)
dense2 = LayerDense.new(64, 3)
cache1 = {OptimizerAdam.new_cache(dense1), OptimizerAdam.new_momentum(dense1)}
cache2 = {OptimizerAdam.new_cache(dense2), OptimizerAdam.new_momentum(dense2)}

settings = %{
  epochs: 1,
  update_every: 5,
  learning_rate: 0.05,
  decay: 5.0e-7,
  epsilon: 1.0e-7,
  momentum: 0.95,
  rho: 0.95
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  cache1: cache1,
  cache2: cache2,
  settings: settings
}

Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
  # forward pass
  dense1_output = LayerDense.forward(x_spiral, acc.dense1)
  relu_output = ActivationReLU.forward(dense1_output)
  dense2_output = LayerDense.forward(relu_output, acc.dense2)
  softmax_output = ActivationSoftmax.forward(dense2_output)
  # backward pass
  d_softmax = ActivationSoftmaxLossCategoricalCrossentropy.backward(softmax_output, y_spiral)
  {d_dense2, dinputs2} = LayerDense.backward(d_softmax, acc.dense2, relu_output)
  d_relu = ActivationReLU.backward(dinputs2, dense1_output)
  {d_dense1, _dinputs1} = LayerDense.backward(d_relu, acc.dense1, x_spiral)
  # settings update
  new_settings = OptimizerAdam.pre_update_params(epoch, acc.settings)

  {new_dense1, new_cache1} =
    OptimizerAdam.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

  {new_dense2, new_cache2} =
    OptimizerAdam.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

  # calculate how good we are
  loss = LossCategoricalCrossentropy.calculate(softmax_output, y_spiral) |> Nx.to_number()
  predictions = Nx.argmax(softmax_output, axis: 1)
  accuracy = Nx.mean(Nx.equal(predictions, y_spiral)) |> Nx.to_number()
  # print chart
  if rem(epoch, acc.settings.update_every) == 0 do
    point = %{epoch: epoch, loss: loss, accuracy: accuracy}
    Kino.VegaLite.push(widget, point, window: 1000)
  end

  %{
    acc
    | dense1: new_dense1,
      dense2: new_dense2,
      cache1: new_cache1,
      cache2: new_cache2,
      settings: new_settings
  }
end)

L1 and L2 Regularization

# override module with backward function
defmodule LayerDense do
  def new(n_inputs, n_neurons) do
    weights = Nx.random_normal({n_inputs, n_neurons}) |> Nx.multiply(0.01)
    biases = Nx.broadcast(Nx.tensor(0), {1, n_neurons})
    {weights, biases}
  end

  def new_regularization(l1_w \\ 0, l1_b \\ 0, l2_w \\ 0, l2_b \\ 0) do
    %{l1_w: l1_w, l1_b: l1_b, l2_w: l2_w, l2_b: l2_b}
  end

  def forward(inputs, {weights, biases}) do
    Nx.dot(inputs, weights) |> Nx.add(biases)
  end

  def backward(dvalues, {old_weights, old_biases}, old_inputs, regularization) do
    %{l1_w: l1_w, l1_b: l1_b, l2_w: l2_w, l2_b: l2_b} = regularization

    dweights =
      Nx.dot(Nx.transpose(old_inputs), dvalues)
      |> dl1(old_weights, l1_w)
      |> dl2(old_weights, l2_w)

    dbiases =
      Nx.sum(dvalues, axes: [0], keep_axes: true)
      |> dl1(old_biases, l1_b)
      |> dl2(old_biases, l2_b)

    dinputs = Nx.dot(dvalues, Nx.transpose(old_weights))
    {{dweights, dbiases}, dinputs}
  end

  def dl1(dparam, old_param, l1) when l1 <= 0, do: dparam

  def dl1(dparam, old_param, l1) do
    dL1 = Nx.multiply(Nx.subtract(Nx.greater(old_param, 0), 0.5), 2)
    Nx.add(dparam, Nx.add(dL1, l1))
  end

  def dl2(dparam, old_param, l2) when l2 <= 0, do: dparam

  def dl2(dparam, old_param, l2) do
    Nx.add(dparam, Nx.multiply(old_param, l2 * 2))
  end
end

~M{1 1
  -1 1
}
|> LayerDense.dl1(
  ~M{1 1
    1 1},
  1
)
defmodule LossCategoricalCrossentropy do
  @eps 1.0e-7

  def forward(y_pred, y_true) do
    shape = tuple_size(Nx.shape(y_true))
    y_pred_clipped = Nx.clip(y_pred, @eps, 1 - @eps)

    y_true =
      if shape == 1 do
        labels_to_matrix(y_pred, y_true)
      else
        y_true
      end

    correct_confidences = Nx.sum(Nx.multiply(y_pred_clipped, y_true), axes: [1])
    Nx.multiply(-1, Nx.log(correct_confidences))
  end

  def labels_to_matrix(y_pred, labels) do
    {samples, n_labels} = Nx.shape(y_pred)

    labels
    |> Nx.reshape({samples, 1}, names: [:batch, :output])
    |> Nx.equal(Nx.tensor(Enum.to_list(0..(n_labels - 1))))
  end

  def regularization_loss({weights, biases}, %{l1_w: l1_w, l1_b: l1_b, l2_w: l2_w, l2_b: l2_b}) do
    zero = Nx.tensor(0.0)

    zero
    |> Nx.add(
      if l1_w > 0 do
        Nx.multiply(Nx.sum(Nx.abs(weights)), l1_w)
      else
        zero
      end
    )
    |> Nx.add(
      if l1_b > 0 do
        Nx.multiply(Nx.sum(Nx.abs(biases)), l1_b)
      else
        zero
      end
    )
    |> Nx.add(
      if l2_w > 0 do
        Nx.multiply(Nx.sum(Nx.power(weights, 2)), l2_w)
      else
        zero
      end
    )
    |> Nx.add(
      if l2_b > 0 do
        Nx.multiply(Nx.sum(Nx.power(biases, 2)), l2_b)
      else
        zero
      end
    )
  end

  def calculate(output, y) do
    output
    |> forward(y)
    |> Nx.mean()
  end
end
widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :green, size: 1)
    |> Vl.encode_field(:y, "reg_loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(2, 64)
dense2 = LayerDense.new(64, 3)
dense1_regularizer = LayerDense.new_regularization(0, 0, 5.0e-4, 5.0e-4)
dense2_regularizer = LayerDense.new_regularization(0, 0, 5.0e-4, 5.0e-4)
cache1 = {OptimizerAdam.new_cache(dense1), OptimizerAdam.new_momentum(dense1)}
cache2 = {OptimizerAdam.new_cache(dense2), OptimizerAdam.new_momentum(dense2)}

settings = %{
  epochs: 1,
  update_every: 10,
  learning_rate: 0.05,
  decay: 5.0e-7,
  epsilon: 1.0e-7,
  momentum: 0.95,
  rho: 0.95
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  cache1: cache1,
  cache2: cache2,
  settings: settings
}

Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
  # forward pass
  dense1_output = LayerDense.forward(x_spiral, acc.dense1)
  relu_output = ActivationReLU.forward(dense1_output)
  dense2_output = LayerDense.forward(relu_output, acc.dense2)
  softmax_output = ActivationSoftmax.forward(dense2_output)
  # backward pass
  d_softmax = ActivationSoftmaxLossCategoricalCrossentropy.backward(softmax_output, y_spiral)

  {d_dense2, dinputs2} =
    LayerDense.backward(d_softmax, acc.dense2, relu_output, dense2_regularizer)

  d_relu = ActivationReLU.backward(dinputs2, dense1_output)
  {d_dense1, _dinputs1} = LayerDense.backward(d_relu, acc.dense1, x_spiral, dense1_regularizer)
  # settings update
  new_settings = OptimizerAdam.pre_update_params(epoch, acc.settings)

  {new_dense1, new_cache1} =
    OptimizerAdam.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

  {new_dense2, new_cache2} =
    OptimizerAdam.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

  # calculate how good we are
  data_loss = LossCategoricalCrossentropy.calculate(softmax_output, y_spiral)

  regularization_loss =
    Nx.add(
      LossCategoricalCrossentropy.regularization_loss(acc.dense1, dense1_regularizer),
      LossCategoricalCrossentropy.regularization_loss(acc.dense2, dense2_regularizer)
    )

  loss = Nx.add(data_loss, regularization_loss)
  predictions = Nx.argmax(softmax_output, axis: 1)
  accuracy = Nx.mean(Nx.equal(predictions, y_spiral)) |> Nx.to_number()
  # print chart
  if rem(epoch, acc.settings.update_every) == 0 do
    point = %{
      epoch: epoch,
      loss: Nx.to_number(data_loss),
      reg_loss: Nx.to_number(regularization_loss),
      accuracy: accuracy
    }

    Kino.VegaLite.push(widget, point, window: 1000)
  end

  %{
    acc
    | dense1: new_dense1,
      dense2: new_dense2,
      cache1: new_cache1,
      cache2: new_cache2,
      settings: new_settings
  }
end)
defmodule LayerDropout do
  def forward(inputs, settings) do
    dropout_rate = Map.get(settings, :dropout_rate, 0.1)
    binary_mask = Nx.random_uniform(inputs) |> Nx.greater(dropout_rate)
    {Nx.multiply(inputs, binary_mask), binary_mask}
  end

  def backward(dvalues, binary_mask) do
    _dinputs = Nx.multiply(dvalues, binary_mask)
  end
end
widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :green, size: 1)
    |> Vl.encode_field(:y, "reg_loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(2, 64)
dense2 = LayerDense.new(64, 3)
dense1_regularizer = LayerDense.new_regularization(0, 0, 5.0e-4, 5.0e-4)
dense2_regularizer = LayerDense.new_regularization(0, 0, 5.0e-4, 5.0e-4)
cache1 = {OptimizerAdam.new_cache(dense1), OptimizerAdam.new_momentum(dense1)}
cache2 = {OptimizerAdam.new_cache(dense2), OptimizerAdam.new_momentum(dense2)}

settings = %{
  epochs: 1,
  update_every: 10,
  learning_rate: 0.05,
  decay: 5.0e-5,
  epsilon: 1.0e-7,
  momentum: 0.95,
  rho: 0.95,
  dropout_rate: 0.1
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  cache1: cache1,
  cache2: cache2,
  settings: settings
}

Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
  # forward pass
  dense1_output = LayerDense.forward(x_spiral, acc.dense1)
  {dropout1_output, binary_mask1} = LayerDropout.forward(dense1_output, acc.settings)
  relu_output = ActivationReLU.forward(dropout1_output)
  dense2_output = LayerDense.forward(relu_output, acc.dense2)
  softmax_output = ActivationSoftmax.forward(dense2_output)
  # backward pass
  d_softmax = ActivationSoftmaxLossCategoricalCrossentropy.backward(softmax_output, y_spiral)

  {d_dense2, dinputs2} =
    LayerDense.backward(d_softmax, acc.dense2, relu_output, dense2_regularizer)

  d_relu = ActivationReLU.backward(dinputs2, dense1_output)
  d_dropout1 = LayerDropout.backward(d_relu, binary_mask1)

  {d_dense1, _dinputs1} =
    LayerDense.backward(d_dropout1, acc.dense1, x_spiral, dense1_regularizer)

  # settings update
  new_settings = OptimizerAdam.pre_update_params(epoch, acc.settings)

  {new_dense1, new_cache1} =
    OptimizerAdam.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

  {new_dense2, new_cache2} =
    OptimizerAdam.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

  # calculate how good we are
  data_loss = LossCategoricalCrossentropy.calculate(softmax_output, y_spiral)

  regularization_loss =
    Nx.add(
      LossCategoricalCrossentropy.regularization_loss(acc.dense1, dense1_regularizer),
      LossCategoricalCrossentropy.regularization_loss(acc.dense2, dense2_regularizer)
    )

  loss = Nx.add(data_loss, regularization_loss)
  predictions = Nx.argmax(softmax_output, axis: 1)
  accuracy = Nx.mean(Nx.equal(predictions, y_spiral)) |> Nx.to_number()
  # print chart
  if rem(epoch, acc.settings.update_every) == 0 do
    point = %{
      epoch: epoch,
      loss: Nx.to_number(data_loss),
      reg_loss: Nx.to_number(regularization_loss),
      accuracy: accuracy
    }

    Kino.VegaLite.push(widget, point, window: 1000)
  end

  %{
    acc
    | dense1: new_dense1,
      dense2: new_dense2,
      cache1: new_cache1,
      cache2: new_cache2,
      settings: new_settings
  }
end)

Binary Logistic Regression

Section

defmodule ActivationSigmoid do
  def forward(inputs) do
    Nx.divide(1, Nx.add(1, Nx.exp(Nx.negate(inputs))))
  end

  def backward(dvalues, old_outputs) do
    Nx.multiply(Nx.multiply(Nx.subtract(1, old_outputs), old_outputs), dvalues)
  end
end

sig_output = ~M{1 3 -4 5} |> ActivationSigmoid.forward()
ActivationSigmoid.backward(sig_output, sig_output)
defmodule LossBinaryCrossentropy do
  @exp 1.0e-7

  def forward(y_pred, y_true) do
    y_clipped = Nx.clip(y_pred, @exp, 1 - @exp)

    losses =
      Nx.add(
        Nx.negate(
          Nx.multiply(
            y_true,
            Nx.log(y_clipped)
          )
        ),
        Nx.multiply(
          Nx.subtract(1, y_true),
          Nx.log(Nx.subtract(1, y_clipped))
        )
      )

    Nx.mean(losses, axes: [-1])
  end

  def backward(dvalues, y_true) do
    {samples, outputs} = Nx.shape(dvalues)
    d_clipped = Nx.clip(dvalues, @exp, 1 - @exp)

    Nx.negate(
      Nx.divide(
        Nx.subtract(
          Nx.divide(y_true, d_clipped),
          Nx.divide(Nx.subtract(1, y_true), Nx.subtract(1, d_clipped))
        ),
        outputs
      )
    )
    |> Nx.divide(samples)
  end

  def calculate(output, y) do
    output
    |> forward(y)
    |> Nx.mean()
  end
end

y_true = ~M(1 1 0 1)
LossBinaryCrossentropy.forward(sig_output, y_true)
LossBinaryCrossentropy.backward(sig_output, y_true)
widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :green, size: 1)
    |> Vl.encode_field(:y, "reg_loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(2, 64)
dense2 = LayerDense.new(64, 1)
dense1_regularizer = LayerDense.new_regularization(0, 0, 5.0e-4, 5.0e-4)
dense2_regularizer = LayerDense.new_regularization()
cache1 = {OptimizerAdam.new_cache(dense1), OptimizerAdam.new_momentum(dense1)}
cache2 = {OptimizerAdam.new_cache(dense2), OptimizerAdam.new_momentum(dense2)}

settings = %{
  epochs: 1,
  update_every: 10,
  learning_rate: 0.05,
  decay: 5.0e-7,
  epsilon: 1.0e-7,
  momentum: 0.95,
  rho: 0.95,
  dropout_rate: 0.1
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  cache1: cache1,
  cache2: cache2,
  settings: settings
}

y_spiral2 = Nx.slice(y_spiral, [0], [200]) |> Nx.reshape({200, 1})
x_spiral2 = Nx.slice(x_spiral, [0, 0], [200, 2])

Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
  # forward pass
  dense1_output = LayerDense.forward(x_spiral2, acc.dense1)
  relu_output = ActivationReLU.forward(dense1_output)
  dense2_output = LayerDense.forward(relu_output, acc.dense2)
  sigmoid_output = ActivationSigmoid.forward(dense2_output)

  data_loss = LossBinaryCrossentropy.calculate(sigmoid_output, y_spiral2)

  regularization_loss =
    Nx.add(
      LossCategoricalCrossentropy.regularization_loss(acc.dense1, dense1_regularizer),
      LossCategoricalCrossentropy.regularization_loss(acc.dense2, dense2_regularizer)
    )

  loss = Nx.add(data_loss, regularization_loss)
  predictions = Nx.greater(sigmoid_output, 0.5)
  accuracy = Nx.mean(Nx.equal(predictions, y_spiral2)) |> Nx.to_number()
  # backward pass

  d_loss = LossBinaryCrossentropy.backward(sigmoid_output, y_spiral2)

  d_sigmoid = ActivationSigmoid.backward(d_loss, sigmoid_output)

  {d_dense2, dinputs2} =
    LayerDense.backward(d_sigmoid, acc.dense2, relu_output, dense2_regularizer)

  d_relu = ActivationReLU.backward(dinputs2, dense1_output)

  {d_dense1, _dinputs1} = LayerDense.backward(d_relu, acc.dense1, x_spiral2, dense1_regularizer)

  # settings update
  new_settings = OptimizerAdam.pre_update_params(epoch, acc.settings)

  {new_dense1, new_cache1} =
    OptimizerAdam.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

  {new_dense2, new_cache2} =
    OptimizerAdam.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

  # print chart
  if rem(epoch, acc.settings.update_every) == 0 do
    point = %{
      epoch: epoch,
      loss: Nx.to_number(data_loss),
      reg_loss: Nx.to_number(regularization_loss),
      accuracy: accuracy
    }

    Kino.VegaLite.push(widget, point, window: 1000)
  end

  %{
    acc
    | dense1: new_dense1,
      dense2: new_dense2,
      cache1: new_cache1,
      cache2: new_cache2,
      settings: new_settings
  }
end)

Regression

defmodule ActivationLinear do
  def forward(inputs), do: inputs
  def backward(dinputs), do: dinputs
end
defmodule LossMSE do
  def forward(y_pred, y_true) do
    Nx.mean(Nx.power(Nx.subtract(y_true, y_pred), 2), axes: [-1])
  end

  def backward(dvalues, y_true) do
    {samples, outputs} = Nx.shape(dvalues)

    dinputs =
      Nx.multiply(Nx.divide(Nx.subtract(y_true, dvalues), outputs), -2)
      |> Nx.divide(samples)
  end

  def calculate(output, y) do
    output
    |> forward(y)
    |> Nx.mean()
  end
end
defmodule LossMAE do
  def forward(y_pred, y_true) do
    Nx.mean(Nx.abs(Nx.subtract(y_true, y_pred)), axes: [-1])
  end

  def backward(dvalues, y_true) do
    {samples, outputs} = Nx.shape(dvalues)

    Nx.divide(Nx.sign(Nx.subtract(y_true, dvalues)), outputs)
    |> Nx.divide(samples)
  end

  def calculate(output, y) do
    output
    |> forward(y)
    |> Nx.mean()
  end
end
x_sine = 1..1000 |> Enum.map(&(&1 / 1000))
y_sine = Enum.map(x_sine, fn x -> :math.sin(x * :math.pi() * 2) end)

Vl.new(width: 400, height: 400, title: "sine")
|> Vl.mark(:circle, opacity: 0.8, size: 4)
|> Vl.data_from_series(
  x: x_sine,
  y: y_sine
)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
x_sine = Nx.tensor(x_sine) |> Nx.reshape({1000, 1})
y_sine = Nx.tensor(y_sine) |> Nx.reshape({1000, 1})

# standard deviation
std = fn y ->
  {total_samples, _} = Nx.shape(y)
  u = Nx.divide(Nx.sum(y), total_samples)
  Nx.sqrt(Nx.multiply(Nx.sum(Nx.power(Nx.subtract(y, u), 2)), Nx.divide(1, total_samples)))
end

accuracy_precision = Nx.divide(std.(y_sine), 250)
widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :green, size: 1)
    |> Vl.encode_field(:y, "reg_loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(1, 64)
dense2 = LayerDense.new(64, 1)
dense1_regularizer = LayerDense.new_regularization(0, 0, 5.0e-4, 5.0e-4)
dense2_regularizer = LayerDense.new_regularization()
cache1 = {OptimizerAdam.new_cache(dense1), OptimizerAdam.new_momentum(dense1)}
cache2 = {OptimizerAdam.new_cache(dense2), OptimizerAdam.new_momentum(dense2)}

settings = %{
  epochs: 100,
  update_every: 10,
  learning_rate: 0.05,
  decay: 5.0e-7,
  epsilon: 1.0e-7,
  momentum: 0.95,
  rho: 0.95,
  dropout_rate: 0.1
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  cache1: cache1,
  cache2: cache2,
  settings: settings
}

acc =
  Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
    # forward pass
    dense1_output = LayerDense.forward(x_sine, acc.dense1)
    relu_output = ActivationReLU.forward(dense1_output)
    dense2_output = LayerDense.forward(relu_output, acc.dense2)
    linear_output = ActivationLinear.forward(dense2_output)

    data_loss = LossMSE.calculate(linear_output, y_sine)

    regularization_loss =
      Nx.add(
        LossCategoricalCrossentropy.regularization_loss(acc.dense1, dense1_regularizer),
        LossCategoricalCrossentropy.regularization_loss(acc.dense2, dense2_regularizer)
      )

    loss = Nx.add(data_loss, regularization_loss)
    predictions = linear_output

    accuracy =
      Nx.mean(Nx.greater(accuracy_precision, Nx.abs(Nx.subtract(predictions, y_sine))))
      |> Nx.to_number()

    # backward pass

    d_loss = LossMSE.backward(linear_output, y_sine)

    d_sigmoid = ActivationLinear.backward(d_loss)

    {d_dense2, dinputs2} =
      LayerDense.backward(d_sigmoid, acc.dense2, relu_output, dense2_regularizer)

    d_relu = ActivationReLU.backward(dinputs2, dense1_output)

    {d_dense1, _dinputs1} = LayerDense.backward(d_relu, acc.dense1, x_sine, dense1_regularizer)

    # settings update
    new_settings = OptimizerAdam.pre_update_params(epoch, acc.settings)

    {new_dense1, new_cache1} =
      OptimizerAdam.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

    {new_dense2, new_cache2} =
      OptimizerAdam.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

    # print chart
    if rem(epoch, acc.settings.update_every) == 0 do
      point = %{
        epoch: epoch,
        loss: Nx.to_number(data_loss),
        reg_loss: Nx.to_number(regularization_loss),
        accuracy: accuracy
      }

      Kino.VegaLite.push(widget, point, window: 1000)
    end

    %{
      acc
      | dense1: new_dense1,
        dense2: new_dense2,
        cache1: new_cache1,
        cache2: new_cache2,
        settings: new_settings
    }
  end)
x_input = 1..1000 |> Enum.map(&(&1 / 1000))

predict = fn x ->
  dense1_output = LayerDense.forward(x, acc.dense1)
  relu_output = ActivationReLU.forward(dense1_output)
  dense2_output = LayerDense.forward(relu_output, acc.dense2)
  ActivationLinear.forward(dense2_output)
end

y_predicted = predict.(Nx.tensor(x_input) |> Nx.reshape({1000, 1})) |> Nx.to_flat_list()

Vl.new(width: 400, height: 400, title: "sine")
|> Vl.mark(:circle, opacity: 0.8, size: 4)
|> Vl.data_from_series(
  x: x_input,
  y: y_predicted
)
|> Vl.encode_field(:x, "x", type: :quantitative)
|> Vl.encode_field(:y, "y", type: :quantitative)
# override module with backward function
defmodule LayerDense do
  def new(n_inputs, n_neurons) do
    weights = Nx.random_normal({n_inputs, n_neurons}) |> Nx.multiply(0.01)
    biases = Nx.broadcast(Nx.tensor(0), {1, n_neurons})
    {weights, biases}
  end

  def new_regularization(l1_w \\ 0, l1_b \\ 0, l2_w \\ 0, l2_b \\ 0) do
    %{l1_w: l1_w, l1_b: l1_b, l2_w: l2_w, l2_b: l2_b}
  end

  def forward(inputs, {weights, biases}) do
    Nx.dot(inputs, weights) |> Nx.add(biases)
  end

  def backward(dvalues, {old_weights, old_biases}, old_inputs, regularization) do
    %{l1_w: l1_w, l1_b: l1_b, l2_w: l2_w, l2_b: l2_b} = regularization

    dweights =
      Nx.dot(Nx.transpose(old_inputs), dvalues)
      |> dl1(old_weights, l1_w)
      |> dl2(old_weights, l2_w)

    dbiases =
      Nx.sum(dvalues, axes: [0], keep_axes: true)
      |> dl1(old_biases, l1_b)
      |> dl2(old_biases, l2_b)

    dinputs = Nx.dot(dvalues, Nx.transpose(old_weights))
    {{dweights, dbiases}, dinputs}
  end

  def dl1(dparam, old_param, l1) when l1 <= 0, do: dparam

  def dl1(dparam, old_param, l1) do
    dL1 = Nx.multiply(Nx.subtract(Nx.greater(old_param, 0), 0.5), 2)
    Nx.add(dparam, Nx.add(dL1, l1))
  end

  def dl2(dparam, old_param, l2) when l2 <= 0, do: dparam

  def dl2(dparam, old_param, l2) do
    Nx.add(dparam, Nx.multiply(old_param, l2 * 2))
  end
end

~M{1 1
  -1 1
}
|> LayerDense.dl1(
  ~M{1 1
    1 1},
  1
)
widget =
  Vl.new(width: 400, height: 400, title: "learning")
  |> Vl.encode_field(:x, "epoch", type: :quantitative)
  |> Vl.layers([
    Vl.new()
    |> Vl.mark(:line, color: :blue, size: 1)
    |> Vl.encode_field(:y, "loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :green, size: 1)
    |> Vl.encode_field(:y, "reg_loss", type: :quantitative),
    Vl.new()
    |> Vl.mark(:line, color: :red, size: 1)
    |> Vl.encode_field(:y, "accuracy", type: :quantitative)
  ])
  |> Kino.VegaLite.new()
  |> Kino.render()

dense1 = LayerDense.new(1, 64)
dense2 = LayerDense.new(64, 64)
dense3 = LayerDense.new(64, 1)
dense1_regularizer = LayerDense.new_regularization(0, 0, 5.0e-4, 5.0e-4)
dense2_regularizer = LayerDense.new_regularization()
dense3_regularizer = LayerDense.new_regularization()
cache1 = {OptimizerAdam.new_cache(dense1), OptimizerAdam.new_momentum(dense1)}
cache2 = {OptimizerAdam.new_cache(dense2), OptimizerAdam.new_momentum(dense2)}
cache3 = {OptimizerAdam.new_cache(dense3), OptimizerAdam.new_momentum(dense3)}

settings = %{
  epochs: 100,
  update_every: 10,
  learning_rate: 0.05,
  decay: 5.0e-7,
  epsilon: 1.0e-7,
  momentum: 0.95,
  rho: 0.95,
  dropout_rate: 0.1
}

initial_acc = %{
  dense1: dense1,
  dense2: dense2,
  dense3: dense3,
  cache1: cache1,
  cache2: cache2,
  cache3: cache3,
  settings: settings
}

Enum.reduce(0..settings.epochs, initial_acc, fn epoch, acc ->
  # forward pass
  dense1_output = LayerDense.forward(x_sine, acc.dense1)
  relu1_output = ActivationReLU.forward(dense1_output)
  dense2_output = LayerDense.forward(relu1_output, acc.dense2)
  relu2_output = ActivationReLU.forward(dense2_output)
  dense3_output = LayerDense.forward(relu2_output, acc.dense3)
  linear_output = ActivationLinear.forward(dense3_output)

  data_loss = LossMSE.calculate(linear_output, y_sine)

  regularization_loss =
    LossCategoricalCrossentropy.regularization_loss(acc.dense1, dense1_regularizer)
    |> Nx.add(LossCategoricalCrossentropy.regularization_loss(acc.dense2, dense2_regularizer))
    |> Nx.add(LossCategoricalCrossentropy.regularization_loss(acc.dense3, dense3_regularizer))

  loss = Nx.add(data_loss, regularization_loss)
  predictions = linear_output

  accuracy =
    Nx.mean(Nx.greater(accuracy_precision, Nx.abs(Nx.subtract(predictions, y_sine))))
    |> Nx.to_number()

  d_loss = LossMSE.backward(linear_output, y_sine)

  d_sigmoid = ActivationLinear.backward(d_loss)

  {d_dense3, dinputs3} =
    LayerDense.backward(d_sigmoid, acc.dense3, relu2_output, dense3_regularizer)

  d_relu2 = ActivationReLU.backward(dinputs3, dense3_output)

  {d_dense2, dinputs2} =
    LayerDense.backward(d_relu2, acc.dense2, relu1_output, dense2_regularizer)

  d_relu1 = ActivationReLU.backward(dinputs2, dense2_output)

  {d_dense1, _dinputs1} = LayerDense.backward(d_relu1, acc.dense1, x_sine, dense1_regularizer)

  # settings update
  new_settings = OptimizerAdam.pre_update_params(epoch, acc.settings)

  {new_dense1, new_cache1} =
    OptimizerAdam.update_params(acc.dense1, d_dense1, acc.cache1, new_settings)

  {new_dense2, new_cache2} =
    OptimizerAdam.update_params(acc.dense2, d_dense2, acc.cache2, new_settings)

  {new_dense3, new_cache3} =
    OptimizerAdam.update_params(acc.dense3, d_dense3, acc.cache3, new_settings)

  # print chart
  if rem(epoch, acc.settings.update_every) == 0 do
    point = %{
      epoch: epoch,
      loss: Nx.to_number(data_loss),
      reg_loss: Nx.to_number(regularization_loss),
      accuracy: accuracy
    }

    Kino.VegaLite.push(widget, point, window: 1000)
  end

  %{
    acc
    | dense1: new_dense1,
      dense2: new_dense2,
      dense3: new_dense3,
      cache1: new_cache1,
      cache2: new_cache2,
      cache3: new_cache3,
      settings: new_settings
  }
end)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment