Created
December 17, 2024 14:44
-
-
Save behrica/9667f65bd5308c7b70ba01ab484ad106 to your computer and use it in GitHub Desktop.
using dl4j with Clojure dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns iris | |
(:require | |
[scicloj.metamorph.core :as mm] | |
[scicloj.metamorph.ml.preprocessing :as preprocessing] | |
[scicloj.metamorph.ml.toydata :as toydata] | |
[sweet-array.core :as sa] | |
[tablecloth.api :as tc] | |
[tech.v3.dataset :as ds] | |
[tech.v3.dataset.column-filters :as cf]) | |
(:import | |
[org.deeplearning4j.nn.conf NeuralNetConfiguration$Builder] | |
[org.deeplearning4j.nn.conf.layers DenseLayer$Builder OutputLayer$Builder] | |
[org.deeplearning4j.nn.multilayer MultiLayerNetwork] | |
[org.deeplearning4j.nn.weights WeightInit] | |
[org.deeplearning4j.optimize.listeners ScoreIterationListener] | |
[org.nd4j.evaluation.classification Evaluation] | |
[org.nd4j.linalg.activations Activation] | |
[org.nd4j.linalg.dataset DataSet] | |
[org.nd4j.linalg.factory Nd4j] | |
[org.nd4j.linalg.learning.config Sgd] | |
[org.nd4j.linalg.lossfunctions LossFunctions$LossFunction])) | |
;; clojure | |
(def iris | |
(toydata/iris-ds)) | |
(def split | |
( first | |
(tc/split->seq iris :holdout ))) | |
(def feature-names | |
(tc/column-names | |
(cf/feature iris))) | |
(def my-pipe | |
(mm/pipeline | |
(preprocessing/std-scale feature-names {}) | |
(mm/lift ds/categorical->one-hot [:species])) ) | |
(def fitted-ctx | |
(mm/fit-pipe | |
(:train split) | |
my-pipe)) | |
(def transformed-ctx | |
(mm/transform-pipe | |
(:test split) | |
my-pipe | |
fitted-ctx)) | |
(def train-ds | |
(:metamorph/data fitted-ctx)) | |
(def test-ds | |
(:metamorph/data transformed-ctx)) | |
(defn- ds->nd-float-array [ds] | |
(->> ds | |
(tc/rows) | |
(sa/into-array [[float]]) | |
(Nd4j/create))) | |
(def nd-train-features | |
(-> train-ds | |
(cf/feature) | |
ds->nd-float-array)) | |
(def nd-train-labels | |
(-> train-ds | |
(cf/target) | |
ds->nd-float-array)) | |
(def nd-test-features | |
(-> test-ds | |
(cf/feature) | |
ds->nd-float-array)) | |
(def nd-test-labels | |
(->> test-ds | |
(cf/target) | |
ds->nd-float-array)) | |
;; Clojure interop starts | |
(def num-inputs 4) | |
(def output-num 3) | |
(def seed 6) | |
(def train-dataset (DataSet. nd-train-features nd-train-labels)) | |
(def test-dataset (DataSet. nd-test-features nd-test-labels)) | |
(def conf (-> (NeuralNetConfiguration$Builder.) | |
(.seed seed) | |
(.activation Activation/TANH) | |
(.weightInit WeightInit/XAVIER) | |
(.updater (Sgd. 0.1)) | |
(.l2 1e-4) | |
(.list) | |
(.layer (-> (DenseLayer$Builder.) | |
(.nIn num-inputs) | |
(.nOut 3) | |
(.build))) | |
(.layer (-> (DenseLayer$Builder.) | |
(.nIn 3) | |
(.nOut 3) | |
(.build))) | |
(.layer (-> (OutputLayer$Builder. LossFunctions$LossFunction/NEGATIVELOGLIKELIHOOD) | |
(.activation Activation/SOFTMAX) ; Override global TANH activation | |
(.nIn 3) | |
(.nOut output-num) | |
(.build))) | |
(.build))) | |
(def model (MultiLayerNetwork. conf)) | |
(.init model) | |
(.setListeners model [ (ScoreIterationListener. 100)]) ; Record score every 100 iterations | |
;; Train the model for 1000 epochs | |
(dotimes [i 1000] | |
(.fit model train-dataset)) | |
;; Evaluate the model on the test set | |
(def evaluation (Evaluation. output-num)) | |
(def output (.output model (.getFeatures test-dataset))) | |
(.eval evaluation (.getLabels test-dataset) output) | |
(println (.stats evaluation)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment