Created
March 9, 2010 15:05
-
-
Save pervognsen/326672 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
;; the double-array should be encapsulated inside a closure, with no escape possible. | |
;; here it's dangerously exposed for the ease of testing. | |
(defstruct matrix :num-rows :num-cols :entries) | |
(defn row-col-index [m row-index col-index] | |
(+ (* row-index (:num-cols m)) col-index)) | |
(defn num-entries [m] | |
(* (:num-rows m) (:num-cols m))) | |
(defn has-index? [m i] | |
(and (pos? i) (< i (* (:num-rows m) (:num-cols m))))) | |
(defn into-matrix [num-rows num-cols rows] | |
(let [entries (double-array (* num-rows num-cols))] | |
(assert (= (count rows) num-rows)) | |
(doseq [[row-index row] (indexed rows)] | |
(assert (= (count row) num-cols)) | |
(doseq [[col-index x] (indexed row)] | |
(aset entries | |
(row-col-index {:num-cols num-cols} row-index col-index) | |
(double x)))) | |
(struct matrix num-rows num-cols entries))) | |
(defn fill-matrix [num-rows num-cols fill-fn] | |
(into-matrix num-rows num-cols | |
(for [row-index (range num-rows)] | |
(for [col-index (range num-cols)] | |
(fill-fn row-index col-index))))) | |
(defn has-row? [m row-index] | |
(in-range? row-index (:num-rows m))) | |
(defn has-col? [m col-index] | |
(in-range? col-index (:num-cols m))) | |
;; Unchecked private accessors | |
(defn- unsafe-index [m i] | |
(aget (:entries m) i)) | |
(defn- unsafe-row [m row-index] | |
(for [i (range (* row-index (:num-cols m)) | |
(* (inc row-index) (:num-cols m)))] | |
(unsafe-entry m i))) | |
(defn- unsafe-col [m col-index] | |
(for [i (range col-index (num-entries m) (:num-cols m))] | |
(unsafe-entry m i))) | |
;; Public functions | |
(defn entry [m row col] | |
(let [i (row-col-index m row col)] | |
(assert (has-index? m i)) | |
(unsafe-index m i))) | |
(defn row [m row-index] | |
(assert (has-row? m row-index)) | |
(unsafe-row m row-index)) | |
(defn col [m col-index] | |
(assert (has-col? m col-index)) | |
(unsafe-col m col-index)) | |
(defn rows [m] | |
(for [row-index (range (:num-rows m))] | |
(unsafe-row m row-index))) | |
(defn cols [m] | |
(for [col-index (range (:num-cols m))] | |
(unsafe-col m col-index))) | |
;; Let the games begin! | |
(defn transpose [m] | |
(into-matrix (:num-cols m) (:num-rows m) (cols m))) | |
(defn map-rows [f & rs] | |
(apply map (fn [& xs] (apply map f xs)) rs)) | |
(defn map-matrix [f & ms] | |
(into-matrix (:num-rows (first ms)) (:num-cols (first ms)) (apply map-rows f (map rows ms)))) | |
(defn scale-matrix [s m] | |
(map-matrix #(* s %) m)) | |
(defn add-matrix [& ms] | |
(apply map-matrix + ms)) | |
(defn dot-product [xs ys] | |
(reduce + 0 (map * xs ys))) | |
(defn multiply-matrix [m1 m2] | |
(into-matrix (:num-rows m1) (:num-cols m2) | |
(for [row (rows m1)] | |
(for [col (cols m2)] | |
(dot-product row col))))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment