Skip to content

Instantly share code, notes, and snippets.

@toranb
Last active April 16, 2024 19:29
Show Gist options
  • Save toranb/56a3e65ca81fba1c4f6c92c6f1857681 to your computer and use it in GitHub Desktop.
Save toranb/56a3e65ca81fba1c4f6c92c6f1857681 to your computer and use it in GitHub Desktop.
fizzbuzz with Nx
defmodule Mlearning do
@moduledoc false
def mods(x) do
[rem(x, 3), rem(x, 5)]
end
def fizzbuzz(n) do
cond do
rem(n, 15) == 0 -> [0, 0, 1, 0]
rem(n, 3) == 0 -> [1, 0, 0, 0]
rem(n, 5) == 0 -> [0, 1, 0, 0]
true -> [0, 0, 0, 1]
end
end
defmodule Demo do
import Nx.Defn
defn relu(x) do
custom_grad(
Nx.max(x, 0),
[x],
fn g -> [Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))] end
)
end
defn softmax(logits) do
Nx.exp(logits) / Nx.sum(Nx.exp(logits))
end
defn loss({w1, b1, w2, b2}, numbers, labels) do
preds = predict({w1, b1, w2, b2}, numbers)
-Nx.sum(Nx.mean(Nx.log(preds) * labels))
end
defn update({w1, b1, w2, b2} = params, numbers, labels) do
{grad_w1, grad_b1, grad_w2, grad_b2} = grad(params, &loss(&1, numbers, labels))
{w1 - grad_w1 * 0.1, b1 - grad_b1 * 0.1, w2 - grad_w2 * 0.1, b2 - grad_b2 * 0.1}
end
defn init_params(key) do
w1 =
Nx.Random.normal_split(key, 0.0, 0.1,
shape: {2, 3},
names: [:input, :hidden],
type: {:f, 32}
)
b1 = Nx.Random.normal_split(key, 0.0, 0.1, shape: {3}, names: [:hidden], type: {:f, 32})
w2 =
Nx.Random.normal_split(key, 0.0, 0.1,
shape: {3, 4},
names: [:hidden, :output],
type: {:f, 32}
)
b2 = Nx.Random.normal_split(key, 0.0, 0.1, shape: {4}, names: [:output], type: {:f, 32})
{w1, b1, w2, b2}
end
defn predict({w1, b1, w2, b2}, numbers) do
numbers
|> Nx.dot(w1)
|> Nx.add(b1)
|> relu()
|> Nx.dot(w2)
|> Nx.add(b2)
|> softmax()
end
end
def world() do
init_nums =
1..1000
|> Enum.map(fn n ->
mods(n)
end)
init_labels =
1..1000
|> Enum.map(fn n ->
fizzbuzz(n)
end)
key = Nx.Random.key(1)
init_params = Demo.init_params(key)
data = Enum.zip(init_nums, init_labels) |> Enum.with_index()
params =
Enum.reduce(1..5, init_params, fn _, params ->
data
|> Enum.reduce(params, fn {{numbers, labels}, _b}, cur_params ->
numbers = numbers |> Nx.tensor()
labels = labels |> Nx.tensor()
Demo.update(cur_params, numbers, labels)
end)
end)
guess = fn x ->
mod = Nx.tensor(mods(x))
Demo.predict(params, mod)
case Demo.predict(params, mod) |> Nx.argmax() |> Nx.to_flat_list() do
[0] -> "fizz"
[1] -> "buzz"
[2] -> "fizzbuzz"
[3] -> "womp"
end
end
guess.(3) |> IO.inspect(label: "3")
guess.(5) |> IO.inspect(label: "5")
guess.(15) |> IO.inspect(label: "15")
guess.(16) |> IO.inspect(label: "16")
guess.(15_432_115) |> IO.inspect(label: "15,432,115")
guess.(20_399_985) |> IO.inspect(label: "20,399,985")
guess.(20_399_997) |> IO.inspect(label: "20,399,997")
guess.(20_399_998) |> IO.inspect(label: "20,399,998")
:ok
end
end
@toranb
Copy link
Author

toranb commented Mar 19, 2023

  defp deps do
    [
      {:exla, "~> 0.6"},
      {:nx, "~> 0.6"}
    ]
  end

@toranb
Copy link
Author

toranb commented Sep 4, 2023

the inspiration for this came from Bruce and the Programmer Passport series on Nx

https://www.youtube.com/watch?v=NcsqGS6SVXg

@batate
Copy link

batate commented Apr 16, 2024

Just saw this. I like the way this is layered. Thanks for the credit!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment