Last active
April 16, 2024 19:29
-
-
Save toranb/56a3e65ca81fba1c4f6c92c6f1857681 to your computer and use it in GitHub Desktop.
fizzbuzz with Nx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule Mlearning do | |
@moduledoc false | |
def mods(x) do | |
[rem(x, 3), rem(x, 5)] | |
end | |
def fizzbuzz(n) do | |
cond do | |
rem(n, 15) == 0 -> [0, 0, 1, 0] | |
rem(n, 3) == 0 -> [1, 0, 0, 0] | |
rem(n, 5) == 0 -> [0, 1, 0, 0] | |
true -> [0, 0, 0, 1] | |
end | |
end | |
defmodule Demo do | |
import Nx.Defn | |
defn relu(x) do | |
custom_grad( | |
Nx.max(x, 0), | |
[x], | |
fn g -> [Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))] end | |
) | |
end | |
defn softmax(logits) do | |
Nx.exp(logits) / Nx.sum(Nx.exp(logits)) | |
end | |
defn loss({w1, b1, w2, b2}, numbers, labels) do | |
preds = predict({w1, b1, w2, b2}, numbers) | |
-Nx.sum(Nx.mean(Nx.log(preds) * labels)) | |
end | |
defn update({w1, b1, w2, b2} = params, numbers, labels) do | |
{grad_w1, grad_b1, grad_w2, grad_b2} = grad(params, &loss(&1, numbers, labels)) | |
{w1 - grad_w1 * 0.1, b1 - grad_b1 * 0.1, w2 - grad_w2 * 0.1, b2 - grad_b2 * 0.1} | |
end | |
defn init_params(key) do | |
w1 = | |
Nx.Random.normal_split(key, 0.0, 0.1, | |
shape: {2, 3}, | |
names: [:input, :hidden], | |
type: {:f, 32} | |
) | |
b1 = Nx.Random.normal_split(key, 0.0, 0.1, shape: {3}, names: [:hidden], type: {:f, 32}) | |
w2 = | |
Nx.Random.normal_split(key, 0.0, 0.1, | |
shape: {3, 4}, | |
names: [:hidden, :output], | |
type: {:f, 32} | |
) | |
b2 = Nx.Random.normal_split(key, 0.0, 0.1, shape: {4}, names: [:output], type: {:f, 32}) | |
{w1, b1, w2, b2} | |
end | |
defn predict({w1, b1, w2, b2}, numbers) do | |
numbers | |
|> Nx.dot(w1) | |
|> Nx.add(b1) | |
|> relu() | |
|> Nx.dot(w2) | |
|> Nx.add(b2) | |
|> softmax() | |
end | |
end | |
def world() do | |
init_nums = | |
1..1000 | |
|> Enum.map(fn n -> | |
mods(n) | |
end) | |
init_labels = | |
1..1000 | |
|> Enum.map(fn n -> | |
fizzbuzz(n) | |
end) | |
key = Nx.Random.key(1) | |
init_params = Demo.init_params(key) | |
data = Enum.zip(init_nums, init_labels) |> Enum.with_index() | |
params = | |
Enum.reduce(1..5, init_params, fn _, params -> | |
data | |
|> Enum.reduce(params, fn {{numbers, labels}, _b}, cur_params -> | |
numbers = numbers |> Nx.tensor() | |
labels = labels |> Nx.tensor() | |
Demo.update(cur_params, numbers, labels) | |
end) | |
end) | |
guess = fn x -> | |
mod = Nx.tensor(mods(x)) | |
Demo.predict(params, mod) | |
case Demo.predict(params, mod) |> Nx.argmax() |> Nx.to_flat_list() do | |
[0] -> "fizz" | |
[1] -> "buzz" | |
[2] -> "fizzbuzz" | |
[3] -> "womp" | |
end | |
end | |
guess.(3) |> IO.inspect(label: "3") | |
guess.(5) |> IO.inspect(label: "5") | |
guess.(15) |> IO.inspect(label: "15") | |
guess.(16) |> IO.inspect(label: "16") | |
guess.(15_432_115) |> IO.inspect(label: "15,432,115") | |
guess.(20_399_985) |> IO.inspect(label: "20,399,985") | |
guess.(20_399_997) |> IO.inspect(label: "20,399,997") | |
guess.(20_399_998) |> IO.inspect(label: "20,399,998") | |
:ok | |
end | |
end |
Author
toranb
commented
Mar 19, 2023
•
the inspiration for this came from Bruce and the Programmer Passport series on Nx
Just saw this. I like the way this is layered. Thanks for the credit!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment