Skip to content

Instantly share code, notes, and snippets.

@behrica
Created December 17, 2024 14:44
Show Gist options
  • Save behrica/9667f65bd5308c7b70ba01ab484ad106 to your computer and use it in GitHub Desktop.
Save behrica/9667f65bd5308c7b70ba01ab484ad106 to your computer and use it in GitHub Desktop.
using dl4j with Clojure dataset
(ns iris
(:require
[scicloj.metamorph.core :as mm]
[scicloj.metamorph.ml.preprocessing :as preprocessing]
[scicloj.metamorph.ml.toydata :as toydata]
[sweet-array.core :as sa]
[tablecloth.api :as tc]
[tech.v3.dataset :as ds]
[tech.v3.dataset.column-filters :as cf])
(:import
[org.deeplearning4j.nn.conf NeuralNetConfiguration$Builder]
[org.deeplearning4j.nn.conf.layers DenseLayer$Builder OutputLayer$Builder]
[org.deeplearning4j.nn.multilayer MultiLayerNetwork]
[org.deeplearning4j.nn.weights WeightInit]
[org.deeplearning4j.optimize.listeners ScoreIterationListener]
[org.nd4j.evaluation.classification Evaluation]
[org.nd4j.linalg.activations Activation]
[org.nd4j.linalg.dataset DataSet]
[org.nd4j.linalg.factory Nd4j]
[org.nd4j.linalg.learning.config Sgd]
[org.nd4j.linalg.lossfunctions LossFunctions$LossFunction]))
;; clojure
(def iris
(toydata/iris-ds))
(def split
( first
(tc/split->seq iris :holdout )))
(def feature-names
(tc/column-names
(cf/feature iris)))
(def my-pipe
(mm/pipeline
(preprocessing/std-scale feature-names {})
(mm/lift ds/categorical->one-hot [:species])) )
(def fitted-ctx
(mm/fit-pipe
(:train split)
my-pipe))
(def transformed-ctx
(mm/transform-pipe
(:test split)
my-pipe
fitted-ctx))
(def train-ds
(:metamorph/data fitted-ctx))
(def test-ds
(:metamorph/data transformed-ctx))
(defn- ds->nd-float-array [ds]
(->> ds
(tc/rows)
(sa/into-array [[float]])
(Nd4j/create)))
(def nd-train-features
(-> train-ds
(cf/feature)
ds->nd-float-array))
(def nd-train-labels
(-> train-ds
(cf/target)
ds->nd-float-array))
(def nd-test-features
(-> test-ds
(cf/feature)
ds->nd-float-array))
(def nd-test-labels
(->> test-ds
(cf/target)
ds->nd-float-array))
;; Clojure interop starts
(def num-inputs 4)
(def output-num 3)
(def seed 6)
(def train-dataset (DataSet. nd-train-features nd-train-labels))
(def test-dataset (DataSet. nd-test-features nd-test-labels))
(def conf (-> (NeuralNetConfiguration$Builder.)
(.seed seed)
(.activation Activation/TANH)
(.weightInit WeightInit/XAVIER)
(.updater (Sgd. 0.1))
(.l2 1e-4)
(.list)
(.layer (-> (DenseLayer$Builder.)
(.nIn num-inputs)
(.nOut 3)
(.build)))
(.layer (-> (DenseLayer$Builder.)
(.nIn 3)
(.nOut 3)
(.build)))
(.layer (-> (OutputLayer$Builder. LossFunctions$LossFunction/NEGATIVELOGLIKELIHOOD)
(.activation Activation/SOFTMAX) ; Override global TANH activation
(.nIn 3)
(.nOut output-num)
(.build)))
(.build)))
(def model (MultiLayerNetwork. conf))
(.init model)
(.setListeners model [ (ScoreIterationListener. 100)]) ; Record score every 100 iterations
;; Train the model for 1000 epochs
(dotimes [i 1000]
(.fit model train-dataset))
;; Evaluate the model on the test set
(def evaluation (Evaluation. output-num))
(def output (.output model (.getFeatures test-dataset)))
(.eval evaluation (.getLabels test-dataset) output)
(println (.stats evaluation))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment