Theory Lucas_Lehmer

section ‹The Lucas--Lehmer test›
theory Lucas_Lehmer
imports
  Lucas_Lehmer_Auxiliary
  "HOL-Algebra.Ring"
  "Probabilistic_Prime_Tests.Jacobi_Symbol"
  "Pell.Pell" (* only needed for irrationality of sqrt(3) *)
begin

subsection ‹General properties of Mersenne numbers and Mersenne primes›

text ‹
  We mostly follow the proofs given on Wikipedia~cite"wiki:mersenne" and "wiki:lucas_lehmer" in the
  following sections.

  We first show some basic and theorems about Mersenne numbers and Mersenne primes in general,
  beginning with this: Mersenne primes are the only primes of the form $a^n - 1$ for $n > 1$.
›
lemma prime_power_minus_oneD:
  fixes a n :: nat
  assumes "prime (a ^ n - 1)"
  shows   "n = 1  a = 2"
proof -
  from assms have "n > 0"
    by (intro Nat.gr0I) auto
  have "a  0" "a  1"
    by (rule notI, use n > 0 assms in simp add: zero_power)+
  hence "a > 1" by auto
  have "[a - 1 + 1 = 0 + 1] (mod (a - 1))"
    by (rule cong_add) (auto simp: cong_def)
  hence "[a = 1] (mod (a - 1))"
    using a > 1 by simp
  hence "[a ^ n - 1 = 1 ^ n - 1] (mod (a - 1))"
    using a > 1 by (intro cong_pow cong_diff_nat) auto
  hence "(a - 1) dvd (a ^ n - 1)"
    by (simp add: cong_0_iff)
  have "a - 1 = 1  a - 1 = a ^ n - 1"
    using prime (a ^ n - 1) and (a - 1) dvd _ by (rule prime_natD)
  thus ?thesis
  proof
    assume "a - 1 = 1"
    hence "a = 2" by simp
    thus ?thesis by simp
  next
    assume "a - 1 = a ^ n - 1"
    hence "a ^ n = a ^ 1"
      using a > 1 by (simp add: Nat.eq_diff_iff)
    hence "n = 1"
      using a > 1 by (subst (asm) power_inject_exp) auto
    thus ?thesis by simp
  qed
qed

text ‹
  Next, we show that if a prime q› divides a Mersenne number $2^p - 1$ with an odd prime
  exponent p›, then q› must be of the form $q = 1 + 2kp$ for some $k > 0$.
›
lemma prime_dvd_mersenneD:
  fixes p q :: nat
  assumes "prime p" "p  2" "prime q" "q dvd (2 ^ p - 1)"
  shows   "[q = 1] (mod (2 * p))"
proof -
  from assms have "odd p"
    using prime_gt_1_nat[of p] by (intro prime_odd_nat) auto
  have "q  0" "q  1" "q  2"
    using assms by (auto intro!: Nat.gr0I)
  hence "q > 2" by simp
  with prime q have "odd q"
    by (simp add: prime_odd_nat)

  have "ord q 2 = p"
  proof -
    from assms have "[2 ^ p - 1 + 1 = 0 + 1] (mod q)"
      by (intro cong_add cong_refl) (auto simp: cong_0_iff)
    hence "[2 ^ p = 1] (mod q)" by simp
    hence "ord q 2 dvd p"
      by (subst (asm) ord_divides)
    hence "ord q 2 = 1  ord q 2 = p"
      using prime p and prime_natD by blast
    moreover have "ord q 2  1"
      using ord_works[of 2 q] and prime q by (auto simp: cong_altdef_nat)
    ultimately show "ord q 2 = p" by blast
  qed

  have q_dvd_iff: "q dvd (2 ^ x - 1)  p dvd x" for x :: nat
  proof -
    have "q dvd (2 ^ x - 1)  [2 ^ x = 1] (mod q)"
      by (auto simp: cong_altdef_nat)
    also have "  ord q 2 dvd x"
      by (rule ord_divides)
    also note ord q 2 = p
    finally show ?thesis .
  qed

  from q > 2 and assms have "¬q dvd 2"
    using primes_dvd_imp_eq two_is_prime_nat by blast
  hence "[2 ^ (q - 1) - 1 = 1 - 1] (mod q)"
    using assms by (intro fermat_theorem cong_diff_nat) auto
  hence "q dvd (2 ^ (q - 1) - 1)"
    by (simp add: cong_0_iff)
  hence "p dvd (q - 1)"
    by (subst (asm) q_dvd_iff)
  hence "[q = 1] (mod p)"
    using q > 2 by (auto simp: cong_altdef_nat prime_gt_1_nat)

  moreover have "[q = 1] (mod 2)"
    using odd q by (auto simp: cong_def odd_iff_mod_2_eq_one)
  ultimately show "[q = 1] (mod (2 * p))"
    using odd p by (intro coprime_cong_mult_nat) auto
qed

lemma prime_dvd_mersenneD':
  fixes p q :: nat
  assumes "prime p" "p  2" "prime q" "q dvd (2 ^ p - 1)"
  shows "k>0. q = 1 + 2 * k * p"
proof -
  have "q  0" "q  1" "q  2"
    using assms by (auto intro!: Nat.gr0I)
  hence "q > 2" by simp

  have "[q = 1] (mod (2 * p))"
    by (rule prime_dvd_mersenneD) fact+
  hence "(2 * p) dvd (q - 1)"
    using q > 2 by (auto simp: cong_altdef_nat)
  then obtain k where k: "q - 1 = (2 * p) * k"
    by blast
  hence "q = 1 + 2 * k * p"
    using q > 2 by (simp add: algebra_simps)
  moreover have "k > 0"
    using q > 2 and k by (intro Nat.gr0I) auto
  ultimately show ?thesis by blast
qed


text ‹
  A Mersenne number is any number of the form $2^p - 1$ for a natural number $p$. To make things
  a bit more pleasant, we additionally exclude $2^2 - 1$, i.e. we require $p > 2$. It can
  be shown that $p$ is then always an odd prime.
›
locale mersenne_prime =
  fixes p M :: nat
  defines "M  2 ^ p - 1"
  assumes p_gt_2: "p > 2" and prime: "prime M"
begin

lemma M_gt_6: "M > 6"
proof -
  from p_gt_2 have "2 ^ p  (2 ^ 3 :: nat)"
    by (intro power_increasing) auto
  thus ?thesis by (simp add: M_def)
qed

lemma M_odd: "odd M"
  using p_gt_2 by (auto simp: M_def)

theorem p_prime: "prime p"
proof (rule ccontr)
  assume "¬prime p"
  then obtain a b where ab: "p = a * b" "a > 1" "b > 1"
    using p_gt_2 not_prime_imp_ex_prod_nat[of p] by auto

  have geometric_sum_aux: "(x - (1 :: int)) * (k<a. x ^ k) = x ^ a - 1" for x
    by (induction a) (auto simp: algebra_simps)
  have "(2 ^ b - 1 :: int) * (k<a. (2 ^ b) ^ k) = (2 ^ b) ^ a - 1"
    by (rule geometric_sum_aux)
  hence "2 ^ (a*b) - 1 = (2 ^ b - 1 :: int) * (k<a. 2 ^ (k*b))"
    by (simp flip: power_mult add: algebra_simps)
  hence "(2 ^ b - 1) dvd (2 ^ (a*b) - 1 :: int)"
    by simp
  hence "int (2 ^ b - 1) dvd int (2 ^ (a * b) - 1)"
    by (subst of_nat_diff) (auto simp: of_nat_diff)
  hence "(2 ^ b - 1) dvd (2 ^ (a * b) - 1 :: nat)"
    by (subst (asm) int_dvd_int_iff)
  with prime have "2 ^ b - 1 = (1 :: nat)  2 ^ b - 1 = (2 ^ p - 1 :: nat)"
    unfolding ab M_def by (intro prime_natD) auto
  moreover have "2 ^ b > (2 ^ 1 :: nat)"
    using ab by (intro power_strict_increasing) auto
  moreover have "2 ^ b < (2 ^ p :: nat)"
    using ab by (intro power_strict_increasing) auto
  hence "2 ^ b - 1 < (2 ^ p - 1 :: nat)"
    by (subst less_diff_iff) auto
  ultimately show False by auto
qed

lemma p_odd: "odd p"
  using p_prime p_gt_2 prime_odd_nat by auto

text ‹
  We now first show a few more properties of Mersenne primes regarding congruences
  and the Legendre symbol.
›
lemma M_cong_7_mod_12: "[M = 7] (mod 12)"
proof -
  have "[M = 8 - 1] (mod 12)"
    using p_gt_2 p_odd unfolding M_def by (intro cong_diff_nat two_power_odd_mod_12) auto
  thus "[M = 7] (mod 12)" by simp
qed

lemma Legendre_3_M: "Legendre 3 M = -1"
  using prime M_cong_7_mod_12 by (subst Legendre_3_left) (auto simp: cong_def)

lemma M_cong_7_mod_8: "[M = 7] (mod 8)"
proof -
  have "2 ^ 3 dvd (2 ^ p :: int)"
    using p_gt_2 by (intro le_imp_power_dvd) auto
  hence "[2 ^ p - 1 = 0 - 1] (mod (8 :: int))"
    by (intro cong_diff) (auto simp: cong_def)
  also have "2 ^ p - 1 = int M"
    by (simp add: M_def of_nat_diff)
  finally have "int M mod int 8 = 7"
    by (simp add: cong_def)
  thus "[M = 7] (mod 8)"
    by (subst (asm) zmod_int [symmetric]) (auto simp: cong_def)
qed 

lemma Legendre_2_M: "Legendre 2 M = 1"
  using prime M_gt_6 M_cong_7_mod_8
  by (subst supplement2_Legendre') (auto simp: cong_def nat_mod_as_int)

lemma M_not_dvd_24: "¬M dvd 24"
proof
  assume "M dvd 24"
  hence "M dvd 2 * 2 * 2 * 3"
    by simp
  also have "?this  M dvd 2  M dvd 3"
    using prime by (simp only: prime_dvd_mult_iff) auto
  finally show False using M_gt_6 by (auto dest: dvd_imp_le)
qed

end


subsection ‹The Lucas--Lehmer sequence›

text ‹
  We now define the Lucas--Lehmer sequence $a_{n+1} = a_n ^ 2 - 2$. The starting value
  we will always use is $a_0 = 4$.
›
primrec gen_lucas_lehmer_sequence :: "int  nat  int" where
  "gen_lucas_lehmer_sequence a 0 = a"
| "gen_lucas_lehmer_sequence a (Suc n) = gen_lucas_lehmer_sequence a n ^ 2 - 2"

lemma gen_lucas_lehmer_sequence_Suc':
  "gen_lucas_lehmer_sequence a (Suc n) = gen_lucas_lehmer_sequence (a ^ 2 - 2) n"
  by (induction n arbitrary: a) auto

lemmas gen_lucas_lehmer_code [code] =
  gen_lucas_lehmer_sequence.simps(1) gen_lucas_lehmer_sequence_Suc'

text ‹
  For $a_0 = 4$, the recurrence has the closed form $a_{4,n} = \omega^{2^n} + \bar\omega^{2^n}$
  with $\omega = 2 + \sqrt{3}$ and $\bar\omega = 2 - \sqrt{3}$.
›
lemma gen_lucas_lehmer_sequence_4_closed_form1:
  "real_of_int (gen_lucas_lehmer_sequence 4 n) = (2 + sqrt 3) ^ (2 ^ n) + (2 - sqrt 3) ^ (2 ^ n)"
  by (induction n)
     (auto simp: algebra_simps power2_eq_square power_mult simp flip: power_mult_distrib)

lemma gen_lucas_lehmer_sequence_4_closed_form2:
  "gen_lucas_lehmer_sequence 4 n = round ((2 + sqrt 3) ^ (2 ^ n))"
proof (rule sym, rule round_unique')
  have "5 / 3 < sqrt (3 :: real)"
    by (rule real_less_rsqrt) (auto simp: power2_eq_square)
  hence "(2 - sqrt 3) ^ (2 ^ n) < (1 / 3) ^ (2 ^ n)"
    by (intro power_strict_mono) (auto simp: real_le_lsqrt)
  also have "  (1 / 3) ^ 1"
    by (intro power_decreasing) auto
  finally have "(2 - sqrt 3) ^ (2 ^ n) < 1 / 2" by simp
  moreover have "(2 - sqrt 3) ^ (2 ^ n)  0"
    by (intro zero_le_power) (auto simp: real_le_lsqrt)
  ultimately show "¦(2 + sqrt 3) ^ 2 ^ n - real_of_int (gen_lucas_lehmer_sequence 4 n)¦ < 1 / 2"
    unfolding gen_lucas_lehmer_sequence_4_closed_form1 by linarith
qed

lemma gen_lucas_lehmer_sequence_4_closed_form3:
  "gen_lucas_lehmer_sequence 4 n = (2 + sqrt 3) ^ (2 ^ n)"
proof (rule sym, rule ceiling_unique)
  show "real_of_int (gen_lucas_lehmer_sequence 4 n)  (2 + sqrt 3) ^ 2 ^ n"
    unfolding gen_lucas_lehmer_sequence_4_closed_form1 by (auto intro!: zero_le_power real_le_lsqrt)
next
  have "5 / 3 < sqrt (3 :: real)"
    by (rule real_less_rsqrt) (auto simp: power2_eq_square)
  hence "(2 - sqrt 3) ^ (2 ^ n) < (1 / 3) ^ (2 ^ n)"
    by (intro power_strict_mono) (auto simp: real_le_lsqrt)
  also have "  (1 / 3) ^ 1"
    by (intro power_decreasing) auto
  finally have "(2 - sqrt 3) ^ (2 ^ n) < 1 / 2" by simp
  moreover have "(2 - sqrt 3) ^ (2 ^ n)  0"
    by (intro zero_le_power) (auto simp: real_le_lsqrt)
  ultimately show "real_of_int (gen_lucas_lehmer_sequence 4 n) - 1 < (2 + sqrt 3) ^ 2 ^ n"
    unfolding gen_lucas_lehmer_sequence_4_closed_form1 by linarith
qed


subsection ‹The ring $\mathbb{Z}[\sqrt{3}]$›

text ‹
  To relate this sequence to Mersenne primes, we now first need to define the ring
  $\mathbb{Z}[\sqrt{3}]$, which is a subring of $\mathbb{R}$. This ring can be seen as the
  lattice on $\mathbb{R}$ that is freely generated by $1$ and $\sqrt{3}$.

  It is, however, more convenient to explicitly describe it as a ring structure over the 
  set $\mathbb{Z}\times\mathbb{Z}$ with a corresponding injective homomorphism
  $\mathbb{Z}\times\mathbb{Z} \to \mathbb{R}$.
›

definition lucas_lehmer_add' :: "int × int  int × int  int × int" where
  "lucas_lehmer_add' = (λ(a,b) (c,d). (a + c, b + d))"

definition lucas_lehmer_mult' :: "int × int  int × int  int × int" where
  "lucas_lehmer_mult' = (λ(a,b) (c,d). (a * c + 3 * b * d, a * d + b * c))"

definition lucas_lehmer_ring :: "(int × int) ring" where
  "lucas_lehmer_ring =
     carrier = UNIV,
      monoid.mult = lucas_lehmer_mult',
      one = (1, 0),
      ring.zero = (0, 0),
      add = lucas_lehmer_add'"

lemma carrier_lucas_lehmer_ring [simp]: "carrier lucas_lehmer_ring = UNIV"
  by (simp add: lucas_lehmer_ring_def)

lemma cring_lucas_lehmer_ring [intro]: "cring (lucas_lehmer_ring)"
proof
  have "aa ba. lucas_lehmer_add' (aa, ba) (a, b) = (0, 0) 
                lucas_lehmer_add' (a, b) (aa, ba) = (0, 0)" for a b
    by (rule exI[of _ "-a"], rule exI[of _ "-b"]) (auto simp: lucas_lehmer_add'_def)
  thus "carrier (add_monoid lucas_lehmer_ring)  Units (add_monoid lucas_lehmer_ring)"
    by (auto simp: Units_def lucas_lehmer_ring_def)
qed (auto simp: lucas_lehmer_ring_def lucas_lehmer_add'_def lucas_lehmer_mult'_def algebra_simps)


subsection ‹The ring $(\mathbb{Z}/m\mathbb{Z})[\sqrt{3}]$›

text ‹
  We shall also need the ring $(\mathbb{Z}/m\mathbb{Z})[\sqrt{3}]$, which is obtained from
  $\mathbb{Z}[\sqrt{3}]$ by reducing each component separately modulo $m$. This essentially
  identifies any two points that are a multiple of $m$ apart and then all those that are
  a multiple of $m\sqrt{3}$ apart.
›
definition lucas_lehmer_mult :: "nat  nat × nat  nat × nat  nat × nat" where
  "lucas_lehmer_mult m = (λ(a,b) (c,d). ((a * c + 3 * b * d) mod m, (a * d + b * c) mod m))"

definition lucas_lehmer_add :: "nat  nat × nat  nat × nat  nat × nat" where
  "lucas_lehmer_add m = (λ(a,b) (c,d). ((a + c) mod m, (b + d) mod m))"

definition lucas_lehmer_ring_mod :: "nat  (nat × nat) ring" where
  "lucas_lehmer_ring_mod m =
     carrier = {..<m} × {..<m},
      monoid.mult = lucas_lehmer_mult m,
      one = (1, 0),
      ring.zero = (0, 0),
      add = lucas_lehmer_add m"

lemma lucas_lehmer_add_in_carrier: "m > 0  lucas_lehmer_add m x y  {..<m} × {..<m}"
  by (auto simp: lucas_lehmer_add_def split: prod.splits)

lemma lucas_lehmer_mult_in_carrier: "m > 0  lucas_lehmer_mult m x y  {..<m} × {..<m}"
  by (auto simp: lucas_lehmer_mult_def split: prod.splits)

lemma lucas_lehmer_add_cong:
  "[fst (lucas_lehmer_add m x y) = fst x + fst y] (mod m)"
  "[snd (lucas_lehmer_add m x y) = snd x + snd y] (mod m)"
  by (simp_all add: lucas_lehmer_add_def cong_def case_prod_unfold)

lemma lucas_lehmer_mult_cong:
  "[fst (lucas_lehmer_mult m x y) = fst x * fst y + 3 * snd x * snd y] (mod m)"
  "[snd (lucas_lehmer_mult m x y) = fst x * snd y + snd x * fst y] (mod m)"
  by (simp_all add: lucas_lehmer_mult_def cong_def case_prod_unfold)

lemma lucas_lehmer_add_neutral [simp]:
  assumes "fst x < m" "snd x < m"
  shows   "lucas_lehmer_add m (0, 0) x = x"
    and   "lucas_lehmer_add m x (0, 0) = x"
  using assms by (auto simp: lucas_lehmer_add_def case_prod_unfold)

lemma lucas_lehmer_mult_neutral [simp]:
  assumes "fst x < m" "snd x < m"
  shows   "lucas_lehmer_mult m (Suc 0, 0) x = x"
    and   "lucas_lehmer_mult m x (Suc 0, 0) = x"
  using assms by (auto simp: lucas_lehmer_mult_def case_prod_unfold)

lemma lucas_lehmer_add_commute: "lucas_lehmer_add m x y = lucas_lehmer_add m y x"
  by (simp add: lucas_lehmer_add_def algebra_simps case_prod_unfold)

lemma lucas_lehmer_mult_commute: "lucas_lehmer_mult m x y = lucas_lehmer_mult m y x"
  by (simp add: lucas_lehmer_mult_def algebra_simps case_prod_unfold)

lemma lucas_lehmer_add_assoc:
  assumes m: "m > 0"
  shows   "lucas_lehmer_add m x (lucas_lehmer_add m y z) =
           lucas_lehmer_add m (lucas_lehmer_add m x y) z"
proof (rule prod_eqI)
  let ?add = "lucas_lehmer_add m"
  have "[fst (?add x (?add y z)) = fst x + (fst y + fst z)] (mod m)"
    by (rule lucas_lehmer_add_cong[THEN cong_trans] cong_add cong_mult cong_refl)+
  also have "fst x + (fst y + fst z) = (fst x + fst y) + fst z"
    by (simp add: add_ac)
  also have "[ = fst (?add (?add x y) z)] (mod m)"
    by (rule cong_sym, (rule lucas_lehmer_add_cong[THEN cong_trans] cong_add cong_mult cong_refl)+)
  finally show "fst (?add x (?add y z)) = fst (?add (?add x y) z)"
    by (rule cong_less_modulus_unique_nat)
       (use m in auto simp: lucas_lehmer_add_def case_prod_unfold)

  have "[snd (?add x (?add y z)) = snd x + (snd y + snd z)] (mod m)"
    by (rule lucas_lehmer_add_cong[THEN cong_trans] cong_add cong_mult cong_refl)+
  also have "snd x + (snd y + snd z) = (snd x + snd y) + snd z"
    by (simp add: add_ac)
  also have "[ = snd (?add (?add x y) z)] (mod m)"
    by (rule cong_sym, (rule lucas_lehmer_add_cong[THEN cong_trans] cong_add cong_mult cong_refl)+)
  finally show "snd (?add x (?add y z)) = snd (?add (?add x y) z)"
    by (rule cong_less_modulus_unique_nat)
       (use m in auto simp: lucas_lehmer_add_def case_prod_unfold)
qed

lemma lucas_lehmer_mult_assoc:
  assumes m: "m > 0"
  shows   "lucas_lehmer_mult m x (lucas_lehmer_mult m y z) =
           lucas_lehmer_mult m (lucas_lehmer_mult m x y) z"
proof (rule prod_eqI)
  let ?mul = "lucas_lehmer_mult m"
  have "[fst (?mul x (?mul y z)) = fst x * (fst y * fst z + 3 * snd y * snd z) +
           3 * snd x * (fst y * snd z + snd y * fst z)] (mod m)"
    by (rule lucas_lehmer_mult_cong[THEN cong_trans] cong_add cong_mult cong_refl)+
  also have "fst x * (fst y * fst z + 3 * snd y * snd z) +
               3 * snd x * (fst y * snd z + snd y * fst z) =
             (fst x * fst y + 3 * snd x * snd y) * fst z +
               3 * (fst x * snd y + snd x * fst y) * snd z"
    by (simp add: algebra_simps)
  also have "[ = fst (?mul (?mul x y) z)] (mod m)"
    by (rule cong_sym, (rule lucas_lehmer_mult_cong[THEN cong_trans] cong_add cong_mult cong_refl)+)
  finally show "fst (?mul x (?mul y z)) = fst (?mul (?mul x y) z)"
    by (rule cong_less_modulus_unique_nat)
       (use m in auto simp: lucas_lehmer_mult_def case_prod_unfold)

  have "[snd (?mul x (?mul y z)) = fst x * (fst y * snd z + snd y * fst z) +
     snd x * (fst y * fst z + 3 * snd y * snd z)] (mod m)"
    by (rule lucas_lehmer_mult_cong[THEN cong_trans] cong_add cong_mult cong_refl)+
  also have "fst x * (fst y * snd z + snd y * fst z) + snd x * (fst y * fst z + 3 * snd y * snd z) =
             (fst x * fst y + 3 * snd x * snd y) * snd z + (fst x * snd y + snd x * fst y) * fst z"
    by (simp add: algebra_simps)
  also have "[ = snd (?mul (?mul x y) z)] (mod m)"
    by (rule cong_sym, (rule lucas_lehmer_mult_cong[THEN cong_trans] cong_add cong_mult cong_refl)+)
  finally show "snd (?mul x (?mul y z)) = snd (?mul (?mul x y) z)"
    by (rule cong_less_modulus_unique_nat)
       (use m in auto simp: lucas_lehmer_mult_def case_prod_unfold)
qed

lemma lucas_lehmer_distrib_right:
  assumes m: "m > 1"
  shows "lucas_lehmer_mult m (lucas_lehmer_add m x y) z =
         lucas_lehmer_add m (lucas_lehmer_mult m x z) (lucas_lehmer_mult m y z)"
proof (rule prod_eqI)
  let ?mul = "lucas_lehmer_mult m" and ?add = "lucas_lehmer_add m"
  have "[fst (?mul (?add x y) z) = (fst x + fst y) * fst z + 3 * (snd x + snd y) * snd z] (mod m)"
    by (rule lucas_lehmer_mult_cong[THEN cong_trans] lucas_lehmer_add_cong[THEN cong_trans]
             cong_add cong_mult cong_refl)+
  also have "(fst x + fst y) * fst z + 3 * (snd x + snd y) * snd z =
               (fst x * fst z + 3 * snd x * snd z) + (fst y * fst z + 3 * snd y * snd z)"
    by (simp add: algebra_simps)
  also have "[ = fst (?add (?mul x z) (?mul y z))] (mod m)"
    by (rule cong_sym, (rule lucas_lehmer_mult_cong[THEN cong_trans]
          lucas_lehmer_add_cong[THEN cong_trans] cong_add cong_mult cong_refl)+)
  finally show "fst (?mul (?add x y) z) = fst (?add (?mul x z) (?mul y z))"
    by (rule cong_less_modulus_unique_nat)
       (use m in auto simp: lucas_lehmer_add_def lucas_lehmer_mult_def case_prod_unfold)

  have "[snd (?mul (?add x y) z) = (fst x + fst y) * snd z + (snd x + snd y) * fst z] (mod m)"
    by (rule lucas_lehmer_mult_cong[THEN cong_trans] lucas_lehmer_add_cong[THEN cong_trans]
             cong_add cong_mult cong_refl)+
  also have "(fst x + fst y) * snd z + (snd x + snd y) * fst z =
               (fst x * snd z + snd x * fst z) + (fst y * snd z + snd y * fst z)"
    by (simp add: algebra_simps)
  also have "[ = snd (?add (?mul x z) (?mul y z))] (mod m)"
    by (rule cong_sym, (rule lucas_lehmer_mult_cong[THEN cong_trans]
          lucas_lehmer_add_cong[THEN cong_trans] cong_add cong_mult cong_refl)+)
  finally show "snd (?mul (?add x y) z) = snd (?add (?mul x z) (?mul y z))"
    by (rule cong_less_modulus_unique_nat)
       (use m in auto simp: lucas_lehmer_add_def lucas_lehmer_mult_def case_prod_unfold)
qed

lemma lucas_lehmer_distrib_left:
  assumes "m > 1"
  shows "lucas_lehmer_mult m z (lucas_lehmer_add m x y) =
         lucas_lehmer_add m (lucas_lehmer_mult m z x) (lucas_lehmer_mult m z y)"
  using lucas_lehmer_distrib_right[of m x y z] assms
  by (simp add: lucas_lehmer_mult_commute)

lemma cring_lucas_lehmer_ring_mod [intro]:
  assumes "m > 1"
  shows   "cring (lucas_lehmer_ring_mod m)"
proof unfold_locales
  let ?neg = "λx. if x = 0 then 0 else m - x"
  have "xcarrier (lucas_lehmer_ring_mod m).
           x lucas_lehmer_ring_mod m(a, b) = 𝟬lucas_lehmer_ring_mod m
           (a, b) lucas_lehmer_ring_mod mx = 𝟬lucas_lehmer_ring_mod m⇙"
    if "(a, b)  carrier (lucas_lehmer_ring_mod m)" for a b
    using that assms
    by (intro bexI[of _ "(?neg a, ?neg b)"])
       (auto simp: lucas_lehmer_ring_mod_def lucas_lehmer_add_def)
  thus "carrier (add_monoid (lucas_lehmer_ring_mod m))  Units (add_monoid (lucas_lehmer_ring_mod m))"
    by (auto simp: Units_def)
qed (insert assms,
     auto simp: lucas_lehmer_ring_mod_def algebra_simps lucas_lehmer_mult_assoc
                lucas_lehmer_add_assoc lucas_lehmer_distrib_right lucas_lehmer_distrib_left
          intro: lucas_lehmer_mult_in_carrier lucas_lehmer_add_in_carrier
                 lucas_lehmer_add_commute lucas_lehmer_mult_commute)

text ‹
  Since $0$ is clearly not a unit in the ring and its carrier has size $m ^ 2$, the 
  number of units is strictly less than $m ^ 2$.
›
lemma card_lucas_lehmer_Units:
  assumes "m > 1"
  shows   "card (Units (lucas_lehmer_ring_mod m)) < m ^ 2"
proof -
  interpret cring "lucas_lehmer_ring_mod m"
    using assms by auto
  have "m ^ 2 > 0"
    using assms by auto
  from assms have "card (Units (lucas_lehmer_ring_mod m))  card ({..<m} × {..<m} - {(0, 0)})"
    by (intro card_mono) (auto simp: Units_def lucas_lehmer_ring_mod_def lucas_lehmer_mult_def)
  also have " = m ^ 2 - 1"
    using assms by (subst card_Diff_subset) (auto simp: power2_eq_square)
  finally show ?thesis using m ^ 2 > 0 by linarith
qed

text ‹
  Consider now the case of a prime modulus $m$: Since $\mathbb{Z}/m\mathbb{Z} = \text{GF}(m)$
  is a field, any element of $\mathbb{Z}/m\mathbb{Z}$ is a unit in
  $(\mathbb{Z}/m\mathbb{Z})[\sqrt{3}]$.
›
lemma int_in_Units_lucas_lehmer_ring_mod:
  assumes "prime p"
  assumes "x > 0" "x < p"
  shows   "(x, 0)  Units (lucas_lehmer_ring_mod p)"
proof -
  define R where "R = lucas_lehmer_ring_mod p"
  have "[x * (x ^ (p - 2) mod p) = x * x ^ (p - 2)] (mod p)"
    by (intro cong_mult) (auto simp: cong_def)
  also have "x * x ^ (p - 2) = x ^ (Suc (p - 2))"
    by (simp add: mult_ac)
  also have "Suc (p - 2) = p - 1"
    using prime_gt_1_nat[of p] assms by simp
  also have "[x ^ (p - 1) = 1] (mod p)"
    using assms by (intro fermat_theorem) (auto dest: dvd_imp_le)
  finally have "(x, 0) R(x ^ (p - 2) mod p, 0) = 𝟭R⇙"
               "(x ^ (p - 2) mod p, 0) R(x, 0) = 𝟭R⇙"
               "(x ^ (p - 2) mod p, 0)  carrier R"
    using prime_gt_1_nat[of p] assms
    by (auto simp: lucas_lehmer_mult_def cong_def lucas_lehmer_ring_mod_def mult_ac R_def)
  moreover from assms have "(x, 0)  carrier R"
    by (auto simp: R_def lucas_lehmer_ring_mod_def)
  ultimately show ?thesis using assms
    by (auto simp: Units_def R_def)
qed


subsection ‹$\mathbb{Z}[\sqrt{3}]$ as a subring of $\mathbb{R}$›

text ‹
  We now define the homomorphism from $\mathbb{Z}[\sqrt{3}]$ into the reals:
›
definition lucas_lehmer_to_real :: "int × int  real" where
  "lucas_lehmer_to_real = (λ(a,b). real_of_int a + real_of_int b * sqrt 3)"

context
begin

interpretation cring lucas_lehmer_ring ..

lemma minus_lucas_lehmer_ring: "lucas_lehmer_ringx = (case x of (a, b)  (-a, -b))"
  by (rule sym, rule sum_zero_eq_neg)
     (auto simp: case_prod_unfold lucas_lehmer_ring_def lucas_lehmer_add'_def)

lemma lucas_lehmer_to_real_simps1:
      "lucas_lehmer_to_real (a, b) = of_int a + of_int b * sqrt 3"
      "lucas_lehmer_to_real (x lucas_lehmer_ringy) =
       lucas_lehmer_to_real x + lucas_lehmer_to_real y"
      "lucas_lehmer_to_real (x lucas_lehmer_ringy) =
       lucas_lehmer_to_real x * lucas_lehmer_to_real y"
      "lucas_lehmer_to_real (lucas_lehmer_ringx) = -lucas_lehmer_to_real x"
      "lucas_lehmer_to_real (𝟬lucas_lehmer_ring) = 0"
      "lucas_lehmer_to_real (𝟭lucas_lehmer_ring) = 1"
  using minus_lucas_lehmer_ring
  by (simp_all add: lucas_lehmer_to_real_def lucas_lehmer_add'_def lucas_lehmer_mult'_def
                    case_prod_unfold algebra_simps lucas_lehmer_ring_def)

lemma lucas_lehmer_to_add_pow_nat:
  "lucas_lehmer_to_real ([n] lucas_lehmer_ringx) = of_nat n * lucas_lehmer_to_real x"
  by (induction n) (auto simp: lucas_lehmer_to_real_simps1 algebra_simps)

lemma lucas_lehmer_to_add_pow_int:
  "lucas_lehmer_to_real ([n] lucas_lehmer_ringx) = of_int n * lucas_lehmer_to_real x"
proof (cases "n  0")
  case True
  hence "lucas_lehmer_to_real ([n] lucas_lehmer_ringx) =
         lucas_lehmer_to_real ([int (nat n)] lucas_lehmer_ringx)"
    by simp
  also have " = lucas_lehmer_to_real ([nat n] lucas_lehmer_ringx)"
    by (simp add: add_pow_int_ge)
  also have " = of_int n * lucas_lehmer_to_real x" using True
    by (simp add: lucas_lehmer_to_add_pow_nat algebra_simps)
  finally show ?thesis .
next
  case False
  hence "lucas_lehmer_to_real ([n] lucas_lehmer_ringx) =
         lucas_lehmer_to_real (add_pow lucas_lehmer_ring (-int (nat (-n))) x)"
    by simp
  also have "add_pow lucas_lehmer_ring (-int (nat (-n))) x =
              lucas_lehmer_ring(add_pow lucas_lehmer_ring (nat (-n)) x)"
    using False by (subst add.int_pow_neg_int) (auto simp: lucas_lehmer_ring_def)
  also have "lucas_lehmer_to_real  = of_int n * lucas_lehmer_to_real x" using False
    by (simp add: lucas_lehmer_to_add_pow_nat lucas_lehmer_to_real_simps1 algebra_simps)
  finally show ?thesis .
qed

lemma lucas_lehmer_to_real_power:
  "lucas_lehmer_to_real (x [^]lucas_lehmer_ring(n :: nat)) = lucas_lehmer_to_real x ^ n"
  by (induction n) (auto simp: lucas_lehmer_to_real_simps1)

lemmas lucas_lehmer_to_real_simps =
  lucas_lehmer_to_real_simps1 lucas_lehmer_to_real_power
  lucas_lehmer_to_add_pow_nat lucas_lehmer_to_add_pow_int

end

lemma lucas_lehmer_to_real_inj: "inj lucas_lehmer_to_real"
proof (rule injI, clarify)
  fix a b c d :: int
  assume eq: "lucas_lehmer_to_real (a, b) = lucas_lehmer_to_real (c, d)"
  have "b = d"
  proof (rule ccontr)
    assume "b  d"
    hence "sqrt 3 = (c - a) / (b - d)"
      using eq by (simp add: lucas_lehmer_to_real_def field_simps)
    also have "  " by auto
    finally have "sqrt 3  " .
    moreover have "sqrt 3  "
      using is_nth_power_prime_power_nat_iff[of 3 2 1] irrat_sqrt_nonsquare[of 3] by auto
    ultimately show False by contradiction
  qed
  moreover from this and eq have "a = c"
    by (auto simp: lucas_lehmer_to_real_def)
  ultimately show "a = c  b = d" by blast
qed


subsection ‹The canonical homomorphism $\mathbb{Z}[\sqrt 3] \to (\mathbb{Z}/m\mathbb{Z})[\sqrt 3]$›

text ‹
  Next, we show that reduction modulo $m$ is indeed a homomorphism.
›
definition lucas_lehmer_hom :: "nat  (int × int)  (nat × nat)" where
  "lucas_lehmer_hom m = (λ(x,y). (nat (x mod m), nat (y mod m)))"

lemma lucas_lehmer_hom_cong:
  "[fst x = fst y] (mod int m)  [snd x = snd y] (mod int m) 
   lucas_lehmer_hom m x = lucas_lehmer_hom m y"
  by (auto simp: lucas_lehmer_hom_def cong_def case_prod_unfold)

lemma lucas_lehmer_hom_cong':
  "[a = b] (mod int m)  [c = d] (mod int m) 
   lucas_lehmer_hom m (a, c) = lucas_lehmer_hom m (b, d)"
  by (auto simp: lucas_lehmer_hom_def cong_def)

context
  fixes m :: nat
  assumes m: "m > 1"
begin

lemma lucas_lehmer_hom_in_carrier: "lucas_lehmer_hom m x  {..<m} × {..<m}"
  using m nat_less_iff by (auto simp: lucas_lehmer_hom_def case_prod_unfold)

lemma lucas_lehmer_hom_add:
  "lucas_lehmer_hom m (lucas_lehmer_add' x y) =
   lucas_lehmer_add m (lucas_lehmer_hom m x) (lucas_lehmer_hom m y)"
proof (rule prod_eqI)
  let ?add1 = "lucas_lehmer_add'" and ?add2 = "lucas_lehmer_add m"
  let  = "lucas_lehmer_hom m"
  have "fst ( (?add1 x y)) = nat ((fst x + fst y) mod int m)"
    by (simp add: lucas_lehmer_hom_def lucas_lehmer_add'_def case_prod_unfold)
  also have "(fst x + fst y) mod int m = ((fst x mod m) + (fst y mod m)) mod int m"
    by (simp add: mod_add_eq)
  also have "nat  = (nat (fst x mod int m) + nat (fst y mod int m)) mod m"
    using m nat_add_distrib nat_mod_distrib by auto
  also have " = fst (?add2 ( x) ( y))"
    by (auto simp: lucas_lehmer_hom_def lucas_lehmer_add_def case_prod_unfold)
  finally show "fst ( (?add1 x y)) = fst (?add2 ( x) ( y))" .

  have "snd ( (?add1 x y)) = nat ((snd x + snd y) mod int m)"
    by (simp add: lucas_lehmer_hom_def lucas_lehmer_add'_def case_prod_unfold)
  also have "(snd x + snd y) mod int m = ((snd x mod m) + (snd y mod m)) mod int m"
    by (simp add: mod_add_eq)
  also have "nat  = (nat (snd x mod int m) + nat (snd y mod int m)) mod m"
    using m nat_add_distrib nat_mod_distrib by auto
  also have " = snd (?add2 ( x) ( y))"
    by (auto simp: lucas_lehmer_hom_def lucas_lehmer_add_def case_prod_unfold)
  finally show "snd ( (?add1 x y)) = snd (?add2 ( x) ( y))" .
qed

lemma lucas_lehmer_hom_mult:
  "lucas_lehmer_hom m (lucas_lehmer_mult' x y) =
   lucas_lehmer_mult m (lucas_lehmer_hom m x) (lucas_lehmer_hom m y)"
proof (rule prod_eqI)
  let ?mul1 = "lucas_lehmer_mult'" and ?mul2 = "lucas_lehmer_mult m"
  let  = "lucas_lehmer_hom m"
  have "fst ( (?mul1 x y)) = nat ((fst x * fst y + 3 * snd x * snd y) mod int m)"
    by (simp add: lucas_lehmer_hom_def lucas_lehmer_mult'_def case_prod_unfold)
  also have "(fst x * fst y + 3 * snd x * snd y) mod int m =
               ((fst x mod int m) * (fst y mod int m) +
                3 * (snd x mod int m) * (snd y mod int m)) mod m"
    by (intro congD cong_mult cong_add cong_refl) (auto simp: cong_def)
  also have " = int (nat (((fst x mod int m) * (fst y mod int m) +
                3 * (snd x mod int m) * (snd y mod int m)) mod m))"
    using m by (subst of_nat_nat) auto
  also have " = int (nat (fst x mod int m) * nat (fst y mod int m) +
                3 * (nat (snd x mod int m)) * nat (snd y mod int m)) mod m"
    using m by simp
  also have "nat  = (nat (fst x mod int m) * nat (fst y mod int m) +
           3 * nat (snd x mod int m) * nat (snd y mod int m)) mod m"
    using m by (metis nat_int zmod_int)
  also have " = fst (?mul2 ( x) ( y))"
    by (simp add: lucas_lehmer_hom_def lucas_lehmer_mult_def case_prod_unfold)
  finally show "fst ( (?mul1 x y)) = fst (?mul2 ( x) ( y))" .

  have "snd ( (?mul1 x y)) = nat ((fst x * snd y + snd x * fst y) mod int m)"
    by (simp add: lucas_lehmer_hom_def lucas_lehmer_mult'_def case_prod_unfold)
  also have "(fst x * snd y + snd x * fst y) mod int m =
               ((fst x mod int m) * (snd y mod int m) +
                (snd x mod int m) * (fst y mod int m)) mod m"
    by (intro congD cong_mult cong_add cong_refl) (auto simp: cong_def)
  also have " = int (nat (((fst x mod int m) * (snd y mod int m) +
                (snd x mod int m) * (fst y mod int m)) mod m))"
    using m by (subst of_nat_nat) auto
  also have " = int (nat (fst x mod int m) * nat (snd y mod int m) +
                    (nat (snd x mod int m)) * nat (fst y mod int m)) mod m"
    using m by simp
  also have "nat  = (nat (fst x mod int m) * nat (snd y mod int m) +
               nat (snd x mod int m) * nat (fst y mod int m)) mod m"
    using m by (metis nat_int zmod_int)
  also have " = snd (?mul2 ( x) ( y))"
    by (simp add: lucas_lehmer_hom_def lucas_lehmer_mult_def case_prod_unfold)
  finally show "snd ( (?mul1 x y)) = snd (?mul2 ( x) ( y))" .
qed

lemma lucas_lehmer_hom_1 [simp]: "lucas_lehmer_hom m (1, 0) = (1, 0)"
  using m by (simp add: lucas_lehmer_hom_def)

lemma ring_hom_lucas_lehmer_hom:
  "lucas_lehmer_hom m  ring_hom lucas_lehmer_ring (lucas_lehmer_ring_mod m)"
proof -
  interpret R: cring lucas_lehmer_ring ..
  from m interpret S: cring "lucas_lehmer_ring_mod m" ..
  show ?thesis
    unfolding ring_hom_def using lucas_lehmer_hom_in_carrier m
    by (auto simp: lucas_lehmer_ring_mod_def lucas_lehmer_hom_add
                   lucas_lehmer_ring_def lucas_lehmer_hom_mult)
qed

end


subsection ‹Correctness of the Lucas--Lehmer test›

text ‹
  In this section, we will prove that the Lucas--Lehmer test is both a necessary and sufficient
  condition for the primality of a Mersenne number of the form $2^p - 1$ for an odd prime $p$.
  The proof that shall be given here is rather explicit and heavily draws from the Wikipedia
  article on the Lucas--Lehmer test~cite"wiki:lucas_lehmer".

  A shorter and more high-level proof of a more general statement can be obtained using more
  theory on finite fields (in particular the field $\text{GF}(q^2)$ (cf.\ e.\,g.\ 
  Rödseth~cite"roedseth94").
›

definition lucas_lehmer_test where
  "lucas_lehmer_test p = (p > 2 
     (2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2))"

text ‹
  We can now prove that any Mersenne number $2^p - 1$ for $p$ prime that passes the 
  Lucas--Lehmer test is prime. We follow the simple argument given by Bruce~cite"bruce93",
  which is also given on Wikipedia~cite"wiki:lucas_lehmer".
›
theorem lucas_lehmer_sufficient:
  assumes "prime p" "odd p"
  assumes "(2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2)"
  shows   "prime (2 ^ p - 1 :: nat)"
proof (rule ccontr)
  assume not_prime: "¬prime (2 ^ p - 1 :: nat)"
  from assms obtain k :: int where k: "gen_lucas_lehmer_sequence 4 (p - 2) = k * (2 ^ p - 1)"
    by (elim dvdE) (auto simp: mult_ac)
  from assms have "p > 2"
    using odd_prime_gt_2_nat by blast
  from p > 2 have "2 ^ p  (2 ^ 3 :: nat)" by (intro power_increasing) auto
  hence "2 ^ p  (8 :: nat)" by simp

  define q :: nat where "q = Min (prime_factors (2 ^ p - 1))"
  have "q  prime_factors (2 ^ p - 1)" using 2 ^ p  8
    unfolding q_def by (intro Min_in) (auto simp: prime_factorization_empty_iff)
  hence q: "prime q" "q dvd (2 ^ p - 1 :: nat)"
    by (auto simp: in_prime_factors_iff)
  have q_minimal: "q  q'" if "q'  prime_factors (2 ^ p - 1)" for q'
    unfolding q_def by (rule Min_le) (use that in auto)

  have "2 ^ p - 1  q ^ 2"
  proof -
    from q obtain k where k: "2 ^ p - 1 = q * k" by auto
    have "prime_factorization (2 ^ p - 1 :: nat)  {#q#}"
    proof
      assume *: "prime_factorization (2 ^ p - 1 :: nat) = {#q#}"
      have "2 ^ p - 1 = prod_mset (prime_factorization (2 ^ p - 1 :: nat))"
        using 2 ^ p  8 by (subst prod_mset_prime_factorization_nat) auto
      also have " = q" by (subst *) auto
      finally show False using not_prime q by simp
    qed
    hence "prime_factorization k  {#}" using q k 2 ^ p  8
      by (subst (asm) k, subst (asm) prime_factorization_mult)
         (auto intro!: Nat.gr0I simp: prime_factorization_prime)
    hence "k  1" by (auto simp: prime_factorization_empty_iff)
    then obtain q' where q': "prime q'" "q' dvd k"
      using prime_factor_nat by blast
    from q' k 2 ^ p  8 have "q  q'"
      by (intro q_minimal) (auto simp: in_prime_factors_iff intro!: Nat.gr0I)
    hence "q ^ 2  q * q'"
      unfolding power2_eq_square by (intro mult_mono) auto
    also have "q * q'  2 ^ p - 1"
      using q q' k 2 ^ p  8 by (intro dvd_imp_le) (auto intro!: Nat.gr0I)
    finally show "2 ^ p - 1  q ^ 2" .
  qed

  have "q  2" using q p > 2 by auto
  moreover from q have "q  0" "q  1" by auto
  ultimately have "q > 2" by auto

  write lucas_lehmer_ring ("R")
  define S where "S = lucas_lehmer_ring_mod q"
  define S' where "S' = units_of S"
  define φ where "φ = lucas_lehmer_hom q"

  interpret R: cring R ..
  interpret S: cring S
    unfolding S_def by (rule cring_lucas_lehmer_ring_mod) (use q > 2 in auto)
  interpret S': comm_group S'
    unfolding S'_def by (rule S.units_comm_group)
  have "φ  ring_hom R S"
    unfolding φ_def S_def by (rule ring_hom_lucas_lehmer_hom) (use q > 2 in auto)
  interpret φ: ring_hom_cring R S φ
    by standard fact

  have "(2 + sqrt 3) ^ (2 ^ (p - 2)) + (2 - sqrt 3) ^ (2 ^ (p - 2)) =
          real_of_int (gen_lucas_lehmer_sequence 4 (p - 2))"
    unfolding gen_lucas_lehmer_sequence_4_closed_form1 ..
  also have " = real_of_int k * (2 ^ p - 1)"
    by (simp add: k)
  finally have *: "(2 + sqrt 3) ^ (2 ^ (p - 2)) =
                     real_of_int k * (2 ^ p - 1) - (2 - sqrt 3) ^ (2 ^ (p - 2))"
    by (simp add: algebra_simps)
  have "((2 + sqrt 3) ^ (2 ^ (p - 2))) ^ 2 =
             real_of_int k * (2 ^ p - 1) * (2 + sqrt 3) ^ (2 ^ (p - 2)) -
             (2 - sqrt 3) ^ (2 ^ (p - 2)) * (2 + sqrt 3) ^ (2 ^ (p - 2))"
    unfolding power2_eq_square by (subst *) (simp add: algebra_simps)
  also have "((2 + sqrt 3) ^ (2 ^ (p - 2))) ^ 2 = (2 + sqrt 3) ^ (2 * 2 ^ (p - 2))"
    by (simp flip: power_mult add: mult_ac)
  also have "2 * 2 ^ (p - 2) = 2 ^ (Suc (p - 2))"
    by simp
  also from p > 2 have "Suc (p - 2) = p - 1"
    by linarith
  also have "(2 - sqrt 3) ^ (2 ^ (p - 2)) * (2 + sqrt 3) ^ (2 ^ (p - 2)) = 1"
    by (subst power_mult_distrib [symmetric]) (auto simp: algebra_simps)
  finally have "(2 + sqrt 3) ^ (2 ^ (p - 1)) =
                  real_of_int k * (2 ^ p - 1) * (2 + sqrt 3) ^ (2 ^ (p - 2)) - 1" .

  also have "(2 + sqrt 3) ^ (2 ^ (p - 1)) =
               lucas_lehmer_to_real ((2, 1) [^]R(2 ^ (p - 1) :: nat))"
    by (simp add: lucas_lehmer_to_real_simps)
  also have "real_of_int k * (2 ^ p - 1) * (2 + sqrt 3) ^ (2 ^ (p - 2)) - 1 =
               lucas_lehmer_to_real ((k * (2 ^ p - 1), 0) R(2, 1) [^]R(2 ^ (p - 2) :: nat) RR𝟭R)"
    by (simp add: lucas_lehmer_to_real_simps)
  finally have "((2, 1) [^]R(2 ^ (p - 1) :: nat)) =
                ((k * (2 ^ p - 1), 0) R(2, 1) [^]R(2 ^ (p - 2) :: nat) RR𝟭R)"
    by (rule injD[OF lucas_lehmer_to_real_inj])

  hence "φ ((2, 1) [^]R(2 ^ (p - 1) :: nat)) =
           φ ((k * (2 ^ p - 1), 0) R(2, 1) [^]R(2 ^ (p - 2) :: nat) RR𝟭R)"
    by (simp only:)
  also have "φ ((2, 1) [^]R(2 ^ (p - 1) :: nat)) = φ (2, 1) [^]S(2 ^ (p - 1) :: nat)"
    by simp
  also {
    have "int q dvd int (2 ^ p - 1)"
      by (subst int_dvd_int_iff) (use q in auto)
    also have "int (2 ^ p - 1) = 2 ^ p - 1"
      by (simp add: of_nat_diff)
    finally have "φ (k * (2 ^ p - 1), 0) = 𝟬S⇙"
      by (simp add: φ_def lucas_lehmer_hom_def S_def lucas_lehmer_ring_mod_def)
  }
  hence "φ ((k * (2 ^ p - 1), 0) R(2, 1) [^]R(2 ^ (p - 2) :: nat) RR𝟭R) = 
           S𝟭S⇙"
    by simp
  finally have eq: "φ (2, 1) [^]S(2 ^ (p - 1) :: nat) = S𝟭S⇙" .

  have "φ (2, 1) [^]S(2 ^ p :: nat) = φ (2, 1) [^]S(2 ^ (p - 1) * 2 :: nat)"
    using p > 2 by (cases p) (auto simp: mult_ac)
  also have " = (φ (2, 1) [^]S(2 ^ (p - 1) :: nat)) [^]S(2 :: nat)"
    by (subst S.nat_pow_pow) auto
  also have " = 𝟭S⇙"
    by (subst eq) (auto simp: numeral_2_eq_2 S.l_minus)
  finally have eq': "φ (2, 1) [^]S(2 ^ p :: nat) = 𝟭S⇙" .

  from eq' have unit: "φ (2, 1)  Units S"
    by (rule S.pow_nat_eq_1_imp_unit) auto

  have neg_one_not_one: "S𝟭S 𝟭S⇙"
  proof
    assume *: "S𝟭S= 𝟭S⇙"
    have "(S𝟭S) S𝟭S= 𝟬S⇙"
      by (rule S.l_neg) auto
    hence "𝟭SS𝟭S= 𝟬S⇙"
      by (simp only: *)
    thus False using q > 2
      by (auto simp: S_def lucas_lehmer_ring_mod_def lucas_lehmer_add_def)
  qed

  have fin: "finite (Units S)"
    by (rule finite_subset[of _ "carrier S"]) (auto simp: Units_def S_def lucas_lehmer_ring_mod_def)

  have "group.ord S' (φ (2, 1)) = 2 ^ p"
    using p > 2 eq eq' unit neg_one_not_one
    by (intro S'.ord_eqI_prime_factors)
       (auto simp: prime_factors_power prime_factorization_prime
                   S'_def S.units_of_pow units_of_carrier units_of_one power_diff)
  hence "2 ^ p = group.ord S' (φ (2, 1))"
    by simp
  also have " = card (generate S' {φ (2, 1)})"
    using unit fin
    by (intro S'.generate_pow_card) (auto simp: S'_def units_of_carrier)
  also have "  card (carrier S')"
    using fin unit by (intro card_mono S'.generate_incl) (auto simp: S'_def units_of_carrier)
  also have " < q ^ 2"
    unfolding S'_def S_def using card_lucas_lehmer_Units[of q] q > 2
    by (auto simp: units_of_carrier)
  also note q ^ 2  2 ^ p - 1
  finally show False by simp
qed

text ‹
  Next, we show that any Mersenne prime passes the Lucas--Lehmer test. We again follow the
  rather explicit proof outlined on Wikipedia~cite"wiki:lucas_lehmer", which is a simplified
  (but less general and less abstract) version of the proof by Rödseth~cite"roedseth94".
›
theorem (in mersenne_prime) lucas_lehmer_necessary:
  "(2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2)"
proof -
  write lucas_lehmer_ring ("R")
  define S where "S = lucas_lehmer_ring_mod M"
  define S' where "S' = units_of S"
  define φ where "φ = lucas_lehmer_hom M"

  interpret R: cring R ..
  interpret S: cring S unfolding S_def
    by (rule cring_lucas_lehmer_ring_mod) (use M_gt_6 in auto)
  interpret S': comm_group S'
    unfolding S'_def by (rule S.units_comm_group)
  have "φ  ring_hom R S" unfolding φ_def S_def
    by (rule ring_hom_lucas_lehmer_hom) (use M_gt_6 in auto)
  interpret φ: ring_hom_cring R S φ
    by standard fact

  have R_pow_int: "(n, 0) [^]Rm = (n ^ m, 0)" for n :: int and m :: nat
    by (induction m; simp; simp add: lucas_lehmer_ring_def lucas_lehmer_mult'_def)

  have "add_pow R n 𝟭R= (int n, 0)" for n
    by (induction n; simp; simp add: lucas_lehmer_ring_def lucas_lehmer_add'_def)
  hence "add_pow R M 𝟭R= (int M, 0)"
    by simp
  also have "φ  = 𝟬S⇙"
    by (simp add: φ_def S_def lucas_lehmer_ring_mod_def lucas_lehmer_hom_def)
  finally have "add_pow S M 𝟭S= 𝟬S⇙"
    by (simp add: φ.hom_add_pow_nat)

  define σ :: "int × int" where "σ = (0, 2)"
  have eq1: "φ ((6, 2) [^]RM) = φ (6, -2)"
  proof -
    have "(6, 2) = (6, 0) Rσ"
      by (simp add: lucas_lehmer_ring_def σ_def lucas_lehmer_add'_def)
    also have "φ ( [^]RM) = φ ((6, 0) [^]RM) Sφ (σ [^]RM)"
      using prime and add_pow S M 𝟭S= 𝟬S⇙›
      by (simp add: S.binomial_finite_char)
    also have "(6, 0) [^]RM = (6 ^ M, 0)"
      by (simp add: R_pow_int)
    also have "[6 ^ M = 6] (mod (int M))" using M_gt_6
      by (intro little_Fermat_int) (use prime in auto simp flip: dvd_nat_abs_iff)
    hence "φ (6 ^ M, 0) = φ (6, 0)"
      unfolding φ_def by (intro lucas_lehmer_hom_cong) auto
    also have "σ = (2, 0) R(0, 1)"
      by (simp add: σ_def lucas_lehmer_ring_def lucas_lehmer_mult'_def)
    hence "φ (σ [^]RM) = φ ((2, 0) [^]RM R(0, 1) [^]RM)"
      by (subst R.nat_pow_distrib [symmetric]) auto
    also have " = φ ((2, 0) [^]RM) Sφ ((0, 1) [^]RM)"
      by simp
    also have "(2, 0) [^]RM = (2 ^ M, 0)"
      by (simp add: R_pow_int)
    also have "[2 ^ M = 2] (mod int M)" using M_gt_6 prime
      by (intro little_Fermat_int) (auto simp flip: dvd_nat_abs_iff dest: dvd_imp_le)
    hence "φ (2 ^ M, 0) = φ (2, 0)"
      unfolding φ_def by (intro lucas_lehmer_hom_cong) auto
    also have M_eq: "M = Suc (2 * ((M - 1) div 2))"
      using M_odd by auto
    have "(0, 1) [^]RM = (0, 1) R((0, 1) [^]R(2::nat)) [^]R((M - 1) div 2)"
      by (subst M_eq) (auto simp: R.nat_pow_mult R.nat_pow_pow R.cring_simprules)
    also have "(0, 1) [^]R(2::nat) = (3, 0)"
      by (simp add: eval_nat_numeral) (simp add: lucas_lehmer_ring_def lucas_lehmer_mult'_def)
    also have "φ ((0, 1) R(3, 0) [^]R((M - 1) div 2)) = 
                 φ ((3, 0) [^]R((M - 1) div 2)) Sφ (0, 1)"
      by (simp add: S.cring_simprules)
    also have "(3, 0) [^]R((M - 1) div 2) = (3 ^ ((M - 1) div 2), 0)"
      by (simp add: R_pow_int)
    also have "φ (3 ^ ((M - 1) div 2), 0) = φ (-1, 0)"
      unfolding φ_def
    proof (intro lucas_lehmer_hom_cong')
      have "[3 ^ ((M - 1) div 2) = Legendre 3 M] (mod int M)"
        by (rule cong_sym, rule euler_criterion) (use prime M_gt_6 in auto)
      thus "[3 ^ ((M - 1) div 2) = -1] (mod int M)"
        by (simp add: Legendre_3_M)
    qed auto
    also have "φ (2, 0) S(φ (- 1, 0) Sφ (0, 1)) = φ ((2, 0) R(-1, 0) R(0, 1))"
      by (simp add: R.cring_simprules S.cring_simprules)
    also have "φ (6, 0) Sφ ((2, 0) R(- 1, 0) R(0, 1)) =
                 φ ((6, 0) R(2, 0) R(- 1, 0) R(0, 1))"
      by simp
    also have " = φ (6, -2)" unfolding φ_def
      by (intro lucas_lehmer_hom_cong)
         (auto simp: lucas_lehmer_ring_def lucas_lehmer_mult'_def lucas_lehmer_add'_def)
    finally show "φ ((6, 2) [^]RM) = φ (6, -2)"
      by (simp add: R.cring_simprules S.cring_simprules)
  qed

  have eq2: "φ ((24, 0) [^]R((M - 1) div 2)) = S𝟭S⇙"
  proof -
    have "(24, 0) = (2, 0) [^]R(3::nat) R(3, 0)"
      by (simp add: eval_nat_numeral) (auto simp: lucas_lehmer_ring_def lucas_lehmer_mult'_def)
    also have " [^]R((M - 1) div 2) =
                 ((2, 0) [^]R((M - 1) div 2)) [^]R(3::nat) R(3, 0) [^]R((M - 1) div 2)"
      by (simp add: R.cring_simprules R.nat_pow_distrib R.nat_pow_pow mult_ac)
    also have "φ  = (φ ((2, 0) [^]R((M - 1) div 2))) [^]S(3::nat) Sφ ((3, 0) [^]R((M - 1) div 2))" by simp
    also have "(2, 0) [^]R((M - 1) div 2) = (2 ^ ((M - 1) div 2), 0)"
      by (simp add: R_pow_int)
    also have "φ  = φ (1, 0)"
      unfolding φ_def
    proof (intro lucas_lehmer_hom_cong')
      have "[2 ^ ((M - 1) div 2) = Legendre 2 M] (mod int M)"
        by (rule cong_sym, rule euler_criterion) (use prime M_gt_6 in auto)
      thus "[2 ^ ((M - 1) div 2) = 1] (mod int M)"
        using Legendre_2_M by simp
    qed auto
    also have "(1, 0) = 𝟭R⇙"
      by (simp add: lucas_lehmer_ring_def)
    also have "(3, 0) [^]R((M - 1) div 2) = (3 ^ ((M - 1) div 2), 0)"
      by (simp add: R_pow_int)
    also have "φ  = φ (-1, 0)"
      unfolding φ_def
    proof  (intro lucas_lehmer_hom_cong')
      have "[3 ^ ((M - 1) div 2) = Legendre 3 M] (mod int M)"
        by (rule cong_sym, rule euler_criterion) (use prime M_gt_6 in auto)
      thus "[3 ^ ((M - 1) div 2) = -1] (mod int M)"
        using Legendre_3_M by simp
    qed auto
    also have "(-1, 0) = R𝟭R⇙"
      using minus_lucas_lehmer_ring by (simp add: lucas_lehmer_ring_def)
    finally show "φ ((24, 0) [^]R((M - 1) div 2)) = S𝟭S⇙"
      by simp
  qed

  define ω ω' :: "int × int" where "ω = (2, 1)" and "ω' = (2, -1)"
  have eq3: "φ (ω [^]R((M + 1) div 2)) = S𝟭S⇙"
  proof -
    have "(M + 1) div 2 = Suc ((M - 1) div 2)"
      using M_odd M_gt_6 by (auto elim!: oddE)  
    have *: "φ ((24, 0) Rω) = φ ((6, 2) [^]R(2 :: nat))" unfolding φ_def
      by (intro lucas_lehmer_hom_cong)
         (simp_all add: eval_nat_numeral,
          auto simp: lucas_lehmer_ring_def lucas_lehmer_mult'_def ω_def)
    have "φ (R𝟭R) Sφ ((24, 0) Rω) [^]S((M + 1) div 2) =
            φ (R𝟭R) Sφ ((6, 2) [^]R(2 :: nat)) [^]S((M + 1) div 2)"
      by (subst *) auto
    hence "φ (R𝟭R) Sφ ((24, 0) [^]R((M + 1) div 2)) Sφ (ω [^]R((M + 1) div 2)) =
            φ (R𝟭R) Sφ ((6, 2) [^]R(2 * ((M + 1) div 2)))"
      by (simp add: R.nat_pow_distrib S.nat_pow_distrib R.nat_pow_pow
                    S.nat_pow_pow R.cring_simprules S.cring_simprules)
    also have "2 * ((M + 1) div 2) = M + 1"
      using M_odd by auto
    finally have "φ (24, 0) S(φ (R𝟭R) Sφ ((24, 0) [^]R((M - 1) div 2))) Sφ (ω [^]R((M + 1) div 2)) =
                  φ (R𝟭R) S(φ (6, 2) Sφ ((6, 2) [^]RM))"
      by (subst (asm) (M + 1) div 2 = _) (simp add: S.cring_simprules R.cring_simprules)
  
    also have "φ ((24, 0) [^]R((M - 1) div 2)) = S𝟭S⇙"
      by (subst eq2) auto
    also have "(φ (R𝟭R) SS𝟭S) = 𝟭S⇙"
      by (simp add: S.cring_simprules)
    also have "φ ((6, 2) [^]RM) = φ (6, -2)"
      by (subst eq1) auto
    also have "φ (6, 2) Sφ (6, -2) = φ ((6, 2) R(6, -2))"
      by simp
    also have " = φ (24, 0)" unfolding φ_def
      by (intro lucas_lehmer_hom_cong) (auto simp: lucas_lehmer_ring_def lucas_lehmer_mult'_def)
    finally have "φ (24, 0) S(φ (ω [^]R((M + 1) div 2))) =
                  φ (24, 0) Sφ (R𝟭R)"
      by (simp add: S.cring_simprules)
    also have "φ (24, 0) = (24 mod M, 0)"
      by (simp add: φ_def lucas_lehmer_hom_def nat_mod_as_int)
    finally have "(24 mod M, 0) S(φ (ω [^]R((M + 1) div 2))) =
                  (24 mod M, 0) Sφ (R𝟭R)" .
    moreover have "(24 mod M, 0)  Units S"
      unfolding S_def using M_gt_6 prime M_not_dvd_24
      by (intro int_in_Units_lucas_lehmer_ring_mod) (auto simp: dvd_mod_iff intro!: Nat.gr0I)
    ultimately show "φ (ω [^]R((M + 1) div 2)) = S𝟭S⇙"
      by (subst (asm) S.Units_l_cancel) auto
  qed

  have eq4: "φ (ω [^]R(2 ^ (p - 2) :: nat) Rω' [^]R(2 ^ (p - 2) :: nat)) = 𝟬S⇙"
    (is "φ ?lhs = _")
  proof -
    have "φ (ω [^]R((M + 1) div 2)) Sφ (ω' [^]R((M + 1) div 4)) Sφ (ω' [^]R((M + 1) div 4)) = 𝟬S⇙"
      by (subst eq3) (auto simp: S.cring_simprules)
    also have "2 ^ 2 dvd (2 ^ p :: nat)"
      by (intro le_imp_power_dvd) (use p_gt_2 in auto)
    hence "4 dvd (M + 1)" by (auto simp: M_def)
    hence "(M + 1) div 2 = (M + 1) div 4 + (M + 1) div 4"
      by presburger
    also have "φ (ω [^]R) Sφ (ω' [^]R((M + 1) div 4)) =
               φ (ω Rω') [^]S((M + 1) div 4) Sφ (ω [^]R((M + 1) div 4))"
      by (simp add: S.cring_simprules S.nat_pow_distrib flip: S.nat_pow_mult)
    also have "φ (ω Rω') = φ 𝟭R⇙" unfolding φ_def
      by (intro lucas_lehmer_hom_cong)
         (auto simp: ω_def ω'_def lucas_lehmer_ring_def lucas_lehmer_mult'_def)
    also have "(M + 1) div 4 = 2 ^ (p - 2)"
      using p_gt_2 by (auto simp: M_def power_diff)
    finally show eq4: "φ (ω [^]R(2 ^ (p - 2) :: nat) Rω' [^]R(2 ^ (p - 2) :: nat)) = 𝟬S⇙"
      by simp
  qed

  have "φ ?lhs = 𝟬S⇙"
    by (rule eq4)
  also have "lucas_lehmer_to_real ?lhs =
             lucas_lehmer_to_real (gen_lucas_lehmer_sequence 4 (p - 2), 0)"
    by (simp add: ω_def ω'_def lucas_lehmer_to_real_simps gen_lucas_lehmer_sequence_4_closed_form1)
  hence "?lhs = (gen_lucas_lehmer_sequence 4 (p - 2), 0)"
    by (rule injD[OF lucas_lehmer_to_real_inj])
  finally have "gen_lucas_lehmer_sequence 4 (p - 2) mod M = 0" using M_gt_6
    by (auto simp: φ_def lucas_lehmer_hom_def S_def lucas_lehmer_ring_mod_def)
  thus "(2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2)"
    by (simp add: M_def mod_eq_0_iff_dvd of_nat_diff)
qed

corollary lucas_lehmer_correct:
  "prime (2 ^ p - 1 :: nat) 
     prime p  (p = 2  (2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2))"
proof (intro iffI; (elim conjE)?)
  assume prime: "prime (2 ^ p - 1 :: nat)"
  from prime have "p  0" "p  1"
    by (auto intro!: Nat.gr0I)
  hence "p = 2  p > 2" by auto
  thus "prime p  (p = 2  (2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2))"
  proof (elim disjE)
    assume "p > 2"
    with prime interpret mersenne_prime p "2 ^ p - 1"
      by unfold_locales
    from lucas_lehmer_necessary p_prime show ?thesis by auto
  qed auto
next
  assume prime: "prime p" and *: "p = 2  (2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2)"
  from * consider "p = 2" | "p  2" "(2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2)"
    by auto
  thus "prime (2 ^ p - 1 :: nat)"
  proof cases
    assume "p  2" and dvd: "(2 ^ p - 1) dvd gen_lucas_lehmer_sequence 4 (p - 2)"
    from prime p and p  2 have "p > 2"
      using prime_gt_1_nat[of p] by auto
    with prime have "odd p" by (auto simp: prime_odd_nat)
    with prime dvd show ?thesis      
      by (intro lucas_lehmer_sufficient)
  qed auto
qed

corollary lucas_lehmer_correct':
  "prime (2 ^ p - 1 :: nat)  prime p  (p = 2  lucas_lehmer_test p)"
  using lucas_lehmer_correct[of p] prime_gt_1_nat[of p]
  by (auto simp: lucas_lehmer_test_def)


subsection ‹A first executable version Lucas--Lehmer test›

text ‹
  The following is an implementation of the Lucas--Lehmer test using modular
  arithmetic on the integers. This is not the most efficient implementation --
  the modular arithmetic can be replaced by much cheaper bitwise operations,
  and we will do that in the next section.
›

primrec gen_lucas_lehmer_sequence' :: "int  int  nat  int" where
  "gen_lucas_lehmer_sequence' m a 0 = a"
| "gen_lucas_lehmer_sequence' m a (Suc n) = gen_lucas_lehmer_sequence' m ((a ^ 2 - 2) mod m) n"

lemma gen_lucas_lehmer_sequence'_Suc':
  "gen_lucas_lehmer_sequence' m a (Suc n) = (gen_lucas_lehmer_sequence' m a n ^ 2 - 2) mod m"
  by (induction n arbitrary: a) auto

lemma gen_lucas_lehmer_sequence'_correct:
  assumes "a  {0..<m}"
  shows   "gen_lucas_lehmer_sequence' m a n = gen_lucas_lehmer_sequence a n mod m"
  using assms
proof (induction n)
  case (Suc n)
  have "gen_lucas_lehmer_sequence' m a (Suc n) =
          ((gen_lucas_lehmer_sequence a n mod m)2 - 2) mod m"
    using Suc unfolding gen_lucas_lehmer_sequence'_Suc' by simp
  also have " = ((gen_lucas_lehmer_sequence a n)2 - 2) mod m"
    by (intro congD cong_diff cong_pow cong_refl) (auto simp: cong_def)
  finally show ?case by simp
qed auto

lemma lucas_lehmer_test_code_arithmetic [code]:
  "lucas_lehmer_test p = (p > 2 
     gen_lucas_lehmer_sequence' (2 ^ p - 1) 4 (p - 2) = 0)"
  unfolding lucas_lehmer_test_def
proof (intro conj_cong refl)
  assume p: "p > 2"
  from p have "2 ^ p  (2 ^ 3 :: int)" by (intro power_increasing) auto
  have "(2 ^ p - 1 dvd gen_lucas_lehmer_sequence 4 (p - 2)) 
          gen_lucas_lehmer_sequence 4 (p - 2) mod (2 ^ p - 1) = 0"
    by auto
  also have "gen_lucas_lehmer_sequence 4 (p - 2) mod (2 ^ p - 1) =
             gen_lucas_lehmer_sequence' (2 ^ p - 1) 4 (p - 2)"
    using 2 ^ p  2 ^ 3
    by (intro gen_lucas_lehmer_sequence'_correct [symmetric]) auto
  finally show "(2 ^ p - 1 dvd gen_lucas_lehmer_sequence 4 (p - 2)) =
                (gen_lucas_lehmer_sequence' (2 ^ p - 1) 4 (p - 2) = 0)" .
qed

lemma mersenne_prime_iff: "mersenne_prime p  p > 2  prime (2 ^ p - 1 :: nat)"
  by (simp add: mersenne_prime_def)

lemma mersenne_prime_code [code]:
  "mersenne_prime p  prime p  lucas_lehmer_test p"
  unfolding mersenne_prime_iff using lucas_lehmer_correct'[of p]
  by (auto simp: lucas_lehmer_test_def)

end