Theory Refine_Imperative_HOL.IICF_Matrix

section ‹Matrices›
theory IICF_Matrix
imports "../../Sepref"
begin
  subsection ‹Relator and Interface›
  definition [to_relAPP]: "mtx_rel A  nat_rel ×r nat_rel  A"

  lemma mtx_rel_id[simp]: "Idmtx_rel = Id" unfolding mtx_rel_def by auto
  
  type_synonym 'a mtx = "nat×nat  'a"
  sepref_decl_intf 'a i_mtx is "nat×nat  'a"

  lemma [synth_rules]: "INTF_OF_REL A TYPE('a)  INTF_OF_REL (Amtx_rel) TYPE('a i_mtx)"
    by simp
  
  subsection ‹Operations›  

  definition op_mtx_new :: "'a mtx  'a mtx" where [simp]: "op_mtx_new c  c"

  sepref_decl_op (no_def) mtx_new: "op_mtx_new" :: "(nat_rel×rnat_rel  A)  Amtx_rel"
    apply (rule fref_ncI) unfolding op_mtx_new_def[abs_def] mtx_rel_def 
    by parametricity

  (* TODO: Ad-hoc rule *)
  lemma mtx_init_adhoc_frame_match_rule[sepref_frame_match_rules]:
    "hn_val (nat_rel×rnat_rel  A) x y t hn_val (nat_rel×rnat_rel  the_pure (pure A)) x y"
    by simp

  definition op_mtx_copy :: "'a mtx  'a mtx" where [simp]: "op_mtx_copy c  c"

  sepref_decl_op (no_def) mtx_copy: "op_mtx_copy" :: "Amtx_rel  Amtx_rel" .

  sepref_decl_op mtx_get: "λ(c::'a mtx) ij. c ij" :: "Amtx_rel  (nat_rel×rnat_rel)  A"
    apply (rule fref_ncI) unfolding mtx_rel_def
    by parametricity
    
  sepref_decl_op mtx_set: "fun_upd::'a mtx  _" :: "Amtx_rel  (nat_rel×rnat_rel)  A  Amtx_rel"
    apply (rule fref_ncI) 
    unfolding mtx_rel_def
  proof goal_cases case 1  
    have [param]: "((=), (=))  nat_rel ×r nat_rel  nat_rel ×r nat_rel  bool_rel" by simp
    show ?case by parametricity
  qed

  definition mtx_nonzero :: "_ mtx  (nat×nat) set" where "mtx_nonzero m  {(i,j). m (i,j)0}"

  sepref_decl_op mtx_nonzero: "mtx_nonzero" :: "Amtx_rel  nat_rel×rnat_relset_rel"
    where "IS_ID (A::(_×(_::zero)) set)"
  proof goal_cases
    case 1
    assume "IS_ID A"
    hence U: "A=Id" by (simp only: IS_ID_def)
    have [param]: "((=),(=))AAbool_rel" using U by simp
    show ?case
      apply (rule fref_ncI)
      unfolding mtx_rel_def
      apply parametricity
      unfolding U by simp_all
  qed

  subsection ‹Patterns›
  lemma pat_amtx_get: "c$eop_mtx_get$'c$'e" by simp
  lemma pat_amtx_set: "fun_upd$c$e$vop_mtx_set$'c$'e$'v" by simp

  lemmas amtx_pats[pat_rules] = pat_amtx_get pat_amtx_set

  subsection ‹Pointwise Operations›
  subsubsection ‹Auxiliary Definitions and Lemmas›
  locale pointwise_op =
    fixes f :: "'p  's  's"
    fixes q :: "'s  'p  'a"
    assumes upd_indep1[simp, intro]: "pp'  q (f p s) p' = q s p'"
    assumes upd_indep2[simp, intro]: "pp'  q (f p (f p' s)) p = q (f p s) p"
  begin
    lemma pointwise_upd_fold: "distinct ps  
      q (fold f ps s) p = (if pset ps then q (f p s) p else q s p)"
      by (induction ps arbitrary: s) auto
  
  end
  
  lemma pointwise_fun_fold: 
    fixes f :: "'a  ('a  'b)  ('a  'b)"
    fixes s :: "'a  'b"
    assumes indep1: "x x' s. x  x'  f x s x' = s x'"
    assumes indep2:  "x x' s. x  x'  f x (f x' s) x = f x s x"
    assumes [simp]: "distinct xs"
    shows "fold f xs s x = (if x  set xs then f x s x else s x)"
  proof -
    interpret pointwise_op f "λs. s"
      by unfold_locales fact+
  
    show ?thesis  
      using pointwise_upd_fold[of xs s x]
      by auto
  qed

  lemma list_prod_divmod_eq: "List.product [0..<M] [0..<N] = map (λi. (i div N, i mod N)) [0..<N*M]"
  proof -
    have [simp]: "i < m*n  (i::nat) div n < m" for i m n
      by (metis mult.commute div_eq_0_iff div_mult2_eq gr_implies_not_zero mult_not_zero)

    have [simp]: "i<N*M  N>0  M>0" for i
      by (cases N; cases M; auto)

    show ?thesis  
      by (rule nth_equalityI) (auto simp add: product_nth algebra_simps)
  qed    


  lemma nfoldli_prod_divmod_conv: 
    "nfoldli (List.product [0..<N] [0..<M]) ctd (λ(i,j). f i j) = nfoldli [0..<N*M] ctd (λi. f (i div M) (i mod M))"
    apply (intro ext)
    apply (subst list_prod_divmod_eq)
    apply (simp add: nfoldli_map)
    apply (fo_rule cong)+
    apply (auto simp: algebra_simps)
    done

  lemma nfoldli_prod_divmod_conv': 
    "nfoldli [0..<M] ctd (λi. nfoldli [0..<N] ctd (f i)) = nfoldli [0..<N*M] ctd (λi. f (i div N) (i mod N))"
    apply (intro ext)
    apply (subst nfoldli_nfoldli_prod_conv)
    by (simp add: nfoldli_prod_divmod_conv algebra_simps)

  lemma foldli_prod_divmod_conv': 
    "foldli [0..<M] ctd (λi. foldli [0..<N] ctd (f i)) = foldli [0..<N*M] ctd (λi. f (i div N) (i mod N))"
    (is "?lhs=?rhs")  
  proof -
    have "RETURN (?lhs s) = RETURN (?rhs s)" for s 
      apply (subst foldli_eq_nfoldli)+
      apply (subst nfoldli_prod_divmod_conv')
      ..
    thus ?thesis by auto
  qed  
    
  lemma fold_prod_divmod_conv': "fold (λi. fold (f i) [0..<N]) [0..<M] = fold (λi. f (i div N) (i mod N)) [0..<N*M]"
    using foldli_prod_divmod_conv'[of M "λ_. True" N f, THEN fun_cong]
    apply (intro ext)
    apply (simp add: foldli_foldl foldl_conv_fold)
    done
    


  lemma mtx_nonzero_cases[consumes 0, case_names nonzero zero]:
    obtains "(i,j)mtx_nonzero m" | "m (i,j) = 0"
    by (auto simp: mtx_nonzero_def)
  


  subsubsection ‹Unary Pointwise›
  definition mtx_pointwise_unop :: "(nat×nat  'a  'a)  'a mtx  'a mtx" where
    "mtx_pointwise_unop f m  λ(i,j). f (i,j) (m(i,j))"
  
  context fixes f :: "nat×nat  'a  'a" begin
    sepref_register "PR_CONST (mtx_pointwise_unop f)" :: "'a i_mtx  'a i_mtx"
    lemma [def_pat_rules]: "mtx_pointwise_unop$f  UNPROTECT (mtx_pointwise_unop f)" by simp
  end
  
  locale mtx_pointwise_unop_loc =
    fixes N :: nat and M :: nat
    fixes f :: "(nat×nat)  'a::{zero}  'a"
    assumes pres_zero[simp]: " iN  jM   f (i,j) 0 = 0"
  begin  
    definition "opr_fold_impl  fold (λi. fold (λj m. m( (i,j) := f (i,j) (m(i,j)) )) [0..<M]) [0..<N]"
    
    lemma opr_fold_impl_eq:
      assumes "mtx_nonzero m  {0..<N}×{0..<M}"
      shows "mtx_pointwise_unop f m = opr_fold_impl m"
      apply (rule ext)
      unfolding opr_fold_impl_def
      apply (simp add: fold_fold_prod_conv)
      apply (subst pointwise_fun_fold)
      apply (auto simp: mtx_pointwise_unop_def distinct_product) [3]
      apply clarsimp
      subgoal for a b
        apply (cases a b m rule: mtx_nonzero_cases)
        using assms
        apply (auto simp: mtx_pointwise_unop_def)
        done
      done  
  
    lemma opr_fold_impl_refine: "(opr_fold_impl, mtx_pointwise_unop f)  [λm. mtx_nonzero m  {0..<N}×{0..<M}]f Id  Id"  
      apply (rule frefI)
      using opr_fold_impl_eq
      by auto
  
  end
  
  locale mtx_pointwise_unop_gen_impl = mtx_pointwise_unop_loc +
    fixes assn :: "'a mtx  'i  assn"
    fixes A :: "'a  'ai  assn"
    fixes get_impl :: "'i  nat×nat  'ai Heap"
    fixes set_impl :: "'i  nat×nat  'ai  'i Heap"
    fixes fi :: "nat×nat  'ai  'ai Heap"
    assumes assn_range: "rdomp assn m  mtx_nonzero m  {0..<N}×{0..<M}"
    assumes get_impl_hnr: 
      "(uncurry get_impl,uncurry (RETURN oo op_mtx_get))  assnk *a (prod_assn (nbn_assn N) (nbn_assn M))k a A"
    assumes set_impl_hnr: 
      "(uncurry2 set_impl,uncurry2 (RETURN ooo op_mtx_set))  assnd *a (prod_assn (nbn_assn N) (nbn_assn M))k *a Ak a assn"
    assumes fi_hnr:
      "(uncurry fi,uncurry (RETURN oo f))  (prod_assn nat_assn nat_assn)k *a Ak a A"  
  begin
  
    lemma this_loc: "mtx_pointwise_unop_gen_impl N M f assn A get_impl set_impl fi"
      by unfold_locales
  
    context 
      notes [[sepref_register_adhoc f N M]]
      notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"]
      notes [sepref_import_param] = IdI[of N] IdI[of M]
      notes [sepref_fr_rules] = get_impl_hnr set_impl_hnr fi_hnr
    begin
      sepref_thm opr_fold_impl1 is "RETURN o opr_fold_impl" :: "assnd a assn"
        unfolding opr_fold_impl_def
        supply [[goals_limit = 1]]
        by sepref
    end    
  
    concrete_definition (in -) mtx_pointwise_unnop_fold_impl1 uses mtx_pointwise_unop_gen_impl.opr_fold_impl1.refine_raw
    prepare_code_thms (in -) mtx_pointwise_unnop_fold_impl1_def
  
    lemma op_hnr[sepref_fr_rules]: "(mtx_pointwise_unnop_fold_impl1 N M get_impl set_impl fi, RETURN  PR_CONST (mtx_pointwise_unop f))  assnd a assn"
      unfolding PR_CONST_def
      apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_unnop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]])
      by (simp add: assn_range)
  
  end

  subsubsection ‹Binary Pointwise›
  definition mtx_pointwise_binop :: "('a  'a  'a)  'a mtx  'a mtx  'a mtx" where
    "mtx_pointwise_binop f m n  λ(i,j). f (m(i,j)) (n(i,j))"
  context fixes f :: "'a  'a  'a" begin
    sepref_register "PR_CONST (mtx_pointwise_binop f)" :: "'a i_mtx  'a i_mtx  'a i_mtx"
    lemma [def_pat_rules]: "mtx_pointwise_binop$f  UNPROTECT (mtx_pointwise_binop f)" by simp
  end
  
  locale mtx_pointwise_binop_loc =
    fixes N :: nat and M :: nat
    fixes f :: "'a::{zero}  'a  'a"
    assumes pres_zero[simp]: "f 0 0 = 0"
  begin  
    definition "opr_fold_impl m n  fold (λi. fold (λj m. m( (i,j) := f (m(i,j)) (n(i,j)) )) [0..<M]) [0..<N] m"
    
    lemma opr_fold_impl_eq:
      assumes "mtx_nonzero m  {0..<N}×{0..<M}"
      assumes "mtx_nonzero n  {0..<N}×{0..<M}"
      shows "mtx_pointwise_binop f m n = opr_fold_impl m n"
      apply (rule ext)
      unfolding opr_fold_impl_def
      apply (simp add: fold_fold_prod_conv)
      apply (subst pointwise_fun_fold)
      apply (auto simp: mtx_pointwise_binop_def distinct_product) [3]
      apply clarsimp
      subgoal for a b
        apply (cases a b m rule: mtx_nonzero_cases; cases a b n rule: mtx_nonzero_cases)
        using assms
        apply (auto simp: mtx_pointwise_binop_def)
        done
      done  
  
    lemma opr_fold_impl_refine: "(uncurry opr_fold_impl, uncurry (mtx_pointwise_binop f))  [λ(m,n). mtx_nonzero m  {0..<N}×{0..<M}  mtx_nonzero n  {0..<N}×{0..<M}]f Id×rId  Id"  
      apply (rule frefI)
      using opr_fold_impl_eq
      by auto
  
  end
  
  locale mtx_pointwise_binop_gen_impl = mtx_pointwise_binop_loc +
    fixes assn :: "'a mtx  'i  assn"
    fixes A :: "'a  'ai  assn"
    fixes get_impl :: "'i  nat×nat  'ai Heap"
    fixes set_impl :: "'i  nat×nat  'ai  'i Heap"
    fixes fi :: "'ai  'ai  'ai Heap"
    assumes assn_range: "rdomp assn m  mtx_nonzero m  {0..<N}×{0..<M}"
    assumes get_impl_hnr: 
      "(uncurry get_impl,uncurry (RETURN oo op_mtx_get))  assnk *a (prod_assn (nbn_assn N) (nbn_assn M))k a A"
    assumes set_impl_hnr: 
      "(uncurry2 set_impl,uncurry2 (RETURN ooo op_mtx_set))  assnd *a (prod_assn (nbn_assn N) (nbn_assn M))k *a Ak a assn"
    assumes fi_hnr:
      "(uncurry fi,uncurry (RETURN oo f))  Ak *a Ak a A"  
  begin
  
    lemma this_loc: "mtx_pointwise_binop_gen_impl N M f assn A get_impl set_impl fi"
      by unfold_locales
  
    context 
      notes [[sepref_register_adhoc f N M]]
      notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"]
      notes [sepref_import_param] = IdI[of N] IdI[of M]
      notes [sepref_fr_rules] = get_impl_hnr set_impl_hnr fi_hnr
    begin
      sepref_thm opr_fold_impl1 is "uncurry (RETURN oo opr_fold_impl)" :: "assnd*aassnk a assn"
        unfolding opr_fold_impl_def[abs_def]
        by sepref
        
    end    
  
    concrete_definition (in -) mtx_pointwise_binop_fold_impl1 
      uses mtx_pointwise_binop_gen_impl.opr_fold_impl1.refine_raw is "(uncurry ?f,_)_"
    prepare_code_thms (in -) mtx_pointwise_binop_fold_impl1_def
  
    lemma op_hnr[sepref_fr_rules]: "(uncurry (mtx_pointwise_binop_fold_impl1 N M get_impl set_impl fi), uncurry (RETURN oo PR_CONST (mtx_pointwise_binop f)))  assnd *a assnk a assn"
      unfolding PR_CONST_def
      apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_binop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]])
      apply (auto dest: assn_range)
      done
  
  end



  subsubsection ‹Compare Pointwise›
  definition mtx_pointwise_cmpop :: "('a  'a  bool)  ('a  'a  bool)  'a mtx  'a mtx  bool" where
    "mtx_pointwise_cmpop f g m n  (i j. f (m(i,j)) (n(i,j)))  (i j. g (m(i,j)) (n(i,j)))"
  context fixes f g :: "'a  'a  bool" begin
    sepref_register "PR_CONST (mtx_pointwise_cmpop f g)" :: "'a i_mtx  'a i_mtx  bool"
    lemma [def_pat_rules]: "mtx_pointwise_cmpop$f$g  UNPROTECT (mtx_pointwise_cmpop f g)" by simp
  end

  (* TODO: Move *)  
  lemma mtx_nonzeroD:
    "¬i<N; mtx_nonzero m  {0..<N}×{0..<M}  m(i,j) = 0"
    "¬j<M; mtx_nonzero m  {0..<N}×{0..<M}  m(i,j) = 0"
    by (auto simp: mtx_nonzero_def)


  locale mtx_pointwise_cmpop_loc =
    fixes N :: nat and M :: nat
    fixes f g :: "'a::{zero}  'a  bool"
    assumes pres_zero[simp]: "f 0 0 = True" "g 0 0 = False"
  begin  
    definition "opr_fold_impl m n  do {
      s  nfoldli (List.product [0..<N] [0..<M]) (λs. s2) (λ(i,j) s. do {
        if f (m(i,j)) (n(i,j)) then
          if s=0 then
            if g (m(i,j)) (n(i,j)) then RETURN 1 else RETURN s
          else
            RETURN s

        else RETURN 2
      }) (0::nat);
      RETURN (s=1)
    }"

    lemma opr_fold_impl_eq:
      assumes "mtx_nonzero m  {0..<N}×{0..<M}"
      assumes "mtx_nonzero n  {0..<N}×{0..<M}"
      shows "opr_fold_impl m n  RETURN (mtx_pointwise_cmpop f g m n)"
    proof -
      have "(i<N. j<M. f (m (i, j)) (n (i, j)))  f (m (i, j)) (n (i, j))" for i j
        apply (cases "i<N"; cases "j<M")
        using assms by (auto simp: mtx_nonzeroD)
      moreover have "g (m (i, j)) (n (i, j))  (i<N. j<M. g (m (i, j)) (n (i, j)))" for i j
        apply (cases "i<N"; cases "j<M")
        using assms by (auto simp: mtx_nonzeroD)
      ultimately have EQ: "mtx_pointwise_cmpop f g m n 
         (i<N. j<M. f (m(i,j)) (n(i,j)))  (i<N. j<M. g (m(i,j)) (n(i,j)))"
        unfolding mtx_pointwise_cmpop_def by meson
        
      have aux: "List.product [0..<N] [0..<M] = l1 @ (i, j) # l2  i<N  j<M" for l1 i j l2
      proof -
        assume "List.product [0..<N] [0..<M] = l1 @ (i, j) # l2"
        hence "(i,j)set (List.product [0..<N] [0..<M])" by simp
        thus ?thesis by simp
      qed  

      show ?thesis  
        unfolding opr_fold_impl_def
        apply (refine_vcg
          nfoldli_rule[where I="λl1 _ s. 
              if s=2 then i<N. j<M. ¬f (m(i,j)) (n(i,j)) 
              else (
                (s=0  s=1) 
                ((i,j)set l1. f (m(i,j)) (n(i,j))) 
                (s=1  ((i,j)set l1. g (m(i,j)) (n(i,j))))
              )"]
          )
        apply (vc_solve dest: aux solve: asm_rl simp: EQ) [6]
        apply (fastforce simp: EQ)
        done
    qed    
  
    lemma opr_fold_impl_refine: 
      "(uncurry opr_fold_impl, uncurry (RETURN oo mtx_pointwise_cmpop f g))  [λ(m,n). mtx_nonzero m  {0..<N}×{0..<M}  mtx_nonzero n  {0..<N}×{0..<M}]f Id×rId  bool_relnres_rel"  
      apply (rule frefI)
      using opr_fold_impl_eq
      by (auto intro: nres_relI)
  
  end
  
  locale mtx_pointwise_cmpop_gen_impl = mtx_pointwise_cmpop_loc +
    fixes assn :: "'a mtx  'i  assn"
    fixes A :: "'a  'ai  assn"
    fixes get_impl :: "'i  nat×nat  'ai Heap"
    fixes fi :: "'ai  'ai  bool Heap"
    fixes gi :: "'ai  'ai  bool Heap"
    assumes assn_range: "rdomp assn m  mtx_nonzero m  {0..<N}×{0..<M}"
    assumes get_impl_hnr: 
      "(uncurry get_impl,uncurry (RETURN oo op_mtx_get))  assnk *a (prod_assn (nbn_assn N) (nbn_assn M))k a A"
    assumes fi_hnr:
      "(uncurry fi,uncurry (RETURN oo f))  Ak *a Ak a bool_assn"  
    assumes gi_hnr:
      "(uncurry gi,uncurry (RETURN oo g))  Ak *a Ak a bool_assn"  
  begin
  
    lemma this_loc: "mtx_pointwise_cmpop_gen_impl N M f g assn A get_impl fi gi"
      by unfold_locales
  
    context 
      notes [[sepref_register_adhoc f g N M]]
      notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"]
      notes [sepref_import_param] = IdI[of N] IdI[of M]
      notes [sepref_fr_rules] = get_impl_hnr fi_hnr gi_hnr
    begin
      sepref_thm opr_fold_impl1 is "uncurry opr_fold_impl" :: "assnd*aassnk a bool_assn"
        unfolding opr_fold_impl_def[abs_def] nfoldli_nfoldli_prod_conv[symmetric]
        by sepref
        
    end    
  
    concrete_definition (in -) mtx_pointwise_cmpop_fold_impl1 
      uses mtx_pointwise_cmpop_gen_impl.opr_fold_impl1.refine_raw is "(uncurry ?f,_)_"
    prepare_code_thms (in -) mtx_pointwise_cmpop_fold_impl1_def
  
    lemma op_hnr[sepref_fr_rules]: "(uncurry (mtx_pointwise_cmpop_fold_impl1 N M get_impl fi gi), uncurry (RETURN oo PR_CONST (mtx_pointwise_cmpop f g)))  assnd *a assnk a bool_assn"
      unfolding PR_CONST_def
      apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_cmpop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]])
      apply (auto dest: assn_range)
      done
  
  end

end