Theory PAC_Polynomials

theory PAC_Polynomials
  imports PAC_Specification Finite_Map_Multiset
begin


section ‹Polynomials of strings›

text ‹

  Isabelle's definition of polynomials only work with variables of type
  typnat. Therefore, we introduce a version that uses strings by using an injective function
  that converts a string to a natural number. It exists because strings are countable. Remark that
  the whole development is independent of the function.

›

subsection ‹Polynomials and Variables›

lemma poly_embed_EX:
  φ. bij (φ :: string  nat)
  by (rule countableE_infinite[of UNIV :: string set])
     (auto intro!: infinite_UNIV_listI)

text ‹Using a multiset instead of a list has some advantage from an abstract point of view. First,
  we can have monomials that appear several times  and the coefficient can also be zero. Basically,
  we can represent un-normalised polynomials, which is very useful to talk about intermediate states
  in our program.
›
type_synonym term_poly = string multiset
type_synonym mset_polynomial =
  (term_poly * int) multiset

definition normalized_poly :: mset_polynomial  bool where
  normalized_poly p 
     distinct_mset (fst `# p) 
     0 ∉# snd `# p

lemma normalized_poly_simps[simp]:
  normalized_poly {#}
  normalized_poly (add_mset t p)  snd t  0 
    fst t ∉# fst `# p  normalized_poly p
  by (auto simp: normalized_poly_def)

lemma normalized_poly_mono:
  normalized_poly B  A ⊆# B  normalized_poly A
  unfolding normalized_poly_def
  by (auto intro: distinct_mset_mono image_mset_subseteq_mono)

definition mult_poly_by_monom :: term_poly * int  mset_polynomial  mset_polynomial where
  mult_poly_by_monom  = (λys q. image_mset (λxs. (fst xs + fst ys, snd ys * snd xs)) q)


definition mult_poly_raw :: mset_polynomial  mset_polynomial  mset_polynomial where
  mult_poly_raw p q =
    (sum_mset ((λy. mult_poly_by_monom y q) `# p))


definition remove_powers :: mset_polynomial  mset_polynomial where
  remove_powers xs =  image_mset (apfst remdups_mset) xs


definition all_vars_mset :: mset_polynomial  string multiset where
  all_vars_mset p = # (fst `# p)

abbreviation all_vars :: mset_polynomial  string set where
  all_vars p  set_mset (all_vars_mset p)

definition add_to_coefficient :: _  mset_polynomial  mset_polynomial  where
  add_to_coefficient = (λ(a, n) b. {#(a', _) ∈# b. a'  a#} +
             (if n + sum_mset (snd `# {#(a', _) ∈# b. a' = a#}) = 0 then {#}
               else {#(a, n + sum_mset (snd `# {#(a', _) ∈# b. a' = a#}))#}))

definition normalize_poly :: mset_polynomial  mset_polynomial where
  normalize_poly p = fold_mset add_to_coefficient {#} p

lemma add_to_coefficient_simps:
  n + sum_mset (snd `# {#(a', _) ∈# b. a' = a#})  0 
    add_to_coefficient (a, n) b = {#(a', _) ∈# b. a'  a#} +
             {#(a, n + sum_mset (snd `# {#(a', _) ∈# b. a' = a#}))#}
  n + sum_mset (snd `# {#(a', _) ∈# b. a' = a#}) = 0 
    add_to_coefficient (a, n) b = {#(a', _) ∈# b. a'  a#} and
  add_to_coefficient_simps_If:
  add_to_coefficient (a, n) b = {#(a', _) ∈# b. a'  a#} +
             (if n + sum_mset (snd `# {#(a', _) ∈# b. a' = a#}) = 0 then {#}
               else {#(a, n + sum_mset (snd `# {#(a', _) ∈# b. a' = a#}))#})
  by (auto simp: add_to_coefficient_def)

interpretation comp_fun_commute add_to_coefficient
proof -
  have [iff]:
    a  aa 
    ((case x of (a', _)  a' = a)  (case x of (a', _)  a'  aa)) 
    (case x of (a', _)  a' = a) for a' aa a x
    by auto
  show comp_fun_commute add_to_coefficient
    unfolding add_to_coefficient_def
    by standard
     (auto intro!: ext simp: filter_filter_mset ac_simps add_eq_0_iff)
qed

lemma normalized_poly_normalize_poly[simp]:
  normalized_poly (normalize_poly p)
  unfolding normalize_poly_def
  apply (induction p)
  subgoal by auto
  subgoal for x p
    by (cases x)
      (auto simp: add_to_coefficient_simps_If
      intro: normalized_poly_mono)
  done


subsection ‹Addition›

inductive add_poly_p :: mset_polynomial × mset_polynomial × mset_polynomial  mset_polynomial × mset_polynomial × mset_polynomial  bool where
add_new_coeff_r:
    add_poly_p (p, add_mset x q, r) (p, q, add_mset x r) |
add_new_coeff_l:
    add_poly_p (add_mset x p, q, r) (p, q, add_mset x r) |
add_same_coeff_l:
    add_poly_p (add_mset (x, n) p, q, add_mset (x, m) r) (p, q, add_mset (x, n + m) r) |
add_same_coeff_r:
    add_poly_p (p, add_mset (x, n) q, add_mset (x, m) r) (p, q, add_mset (x, n + m) r) |
rem_0_coeff:
    add_poly_p (p, q, add_mset (x, 0) r) (p, q, r)

inductive_cases add_poly_pE: add_poly_p S T

lemmas add_poly_p_induct =
  add_poly_p.induct[split_format(complete)]

lemma add_poly_p_empty_l:
  add_poly_p** (p, q, r) ({#}, q, p + r)
  apply (induction p arbitrary: r)
  subgoal by auto
  subgoal
    by (metis (no_types, lifting) add_new_coeff_l r_into_rtranclp
      rtranclp_trans union_mset_add_mset_left union_mset_add_mset_right)
  done

lemma add_poly_p_empty_r:
  add_poly_p** (p, q, r) (p, {#}, q + r)
  apply (induction q arbitrary: r)
  subgoal by auto
  subgoal
    by (metis (no_types, lifting) add_new_coeff_r r_into_rtranclp
      rtranclp_trans union_mset_add_mset_left union_mset_add_mset_right)
  done

lemma add_poly_p_sym:
  add_poly_p (p, q, r) (p', q', r')  add_poly_p (q, p, r) (q', p', r')
  apply (rule iffI)
  subgoal
    by (cases rule: add_poly_p.cases, assumption)
      (auto intro: add_poly_p.intros)
  subgoal
    by (cases rule: add_poly_p.cases, assumption)
      (auto intro: add_poly_p.intros)
  done

lemma wf_if_measure_in_wf:
  wf R  (a b. (a, b)  S  (ν a, ν b)R)  wf S
  by (metis in_inv_image wfE_min wfI_min wf_inv_image)

lemma lexn_n:
  n > 0  (x # xs, y # ys)  lexn r n 
  (length xs = n-1  length ys = n-1)  ((x, y)  r  (x = y  (xs, ys)  lexn r (n - 1)))
  apply (cases n)
   apply simp
  by (auto simp: map_prod_def image_iff lex_prod_def)

lemma wf_add_poly_p:
  wf {(x, y). add_poly_p y x}
  by (rule wf_if_measure_in_wf[where R = lexn less_than 3 and
     ν = λ(a,b,c). [size a , size b, size c]])
    (auto simp: add_poly_p.simps wf_lexn
     simp: lexn_n simp del: lexn.simps(2))

lemma mult_poly_by_monom_simps[simp]:
  mult_poly_by_monom t {#} = {#}
  mult_poly_by_monom t (ps + qs) =  mult_poly_by_monom t ps + mult_poly_by_monom t qs
  mult_poly_by_monom a (add_mset p ps) = add_mset (fst a + fst p, snd a * snd p) (mult_poly_by_monom a ps)
proof -
  interpret comp_fun_commute (λxs. add_mset (xs + t)) for t
    by standard auto
  show
    mult_poly_by_monom t (ps + qs) =  mult_poly_by_monom t ps + mult_poly_by_monom t qs for t
    by (induction ps)
      (auto simp: mult_poly_by_monom_def)
  show
    mult_poly_by_monom a (add_mset p ps) = add_mset (fst a + fst p, snd a * snd p) (mult_poly_by_monom a ps)
    mult_poly_by_monom t {#} = {#}for t
    by (auto simp: mult_poly_by_monom_def)
qed

inductive mult_poly_p :: mset_polynomial  mset_polynomial × mset_polynomial  mset_polynomial × mset_polynomial  bool
  for q :: mset_polynomial where
mult_step:
    mult_poly_p q (add_mset (xs, n) p, r) (p, (λ(ys, m). (remdups_mset (xs + ys), n * m)) `# q + r)


lemmas mult_poly_p_induct = mult_poly_p.induct[split_format(complete)]

subsection ‹Normalisation›

inductive normalize_poly_p :: mset_polynomial  mset_polynomial  boolwhere
rem_0_coeff[simp, intro]:
    normalize_poly_p p q  normalize_poly_p (add_mset (xs, 0) p) q |
merge_dup_coeff[simp, intro]:
    normalize_poly_p p q  normalize_poly_p (add_mset (xs, m) (add_mset (xs, n) p)) (add_mset (xs, m + n) q) |
same[simp, intro]:
    normalize_poly_p p p |
keep_coeff[simp, intro]:
    normalize_poly_p p q  normalize_poly_p (add_mset x p) (add_mset x q)


subsection ‹Correctness›
text ‹
  This locales maps string polynomials to real polynomials.
›
locale poly_embed =
  fixes φ :: string  nat
  assumes φ_inj: inj φ
begin

definition poly_of_vars :: "term_poly  ('a :: {comm_semiring_1}) mpoly" where
  poly_of_vars xs = fold_mset (λa b. Var (φ a) * b) (1 :: 'a mpoly) xs

lemma poly_of_vars_simps[simp]:
  shows
    poly_of_vars (add_mset x xs) = Var (φ x) * (poly_of_vars xs :: ('a :: {comm_semiring_1}) mpoly) (is ?A) and
    poly_of_vars (xs + ys) = poly_of_vars xs * (poly_of_vars ys :: ('a :: {comm_semiring_1}) mpoly) (is ?B)
proof -
  interpret comp_fun_commute (λa b. (b :: 'a :: {comm_semiring_1} mpoly) * Var (φ a))
    by standard
      (auto simp: algebra_simps ac_simps
         Var_def times_monomial_monomial intro!: ext)

  show ?A
    by (auto simp: poly_of_vars_def comp_fun_commute_axioms fold_mset_fusion
      ac_simps)
  show ?B
    apply (auto simp: poly_of_vars_def ac_simps)
    by (simp add: local.comp_fun_commute_axioms local.fold_mset_fusion
      semiring_normalization_rules(18))
qed


definition mononom_of_vars where
  mononom_of_vars  (λ(xs, n). (+) (Const n * poly_of_vars xs))

interpretation comp_fun_commute mononom_of_vars
  by standard
    (auto simp: algebra_simps ac_simps mononom_of_vars_def
       Var_def times_monomial_monomial intro!: ext)

lemma [simp]:
  poly_of_vars {#} = 1
  by (auto simp: poly_of_vars_def)

lemma mononom_of_vars_add[simp]:
  NO_MATCH 0 b  mononom_of_vars xs b = Const (snd xs) * poly_of_vars (fst xs) + b
  by (cases xs)
    (auto simp: ac_simps mononom_of_vars_def)

definition polynomial_of_mset :: mset_polynomial  _ where
  polynomial_of_mset p = sum_mset (mononom_of_vars `# p) 0

lemma polynomial_of_mset_append[simp]:
  polynomial_of_mset (xs + ys) = polynomial_of_mset xs + polynomial_of_mset ys
  by (auto simp: ac_simps Const_def polynomial_of_mset_def)

lemma polynomial_of_mset_Cons[simp]:
  polynomial_of_mset (add_mset x ys) = Const (snd x) * poly_of_vars (fst x) + polynomial_of_mset ys
  by (cases x)
    (auto simp: ac_simps polynomial_of_mset_def mononom_of_vars_def)

lemma polynomial_of_mset_empty[simp]:
  polynomial_of_mset {#} = 0
  by (auto simp: polynomial_of_mset_def)

lemma polynomial_of_mset_mult_poly_by_monom[simp]:
  polynomial_of_mset (mult_poly_by_monom x ys) =
       (Const (snd x) * poly_of_vars (fst x) * polynomial_of_mset ys)
 by (induction ys)
   (auto simp: Const_mult algebra_simps)

lemma polynomial_of_mset_mult_poly_raw[simp]:
  polynomial_of_mset (mult_poly_raw xs ys) = polynomial_of_mset xs * polynomial_of_mset ys
  unfolding mult_poly_raw_def
  by (induction xs arbitrary: ys)
   (auto simp: Const_mult algebra_simps)

lemma polynomial_of_mset_uminus:
  polynomial_of_mset {#case x of (a, b)  (a, - b). x ∈# za#} =
    - polynomial_of_mset za
  by (induction za)
    auto


lemma X2_X_polynomial_bool_mult_in:
  Var (x1) * (Var (x1) * p) -  Var (x1) * p  More_Modules.ideal polynomial_bool
  using ideal_mult_right_in[OF  X2_X_in_pac_ideal[of x1 {}, unfolded pac_ideal_def], of p]
  by (auto simp: right_diff_distrib ac_simps power2_eq_square)


lemma polynomial_of_list_remove_powers_polynomial_bool:
  (polynomial_of_mset xs) - polynomial_of_mset (remove_powers xs)  ideal polynomial_bool
proof (induction xs)
  case empty
  then show ?case by (auto simp: remove_powers_def ideal.span_zero)
next
  case (add x xs)
  have H1: x1 ∈# x2 
       Var (φ x1) * poly_of_vars x2 - p  More_Modules.ideal polynomial_bool 
       poly_of_vars x2 - p  More_Modules.ideal polynomial_bool for x1 x2 p
    apply (subst (2) ideal.span_add_eq[symmetric,
      of Var (φ x1) * poly_of_vars x2 - poly_of_vars x2])
    apply (drule multi_member_split)
    apply (auto simp: X2_X_polynomial_bool_mult_in)
    done

  have diff: poly_of_vars (x) - poly_of_vars (remdups_mset (x))  ideal polynomial_bool for x
    by (induction x)
     (auto simp: remove_powers_def ideal.span_zero H1
      simp flip: right_diff_distrib intro!: ideal.span_scale)
  have [simp]: polynomial_of_mset xs -
    polynomial_of_mset (apfst remdups_mset `# xs)
     More_Modules.ideal polynomial_bool 
    poly_of_vars ys * poly_of_vars ys -
    poly_of_vars ys * poly_of_vars (remdups_mset ys)
     More_Modules.ideal polynomial_bool 
    polynomial_of_mset xs + Const y * poly_of_vars ys -
    (polynomial_of_mset (apfst remdups_mset `# xs) +
    Const y * poly_of_vars (remdups_mset ys))
     More_Modules.ideal polynomial_bool for y ys
    by (metis add_diff_add diff ideal.scale_right_diff_distrib ideal.span_add ideal.span_scale)
  show ?case
    using add
    apply (cases x)
    subgoal for ys y
      using ideal_mult_right_in2[OF diff, of poly_of_vars ys ys]
      by (auto simp: remove_powers_def right_diff_distrib
        ideal.span_diff ideal.span_add field_simps)
    done
qed

lemma add_poly_p_polynomial_of_mset:
  add_poly_p (p, q, r) (p', q', r') 
    polynomial_of_mset r + (polynomial_of_mset p + polynomial_of_mset q) =
    polynomial_of_mset r' + (polynomial_of_mset p' + polynomial_of_mset q')
  apply (induction rule: add_poly_p_induct)
  subgoal
    by auto
  subgoal
    by auto
  subgoal
    by (auto simp: algebra_simps Const_add)
  subgoal
    by (auto simp: algebra_simps Const_add)
  subgoal
    by (auto simp: algebra_simps Const_add)
  done

lemma rtranclp_add_poly_p_polynomial_of_mset:
  add_poly_p** (p, q, r) (p', q', r') 
    polynomial_of_mset r + (polynomial_of_mset p + polynomial_of_mset q) =
    polynomial_of_mset r' + (polynomial_of_mset p' + polynomial_of_mset q')
  by (induction rule: rtranclp_induct[of add_poly_p (_, _, _) (_, _, _), split_format(complete), of for r])
    (auto dest: add_poly_p_polynomial_of_mset)


lemma rtranclp_add_poly_p_polynomial_of_mset_full:
  add_poly_p** (p, q, {#}) ({#}, {#}, r') 
    polynomial_of_mset r' = (polynomial_of_mset p + polynomial_of_mset q)
  by (drule rtranclp_add_poly_p_polynomial_of_mset)
    (auto simp: ac_simps add_eq_0_iff)

lemma poly_of_vars_remdups_mset:
  poly_of_vars (remdups_mset (xs)) - (poly_of_vars xs)
     More_Modules.ideal polynomial_bool
  apply (induction xs)
  subgoal by (auto simp: ideal.span_zero)
  subgoal for x xs
    apply (cases x ∈# xs)
     apply (metis (no_types, lifting) X2_X_polynomial_bool_mult_in diff_add_cancel diff_diff_eq2
        ideal.span_diff insert_DiffM poly_of_vars_simps(1) remdups_mset_singleton_sum)
    by (metis (no_types, lifting) ideal.span_scale poly_of_vars_simps(1) remdups_mset_singleton_sum
        right_diff_distrib)
  done

lemma polynomial_of_mset_mult_map:
  polynomial_of_mset
     {#case x of (ys, n)  (remdups_mset (ys + xs), n * m). x ∈# q#} -
    Const m * (poly_of_vars xs * polynomial_of_mset q)
     More_Modules.ideal polynomial_bool
  (is ?P q  _)
proof (induction q)
  case empty
  then show ?case by (auto simp: algebra_simps ideal.span_zero)
next
  case (add x q)
  then have uP:  -?P q  More_Modules.ideal polynomial_bool
    using ideal.span_neg by blast
  have Const b * (Const m * poly_of_vars (remdups_mset (a + xs))) -
           Const b * (Const m * (poly_of_vars a * poly_of_vars xs))
            More_Modules.ideal polynomial_bool for a b
    by (auto simp: Const_mult simp flip: right_diff_distrib' poly_of_vars_simps
        intro!: ideal.span_scale poly_of_vars_remdups_mset)
  then show ?case
    apply (subst ideal.span_add_eq2[symmetric, OF uP])
    apply (cases x)
    apply (auto simp: field_simps Const_mult  simp flip:
        intro!: ideal.span_scale poly_of_vars_remdups_mset)
    done
qed

lemma mult_poly_p_mult_ideal:
  mult_poly_p q (p, r) (p', r') 
     (polynomial_of_mset p' * polynomial_of_mset q + polynomial_of_mset r') - (polynomial_of_mset p * polynomial_of_mset q + polynomial_of_mset r)
        ideal polynomial_bool
proof (induction rule: mult_poly_p_induct)
  case (mult_step xs n p r)
  show ?case
    by (auto simp: algebra_simps polynomial_of_mset_mult_map)
qed

lemma rtranclp_mult_poly_p_mult_ideal:
  (mult_poly_p q)** (p, r) (p', r') 
     (polynomial_of_mset p' * polynomial_of_mset q + polynomial_of_mset r') - (polynomial_of_mset p * polynomial_of_mset q + polynomial_of_mset r)
        ideal polynomial_bool
  apply (induction p' r' rule: rtranclp_induct[of mult_poly_p q (p, r) (p', q') for p' q', split_format(complete)])
  subgoal
    by (auto simp: ideal.span_zero)
  subgoal for a b aa ba
    apply (drule mult_poly_p_mult_ideal)
    apply (drule ideal.span_add)
    apply assumption
    by (auto simp: group_add_class.diff_add_eq_diff_diff_swap
        add.inverse_distrib_swap ac_simps add_diff_eq
      simp flip:  diff_add_eq_diff_diff_swap)
  done

lemma rtranclp_mult_poly_p_mult_ideal_final:
  (mult_poly_p q)** (p, {#}) ({#}, r) 
    (polynomial_of_mset r) - (polynomial_of_mset p * polynomial_of_mset q)
        ideal polynomial_bool
  by (drule rtranclp_mult_poly_p_mult_ideal) auto

lemma normalize_poly_p_poly_of_mset:
  normalize_poly_p p q  polynomial_of_mset p = polynomial_of_mset q
  apply (induction rule: normalize_poly_p.induct)
  apply (auto simp: Const_add algebra_simps)
  done


lemma rtranclp_normalize_poly_p_poly_of_mset:
  normalize_poly_p** p q  polynomial_of_mset p = polynomial_of_mset q
  by (induction rule: rtranclp_induct)
    (auto simp: normalize_poly_p_poly_of_mset)

end


text ‹It would be nice to have the property in the other direction too, but this requires a deep
dive into the definitions of polynomials.›
locale poly_embed_bij = poly_embed +
  fixes V N
  assumes φ_bij: bij_betw φ V N
begin

definition φ' :: nat  string where
  φ' = the_inv_into V φ

lemma φ'_φ[simp]:
  x  V  φ' (φ x) = x
  using φ_bij unfolding φ'_def
  by (meson bij_betw_imp_inj_on the_inv_into_f_f)

lemma φ_φ'[simp]:
  x  N  φ (φ' x) = x
  using φ_bij unfolding φ'_def
  by (meson f_the_inv_into_f_bij_betw)

end

end