Theory FFT

(*  Title:      Fast Fourier Transform
    Author:     Clemens Ballarin <ballarin at in.tum.de>, started 12 April 2005
    Maintainer: Clemens Ballarin <ballarin at in.tum.de>
*)

theory FFT
imports Complex_Main
begin

text ‹We formalise a functional implementation of the FFT algorithm
  over the complex numbers, and its inverse.  Both are shown
  equivalent to the usual definitions
  of these operations through Vandermonde matrices.  They are also
  shown to be inverse to each other, more precisely, that composition
  of the inverse and the transformation yield the identity up to a
  scalar.

  The presentation closely follows Section 30.2 of Cormen \textit{et
  al.}, \emph{Introduction to Algorithms}, 2nd edition, MIT Press,
  2003.›


section ‹Preliminaries›

text ‹The following two lemmas are useful for experimenting with the
  transformations, at a vector length of four.›

lemma Ivl4:
  "{0..<4::nat} = {0, 1, 2, 3}"
proof -
  have "{0..<4::nat} = {0..<Suc (Suc (Suc (Suc 0)))}" by (simp add: eval_nat_numeral)
  also have "... = {0, 1, 2, 3}"
    by (simp add: atLeastLessThanSuc eval_nat_numeral insert_commute)
  finally show ?thesis .
qed

lemma Sum4:
  "(i=0..<4::nat. x i) = x 0 + x 1 + x 2 + x 3"
  by (simp add: Ivl4 eval_nat_numeral)


text ‹A number of specialised lemmas for the summation operator,
  where the index set is the natural numbers›

lemma sum_add_nat_ivl_singleton:
  assumes less: "m < (n::nat)"
  shows "f m + sum f {m<..<n} = sum f {m..<n}"
proof -
  have "f m + sum f {m<..<n} = sum f ({m}  {m<..<n})"
    by (simp add: sum.union_disjoint ivl_disj_int)
  also from less have "... = sum f {m..<n}"
    by (simp only: ivl_disj_un)
  finally show ?thesis .
qed

lemma sum_add_split_nat_ivl_singleton:
  assumes less: "m < (n::nat)"
    and g: "!!i. [| m < i; i < n |] ==> g i = f i"
  shows "f m + sum g {m<..<n} = sum f {m..<n}"
  using less g
  by(simp add: sum_add_nat_ivl_singleton cong: sum.cong_simp)

lemma sum_add_split_nat_ivl:
  assumes le: "m <= (k::nat)" "k <= n"
    and g: "!!i. [| m <= i; i < k |] ==> g i = f i"
    and h: "!!i. [| k <= i; i < n |] ==> h i = f i"
  shows "sum g {m..<k} + sum h {k..<n} = sum f {m..<n}"
  using le g h by (simp add: sum.atLeastLessThan_concat cong: sum.cong_simp)

lemma ivl_splice_Un:
  "{0..<2*n::nat} = ((*) 2 ` {0..<n})  ((%i. Suc (2*i)) ` {0..<n})"
  apply (unfold image_def Bex_def)
  apply auto
  apply arith
  done

lemma ivl_splice_Int:
  "((*) 2 ` {0..<n})  ((%i. Suc (2*i)) ` {0..<n}) = {}"
  by auto arith

lemma double_inj_on:
  "inj_on (%i. 2*i::nat) A"
  by (simp add: inj_onI)

lemma Suc_double_inj_on:
  "inj_on (%i. Suc (2*i)) A"
  by (rule inj_onI) simp

lemma sum_splice:
  "(i::nat = 0..<2*n. f i) = (i = 0..<n. f (2*i)) + (i = 0..<n. f (2*i+1))"
proof -
  have "(i::nat = 0..<2*n. f i) =
    sum f ((*) 2 ` {0..<n}) + sum f ((%i. 2*i+1) ` {0..<n})"
    by (simp add: ivl_splice_Un ivl_splice_Int sum.union_disjoint)
  also have "... = (i = 0..<n. f (2*i)) + (i = 0..<n. f (2*i+1))"
    by (simp add: sum.reindex [OF double_inj_on]
      sum.reindex [OF Suc_double_inj_on])
  finally show ?thesis .
qed


section ‹Complex Roots of Unity›

text ‹The function @{term cis} from the complex library returns the
  point on the unity circle corresponding to the argument angle.  It
  is the base for our definition of root›.  The main property,
  De Moirve's formula is already there in the library.›

definition root :: "nat => complex" where
  "root n == cis (2*pi/(real (n::nat)))"

lemma sin_periodic_pi_diff: "sin (x - pi) = - sin x"
  by (fact sin_minus_pi)

lemma sin_cos_between_zero_two_pi:
  assumes 0: "0 < x" and pi: "x < 2 * pi"
  shows "sin x  0  cos x  1"
proof -
  { assume "0 < x" and "x < pi"
    then have "sin x  0" by (auto dest: sin_gt_zero) }
  moreover
  { assume "x = pi"
    then have "cos x  1" by simp }
  moreover
  { assume pi1: "pi < x" and pi2: "x < 2 * pi"
    then have "0 < x - pi" and "x - pi < pi" by arith+
    then have "sin (x - pi)  0" using sin_gt_zero by fastforce
    with pi1 pi2 have "sin x  0" by simp }
  ultimately show ?thesis using 0 pi by arith
qed


subsection ‹Basic Lemmas›

lemma root_nonzero: "root n  0"
  by (auto simp add: complex_eq_iff root_def dest: sin_zero_abs_cos_one)

lemma root_unity: "root n ^ n = 1"
  by (simp add: complex_eq_iff root_def DeMoivre)

lemma root_cancel: "0 < d ==> root (d * n) ^ (d * k) = root n ^ k"
  apply (unfold root_def)
  apply (simp add: DeMoivre)
  done

lemma root_summation:
  assumes k: "0 < k" "k < n"
  shows "(i=0..<n. (root n ^ k) ^ i) = 0"
proof -
  from k have real0: "0 < real k * (2 * pi) / real n"
    by (simp add: zero_less_divide_iff
      mult_strict_right_mono [where a = 0, simplified])
  from k mult_strict_right_mono [where a = "real k" and
    b = "real n" and c = "2 * pi / real n", simplified]
  have realk: "real k * (2 * pi) / real n < 2 * pi"
    by (simp add: zero_less_divide_iff)
  txt ‹Main part of the proof›
  have "(i=0..<n. (root n ^ k) ^ i) =
    ((root n ^ k) ^ n - 1) / (root n ^ k - 1)"
    unfolding atLeast0LessThan
    apply (rule geometric_sum)
    apply (unfold root_def)
    apply (simp add: DeMoivre)
    using real0 realk sin_cos_between_zero_two_pi
    apply (auto simp add: complex_eq_iff)
    done
  also have "... = ((root n ^ n) ^ k - 1) / (root n ^ k - 1)"
    by (simp add: power_mult [THEN sym] ac_simps)
  also have "... = 0"
    by (simp add: root_unity)
  finally show ?thesis .
qed

lemma root_summation_inv:
  assumes k: "0 < k" "k < n"
  shows "(i=0..<n. ((1 / root n) ^ k) ^ i) = 0"
proof -
  from k have real0: "0 < real k * (2 * pi) / real n"
    by (simp add: zero_less_divide_iff
      mult_strict_right_mono [where a = 0, simplified])
  from k mult_strict_right_mono [where a = "real k" and
    b = "real n" and c = "2 * pi / real n", simplified]
  have realk: "real k * (2 * pi) / real n < 2 * pi"
    by (simp add: zero_less_divide_iff)
  txt ‹Main part of the proof›
  have "(i=0..<n. ((1 / root n) ^ k) ^ i) =
    (((1 / root n) ^ k) ^ n - 1) / ((1 / root n) ^ k - 1)"
    unfolding atLeast0LessThan
    apply (rule geometric_sum)
    apply (simp add: nonzero_inverse_eq_divide [THEN sym] root_nonzero)
    apply (unfold root_def)
    apply (simp add: DeMoivre)
    using real0 realk sin_cos_between_zero_two_pi
    apply (auto simp add: complex_eq_iff)
    done
  also have "... = (((1 / root n) ^ n) ^ k - 1) / ((1 / root n) ^ k - 1)"
    by (simp add: power_mult [THEN sym] ac_simps)
  also have "... = 0"
    by (simp add: power_divide root_unity)
  finally show ?thesis .
qed

lemma root0 [simp]:
  "root 0 = 1"
  by (simp add: complex_eq_iff root_def)

lemma root1 [simp]:
  "root 1 = 1"
  by (simp add: complex_eq_iff root_def)

lemma root2 [simp]:
  "root 2 = -1"
  by (simp add: complex_eq_iff root_def)

lemma root4 [simp]:
  "root 4 = 𝗂"
  by (simp add: complex_eq_iff root_def)


subsection ‹Derived Lemmas›

lemma root_cancel1:
  "root (2 * m) ^ (i * (2 * j)) = root m ^ (i * j)"
proof -
  have "root (2 * m) ^ (i * (2 * j)) = root (2 * m) ^ (2 * (i * j))"
    by (simp add: ac_simps)
  also have "... = root m ^ (i * j)"
    by (simp add: root_cancel)
  finally show ?thesis .
qed

lemma root_cancel2:
  "0 < n ==> root (2 * n) ^ n = - 1"
  txt ‹Note the space between -› and 1›.›
  using root_cancel [where n = 2 and k = 1]
  by (simp add: complex_eq_iff ac_simps)


section ‹Discrete Fourier Transformation›

text ‹
  We define operations  DFT› and IDFT› for the discrete
  Fourier Transform and its inverse.  Vectors are simply functions of
  type nat => complex›.›

text DFT n a› is the transform of vector a›
  of length n›, IDFT› its inverse.›

definition DFT :: "nat => (nat => complex) => (nat => complex)" where
  "DFT n a == (%i. j=0..<n. (root n) ^ (i * j) * (a j))"

definition IDFT :: "nat => (nat => complex) => (nat => complex)" where
  "IDFT n a == (%i. (k=0..<n. (a k) / (root n) ^ (i * k)))"

schematic_goal "map (DFT 4 a) [0, 1, 2, 3] = ?x"
  by(simp add: DFT_def Sum4)

text ‹Lemmas for the correctness proof.›

lemma DFT_lower:
  "DFT (2 * m) a i =
  DFT m (%i. a (2 * i)) i +
  (root (2 * m)) ^ i * DFT m (%i. a (2 * i + 1)) i"
proof (unfold DFT_def)
  have "(j = 0..<2 * m. root (2 * m) ^ (i * j) * a j) =
    (j = 0..<m. root (2 * m) ^ (i * (2 * j)) * a (2 * j)) +
    (j = 0..<m. root (2 * m) ^ (i * (2 * j + 1)) * a (2 * j + 1))"
    (is "?s = _")
    by (simp add: sum_splice)
  also have "... = (j = 0..<m. root m ^ (i * j) * a (2 * j)) +
    root (2 * m) ^ i *
    (j = 0..<m. root m ^ (i * j) * a (2 * j + 1))"
    (is "_ = ?t")
    txt ‹First pair of sums›
    apply (simp add: root_cancel1)
    txt ‹Second pair of sums›
    apply (simp add: sum_distrib_left)
    apply (simp add: power_add)
    apply (simp add: root_cancel1)
    apply (simp add: ac_simps)
    done
  finally show "?s = ?t" .
qed

lemma DFT_upper:
  assumes mbound: "0 < m" and ibound: "m <= i"
  shows "DFT (2 * m) a i =
    DFT m (%i. a (2 * i)) (i - m) -
    root (2 * m) ^ (i - m) * DFT m (%i. a (2 * i + 1)) (i - m)"
proof (unfold DFT_def)
  have "(j = 0..<2 * m. root (2 * m) ^ (i * j) * a j) =
    (j = 0..<m. root (2 * m) ^ (i * (2 * j)) * a (2 * j)) +
    (j = 0..<m. root (2 * m) ^ (i * (2 * j + 1)) * a (2 * j + 1))"
    (is "?s = _")
    by (simp add: sum_splice)
  also have "... =
    (j = 0..<m. root m ^ ((i - m) * j) * a (2 * j)) -
    root (2 * m) ^ (i - m) *
    (j = 0..<m. root m ^ ((i - m) * j) * a (2 * j + 1))"
    (is "_ = ?t")
    txt ‹First pair of sums›
    apply (simp add: root_cancel1)
    apply (simp add: root_unity ibound root_nonzero power_diff power_mult)
    txt ‹Second pair of sums›
    apply (simp add: mbound root_cancel2)
    apply (simp add: sum_distrib_left)
    apply (simp add: power_add)
    apply (simp add: root_cancel1)
    apply (simp add: power_mult)
    apply (simp add: ac_simps)
    done
  finally show "?s = ?t" .
qed

lemma IDFT_lower:
  "IDFT (2 * m) a i =
  IDFT m (%i. a (2 * i)) i +
  (1 / root (2 * m)) ^ i * IDFT m (%i. a (2 * i + 1)) i"
proof (unfold IDFT_def)
  have "(j = 0..<2 * m. a j / root (2 * m) ^ (i * j)) =
    (j = 0..<m. a (2 * j) / root (2 * m) ^ (i * (2 * j))) +
    (j = 0..<m. a (2 * j + 1) / root (2 * m) ^ (i * (2 * j + 1)))"
    (is "?s = _")
    by (simp add: sum_splice)
  also have "... = (j = 0..<m. a (2 * j) / root m ^ (i * j)) +
    (1 / root (2 * m)) ^ i *
    (j = 0..<m. a (2 * j + 1) / root m ^ (i * j))"
    (is "_ = ?t")
    txt ‹First pair of sums›
    apply (simp add: root_cancel1)
    txt ‹Second pair of sums›
    apply (simp add: sum_distrib_left)
    apply (simp add: power_add)
    apply (simp add: power_divide root_nonzero)
    apply (simp add: root_cancel1)
    done
  finally show "?s = ?t" .
qed

lemma IDFT_upper:
  assumes mbound: "0 < m" and ibound: "m <= i"
  shows "IDFT (2 * m) a i =
    IDFT m (%i. a (2 * i)) (i - m) -
    (1 / root (2 * m)) ^ (i - m) *
    IDFT m (%i. a (2 * i + 1)) (i - m)"
proof (unfold IDFT_def)
  have "(j = 0..<2 * m. a j / root (2 * m) ^ (i * j)) =
    (j = 0..<m. a (2 * j) / root (2 * m) ^ (i * (2 * j))) +
    (j = 0..<m. a (2 * j + 1) / root (2 * m) ^ (i * (2 * j + 1)))"
    (is "?s = _")
    by (simp add: sum_splice)
  also have "... =
    (j = 0..<m. a (2 * j) / root m ^ ((i - m) * j)) -
    (1 / root (2 * m)) ^ (i - m) *
    (j = 0..<m. a (2 * j + 1) / root m ^ ((i - m) * j))"
    (is "_ = ?t")
    txt ‹First pair of sums›
    apply (simp add: root_cancel1)
    apply (simp add: root_unity ibound root_nonzero power_diff power_mult)
    txt ‹Second pair of sums›
    apply (simp add: power_divide root_nonzero)
    apply (simp add: mbound root_cancel2)
    apply (simp add: sum_divide_distrib)
    apply (simp add: power_add)
    apply (simp add: root_cancel1)
    apply (simp add: power_mult)
    apply (simp add: ac_simps)
    done
  finally show "?s = ?t" .
qed

text DFT› und IDFT› are inverses.›

declare divide_divide_eq_right [simp del]
  divide_divide_eq_left [simp del]

lemma power_diff_inverse:
  assumes nz: "(a::'a::field) ~= 0"
  shows "m <= n ==> (inverse a) ^ (n-m) = (a^m) / (a^n)"
  apply (induct n m rule: diff_induct)
    apply (simp add: power_inverse
      nonzero_inverse_eq_divide [THEN sym] nz)
   apply simp
  apply (simp add: nz)
  done

lemma power_diff_rev_if:
  assumes nz: "(a::'a::field) ~= 0"
  shows "(a^m) / (a^n) = (if n <= m then a ^ (m-n) else (1/a) ^ (n-m))"
proof (cases "n <= m")
  case True with nz show ?thesis
    by (simp add: power_diff)
next
  case False with nz show ?thesis
    by (simp add: power_diff_inverse nonzero_inverse_eq_divide [THEN sym])
qed

lemma power_divides_special:
  "(a::'a::field) ~= 0 ==>
  a ^ (i * j) / a ^ (k * i) = (a ^ j / a ^ k) ^ i"
  by (simp add: power_divide power_mult [THEN sym] ac_simps)

theorem DFT_inverse:
  assumes i_less: "i < n"
  shows  "IDFT n (DFT n a) i = of_nat n * a i"
  using [[linarith_split_limit = 0]]
  apply (unfold DFT_def IDFT_def)
  apply (simp add: sum_divide_distrib)
  apply (subst sum.swap)
  apply (simp only: times_divide_eq_left [THEN sym])
  apply (simp only: power_divides_special [OF root_nonzero])
  apply (simp add: power_diff_rev_if root_nonzero)
  apply (simp add: sum_divide_distrib [THEN sym]
    sum_distrib_right [THEN sym])
  proof -
    from i_less have i_diff: "!!k. i - k < n" by arith
    have diff_i: "!!k. k < n ==> k - i < n" by arith

    let ?sum = "%i j n. sum ((^) (if i <= j then root n ^ (j - i)
                  else (1 / root n) ^ (i - j))) {0..<n} * a j"
    let ?sum1 = "%i j n. sum ((^) (root n ^ (j - i))) {0..<n} * a j"
    let ?sum2 = "%i j n. sum ((^) ((1 / root n) ^ (i - j))) {0..<n} * a j"

    from i_less have "(j = 0..<n. ?sum i j n) =
      (j = 0..<i. ?sum2 i j n) + (j = i..<n. ?sum1 i j n)"
      (is "?s = _")
      by (simp add: root_summation_inv nat_dvd_not_less
        sum_add_split_nat_ivl [where f = "%j. ?sum i j n"])
    also from i_less i_diff
    have "... = (j = i..<n. ?sum1 i j n)"
      by (simp add: root_summation_inv nat_dvd_not_less)
    also from i_less have "... =
      (j{i}  {i<..<n}. ?sum1 i j n)"
      by (simp only: ivl_disj_un)
    also have "... =
      (?sum1 i i n + (j{i<..<n}. ?sum1 i j n))"
      by (simp add: sum.union_disjoint ivl_disj_int)
    also from i_less diff_i have "... = ?sum1 i i n"
      by (simp add: root_summation nat_dvd_not_less)
    also from i_less have "... = of_nat n * a i" (is "_ = ?t")
      by simp
    finally show "?s = ?t" .
  qed


section ‹Discrete, Fast Fourier Transformation›

text FFT k a› is the transform of vector a›
  of length 2 ^ k›, IFFT› its inverse.›

primrec FFT :: "nat => (nat => complex) => (nat => complex)" where
  "FFT 0 a = a"
| "FFT (Suc k) a =
     (let (x, y) = (FFT k (%i. a (2*i)), FFT k (%i. a (2*i+1)))
      in (%i. if i < 2^k
            then x i + (root (2 ^ (Suc k))) ^ i * y i
            else x (i- 2^k) - (root (2 ^ (Suc k))) ^ (i- 2^k) * y (i- 2^k)))"

primrec IFFT :: "nat => (nat => complex) => (nat => complex)" where
  "IFFT 0 a = a"
| "IFFT (Suc k) a =
     (let (x, y) = (IFFT k (%i. a (2*i)), IFFT k (%i. a (2*i+1)))
      in (%i. if i < 2^k
            then x i + (1 / root (2 ^ (Suc k))) ^ i * y i
            else x (i - 2^k) -
              (1 / root (2 ^ (Suc k))) ^ (i - 2^k) * y (i - 2^k)))"

text ‹Finally, for vectors of length 2 ^ k›,
  DFT› and FFT›, and IDFT› and
  IFFT› are equivalent.›

theorem DFT_FFT:
  "!!a i. i < 2 ^ k ==> DFT (2 ^ k) a i = FFT k a i"
proof (induct k)
  case 0
  then show ?case by (simp add: DFT_def)
next
  case (Suc k)
  assume i: "i < 2 ^ Suc k"
  show ?case proof (cases "i < 2 ^ k")
    case True
    then show ?thesis apply simp apply (simp add: DFT_lower)
      apply (simp add: Suc) done
  next
    case False
    from i have "i - 2 ^ k < 2 ^ k" by simp
    with False i show ?thesis apply simp apply (simp add: DFT_upper)
      apply (simp add: Suc) done
  qed
qed

theorem IDFT_IFFT:
  "!!a i. i < 2 ^ k ==> IDFT (2 ^ k) a i = IFFT k a i"
proof (induct k)
  case 0
  then show ?case by (simp add: IDFT_def)
next
  case (Suc k)
  assume i: "i < 2 ^ Suc k"
  show ?case proof (cases "i < 2 ^ k")
    case True
    then show ?thesis apply simp apply (simp add: IDFT_lower)
      apply (simp add: Suc) done
  next
    case False
    from i have "i - 2 ^ k < 2 ^ k" by simp
    with False i show ?thesis apply simp apply (simp add: IDFT_upper)
      apply (simp add: Suc) done
  qed
qed

schematic_goal "map (FFT (Suc (Suc 0)) a) [0, 1, 2, 3] = ?x"
  by simp

end