Last active
May 3, 2023 02:43
-
-
Save VictorTaelin/af471e66a03a9a3efca25a42f6376aae to your computer and use it in GitHub Desktop.
complete good FFT in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Debug.Trace | |
import Data.Bits | |
-- FFT Algorithm | |
-- ============= | |
data Nat = E | O Nat | I Nat | |
data Tree a = L a | B (Tree a) (Tree a) | |
type GN = Tree Int | |
add :: GN -> GN -> GN | |
add (L x) (L y) = L (x + y) | |
add (B ax ay) (B bx by) = B (add ax bx) (add ay by) | |
sub :: GN -> GN -> GN | |
sub (L x) (L y) = L (x - y) | |
sub (B ax ay) (B bx by) = B (sub ax bx) (sub ay by) | |
rot :: GN -> GN | |
rot (L z) = L (-z) | |
rot (B x y) = B (rot y) x | |
mul :: Nat -> GN -> GN | |
mul E x = x | |
mul (O p) (L x) = L x | |
mul (I p) (L x) = rot (L x) | |
mul (O p) (B ax ay) = B (mul p ax) (mul p ay) | |
mul (I p) (B ax ay) = rot (B (mul p ax) (mul p ay)) | |
fft :: (Nat -> Nat) -> Tree GN -> Tree GN | |
fft ang (L x) = L x | |
fft ang (B x y) = | |
let pt0 = fft (ang.O) x | |
pt1 = fft (ang.O) y | |
in mix ang pt0 pt1 | |
mix :: (Nat -> Nat) -> Tree GN -> Tree GN -> Tree GN | |
mix ang (B ax ay) (B bx by) = | |
let ptl = mix (ang.O) ax bx | |
ptr = mix (ang.I) ay by | |
in B ptl ptr | |
mix ang (L pt0) (L pt1) = | |
let pt2 = mul (ang E) pt1 | |
ptl = L (add pt0 pt2) | |
ptr = L (sub pt0 pt2) | |
in B ptl ptr | |
-- Complex | |
-- ======= | |
data Complex = C Double Double deriving Show | |
cScale :: Double -> Complex -> Complex | |
cScale s (C ar ai) = C (ar * s) (ai * s) | |
cAdd :: Complex -> Complex -> Complex | |
cAdd (C ar ai) (C br bi) = C (ar + br) (ai + bi) | |
cMul :: Complex -> Complex -> Complex | |
cMul (C ar ai) (C br bi) = C (ar * br - ai * bi) (ar * bi + ai * br) | |
cPol :: Double -> Complex | |
cPol ang = C (cos ang) (sin ang) | |
-- Tree | |
-- ==== | |
flatten :: Tree a -> [a] | |
flatten (L x) = [x] | |
flatten (B x y) = flatten x ++ flatten y | |
getAt :: Nat -> Tree a -> a | |
getAt E (L x) = x | |
getAt (O p) (B x y) = getAt p x | |
getAt (I p) (B x y) = getAt p y | |
depth :: Tree a -> Int | |
depth (L _) = 0 | |
depth (B x y) = 1 + depth x | |
invert :: Tree a -> Tree a | |
invert tree = invertGo (depth tree) E tree where | |
invertGo 0 path tree = L (getAt path tree) | |
invertGo n path tree = B (invertGo (n - 1) (O path) tree) (invertGo (n - 1) (I path) tree) | |
-- Conversions | |
-- =========== | |
intToGN :: Int -> Int -> GN | |
intToGN 0 x = L x | |
intToGN n x = B (intToGN (n - 1) x) (intToGN (n - 1) 0) | |
gaussToComplex :: GN -> Complex | |
gaussToComplex num = val num E where | |
val (L x) path = cScale (fromIntegral x) (twd path (pi / 2) (C 1 0)) | |
val (B a0 a1) path = cAdd (val a0 (O path)) (val a1 (I path)) | |
twd E _ num = num | |
twd (O path) ang num = twd path (ang / 2) num | |
twd (I path) ang num = twd path (ang / 2) (cMul num (cPol ang)) | |
listToTree :: [a] -> Tree a | |
listToTree [x] = L x | |
listToTree list = split 0 id id list where | |
split _ a b [] = B (listToTree (a [])) (listToTree (b [])) | |
split 0 a b (x:xs) = split 1 (\r -> a (x:r)) b xs | |
split 1 a b (x:xs) = split 0 a (\r -> b (x:r)) xs | |
treeToList :: Tree a -> [a] | |
treeToList t = flatten (invert t) | |
log2 :: Int -> Int | |
log2 n = floor (logBase 2 (fromIntegral n) :: Double) | |
isPow2 :: Int -> Bool | |
isPow2 n | |
| n <= 0 = False | |
| otherwise = (n .&. (n - 1)) == 0 | |
listFFT :: [Int] -> [Complex] | |
listFFT xs | |
| not (isPow2 (length xs)) | |
= error "list must have 2^n elements" | |
listFFT xs | |
= map gaussToComplex | |
. treeToList | |
. fft (id.O) | |
. listToTree | |
. map (intToGN (log2 (length xs))) | |
$ xs | |
-- Main | |
-- ==== | |
main :: IO () | |
main = do | |
let c = intToGN 3 | |
print $ listFFT [1, 2, 3, 4, 5, 6, 7, 8] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment