Last active May 3, 2023
complete good FFT in Haskell
import Debug.Trace
import Data.Bits
-- FFT Algorithm
-- =============
data Nat = E | O Nat | I Nat
data Tree a = L a | B (Tree a) (Tree a)
type GN = Tree Int
add :: GN -> GN -> GN
add (L x) (L y) = L (x + y)
add (B ax ay) (B bx by) = B (add ax bx) (add ay by)
sub :: GN -> GN -> GN
sub (L x) (L y) = L (x - y)
sub (B ax ay) (B bx by) = B (sub ax bx) (sub ay by)
rot :: GN -> GN
rot (L z) = L (-z)
rot (B x y) = B (rot y) x
mul :: Nat -> GN -> GN
mul E x = x
mul (O p) (L x) = L x
mul (I p) (L x) = rot (L x)
mul (O p) (B ax ay) = B (mul p ax) (mul p ay)
mul (I p) (B ax ay) = rot (B (mul p ax) (mul p ay))
fft :: (Nat -> Nat) -> Tree GN -> Tree GN
fft ang (L x) = L x
fft ang (B x y) =
let pt0 = fft (ang.O) x
pt1 = fft (ang.O) y
in mix ang pt0 pt1
mix :: (Nat -> Nat) -> Tree GN -> Tree GN -> Tree GN
mix ang (B ax ay) (B bx by) =
let ptl = mix (ang.O) ax bx
ptr = mix (ang.I) ay by
in B ptl ptr
mix ang (L pt0) (L pt1) =
let pt2 = mul (ang E) pt1
ptl = L (add pt0 pt2)
ptr = L (sub pt0 pt2)
in B ptl ptr
-- Complex
-- =======
data Complex = C Double Double deriving Show
cScale :: Double -> Complex -> Complex
cScale s (C ar ai) = C (ar * s) (ai * s)
cAdd :: Complex -> Complex -> Complex
cAdd (C ar ai) (C br bi) = C (ar + br) (ai + bi)
cMul :: Complex -> Complex -> Complex
cMul (C ar ai) (C br bi) = C (ar * br - ai * bi) (ar * bi + ai * br)
cPol :: Double -> Complex
cPol ang = C (cos ang) (sin ang)
-- Tree
-- ====
flatten :: Tree a -> [a]
flatten (L x) = [x]
flatten (B x y) = flatten x ++ flatten y
getAt :: Nat -> Tree a -> a
getAt E (L x) = x
getAt (O p) (B x y) = getAt p x
getAt (I p) (B x y) = getAt p y
depth :: Tree a -> Int
depth (L _) = 0
depth (B x y) = 1 + depth x
invert :: Tree a -> Tree a
invert tree = invertGo (depth tree) E tree where
invertGo 0 path tree = L (getAt path tree)
invertGo n path tree = B (invertGo (n - 1) (O path) tree) (invertGo (n - 1) (I path) tree)
-- Conversions
-- ===========
intToGN :: Int -> Int -> GN
intToGN 0 x = L x
intToGN n x = B (intToGN (n - 1) x) (intToGN (n - 1) 0)
gaussToComplex :: GN -> Complex
gaussToComplex num = val num E where
val (L x) path = cScale (fromIntegral x) (twd path (pi / 2) (C 1 0))
val (B a0 a1) path = cAdd (val a0 (O path)) (val a1 (I path))
twd E _ num = num
twd (O path) ang num = twd path (ang / 2) num
twd (I path) ang num = twd path (ang / 2) (cMul num (cPol ang))
listToTree :: [a] -> Tree a
listToTree [x] = L x
listToTree list = split 0 id id list where
split _ a b [] = B (listToTree (a [])) (listToTree (b []))
split 0 a b (x:xs) = split 1 (\r -> a (x:r)) b xs
split 1 a b (x:xs) = split 0 a (\r -> b (x:r)) xs
treeToList :: Tree a -> [a]
treeToList t = flatten (invert t)
log2 :: Int -> Int
log2 n = floor (logBase 2 (fromIntegral n) :: Double)
isPow2 :: Int -> Bool
isPow2 n
| n <= 0 = False
| otherwise = (n .&. (n - 1)) == 0
listFFT :: [Int] -> [Complex]
listFFT xs
| not (isPow2 (length xs))
= error "list must have 2^n elements"
listFFT xs
= map gaussToComplex
. treeToList
. fft (id.O)
. listToTree
. map (intToGN (log2 (length xs)))
$ xs
-- Main
-- ====
main :: IO ()
main = do
let c = intToGN 3
print $ listFFT [1, 2, 3, 4, 5, 6, 7, 8]
