Theory Term_Variants

(*  Title:      Term_Variants.thy
    Author:     Andreas Viktor Hess, DTU
    Author:     Sebastian A. Mödersheim, DTU
    Author:     Achim D. Brucker, University of Exeter
    Author:     Anders Schlichtkrull, DTU
    SPDX-License-Identifier: BSD-3-Clause
*)

section‹Term Variants›
theory Term_Variants
  imports Stateful_Protocol_Composition_and_Typing.Intruder_Deduction
begin

fun term_variants where
  "term_variants P (Var x) = [Var x]"
| "term_variants P (Fun f T) = (
  let S = product_lists (map (term_variants P) T)
  in map (Fun f) S@concat (map (λg. map (Fun g) S) (P f)))"

inductive term_variants_pred for P where
  term_variants_Var:
  "term_variants_pred P (Var x) (Var x)"
| term_variants_P:
  "length T = length S; i. i < length T  term_variants_pred P (T ! i) (S ! i); g  set (P f)
    term_variants_pred P (Fun f T) (Fun g S)"
| term_variants_Fun:
  "length T = length S; i. i < length T  term_variants_pred P (T ! i) (S ! i)
    term_variants_pred P (Fun f T) (Fun f S)"

lemma term_variants_pred_inv:
  assumes "term_variants_pred P (Fun f T) (Fun h S)"
  shows "length T = length S"
    and "i. i < length T  term_variants_pred P (T ! i) (S ! i)"
    and "f  h  h  set (P f)"
using assms by (auto elim: term_variants_pred.cases)

lemma term_variants_pred_inv':
  assumes "term_variants_pred P (Fun f T) t"
  shows "is_Fun t"
    and "length T = length (args t)"
    and "i. i < length T  term_variants_pred P (T ! i) (args t ! i)"
    and "f  the_Fun t  the_Fun t  set (P f)"
    and "P  (λ_. [])(g := [h])  f  the_Fun t  f = g  the_Fun t = h"
using assms by (auto elim: term_variants_pred.cases)

lemma term_variants_pred_inv'':
  assumes "term_variants_pred P t (Fun f T)"
  shows "is_Fun t"
    and "length T = length (args t)"
    and "i. i < length T  term_variants_pred P (args t ! i) (T ! i)"
    and "f  the_Fun t  f  set (P (the_Fun t))"
    and "P  (λ_. [])(g := [h])  f  the_Fun t  f = h  the_Fun t = g"
using assms by (auto elim: term_variants_pred.cases)

lemma term_variants_pred_inv_Var:
  "term_variants_pred P (Var x) t  t = Var x"
  "term_variants_pred P t (Var x)  t = Var x"
by (auto intro: term_variants_Var elim: term_variants_pred.cases)

lemma term_variants_pred_inv_const:
  "term_variants_pred P (Fun c []) t  ((g  set (P c). t = Fun g [])  (t = Fun c []))"
by (auto intro: term_variants_P term_variants_Fun elim: term_variants_pred.cases)

lemma term_variants_pred_refl: "term_variants_pred P t t"
by (induct t) (auto intro: term_variants_pred.intros)

lemma term_variants_pred_refl_inv:
  assumes st: "term_variants_pred P s t"
    and P: "f. g  set (P f). f = g"
  shows "s = t"
  using st P
proof (induction s t rule: term_variants_pred.induct)
  case (term_variants_Var x) thus ?case by blast
next
  case (term_variants_P T S g f)
  hence "T ! i = S ! i" when i: "i < length T" for i using i by blast
  hence "T = S" using term_variants_P.hyps(1) by (simp add: nth_equalityI)
  thus ?case using term_variants_P.prems term_variants_P.hyps(3) by fast
next
  case (term_variants_Fun T S f)
  hence "T ! i = S ! i" when i: "i < length T" for i using i by blast
  hence "T = S" using term_variants_Fun.hyps(1) by (simp add: nth_equalityI)
  thus ?case by fast
qed

lemma term_variants_pred_const:
  assumes "b  set (P a)"
  shows "term_variants_pred P (Fun a []) (Fun b [])"
using term_variants_P[of "[]" "[]"] assms by simp

lemma term_variants_pred_const_cases:
  "P a  []  term_variants_pred P (Fun a []) t 
                 (t = Fun a []  (b  set (P a). t = Fun b []))"
  "P a = []  term_variants_pred P (Fun a []) t  t = Fun a []"
using term_variants_pred_inv_const[of P] by auto

lemma term_variants_pred_param:
  assumes "term_variants_pred P t s"
    and fg: "f = g  g  set (P f)"
  shows "term_variants_pred P (Fun f (S@t#T)) (Fun g (S@s#T))"
proof -
  have 1: "length (S@t#T) = length (S@s#T)" by simp
  
  have "term_variants_pred P (T ! i) (T ! i)" "term_variants_pred P (S ! i) (S ! i)" for i
    by (metis term_variants_pred_refl)+
  hence 2: "term_variants_pred P ((S@t#T) ! i) ((S@s#T) ! i)" for i
    by (simp add: assms nth_Cons' nth_append)

  show ?thesis by (metis term_variants_Fun[OF 1 2] term_variants_P[OF 1 2] fg)
qed

lemma term_variants_pred_Cons:
  assumes t: "term_variants_pred P t s"
    and T: "term_variants_pred P (Fun f T) (Fun f S)"
    and fg: "f = g  g  set (P f)"
  shows "term_variants_pred P (Fun f (t#T)) (Fun g (s#S))"
proof -
  have 1: "length (t#T) = length (s#S)"
       and "i. i < length T  term_variants_pred P (T ! i) (S ! i)"
    using term_variants_pred_inv[OF T] by simp_all
  hence 2: "i. i < length (t#T)  term_variants_pred P ((t#T) ! i) ((s#S) ! i)"
    by (metis t One_nat_def diff_less length_Cons less_Suc_eq less_imp_diff_less nth_Cons'
              zero_less_Suc) 

  show ?thesis using 1 2 fg by (auto intro: term_variants_pred.intros)
qed

lemma term_variants_pred_dense:
  fixes P Q::"'a set" and fs gs::"'a list"
  defines "P_fs x  if x  P then fs else []"
    and "P_gs x  if x  P then gs else []"
    and "Q_fs x  if x  Q then fs else []"
  assumes ut: "term_variants_pred P_fs u t"
    and g: "g  Q" "g  set gs"
  shows "s. term_variants_pred P_gs u s  term_variants_pred Q_fs s t"
proof -
  define F where "F  λ(P::'a set) (fs::'a list) x. if x  P then fs else []"

  show ?thesis using ut g P_fs_def unfolding P_gs_def Q_fs_def
  proof (induction u t arbitrary: g gs rule: term_variants_pred.induct)
    case (term_variants_Var h x) thus ?case
      by (auto intro: term_variants_pred.term_variants_Var)
  next
    case (term_variants_P T S h' h g gs)
    note hyps = term_variants_P.hyps(1,2,4,5,6,7)
    note IH = term_variants_P.hyps(3)

    have "s. term_variants_pred (F P gs) (T ! i) s  term_variants_pred (F Q fs) s (S ! i)"
      when i: "i < length T" for i
      using IH[OF i hyps(4,5,6)] unfolding F_def by presburger
    then obtain U where U:
        "length T = length U" "i. i < length T  term_variants_pred (F P gs) (T ! i) (U ! i)"
        "length U = length S" "i. i < length U  term_variants_pred (F Q fs) (U ! i) (S ! i)"
      using hyps(1) Skolem_list_nth[of _ "λi s. term_variants_pred (F P gs) (T ! i) s 
                                                term_variants_pred (F Q fs) s (S ! i)"]
      by (metis (no_types))

    show ?case
      using term_variants_pred.term_variants_P[OF U(1,2), of g h]
            term_variants_pred.term_variants_P[OF U(3,4), of h' g]
            hyps(3)[unfolded hyps(6)] hyps(4,5)
      unfolding F_def by force
  next
    case (term_variants_Fun T S h' g gs)
    note hyps = term_variants_Fun.hyps(1,2,4,5,6)
    note IH = term_variants_Fun.hyps(3)

    have "s. term_variants_pred (F P gs) (T ! i) s  term_variants_pred (F Q fs) s (S ! i)"
      when i: "i < length T" for i
      using IH[OF i hyps(3,4,5)] unfolding F_def by presburger
    then obtain U where U:
        "length T = length U" "i. i < length T  term_variants_pred (F P gs) (T ! i) (U ! i)"
        "length U = length S" "i. i < length U  term_variants_pred (F Q fs) (U ! i) (S ! i)"
      using hyps(1) Skolem_list_nth[of _ "λi s. term_variants_pred (F P gs) (T ! i) s 
                                                term_variants_pred (F Q fs) s (S ! i)"]
      by (metis (no_types))
    
    thus ?case
      using term_variants_pred.term_variants_Fun[OF U(1,2)]
            term_variants_pred.term_variants_Fun[OF U(3,4)]
      unfolding F_def by meson
  qed
qed

lemma term_variants_pred_dense':
  assumes ut: "term_variants_pred ((λ_. [])(a := [b])) u t"
  shows "s. term_variants_pred ((λ_. [])(a := [c])) u s 
             term_variants_pred ((λ_. [])(c := [b])) s t"
using ut term_variants_pred_dense[of "{a}" "[b]" u t c "{c}" "[c]"]
unfolding fun_upd_def by simp

lemma term_variants_pred_eq_case:
  fixes t s::"('a,'b) term"
  assumes "term_variants_pred P t s" "f  funs_term t. P f = []"
  shows "t = s"
using assms
proof (induction t s rule: term_variants_pred.induct)
  case (term_variants_Fun T S f) thus ?case
    using subtermeq_imp_funs_term_subset[OF Fun_param_in_subterms[OF nth_mem], of _ T f]
          nth_equalityI[of T S]
    by blast
qed (simp_all add: term_variants_pred_refl)

lemma term_variants_pred_subst:
  assumes "term_variants_pred P t s"
  shows "term_variants_pred P (t  δ) (s  δ)"
using assms
proof (induction t s rule: term_variants_pred.induct)
  case (term_variants_P T S f g)
  have 1: "length (map (λt. t  δ) T) = length (map (λt. t  δ) S)"
    using term_variants_P.hyps
    by simp

  have 2: "term_variants_pred P ((map (λt. t  δ) T) ! i) ((map (λt. t  δ) S) ! i)"
    when "i < length (map (λt. t  δ) T)" for i
    using term_variants_P that
    by fastforce

  show ?case
    using term_variants_pred.term_variants_P[OF 1 2 term_variants_P.hyps(3)]
    by fastforce
next
  case (term_variants_Fun T S f)
  have 1: "length (map (λt. t  δ) T) = length (map (λt. t  δ) S)"
    using term_variants_Fun.hyps
    by simp

  have 2: "term_variants_pred P ((map (λt. t  δ) T) ! i) ((map (λt. t  δ) S) ! i)"
    when "i < length (map (λt. t  δ) T)" for i
    using term_variants_Fun that
    by fastforce

  show ?case
    using term_variants_pred.term_variants_Fun[OF 1 2]
    by fastforce
qed (simp add: term_variants_pred_refl)

lemma term_variants_pred_subst':
  fixes t s::"('a,'b) term" and δ::"('a,'b) subst"
  assumes "term_variants_pred P (t  δ) s"
    and "x  fv t  fv s. (y. δ x = Var y)  (f. δ x = Fun f []  P f = [])"
  shows "u. term_variants_pred P t u  s = u  δ"
using assms
proof (induction "t  δ" s arbitrary: t rule: term_variants_pred.induct)
  case (term_variants_Var x g) thus ?case using term_variants_pred_refl by fast
next
  case (term_variants_P T S g f) show ?case
  proof (cases t)
    case (Var x) thus ?thesis
      using term_variants_P.hyps(4,5) term_variants_P.prems
      by fastforce
  next
    case (Fun h U)
    hence 1: "h = f" "T = map (λs. s  δ) U" "length U = length T"
      using term_variants_P.hyps(5) by simp_all
    hence 2: "T ! i = U ! i  δ" when "i < length T" for i
      using that by simp

    have "x  fv (U ! i)  fv (S ! i). (y. δ x = Var y)  (f. δ x = Fun f []  P f = [])"
      when "i < length U" for i
      using that Fun term_variants_P.prems term_variants_P.hyps(1) 1(3)
      by force
    hence IH: "i < length U. u. term_variants_pred P (U ! i) u  S ! i = u  δ"
      by (metis 1(3) term_variants_P.hyps(3)[OF _ 2])

    have "V. length U = length V  S = map (λv. v  δ) V 
               (i < length U. term_variants_pred P (U ! i) (V ! i))"
      using term_variants_P.hyps(1) 1(3) subst_term_list_obtain[OF IH] by metis
    then obtain V where V: "length U = length V" "S = map (λv. v  δ) V"
                           "i. i < length U  term_variants_pred P (U ! i) (V ! i)"
      by blast

    have "term_variants_pred P (Fun f U) (Fun g V)"
      by (metis term_variants_pred.term_variants_P[OF V(1,3) term_variants_P.hyps(4)])
    moreover have "Fun g S = Fun g V  δ" using V(2) by simp
    ultimately show ?thesis using term_variants_P.hyps(1,4) Fun 1 by blast
  qed
next
  case (term_variants_Fun T S f t) show ?case
  proof (cases t)
    case (Var x)
    hence "T = []" "P f = []" using term_variants_Fun.hyps(4) term_variants_Fun.prems by fastforce+
    thus ?thesis using term_variants_pred_refl Var term_variants_Fun.hyps(1,4) by fastforce
  next
    case (Fun h U)
    hence 1: "h = f" "T = map (λs. s  δ) U" "length U = length T"
      using term_variants_Fun.hyps(4) by simp_all
    hence 2: "T ! i = U ! i  δ" when "i < length T" for i
      using that by simp

    have "x  fv (U ! i)  fv (S ! i). (y. δ x = Var y)  (f. δ x = Fun f []  P f = [])"
      when "i < length U" for i
      using that Fun term_variants_Fun.prems term_variants_Fun.hyps(1) 1(3)
      by force
    hence IH: "i < length U. u. term_variants_pred P (U ! i) u  S ! i = u  δ"
      by (metis 1(3) term_variants_Fun.hyps(3)[OF _ 2 ])

    have "V. length U = length V  S = map (λv. v  δ) V 
               (i < length U. term_variants_pred P (U ! i) (V ! i))"
      using term_variants_Fun.hyps(1) 1(3) subst_term_list_obtain[OF IH] by metis
    then obtain V where V: "length U = length V" "S = map (λv. v  δ) V"
                           "i. i < length U  term_variants_pred P (U ! i) (V ! i)"
      by blast

    have "term_variants_pred P (Fun f U) (Fun f V)"
      by (metis term_variants_pred.term_variants_Fun[OF V(1,3)])
    moreover have "Fun f S = Fun f V  δ" using V(2) by simp
    ultimately show ?thesis using term_variants_Fun.hyps(1) Fun 1 by blast
  qed
qed

lemma term_variants_pred_subst'':
  assumes "x  fv t. term_variants_pred P (δ x) (θ x)"
  shows "term_variants_pred P (t  δ) (t  θ)"
using assms
proof (induction t)
  case (Fun f ts) thus ?case
    using term_variants_Fun[of "map (λt. t  δ) ts" "map (λt. t  θ) ts" P f] by force
qed simp

lemma term_variants_pred_iff_in_term_variants:
  fixes t::"('a,'b) term"
  shows "term_variants_pred P t s  s  set (term_variants P t)"
    (is "?A t s  ?B t s")
proof
  define U where "U  λP (T::('a,'b) term list). product_lists (map (term_variants P) T)"

  have a:
      "g  set (P f)  set (map (Fun g) (U P T))  set (term_variants P (Fun f T))"
      "set (map (Fun f) (U P T))  set (term_variants P (Fun f T))"
    for f P g and T::"('a,'b) term list"
    using term_variants.simps(2)[of P f T]
    unfolding U_def Let_def by auto

  have b: "S  set (U P T). s = Fun f S  (g  set (P f). s = Fun g S)"
    when "s  set (term_variants P (Fun f T))" for P T f s
    using that by (cases "P f") (auto simp add: U_def Let_def)

  have c: "length T = length S" when "S  set (U P T)" for S P T
    using that unfolding U_def
    by (simp add: in_set_product_lists_length)

  show "?A t s  ?B t s"
  proof (induction t s rule: term_variants_pred.induct)
    case (term_variants_P T S g f)
    note hyps = term_variants_P.hyps
    note IH = term_variants_P.IH

    have "S  set (U P T)"
      using IH hyps(1) product_lists_in_set_nth'[of _ S]
      unfolding U_def by simp
    thus ?case using a(1)[of _ P, OF hyps(3)] by auto
  next
    case (term_variants_Fun T S f)
    note hyps = term_variants_Fun.hyps
    note IH = term_variants_Fun.IH

    have "S  set (U P T)"
      using IH hyps(1) product_lists_in_set_nth'[of _ S]
      unfolding U_def by simp
    thus ?case using a(2)[of f P T] by (cases "P f") auto
  qed (simp add: term_variants_Var)

  show "?B t s  ?A t s"
  proof (induction P t arbitrary: s rule: term_variants.induct)
    case (2 P f T)
    obtain S where S:
        "s = Fun f S  (g  set (P f). s = Fun g S)"
        "S  set (U P T)" "length T = length S"
      using c b[OF "2.prems"] by blast

    have "i < length T. term_variants_pred P (T ! i) (S ! i)"
      using "2.IH" S product_lists_in_set_nth by (fastforce simp add: U_def)
    thus ?case using S by (auto intro: term_variants_pred.intros)
  qed (simp add: term_variants_Var)
qed

lemma term_variants_pred_finite:
  "finite {s. term_variants_pred P t s}"
using term_variants_pred_iff_in_term_variants[of P t]
by simp

lemma term_variants_pred_fv_eq:
  assumes "term_variants_pred P s t"
  shows "fv s = fv t"
using assms
by (induct rule: term_variants_pred.induct)
   (metis, metis fv_eq_FunI, metis fv_eq_FunI)

lemma (in intruder_model) term_variants_pred_wf_trms:
  assumes "term_variants_pred P s t"
    and "f g. g  set (P f)  arity f = arity g"
    and "wftrm s"
  shows "wftrm t"
using assms
apply (induction rule: term_variants_pred.induct, simp)
by (metis (no_types) wf_trmI wf_trm_arity in_set_conv_nth wf_trm_param_idx)+

lemma term_variants_pred_funs_term:
  assumes "term_variants_pred P s t"
    and "f  funs_term t"
  shows "f  funs_term s  (g  funs_term s. f  set (P g))"
  using assms
proof (induction rule: term_variants_pred.induct)
  case (term_variants_P T S g h) thus ?case
  proof (cases "f = g")
    case False
    then obtain s where "s  set S" "f  funs_term s"
      using funs_term_subterms_eq(1)[of "Fun g S"] term_variants_P.prems by auto
    thus ?thesis
      using term_variants_P.IH term_variants_P.hyps(1) in_set_conv_nth[of s S] by force
  qed simp
next
  case (term_variants_Fun T S h) thus ?case
  proof (cases "f = h")
    case False
    then obtain s where "s  set S" "f  funs_term s"
      using funs_term_subterms_eq(1)[of "Fun h S"] term_variants_Fun.prems by auto
    thus ?thesis
      using term_variants_Fun.IH term_variants_Fun.hyps(1) in_set_conv_nth[of s S] by force
  qed simp
qed fast

end