Theory Strassen_Algorithm

theory Strassen_Algorithm
imports Matrix
(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Strassen's algorithm for matrix multiplication.›

text ‹We define the algorithm for arbitrary matrices over rings,
  where an alignment of the dimensions to even numbers will 
  be performed throughout the algorithm.›

theory Strassen_Algorithm
imports 
  Matrix
begin

text ‹With @{const four_block_mat} and @{const split_block} we can define Strassen's 
  multiplication algorithm.›

text ‹We start with a simple heuristic on when to switch to the basic algorithm.›

definition strassen_constant :: nat where
  [code_unfold]: "strassen_constant = 20"

definition "strassen_too_small A B ≡ 
  dim_row A < strassen_constant ∨ 
  dim_col A < strassen_constant ∨ 
  dim_col B < strassen_constant"

text ‹We have to make a case analysis on whether all dimensions are even.›
definition "strassen_even A B ≡ even (dim_row A) ∧ even (dim_col A) ∧ even (dim_col B)"

text ‹And then we can define the algorithm.›

function strassen_mat_mult :: "'a :: ring mat ⇒ 'a mat ⇒ 'a mat" where
  "strassen_mat_mult A B = (let nr = dim_row A; n = dim_col A; nc = dim_col B in
    if strassen_too_small A B then A * B else 
    if strassen_even A B then let
      nr2 = nr div 2;
      n2 = n div 2;
      nc2 = nc div 2;
      (A1,A2,A3,A4) = split_block A nr2 n2;
      (B1,B2,B3,B4) = split_block B n2 nc2;
      M1 = strassen_mat_mult (A1 + A4) (B1 + B4);
      M2 = strassen_mat_mult (A3 + A4) B1;
      M3 = strassen_mat_mult A1 (B2 - B4);
      M4 = strassen_mat_mult A4 (B3 - B1);
      M5 = strassen_mat_mult (A1 + A2) B4;
      M6 = strassen_mat_mult (A3 - A1) (B1 + B2);
      M7 = strassen_mat_mult (A2 - A4) (B3 + B4);
      C1 = M1 + M4 - M5 + M7;
      C2 = M3 + M5;
      C3 = M2 + M4;
      C4 = M1 - M2 + M3 + M6
    in four_block_mat C1 C2 C3 C4 else 
    let 
     nr' = (nr div 2) * 2;
     n' = (n div 2) * 2;
     nc' = (nc div 2) * 2;
     (A1,A2,A3,A4) = split_block A nr' n';
     (B1,B2,B3,B4) = split_block B n' nc';
     C1 = strassen_mat_mult A1 B1 + A2 * B3;
     C2 = A1 * B2 + A2 * B4;
     C3 = A3 * B1 + A4 * B3;
     C4 = A3 * B2 + A4 * B4
     in four_block_mat C1 C2 C3 C4)"
  by pat_completeness auto

text ‹For termination, we use the following measure.›

definition "strassen_measure ≡ λ (A,B). (dim_row A + dim_col A + dim_col B)
  + (dim_row A + dim_col A + dim_col B) + (if strassen_even A B then 0 else 1)"

lemma strassen_measure_add[simp]: 
  "strassen_measure (A + B, C) = strassen_measure (B,C)" 
  "strassen_measure (A, B + C) = strassen_measure (A,C)" 
  "strassen_measure (A - B, C) = strassen_measure (B,C)" 
  "strassen_measure (A, B - C) = strassen_measure (A,C)" 
  "strassen_measure (- A, B) = strassen_measure (A,B)"
  "strassen_measure (A, - B) = strassen_measure (A,B)"
  unfolding strassen_measure_def strassen_even_def by auto

lemma strassen_measure_div_2: assumes "(A1, A2, A3, A4) = split_block A (dim_row A div 2) (dim_col A div 2)"
  "(B1, B2, B3, B4) = split_block B (dim_col A div 2) (dim_col B div 2)"  
  and large: "¬ strassen_too_small A B"
  shows 
  "strassen_measure (A1,B4) < strassen_measure (A,B)"
  "strassen_measure (A1,B2) < strassen_measure (A,B)"
  "strassen_measure (A2,B4) < strassen_measure (A,B)"
  "strassen_measure (A3,B2) < strassen_measure (A,B)"
  "strassen_measure (A4,B1) < strassen_measure (A,B)"
  "strassen_measure (A4,B3) < strassen_measure (A,B)"
  "strassen_measure (A4,B4) < strassen_measure (A,B)"
proof -
  {
    fix Ai Bi
    assume Ai: "Ai ∈ {A1,A2,A3,A4}" and Bi: "Bi ∈ {B1,B2,B3,B4}"
    from large[unfolded strassen_too_small_def strassen_constant_def]
    have "¬ dim_row A < 2" by auto 
    with assms Ai Bi have Ar:
      "dim_row Ai < dim_row A"
      "dim_col Ai ≤ dim_col A"
      "dim_col Bi ≤ dim_col B" 
      unfolding split_block_def Let_def by auto    
    hence "strassen_measure (Ai,Bi) < strassen_measure (A,B)"
      unfolding strassen_measure_def split by auto
  }
  thus
    "strassen_measure (A1,B2) < strassen_measure (A,B)"
    "strassen_measure (A1,B4) < strassen_measure (A,B)"
    "strassen_measure (A2,B4) < strassen_measure (A,B)"
    "strassen_measure (A3,B2) < strassen_measure (A,B)"
    "strassen_measure (A4,B1) < strassen_measure (A,B)"
    "strassen_measure (A4,B3) < strassen_measure (A,B)"
    "strassen_measure (A4,B4) < strassen_measure (A,B)"
    by auto
qed

lemma strassen_measure_odd: assumes "(A1, A2, A3, A4) = split_block A ((dim_row A div 2) * 2) ((dim_col A div 2) * 2)"  
  and "(B1, B2, B3, B4) = split_block B ((dim_col A div 2) * 2) ((dim_col B div 2) * 2)"
  and odd: "¬ strassen_even A B"
  shows "strassen_measure (A1,B1) < strassen_measure (A,B)"
proof -
  from assms have Ar:
    "dim_row A1 < dim_row A ∨ dim_row A1 = dim_row A ∧ even (dim_row A)" 
    unfolding split_block_def Let_def by auto presburger
  from assms have Ac:
    "dim_col A1 < dim_col A ∨ dim_col A1 = dim_col A ∧ even (dim_col A)" 
    unfolding split_block_def Let_def by auto presburger
  from assms have Bc:
    "dim_col B1 < dim_col B ∨ dim_col B1 = dim_col B ∧ even (dim_col B)" 
    unfolding split_block_def Let_def by auto presburger
  from Ar Ac Bc odd show ?thesis unfolding strassen_measure_def strassen_even_def split
    by (auto split: if_splits)
qed

termination by (relation "measure strassen_measure", 
   auto elim: strassen_measure_div_2 strassen_measure_odd)


lemma strassen_mat_mult: 
  "dim_col A = dim_row B ⟹ strassen_mat_mult A B = A * B"
proof (induct A B rule: strassen_mat_mult.induct)
  case (1 A B)
  let ?nr = "dim_row A"
  let ?nc = "dim_col B"
  let ?n = "dim_col A"
  show ?case
  proof (cases "strassen_too_small A B")
    case False note large = this
    let ?smm = strassen_mat_mult
    note IH = 1(1-8)[OF refl refl refl False _ refl refl refl _ refl refl refl _ refl refl refl]
    show ?thesis
    proof (cases "strassen_even A B")
      case True
      note even = True[unfolded strassen_even_def]
      let ?nr2 = "?nr div 2"
      let ?n2 = "?n div 2"
      let ?nc2 = "?nc div 2"
      from even have nr: "?nr = ?nr2 + ?nr2" by presburger 
      from even have n: "?n = ?n2 + ?n2" by presburger 
      from even have nc: "?nc = ?nc2 + ?nc2" by presburger 
      from 1(9) even have n': "dim_row B = ?n2 + ?n2"
        by auto
      obtain A1 A2 A3 A4 where splitA: 
        "split_block A ?nr2 ?n2 = (A1,A2,A3,A4)" by (rule prod_cases4)
      obtain B1 B2 B3 B4 where splitB: 
        "split_block B ?n2 ?nc2 = (B1,B2,B3,B4)" by (rule prod_cases4)
      note IH = IH(1-7)[OF True splitA[symmetric] splitB[symmetric] ]
      from split_block[OF splitA nr n]
      have blockA: "A = four_block_mat A1 A2 A3 A4"
        and A1: "A1 ∈ carrier_mat ?nr2 ?n2" 
        and A2: "A2 ∈ carrier_mat ?nr2 ?n2" 
        and A3: "A3 ∈ carrier_mat ?nr2 ?n2" 
        and A4: "A4 ∈ carrier_mat ?nr2 ?n2" 
        by blast+
      from split_block[OF splitB n' nc]
      have blockB: "B = four_block_mat B1 B2 B3 B4"
        and B1: "B1 ∈ carrier_mat ?n2 ?nc2" 
        and B2: "B2 ∈ carrier_mat ?n2 ?nc2" 
        and B3: "B3 ∈ carrier_mat ?n2 ?nc2" 
        and B4: "B4 ∈ carrier_mat ?n2 ?nc2" 
        by blast+
      note carr = A1 A2 A3 A4 B1 B2 B3 B4
      let ?M11 = "A1 + A4" let ?M12 = "B1 + B4"
      let ?M21 = "A3 + A4" let ?M22 = "B1"
      let ?M31 = "A1" let ?M32 = "B2 - B4"
      let ?M41 = "A4" let ?M42 = "B3 - B1"
      let ?M51 = "A1 + A2" let ?M52 = "B4"
      let ?M61 = "A3 - A1" let ?M62 = "B1 + B2"
      let ?M71 = "A2 - A4" let ?M72 = "B3 + B4"
      let ?M1 = "?smm ?M11 ?M12"
      let ?M2 = "?smm ?M21 ?M22"
      let ?M3 = "?smm ?M31 ?M32"
      let ?M4 = "?smm ?M41 ?M42"
      let ?M5 = "?smm ?M51 ?M52"
      let ?M6 = "?smm ?M61 ?M62"
      let ?M7 = "?smm ?M71 ?M72"
      let ?C1 = "?M1 + ?M4 - ?M5 + ?M7"
      let ?C2 = "?M3 + ?M5"
      let ?C3 = "?M2 + ?M4"
      let ?C4 = "?M1 - ?M2 + ?M3 + ?M6"
      have res: "?smm A B = four_block_mat ?C1 ?C2 ?C3 ?C4"
        using large True
        unfolding strassen_mat_mult.simps[of A B] Let_def splitA splitB split
        by auto
      have M1: "?M1 = ?M11 * ?M12"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M2: "?M2 = ?M21 * ?M22"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M3: "?M3 = ?M31 * ?M32"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M4: "?M4 = ?M41 * ?M42"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M5: "?M5 = ?M51 * ?M52"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M6: "?M6 = ?M61 * ?M62"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M7: "?M7 = ?M71 * ?M72"
        by (rule IH(1), insert carr, auto)
      note distr = 
        add_mult_distrib_mat[of _ ?nr2 ?n2 _ _ ?nc2]
        minus_mult_distrib_mat[of _ ?nr2 ?n2 _ _ ?nc2]
        mult_add_distrib_mat[of _ ?nr2 ?n2 _ ?nc2]
        mult_minus_distrib_mat[of _ ?nr2 ?n2 _ ?nc2]
      note closed = add_carrier_mat[of _ ?nr2 ?nc2]
         uminus_carrier_iff_mat[of _ ?nr2 ?nc2]
      note ac = assoc_add_mat[of _ ?nr2 ?nc2] comm_add_mat[of _ ?nr2 ?nc2]
      show ?thesis unfolding res M1 M2 M3 M4 M5 M6 M7
        unfolding blockA blockB
          mult_four_block_mat[OF carr]
        by (rule cong_four_block_mat)
           (insert carr, auto simp: distr ac closed)
    next
      case False
      let ?nr2 = "?nr div 2 * 2" let ?nr2' = "?nr - ?nr2"
      let ?n2 = "?n div 2 * 2"   let ?n2' = "?n - ?n2"
      let ?nc2 = "?nc div 2 * 2" let ?nc2' = "?nc - ?nc2"
      have nr: "?nr = ?nr2 + ?nr2'" by presburger 
      have n: "?n = ?n2 + ?n2'" by presburger 
      have nc: "?nc = ?nc2 + ?nc2'" by presburger 
      from 1(9) have n': "dim_row B = ?n2 + ?n2'" by auto   
      obtain A1 A2 A3 A4 where splitA: 
        "split_block A ?nr2 ?n2 = (A1,A2,A3,A4)" by (rule prod_cases4)
      obtain B1 B2 B3 B4 where splitB: 
        "split_block B ?n2 ?nc2 = (B1,B2,B3,B4)" by (rule prod_cases4)
      note IH = IH(8)[OF False splitA[symmetric] splitB[symmetric]]
      from split_block[OF splitA nr n]
      have blockA: "A = four_block_mat A1 A2 A3 A4"
        and A1: "A1 ∈ carrier_mat ?nr2 ?n2" 
        and A2: "A2 ∈ carrier_mat ?nr2 ?n2'" 
        and A3: "A3 ∈ carrier_mat ?nr2' ?n2" 
        and A4: "A4 ∈ carrier_mat ?nr2' ?n2'" 
        by blast+
      from split_block[OF splitB n' nc]
      have blockB: "B = four_block_mat B1 B2 B3 B4"
        and B1: "B1 ∈ carrier_mat ?n2 ?nc2" 
        and B2: "B2 ∈ carrier_mat ?n2 ?nc2'" 
        and B3: "B3 ∈ carrier_mat ?n2' ?nc2" 
        and B4: "B4 ∈ carrier_mat ?n2' ?nc2'" 
        by blast+      
      note carr = A1 A2 A3 A4 B1 B2 B3 B4
      from carr have "dim_col A1 = dim_row B1" by simp
      note IH = IH[OF this]
      have "?smm A B = four_block_mat 
        (A1 * B1 + A2 * B3)
        (A1 * B2 + A2 * B4)
        (A3 * B1 + A4 * B3)
        (A3 * B2 + A4 * B4)"
        unfolding strassen_mat_mult.simps[of A B] Let_def 
          splitA splitB split IH using large False by auto
      also have "… = A * B"
        unfolding blockA blockB
         mult_four_block_mat[OF carr] by simp
      finally show ?thesis by simp
    qed
  qed simp
qed

end