Theory PAC_Checker_Init

(*
  File:         PAC_Checker_Init.thy
  Author:       Mathias Fleury, Daniela Kaufmann, JKU
  Maintainer:   Mathias Fleury, JKU
*)
theory PAC_Checker_Init
  imports  PAC_Checker WB_Sort PAC_Checker_Relation
begin

section ‹Initial Normalisation of Polynomials›

subsection ‹Sorting›

text ‹Adapted from the theory text‹HOL-ex.MergeSort› by Tobias Nipkow. We did not change much, but
   we refine it to executable code and try to improve efficiency.
›

fun merge :: "_   'a list  'a list  'a list"
where
  "merge f (x#xs) (y#ys) =
         (if f x y then x # merge f xs (y#ys) else y # merge f (x#xs) ys)"
| "merge f xs [] = xs"
| "merge f [] ys = ys"

lemma mset_merge [simp]:
  "mset (merge f xs ys) = mset xs + mset ys"
  by (induct f xs ys rule: merge.induct) (simp_all add: ac_simps)

lemma set_merge [simp]:
  "set (merge f xs ys) = set xs  set ys"
  by (induct f xs ys rule: merge.induct) auto

lemma sorted_merge:
  "transp f  (x y. f x y  f y x) 
   sorted_wrt f (merge f xs ys)  sorted_wrt f xs  sorted_wrt f ys"
  apply (induct f xs ys rule: merge.induct)
  apply (auto simp add: ball_Un not_le less_le dest: transpD)
  apply blast
  apply (blast dest: transpD)
  done

fun msort :: "_  'a list  'a list"
where
  "msort f [] = []"
| "msort f [x] = [x]"
| "msort f xs = merge f
                      (msort f (take (size xs div 2) xs))
                      (msort f (drop (size xs div 2) xs))"

fun swap_ternary :: _natnat ('a × 'a × 'a)  ('a × 'a × 'a) where
  swap_ternary f m n  =
    (if (m = 0  n = 1)
    then (λ(a, b, c). if f a b then (a, b, c)
      else (b,a,c))
    else if (m = 0  n = 2)
    then (λ(a, b, c). if f a c then (a, b, c)
      else (c,b,a))
    else if (m = 1  n = 2)
    then (λ(a, b, c). if f b c then (a, b, c)
      else (a,c,b))
    else (λ(a, b, c). (a,b,c)))

fun msort2 :: "_  'a list  'a list"
where
  "msort2 f [] = []"
| "msort2 f [x] = [x]"
| "msort2 f [x,y] = (if f x y then [x,y] else [y,x])"
| "msort2 f xs = merge f
                      (msort f (take (size xs div 2) xs))
                      (msort f (drop (size xs div 2) xs))"

lemmas [code del] =
  msort2.simps

declare msort2.simps[simp del]
lemmas [code] =
  msort2.simps[unfolded swap_ternary.simps, simplified]

declare msort2.simps[simp]

lemma msort_msort2:
  fixes xs :: 'a :: linorder list
  shows msort (≤) xs = msort2 (≤) xs
  apply (induction  (≤) :: 'a  'a  bool xs rule: msort2.induct)
  apply (auto dest: transpD)
  done

lemma sorted_msort:
  "transp f  (x y. f x y  f y x) 
   sorted_wrt f (msort f xs)"
  by (induct f xs rule: msort.induct) (simp_all add: sorted_merge)

lemma mset_msort[simp]:
  "mset (msort f xs) = mset xs"
  by (induction f xs rule: msort.induct)
    (simp_all add: union_code)


subsection ‹Sorting applied to monomials›

lemma merge_coeffs_alt_def:
  (RETURN o merge_coeffs) p =
   RECT(λf p.
     (case p of
       []  RETURN []
     | [_] => RETURN p
     | ((xs, n) # (ys, m) # p) 
      (if xs = ys
       then if n + m  0 then f ((xs, n + m) # p) else f p
       else do {p  f ((ys, m) # p); RETURN ((xs, n) # p)})))
    p
  apply (induction p rule: merge_coeffs.induct)
  subgoal by (subst RECT_unfold, refine_mono) auto
  subgoal by (subst RECT_unfold, refine_mono) auto
  subgoal for x p y q
    by (subst RECT_unfold, refine_mono)
     (smt case_prod_conv list.simps(5) merge_coeffs.simps(3) nres_monad1
      push_in_let_conv(2))
  done

lemma hn_invalid_recover:
  is_pure R  hn_invalid R = (λx y. R x y * true)
  is_pure R  invalid_assn R = (λx y. R x y * true)
  by (auto simp: is_pure_conv invalid_pure_recover hn_ctxt_def intro!: ext)

lemma safe_poly_vars:
  shows
    [safe_constraint_rules]:
      "is_pure (poly_assn)" and
    [safe_constraint_rules]:
      "is_pure (monom_assn)" and
    [safe_constraint_rules]:
      "is_pure (monomial_assn)" and
    [safe_constraint_rules]:
      "is_pure string_assn"
  by (auto intro!: pure_prod list_assn_pure simp: prod_assn_pure_conv)

lemma invalid_assn_distrib:
  invalid_assn monom_assn ×a invalid_assn int_assn = invalid_assn (monom_assn ×a int_assn)
    apply (simp add: invalid_pure_recover hn_invalid_recover
      safe_constraint_rules)
    apply (subst hn_invalid_recover)
    apply (rule safe_poly_vars(2))
    apply (subst hn_invalid_recover)
    apply (rule safe_poly_vars)
    apply (auto intro!: ext)
    done

lemma WTF_RF_recover:
  hn_ctxt (invalid_assn monom_assn ×a invalid_assn int_assn) xb
        x'a A
       hn_ctxt monomial_assn xb x'a t
       hn_ctxt (monomial_assn) xb x'a
  by (smt assn_aci(5) hn_ctxt_def invalid_assn_distrib invalid_pure_recover is_pure_conv
    merge_thms(4) merge_true_star reorder_enttI safe_poly_vars(3) star_aci(2) star_aci(3))

lemma WTF_RF:
  hn_ctxt (invalid_assn monom_assn ×a invalid_assn int_assn) xb x'a *
       (hn_invalid poly_assn la l'a * hn_invalid int_assn a2' a2 *
        hn_invalid monom_assn a1' a1 *
        hn_invalid poly_assn l l' *
        hn_invalid monomial_assn xa x' *
        hn_invalid poly_assn ax px) t
       hn_ctxt (monomial_assn) xb x'a *
       hn_ctxt poly_assn
        la l'a *
       hn_ctxt poly_assn l l' *
       (hn_invalid int_assn a2' a2 *
        hn_invalid monom_assn a1' a1 *
        hn_invalid monomial_assn xa x' *
        hn_invalid poly_assn ax px)
  hn_ctxt (invalid_assn monom_assn ×a invalid_assn int_assn) xa x' *
       (hn_ctxt poly_assn l l' * hn_invalid poly_assn ax px) t
       hn_ctxt (monomial_assn) xa x' *
       hn_ctxt poly_assn l l' *
       hn_ctxt poly_assn ax px *
       emp
  by sepref_dbg_trans_step+

text ‹The refinement frameword is completely lost here when synthesizing the constants -- it does
  not understant what is pure (actually everything) and what must be destroyed.›
sepref_definition merge_coeffs_impl
  is RETURN o merge_coeffs
  :: poly_assnd a poly_assn
  supply [[goals_limit=1]]
  unfolding merge_coeffs_alt_def
    HOL_list.fold_custom_empty poly_assn_alt_def
  apply (rewrite in _ annotate_assn[where A=poly_assn])
  apply sepref_dbg_preproc
  apply sepref_dbg_cons_init
  apply sepref_dbg_id
  apply sepref_dbg_monadify
  apply sepref_dbg_opt_init
  apply (rule WTF_RF | sepref_dbg_trans_step)+
  apply sepref_dbg_opt
  apply sepref_dbg_cons_solve
  apply sepref_dbg_cons_solve
  apply sepref_dbg_constraints
  done

definition full_quicksort_poly where
  full_quicksort_poly = full_quicksort_ref (λx y. x = y  (x, y)  term_order_rel) fst

lemma down_eq_id_list_rel: (Idlist_rel) x = x
  by auto

definition quicksort_poly:: nat  nat  llist_polynomial  (llist_polynomial) nres where
  quicksort_poly x y  z = quicksort_ref (≤) fst (x, y, z)

term partition_between_ref

definition partition_between_poly :: nat  nat  llist_polynomial  (llist_polynomial × nat) nres where
  partition_between_poly = partition_between_ref (≤) fst

definition partition_main_poly :: nat  nat  llist_polynomial  (llist_polynomial × nat) nres where
  partition_main_poly = partition_main (≤)  fst

lemma string_list_trans:
  (xa ::char list list, ya)  lexord (lexord {(x, y). x < y}) 
  (ya, z)  lexord (lexord {(x, y). x < y}) 
    (xa, z)  lexord (lexord {(x, y). x < y})
  by (smt (verit) less_char_def char.less_trans less_than_char_def lexord_partial_trans p2rel_def)

lemma full_quicksort_sort_poly_spec:
  (full_quicksort_poly, sort_poly_spec)  Idlist_rel f Idlist_relnres_rel
proof -
  have xs: (xs, xs)  Idlist_rel and (Idlist_rel) x = x for x xs
    by auto
  show ?thesis
    apply (intro frefI nres_relI)
    unfolding full_quicksort_poly_def
    apply (rule full_quicksort_ref_full_quicksort[THEN fref_to_Down_curry, THEN order_trans])
    subgoal
      by (auto simp: rel2p_def var_order_rel_def p2rel_def Relation.total_on_def
        dest: string_list_trans)
    subgoal
      using total_on_lexord_less_than_char_linear[unfolded var_order_rel_def]
      apply (auto simp: rel2p_def var_order_rel_def p2rel_def Relation.total_on_def less_char_def)
      done
    subgoal by fast
    apply (rule xs)
    apply (subst down_eq_id_list_rel)
    unfolding sorted_wrt_map sort_poly_spec_def
    apply (rule full_quicksort_correct_sorted[where R = (λx y. x = y  (x, y)  term_order_rel) and h = fst,
       THEN order_trans])
    subgoal
      by (auto simp: rel2p_def var_order_rel_def p2rel_def Relation.total_on_def dest: string_list_trans)
    subgoal for x y
      using total_on_lexord_less_than_char_linear[unfolded var_order_rel_def]
      apply (auto simp: rel2p_def var_order_rel_def p2rel_def Relation.total_on_def
        less_char_def)
      done
   subgoal
    by (auto simp: rel2p_def p2rel_def)
   done
qed


subsection ‹Lifting to polynomials›

definition merge_sort_poly :: _ where
merge_sort_poly = msort (λa b. fst a  fst b)

definition merge_monoms_poly :: _ where
merge_monoms_poly = msort (≤)

definition merge_poly :: _ where
merge_poly = merge (λa b. fst a  fst b)

definition merge_monoms :: _ where
merge_monoms = merge (≤)

definition msort_poly_impl :: (String.literal list × int) list  _ where
msort_poly_impl = msort (λa b. fst a  fst b)

definition msort_monoms_impl :: (String.literal list)  _ where
msort_monoms_impl = msort (≤)

lemma msort_poly_impl_alt_def:
  msort_poly_impl xs =
    (case xs of
      []  []
     | [a]  [a]
     | [a,b]  if fst a  fst b then [a,b]else [b,a]
     | xs  merge_poly
                      (msort_poly_impl (take ((length xs) div 2) xs))
                      (msort_poly_impl (drop ((length xs) div 2) xs)))
   unfolding msort_poly_impl_def
  apply (auto split: list.splits simp: merge_poly_def)
  done

lemma le_term_order_rel':
  (≤) = (λx y. x = y   term_order_rel' x y)
  apply (intro ext)
  apply (auto simp add: less_list_def less_eq_list_def
    lexordp_eq_conv_lexord lexordp_def)
  using term_order_rel'_alt_def_lexord term_order_rel'_def apply blast
  using term_order_rel'_alt_def_lexord term_order_rel'_def apply blast
  done

fun lexord_eq where
  lexord_eq [] _ = True |
  lexord_eq (x # xs) (y # ys) = (x < y  (x = y  lexord_eq xs ys)) |
  lexord_eq _ _ = False

lemma [simp]:
  lexord_eq [] [] = True
  lexord_eq (a # b)[] = False
  lexord_eq [] (a # b) = True
  apply auto
  done

lemma var_order_rel':
  (≤) = (λx y. x = y  (x,y)  var_order_rel)
  by (intro ext)
   (auto simp add: less_list_def less_eq_list_def
    lexordp_eq_conv_lexord lexordp_def var_order_rel_def
    lexordp_conv_lexord p2rel_def)


lemma var_order_rel'':
  (x,y)  var_order_rel  x < y
  by (metis leD less_than_char_linear lexord_linear neq_iff var_order_rel' var_order_rel_antisym
      var_order_rel_def)

lemma lexord_eq_alt_def1:
  a  b = lexord_eq a b for a b :: String.literal list
  unfolding le_term_order_rel'
  apply (induction a b rule: lexord_eq.induct)
  apply (auto simp: var_order_rel'' less_eq_list_def)
  done

lemma lexord_eq_alt_def2:
  (RETURN oo lexord_eq) xs ys =
     RECT (λf (xs, ys).
        case (xs, ys) of
           ([], _)  RETURN True
         | (x # xs, y # ys) 
            if x < y then RETURN True
            else if x = y then f (xs, ys) else RETURN False
        | _  RETURN False)
        (xs, ys)
  apply (subst eq_commute)
  apply (induction xs ys rule: lexord_eq.induct)
  subgoal by (subst RECT_unfold, refine_mono) auto
  subgoal by (subst RECT_unfold, refine_mono) auto
  subgoal by (subst RECT_unfold, refine_mono) auto
  done


definition var_order' where
  [simp]: var_order' = var_order

lemma var_order_rel[def_pat_rules]:
  (∈)$(x,y)$var_order_rel  var_order'$x$y
  by (auto simp: p2rel_def rel2p_def)

lemma var_order_rel_alt_def:
  var_order_rel = p2rel char.lexordp
  apply (auto simp: p2rel_def char.lexordp_conv_lexord var_order_rel_def)
  using char.lexordp_conv_lexord apply auto
  done

lemma var_order_rel_var_order:
  (x, y)  var_order_rel  var_order x y
  by (auto simp: rel2p_def)

lemma var_order_string_le[sepref_import_param]:
  ((<), var_order')  string_rel  string_rel  bool_rel
  apply (auto intro!: frefI simp: string_rel_def String.less_literal_def
     rel2p_def linorder.lexordp_conv_lexord[OF char.linorder_axioms,
      unfolded less_eq_char_def] var_order_rel_def
      p2rel_def
      simp flip: PAC_Polynomials_Term.less_char_def)
  using char.lexordp_conv_lexord apply auto
  done

lemma [sepref_import_param]:
  ( (≤), (≤))  monom_rel  monom_rel bool_rel
  apply (intro fun_relI)
  using list_rel_list_rel_order_iff by fastforce

lemma [sepref_import_param]:
  ( (<), (<))  string_rel  string_rel bool_rel
proof -
  have [iff]: ord.lexordp (<) (literal.explode a) (literal.explode aa) 
       List.lexordp (<) (literal.explode a) (literal.explode aa) for a aa
    apply (rule iffI)
     apply (metis PAC_Checker_Relation.less_char_def char.lexordp_conv_lexord less_list_def
        p2rel_def var_order_rel'' var_order_rel_def)
    apply (metis PAC_Checker_Relation.less_char_def char.lexordp_conv_lexord less_list_def
        p2rel_def var_order_rel'' var_order_rel_def)
    done
  show ?thesis
    unfolding string_rel_def less_literal.rep_eq less_than_char_def
      less_eq_list_def PAC_Polynomials_Term.less_char_def[symmetric]
    by (intro fun_relI)
     (auto simp: string_rel_def less_literal.rep_eq
        less_list_def char.lexordp_conv_lexord lexordp_eq_refl
        lexordp_eq_conv_lexord)
qed


lemma lexordp_char_char: ord_class.lexordp = char.lexordp
  unfolding char.lexordp_def ord_class.lexordp_def
  by (rule arg_cong[of _ _ lfp])
    (auto intro!: ext)

lemma [sepref_import_param]:
  ( (≤), (≤))  string_rel  string_rel bool_rel
  unfolding string_rel_def less_eq_literal.rep_eq less_than_char_def
    less_eq_list_def PAC_Polynomials_Term.less_char_def[symmetric]
  by (intro fun_relI)
   (auto simp: string_rel_def less_eq_literal.rep_eq less_than_char_def
    less_eq_list_def char.lexordp_eq_conv_lexord lexordp_eq_refl
    lexordp_eq_conv_lexord lexordp_char_char
    simp flip: less_char_def[abs_def])

sepref_register lexord_eq
sepref_definition lexord_eq_term
  is uncurry (RETURN oo lexord_eq)
  :: monom_assnk *a monom_assnk a bool_assn
  supply[[goals_limit=1]]
  unfolding lexord_eq_alt_def2
  by sepref

declare lexord_eq_term.refine[sepref_fr_rules]


lemmas [code del] = msort_poly_impl_def msort_monoms_impl_def
lemmas [code] =
  msort_poly_impl_def[unfolded lexord_eq_alt_def1[abs_def]]
  msort_monoms_impl_def[unfolded msort_msort2]

lemma term_order_rel_trans:
  (a, aa)  term_order_rel 
       (aa, ab)  term_order_rel  (a, ab)  term_order_rel
  by (metis PAC_Checker_Relation.less_char_def p2rel_def string_list_trans var_order_rel_def)

lemma merge_sort_poly_sort_poly_spec:
  (RETURN o merge_sort_poly, sort_poly_spec)  Idlist_rel f Idlist_relnres_rel
  unfolding sort_poly_spec_def merge_sort_poly_def
  apply (intro frefI nres_relI)
  using total_on_lexord_less_than_char_linear var_order_rel_def
  by (auto intro!: sorted_msort simp: sorted_wrt_map rel2p_def
    le_term_order_rel' transp_def dest: term_order_rel_trans)

lemma msort_alt_def:
  RETURN o (msort f) =
     RECT (λg xs.
        case xs of
          []  RETURN []
        | [x]  RETURN [x]
        | _  do {
           a  g (take (size xs div 2) xs);
           b  g (drop (size xs div 2) xs);
           RETURN (merge f a b)})
  apply (intro ext)
  unfolding comp_def
  apply (induct_tac f x rule: msort.induct)
  subgoal by (subst RECT_unfold, refine_mono) auto
  subgoal by (subst RECT_unfold, refine_mono) auto
  subgoal
    by (subst RECT_unfold, refine_mono)
     (smt (verit) let_to_bind_conv list.simps(5) msort.simps(3))
  done

lemma monomial_rel_order_map:
  (x, a, b)  monomial_rel 
       (y, aa, bb)  monomial_rel 
       fst x  fst y  a  aa
  apply (cases x; cases y)
  apply auto
  using list_rel_list_rel_order_iff by fastforce+


lemma step_rewrite_pure:
  fixes K :: ('olbl × 'lbl) set
  shows
    pure (p2rel (K, V, Rpac_step_rel_raw)) = pac_step_rel_assn (pure K) (pure V) (pure R)
    monomial_assn = pure (monom_rel ×r int_rel) and
  poly_assn_list:
    poly_assn = pure (monom_rel ×r int_rellist_rel)
  subgoal
    apply (intro ext)
    apply (case_tac x; case_tac xa)
    apply (auto simp: relAPP_def p2rel_def pure_def)
    done
  subgoal H
    apply (intro ext)
    apply (case_tac x; case_tac xa)
    by (simp add: list_assn_pure_conv)
  subgoal
    unfolding H
    by (simp add: list_assn_pure_conv relAPP_def)
  done

lemma safe_pac_step_rel_assn[safe_constraint_rules]:
  "is_pure K  is_pure V  is_pure R  is_pure (pac_step_rel_assn K V R)"
  by (auto simp: step_rewrite_pure(1)[symmetric] is_pure_conv)


lemma merge_poly_merge_poly:
  (merge_poly, merge_poly)
    poly_rel  poly_rel  poly_rel
   unfolding merge_poly_def
  apply (intro fun_relI)
  subgoal for a a' aa a'a
    apply (induction (λ(a :: String.literal list × int)
      (b :: String.literal list × int). fst a  fst b) a aa
      arbitrary: a' a'a
      rule: merge.induct)
    subgoal
      by (auto elim!: list_relE3 list_relE4 list_relE list_relE2
        simp: monomial_rel_order_map)
    subgoal
      by (auto elim!: list_relE3 list_relE)
    subgoal
      by (auto elim!: list_relE3 list_relE4 list_relE list_relE2)
    done
  done

lemmas [fcomp_norm_unfold] =
  poly_assn_list[symmetric]
  step_rewrite_pure(1)

lemma merge_poly_merge_poly2:
  (a, b)  poly_rel  (a', b')  poly_rel 
    (merge_poly a a', merge_poly b b')  poly_rel
  using merge_poly_merge_poly
  unfolding fun_rel_def
  by auto

lemma list_rel_takeD:
  (a, b)  Rlist_rel  (n, n') Id  (take n a, take n' b)  Rlist_rel
  by (simp add: list_rel_eq_listrel listrel_iff_nth relAPP_def)

lemma list_rel_dropD:
  (a, b)  Rlist_rel  (n, n') Id  (drop n a, drop n' b)  Rlist_rel
  by (simp add: list_rel_eq_listrel listrel_iff_nth relAPP_def)

lemma merge_sort_poly[sepref_import_param]:
  (msort_poly_impl, merge_sort_poly)
    poly_rel  poly_rel
   unfolding merge_sort_poly_def msort_poly_impl_def
  apply (intro fun_relI)
  subgoal for a a'
    apply (induction (λ(a :: String.literal list × int)
      (b :: String.literal list × int). fst a  fst b) a
      arbitrary: a'
      rule: msort.induct)
    subgoal
      by auto
    subgoal
      by (auto elim!: list_relE3 list_relE)
    subgoal premises p
      using p
      by (auto elim!: list_relE3 list_relE4 list_relE list_relE2
        simp: merge_poly_def[symmetric]
        intro!: list_rel_takeD list_rel_dropD
        intro!: merge_poly_merge_poly2 p(1)[simplified] p(2)[simplified],
        auto simp: list_rel_imp_same_length)
    done
  done



lemmas [sepref_fr_rules] = merge_sort_poly[FCOMP merge_sort_poly_sort_poly_spec]

sepref_definition partition_main_poly_impl
  is uncurry2 partition_main_poly
  :: nat_assnk *a nat_assnk *a poly_assnk a prod_assn poly_assn nat_assn
  unfolding partition_main_poly_def partition_main_def
    term_order_rel'_def[symmetric]
    term_order_rel'_alt_def
    le_term_order_rel'
  by sepref

declare partition_main_poly_impl.refine[sepref_fr_rules]

sepref_definition partition_between_poly_impl
  is uncurry2 partition_between_poly
  :: nat_assnk *a nat_assnk *a poly_assnk a prod_assn poly_assn nat_assn
  unfolding partition_between_poly_def partition_between_ref_def
    partition_main_poly_def[symmetric]
  unfolding choose_pivot3_def
    term_order_rel'_def[symmetric]
    term_order_rel'_alt_def choose_pivot_def
    lexord_eq_alt_def1
  by sepref

declare partition_between_poly_impl.refine[sepref_fr_rules]

sepref_definition quicksort_poly_impl
  is uncurry2 quicksort_poly
  :: nat_assnk *a nat_assnk *a poly_assnk a poly_assn
  unfolding partition_main_poly_def quicksort_ref_def quicksort_poly_def
    partition_between_poly_def[symmetric]
  by sepref

lemmas [sepref_fr_rules] = quicksort_poly_impl.refine

sepref_register quicksort_poly
sepref_definition full_quicksort_poly_impl
  is full_quicksort_poly
  :: poly_assnk a poly_assn
  unfolding full_quicksort_poly_def full_quicksort_ref_def
    quicksort_poly_def[symmetric]
    le_term_order_rel'[symmetric]
    term_order_rel'_def[symmetric]
    List.null_def
  by sepref


lemmas sort_poly_spec_hnr =
  full_quicksort_poly_impl.refine[FCOMP full_quicksort_sort_poly_spec]

declare merge_coeffs_impl.refine[sepref_fr_rules]

sepref_definition normalize_poly_impl
  is normalize_poly
  :: poly_assnk a poly_assn
  supply [[goals_limit=1]]
  unfolding normalize_poly_def
  by sepref

declare normalize_poly_impl.refine[sepref_fr_rules]


definition full_quicksort_vars where
  full_quicksort_vars = full_quicksort_ref (λx y. x = y  (x, y)  var_order_rel) id


definition quicksort_vars:: nat  nat  string list  (string list) nres where
  quicksort_vars x y  z = quicksort_ref (≤) id (x, y, z)


definition partition_between_vars :: nat  nat  string list  (string list × nat) nres where
  partition_between_vars = partition_between_ref (≤) id

definition partition_main_vars :: nat  nat  string list  (string list × nat) nres where
  partition_main_vars = partition_main (≤) id

lemma total_on_lexord_less_than_char_linear2:
  xs  ys  (xs, ys)  lexord (less_than_char) 
       (ys, xs)  lexord less_than_char
   using lexord_linear[of less_than_char xs ys]
   using lexord_linear[of less_than_char] less_than_char_linear
   apply (auto simp: Relation.total_on_def)
   using lexord_irrefl[OF irrefl_less_than_char]
     antisym_lexord[OF antisym_less_than_char irrefl_less_than_char]
   apply (auto simp: antisym_def)
   done

lemma string_trans:
  (xa, ya)  lexord {(x::char, y::char). x < y} 
  (ya, z)  lexord {(x::char, y::char). x < y} 
  (xa, z)  lexord {(x::char, y::char). x < y}
  by (smt (verit) less_char_def char.less_trans less_than_char_def lexord_partial_trans p2rel_def)

lemma full_quicksort_sort_vars_spec:
  (full_quicksort_vars, sort_coeff)  Idlist_rel f Idlist_relnres_rel
proof -
  have xs: (xs, xs)  Idlist_rel and (Idlist_rel) x = x for x xs
    by auto
  show ?thesis
    apply (intro frefI nres_relI)
    unfolding full_quicksort_vars_def
    apply (rule full_quicksort_ref_full_quicksort[THEN fref_to_Down_curry, THEN order_trans])
    subgoal
      by (auto simp: rel2p_def var_order_rel_def p2rel_def Relation.total_on_def
        dest: string_trans)
    subgoal
      using total_on_lexord_less_than_char_linear2[unfolded var_order_rel_def]
      apply (auto simp: rel2p_def var_order_rel_def p2rel_def Relation.total_on_def less_char_def)
      done
    subgoal by fast
    apply (rule xs)
    apply (subst down_eq_id_list_rel)
    unfolding sorted_wrt_map sort_coeff_def
    apply (rule full_quicksort_correct_sorted[where R = (λx y. x = y  (x, y)  var_order_rel) and h = id,
       THEN order_trans])
    subgoal
      by (auto simp: rel2p_def var_order_rel_def p2rel_def Relation.total_on_def dest: string_trans)
    subgoal for x y
      using total_on_lexord_less_than_char_linear2[unfolded var_order_rel_def]
      by (auto simp: rel2p_def var_order_rel_def p2rel_def Relation.total_on_def
        less_char_def)
   subgoal
    by (auto simp: rel2p_def p2rel_def rel2p_def[abs_def])
   done
qed


sepref_definition partition_main_vars_impl
  is uncurry2 partition_main_vars
  :: nat_assnk *a nat_assnk *a (monom_assn)k a prod_assn (monom_assn) nat_assn
  unfolding partition_main_vars_def partition_main_def
    var_order_rel_var_order
    var_order'_def[symmetric]
    term_order_rel'_alt_def
    le_term_order_rel'
    id_apply
    by sepref

declare partition_main_vars_impl.refine[sepref_fr_rules]

sepref_definition partition_between_vars_impl
  is uncurry2 partition_between_vars
  :: nat_assnk *a nat_assnk *a monom_assnk a prod_assn monom_assn nat_assn
  unfolding partition_between_vars_def partition_between_ref_def
    partition_main_vars_def[symmetric]
  unfolding choose_pivot3_def
    term_order_rel'_def[symmetric]
    term_order_rel'_alt_def choose_pivot_def
    le_term_order_rel' id_apply
  by sepref

declare partition_between_vars_impl.refine[sepref_fr_rules]

sepref_definition quicksort_vars_impl
  is uncurry2 quicksort_vars
  :: nat_assnk *a nat_assnk *a monom_assnk a monom_assn
  unfolding partition_main_vars_def quicksort_ref_def quicksort_vars_def
    partition_between_vars_def[symmetric]
  by sepref

lemmas [sepref_fr_rules] = quicksort_vars_impl.refine

sepref_register quicksort_vars


lemma le_var_order_rel:
  (≤) = (λx y. x = y  (x, y)  var_order_rel)
  by (intro ext)
   (auto simp add: less_list_def less_eq_list_def rel2p_def
      p2rel_def lexordp_conv_lexord p2rel_def var_order_rel_def
    lexordp_eq_conv_lexord lexordp_def)

sepref_definition full_quicksort_vars_impl
  is full_quicksort_vars
  :: monom_assnk a monom_assn
  unfolding full_quicksort_vars_def full_quicksort_ref_def
    quicksort_vars_def[symmetric]
    le_var_order_rel[symmetric]
    term_order_rel'_def[symmetric]
    List.null_def
  by sepref


lemmas sort_vars_spec_hnr =
  full_quicksort_vars_impl.refine[FCOMP full_quicksort_sort_vars_spec]

lemma string_rel_order_map:
  (x, a)  string_rel 
       (y, aa)  string_rel 
       x  y  a  aa
  unfolding string_rel_def less_eq_literal.rep_eq less_than_char_def
    less_eq_list_def PAC_Polynomials_Term.less_char_def[symmetric]
  by (auto simp: string_rel_def less_eq_literal.rep_eq less_than_char_def
    less_eq_list_def char.lexordp_eq_conv_lexord lexordp_eq_refl
    lexordp_char_char lexordp_eq_conv_lexord
    simp flip: less_char_def[abs_def])

lemma merge_monoms_merge_monoms:
  (merge_monoms, merge_monoms)  monom_rel  monom_rel  monom_rel
   unfolding merge_monoms_def
  apply (intro fun_relI)
  subgoal for a a' aa a'a
    apply (induction (λ(a :: String.literal)
      (b :: String.literal). a  b) a aa
      arbitrary: a' a'a
      rule: merge.induct)
    subgoal
      by (auto elim!: list_relE3 list_relE4 list_relE list_relE2
        simp: string_rel_order_map)
    subgoal
      by (auto elim!: list_relE3 list_relE)
    subgoal
      by (auto elim!: list_relE3 list_relE4 list_relE list_relE2)
    done
  done

lemma merge_monoms_merge_monoms2:
  (a, b)  monom_rel  (a', b')  monom_rel 
    (merge_monoms a a', merge_monoms b b')  monom_rel
  using merge_monoms_merge_monoms
  unfolding fun_rel_def merge_monoms_def
  by auto


lemma msort_monoms_impl:
  (msort_monoms_impl, merge_monoms_poly)
    monom_rel  monom_rel
   unfolding msort_monoms_impl_def merge_monoms_poly_def
  apply (intro fun_relI)
  subgoal for a a'
    apply (induction (λ(a :: String.literal)
      (b :: String.literal). a  b) a
      arbitrary: a'
      rule: msort.induct)
    subgoal
      by auto
    subgoal
      by (auto elim!: list_relE3 list_relE)
    subgoal premises p
      using p
      by (auto elim!: list_relE3 list_relE4 list_relE list_relE2
        simp: merge_monoms_def[symmetric] intro!: list_rel_takeD list_rel_dropD
        intro!: merge_monoms_merge_monoms2 p(1)[simplified] p(2)[simplified])
        (simp_all add: list_rel_imp_same_length)
    done
  done

lemma merge_sort_monoms_sort_monoms_spec:
  (RETURN o merge_monoms_poly, sort_coeff)  Idlist_rel f Idlist_relnres_rel
  unfolding merge_monoms_poly_def sort_coeff_def
  by (intro frefI nres_relI)
    (auto intro!: sorted_msort simp: sorted_wrt_map rel2p_def
     le_term_order_rel' transp_def rel2p_def[abs_def]
     simp flip: le_var_order_rel)

sepref_register sort_coeff
lemma  [sepref_fr_rules]:
  (return o msort_monoms_impl, sort_coeff)  monom_assnk a monom_assn
  using msort_monoms_impl[sepref_param, FCOMP merge_sort_monoms_sort_monoms_spec]
  by auto

sepref_definition sort_all_coeffs_impl
  is sort_all_coeffs
  :: poly_assnk a poly_assn
  unfolding sort_all_coeffs_def
    HOL_list.fold_custom_empty
  by sepref

declare sort_all_coeffs_impl.refine[sepref_fr_rules]

lemma merge_coeffs0_alt_def:
  (RETURN o merge_coeffs0) p =
   RECT(λf p.
     (case p of
       []  RETURN []
     | [p] => if snd p = 0 then RETURN [] else RETURN [p]
     | ((xs, n) # (ys, m) # p) 
      (if xs = ys
       then if n + m  0 then f ((xs, n + m) # p) else f p
       else if n = 0 then
          do {p  f ((ys, m) # p);
            RETURN p}
       else do {p  f ((ys, m) # p);
            RETURN ((xs, n) # p)})))
    p
  apply (subst eq_commute)
  apply (induction p rule: merge_coeffs0.induct)
  subgoal by (subst RECT_unfold, refine_mono) auto
  subgoal by (subst RECT_unfold, refine_mono) auto
  subgoal by (subst RECT_unfold, refine_mono) (auto simp: let_to_bind_conv)
  done

text ‹Again, Sepref does not understand what is going here.›
sepref_definition merge_coeffs0_impl
  is RETURN o merge_coeffs0
  :: poly_assnk a poly_assn
  supply [[goals_limit=1]]
  unfolding merge_coeffs0_alt_def
    HOL_list.fold_custom_empty
  apply sepref_dbg_preproc
  apply sepref_dbg_cons_init
  apply sepref_dbg_id
  apply sepref_dbg_monadify
  apply sepref_dbg_opt_init
  apply (rule WTF_RF | sepref_dbg_trans_step)+
  apply sepref_dbg_opt
  apply sepref_dbg_cons_solve
  apply sepref_dbg_cons_solve
  apply sepref_dbg_constraints
  done


declare merge_coeffs0_impl.refine[sepref_fr_rules]

sepref_definition fully_normalize_poly_impl
  is full_normalize_poly
  :: poly_assnk a poly_assn
  supply [[goals_limit=1]]
  unfolding full_normalize_poly_def
  by sepref

declare fully_normalize_poly_impl.refine[sepref_fr_rules]


end