Theory First_Order_Clause.Monomorphic_Typing

theory Monomorphic_Typing
  imports
    Nonground_Term_Typing
    Ground_Term_Typing
    IsaFoR_Nonground_Term
begin

no_notation Ground_Term_Typing.welltyped (‹_  _ : _› [1000, 0, 50] 50)
notation Ground_Term_Typing.welltyped (‹_  _ :G _› [1000, 0, 50] 50)

locale monomorphic_term_typing =
  fixes  :: "('f, 'ty) fun_types"
begin

inductive welltyped :: "('v, 'ty) var_types  ('f,'v) term  'ty  bool" 
  for 𝒱 where
    Var: "welltyped 𝒱 (Var x) τ" if 
    "𝒱 x = τ" 
  | Fun: "welltyped 𝒱 (Fun f ts) τ" if 
    " f (length ts) = Some (τs, τ)"
    "list_all2 (welltyped 𝒱) ts τs"

(* TODO: Introduce notations also for lifting + substition *)
notation welltyped (‹_  _ : _› [1000, 0, 50] 50)

sublocale "term": base_typing where welltyped = "welltyped 𝒱" for 𝒱
proof unfold_locales
  show "right_unique (welltyped 𝒱)"
  proof (rule right_uniqueI)
    fix t τ1 τ2
    assume "𝒱  t : τ1" and "𝒱  t : τ2"
    thus "τ1 = τ2"
      by (auto elim!: welltyped.cases)
  qed
qed

sublocale "term": term_typing where 
  welltyped = "welltyped 𝒱" and apply_context = ctxt_apply_term for 𝒱
proof unfold_locales
  fix t t' c τ τ'

  assume
    welltyped_t_t': "𝒱  t : τ'" "𝒱  t' : τ'" and
    welltyped_c_t: "𝒱  ct : τ"

  from welltyped_c_t show "𝒱  ct' : τ"
  proof (induction c arbitrary: τ)
    case Hole
    then show ?case
      using welltyped_t_t'
      by auto
  next
    case (More f ss1 c ss2)

    have "𝒱  Fun f (ss1 @ ct # ss2) : τ"
      using More.prems
      by simp

    hence "𝒱  Fun f (ss1 @ ct' # ss2) : τ"
    proof (cases 𝒱 "Fun f (ss1 @ ct # ss2)" τ rule: welltyped.cases)
      case (Fun τs)

      show ?thesis
      proof (rule welltyped.Fun)
        show " f (length (ss1 @ ct' # ss2)) = Some (τs, τ)"
          using Fun 
          by simp
      next
        show "list_all2 (welltyped 𝒱) (ss1 @ ct' # ss2) τs"
          using list_all2 (welltyped 𝒱) (ss1 @ ct # ss2) τs
          using More.IH
          by (smt (verit, del_insts) list_all2_Cons1 list_all2_append1 list_all2_lengthD)
      qed
    qed

    thus ?case
      by simp
  qed
next
  fix c t τ
  assume "𝒱  ct : τ" 
  then show "τ'. 𝒱  t : τ'"
    by
      (induction c arbitrary: τ)
      (auto simp: welltyped.simps list_all2_Cons1 list_all2_append1)
qed

sublocale "term": base_typing_properties where
  id_subst = "Var :: 'v  ('f, 'v) term" and comp_subst = "(∘s)" and subst = "(⋅)" and
  vars = term.vars and welltyped = welltyped and to_ground = term.to_ground and
  from_ground = term.from_ground
proof(unfold_locales; (intro welltyped.Var refl)?)
  fix t :: "('f, 'v) term" and 𝒱 σ τ
  assume type_preserving_σ: " xterm.vars t. 𝒱  Var x : 𝒱 x  𝒱  σ x : 𝒱 x"

  show "𝒱  t  σ : τ  𝒱  t : τ"  
  proof (rule iffI)

    assume "𝒱  t : τ"

    then show "𝒱  t  σ : τ"
      using type_preserving_σ
      by 
        (induction rule: welltyped.induct)
        (auto simp: list.rel_mono_strong list_all2_map1 welltyped.simps)
  next

    assume "𝒱  t  σ : τ"

    then show "𝒱  t : τ"
      using type_preserving_σ
    proof (induction "t  σ" τ arbitrary: t rule: welltyped.induct)
      case (Var x τ)

      then obtain x' where t: "t = Var x'"
        by (metis subst_apply_eq_Var)

      have "𝒱  t : 𝒱 x'"
        unfolding t
        by (simp add: welltyped.Var)

      moreover have "𝒱  t : 𝒱 x"
        using Var
        unfolding t
        by (simp add: welltyped.simps)

      ultimately have 𝒱_x': "τ = 𝒱 x'"
        using Var.hyps
        by blast

      show ?case
        unfolding t 𝒱_x'
        by (simp add: welltyped.Var)
    next
      case (Fun f τs τ ts)

      then show ?case
        by (cases t) (simp_all add: list.rel_mono_strong list_all2_map1 welltyped.simps)
    qed
  qed

next
  fix t :: "('f, 'v) term" and 𝒱 𝒱' τ

  assume "𝒱  t : τ" "xterm.vars t. 𝒱 x = 𝒱' x"

  then show "𝒱'  t : τ"
    by (induction rule: welltyped.induct) (simp_all add: welltyped.simps list.rel_mono_strong)
next
  fix 𝒱 𝒱' :: "('v, 'ty) var_types" and t :: "('f, 'v) term" and ρ :: "('f, 'v) subst" and τ

  assume
    renaming: "term_subst.is_renaming ρ" and
    𝒱: "xterm.vars t. 𝒱 x = 𝒱' (term.rename ρ x)"

  then show "𝒱'  t  ρ : τ  𝒱  t : τ"
  proof(intro iffI)

    assume "𝒱'  t  ρ : τ"

    with 𝒱 show "𝒱  t : τ"
    proof(induction t arbitrary: τ)
      case (Var x)

      then have "𝒱' (term.rename ρ x) = τ"
        using renaming term.id_subst_rename[OF renaming]
        by (metis eval_term.simps(1) monomorphic_term_typing.Var term.right_uniqueD)

      then have "𝒱 x = τ"
        by (simp add: Var.prems(1))

      then show ?case
        by(rule welltyped.Var)
    next
      case (Fun f ts)

      then have "𝒱'  Fun f (map (λs. s  ρ) ts) : τ"
        by auto

      then obtain τs where τs:
        "list_all2 (welltyped 𝒱') (map (λs. s  ρ) ts) τs" 
        " f (length (map (λs. s  ρ) ts)) = Some (τs, τ)"
        using welltyped.simps
        by blast

      then have : " f (length ts) = Some (τs, τ)"
        by simp

      show ?case
      proof(rule welltyped.Fun[OF ])

        show "list_all2 (welltyped 𝒱) ts τs"
          using τs(1) Fun.IH
          by (smt (verit, ccfv_SIG) Fun.prems(1) eval_term.simps(2) in_set_conv_nth length_map
              list_all2_conv_all_nth nth_map term.set_intros(4))
      qed
    qed
  next
    assume "𝒱  t : τ"

    then show "𝒱'  t  ρ : τ"
      using 𝒱
    proof(induction rule: welltyped.induct)
      case (Var x τ)

      then have "𝒱' (term.rename ρ x) = τ"
        by simp

      then show ?case
        using term.id_subst_rename[OF renaming]
        by (metis eval_term.simps(1) welltyped.Var)
    next
      case (Fun f ts τs τ)

      have "list_all2 (welltyped 𝒱') (map (λs. s  ρ) ts) τs"
        using Fun
        by (auto simp: list.rel_mono_strong list_all2_map1)

      then show ?case
        by (simp add: Fun.hyps welltyped.simps)
    qed
  qed
next
  fix 𝒱 :: "('v, 'ty) var_types"  and x

  show "𝒱  Var x : 𝒱 x  ¬ term.is_welltyped 𝒱 (Var x)"
    by (simp add: Var)
qed

(* TODO: Try to put in other file *)
lemma unify_subterms:
  assumes "term.type_preserving_unifier 𝒱 υ (Fun f ts) (Fun f ts')"
  shows "list_all2 (term.type_preserving_unifier 𝒱 υ) ts ts'"
  using assms
  unfolding list_all2_iff
  by (auto elim: in_set_zipE intro: map_eq_imp_length_eq)

lemma type_preserving_subst:
  assumes "𝒱  Var x : τ" "𝒱  t : τ"
  shows "term.type_preserving 𝒱 (subst x t)"
  using assms
  unfolding subst_def
  by auto

lemma type_preserving_unifier_subst:
  assumes "(s, s')  set ((Var x, t) # es). term.type_preserving_unifier 𝒱 υ s s'"
  shows "(s, s')  set es. term.type_preserving_unifier 𝒱 υ (s  subst x t) (s'  subst x t)"
proof (intro ballI2)
  fix s s'
  assume s_s': "(s, s')  set es"

  then have "term.type_preserving_on (term.vars (s  subst x t)  term.vars (s'  subst x t)) 𝒱 υ"
    using assms term.vars_id_subst_update
    unfolding subst_def
    by (smt (verit, del_insts) UnCI UnE case_prodD list.set_intros(1,2) subset_iff)

  then show "term.type_preserving_unifier 𝒱 υ (s  subst x t) (s'  subst x t)"
    using assms s_s'
    by auto
qed

lemma type_preserving_unifier_subst_list:
  assumes "(s, s')  set ((Var x, t) # es). term.type_preserving_unifier 𝒱 υ s s'"
  shows "(s, s')  set (subst_list (subst x t) es). term.type_preserving_unifier 𝒱 υ s s'"
  using type_preserving_unifier_subst[OF assms]
  unfolding subst_list_def
  by (smt (verit, best) case_prod_conv image_iff list.set_map prod.collapse)

lemma unify_subterms_zip_option:
  assumes
    type_preserving_unifier: "term.type_preserving_unifier 𝒱 υ (Fun f ts) (Fun f ts')" and
    zip_option: "zip_option ts ts' = Some es"
  shows
    "(t, t')  set es. term.type_preserving_unifier 𝒱 υ t t'"
  using unify_subterms[OF type_preserving_unifier] zip_option
  unfolding zip_option_zip_conv list_all2_iff
  by argo

lemma type_preserving_unifier_decompose_Fun:
  assumes
    type_preserving_unifier: "term.type_preserving_unifier 𝒱 υ (Fun f ts) (Fun g ss)" and
    decompose: "decompose (Fun f ts) (Fun g ss) = Some es"
  shows "(t, t')  set es. term.type_preserving_unifier 𝒱 υ t t'"
  using type_preserving_unifier decompose_Some[OF decompose]
  by (metis (mono_tags, lifting) list_all2_iff unify_subterms)

lemma type_preserving_unifier_decompose:
  assumes
    type_preserving_unifier: "term.type_preserving_unifier 𝒱 υ f g" and
    decompose: "decompose f g = Some es"
  shows "(t, t')  set es. term.type_preserving_unifier 𝒱 υ t t'"
proof -

  obtain f' fs gs where Fun: "f = Fun f' fs" "g = Fun f' gs"
    using decompose
    unfolding decompose_def
    by (auto split: term.splits if_splits)

  show ?thesis
    using type_preserving_unifier_decompose_Fun[OF assms[unfolded Fun]] .
qed

lemma type_preserving_unify:
  assumes
    "unify es bs = Some unifier"
    "(t, t')  set es. term.type_preserving_unifier 𝒱 υ t t'"
    "term.type_preserving 𝒱 (subst_of bs)"
  shows "term.type_preserving 𝒱 (subst_of unifier)"
  using assms
proof(induction es bs rule: unify.induct)
  case 1

  then show ?case
    by simp
next
  case (2 f ss g ts es bs)

  obtain es' where es': "decompose (Fun f ss) (Fun g ts) = Some es'"
    using 2(2)
    by (simp split: option.splits)

  show ?case
  proof (rule "2.IH"[OF es' _ _ "2.prems"(3)])

    show "unify (es' @ es) bs = Some unifier"
      using es' "2.prems"(1)
      by auto
  next

    have 
      "(t, t')set es'. term.type_preserving_unifier 𝒱 υ t t'" 
      "(t, t')set es. term.type_preserving_unifier 𝒱 υ t t'"
      using type_preserving_unifier_decompose[OF _ es'] "2.prems"(2) 
      by auto
    
    then show "(t, t')set (es' @ es). term.type_preserving_unifier 𝒱 υ t t'"
      by auto
  qed
next
  case (3 x t es bs)

  show ?case
  proof(cases "t = Var x")
    case True

    then show ?thesis
      using 3
      by simp
  next
    case False

    then show ?thesis
    proof (rule 3(2))

      show "x  term.vars t"
        using "3.prems"(1) False
        by auto
    next

      show "unify (subst_list (subst x t) es) ((x, t) # bs) = Some unifier"
        using False 3(3)
        by (auto split: if_splits)
    next

      show "(s, s')  set (subst_list (subst x t) es). term.type_preserving_unifier 𝒱 υ s s'"
        using type_preserving_unifier_subst_list[OF 3(4)] .
    next

      have "term.type_preserving 𝒱 (subst x t)"
        using 3(4) type_preserving_subst term.type_preserving_unifier
        by (smt (verit, del_insts) fun_upd_other list.set_intros(1) prod.case subst_def)  

      then show "term.type_preserving 𝒱 (subst_of ((x, t) # bs))"
        using 3(5)
        by  (simp add: subst_compose_def)
    qed
  qed
next
  case (4 f ts x es bs)

  let ?t = "Fun f ts"

  show ?case
  proof (rule 4(1))

    show "x  term.vars ?t"
      using "4.prems"
      by fastforce
  next

    show "unify (subst_list (subst x ?t) es) ((x, ?t) # bs) = Some unifier"
      using "4.prems"(1)
      by (auto split: if_splits)
  next

    show "(s, s')  set (subst_list (subst x ?t) es). term.type_preserving_unifier 𝒱 υ s s'"
    proof (rule type_preserving_unifier_subst_list)
      show "(s, s')set ((Var x, ?t) # es). term.type_preserving_unifier 𝒱 υ s s'"
        using "4.prems"(2)
        by auto
    qed
  next
    have "term.type_preserving 𝒱 (subst x ?t)"
      using 4(3) type_preserving_subst term.type_preserving_unifier
      by (smt (verit, del_insts) fun_upd_other list.set_intros(1) prod.case subst_def)  

    then show "term.type_preserving 𝒱 (subst_of ((x, ?t) # bs))"
      using 4(4)
      by (simp add: subst_compose_def)
  qed
qed

lemma type_preserving_unify_single:
  assumes
    unify: "unify [(t, t')] [] = Some unifier" and
    unifier: "term.type_preserving_unifier 𝒱 υ t t'"
  shows "term.type_preserving 𝒱 (subst_of unifier)"
  using type_preserving_unify[OF unify] unifier
  by simp

lemma type_preserving_the_mgu:
  assumes
    the_mgu: "the_mgu t t' = μ" and
    unifier: "term.type_preserving_unifier 𝒱 υ t t'"
  shows "term.type_preserving 𝒱 μ"
  using the_mgu type_preserving_unify_single[OF _ unifier]
  unfolding the_mgu_def mgu_def
  by (metis (mono_tags, lifting) option.exhaust option.simps(4,5))

sublocale type_preserving_imgu where 
  id_subst = "Var :: 'v  ('f, 'v) term" and comp_subst = "(∘s)" and subst = "(⋅)" and
  vars = term.vars and welltyped = welltyped
  by unfold_locales
     (metis (full_types) the_mgu the_mgu_term_subst_is_imgu
      type_preserving_the_mgu)

end

locale witnessed_monomorphic_term_typing =
  monomorphic_term_typing where=  for  :: "('f, 'ty) fun_types" +
assumes types_witnessed: "τ. f.  f 0 = Some ([], τ)"
begin

sublocale "term": base_witnessed_typing_properties where
  id_subst = "Var :: 'v  ('f, 'v) term" and comp_subst = "(∘s)" and subst = "(⋅)" and
  vars = term.vars and welltyped = welltyped and to_ground = term.to_ground and
  from_ground = term.from_ground
proof unfold_locales
  fix 𝒱 :: "('v, 'ty) var_types" and τ

  obtain f where f: " f 0 = Some ([], τ)"
    using types_witnessed
    by blast

  show "t. term.is_ground t  𝒱  t : τ"
  proof(rule exI[of _ "Fun f []"], intro conjI welltyped.Fun)

    show "term.is_ground (Fun f [])"
      by simp
  next

    show " f (length []) = Some ([], τ)"
      using f
      by simp
  next

    show "list_all2 (welltyped 𝒱) [] []"
      by simp
  qed
qed

end

end