Theory Deutsch_Jozsa

(*
Authors: 
  Hanna Lachnitt, TU Wien, lachnitt@student.tuwien.ac.at
  Anthony Bordg, University of Cambridge, apdb3@cam.ac.uk
*) 

section ‹The Deutsch-Jozsa Algorithm›

theory Deutsch_Jozsa
imports
  Deutsch
  More_Tensor
  Binary_Nat
begin


text ‹
Given a function $f:{0,1}^n \mapsto {0,1}$, the Deutsch-Jozsa algorithm decides if this function is 
constant or balanced with a single $f(x)$ circuit to evaluate the function for multiple values of $x$ 
simultaneously. The algorithm makes use of quantum parallelism and quantum interference.

A constant function with values in {0,1} returns either always 0 or always 1. 
A balanced function is 0 for half of the inputs and 1 for the other half. 
›

locale bob_fun =
  fixes f:: "nat  nat" and n:: "nat"
  assumes dom: "f  ({(i::nat). i < 2^n} E {0,1})"
  assumes dim: "n  1"

context bob_fun
begin

definition const:: "nat  bool" where 
"const c = (x{i::nat. i<2^n}. f x = c)"

definition is_const:: bool where 
"is_const  const 0  const 1"

definition is_balanced:: bool where
"is_balanced  A B ::nat set. A  {i::nat. i < 2^n}  B  {i::nat. i < 2^n}
                    card A = 2^(n-1)  card B = 2^(n-1)  
                    (xA. f x = 0)   (xB. f x = 1)"

lemma is_balanced_inter: 
  fixes A B:: "nat set"
  assumes "x  A. f x = 0" and "x  B. f x = 1" 
  shows "A  B = {}" 
  using assms by auto

lemma is_balanced_union:
  fixes A B:: "nat set"
  assumes "A  {i::nat. i < 2^n}" and "B  {i::nat. i < 2^n}" 
      and "card A = 2^(n-1)" and "card B = 2^(n-1)" 
      and "A  B = {}"
  shows "A  B = {i::nat. i < 2^n}"
proof-
  have "finite A" and "finite B" 
    by (simp add: assms(3) card_ge_0_finite)
      (simp add: assms(4) card_ge_0_finite)
  then have "card(A  B) = 2 * 2^(n-1)" 
    using assms(3-5) by (simp add: card_Un_disjoint)
  then have "card(A  B) = 2^n"
    by (metis Nat.nat.simps(3) One_nat_def dim le_0_eq power_eq_if)
  moreover have " = card({i::nat. i < 2^n})" by simp
  moreover have "A  B  {i::nat. i < 2^n}" 
    using assms(1,2) by simp
  moreover have "finite ({i::nat. i < 2^n})" by simp
  ultimately show ?thesis 
    using card_subset_eq[of "{i::nat. i < 2^n}" "A  B"] by simp
qed

lemma f_ge_0: "x. f x  0" by simp

lemma f_dom_not_zero: 
  shows "f  ({i::nat. n  1  i < 2^n} E {0,1})" 
  using dim dom by simp

lemma f_values: "x  {(i::nat). i < 2^n} . f x = 0  f x = 1" 
  using dom by auto

end (* bob_fun *)

text ‹The input function has to be constant or balanced.›

locale jozsa = bob_fun +
  assumes const_or_balanced: "is_const  is_balanced "

text ‹
Introduce two customised rules: disjunctions with four disjuncts and induction starting from one 
instead of zero.
›

(* To deal with Uf it is often necessary to do a case distinction with four different cases.*)
lemma disj_four_cases:
  assumes "A  B  C  D" and "A  P" and "B  P" and "C  P" and "D  P"
  shows "P" 
  using assms by auto

text ‹The unitary transform @{term Uf}.›

definition (in jozsa) jozsa_transform:: "complex Matrix.mat" (Uf) where 
"Uf  Matrix.mat (2^(n+1)) (2^(n+1)) (λ(i,j). 
  if i = j then (1-f(i div 2)) else 
    if i = j + 1  odd i then f(i div 2) else
      if i = j - 1  even i  j1 then f(i div 2) else 0)"

lemma (in jozsa) jozsa_transform_dim [simp]:
  shows "dim_row Uf = 2^(n+1)" and "dim_col Uf = 2^(n+1)" 
  by (auto simp add: jozsa_transform_def)

lemma (in jozsa) jozsa_transform_coeff_is_zero [simp]:
  assumes "i < dim_row Uf  j < dim_col Uf"
  shows "(ij  ¬(i=j+1  odd i)  ¬ (i=j-1  even i  j1))  Uf $$ (i,j) = 0"
  using jozsa_transform_def assms by auto

lemma (in jozsa) jozsa_transform_coeff [simp]: 
  assumes "i < dim_row Uf  j < dim_col Uf"
  shows "i = j  Uf $$ (i,j) = 1 - f (i div 2)"
  and "i = j + 1  odd i  Uf $$ (i,j) = f (i div 2)"
  and "j  1  i = j - 1  even i  Uf $$ (i,j) = f (i div 2)" 
  using jozsa_transform_def assms by auto

lemma (in jozsa) Uf_mult_without_empty_summands_sum_even:
  fixes i j A
  assumes "i < dim_row Uf" and "j < dim_col A" and "even i" and "dim_col Uf = dim_row A"
  shows "(k{0..< dim_row A}. Uf $$ (i,k) * A $$ (k,j)) =(k{i,i+1}. Uf $$ (i,k) * A $$ (k,j))"
proof-
  have "(k  {0..< 2^(n+1)}. Uf $$ (i,k) * A $$ (k,j)) = 
             (k  {0..<i}. Uf $$ (i,k) * A $$ (k,j)) +
             (k  {i,i+1}. Uf $$ (i,k) * A $$ (k,j)) +
             (k  {(i+2)..< 2^(n+1)}. Uf $$ (i,k) * A $$ (k,j))" 
  proof- 
    have "{0..< 2^(n+1)} = {0..<i}  {i..< 2^(n+1)} 
           {i..< 2^(n+1)} = {i,i+1}  {(i+2)..<2^(n+1)}" using assms(1-3) by auto
    moreover have "{0..<i}  {i,i+1} = {} 
                   {i,i+1}  {(i+2)..< 2^(n+1)} = {} 
                   {0..<i}  {(i+2)..< 2^(n+1)} = {}" using assms by simp
    ultimately show ?thesis
      using sum.union_disjoint
      by (metis (no_types, lifting) finite_Un finite_atLeastLessThan is_num_normalize(1) ivl_disj_int_two(3))
  qed
  moreover have "(k  {0..<i}. Uf $$ (i,k) * A $$ (k,j)) = 0" 
  proof-
    have "k  {0..<i}  (ik  ¬(i=k+1  odd i)  ¬ (i=k-1  even i  k1))" for k 
      using assms by auto
    then have "k  {0..<i}  Uf $$ (i,k) = 0" for k
      using assms(1) by auto
    then show ?thesis by simp
  qed
  moreover have "(k  {(i+2)..< 2^(n+1)}. Uf $$ (i,k) * A $$ (k,j)) = 0" 
  proof- 
    have "k{(i+2)..< 2^(n+1)}  (ik  ¬(i=k+1  odd i)  ¬ (i=k-1  even i  k1))" for k by auto
    then have "k  {(i+2)..< 2^(n+1)} Uf $$ (i,k) = 0" for k
      using assms(1) by auto
    then show ?thesis by simp
  qed
  moreover have  "dim_row A = 2^(n+1)" using assms(4) by simp
  ultimately show "?thesis" by(metis (no_types, lifting) add.left_neutral add.right_neutral)
qed

lemma (in jozsa) Uf_mult_without_empty_summands_even: 
  fixes i j A
  assumes "i < dim_row Uf" and "j < dim_col A" and "even i" and "dim_col Uf = dim_row A"
  shows "(Uf * A) $$ (i,j) = (k  {i,i+1}. Uf $$ (i,k) * A $$ (k,j))"
proof-
  have "(Uf * A) $$ (i,j) = ( k{0..< dim_row A}. (Uf $$ (i,k)) * (A $$ (k,j)))"
    using assms(1,2,4) index_matrix_prod by (simp add: atLeast0LessThan)
  then show ?thesis
    using assms Uf_mult_without_empty_summands_sum_even by simp
qed

lemma (in jozsa) Uf_mult_without_empty_summands_sum_odd:
  fixes i j A
  assumes "i < dim_row Uf" and "j < dim_col A" and "odd i" and "dim_col Uf = dim_row A"
  shows "(k{0..< dim_row A}. Uf $$ (i,k) * A $$ (k,j)) =(k{i-1,i}. Uf $$ (i,k) * A $$ (k,j))"
proof-
  have "(k{0..< 2^(n+1)}. Uf $$ (i,k) * A $$ (k,j)) = 
             (k  {0..<i-1}. Uf $$ (i,k) * A $$ (k,j)) +
             (k  {i-1,i}. Uf $$ (i,k) * A $$ (k,j)) +
             (k  {i+1..< 2^(n+1)}. Uf $$ (i,k) * A $$ (k,j))" 
  proof- 
    have "{0..< 2^(n+1)} = {0..<i-1}  {i-1..< 2^(n+1)} 
           {i-1..< 2^(n+1)} = {i-1,i}  {i+1..<2^(n+1)}" using assms(1-3) by auto
    moreover have "{0..<i-1}  {i-1,i} = {} 
                   {i-1,i}  {i+1..< 2^(n+1)} = {} 
                   {0..<i-1}  {i+1..< 2^(n+1)} = {}" using assms by simp
    ultimately show ?thesis
      using sum.union_disjoint 
      by(metis (no_types, lifting) finite_Un finite_atLeastLessThan is_num_normalize(1) ivl_disj_int_two(3))
  qed
  moreover have "(k  {0..<i-1}. Uf $$ (i,k) * A $$ (k,j)) = 0"
  proof-
    have "k  {0..<i-1}  (ik  ¬(i=k+1  odd i)  ¬ (i=k-1  even i  k1))" for k by auto
    then have "k  {0..<i-1}  Uf $$ (i,k) = 0" for k
      using assms(1) by auto
    then show ?thesis by simp
  qed
  moreover have "(k  {i+1..< 2^(n+1)}. Uf $$ (i,k) * A $$ (k,j)) = 0" 
    using assms(3) by auto 
  moreover have  "dim_row A = 2^(n+1)" using assms(4) by simp
  ultimately show "?thesis" by(metis (no_types, lifting) add.left_neutral add.right_neutral)
qed

lemma (in jozsa) Uf_mult_without_empty_summands_odd: 
  fixes i j A
  assumes "i < dim_row Uf" and "j < dim_col A" and "odd i" and "dim_col Uf = dim_row A"
  shows "(Uf * A) $$ (i,j) = (k  {i-1,i}. Uf $$ (i,k) * A $$ (k,j)) "
proof-
  have "(Uf * A) $$ (i,j) = (k  {0 ..< dim_row A}. (Uf $$ (i,k)) * (A $$ (k,j)))"
    using assms(1,2,4) index_matrix_prod by (simp add: atLeast0LessThan)
  then show "?thesis" 
    using assms Uf_mult_without_empty_summands_sum_odd by auto
qed

text @{term Uf} is a gate.›

lemma (in jozsa) transpose_of_jozsa_transform:
  shows "(Uf)t = Uf"
proof
  show "dim_row (Uft) = dim_row Uf" by simp
next
  show "dim_col (Uft) = dim_col Uf" by simp
next
  fix i j:: nat
  assume a0: "i < dim_row Uf" and a1: "j < dim_col Uf"
  then show "Uft $$ (i, j) = Uf $$ (i, j)" 
  proof (induct rule: disj_four_cases)
    show "i=j  (i=j+1  odd i)  (i=j-1  even i  j1)  (ij  ¬(i=j+1  odd i)  ¬ (i=j-1  even i  j1))" 
      by linarith
  next
    assume "i = j"
    then show "Uft $$ (i,j) = Uf $$ (i,j)" using a0 by simp
  next
    assume "(i=j+1  odd i)"
    then show "Uft $$ (i,j) = Uf $$ (i,j)" using transpose_mat_def a0 a1 by auto
  next
    assume a2:"(i=j-1  even i  j1)"
    then have "Uf $$ (i,j) = f (i div 2)" 
      using a0 a1 jozsa_transform_coeff by auto
    moreover have "Uf $$ (j,i) = f (i div 2)" 
      using a0 a1 a2 jozsa_transform_coeff
      by (metis add_diff_assoc2 diff_add_inverse2 even_plus_one_iff even_succ_div_two jozsa_transform_dim)
    ultimately show "?thesis"
      using transpose_mat_def a0 a1 by simp
  next 
    assume a2:"(ij  ¬(i=j+1  odd i)  ¬ (i=j-1  even i  j1))"
    then have "(ji  ¬(j=i+1  odd j)  ¬ (j=i-1  even j  i1))" 
      by (metis le_imp_diff_is_add diff_add_inverse even_plus_one_iff le_add1)
    then have "Uf $$ (j,i) = 0" 
      using jozsa_transform_coeff_is_zero a0 a1 by auto
    moreover have "Uf $$ (i,j) = 0" 
      using jozsa_transform_coeff_is_zero a0 a1 a2 by auto
    ultimately show "Uft $$ (i,j) = Uf $$ (i,j)"
      using transpose_mat_def a0 a1 by simp
  qed 
qed

lemma (in jozsa) adjoint_of_jozsa_transform: 
  shows "(Uf) = Uf"
proof
  show "dim_row (Uf) = dim_row Uf" by simp
next
  show "dim_col (Uf) = dim_col Uf" by simp
next
  fix i j:: nat
  assume a0: "i < dim_row Uf" and a1: "j < dim_col Uf"
  then show "Uf $$ (i,j) = Uf $$ (i,j)"
  proof (induct rule: disj_four_cases)
  show "i=j  (i=j+1  odd i)  (i=j-1  even i  j1)  (ij  ¬(i=j+1  odd i)  ¬ (i=j-1  even i  j1))"
    by linarith
  next
    assume "i=j"
    then show "Uf $$ (i,j) = Uf $$ (i,j)" using a0 dagger_def by simp
  next
    assume "(i=j+1  odd i)"
    then show "Uf $$ (i,j) = Uf $$ (i,j)" using a0 dagger_def by auto
  next
    assume a2:"(i=j-1  even i  j1)"
    then have "Uf $$ (i,j) = f (i div 2)" 
      using a0 a1 jozsa_transform_coeff by auto
    moreover have "Uf  $$ (j,i) = f (i div 2)" 
      using a1 a2 jozsa_transform_coeff dagger_def by auto
    ultimately show "Uf $$ (i,j) = Uf $$ (i,j)"
      by(metis a0 a1 cnj_transpose_is_dagger dim_row_of_dagger index_transpose_mat dagger_of_transpose_is_cnj transpose_of_jozsa_transform)
  next 
    assume a2: "(ij  ¬(i=j+1  odd i)  ¬ (i=j-1  even i  j1))"
    then have f0:"(ij  ¬(j=i+1  odd j)  ¬ (j=i-1  even j  i1))" 
      by (metis le_imp_diff_is_add diff_add_inverse even_plus_one_iff le_add1)
    then have "Uf $$ (j,i) = 0" and "cnj 0 = 0"
      using jozsa_transform_coeff_is_zero a0 a1 a2 by auto
    then have "Uf $$ (i,j) = 0" 
      using a0 a1 dagger_def by simp
    then show "Uf $$ (i, j) = Uf $$ (i, j)" 
      using a0 a1 a2 jozsa_transform_coeff_is_zero by auto
  qed 
qed

lemma (in jozsa) jozsa_transform_is_unitary_index_even:
  fixes i j:: nat
  assumes "i < dim_row Uf" and "j < dim_col Uf" and "even i"
  shows "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)"
proof-
  have "(Uf * Uf) $$ (i,j) = (k  {i,i+1}. Uf $$ (i,k) * Uf $$ (k,j)) " 
    using Uf_mult_without_empty_summands_even[of i j Uf ] assms by simp
  moreover have "Uf $$ (i,i) * Uf $$ (i,j) = (1-f(i div 2)) * Uf $$ (i,j)"
    using assms(1,3) by simp
  moreover have f0: "Uf $$ (i,i+1) * Uf $$ (i+1,j) = f(i div 2) * Uf $$ (i+1,j)"
    by (metis One_nat_def Suc_leI add.right_neutral add_Suc_right assms(1) assms(3) diff_add_inverse2 
even_add even_mult_iff jozsa_transform_coeff(3) jozsa_transform_dim le_add2 le_eq_less_or_eq odd_one 
one_add_one power.simps(2))
  ultimately have f1: "(Uf * Uf) $$ (i,j) = (1-f(i div 2)) * Uf $$ (i,j) +  f(i div 2) * Uf $$ (i+1,j)" by auto
  thus ?thesis
  proof (induct rule: disj_four_cases)
    show "j=i  (j=i+1  odd j)  (j=i-1  even j  i1)  (ji  ¬(j=i+1  odd j)  ¬ (j=i-1  even j  i1))"
      by linarith
  next
    assume a0:"j=i"
    then have "Uf $$ (i,j) = (1-f(i div 2))" 
      using assms(1,2) a0 by simp
    moreover have "Uf $$ (i+1,j) = f(i div 2)"
      using assms(1,3) a0 by auto
    ultimately have "(Uf * Uf) $$ (i,j) = (1-f(i div 2)) * (1-f(i div 2)) +  f(i div 2) * f(i div 2)" 
      using f1 by simp
    moreover have "(1-f(i div 2)) * (1-f(i div 2)) + f(i div 2) * f(i div 2) = 1" 
      using f_values assms(1)
      by (metis (no_types, lifting) Nat.minus_nat.diff_0 diff_add_0 diff_add_inverse jozsa_transform_dim(1) 
          less_power_add_imp_div_less mem_Collect_eq mult_eq_if one_power2 power2_eq_square power_one_right) 
    ultimately show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)"  by(metis assms(2) a0 index_one_mat(1) of_nat_1)
  next
    assume a0: "(j=i+1  odd j)"
    then have "Uf $$ (i,j) = f(i div 2)" 
      using assms(1,2) a0 by simp
    moreover have "Uf $$ (i+1,j) = (1-f(i div 2))"
      using assms(2,3) a0 by simp
    ultimately have "(Uf * Uf) $$ (i,j) = (1-f(i div 2)) * f(i div 2) + f(i div 2) * (1-f(i div 2))"
      using f0 f1 assms by simp
    then show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)" 
      using assms(1,2) a0 by auto
  next
    assume "(j=i-1  even j  i1)"
    then show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)" 
      using assms(3) dvd_diffD1 odd_one by blast
  next 
    assume a0:"(ji  ¬(j=i+1  odd j)  ¬ (j=i-1  even j  i1))"
    then have "Uf $$ (i,j) = 0" 
      using assms(1,2) by(metis index_transpose_mat(1) jozsa_transform_coeff_is_zero jozsa_transform_dim transpose_of_jozsa_transform)
    moreover have "Uf $$ (i+1,j) = 0" 
      using assms a0 by auto
    ultimately have "(Uf * Uf) $$ (i,j) = (1-f(i div 2)) * 0 +  f(i div 2) * 0" 
      by (simp add: f1)
    then show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)" 
      using a0 assms(1,2) by(metis add.left_neutral index_one_mat(1) jozsa_transform_dim mult_0_right of_nat_0)
  qed
qed

lemma (in jozsa) jozsa_transform_is_unitary_index_odd:
  fixes i j:: nat
  assumes "i < dim_row Uf" and "j < dim_col Uf" and "odd i"
  shows "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)"
proof-
  have f0: "i  1"  
    using linorder_not_less assms(3) by auto
  have "(Uf * Uf) $$ (i,j) = (k  {i-1,i}. Uf $$ (i,k) * Uf $$ (k,j)) " 
    using Uf_mult_without_empty_summands_odd[of i j Uf ] assms by simp
  moreover have "(k  {i-1,i}. Uf $$ (i,k) * Uf $$ (k,j)) 
                 = Uf $$ (i,i-1) * Uf $$ (i-1,j) +  Uf $$ (i,i) * Uf $$ (i,j)"
    using f0 by simp
  moreover have "Uf $$ (i,i) * Uf $$ (i,j) = (1-f(i div 2)) * Uf $$ (i,j)" 
    using assms(1,2) by simp
  moreover have f1: "Uf $$ (i,i-1) * Uf $$ (i-1,j) = f(i div 2) * Uf $$ (i-1,j)" 
    using assms(1) assms(3) by simp
  ultimately have f2: "(Uf * Uf) $$ (i,j) = f(i div 2) * Uf $$ (i-1,j) + (1-f(i div 2)) * Uf $$ (i,j)" by simp
  then show "?thesis"
  proof (induct rule: disj_four_cases)
    show "j=i  (j=i+1  odd j)  (j=i-1  even j  i1)  (ji  ¬(j=i+1  odd j)  ¬ (j=i-1  even j  i1))"
      by linarith
  next
    assume a0:"j=i"
    then have "Uf $$ (i,j) = (1-f(i div 2))"
      using assms(1,2) by simp
    moreover have "Uf $$ (i-1,j) = f(i div 2)"
      using a0 assms
      by (metis index_transpose_mat(1) jozsa_transform_coeff(2) less_imp_diff_less odd_two_times_div_two_nat 
          odd_two_times_div_two_succ transpose_of_jozsa_transform)
    ultimately have "(Uf * Uf) $$ (i,j) = f(i div 2) * f(i div 2) + (1-f(i div 2)) * (1-f(i div 2))"
      using f2 by simp
    moreover have "f(i div 2) * f(i div 2) + (1-f(i div 2)) * (1-f(i div 2)) = 1" 
      using f_values assms(1)
      by (metis (no_types, lifting) Nat.minus_nat.diff_0 diff_add_0 diff_add_inverse jozsa_transform_dim(1) 
          less_power_add_imp_div_less mem_Collect_eq mult_eq_if one_power2 power2_eq_square power_one_right) 
    ultimately show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)" by(metis assms(2) a0 index_one_mat(1) of_nat_1)
  next
    assume a0:"(j=i+1  odd j)"
    then show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)" 
      using assms(3) dvd_diffD1 odd_one even_plus_one_iff by blast
  next
    assume a0:"(j=i-1  even j  i1)"
    then have "(Uf * Uf) $$ (i,j) = f(i div 2) * (1-f(i div 2)) + (1-f(i div 2)) * f(i div 2)" 
      using f0 f1 f2 assms
      by (metis jozsa_transform_coeff(1) Groups.ab_semigroup_mult_class.mult.commute even_succ_div_two f2 
          jozsa_transform_dim odd_two_times_div_two_nat odd_two_times_div_two_succ of_nat_add of_nat_mult)
    then show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)" 
      using assms(1) a0 by auto
  next 
    assume a0:"ji  ¬(j=i+1  odd j)  ¬ (j=i-1  even j  i1)"
    then have "Uf $$ (i,j) = 0" 
      by (metis assms(1,2) index_transpose_mat(1) jozsa_transform_coeff_is_zero jozsa_transform_dim transpose_of_jozsa_transform)
    moreover have "Uf $$ (i-1,j) = 0" 
      using assms a0 f0 
      by auto (smt (verit) One_nat_def Suc_n_not_le_n add_diff_inverse_nat assms(1) assms(2) diff_Suc_less even_add 
jozsa_transform_coeff_is_zero jozsa_axioms less_imp_le less_le_trans less_one odd_one)
    ultimately have "(Uf * Uf) $$ (i,j) = (1-f(i div 2)) * 0 +  f(i div 2) * 0" 
      using f2 by simp
    then show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)" 
      using a0 assms by (metis add.left_neutral index_one_mat(1) jozsa_transform_dim mult_0_right of_nat_0)
  qed
qed

lemma (in jozsa) jozsa_transform_is_gate:
  shows "gate (n+1) Uf"
proof
  show "dim_row Uf = 2^(n+1)" by simp
next
  show "square_mat Uf" by simp
next
  show "unitary Uf"
  proof-
    have "Uf * Uf = 1m (dim_col Uf)"
    proof
      show "dim_row (Uf * Uf) = dim_row (1m (dim_col Uf))" by simp
      show "dim_col (Uf * Uf) = dim_col (1m (dim_col Uf))" by simp
      fix i j:: nat
      assume "i < dim_row (1m (dim_col Uf))" and "j < dim_col (1m (dim_col Uf))"
      then have "i < dim_row Uf" and "j < dim_col Uf" by auto
      then show "(Uf * Uf) $$ (i,j) = 1m (dim_col Uf) $$ (i,j)" 
        using jozsa_transform_is_unitary_index_odd jozsa_transform_is_unitary_index_even by blast
    qed
    thus ?thesis by (simp add: adjoint_of_jozsa_transform unitary_def)
  qed
qed

text ‹N-fold application of the tensor product›

fun iter_tensor:: "complex Matrix.mat  nat  complex Matrix.mat" (‹_ ⊗⇗_ 75)  where
  "A ⊗⇗(Suc 0)= A"  
| "A ⊗⇗(Suc k)= A  (A ⊗⇗k)"

lemma one_tensor_is_id [simp]:
  fixes A
  shows "A ⊗⇗1= A"
  using one_mat_def by simp

lemma iter_tensor_suc: 
  fixes n
  assumes "n  1"
  shows " A ⊗⇗(Suc n)= A  (A ⊗⇗n)" 
  using assms by (metis Deutsch_Jozsa.iter_tensor.simps(2) One_nat_def Suc_le_D)

lemma dim_row_of_iter_tensor [simp]:
  fixes A n
  assumes "n  1"
  shows "dim_row(A ⊗⇗n) = (dim_row A)^n"
  using assms
proof (rule nat_induct_at_least)
  show "dim_row (A ⊗⇗1) = (dim_row A)^1"
    using one_tensor_is_id by simp
  fix n:: nat
  assume "n  1" and "dim_row (A ⊗⇗n) = (dim_row A)^n"
  then show "dim_row (A ⊗⇗Suc n) = (dim_row A)^Suc n"
    using iter_tensor_suc assms dim_row_tensor_mat by simp
qed

lemma dim_col_of_iter_tensor [simp]:
  fixes A n
  assumes "n  1"
  shows "dim_col(A ⊗⇗n) = (dim_col A)^n"
  using assms
proof (rule nat_induct_at_least)
  show "dim_col (A ⊗⇗1) = (dim_col A)^1"
    using one_tensor_is_id by simp
  fix n:: nat
  assume "n  1" and "dim_col (A ⊗⇗n) = (dim_col A)^n"
  then show "dim_col (A ⊗⇗Suc n) = (dim_col A)^Suc n"
    using iter_tensor_suc assms dim_col_tensor_mat by simp
qed

lemma iter_tensor_values:
  fixes A n i j
  assumes "n  1" and "i < dim_row (A  (A ⊗⇗n))" and "j < dim_col (A  (A ⊗⇗n))"
  shows "(A ⊗⇗(Suc n)) $$ (i,j) = (A  (A ⊗⇗n)) $$ (i,j)"
  using assms by (metis One_nat_def le_0_eq not0_implies_Suc iter_tensor.simps(2))

lemma iter_tensor_mult_distr:
  assumes "n  1" and "dim_col A = dim_row B" and "dim_col A > 0" and "dim_col B > 0"
  shows "(A ⊗⇗(Suc n)) * (B ⊗⇗(Suc n)) = (A * B)  ((A ⊗⇗n) * (B ⊗⇗n))" 
proof-
  have "(A ⊗⇗(Suc n)) * (B ⊗⇗(Suc n)) = (A  (A ⊗⇗n)) * (B  (B ⊗⇗n))" 
    using Suc_le_D assms(1) by fastforce
  then show "?thesis" 
    using mult_distr_tensor[of "A" "B" "(iter_tensor A n)" "(iter_tensor B n)"] assms by simp
qed

lemma index_tensor_mat_with_vec2_row_cond:
  fixes A B:: "complex Matrix.mat" and i:: "nat" 
  assumes "i < 2 * (dim_row B)" and "i  dim_row B" and "dim_col B > 0"
and "dim_row A = 2" and "dim_col A = 1"
  shows "(A  B) $$ (i,0) = (A $$ (1,0)) * (B $$ (i-dim_row B,0))"
proof-
  have "(A  B) $$ (i,0) = A $$ (i div (dim_row B),0) * B $$ (i mod (dim_row B),0)"
    using assms index_tensor_mat[of A "dim_row A" "dim_col A" B "dim_row B" "dim_col B" i 0] by simp
  moreover have "i div (dim_row B) = 1" 
    using assms(1,2,4) by simp
  then have "i mod (dim_row B) = i - (dim_row B)" 
    by (simp add: modulo_nat_def)
  ultimately show "(A  B) $$ (i,0) = (A $$ (1,0)) * (B $$ (i-dim_row B,0))" 
    by (simp add: i div dim_row B = 1)
qed

lemma iter_tensor_of_gate_is_gate:
  fixes A:: "complex Matrix.mat" and n m:: "nat" 
  assumes "gate m A" and "n  1" 
  shows "gate (m*n) (A ⊗⇗n)"
  using assms(2)
proof(rule nat_induct_at_least)
  show "gate (m * 1) (A ⊗⇗1)" using assms(1) by simp
  fix n:: nat
  assume "n  1" and IH:"gate (m * n) (A ⊗⇗n)"
  then have "A ⊗⇗(Suc n)= A  (A ⊗⇗n)" 
    by (simp add: iter_tensor_suc)
  moreover have "gate (m*n + m) (A ⊗⇗(Suc n))"  
    using tensor_gate assms(1) by (simp add: IH add.commute calculation(1))
  then show "gate (m*(Suc n)) (A ⊗⇗(Suc n))"
    by (simp add: add.commute)
qed

lemma iter_tensor_of_state_is_state:
  fixes A:: "complex Matrix.mat" and n m:: "nat" 
  assumes "state m A" and "n1" 
  shows "state (m*n) (A ⊗⇗n)"
  using assms(2)
proof(rule nat_induct_at_least)
  show "state (m * 1) (A ⊗⇗1)"
    using one_tensor_is_id assms(1) by simp
  fix n:: nat
  assume "n  1" and IH:"state (m * n) (A ⊗⇗n)"
  then have "A ⊗⇗(Suc n)= A  (A ⊗⇗n)" 
    by (simp add: iter_tensor_suc)
  moreover have "state (m*n + m) (A ⊗⇗(Suc n))"  
    using tensor_gate assms(1) by (simp add: IH add.commute calculation)
  then show "state (m*(Suc n)) (A ⊗⇗(Suc n))" 
    by (simp add: add.commute)
qed

text ‹
We prepare n+1 qubits. The first n qubits in the state $|0\rangle$, the last one in the state 
$|1\rangle$.
›

abbreviation ψ10:: "nat  complex Matrix.mat" where
"ψ10 n  Matrix.mat (2^n) 1 (λ(i,j). 1/(sqrt 2)^n)" 

lemma ψ10_values:
  fixes i j n
  assumes "i < dim_row (ψ10 n)" and "j < dim_col (ψ10 n)" 
  shows "(ψ10 n) $$ (i,j) = 1/(sqrt 2)^n" 
  using assms case_prod_conv by simp

text ‹$H^{\otimes n}$ is applied to $|0\rangle^{\otimes n}$.›

lemma H_on_ket_zero: 
  shows "(H * |zero) = ψ10 1"
proof 
  fix i j:: nat
  assume "i < dim_row (ψ10 1)" and "j < dim_col (ψ10 1)"
  then have f1: "i  {0,1}  j = 0" by (simp add: less_2_cases)
  then show "(H * |zero) $$ (i,j) = (ψ10 1) $$ (i,j)"
    by (auto simp add: times_mat_def scalar_prod_def H_def ket_vec_def)
next
  show "dim_row (H * |zero) = dim_row (ψ10 1)"  by (simp add: H_def)
  show "dim_col (H * |zero) = dim_col (ψ10 1)" using H_def  
    by (simp add: ket_vec_def)
qed

lemma ψ10_tensor: 
  assumes "n  1"
  shows "(ψ10 1)  (ψ10 n) = (ψ10 (Suc n))"
proof
  have "dim_row (ψ10 1) * dim_row (ψ10 n) = 2^(Suc n)" by simp 
  then show "dim_row ((ψ10 1)  (ψ10 n)) = dim_row (ψ10 (Suc n))" by simp
  have "dim_col (ψ10 1) * dim_col (ψ10 n) = 1" by simp
  then show "dim_col ((ψ10 1)  (ψ10 n)) = dim_col (ψ10 (Suc n))" by simp
next
  fix i j:: nat
  assume a0: "i < dim_row (ψ10 (Suc n))" and a1: "j < dim_col (ψ10 (Suc n))"
  then have f0: "j = 0" and f1: "i < 2^(Suc n)" by auto
  then have f2:"(ψ10 (Suc n)) $$ (i,j) = 1/(sqrt 2)^(Suc n)" 
    using ψ10_values[of "i" "(Suc n)" "j"] a0 a1 by simp
  show "((ψ10 1)  (ψ10 n)) $$ (i,j) = (ψ10 (Suc n)) $$ (i,j)" 
  proof (rule disjE) (*case distinction*)
    show "i < dim_row (ψ10 n)  i  dim_row (ψ10 n)" by linarith
  next (* case i < dim_row (ψ10 n) *)
    assume a2: "i < dim_row (ψ10 n)"
    then have "((ψ10 1)  (ψ10 n)) $$ (i,j) = (ψ10 1) $$ (0,0) * (ψ10 n) $$ (i,0)"
      using index_tensor_mat f0 assms by simp
    also have "... = 1/sqrt(2) * 1/(sqrt(2)^n)"
      using ψ10_values a2 assms by simp
    finally show "((ψ10 1)  (ψ10 n)) $$ (i,j) = (ψ10 (Suc n)) $$ (i,j)" 
      using f2 divide_divide_eq_left power_Suc by simp
  next (* case i ≥ dim_row (ψ10 n) *)
    assume "i  dim_row (ψ10 n)"
    then have "((ψ10 1)  (ψ10 n)) $$ (i,0) = ((ψ10 1) $$ (1, 0)) * ((ψ10 n) $$ ( i -dim_row (ψ10 n),0))"
      using index_tensor_mat_with_vec2_row_cond[of i "(ψ10 1)" "(ψ10 n)" ] a0 a1 f0
      by (metis dim_col_mat(1) dim_row_mat(1) index_tensor_mat_with_vec2_row_cond power_Suc power_one_right)  
    then have "((ψ10 1)  (ψ10 n)) $$ (i,0) = 1/sqrt(2) * 1/(sqrt 2)^n"
      using ψ10_values[of "i -dim_row (ψ10 n)" "n" "j"] a0 a1 by simp
    then show  "((ψ10 1)  (ψ10 n)) $$ (i,j) = (ψ10 (Suc n)) $$ (i,j)" 
      using f0 f1 divide_divide_eq_left power_Suc by simp
  qed
qed

lemma ψ10_tensor_is_state:
  assumes "n  1"
  shows "state n ( |zero ⊗⇗n)"  
  using iter_tensor_of_state_is_state ket_zero_is_state assms by fastforce

lemma iter_tensor_of_H_is_gate:
  assumes "n  1"
  shows "gate n (H ⊗⇗n)" 
  using iter_tensor_of_gate_is_gate H_is_gate assms by fastforce

lemma iter_tensor_of_H_on_zero_tensor: 
  assumes "n  1"
  shows "(H ⊗⇗n) * ( |zero ⊗⇗n) = ψ10 n"
  using assms
proof(rule nat_induct_at_least)
  show "(H ⊗⇗1) * ( |zero ⊗⇗1) = ψ10 1"
    using H_on_ket_zero by simp
next
  fix n:: nat
  assume a0: "n  1" and IH: "(H ⊗⇗n) * ( |zero ⊗⇗n) = ψ10 n"
  then have "(H ⊗⇗(Suc n)) * ( |zero ⊗⇗(Suc n)) = (H * |zero)  ((H ⊗⇗n) * ( |zero ⊗⇗n))" 
    using iter_tensor_mult_distr[of "n" "H" "|zero"] a0 ket_vec_def H_def by(simp add: H_def) 
  also have  "... = (H * |zero)  (ψ10 n)" using IH by simp 
  also have "... = (ψ10 1)  (ψ10 n)" using H_on_ket_zero by simp
  also have "... = (ψ10 (Suc n))" using ψ10_tensor a0 by simp
  finally show "(H ⊗⇗(Suc n)) * ( |zero ⊗⇗(Suc n)) = (ψ10 (Suc n))" by simp
qed

lemma ψ10_is_state:
  assumes "n  1"
  shows "state n (ψ10 n)"
  using iter_tensor_of_H_is_gate ψ10_tensor_is_state assms gate_on_state_is_state iter_tensor_of_H_on_zero_tensor assms by metis

abbreviation ψ11:: "complex Matrix.mat" where
"ψ11  Matrix.mat 2 1 (λ(i,j). if i=0 then 1/sqrt(2) else -1/sqrt(2))"

lemma H_on_ket_one_is_ψ11: 
  shows "(H * |one) = ψ11"
proof 
  fix i j:: nat
  assume "i < dim_row ψ11" and "j < dim_col ψ11"
  then have "i  {0,1}  j = 0" by (simp add: less_2_cases)
  then show "(H * |one) $$ (i,j) = ψ11 $$ (i,j)"
    by (auto simp add: times_mat_def scalar_prod_def H_def ket_vec_def)
next
  show "dim_row (H * |one) = dim_row ψ11" by (simp add: H_def)
next 
  show "dim_col (H * |one) = dim_col ψ11" by (simp add: H_def ket_vec_def)
qed

abbreviation ψ1:: "nat  complex Matrix.mat" where
"ψ1 n  Matrix.mat (2^(n+1)) 1 (λ(i,j). if even i then 1/(sqrt 2)^(n+1) else -1/(sqrt 2)^(n+1))"

lemma ψ1_values_even[simp]:
  fixes i j n
  assumes "i < dim_row (ψ1 n)" and "j < dim_col (ψ1 n)" and "even i"
  shows "(ψ1 n) $$ (i,j) = 1/(sqrt 2)^(n+1)" 
  using assms case_prod_conv by simp

lemma ψ1_values_odd [simp]:
  fixes i j n
  assumes "i < dim_row (ψ1 n)" and "j < dim_col (ψ1 n)" and "odd i"
  shows "(ψ1 n) $$ (i,j) = -1/(sqrt 2)^(n+1)" 
  using assms case_prod_conv by simp

lemma "ψ10_tensor_ψ11_is_ψ1":
  assumes "n  1"
  shows "(ψ10 n)  ψ11 = ψ1 n" 
proof 
 show "dim_col ((ψ10 n)  ψ11) = dim_col (ψ1 n)" by simp
next
  show "dim_row ((ψ10 n)  ψ11) = dim_row (ψ1 n)" by simp
next
  fix i j:: nat
  assume a0: "i < dim_row (ψ1 n)" and a1: "j < dim_col (ψ1 n)"
  then have "i < 2^(n+1)" and "j = 0" by auto 
  then have f0: "((ψ10 n)  ψ11) $$ (i,j) = 1/(sqrt 2)^n * ψ11 $$ (i mod 2, j)" 
    using ψ10_values[of "i div 2" n "j div 1"] a0 a1 by simp
  show "((ψ10 n)  ψ11) $$ (i,j) = (ψ1 n) $$ (i,j)" 
    using f0 ψ1_values_even ψ1_values_odd a0 a1 by auto 
qed

lemma ψ1_is_state:
  assumes "n  1"
  shows "state (n+1) (ψ1 n)" 
  using assms ψ10_tensor_ψ11_is_ψ1 ψ10_is_state H_on_ket_one_is_state H_on_ket_one_is_ψ11 tensor_state by metis

abbreviation (in jozsa) ψ2:: "complex Matrix.mat" where
"ψ2  Matrix.mat (2^(n+1)) 1 (λ(i,j). if even i then (-1)^f(i div 2)/(sqrt 2)^(n+1) 
                                        else (-1)^(f(i div 2)+1)/(sqrt 2)^(n+1))"

lemma (in jozsa) ψ2_values_even [simp]:
  fixes i j 
  assumes "i < dim_row ψ2 " and "j < dim_col ψ2" and "even i"
  shows "ψ2 $$ (i,j) = (-1)^f(i div 2)/(sqrt 2)^(n+1)" 
  using assms case_prod_conv by simp

lemma (in jozsa) ψ2_values_odd [simp]:
  fixes i j 
  assumes "i < dim_row ψ2" and "j < dim_col ψ2" and "odd i"
  shows "ψ2 $$ (i,j) = (-1)^(f(i div 2)+1)/(sqrt 2)^(n+1)" 
  using assms case_prod_conv by simp

lemma (in jozsa) ψ2_values_odd_hidden [simp]:
  assumes "2*k+1 < dim_row ψ2" and "j < dim_col ψ2" 
  shows "ψ2 $$ (2*k+1,j) = ((-1)^(f((2*k+1) div 2)+1))/(sqrt 2)^(n+1)" 
  using assms by simp

lemma (in jozsa) snd_rep_of_ψ2:
  assumes "i < dim_row ψ2"
  shows "((1-f(i div 2)) + -f(i div 2)) * 1/(sqrt 2)^(n+1) = (-1)^f(i div 2)/(sqrt 2)^(n+1)"
    and "(-(1-f(i div 2))+(f(i div 2)))* 1/(sqrt 2)^(n+1) = (-1)^(f(i div 2)+1)/(sqrt 2)^(n+1)"
proof- 
  have "i div 2  {i. i < 2 ^ n}" 
    using assms by auto
  then have "real (Suc 0 - f (i div 2)) - real (f (i div 2)) = (- 1) ^ f (i div 2)" 
    using assms f_values by auto
  thus "((1-f(i div 2)) + -f(i div 2)) * 1/(sqrt 2)^(n+1) = (-1)^f(i div 2)/(sqrt 2)^(n+1)" by auto
next
  have "i div 2  {i. i < 2^n}" 
    using assms by simp
  then have "(real (f (i div 2)) - real (Suc 0 - f (i div 2))) / (sqrt 2 ^ (n+1)) =
           - ((- 1) ^ f (i div 2) / (sqrt 2 ^ (n+1)))" 
   using assms f_values by fastforce
  then show "(-(1-f(i div 2))+(f(i div 2)))* 1/(sqrt 2)^(n+1) = (-1)^(f(i div 2)+1)/(sqrt 2)^(n+1)" by simp
qed

lemma (in jozsa) jozsa_transform_times_ψ1_is_ψ2:
  shows "Uf * (ψ1 n) = ψ2" 
proof 
  show "dim_row (Uf * (ψ1 n)) = dim_row ψ2" by simp
next
  show "dim_col (Uf * (ψ1 n)) = dim_col ψ2" by simp
next
  fix i j ::nat
  assume a0: "i < dim_row ψ2" and a1: "j < dim_col ψ2"
  then have f0:"i  {0..2^(n+1)}  j=0" by simp
  then have f1: "i < dim_row Uf  j < dim_col Uf " using a0 by simp
  have f2: "i < dim_row (ψ1 n)  j < dim_col (ψ1 n)" using a0 a1 by simp
  show "(Uf * (ψ1 n)) $$ (i,j) = ψ2 $$ (i,j)"
  proof (rule disjE)
    show "even i  odd i" by auto
  next
    assume a2: "even i"
    then have "(Uf * (ψ1 n)) $$ (i,j) = (k  {i,i+1}. Uf $$ (i,k) * (ψ1 n) $$ (k,j))"
      using f1 f2 Uf_mult_without_empty_summands_even[of i j "(ψ1 n)"] by simp 
    moreover have "Uf $$ (i,i) * (ψ1 n) $$ (i,j) = (1-f(i div 2))* 1/(sqrt 2)^(n+1)" 
      using f0 f1 a2 by simp
    moreover have "Uf $$ (i,i+1) * (ψ1 n) $$ (i+1,j) = (-f(i div 2))* 1/(sqrt 2)^(n+1)" 
      using f0 f1 a2 by auto
    ultimately have "(Uf * (ψ1 n)) $$ (i,j) = (1-f(i div 2))* 1/(sqrt 2)^(n+1) + (-f(i div 2))* 1/(sqrt 2)^(n+1)" by simp
    also have "... = ((1-f(i div 2))+-f(i div 2)) * 1/(sqrt 2)^(n+1)" 
      using add_divide_distrib 
      by (metis (no_types, opaque_lifting) mult.right_neutral of_int_add of_int_of_nat_eq)
    also have "... = ψ2 $$ (i,j)" 
      using a0 a1 a2 snd_rep_of_ψ2 by simp
    finally show "(Uf * (ψ1 n)) $$ (i,j) = ψ2 $$ (i,j)" by simp
  next 
    assume a2: "odd i"
    then have f6: "i1"  
    using linorder_not_less by auto
    have "(Uf * (ψ1 n)) $$ (i,j) = (k  {i-1,i}. Uf $$ (i,k) * (ψ1 n) $$ (k,j))"
      using f1 f2 a2 Uf_mult_without_empty_summands_odd[of i j "(ψ1 n)"]  
      by (metis dim_row_mat(1) jozsa_transform_dim(2)) 
    moreover have "(k  {i-1,i}. Uf $$ (i,k) * (ψ1 n) $$ (k,j)) 
                 = Uf $$ (i,i-1) * (ψ1 n) $$ (i-1,j) +  Uf $$ (i,i) * (ψ1 n) $$ (i,j)" 
      using a2 f6 by simp
    moreover have  "Uf $$ (i,i) * (ψ1 n) $$ (i,j) = (1-f(i div 2))* -1/(sqrt 2)^(n+1)" 
      using f1 f2 a2 by simp
    moreover have "Uf $$ (i,i-1) * (ψ1 n) $$ (i-1,j) = f(i div 2)* 1/(sqrt 2)^(n+1)" 
      using a0 a1 a2 by simp
    ultimately have "(Uf * (ψ1 n)) $$ (i,j) = (1-f(i div 2))* -1/(sqrt 2)^(n+1) +(f(i div 2))* 1/(sqrt 2)^(n+1)" 
      using of_real_add by simp
    also have "... = (-(1-f(i div 2)) + (f(i div 2))) * 1/(sqrt 2)^(n+1)" 
      by (metis (no_types, opaque_lifting) mult.right_neutral add_divide_distrib mult_minus1_right 
          of_int_add of_int_of_nat_eq)
    also have "... = (-1)^(f(i div 2)+1)/(sqrt 2)^(n+1)" 
       using a0 a1 a2 snd_rep_of_ψ2 by simp
   finally show "(Uf * (ψ1 n)) $$ (i,j) = ψ2 $$ (i,j)" 
      using a0 a1 a2 by simp
  qed
qed

lemma (in jozsa) ψ2_is_state:
  shows "state (n+1) ψ2" 
  using jozsa_transform_times_ψ1_is_ψ2 jozsa_transform_is_gate ψ1_is_state dim gate_on_state_is_state by fastforce

text @{text "H^ n"} is the result of taking the nth tensor product of H›

abbreviation iter_tensor_of_H_rep:: "nat  complex Matrix.mat" (H^ _›) where
"iter_tensor_of_H_rep n  Matrix.mat (2^n) (2^n) (λ(i,j).(-1)^(i ⋅⇘nj)/(sqrt 2)^n)"

lemma tensor_of_H_values [simp]:
  fixes n i j:: nat
  assumes "i < dim_row (H^ n)" and "j < dim_col (H^ n)"
  shows "(H^ n) $$ (i,j) = (-1)^(i ⋅⇘nj)/(sqrt 2)^n"
  using assms by simp

lemma dim_row_of_iter_tensor_of_H [simp]:
  assumes "n  1"
  shows "1 < dim_row (H^ n)" 
  using assms by(metis One_nat_def Suc_1 dim_row_mat(1) le_trans lessI linorder_not_less one_less_power)

lemma iter_tensor_of_H_fst_pos:
  fixes n i j:: nat
  assumes "i < 2^n  j < 2^n" and "i < 2^(n+1)  j < 2^(n+1)"
  shows "(H^ (Suc n)) $$ (i,j) = 1/sqrt(2) * ((H^ n) $$ (i mod 2^n, j mod 2^n))"
proof-
  have "(H^ (Suc n)) $$ (i,j) = (-1)^(bip i (Suc n) j)/(sqrt 2)^(Suc n)"
    using assms by simp
  moreover have "bip i (Suc n) j = bip (i mod 2^n) n (j mod 2^n)" 
    using bitwise_inner_prod_fst_el_0 assms(1) by simp 
  ultimately show ?thesis 
    using bitwise_inner_prod_def by simp
qed

lemma iter_tensor_of_H_fst_neg:
  fixes n i j:: nat
  assumes "i  2^n  j  2^n" and "i < 2^(n+1)  j < 2^(n+1)"
  shows "(H^ (Suc n)) $$ (i,j) = -1/sqrt(2) * (H^ n) $$ (i mod 2^n, j mod 2^n)"
proof-
  have "(H^ (Suc n)) $$ (i,j) = (-1)^(bip i (n+1) j)/(sqrt 2)^(n+1)" 
    using assms(2) by simp
  moreover have "bip i (n+1) j = 1 + bip (i mod 2^n) n (j mod 2^n)" 
    using bitwise_inner_prod_fst_el_is_1 assms by simp
  ultimately show ?thesis by simp
qed 

lemma H_tensor_iter_tensor_of_H:   
  fixes n:: nat
  shows  "(H  H^ n) = H^ (Suc n)" 
proof
  fix i j:: nat
  assume a0: "i < dim_row (H^ (Suc n))" and a1: "j < dim_col (H^ (Suc n))"
  then have f0: "i  {0..<2^(n+1)}  j  {0..<2^(n+1)}" by simp
  then have f1: "(H  H^ n) $$ (i,j) = H $$ (i div (dim_row (H^ n)),j div (dim_col (H^ n))) 
                                       * (H^ n) $$ (i mod (dim_row (H^ n)),j mod (dim_col (H^ n)))"
    by (simp add: H_without_scalar_prod)
  show "(H  H^ n) $$ (i,j) = (H^ (Suc n)) $$ (i,j)"
  proof (rule disjE) 
    show "(i < 2^n  j < 2^n)  ¬(i < 2^n  j < 2^n)" by auto
  next
    assume a2: "(i < 2^n  j < 2^n)"
    then have "(H^ (Suc n)) $$ (i,j) = 1/sqrt(2) * ((H^ n) $$ (i mod 2^n, j mod 2^n))" 
      using a0 a1 f0 iter_tensor_of_H_fst_pos by (metis (mono_tags, lifting) atLeastLessThan_iff)
    moreover have "H $$ (i div (dim_row (H^ n)),j div (dim_col (H^ n))) = 1/sqrt 2"
      using a0 a1 f0 H_without_scalar_prod H_values a2
      by (metis (no_types, lifting) dim_col_mat(1) dim_row_mat(1) div_less le_eq_less_or_eq 
          le_numeral_extra(2) less_power_add_imp_div_less plus_1_eq_Suc power_one_right) 
    ultimately show "(H  H^ n) $$ (i,j) = (H^ (Suc n)) $$ (i,j)" 
      using f1 by simp
  next 
    assume a2: "¬(i < 2^n  j < 2^n)"
    then have "i  2^n  j  2^n" by simp
    then have f2:"(H^ (Suc n)) $$ (i,j) = -1/sqrt(2) * ((H^ n) $$ (i mod 2^n, j mod 2^n))" 
      using a0 a1 f0 iter_tensor_of_H_fst_neg by simp
    have "i div (dim_row (H^ n)) =1" and "j div (dim_row (H^ n)) = 1"  
      using a2 a0 a1 by auto
    then have "H $$ (i div (dim_row (H^ n)),j div (dim_col (H^ n))) = -1/sqrt 2"
      using a0 a1 f0 H_values_right_bottom[of "i div (dim_row (H^ n))" "j div (dim_col (H^ n))"] a2 
      by fastforce
    then show "(H  H^ n) $$ (i,j) = (H^ (Suc n)) $$ (i,j)" 
      using f1 f2 by simp
  qed
next
  show "dim_row (H  H^ n) = dim_row (H^ (Suc n))" 
    by (simp add: H_without_scalar_prod) 
next
  show "dim_col (H  H^ n) = dim_col (H^ (Suc n))" 
    by (simp add: H_without_scalar_prod) 
qed

text ‹
We prove that @{term "H^ n"} is indeed the matrix representation of @{term "H ⊗⇗n⇖"}, the iterated 
tensor product of the Hadamard gate H.
›

lemma one_tensor_of_H_is_H:
  shows "(H^ 1) = H"
proof(rule eq_matI)
  show "dim_row (H^ 1) = dim_row H"
    by (simp add: H_without_scalar_prod)
  show "dim_col (H^ 1) = dim_col H"
    by (simp add: H_without_scalar_prod)
next
  fix i j:: nat
  assume a0:"i < dim_row H" and a1:"j < dim_col H"
  then show "(H^ 1) $$ (i,j) = H $$ (i,j)"
  proof-
    have "(H^ 1) $$ (0, 0) = 1/sqrt(2)" 
       using bitwise_inner_prod_def bin_rep_def by simp 
    moreover have "(H^ 1) $$ (0,1) = 1/sqrt(2)" 
       using bitwise_inner_prod_def bin_rep_def by simp 
    moreover have "(H^ 1) $$ (1,0) = 1/sqrt(2)" 
       using bitwise_inner_prod_def bin_rep_def by simp 
    moreover have "(H^ 1) $$ (1,1) = -1/sqrt(2)" 
       using bitwise_inner_prod_def bin_rep_def by simp 
     ultimately show "(H^ 1) $$ (i,j) = H $$ (i,j)" 
       using a0 a1 H_values H_values_right_bottom
       by (metis (no_types, lifting) H_without_scalar_prod One_nat_def dim_col_mat(1) dim_row_mat(1) 
divide_minus_left less_2_cases)
  qed
qed

lemma iter_tensor_of_H_rep_is_correct:
  fixes n:: nat
  assumes "n  1"
  shows "(H ⊗⇗n) = H^ n"
  using assms
proof(rule nat_induct_at_least)
  show "(H ⊗⇗1) = H^ 1" 
    using one_tensor_is_id one_tensor_of_H_is_H by simp
next
  fix n:: nat
  assume a0:"n  1" and IH:"(H ⊗⇗n) = H^ n"
  then have "(H ⊗⇗(Suc n)) = H  (H ⊗⇗n)" 
    using iter_tensor_suc Nat.Suc_eq_plus1 by metis
  also have "... = H  (H^ n)" 
    using IH by simp
  also have "... = H^ (Suc n)" 
    using a0 H_tensor_iter_tensor_of_H by simp
  finally show "(H ⊗⇗(Suc n)) = H^ (Suc n)" 
    by simp
qed

text @{text "HId^ 1"} is the result of taking the tensor product of the nth tensor of H and Id 1 ›

abbreviation tensor_of_H_tensor_Id:: "nat  complex Matrix.mat" (HId^ _›) where
"tensor_of_H_tensor_Id n  Matrix.mat (2^(n+1)) (2^(n+1)) (λ(i,j).
  if (i mod 2 = j mod 2) then (-1)^((i div 2) ⋅⇘n(j div 2))/(sqrt 2)^n else 0)"

lemma mod_2_is_both_even_or_odd:
  "((even i  even j)  (odd i  odd j))  (i mod 2 = j mod 2)" 
  by (metis even_iff_mod_2_eq_zero odd_iff_mod_2_eq_one)
  
lemma HId_values [simp]:
  assumes "n  1" and "i < dim_row (HId^ n)" and "j < dim_col (HId^ n)"
  shows "even i  even j  (HId^ n) $$ (i,j) = (-1)^((i div 2) ⋅⇘n(j div 2))/(sqrt 2)^n"
and "odd i  odd j  (HId^ n) $$ (i,j) = (-1)^((i div 2) ⋅⇘n(j div 2))/(sqrt 2)^n"
and "(i mod 2 = j mod 2)  (HId^ n) $$ (i,j) = (-1)^((i div 2) ⋅⇘n(j div 2))/(sqrt 2)^n"
and "¬(i mod 2 = j mod 2)  (HId^ n) $$ (i,j) = 0"
  using assms mod_2_is_both_even_or_odd by auto

lemma iter_tensor_of_H_tensor_Id_is_HId:
  shows "(H^ n)  Id 1 = HId^ n"
proof
  show "dim_row ((H^ n)  Id 1) = dim_row (HId^ n)" 
    by (simp add: Quantum.Id_def)
  show "dim_col ((H^ n)  Id 1) = dim_col (HId^ n)" 
    by (simp add: Quantum.Id_def)
next
  fix i j:: nat
  assume a0: "i < dim_row (HId^ n)" and a1: "j < dim_col (HId^ n)"
  then have f0: "i < (2^(n+1))  j < (2^(n+1))" by simp
  then have "i < dim_row (H^ n) * dim_row (Id 1)  j < dim_col (H^ n) * dim_col (Id 1)"   
    using Id_def by simp
  moreover have "dim_col (H^ n)  0  dim_col (Id 1)  0"  
    using Id_def by simp
  ultimately have f1: "((H^ n)  (Id 1)) $$ (i,j) 
    = (H^ n) $$ (i div (dim_row (Id 1)),j div (dim_col (Id 1))) * 
      (Id 1) $$ (i mod (dim_row (Id 1)),j mod (dim_col (Id 1)))"
    by (simp add: Quantum.Id_def)
  show "((H^ n)Id 1) $$ (i,j) = (HId^ n) $$ (i,j)" 
  proof (rule disjE)
    show "(i mod 2 = j mod 2)  ¬ (i mod 2 = j mod 2)" by simp
  next
    assume a2:"(i mod 2 = j mod 2)"
    then have "(Id 1) $$ (i mod (dim_row (Id 1)),j mod (dim_col (Id 1))) = 1" 
      by (simp add: Quantum.Id_def)
    moreover have "(H^ n) $$ (i div (dim_row (Id 1)), j div (dim_col (Id 1))) 
                    = (-1)^((i div (dim_row (Id 1))) ⋅⇘n(j div (dim_col (Id 1))))/(sqrt 2)^n" 
      using tensor_of_H_values Id_def f0 less_mult_imp_div_less by simp
    ultimately show "((H^ n)  Id 1) $$ (i,j) = (HId^ n) $$ (i,j)" 
      using a2 f0 f1 Id_def by simp
  next
    assume a2: "¬(i mod 2 = j mod 2)" 
    then have "(Id 1) $$ (i mod (dim_row (Id 1)),j mod (dim_col (Id 1))) = 0" 
      by (simp add: Quantum.Id_def)
    then show "((H^ n)  Id 1) $$ (i,j) = (HId^ n) $$ (i,j)" 
      using a2 f0 f1 by simp
  qed
qed

lemma HId_is_gate:
  assumes "n  1"
  shows "gate (n+1) (HId^ n)" 
proof- 
  have "(HId^ n) = (H^ n)  Id 1" 
    using iter_tensor_of_H_tensor_Id_is_HId by simp
  moreover have "gate 1 (Id 1)" 
    using id_is_gate by simp
  moreover have "gate n (H^ n)"
    using H_is_gate iter_tensor_of_gate_is_gate[of 1 H n] assms by(simp add: iter_tensor_of_H_rep_is_correct)
  ultimately show "gate (n+1) (HId^ n)" 
    using tensor_gate by presburger
qed

text ‹State @{term "ψ3"} is obtained by the multiplication of @{term "HId^ n"} and @{term "ψ2"}

abbreviation (in jozsa) ψ3:: "complex Matrix.mat" where
"ψ3  Matrix.mat (2^(n+1)) 1 (λ(i,j). 
if even i 
  then (k<2^n. (-1)^(f(k) + ((i div 2) ⋅⇘nk))/((sqrt 2)^n * (sqrt 2)^(n+1))) 
    else  (k<2^n. (-1)^(f(k)+ 1 + ((i div 2) ⋅⇘nk)) /((sqrt 2)^n * (sqrt 2)^(n+1))))"

lemma (in jozsa) ψ3_values:
  assumes "i < dim_row ψ3"
  shows "odd i  ψ3 $$ (i,0) = (k<2^n. (-1)^(f(k) + 1 + ((i div 2) ⋅⇘nk))/((sqrt 2)^n * (sqrt 2)^(n+1)))"
  using assms by simp

lemma (in jozsa) ψ3_dim [simp]:
  shows "1 < dim_row ψ3"
  using dim_row_mat(1) nat_neq_iff by fastforce

lemma sum_every_odd_summand_is_zero:
  fixes n:: nat 
  assumes "n  1"
  shows "f::(nat  complex).(i. i<2^(n+1)  odd i  f i = 0)  
            (k{0..<2^(n+1)}. f k) = (k{0..<2^n}. f (2*k))"
  using assms
proof(rule nat_induct_at_least)
  show "f::(nat  complex).(i. i<2^(1+1)  odd i  f i = 0) 
            (k{0..<2^(1+1)}. f k) = (k  {0..<2^1}. f (2*k))"
  proof(rule allI,rule impI)
    fix f:: "(nat  complex)"
    assume asm: "(i. i<2^(1+1)  odd i  f i = 0)" 
    moreover have "(k{0..<4}. f k) = f 0 + f 1 + f 2 + f 3" 
      by (simp add: add.commute add.left_commute)
    moreover have "f 1 = 0" 
      using asm by simp 
    moreover have "f 3 = 0" 
      using asm by simp 
    moreover have "(k{0..<2^1}. f (2*k)) = f 0 + f 2" 
      using add.commute add.left_commute by simp
    ultimately show "(k{0..<2^(1+1)}. f k) = (k{0..<2^1}. f (2*k))" 
      by simp
  qed
next
  fix n:: nat
  assume "n  1"
  and IH: "f::(nat complex).(i. i<2^(n+1)  odd i  f i = 0) 
(k{0..<2^(n+1)}. f k) = (k{0..<2^n}. f (2*k))" 
  show "f::(nat complex).(i. i<2^(Suc n +1)  odd i  f i = 0) 
(k{0..<2^(Suc n +1)}. f k) = (k{0..< 2^(Suc n)}. f (2*k))" 
  proof (rule allI,rule impI)
    fix f::"nat  complex"
    assume asm: "(i. i<2^(Suc n +1)  odd i  f i = 0)"
    have f0: "(k{0..<2^(n+1)}. f k) = (k{0..<2^n}. f (2*k))" 
      using asm IH by simp
    have f1: "(k{0..<2^(n+1)}. (λx. f (x+2^(n+1))) k) = (k{0..< 2^n}. (λx. f (x+2^(n+1))) (2*k))" 
      using asm IH by simp
    have "(k{0..<2^(n+2)}. f k) = (k{0..<2^(n+1)}. f k) + (k{2^(n+1)..<2^(n+2)}. f k)"
      by (simp add: sum.atLeastLessThan_concat)
    also have "... = (k{0..<2^n}. f (2*k)) + (k{2^(n+1)..<2^(n+2)}. f k)"  
      using f0 by simp
    also have "... = (k{0..<2^n}. f (2*k)) + (k{0..<2^(n+1)}. f (k+2^(n+1)))"  
      using sum.shift_bounds_nat_ivl[of "f" "0" "2^(n+1)" "2^(n+1)"] by simp
    also have "... = (k{0..<2^n}. f (2*k)) + (k{0..< 2^n}. (λx. f (x+2^(n+1))) (2*k))"
      using f1 by simp
    also have "... = (k{0..<2^n}. f (2*k)) + (k{2^n..< 2^(n+1)}. f (2 *k))"
      using sum.shift_bounds_nat_ivl[of "λx. (f::natcomplex) (2*(x-2^n)+2^(n+1))" "0" "2^n" "2^n"] 
      by (simp add: mult_2)
    also have "... = (k  {0..<2^(n+1)}. f (2*k))" 
      by (metis Suc_eq_plus1 lessI less_imp_le_nat one_le_numeral power_increasing sum.atLeastLessThan_concat zero_le)
    finally show "(k{0..<2^((Suc n)+1)}. f k) = (k{0..< 2^(Suc n)}. f (2*k))"
      by (metis Suc_eq_plus1 add_2_eq_Suc')
  qed
qed

lemma sum_every_even_summand_is_zero:
  fixes n:: nat 
  assumes "n  1"
  shows "f::(nat  complex).(i. i<2^(n+1)  even i  f i = 0)  
            (k{0..<2^(n+1)}. f k) = (k{0..< 2^n}. f (2*k+1))"
  using assms
proof(rule nat_induct_at_least)
  show "f::(nat  complex).(i. i<2^(1+1)  even i  f i = 0)  
            (k{0..<2^(1+1)}. f k) = (k{0..< 2^1}. f (2*k+1))"
  proof(rule allI,rule impI)
    fix f:: "nat complex"
    assume asm: "(i. i<2^(1+1)  even i  f i = 0)" 
    moreover have "(k{0..<4}. f k) = f 0 + f 1 + f 2 + f 3" 
      by (simp add: add.commute add.left_commute)
    moreover have "f 0 = 0" using asm by simp 
    moreover have "f 2 = 0" using asm by simp 
    moreover have "(k  {0..< 2^1}. f (2*k+1)) = f 1 + f 3" 
      using add.commute add.left_commute by simp
    ultimately show "(k{0..<2^(1+1)}. f k) = (k{0..< 2^1}. f (2*k+1))" by simp
  qed
next
  fix n:: nat
  assume "n  1"
  and IH: "f::(nat complex).(i. i<2^(n+1)  even i  f i = 0) 
(k{0..<2^(n+1)}. f k) = (k{0..< 2^n}. f (2*k+1))" 
  show "f::(nat complex).(i. i<2^((Suc n)+1)  even i  f i = 0) 
(k{0..<2^((Suc n)+1)}. f k) = (k{0..< 2^(Suc n)}. f (2*k+1))" 
  proof (rule allI,rule impI)
    fix f::"nat complex"
    assume asm: "(i. i<2^((Suc n)+1)  even i  f i = 0)"
    have f0: "(k {0..<2^(n+1)}. f k) = (k  {0..< 2^n}. f (2*k+1))" 
      using asm IH by simp
    have f1: "(k{0..<2^(n+1)}. (λx. f (x+2^(n+1))) k) 
              = (k{0..< 2^n}. (λx. f (x+2^(n+1))) (2*k+1))" 
      using asm IH by simp
    have "(k{0..<2^(n+2)}. f k) 
               = (k{0..<2^(n+1)}. f k) + (k{2^(n+1)..<2^(n+2)}. f k)"
      by (simp add: sum.atLeastLessThan_concat)
    also have "... = (k{0..< 2^n}. f (2*k+1)) + (k{2^(n+1)..<2^(n+2)}. f k)"  
      using f0 by simp
    also have "... = (k{0..< 2^n}. f (2*k+1)) + (k{0..<2^(n+1)}. f (k+(2^(n+1))))"  
      using sum.shift_bounds_nat_ivl[of "f" "0" "2^(n+1)" "2^(n+1)"] by simp
    also have "... = (k{0..< 2^n}. f (2*k+1)) + (k{0..< 2^n}. (λx. f (x+2^(n+1))) (2*k+1))"
      using f1 by simp
    also have "... = (k{0..< 2^n}. f (2*k+1)) + (k{2^n..< 2^(n+1)}. f (2 *k+1))"
      using sum.shift_bounds_nat_ivl[of "λx. (f::natcomplex) (2*(x-2^n)+1+2^(n+1))" "0" "2^n" "2^n"] 
      by (simp add: mult_2)
    also have "... = (k{0..< 2^(n+1)}. f (2*k+1))" 
      by (metis Suc_eq_plus1 lessI less_imp_le_nat one_le_numeral power_increasing sum.atLeastLessThan_concat zero_le)
    finally show "(k{0..<2^((Suc n)+1)}. f k) = (k{0..< 2^(Suc n)}. f (2*k+1))"
      by (metis Suc_eq_plus1 add_2_eq_Suc')
  qed
qed

lemma (in jozsa) iter_tensor_of_H_times_ψ2_is_ψ3:
  shows "((H^ n)  Id 1) * ψ2 = ψ3"
proof
  fix i j
  assume a0:"i < dim_row ψ3" and a1:"j < dim_col ψ3" 
  then have f0: "i < (2^(n+1))  j = 0" by simp
  have f1: "((HId^ n)* ψ2) $$ (i,j) = (k<(2^(n+1)). ((HId^ n) $$ (i,k)) * (ψ2 $$ (k,j)))" 
    using a1 f0 by (simp add: atLeast0LessThan)
  show "(((H^ n)  Id 1) * ψ2) $$ (i,j) = ψ3 $$ (i,j)"
  proof(rule disjE)
    show "even i  odd i" by simp
  next
    assume a2: "even i"
    have "(¬(i mod 2 = k mod 2)  k<dim_col (HId^ n))  ((HId^ n) $$ (i,k)) * (ψ2 $$ (k,j)) = 0" for k 
      using f0 by simp
    then have "k<(2^(n+1))  odd k  ((HId^ n) $$ (i,k)) * (ψ2 $$ (k,j)) = 0" for k 
      using a2 mod_2_is_both_even_or_odd f0 by (metis (no_types, lifting) dim_col_mat(1))
    then have "(k{(0::nat)..<(2^(n+1))}. ((HId^ n) $$ (i,k)) * (ψ2 $$ (k,j)))
             = (k{(0::nat)..< (2^n)}. ((HId^ n) $$ (i,2*k)) * (ψ2 $$ (2*k,j)))" 
      using sum_every_odd_summand_is_zero dim by simp
    moreover have "(k<2^n. ((HId^ n) $$ (i,2*k)) * (ψ2 $$ (2*k,j))) 
                 = (k<2^n.(-1)^((i div 2) ⋅⇘nk)/(sqrt(2)^n) *((-1)^f(k))/(sqrt(2)^(n+1)))" 
    proof-
        have "(even k  k<dim_row ψ2)  (ψ2 $$ (k,j)) = ((-1)^f(k div 2))/(sqrt(2)^(n+1))" for k 
          using a0 a1 by simp
      then have "(k<2^n. ((HId^ n) $$ (i,2*k)) * (ψ2 $$ (2*k,j))) 
               = (k<2^n. ((HId^ n) $$ (i,2*k)) *((-1)^f((2*k) div 2))/(sqrt(2)^(n+1)))" 
        by simp
      moreover have "(even k  k<dim_col (HId^ n))
                  ((HId^ n) $$ (i,k)) = (-1)^ ((i div 2) ⋅⇘n(k div 2))/(sqrt(2)^n)" for k
        using a2 a0 a1 by simp
      ultimately have "(k<2^n. ((HId^ n) $$ (i,2*k)) * (ψ2 $$ (2*k,j))) 
                     = (k<2^n. (-1)^((i div 2) ⋅⇘n((2*k) div 2))/(sqrt(2)^n) * 
                                   ((-1)^f((2*k) div 2))/(sqrt(2)^(n+1)))" 
      by simp
      then show "(k<2^n. ((HId^ n) $$ (i,2*k)) * (ψ2 $$ (2*k,j))) 
               = (k<2^n. (-1)^((i div 2) ⋅⇘nk)/(sqrt(2)^n) *((-1)^f(k))/(sqrt(2)^(n+1)))" 
        by simp
    qed
    ultimately have "((HId^ n)* ψ2) $$ (i,j) = (k<2^n. (-1)^((i div 2) ⋅⇘nk)/(sqrt(2)^n) 
                                                        * ((-1)^f(k))/(sqrt(2)^(n+1)))" 
      using f1 by (metis atLeast0LessThan) 
    also have "... =  (k<2^n. (-1)^(f(k)+((i div 2) ⋅⇘nk))/((sqrt(2)^n)*(sqrt(2)^(n+1))))" 
      by (simp add: power_add mult.commute)
    finally have "((HId^ n)* ψ2) $$ (i,j) = (k<2^n. (-1)^(f(k)+((i div 2) ⋅⇘nk))/((sqrt(2)^n)*(sqrt(2)^(n+1))))" 
       by simp
    moreover have "ψ3 $$ (i,j) = (k<2^n. (-1)^(f(k) + ((i div 2) ⋅⇘nk))/(sqrt(2)^n * sqrt(2)^(n+1)))" 
      using a0 a1 a2 by simp
    ultimately show "(((H^ n)  Id 1)* ψ2) $$ (i,j) = ψ3 $$ (i,j)" 
      using iter_tensor_of_H_tensor_Id_is_HId dim by simp
  next
    assume a2: "odd i"
    have "(¬(i mod 2 = k mod 2)  k<dim_col (HId^ n))  ((HId^ n) $$ (i,k)) * (ψ2 $$ (k,j)) = 0" for k 
      using f0 by simp
    then have "k<(2^(n+1))  even k  ((HId^ n) $$ (i,k)) * (ψ2 $$ (k,j)) = 0" for k 
      using a2 mod_2_is_both_even_or_odd f0 by (metis (no_types, lifting) dim_col_mat(1))
    then have "(k{0..<2^(n+1)}. ((HId^ n) $$ (i,k)) * (ψ2 $$ (k,j)))
             = (k{0..<2^n}. ((HId^ n) $$ (i,2*k+1)) * (ψ2 $$ (2*k+1,j)))" 
      using sum_every_even_summand_is_zero dim by simp
    moreover have "(k<2^n. ((HId^ n) $$ (i,2*k+1)) * (ψ2 $$ (2*k+1,j))) 
                 = ( k<2^n. (-1)^((i div 2) ⋅⇘nk)/(sqrt(2)^n) * ((-1)^(f(k)+1))/(sqrt(2)^(n+1)))" 
    proof-
      have "(odd k  k<dim_row ψ2)  (ψ2 $$ (k,j)) = ((-1)^(f(k div 2)+1))/(sqrt(2)^(n+1))" for k 
        using a0 a1 a2 by simp
      then have f2:"(k<2^n. ((HId^ n) $$ (i,2*k+1)) * (ψ2 $$ (2*k+1,j))) 
                  = (k<2^n. ((HId^ n) $$ (i,2*k+1)) * ((-1)^(f((2*k+1) div 2)+1))/(sqrt(2)^(n+1)))" 
        by simp
      have "i < dim_row (HId^ n)" 
        using f0 a2 mod_2_is_both_even_or_odd by simp
      then have "((i mod 2 = k mod 2)  k<dim_col (HId^ n))
                  ((HId^ n) $$ (i,k)) = (-1)^((i div 2) ⋅⇘n(k div 2))/(sqrt(2)^n) " for k
        using a2 a0 a1 f0 dim HId_values by simp
      moreover have "odd k  (i mod 2 = k mod 2)" for k 
        using a2 mod_2_is_both_even_or_odd by auto
      ultimately have "(odd k  k<dim_col (HId^ n))
                  ((HId^ n) $$ (i,k)) = (-1)^((i div 2) ⋅⇘n(k div 2))/(sqrt(2)^n)" for k
        by simp
      then have "k<2^n  ((HId^ n) $$ (i,2*k+1)) = (-1)^((i div 2) ⋅⇘n((2*k+1) div 2))/(sqrt(2)^n) " for k
        by simp
      then have "(k<2^n. ((HId^ n) $$ (i,2*k+1)) * (ψ2 $$ (2*k+1,j))) 
               = (k<2^n. (-1)^((i div 2) ⋅⇘n((2*k+1) div 2))/(sqrt(2)^n) 
                             * ((-1)^(f((2*k+1) div 2)+1))/(sqrt(2)^(n+1)))" 
        using f2 by simp
      then show "(k<2^n. ((HId^ n) $$ (i,2*k+1)) * (ψ2 $$ (2*k+1,j))) 
               = (k<2^n. (-1)^((i div 2) ⋅⇘nk)/(sqrt(2)^n) *((-1)^(f(k)+1))/(sqrt(2)^(n+1)))" 
        by simp
    qed
    ultimately have "((HId^ n)* ψ2) $$ (i,j) = (k<2^n. (-1)^((i div 2) ⋅⇘nk)/(sqrt(2)^n) 
                * ((-1)^(f(k)+1))/(sqrt(2)^(n+1)))" 
      using f1 by (metis atLeast0LessThan) 
    also have "... = (k<2^n. (-1)^(f(k)+1+((i div 2) ⋅⇘nk))/((sqrt(2)^n)*(sqrt(2)^(n+1))))"
      by (simp add: mult.commute power_add)
    finally have "((HId^ n)* ψ2) $$ (i,j) 
                = (k< 2^n. (-1)^(f(k)+1+((i div 2) ⋅⇘nk))/((sqrt(2)^n)*(sqrt(2)^(n+1))))" 
      by simp
    then show "(((H^ n)  Id 1)* ψ2) $$ (i,j) = ψ3 $$ (i,j)" 
      using iter_tensor_of_H_tensor_Id_is_HId dim a2 a0 a1 by simp
  qed
next
  show "dim_row (((H^ n)  Id 1) * ψ2) = dim_row ψ3"  
    using iter_tensor_of_H_tensor_Id_is_HId dim by simp
next
  show "dim_col (((H^ n)  Id 1)* ψ2) = dim_col ψ3" 
    using iter_tensor_of_H_tensor_Id_is_HId dim by simp
qed

lemma (in jozsa) ψ3_is_state:
  shows "state (n+1) ψ3"
proof-
  have "((H^ n)  Id 1) * ψ2 = ψ3" 
    using iter_tensor_of_H_times_ψ2_is_ψ3 by simp
  moreover have "gate (n+1) ((H^ n)  Id 1)" 
    using iter_tensor_of_H_tensor_Id_is_HId HId_is_gate dim by simp
  moreover have "state (n+1) ψ2" 
    using ψ2_is_state by simp
  ultimately show "state (n+1) ψ3"
    using gate_on_state_is_state dim by (metis (no_types, lifting))
qed

text ‹
Finally, all steps are put together. The result depends on the function f. If f is constant
the first n qubits are 0, if f is balanced there is at least one qubit in state 1 among the 
first n qubits. 
The algorithm only uses one evaluation of f(x) and will always succeed. 
›

definition (in jozsa) jozsa_algo:: "complex Matrix.mat" where 
"jozsa_algo  ((H ⊗⇗n)  Id 1) * (Uf * (((H ⊗⇗n) * ( |zero ⊗⇗n))  (H * |one)))"

lemma (in jozsa) jozsa_algo_result [simp]: 
  shows "jozsa_algo = ψ3" 
  using jozsa_algo_def H_on_ket_one_is_ψ11 iter_tensor_of_H_on_zero_tensor ψ10_tensor_ψ11_is_ψ1
  jozsa_transform_times_ψ1_is_ψ2 iter_tensor_of_H_times_ψ2_is_ψ3 dim iter_tensor_of_H_rep_is_correct 
  by simp

lemma (in jozsa) jozsa_algo_result_is_state: 
  shows "state (n+1) jozsa_algo" 
  using ψ3_is_state by simp

lemma (in jozsa) prob0_fst_qubits_of_jozsa_algo: 
  shows "(prob0_fst_qubits n jozsa_algo) = (j{0,1}. (cmod(jozsa_algo $$ (j,0)))2)"
  using prob0_fst_qubits_eq by simp

text ‹General lemmata required to compute probabilities.›

lemma aux_comp_with_sqrt2:
  shows "(sqrt 2)^n * (sqrt 2)^n = 2^n"
  by (smt (verit) power_mult_distrib real_sqrt_mult_self)

lemma aux_comp_with_sqrt2_bis [simp]:
  shows "2^n/(sqrt(2)^n * sqrt(2)^(n+1)) = 1/sqrt 2"
  using aux_comp_with_sqrt2 by (simp add: mult.left_commute)

lemma aux_ineq_with_card: 
  fixes g:: "nat  nat" and A:: "nat set"
  assumes "finite A" 
  shows "(kA. (-1)^(g k))  card A" and "(kA. (-1)^(g k))  -card A" 
   apply (smt (verit) assms neg_one_even_power neg_one_odd_power card_eq_sum of_nat_1 of_nat_sum sum_mono)
   apply (smt (verit) assms neg_one_even_power neg_one_odd_power card_eq_sum of_nat_1 of_nat_sum sum_mono sum_negf).

lemma aux_comp_with_cmod:
  fixes g:: "nat  nat"
  assumes "(x<2^n. g x = 0)  (x<2^n. g x = 1)"
  shows "(cmod (k<2^n. (-1)^(g k)))2  = 2^(2*n)"
proof(rule disjE)
  show "(x<2^n. g x = 0)  (x<2^n. g x = 1)" 
    using assms by simp
next
  assume "x<2^n. g x = 0"
  then have "(cmod (k<2^n. (-1)^(g k)))2 = (2^n)2" 
    by (simp add: norm_power)
  then show "?thesis" 
    by (simp add: power_even_eq)
next 
  assume "x<2^n. g x = 1" 
  then have "(cmod (k<2^n. (-1)^(g k)))2 = (2^n)2" 
    by (simp add: norm_power)
  then show "?thesis" 
    by (simp add: power_even_eq)
qed

lemma cmod_less:
  fixes a n:: int
  assumes "a < n" and "a > -n"
  shows "cmod a < n" 
  using assms by simp

lemma square_less:
  fixes a n:: real
  assumes "a < n" and "a > -n" 
  shows "a2 < n2"
  using assms by (smt (verit) power2_eq_iff power2_minus power_less_imp_less_base)

lemma cmod_square_real [simp]:
  fixes n:: real
  shows "(cmod n)2 = n2" 
  by simp

lemma aux_comp_sum_divide_cmod:
  fixes n:: nat and g:: "nat  int" and a:: real
  shows "(cmod(complex_of_real(k<n. g k / a)))2 = (cmod (k<n. g k) / a)2"
  by (metis cmod_square_real of_int_sum of_real_of_int_eq power_divide sum_divide_distrib)


text ‹
The function is constant if and only if the first n qubits are 0. So, if the function is constant, 
then the probability of measuring 0 for the first n qubits is 1.
›

lemma (in jozsa) prob0_jozsa_algo_of_const_0:
  assumes "const 0"
  shows "prob0_fst_qubits n jozsa_algo = 1"
proof-
  have "prob0_fst_qubits n jozsa_algo = (j{0,1}. (cmod(jozsa_algo $$ (j,0)))2)"
    using prob0_fst_qubits_of_jozsa_algo by simp
  moreover have "(cmod(jozsa_algo $$ (0,0)))2 = 1/2"
  proof-
    have "k<2^n  ((0 div 2) ⋅⇘nk) = 0" for k::nat 
      using bitwise_inner_prod_with_zero by simp 
    then have "(cmod(jozsa_algo $$ (0,0)))2 = (cmod(k::nat<2^n. 1/(sqrt(2)^n * sqrt(2)^(n+1))))2" 
      using jozsa_algo_result const_def assms by simp
    also have "... = (cmod((2::nat)^n/(sqrt(2)^n * sqrt(2)^(n+1))))2"  by simp
    also have "... = (cmod(1/(sqrt(2))))2" 
      using aux_comp_with_sqrt2_bis by simp
    also have "... = 1/2" 
      by (simp add: norm_divide power2_eq_square)
    finally show "?thesis" by simp
  qed
  moreover have "(cmod(jozsa_algo $$ (1,0)))2 = 1/2"
  proof-
    have "k<2^n  ((1 div 2) ⋅⇘nk) = 0" for k:: nat
      using bitwise_inner_prod_with_zero by simp
    then have "k<2^n  f k + 1 + ((1 div 2) ⋅⇘nk) = 1" for k::nat 
      using const_def assms by simp
    moreover have "(cmod(jozsa_algo $$ (1,0)))2 
    = (cmod (k::nat<2^n. (-1)^(f k + 1 + ((1 div 2) ⋅⇘nk))/(sqrt(2)^n * sqrt(2)^(n+1))))2"
      using ψ3_dim by simp
    ultimately have "(cmod(jozsa_algo $$ (1,0)))2 = (cmod(k::nat<2^n. -1/(sqrt(2)^n * sqrt(2)^(n+1))))2"
      by (smt (verit) lessThan_iff power_one_right sum.cong)
    also have "... = (cmod(-1/(sqrt(2))))2" 
      using aux_comp_with_sqrt2_bis by simp
    also have "... = 1/2" 
      by (simp add: norm_divide power2_eq_square)
    finally show "?thesis" by simp
  qed
  ultimately have "prob0_fst_qubits n jozsa_algo = 1/2 + 1/2" by simp
  then show  "?thesis" by simp
qed

lemma (in jozsa) prob0_jozsa_algo_of_const_1:
  assumes "const 1"
  shows "prob0_fst_qubits n jozsa_algo = 1"
proof-
  have "prob0_fst_qubits n jozsa_algo = (j{0,1}. (cmod(jozsa_algo $$ (j,0)))2)"
    using prob0_fst_qubits_of_jozsa_algo by simp
  moreover have "(cmod(jozsa_algo $$ (0,0)))2 = 1/2"
  proof-
     have "k<2^n  ((0 div 2) ⋅⇘nk) = 0" for k::nat
      using bitwise_inner_prod_with_zero by simp 
    then have "(cmod(jozsa_algo $$ (0,0)))2 = (cmod(k::nat<2^n. 1/(sqrt(2)^n * sqrt(2)^(n+1))))2" 
      using jozsa_algo_result const_def assms by simp
    also have "... = (cmod((-1)/(sqrt(2))))2 " 
      using aux_comp_with_sqrt2_bis by simp
    also have "... = 1/2" 
      by (simp add: norm_divide power2_eq_square)
    finally show "?thesis" by simp
  qed
  moreover have "(cmod(jozsa_algo $$ (1,0)))2 = 1/2"
  proof-
    have "k<2^n  ((1 div 2) ⋅⇘nk) = 0" for k::nat
      using bitwise_inner_prod_with_zero by simp
    then have "(k::nat<2^n. (-1)^(f k +1 + ((1 div 2) ⋅⇘nk))/(sqrt(2)^n * sqrt(2)^(n+1)))
             = (k::nat<2^n. 1/(sqrt(2)^n * sqrt(2)^(n+1)))"
      using const_def assms by simp
    moreover have "(cmod(jozsa_algo $$ (1,0)))2 
    = (cmod (k::nat<2^n. (-1)^(f k + 1 + ((1 div 2) ⋅⇘nk))/(sqrt(2)^n * sqrt(2)^(n+1))))2"
      using  ψ3_dim by simp
    ultimately have "(cmod(jozsa_algo $$ (1,0)))2 = (cmod(k::nat<2^n. 1/(sqrt(2)^n * sqrt(2)^(n+1))))2" by simp
    also have "... = (cmod(1/(sqrt(2))))2 " 
      using aux_comp_with_sqrt2_bis by simp
    also have "... = 1/2" 
      by (simp add: norm_divide power2_eq_square)
    finally show "?thesis" by simp
  qed
  ultimately have "prob0_fst_qubits n jozsa_algo = 1/2 + 1/2" by simp
  then show  "?thesis" by simp
qed

text ‹If the probability of measuring 0 for the first n qubits is 1, then the function is constant.›

lemma (in jozsa) max_value_of_not_const_less:
  assumes "¬ const 0" and "¬ const 1"
  shows "(cmod (k::nat<2^n. (-(1::nat))^(f k)))2 < (2::nat)^(2*n)"
proof-
  have "cmod (k::nat<2^n. (-(1::nat))^(f k)) < 2^n"
  proof-
    have "(k::nat<2^n. (-(1::nat))^(f k)) < 2^n"
    proof-
      obtain x where f0:"x < 2^n" and f1:"f x = 1"
        using assms(1) const_def f_values by auto
      then have "(k::nat<2^n. (-(1::nat))^(f k)) < (k{i| i:: nat. i < 2^n}-{x}. (-(1::nat))^(f k))"
      proof-
        have "(-(1::nat))^ f x = -1" using f1 by simp
        moreover have "x{i| i::nat. i<2^n}" using f0 by simp
        moreover have "finite {i| i::nat. i<2^n}" by simp
        moreover have "(k{i| i::nat. i<2^n}. (-(1::nat))^(f k)) < 
(k{i| i:: nat. i<2^n}-{x}. (-(1::nat))^(f k))"
          using calculation(1,2,3) sum_diff1 by (simp add: sum_diff1)
        ultimately show ?thesis by (metis Collect_cong Collect_mem_eq lessThan_iff)
      qed
      moreover have "  int (2^n - 1)"
        using aux_ineq_with_card(1)[of "{i| i:: nat. i<2^n}-{x}"] f0 by simp
      ultimately show ?thesis
        by (meson diff_le_self less_le_trans of_nat_le_numeral_power_cancel_iff)
   qed
   moreover have "(k::nat<2^n. (-(1::nat))^(f k)) > - (2^n)"
   proof-
      obtain x where f0:"x < 2^n" and f1:"f x = 0"
        using assms(2) const_def f_values by auto
      then have "(k::nat<2^n. (-(1::nat))^(f k)) > (k{i| i:: nat. i < 2^n}-{x}. (-(1::nat))^(f k))"
      proof-
        have "(-(1::nat))^ f x = 1" using f1 by simp
        moreover have "x{i| i::nat. i<2^n}" using f0 by simp
        moreover have "finite {i| i::nat. i<2^n}" by simp
        moreover have "(k{i| i::nat. i<2^n}. (-(1::nat))^(f k)) > 
(k{i| i:: nat. i<2^n}-{x}. (-(1::nat))^(f k))"
          using calculation(1,2,3) sum_diff1 by (simp add: sum_diff1)
        ultimately show ?thesis by (metis Collect_cong Collect_mem_eq lessThan_iff)
      qed
      moreover have "- int (2^n - 1)  (k{i| i:: nat. i < 2^n}-{x}. (-(1::nat))^(f k))"
        using aux_ineq_with_card(2)[of "{i| i:: nat. i<2^n}-{x}"] f0 by simp
      ultimately show ?thesis
        by (smt (verit) diff_le_self of_nat_1 of_nat_add of_nat_power_le_of_nat_cancel_iff one_add_one)
   qed
   ultimately show ?thesis
     using cmod_less of_int_of_nat_eq of_nat_numeral of_nat_power by (metis (no_types, lifting))
  qed
  then have "(cmod (k::nat<2^n. (-(1::nat))^(f k)))2 < (2^n)2"
    using square_less norm_ge_zero by (smt (verit))
  thus ?thesis
    by (simp add: power_even_eq)
qed

lemma (in jozsa) max_value_of_not_const_less_bis:
  assumes "¬ const 0" and "¬ const 1"
  shows "(cmod (k::nat<2^n. (-(1::nat))^(f k + 1)))2 < (2::nat)^(2*n)"
proof-
  have "cmod (k::nat<2^n. (-(1::nat))^(f k + 1)) < 2^n"
  proof-
    have "(k::nat<2^n. (-(1::nat))^(f k + 1)) < 2^n"
    proof-
      obtain x where f0:"x < 2^n" and f1:"f x = 0"
        using assms(2) const_def f_values by auto
      then have "(k::nat<2^n. (-(1::nat))^(f k + 1)) < (k{i| i:: nat. i < 2^n}-{x}. (-(1::nat))^(f k + 1))"
      proof-
        have "(-(1::nat))^ (f x + 1) = -1" using f1 by simp
        moreover have "x{i| i::nat. i<2^n}" using f0 by simp
        moreover have "finite {i| i::nat. i<2^n}" by simp
        moreover have "(k{i| i::nat. i<2^n}. (-(1::nat))^(f k + 1)) < 
(k{i| i:: nat. i<2^n}-{x}. (-(1::nat))^(f k + 1))"
          using calculation(1,2,3) sum_diff1 by (simp add: sum_diff1)
        ultimately show ?thesis by (metis Collect_cong Collect_mem_eq lessThan_iff)
      qed
      moreover have "  int (2^n - 1)"
        using aux_ineq_with_card(1)[of "{i| i:: nat. i<2^n}-{x}" "λk. f k + 1"] f0 by simp
      ultimately show ?thesis
        by (meson diff_le_self less_le_trans of_nat_le_numeral_power_cancel_iff)
   qed
   moreover have "(k::nat<2^n. (-(1::nat))^(f k + 1)) > - (2^n)"
   proof-
      obtain x where f0:"x < 2^n" and f1:"f x = 1"
        using assms(1) const_def f_values by auto
      then have "(k::nat<2^n. (-(1::nat))^(f k + 1)) > (k{i| i:: nat. i < 2^n}-{x}. (-(1::nat))^(f k + 1))"
      proof-
        have "(-(1::nat))^ (f x + 1) = 1" using f1 by simp
        moreover have "x{i| i::nat. i<2^n}" using f0 by simp
        moreover have "finite {i| i::nat. i<2^n}" by simp
        moreover have "(k{i| i::nat. i<2^n}. (-(1::nat))^(f k + 1)) > 
(k{i| i:: nat. i<2^n}-{x}. (-(1::nat))^(f k + 1))"
          using calculation(1,2,3) sum_diff1 by (simp add: sum_diff1)
        ultimately show ?thesis by (metis Collect_cong Collect_mem_eq lessThan_iff)
      qed
      moreover have "- int (2^n - 1)  (k{i| i:: nat. i < 2^n}-{x}. (-(1::nat))^(f k + 1))"
        using aux_ineq_with_card(2)[of "{i| i:: nat. i<2^n}-{x}" "λk. f k + 1"] f0 by simp
      ultimately show ?thesis
        by (smt (verit) diff_le_self of_nat_1 of_nat_add of_nat_power_le_of_nat_cancel_iff one_add_one)
   qed
   ultimately show ?thesis
     using cmod_less of_int_of_nat_eq of_nat_numeral of_nat_power by (metis (no_types, lifting))
  qed
  then have "(cmod (k::nat<2^n. (-(1::nat))^(f k + 1)))2 < (2^n)2"
    using square_less norm_ge_zero by (smt (verit))
  thus ?thesis
    by (simp add: power_even_eq)
qed

lemma (in jozsa) f_const_has_max_value: 
  assumes "const 0  const 1"
  shows "(cmod (k<(2::nat)^n. (-1)^(f k)))2 = (2::nat)^(2*n)" 
  and "(cmod (k<(2::nat)^n. (-1)^(f k + 1)))2 = (2::nat)^(2*n)" 
  using aux_comp_with_cmod[of n "λk. f k"] aux_comp_with_cmod[of n "λk. f k + 1"] const_def assms by auto

lemma (in jozsa) prob0_fst_qubits_leq:
  shows "(cmod (k<(2::nat)^n. (-1)^(f k)))2  (2::nat)^(2*n)" 
    and "(cmod (k<(2::nat)^n. (-1)^(f k + 1)))2  (2::nat)^(2*n)"  
proof-
  show "(cmod (k<(2::nat)^n. (-1)^(f k)))2  (2::nat)^(2*n)" 
  proof(rule disjE)
    show "(const 0  const 1)  (¬ const 0  ¬ const 1)" by auto
  next
    assume "const 0  const 1" 
    then show "(cmod (k<(2::nat)^n. (-1)^(f k)))2  (2::nat)^(2*n)" 
      using f_const_has_max_value by simp
  next
    assume "¬ const 0  ¬ const 1"
    then show "(cmod (k<(2::nat)^n. (-1)^(f k)))2  (2::nat)^(2*n)" 
      using max_value_of_not_const_less by simp
  qed
next
  show "(cmod (k<(2::nat)^n. (-1)^(f k + 1)))2  (2::nat)^(2*n)" 
  proof(rule disjE)
    show "(const 0  const 1)  (¬ const 0  ¬ const 1)" by auto
  next
    assume "const 0  const 1" 
    then show "(cmod (k<(2::nat)^n. (-1)^(f k + 1)))2  (2::nat)^(2*n)" 
      using f_const_has_max_value by simp
  next
    assume "¬ const 0  ¬ const 1"
    then show "(cmod (k<(2::nat)^n. (-1)^(f k + 1)))2  (2::nat)^(2*n)" 
      using max_value_of_not_const_less_bis by simp
  qed
qed

lemma (in jozsa) prob0_jozsa_algo_1_is_const:
  assumes "prob0_fst_qubits n jozsa_algo = 1"
  shows "const 0  const 1"
proof-
  have f0: "(j{0,1}. (cmod(jozsa_algo $$ (j,0)))2) = 1"
    using prob0_fst_qubits_of_jozsa_algo assms by simp
  have "k < 2^n((0 div 2) ⋅⇘nk) = 0" for k::nat
    using bitwise_inner_prod_with_zero by simp 
  then have f1: "(cmod(jozsa_algo $$ (0,0)))2 = (cmod(k<(2::nat)^n. (-1)^(f k)/(sqrt(2)^n * sqrt(2)^(n+1))))2" 
    by simp
  have "k < 2^n((1 div 2) ⋅⇘nk) = 0" for k::nat
    using bitwise_inner_prod_with_zero by simp
  moreover have "(cmod(jozsa_algo $$ (1,0)))2 
               = (cmod (k<(2::nat)^n. (-1)^(f k+ 1 + ((1 div 2) ⋅⇘nk))/(sqrt(2)^n * sqrt(2)^(n+1))))2"
      using ψ3_dim by simp
  ultimately have f2: "(cmod(jozsa_algo $$ (1,0)))2 
                     = (cmod (k<(2::nat)^n. (-1)^(f k + 1)/(sqrt(2)^n * sqrt(2)^(n+1))))2" by simp   
  have f3: "1 = (cmod(k::nat<(2::nat)^n.(-1)^(f k)/(sqrt(2)^n * sqrt(2)^(n+1))))2
        + (cmod (k::nat<(2::nat)^n. (-1)^(f k + 1)/(sqrt(2)^n * sqrt(2)^(n+1))))2" 
    using f0 f1 f2 by simp 
  also have "... = ((cmod (k::nat<(2::nat)^n. (-1)^(f k)) ) /(sqrt(2)^n * sqrt(2)^(n+1)))2
                 + ((cmod(k::nat<(2::nat)^n. (-1)^(f k + 1))) /(sqrt(2)^n * sqrt(2)^(n+1)))2"
    using aux_comp_sum_divide_cmod[of "λk. (-1)^(f k)" "(sqrt(2)^n * sqrt(2)^(n+1))" "(2::nat)^n"] 
          aux_comp_sum_divide_cmod[of "λk. (-1)^(f k + 1)" "(sqrt(2)^n * sqrt(2)^(n+1))" "(2::nat)^n"] 
    by simp
  also have "... = ((cmod (k::nat<(2::nat)^n. (-1)^(f k))))2 /((sqrt(2)^n * sqrt(2)^(n+1)))2
                 + ((cmod(k::nat<(2::nat)^n. (-1)^(f k +1))))2 /((sqrt(2)^n * sqrt(2)^(n+1)))2"
    by (simp add: power_divide)
  also have "... = ((cmod (k::nat<(2::nat)^n. (-1)^(f k)) ) )2/(2^(2*n+1))
                 + ((cmod(k::nat<(2::nat)^n. (-1)^(f k + 1))))2 /(2^(2*n+1))"
    by (smt (verit) left_add_twice power2_eq_square power_add power_mult_distrib real_sqrt_pow2)
  also have "... = (((cmod (k::nat<(2::nat)^n. (-1)^(f k))))2 
                 + ((cmod(k::nat<(2::nat)^n. (-1)^(f k + 1))))2)/(2^(2*n+1)) "
    by (simp add: add_divide_distrib)
  finally have "((2::nat)^(2*n+1)) = (((cmod (k::nat<(2::nat)^n. (-1)^(f k))))2 
                 + ((cmod(k::nat<(2::nat)^n. (-1)^(f k + 1))))2)" by simp
  moreover have "((2::nat)^(2*n+1)) = 2^(2*n) + 2^(2*n)" by auto
  moreover have "(cmod (k<(2::nat)^n. (-1)^(f k)))2  2^(2*n)" 
    using prob0_fst_qubits_leq by simp 
  moreover have "(cmod (k<(2::nat)^n. (-1)^(f k + 1)))2  2^(2*n)" 
    using prob0_fst_qubits_leq by simp 
  ultimately have "2^(2*n) = ((cmod (k::nat<(2::nat)^n. (-1)^(f k))))2" by simp
  then show ?thesis
    using  max_value_of_not_const_less by auto
qed

text ‹
The function is balanced if and only if at least one qubit among the first n qubits is not zero.
So, if the function is balanced then the probability of measuring 0 for the first n qubits is 0.
›

lemma sum_union_disjoint_finite_set:
  fixes C::"nat set" and g::"nat  int"
  assumes "finite C"
  shows "A B. A  B = {}  A  B = C  (kC. g k) = (kA. g k) + (kB. g k)" 
  using assms sum.union_disjoint by auto

lemma (in jozsa) balanced_pos_and_neg_terms_cancel_out1:
  assumes "is_balanced" 
  shows "(k<(2::nat)^n. (-(1::nat))^(f k)) = 0"
proof-
  have "A B. A  {i::nat. i < (2::nat)^n}  B  {i::nat. i < (2::nat)^n}
              card A = ((2::nat)^(n-1))  card B = ((2::nat)^(n-1))  
              ((x::nat)  A. f x = (0::nat))   ((x::nat)  B. f x = 1)
         (k<(2::nat)^n. (-(1::nat))^(f k)) = 0"
  proof
    fix A B::"nat set"
    assume asm: "A  {i::nat. i < (2::nat)^n}  B  {i::nat. i < (2::nat)^n}
              card A = ((2::nat)^(n-1))  card B = ((2::nat)^(n-1))  
              ((x::nat)  A. f x = (0::nat))   ((x::nat)  B. f x = 1)" 
    then have " A  B = {}" and "{0..<(2::nat)^n} = A  B" 
      using is_balanced_union is_balanced_inter by auto
    then have "(k{0..<(2::nat)^n}. (-(1::nat))^(f k)) =
               (kA. (-(1::nat))^(f k)) 
             + (kB. (-(1::nat))^(f k))" 
      by (metis finite_atLeastLessThan sum_union_disjoint_finite_set)
    moreover have "(kA. (-1)^(f k)) = ((2::nat)^(n-1))"
      using asm by simp
    moreover have "(kB. (-1)^(f k)) = -((2::nat)^(n-1))" 
      using asm by simp
    ultimately have "(k {0..<(2::nat)^n}. (-(1::nat))^(f k)) = 0" by simp
    then show "(k<(2::nat)^n. (-(1::nat))^(f k)) = 0"
      by (simp add: lessThan_atLeast0)
  qed
  then show ?thesis 
    using assms is_balanced_def by auto
qed

lemma (in jozsa) balanced_pos_and_neg_terms_cancel_out2:
  assumes "is_balanced" 
  shows "(k<(2::nat)^n. (-(1::nat))^(f k + 1)) = 0"
proof-
  have "A B. A  {i::nat. i < (2::nat)^n}  B  {i::nat. i < (2::nat)^n}
              card A = ((2::nat)^(n-1))  card B = ((2::nat)^(n-1))  
              ((x::nat)A. f x = (0::nat))   ((x::nat)B. f x = 1)
         (k<(2::nat)^n. (-(1::nat))^(f k + 1)) = 0"
  proof
    fix A B::"nat set"
    assume asm: "A  {i::nat. i < (2::nat)^n}  B  {i::nat. i < (2::nat)^n}
              card A = ((2::nat)^(n-1))  card B = ((2::nat)^(n-1))  
              ((x::nat)  A. f x = (0::nat))   ((x::nat)  B. f x = 1)" 
    have "A  B = {}" and "{0..<(2::nat)^n} = A  B" 
      using is_balanced_union is_balanced_inter asm by auto
    then have "(k{0..<(2::nat)^n}. (-(1::nat))^(f k + 1)) =
               (kA. (-(1::nat))^(f k + 1)) 
             + (kB. (-(1::nat))^(f k + 1))" 
      by (metis finite_atLeastLessThan sum_union_disjoint_finite_set)
    moreover have "(kA. (-1)^(f k + 1)) = -((2::nat)^(n-1))" 
      using asm by simp
    moreover have "(kB. (-1)^(f k + 1)) = ((2::nat)^(n-1))" 
      using asm by simp
    ultimately have "(k{0..<(2::nat)^n}. (-(1::nat))^(f k + 1)) = 0 " by simp
    then show "(k<(2::nat)^n. (-(1::nat))^(f k + 1)) = 0"
      by (simp add: lessThan_atLeast0)
  qed
  then show "(k<(2::nat)^n. (-(1::nat))^(f k + 1)) = 0" 
    using assms is_balanced_def by auto
qed

lemma (in jozsa) prob0_jozsa_algo_of_balanced:
assumes "is_balanced"
  shows "prob0_fst_qubits n jozsa_algo = 0"
proof-
  have "prob0_fst_qubits n jozsa_algo = (j{0,1}. (cmod(jozsa_algo $$ (j,0)))2)"
    using prob0_fst_qubits_of_jozsa_algo by simp
  moreover have "(cmod(jozsa_algo $$ (0,0)))2 = 0"
  proof-
     have "k < 2^n((1 div 2) ⋅⇘nk) = 0" for k::nat
      using bitwise_inner_prod_with_zero by simp
    then have "(cmod(jozsa_algo $$ (0,0)))2 = (cmod( k < (2::nat)^n. (-1)^(f k)/(sqrt(2)^n * sqrt(2)^(n+1))))2" 
      using ψ3_values by simp
    also have "... = (cmod(k<(2::nat)^n. (-(1::nat))^(f k))/(sqrt(2)^n * sqrt(2)^(n+1)))2" 
      using aux_comp_sum_divide_cmod[of "λk.(-(1::nat))^(f k)" "(sqrt(2)^n * sqrt(2)^(n+1))" "2^n"] 
      by simp
    also have "... = (cmod ((0::int)/(sqrt(2)^n * sqrt(2)^(n+1))))2" 
      using balanced_pos_and_neg_terms_cancel_out1 assms by (simp add: bob_fun_axioms)
    also have "... = 0" by simp
    finally show ?thesis by simp
  qed
  moreover have "(cmod(jozsa_algo $$ (1,0)))2 = 0" 
  proof-
     have "k < 2^n  (((1::nat) div 2) ⋅⇘nk) = 0" for k::nat
       using bitwise_inner_prod_with_zero by auto
     moreover have "(cmod(jozsa_algo $$ (1,0)))2 
     = (cmod (k<(2::nat)^n. (-(1::nat))^(f k + (1::nat) + ((1 div 2) ⋅⇘nk))/(sqrt(2)^n * sqrt(2)^(n+1))))2"
      using ψ3_dim by simp
    ultimately have "(cmod(jozsa_algo $$ (1,0)))2 
    = (cmod(k<(2::nat)^n. (-(1::nat))^(f k + (1::nat))/(sqrt(2)^n * sqrt(2)^(n+1))))2" 
       by simp
    also have "... = (cmod(k<(2::nat)^n. (-(1::nat))^(f k + 1))/(sqrt(2)^n * sqrt(2)^(n+1)))2" 
      using aux_comp_sum_divide_cmod[of "λk.(-(1::nat))^(f k + 1)" "(sqrt(2)^n * sqrt(2)^(n+1))" "2^n"] 
      by simp
    also have "... = (cmod ((0::int)/(sqrt(2)^n * sqrt(2)^(n+1))))2" 
      using balanced_pos_and_neg_terms_cancel_out2 assms by (simp add: bob_fun_axioms)
    also have "... = 0" by simp
    finally show ?thesis by simp
  qed
  ultimately have "prob0_fst_qubits n jozsa_algo = 0 + 0" by simp
  then show  ?thesis by simp
qed

text ‹If the probability that the first n qubits are 0 is 0, then the function is balanced.›

lemma (in jozsa) balanced_prob0_jozsa_algo:
  assumes "prob0_fst_qubits n jozsa_algo = 0"
  shows "is_balanced"
proof-
  have "is_const  is_balanced" 
    using const_or_balanced by simp
  moreover have "is_const  ¬ prob0_fst_qubits n jozsa_algo = 0"
    using is_const_def prob0_jozsa_algo_of_const_0 prob0_jozsa_algo_of_const_1 by simp
  ultimately show ?thesis 
    using assms by simp
qed

text ‹We prove the correctness of the algorithm.›

definition (in jozsa) jozsa_algo_eval:: "real" where
"jozsa_algo_eval  prob0_fst_qubits n jozsa_algo"

theorem (in jozsa) jozsa_algo_is_correct:
  shows "jozsa_algo_eval = 1  is_const" 
    and "jozsa_algo_eval = 0  is_balanced" 
  using prob0_jozsa_algo_of_const_1 prob0_jozsa_algo_of_const_0 jozsa_algo_eval_def
prob0_jozsa_algo_1_is_const is_const_def balanced_prob0_jozsa_algo prob0_jozsa_algo_of_balanced 
  by auto

end