Skip to content

Instantly share code, notes, and snippets.

@VictorTaelin
Last active November 12, 2024 09:06
Show Gist options
  • Save VictorTaelin/5776ede998d0039ad1cc9b12fd96811c to your computer and use it in GitHub Desktop.
Save VictorTaelin/5776ede998d0039ad1cc9b12fd96811c to your computer and use it in GitHub Desktop.
Implementing complex numbers and FFT with just datatypes (no floats)

Implementing complex numbers and FFT with just datatypes (no floats)

In this article, I'll explain why implementing numbers with just algebraic datatypes is desirable. I'll then talk about common implementations of FFT (Fast Fourier Transform) and why they hide inherent inefficiencies. I'll then show how to implement integers and complex numbers with just algebraic datatypes, in a way that is extremely simple and elegant. I'll conclude by deriving a pure functional implementation of complex FFT with just datatypes, no floats.

Why implement numbers with ADTs?

For most programmers, "real numbers" are a given: they just use floats and call it a day. But, in some cases, it doesn't work well, and I'm not talking about precision issues. When trying to prove statements on real numbers in proof assistants like Coq, we can't use doubles, we must formalize reals using datatypes, which can be very hard. For me, the real issue arises when I'm trying to implement optimal algorithms on HVM. To do that, I need functions to fuse, which is a fancy way of saying that (+ 2) . (+ 2) "morphs into" (+ 4) during execution - yet, machine floats block that mechanism. Because of that, I've been trying to come up with an elegant way to implement numbers with just datatypes. For natural numbers, it is easy: we can just implement them as bitstrings:

-- O stands for bit 0
-- I stands for bit 1
-- E stands for end-of-string
data Nat = O Nat | I Nat | E deriving (Show)

-- Applies 'f' n times to 'x'.
rep :: Nat -> (a -> a) -> a -> a
rep (O x) f = rep x (\x -> f (f x))
rep (I x) f = rep x (\x -> f (f x)) . f
rep E     f = id

-- Increments a Nat.
inc :: Nat -> Nat
inc E     = O E
inc (O x) = I x
inc (I x) = O (inc x)

-- Adds two Nats.
add :: Nat -> Nat -> Nat
add a = rep a inc

main :: IO ()
main = do
  let a = I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I E)))))))))))))))))))))))))
  let b = O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O E)))))))))))))))))))))))))
  print (add a b)

The program above implements addition by repeated increment, and is the simplest numerical example where fusion makes a big difference, being exponential on Haskell and linear on HVM. But this isn't useful in practice, as we already have fast addition algorithms like add-carry. The question, then, becomes: is it possible to explore optimality to implement numeric algorithms that actually beat existing ones in practice? The idea is that we could implement numbers with datatypes, which are, usually, much slower than floats, but, if the resulting asymptotics are improved, then this would easily pay off. One of the most successful numeric algorithms ever is the FFT. How would an "optimal FFT" look like? Let's find out!

What is wrong with the textbook FFT

Designing optimal functions is extremely tricky, because it takes very little to interrupt fusion and get an exponential slowdown. For example, if we change the inc Haskell function above to let the Nat grow in size:

inc :: Nat -> Nat
inc E     = I E -- allows the bitstring to grow
inc (O x) = I x
inc (I x) = O (inc x)

Then, the fact we're now using the "I" constructor twice will prevent us from fusing the corresponding function on the HVM. A single character will downgrade it from linear to exponential! There are bazillion ways to ruin fusion. Passing a state as an argument, placing a lambda in the wrong location, using an argument more than once... in fact, it feels like anything that detours from a "direct linear recursion" will refuse to fuse. With that in mind, let's examine the textbook FFT:

Note: you do not need to understand FFT to read this article - just view this as a random snippet we want to optimize.

-- FFT receives a polynomial P, represented as a list of complex coefficients,
-- and returns the evaluation of P on len(P) points of the unit circle. Example:
-- - input  = [A,B,C,D,E,F,G,H]
-- - output = [P(e(0/8)),P(e(1/8)),P(e(2/8)),P(e(3/8)),P(e(4/8)),P(e(5/8)),P(e(6/8)),P(e(7/8))]
--          where P(x) = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷
--                e(x) = e^(2πxi)

-- When len=1, we must eval P(e^(0*2πi)) where P(x) = A. That's just A.
fft [x] = [x]

-- When len>1...
fft xs = pts where

  -- If len=8, we'll eval P = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷
  -- on x ∈ [e(0/8),e(1/8),e(2/8),e(3/8),e(4/8),e(5/8),e(6/8),e(7/8)]
  len = length xs

  -- We first split it on two polynomials, based on even/odd exponents:
  eve = split 0 xs -- EVE = A   + Cx² + Ex⁴ + Gx⁶
  odd = split 1 xs -- ODD = Bx¹ + Dx³ + Fx⁵ + Hx⁷

  -- We then call FFT recursively on EVE and ODD
  pt0 = fft eve -- PT0 = A + Cx¹ + Ex² + Gx³ evaluated on x ∈ [0/4𝜏,1/4𝜏,2/4𝜏,3/4𝜏] (by induction)
                -- PT0 = A + Cx² + Ex⁴ + Gx⁶ evaluated on x ∈ [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏] (by equivalence)
  pt1 = fft odd -- PT1 = B + Dx¹ + Fx² + Hx³ evaluated on x ∈ [0/4𝜏,1/4𝜏,2/4𝜏,3/4𝜏] (by induction)
                -- PT1 = B + Dx² + Fx⁴ + Hx⁶ evaluated on x ∈ [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏] (by equivalence)

  -- We then compute e(x) for each angle, which are the "twiddle factors"
  twi = [cPol (2 * pi * fromIntegral k / fromIntegral len) | k <- [0..len `div` 2 - 1]]

  -- We then multiply PT1 by the tiddle factors
  pt2 = zipWith cMul twi pt1 -- PT2 = Bx¹ + Dx³ + Fx⁵ + Hx⁷ evaluated on x ∈ [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏]

  -- Finally, we obtain all points as PT0 +- PT2
  -- This exploits the symmetry of polynomials
  ptl = zipWith cAdd pt0 pt2 -- PTL = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷ evaluated on [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏]
  ptr = zipWith cSub pt0 pt2 -- PTR = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷ evaluated on [4/8𝜏,5/8𝜏,6/8𝜏,7/8𝜏]

  -- The result just combines PTL and PTR to get all unit circle points
  pts = ptl ++ ptr

Complete code.

So, what is wrong with this? Well, when it comes to optimality, everything. First, it traverses the list just to compute its length. Then, it copies the entire list 2 more times to split even/odd indices. Then, it traverses the odds list twice. Then it copies it. Then it re-generates the same twiddle factors over and over. Then it traverses everything a few more times with zips. All these things block fusion, and, if that wasn't bad enough, it performs arithmetic on machine floats on its its inner loop, which removes any remaining hope we could have. So, how can it be improved?

Improving the structure of FFT

The first insight is to replace lists by balanced binary trees, and store elements on nodes such that, by starting from any element of the tree, walking up to the root, and annotating the branches we passed through as 0 for left and 1 for right, we'll get the binary representation of the index. For example:

[0, 1, 2, 3, 4, 5, 6, 7]

Is represented as:

(B (B (B (L 0) (L 4)) (B (L 2) (L 6))) (B (B (L 1) (L 5)) (B (L 3) (L 7))))

If we start from 6, and move up until the root, we'll pass through right (1), right (1) and left (0) branches of a node, which agrees with the fact 110 is 6 in binary.

Storing elements that way has two benefits. First, we can now split a tree into even/odd indices in O(1) by just taking the left/right branches, removing the no need to call the "split" function, which is O(N), on every recursive call. For example, by taking the left branch of that tree, we get:

(B (B (L 0) (L 4)) (B (L 2) (L 6)))

Which corresponds to the list [0, 2, 4, 6], i.e., the even indices of the original list. And if we take the left branch again, we get:

(B (L 0) (L 4))

Which corresponds to the list [0, 4], which is, once again, the even indices of the list above. The second benefit is that we're now able to replace the 3 calls to zipWith by a single call to mix, which combines evens, odds and the twiddle factors in a single pass. Here is the improved version:

data Complex = C Double Double deriving Show
data Nat     = E | O Nat | I Nat deriving Show
data Tree a  = L a | B (Tree a) (Tree a) deriving Show

fft :: Nat -> Tree Complex -> Tree Complex
fft ang (B eve odd) =
  let pt0 = fft (O ang) eve
      pt1 = fft (O ang) odd
  in mix ang pt0 pt1
fft ang (L x) = L x

mix :: Nat -> Tree Complex -> Tree Complex -> Tree Complex
mix ang (B ax ay) (B bx by) =
  let ptl = mix (O ang) ax bx
      ptr = mix (I ang) ay by
  in B ptl ptr
mix ang (L pt0) (L pt1) =
  let pt2 = cMul pt1 (cPol (2 * pi * natVal ang 0 / 2 ** fromIntegral (natLen ang + 1)))
      ptl = L (cAdd pt0 pt2)
      ptr = L (cSub pt0 pt2)
  in B ptl ptr

Complete code.

The way it works is similar, with all unnecessary work removed. The fft function receives a tree with coefficients, and calls itself recursively on the left and right branches, which corresponds to even and odd coefficients on the original polynomial. Then, it combines the set of recursive points in a single pass by calling the mix function. Finally, it doesn't generate a list of twiddle factors. Instead, it stores an angle (as a natural number, ranging from 0 to N, where N is the number of points) which is then used to generate the twiddle factor at the bottom of the recursion.

While this algorithm is less pedagogical and readable, it is fully linear and way less contrived for the runtime. It is in a healthy shape for optimality, but there is still a problem: it represents complex numbers using floats, thus requires machine numbers that don't fuse. Of course, we could perform FFT over other fields, but in some cases we really need complex FFT.

Implementing integers with datatypes

Our ultimate goal is to come up with an implementation of Complex that can be used on FFT, and that is based purely on ADTs, i.e., no native floats, in such a way that is simple and direct enough to possibly fuse. Let's begin with a simpler goal: integers. Let's recall how we implemented Nat and inc:

data Nat = O Nat | I Nat | E deriving (Show)

inc :: Nat -> Nat
inc E     = O E
inc (O x) = I x
inc (I x) = O (inc x)

So, what about Int? We could try implementing it by using a Nat and a sign:

-- sign = True  correspond to  0,  1,  2,  3, ...
-- sign = False correspond to -1, -2, -3, -4, ...
data Int = Int Bool Nat

intInc :: Int -> Int
intInc (Int True  n) = Int True (inc n)
intInc (Int False n)
    | isMinusOne n = Int True natZero
    | otherwise    = Int False (dec n)

But this is actually a bad idea since, by pattern-matching on the boolean to treat cases separately we inhibit fusion. Also, calling isMinusOne makes it non-linear, which also inhibits fusion. In general, this is exactly the shape of code that does not fuse. We're looking for something more uniform, that doesn't separate the sign.

Fortunatelly, there is a pretty elegant way to do it: balanced ternary. It is similar to ternary numbers, expect its digits are not 0,1,2, but -1,0,1. The -1 digit is usually written as T, so, for example, the string 11T represents 9 + 3 - 1, which is 11. Amazingly, all integers can be uniquely represented in this system. Let's count:

num   trits
--- | -----
−13 |   𝖳𝖳𝖳
−12 |   𝖳𝖳0
−11 |   𝖳𝖳1
−10 |   𝖳0𝖳
 −9 |   𝖳00
 −8 |   𝖳01
 −7 |   𝖳1𝖳
 −6 |   𝖳10
 −5 |   𝖳11
 −4 |    𝖳𝖳 
 −3 |    𝖳0
 −2 |    𝖳1
 −1 |     𝖳
  0 |     0
  1 |     1
  2 |    1𝖳
  3 |    10
  4 |    11
  5 |   1𝖳𝖳
  6 |   1𝖳0
  7 |   1𝖳1
  8 |   10𝖳
  9 |   100
 10 |   101
 11 |   11𝖳
 12 |   110
 13 |   111

Beautiful, isn't it? Its symmetric proporties and simple arithmetic are quite important for fusion. Here is an implementation of Int and inc:

-- T stands for digit -1
-- O stands for digit  0
-- I stands for digit  1
-- E stands for end-of-string
data Int = T Int | O Int | I Int | E

inc :: Int -> Int
inc E     = E
inc (T x) = O x
inc (O x) = I x
inc (I x) = t (Inc x)

As you can imagine, this representation of Int, and its arithmetic operations, can be implemented on HVM in a way that fuses nicely, solving part of the problem. But what about real and complex numbers?

Implementing complex numbers with datatypes

Sadly, it took me a long time to figure out the proper way to do Int, which is astronomically simpler than these. Just to think of the complexity of floats (which include sign, mantissa, exponent, NaNs...), Dedekind Cuts, Cauchy Sequences and the like would give us the feeling we'll never find a representation of real numbers that is as nice and uniform as the Int type above. Yet, for the sake of FFT, there is some light at the end of the tunnel.

The insight is that we don't actually need all complex numbers; we just need enough of them to represent roots of unity. In fact, if we wanted to perform FFT on lists of 2 elements, integers would be enough for us, since we only need to evaluate polynomials on 2 complex points, [1, -1], which are both just ints!

G0

But what if there were 4 elements? In this case, we would actually need 4 points: [1, i, -1, -i]... which is beyond the set of integers. But there is an common extension of integers that could help: Gaussian Integers, which add the square root of -1, i, to Int. It forms sort of a quantized grid over the complex plane, as follows:

We could implement that type as follows:

-- Represents (A + B*i)
data Gauss = G Int Int

This would give us 4 roots of unity, and, thus, the ability to perform FFT on lists of 4 elements!

What about 8 elements though? Well, the idea here is to simply keep going, and add a new constant, j, which is the square root of i. Once we do that, we need 4 ints to represent a number; that's because ij is also part of the set, so any number can be represented as A + B*i + C*j + D*ij. We can implement this extended Gaussian Integer as:

-- Represents (A + B*i + C*j + D*ij)
data ExtGauss = G Int Int Int Int

This would give us 8 roots of unity, and, thus, the ability to perform FFT on lists of 8 elements.

And it would allow us to represent "fractional" complex numbers with just integers. For example, the number 4.652 + 7.816 * i can be approximated as 1 + 2*i + 3*j + 4*ij, which can be constructed as:

num :: ExtGauss
num = G 1 2 3 4

As you can see, we can keep going and generalize this to make the set G(N), with G(0) being Integers, G(1) being Gaussian Integers, G(2) being Extended Gaussian Integers, and so on. To implement it, we can use a tree:

data Tree a = V a | B (Tree a) (Tree a)
type GN     = Tree Int

For G(3), we have 16 roots of unity, and can perform FFT on lists of 16 elements:

To add two complex numbers of GN, we just do so pairwise:

add :: GN -> GN -> GN
add (L z)   (L w)   = L (z + w)
add (B a b) (B c d) = B (add a c) (add b d)

But what about multiplication? A very elegant algorithm, found by T6 on HOC's discord, allows us to multiply a GN number by bases like i, j, k, etc., with just rotations and negation:

rot :: GN -> GN
rot (L z)   = L (-z)
rot (B a b) = B (rot b) a

This simple function, if applied on a GN element, will multiply it by its smallest base. So, for example, on G(3), rot performs multiplication by k. On G(2), it multiplies by j. On G(1), it multiplies by i. On G(0), which is just a scalar, it negates the number. Very nice! We can use Rot to multiply two GN elements as follows (once again, thanks T6):

mul :: GN -> GN -> GN
mul (L z)   (L w)   = L (z * w)
mul (B a b) (B c d) = B (add (mul a c) (rot (mul b d))) (add (mul b c) (mul a d))

Which is, again, very elegant. But, for the sake of FFT, we don't actually need to perform multiplication of two arbitrary numbers during the algorithm; we just need to multiply by twiddle factors on the upper side of the unit circle. As such, we can use this simplified algorithm, which receives an angle on the unit circle, represented by a Nat, and multiplies it by the respective point:

mul :: Nat -> GN -> GN
mul E     x         = x
mul (O p) (L x)     = id  (L x)
mul (I p) (L x)     = rot (L x)
mul (O p) (B ax ay) = id  (B (mul p ax) (mul p ay))
mul (I p) (B ax ay) = rot (B (mul p ax) (mul p ay))

So, for example, mul (I (I (I E))) x will multiply x by ijk, which is approximately -0.923 + 0.382i. This completely removes the need for computing cos, sin, e^ix, and the like, simplifying the pt2 computation from:

pt2 = cMul pt1 (cPol (2 * pi * natVal ang 0 / 2 ** fromIntegral (natLen ang + 1)))

To just:

let pt2 = mul (ang E) pt1

In other words, all the complex floating point arithmetic, including trigonometric and exponentiation functions, are replaced by the mul function above, which is just a simple recursive pass over the tree structure!

Implementing FFT with datatypes

After all this, we're finally able to implement a complete FFT algorithm, in 40 lines of Haskell code, with just plain ADTs:

data Nat    = E | O Nat | I Nat
data Tree a = L a | B (Tree a) (Tree a)
type GN     = Tree Int

add :: GN -> GN -> GN
add (L x)     (L y)     = L (x + y)
add (B ax ay) (B bx by) = B (add ax bx) (add ay by)

sub :: GN -> GN -> GN
sub (L x)     (L y)     = L (x - y)
sub (B ax ay) (B bx by) = B (sub ax bx) (sub ay by)

rot :: GN -> GN
rot (L z)   = L (-z)
rot (B x y) = B (rot y) x

mul :: Nat -> GN -> GN
mul E     x         = x
mul (O p) (L x)     = L x
mul (I p) (L x)     = rot (L x)
mul (O p) (B ax ay) = B (mul p ax) (mul p ay)
mul (I p) (B ax ay) = rot (B (mul p ax) (mul p ay))

fft :: (Nat -> Nat) -> Tree GN -> Tree GN
fft ang (L x)   = L x
fft ang (B x y) =
  let pt0 = fft (ang.O) x
      pt1 = fft (ang.O) y
  in mix ang pt0 pt1

mix :: (Nat -> Nat) -> Tree GN -> Tree GN -> Tree GN
mix ang (B ax ay) (B bx by) =
  let ptl = mix (ang.O) ax bx
      ptr = mix (ang.I) ay by
  in B ptl ptr
mix ang (L pt0) (L pt1) =
  let pt2 = mul (ang E) pt1
      ptl = L (add pt0 pt2)
      ptr = L (sub pt0 pt2)
  in B ptl ptr

Yes, this the a complete implementation! Note in this version I'm actually using the primitive Haskell Int, but, as we've stablished, it can be implemented with just ADTs via balanced ternary. Here is a complete Haskell file which performs FFT using GNs, with some helper functions to convert it from and to lists of normal complex numbers, for visualization.

So, how does this behave on the HVM? I don't know. I've just finished implementing it on Haskell and will spend some time trying to adjust it for the HVM in the future. There are many things to tweak. For example, sub isn't necessary and can easily be removed; Nat angles can be replaced by a twiddle multiplier; mix could be simplified, perhaps. But, as is, this gets rid of most of the inefficiencies with common FFT, and, as far as I know, is the cleanest version of FFT in a pure functional sense. If you like it, feel encouraged to join the Higher Order Community to talk about it and ask any questions. Thanks for reading!

Edit: read the comments below for a cool surprise :)

@VictorTaelin
Copy link
Author

VictorTaelin commented May 4, 2023

So, the type of mix seems fishy indeed, here are some possible calls to it:

// - mix 0 0 (R 0) (GN 0) (GN 0) -> GN 2

// - mix 1 0 (R 0) (GN 2) (GN 2) -> GN 4
// - mix 0 1 (R 1) (GN 1) (GN 1) -> GN 3

// - mix 2 0 (R 0) (GN 4) (GN 4) -> GN 6
// - mix 1 1 (R 1) (GN 3) (GN 3) -> GN 5
// - mix 0 2 (R 2) (GN 2) (GN 2) -> GN 4

// - mix 3 0 (R 0) (GN 6) (GN 6) -> GN 8
// - mix 2 1 (R 1) (GN 5) (GN 5) -> GN 7
// - mix 1 2 (R 2) (GN 4) (GN 4) -> GN 6
// - mix 0 3 (R 3) (GN 3) (GN 3) -> GN 5

So I realize its type should be:

mix : ∀ n m -> R n -> G (2n+m) -> G (2n+m) -> G (2+2n+m)

With R being the type of the rotator. This makes more sense now. So we probably just need a ∀n -> R n -> R (S n) now, i.e., updating the rotator to cover another layer, which should be just f . f, but let me think.

Edit: ohh - I think the rotator should be a Church Nat

Edit: that's a pretty stupid way to call mix, there is no need to increment it, just pass the depth as a constant. the type of mix should then be:

mix : ∀ n m -> R n -> G (n + m) -> G (n + m) -> G (2 + n + m)

and we should probably just make G use polymorphic trees anyway, so the actual type is:

mix : ∀ n m -> R n -> T n (G m) -> T n (G m) -> T (1 + n) (G (1 + m))

which makes much more sense

@VictorTaelin
Copy link
Author

VictorTaelin commented May 5, 2023

The algorithm above was slightly off. First, the type of mix was incorrect, as explained above. The mix function handles two nats: one that counts down (to recurse), and one that counts up (to build a rotator with the correct depth). That's because the trees it receives actually have two halves: the upper half can be seen as the "list" part, and the lower half can be seen as the "number" part. For example, if it receives a tree with depth 6, then depths 0,1,2 are the "list", and depths 3,4,5 are the "number". Thus, it has 8 numbers of 8 components each. This is hard to explain, so here are some complete calls for visualization:

// when n=0, mix args contain 1 (GN 0). that's a (GN 0)
// - A is a (GN 0)
// then it returns 2 (GN 1). that's a (GN 2)
// mix 0 0 A B
// ret [(a+Bi) (A-Bi)]

// when n=1, mix args contain 2 (GN 1). that's a (GN 2)
// - (A+Bi) is a (GN 1), so [(A+Bi) (A-Bi)] is a (GN 2)
// then it returns 4 (GN 2). that's a (GN 4)
// mix 1 0 [(A+Bi) (A-Bi)] [(C+Di) (C-Di)]
// mix 0 1 (A+Bi) (C+Di)
// mix 0 1 (A-Bi) (C-Di)
// ret [(A+Bi+Cj+Dij) (A+Bi-Cj-Dij) (A-Bi+Dj+Cij) (A-Bi-Dj-Cij)]

// when n=2, mix args contain 4 (GN 2). that's a (GN 4)
// then it returns 8 (GN 3). that's a (GN 6)
// mix 2 0 [(A+Bi+Cj+Dij) (A+Bi-Cj-Dij) (A-Bi+Dj+Cij) (A-Bi-Dj-Cij)] [(E+Fi+Gj+Hij) (E+Fi-Gj-Hij) (E-Fi+Hj+Gij) (E-Fi-Hj-Gij)]
// mix 1 1 [(A+Bi+Cj+Dij) (A+Bi-Cj-Dij)] [(E+Fi+Gj+Hij) (E+Fi-Gj-Hij)]
// mix 0 2 (A+Bi+Cj+Dij) (E+Fi+Gj+Hij)
// mix 0 2 (A+Bi-Cj-Dij) (E+Fi-Gj-Hij)
// mix 1 1 [(A-Bi+Dj+Cij) (A-Bi-Dj-Cij)] [(E-Fi+Hj+Gij) (E-Fi-Hj-Gij)]
// mix 0 2 (A-Bi+Dj+Cij) (E-Fi+Hj+Gij)
// mix 0 2 (A-Bi-Dj-Cij) (E-Fi-Hj-Gij)

// note: here I'm not adjusting twiddle factors, so not all signs are correct - this is just for visualization
// note: I afterwards realized the second argument should be just a constant, much cleaner

Notice when we get to the base case of mix (i.e., when n=0), we have passed through the "list" part, and now we're in the number itself, and m will store the depth of the number, which can be used to call rot with the correct depth on it. Now, notice that, at this point, a single rot m will be a function that multiplies by the smallest constant. For example, if m=0, it multiplies by -1; if m=1, it multiplies by i; if m=2, it multiplies by j; if m=3, it multiplies by k, and so on. Now, what we need to do is adjust b by multiplying by the correct twiddle factor, which is e^(t/2^m*pi*i), where t is the index of the element. So, for example, if we are on the element at index 5 of the "list", and if m=3, then we must multiply b by e^(5/8*pi*i). But how do we get that twiddle factor? We just multiply by k 5=10 times! Or, in other words, we apply rot 3 5 times. That's because 5/8*pi is ik (recall from the circle above), and k*k*k*k*k = j*j*k = i*k. Now, how do we negate b, for the second point? Well, we just multiply b by k again, 2^m times! Notice that's half a circle rotation. And that's all. So, the fixed algorithm would be:

still slightly wrong but closer

G : Set
G  Z    = ℤ
G (S p) = C (G p) (G p)

app : (a  a)  (a  a) 
app  Z    f x = x
app (S p) f x = app p f (f x)

rot :  n  G n  G n
rot  Z     a      = - a
rot (S p) (C x y) = C (rot p y) x

mix :  n m  G (2n+m)  G (2n+m)  G (2+2n+m)
mix  Z    m t  a         b        = C (C a (app t (rot m) b)) (C a (app (t+2^m) (rot m) b))
mix (S p) m t (C ax ay) (C bx by) = C (mix p (S m) (2*t) ax bx) (mix p (S m) (1+2*t) ay by)

fft :  n  G n  G (2*n)
fft  Z     a        = a
fft (S p) (C ax ay) = mix p id (fft p ax) (fft p ay)

Now, it may look like this is not very efficient; but this actually should work really well because rot fuses, so, repeated application of it is very fast. So, for example, instead of a Nat index (to use on app), we could pass a Church Nat f, which is doubled as λs λz (f s (f s z)), and incremented as λs λz (s (f s z)), and then use it to efficiently generate the twiddle factors that perform the correct rotation. There are still some things to tweak and improve here. For example, fft should probably return G (G n), which is the same as G (2*n). I'm also thinking about reversed rotations. There could still be some issues as I've not tested/implemented yet, but I think this is reasonable progress. Coming next I might be able to improve it further and perhaps run tests on HVM.

@VictorTaelin
Copy link
Author

VictorTaelin commented May 5, 2023

Some notes on IFFT. In order to implement it, we need to reverse the direction of rotation of twiddle factors. This can be done as:

tor :  n  G n  G n
tor  Z     a      = - a
tor (S p) (C x y) = C y (tor p x)

And that's it. That was easy! But there is another issue that demands explanation. In our implementation of FFT, we're doubling the depth of the tree. There is a reason I do that, but perhaps that's not a great idea, so take this as an experiment. But the reason is that, as you can tell, the GN n type takes 2^n space. So, to store a number with depth 8, we need 256 components, versus just 2 of conventional complex numbers. This is the main reason this algorithm might not perform so well, but the hope is that, on FFT, all these nodes will be heavily shared in an optimal evaluator, taking much less space. That said, there is no hope for that if we start with fully expanded nodes, to begin with. For example, if we run FFT on trees with depth 8, then we need to use GN 8. That's 256 numbers with 256 components each, i.e., 65536 ints, versus 512 floats that conventional FFT would use.

Now, consider the case when we're performing FFT in a polynomial with non-imaginary coefficients. In that case, each of the 256 the GN 8 values would be in the shape A + 0i + 0j + 0ij + 0k + 0ik + 0jk + 0ijk + 0l .... Or, in other words, almost all of the 65536 ints would be 0. That sounds like a huge waste, and room for optimization! One way to optimize this would be to have a tree node directly on GN that represents "everything below this is 0", and that might be actually a better solution, but my feeling was that, for a few reasons, a different solution would be better. The idea I had is that FFT would be restricted so that we can only perform it on a list of non-aligned unit vectors. For example, if we want to perform FFT on 8 values, instead of inputing a list of non-imaginary scalars, like [A, B, C, D, E, F, G, H}, we'd actually input a list like [A, Bi, Cj, Dij, Ek, Fik, Gjk, Hijk]. Or, in another point of view, FFT is performed on the GN itself, viewed as a list. Now, since that is always the case, we can just decrease precision, and assume that the basis are the implicitly! For example, 42 on the first index of the list is actually 42, but on the second index, it is 42i, on the third index, it is 42j, etc. - all that implicitly. So, for example, when mixing a = [A] and b = [B], we implicitly know we are actually mixing the points a = [A + 0i] and b = [0 + Bi], because there is an implicit *i on b. As such, we can do that by just concatenating (C a b), there is no need to call Add. Or, for example, when mixing a = [(A + Bi), (A - Bi)] with b = [(C + Di), (C - Di)], we know we're actually mixing a = (A + Bi + 0j + 0ij), (A - Bi + 0j + 0ij) with b = (0 + 0i + Ck + Dik), (0 + 0i + Ck - Dik) - this time, there is an implicit '*k' on 'b'. As such, once again, we can add a and b by just concatenating (C a b). In a way, that means 'mix' increases the precision of each point it holds by 1 depth. So, basically, we avoid all these zeroes by expanding the precision as needed in a way that perfectly hides all the unnecessary information.

So, for IFFT, we want the opposite to happen: we want the "inverse mix" function to decrease the precision by 1 depth. This is possible due to some unexplained magic that causes components to cancel and become zero. For example, if we perform FFT on [1, i], we get [(1 + i), (1 - i)] - notice the precision of each number increased. Now, when we perform IFFT of [(1 + i), (1 - i)], we get [(1 + i) + (1 - i), (1 + i) - (1 - i)], which is just [(2 + 0i), (0 + 2i)], i.e., the double of the original list we performed FFT on. Similarly, the FFT of [1, i, j, ij] is [(1+i+j+ij),(1+i-j-ij),(1-i+j+ij),(1-i-j-ij)]. Now, suppose we performed IFFT on that. On the base case of unmix, on the first branch, we'll have unmix a=[(1+i+j+ij)] b=[(1+i-j-ij)]. This results in [((1+i+j+ij)+(1+i-j-ij)), ((1+i+j+ij)-(1+i-j-ij))], which is just [(2+2i+0j+0ij), (0+0i+2j+2ij)]. Notice how the right branch of the first point, and the left branch of the second point, magically got zeroed? That can use that property to decrease the precision by erasing the right and left branch, and returning just [(2+2i), (2j+2ij)]. Or, in other words, unmix removes one layer of each GN it receives. Now, since it adds 1 layer (to concatenate the points) and removes 1 layer (to decrease precision), the depth stays constant. It should look like:

unmix :  n m -> G (n + m) -> G (n + m) -> G (n + m)
unmix  Z    m t (C ax ay) (C bx by) = C (add m ax (app t (tor m) by)) (add m ax (app (t+2^m) (tor m) by))
unmix (S p) m t (C ax ay) (C bx by) = C (mix p (S m) (2*t) ax bx) (mix p (S m) (1+2*t) ay by)

Notice how ay and bx got erased. That is great for lazy / optimal reduction, and is the main reason we should hope this algorithm can actually perform well on HVM. That is, if my feeling is correct, the orchestrated erasure of maximally shared information will cause this algorithm to perform well, despite the supposedly exponential size of GN n. The type of IFFT would be:

fft :  n  G (2*n)  G n

Notice here we actually need add though. Now, there is a last thing to be said. As you may know, IFFT actually has a 1/N factor. You can see where that is coming from by noticing that the IFFTs I performed above actually returned the double of the original list. Instead of waiting to divide at the end of the recursion, we can actually do that directly on the add function, by removing a bit from the component. Possibly, the precision of the stored Int should also be increased/decreased by mix and unmix? Or perhaps we don't even need Int and can just store a single balanced trit ({-1,0,1}), since 1. performing FFT on unit vectors will generate vectors with unit components, 2. unmix halves the precision at each call, so we never need more than 1 digit? Another possibility would be to just get rid of signs and perform on bools, although I'm not sure what that would mean - but, in a way, G(0) could be seen as Nat, G(1) as Int, G(2) as Gaussian Int, etc.?

@VictorTaelin
Copy link
Author

VictorTaelin commented May 5, 2023

Just a quick update on the actual types:

// mix  : ∀ n m -> G (n + m) -> G (n + m) -> G (2 + n + m)
// imix : ∀ n m -> G (n + m) -> G (n + m) -> G (n + m)
// fft  : ∀ n m -> G (n + m) -> G (2n + m)
// ifft : ∀ n m -> G (n + m) -> G m

// FFT of (GN 4):
// - mix 0 0 (GN 0) (GN 0) -> GN 2
// - mix 1 1 (GN 2) (GN 2) -> GN 4
// - mix 0 1 (GN 1) (GN 1) -> GN 3
// - mix 2 2 (GN 4) (GN 4) -> GN 6
// - mix 1 2 (GN 3) (GN 3) -> GN 5
// - mix 0 2 (GN 2) (GN 2) -> GN 4
// - mix 3 3 (GN 6) (GN 6) -> GN 8
// - mix 2 3 (GN 5) (GN 5) -> GN 7
// - mix 1 3 (GN 4) (GN 4) -> GN 6
// - mix 0 3 (GN 3) (GN 3) -> GN 5
// Result: (GN 8)

// IFFT of (GN 8):
// - imix 0 4 (GN 4) (GN 4) -> GN 4
// - imix 1 3 (GN 4) (GN 4) -> GN 4
// - imix 0 3 (GN 3) (GN 3) -> GN 4
// - imix 2 2 (GN 4) (GN 4) -> GN 4
// - imix 1 2 (GN 3) (GN 3) -> GN 4
// - imix 0 2 (GN 2) (GN 2) -> GN 4
// - imix 3 1 (GN 4) (GN 4) -> GN 4
// - imix 2 1 (GN 3) (GN 3) -> GN 4
// - imix 1 1 (GN 2) (GN 2) -> GN 4
// - imix 0 1 (GN 1) (GN 1) -> GN 4
// Result: (GN 4)

@VictorTaelin
Copy link
Author

VictorTaelin commented May 6, 2023

FFT on HVM:

// Church
(IZ)   = λs λz z
(IS n) = λs λz (s (n s z))
(IO f) = λs λz let g = (f s); (g (g z))
(II f) = (IS (IO f))
(IN 0) = IZ
(IN s) = (IS (IN (- s 1)))

// Nat
Z     = λz λs z
(S n) = λz λs (s n)
(N 0) = Z
(N s) = (S (N (- s 1)))
(A s) = (s λb(b) λpλb(S ((A p) b)))

// Units
T = λt λo λi t
O = λt λo λi o
I = λt λo λi i

//G : ℕ → Set
//G  Z    = ℤ
//G (S p) = C (G p) (G p)
(C ax ay) = λc (c ax ay)

// Zero : ∀ n → G n
(Zero n) =
  let case_zero = λtλoλi(o)
  let case_succ = λp (let z = (Zero p); (C z z))
  (n case_zero case_succ)

// All : ∀ n → G n
(All n) =
  let case_zero = λtλoλi(i)
  let case_succ = λp (let a = (All p); (C a a))
  (n case_zero case_succ)

// Unit : ∀ n → G n
(Unit s) =
  let case_zero = λtλoλi(i)
  let case_succ = λp (C (Unit p) (Zero p))
  (s case_zero case_succ)

// Rot : ∀ n -> G n -> G n
(Rot s a) =
  let case_zero = λa λtλoλi(a i o t)
  let case_succ = λp λa λc (a λax λay (c (Rot p ay) ax))
  (s case_zero case_succ a)

// IRot : ∀ n -> G n -> G n
(IRot s a) =
  let case_zero = λa λtλoλi(a i o t)
  let case_succ = λp λa λc (a λax λay (c ay (IRot p ax)))
  (s case_zero case_succ a)

// Neg : ∀ n -> G n -> G n
(Neg s a) =
  let case_zero = λa λtλoλi(a i o t)
  let case_succ = λp λa λc (a λax λay (c (Neg p ax) (Neg p ay)))
  (s case_zero case_succ a)

// Add : ∀ n -> G n -> G n -> G n
(Add s a b) =
  let case_zero = λa λb (a λbλtλoλi(b i t o) λb(b) λbλtλoλi(b o i t) b)
  let case_succ = λp λa λb λc (a λax λay (b λbx λby
    let cx = (Add p ax bx)
    let cy = (Add p ay by)
    (c cx cy)))
  (s case_zero case_succ a b)

// mul : ∀ n -> G n -> G n -> G n
(Mul s a b) =
  let case_zero = λa λb (a λbλtλoλi(b i o t) λbλtλoλi(o) λb(b) b)
  let case_succ = λp λa λb λc (a λax λay (b λbx λby
    let cx = (Add p (Mul p ax bx) (Rot p (Mul p ay by)))
    let cy = (Add p (Mul p ay bx) (Mul p ax by))
    (c cx cy)))
  (s case_zero case_succ a b)

// mix : ∀ n m -> G (n + m) -> G (n + m) -> G (2 + n + m)
(Mix s m t a b) =
  let case_zero = λm λt λa λb λc
    let b = (t λx(Rot m x) b)
    (c (C a b) (C a (Neg m b)))
  let case_succ = λp λm λt λa λb λc (a λax λay (b λbx λby
    let p0 = (Mix p m (IO t) ax bx)
    let p1 = (Mix p m (II t) ay by)
    (c p0 p1)))
  (s case_zero case_succ m t a b)

// imix : ∀ n m -> G (n + m) -> G (n + m) -> G (n + m)
(IMix s m t a b) =
  let case_zero = λm λt λa λb λc
    let b = (t λx(IRot m x) b)
    let u = (Add m a b)
    let v = (Add m a (Neg m b))
    (u λux λuy (v λvx λvy (c ux vy)))
  let case_succ = λp λm λt λa λb λc (a λax λay (b λbx λby
    let p0 = (IMix p m (IO t) ax bx)
    let p1 = (IMix p m (II t) ay by)
    (c p0 p1)))
  (s case_zero case_succ m t a b)

// fft : ∀ n m -> G (n + m) -> G (2n + m)
(FFT s m a) =
  let case_zero = λa λm a
  let case_succ = λp λa λm (a λax λay
    let p0 = (FFT p m ax)
    let p1 = (FFT p m ay)
    (Mix p ((A m) p) IZ p0 p1))
  (s case_zero case_succ a m)

// ifft : ∀ n m -> G (n + m) -> G m
(IFFT s m a) =
  let case_zero = λa λm a
  let case_succ = λp λa λm (a λax λay
    let sm = (S m)
    let p0 = (IFFT p sm ax)
    let p1 = (IFFT p sm ay)
    (IMix p sm IZ p0 p1))
  (s case_zero case_succ a m)

// Visualization
(Show s a) =
  let case_zero = λa (a (- 0.0 1.0) 0.0 1.0)
  let case_succ = λp λa (a λaxλay[(Show p ax) (Show p ay)])
  (s case_zero case_succ a)

Main =
  let s = 5
  let x = (All (N s))
  let x = (FFT  (N s) (N 0) x)
  let x = (IFFT (N s) (N 0) x)
  (Show (N s) x)

I think there are still some issues, and it doesn't have ints yet (just a Trit type for elements). That was a lot of work. Now we can start a multi year effort to attempt to fuse this beast (:

@astump
Copy link

astump commented May 16, 2023

Beyond amazing. So much fantastic CS here, and totally inspiring to see mathematical algorithms using only number represented by ADTs. Can't wait to hear the performance punchline.

@VictorTaelin
Copy link
Author

VictorTaelin commented May 16, 2023

Oh, thanks for the kind words, Aaron! I had to stop after completing the working HVM version because I spent a few days immersed on this problem and had a lot of accumulated work :')

Preliminary results is that, indeed, FFT fuses and seems to be efficient as expected, i.e., O(n * log(n)) and NOT quadratic (which would be the case in eager evaluators, considering that just constructing the tree of GNs would take O(n^2) time). In fact, it is quadratic even on GHC, because I exploit the fact that Rot fuses to perform fast multiplication by twiddle factors on HVM (on the let b = (t λx(Rot m x) b) line on the Mix function). So, HVM's optimal reduction allows this elegant algorithm to be fast, which is very cool to me. Sadly, the inverse, IFFT seems to break fusion, because of the call to Add. It is probably wrong too, as I didn't have much time to think about it. So, right now IFFT . FFT is O(n^2), but I believe it could go down to O(n * log(n)) once we put more thought on IFFT.

Some other things I want to try include: just implementing it over fixed-precision GN (i.e., without the whole increase/decrease precision scheme), doing it over GN with a "leaf" constructor (i.e., making it a normal tree rather than an indexed-depth tree, and then recursing on the structure of the GN, instead of recursing on a Nat), replacing recursion by Church Nat (i.e., applying N to the fixed point of FFT/Mix, and trying to fuse the structure of the recursion itself), doing FFT over other fields (like integers mod P, perhaps p-adics), and thinking in other ways to represent complex numbers. Also, I've been pointed FFT is just a special case of Linear Canonical Transformation, so I'd like to investigate that too. Sadly not much time for that right now, but one day.

@pedroth
Copy link

pedroth commented Aug 20, 2024

Amazing post! Even though I couldn't follow everything, find it amazing that roots of unity could be modeled using just integers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment