Theory IICF_Matrix
section ‹Matrices›
theory IICF_Matrix
imports "../../Sepref"
begin
subsection ‹Relator and Interface›
definition [to_relAPP]: "mtx_rel A ≡ nat_rel ×⇩r nat_rel → A"
lemma mtx_rel_id[simp]: "⟨Id⟩mtx_rel = Id" unfolding mtx_rel_def by auto
type_synonym 'a mtx = "nat×nat ⇒ 'a"
sepref_decl_intf 'a i_mtx is "nat×nat ⇒ 'a"
lemma [synth_rules]: "INTF_OF_REL A TYPE('a) ⟹ INTF_OF_REL (⟨A⟩mtx_rel) TYPE('a i_mtx)"
by simp
subsection ‹Operations›
definition op_mtx_new :: "'a mtx ⇒ 'a mtx" where [simp]: "op_mtx_new c ≡ c"
sepref_decl_op (no_def) mtx_new: "op_mtx_new" :: "(nat_rel×⇩rnat_rel → A) → ⟨A⟩mtx_rel"
apply (rule fref_ncI) unfolding op_mtx_new_def[abs_def] mtx_rel_def
by parametricity
lemma mtx_init_adhoc_frame_match_rule[sepref_frame_match_rules]:
"hn_val (nat_rel×⇩rnat_rel → A) x y ⟹⇩t hn_val (nat_rel×⇩rnat_rel → the_pure (pure A)) x y"
by simp
definition op_mtx_copy :: "'a mtx ⇒ 'a mtx" where [simp]: "op_mtx_copy c ≡ c"
sepref_decl_op (no_def) mtx_copy: "op_mtx_copy" :: "⟨A⟩mtx_rel → ⟨A⟩mtx_rel" .
sepref_decl_op mtx_get: "λ(c::'a mtx) ij. c ij" :: "⟨A⟩mtx_rel → (nat_rel×⇩rnat_rel) → A"
apply (rule fref_ncI) unfolding mtx_rel_def
by parametricity
sepref_decl_op mtx_set: "fun_upd::'a mtx ⇒ _" :: "⟨A⟩mtx_rel → (nat_rel×⇩rnat_rel) → A → ⟨A⟩mtx_rel"
apply (rule fref_ncI)
unfolding mtx_rel_def
proof goal_cases case 1
have [param]: "((=), (=)) ∈ nat_rel ×⇩r nat_rel → nat_rel ×⇩r nat_rel → bool_rel" by simp
show ?case by parametricity
qed
definition mtx_nonzero :: "_ mtx ⇒ (nat×nat) set" where "mtx_nonzero m ≡ {(i,j). m (i,j)≠0}"
sepref_decl_op mtx_nonzero: "mtx_nonzero" :: "⟨A⟩mtx_rel → ⟨nat_rel×⇩rnat_rel⟩set_rel"
where "IS_ID (A::(_×(_::zero)) set)"
proof goal_cases
case 1
assume "IS_ID A"
hence U: "A=Id" by (simp only: IS_ID_def)
have [param]: "((=),(=))∈A→A→bool_rel" using U by simp
show ?case
apply (rule fref_ncI)
unfolding mtx_rel_def
apply parametricity
unfolding U by simp_all
qed
subsection ‹Patterns›
lemma pat_amtx_get: "c$e≡op_mtx_get$'c$'e" by simp
lemma pat_amtx_set: "fun_upd$c$e$v≡op_mtx_set$'c$'e$'v" by simp
lemmas amtx_pats[pat_rules] = pat_amtx_get pat_amtx_set
subsection ‹Pointwise Operations›
subsubsection ‹Auxiliary Definitions and Lemmas›
locale pointwise_op =
fixes f :: "'p ⇒ 's ⇒ 's"
fixes q :: "'s ⇒ 'p ⇒ 'a"
assumes upd_indep1[simp, intro]: "p≠p' ⟹ q (f p s) p' = q s p'"
assumes upd_indep2[simp, intro]: "p≠p' ⟹ q (f p (f p' s)) p = q (f p s) p"
begin
lemma pointwise_upd_fold: "distinct ps ⟹
q (fold f ps s) p = (if p∈set ps then q (f p s) p else q s p)"
by (induction ps arbitrary: s) auto
end
lemma pointwise_fun_fold:
fixes f :: "'a ⇒ ('a ⇒ 'b) ⇒ ('a ⇒ 'b)"
fixes s :: "'a ⇒ 'b"
assumes indep1: "⋀x x' s. x ≠ x' ⟹ f x s x' = s x'"
assumes indep2: "⋀x x' s. x ≠ x' ⟹ f x (f x' s) x = f x s x"
assumes [simp]: "distinct xs"
shows "fold f xs s x = (if x ∈ set xs then f x s x else s x)"
proof -
interpret pointwise_op f "λs. s"
by unfold_locales fact+
show ?thesis
using pointwise_upd_fold[of xs s x]
by auto
qed
lemma list_prod_divmod_eq: "List.product [0..<M] [0..<N] = map (λi. (i div N, i mod N)) [0..<N*M]"
proof -
have [simp]: "i < m*n ⟹ (i::nat) div n < m" for i m n
by (metis mult.commute div_eq_0_iff div_mult2_eq gr_implies_not_zero mult_not_zero)
have [simp]: "i<N*M ⟹ N>0 ∧ M>0" for i
by (cases N; cases M; auto)
show ?thesis
by (rule nth_equalityI) (auto simp add: product_nth algebra_simps)
qed
lemma nfoldli_prod_divmod_conv:
"nfoldli (List.product [0..<N] [0..<M]) ctd (λ(i,j). f i j) = nfoldli [0..<N*M] ctd (λi. f (i div M) (i mod M))"
apply (intro ext)
apply (subst list_prod_divmod_eq)
apply (simp add: nfoldli_map)
apply (fo_rule cong)+
apply (auto simp: algebra_simps)
done
lemma nfoldli_prod_divmod_conv':
"nfoldli [0..<M] ctd (λi. nfoldli [0..<N] ctd (f i)) = nfoldli [0..<N*M] ctd (λi. f (i div N) (i mod N))"
apply (intro ext)
apply (subst nfoldli_nfoldli_prod_conv)
by (simp add: nfoldli_prod_divmod_conv algebra_simps)
lemma foldli_prod_divmod_conv':
"foldli [0..<M] ctd (λi. foldli [0..<N] ctd (f i)) = foldli [0..<N*M] ctd (λi. f (i div N) (i mod N))"
(is "?lhs=?rhs")
proof -
have "RETURN (?lhs s) = RETURN (?rhs s)" for s
apply (subst foldli_eq_nfoldli)+
apply (subst nfoldli_prod_divmod_conv')
..
thus ?thesis by auto
qed
lemma fold_prod_divmod_conv': "fold (λi. fold (f i) [0..<N]) [0..<M] = fold (λi. f (i div N) (i mod N)) [0..<N*M]"
using foldli_prod_divmod_conv'[of M "λ_. True" N f, THEN fun_cong]
apply (intro ext)
apply (simp add: foldli_foldl foldl_conv_fold)
done
lemma mtx_nonzero_cases[consumes 0, case_names nonzero zero]:
obtains "(i,j)∈mtx_nonzero m" | "m (i,j) = 0"
by (auto simp: mtx_nonzero_def)
subsubsection ‹Unary Pointwise›
definition mtx_pointwise_unop :: "(nat×nat ⇒ 'a ⇒ 'a) ⇒ 'a mtx ⇒ 'a mtx" where
"mtx_pointwise_unop f m ≡ λ(i,j). f (i,j) (m(i,j))"
context fixes f :: "nat×nat ⇒ 'a ⇒ 'a" begin
sepref_register "PR_CONST (mtx_pointwise_unop f)" :: "'a i_mtx ⇒ 'a i_mtx"
lemma [def_pat_rules]: "mtx_pointwise_unop$f ≡ UNPROTECT (mtx_pointwise_unop f)" by simp
end
locale mtx_pointwise_unop_loc =
fixes N :: nat and M :: nat
fixes f :: "(nat×nat) ⇒ 'a::{zero} ⇒ 'a"
assumes pres_zero[simp]: "⟦ i≥N ∨ j≥M ⟧ ⟹ f (i,j) 0 = 0"
begin
definition "opr_fold_impl ≡ fold (λi. fold (λj m. m( (i,j) := f (i,j) (m(i,j)) )) [0..<M]) [0..<N]"
lemma opr_fold_impl_eq:
assumes "mtx_nonzero m ⊆ {0..<N}×{0..<M}"
shows "mtx_pointwise_unop f m = opr_fold_impl m"
apply (rule ext)
unfolding opr_fold_impl_def
apply (simp add: fold_fold_prod_conv)
apply (subst pointwise_fun_fold)
apply (auto simp: mtx_pointwise_unop_def distinct_product) [3]
apply clarsimp
subgoal for a b
apply (cases a b m rule: mtx_nonzero_cases)
using assms
apply (auto simp: mtx_pointwise_unop_def)
done
done
lemma opr_fold_impl_refine: "(opr_fold_impl, mtx_pointwise_unop f) ∈ [λm. mtx_nonzero m ⊆ {0..<N}×{0..<M}]⇩f Id → Id"
apply (rule frefI)
using opr_fold_impl_eq
by auto
end
locale mtx_pointwise_unop_gen_impl = mtx_pointwise_unop_loc +
fixes assn :: "'a mtx ⇒ 'i ⇒ assn"
fixes A :: "'a ⇒ 'ai ⇒ assn"
fixes get_impl :: "'i ⇒ nat×nat ⇒ 'ai Heap"
fixes set_impl :: "'i ⇒ nat×nat ⇒ 'ai ⇒ 'i Heap"
fixes fi :: "nat×nat ⇒ 'ai ⇒ 'ai Heap"
assumes assn_range: "rdomp assn m ⟹ mtx_nonzero m ⊆ {0..<N}×{0..<M}"
assumes get_impl_hnr:
"(uncurry get_impl,uncurry (RETURN oo op_mtx_get)) ∈ assn⇧k *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k →⇩a A"
assumes set_impl_hnr:
"(uncurry2 set_impl,uncurry2 (RETURN ooo op_mtx_set)) ∈ assn⇧d *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k *⇩a A⇧k →⇩a assn"
assumes fi_hnr:
"(uncurry fi,uncurry (RETURN oo f)) ∈ (prod_assn nat_assn nat_assn)⇧k *⇩a A⇧k →⇩a A"
begin
lemma this_loc: "mtx_pointwise_unop_gen_impl N M f assn A get_impl set_impl fi"
by unfold_locales
context
notes [[sepref_register_adhoc f N M]]
notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"]
notes [sepref_import_param] = IdI[of N] IdI[of M]
notes [sepref_fr_rules] = get_impl_hnr set_impl_hnr fi_hnr
begin
sepref_thm opr_fold_impl1 is "RETURN o opr_fold_impl" :: "assn⇧d →⇩a assn"
unfolding opr_fold_impl_def
supply [[goals_limit = 1]]
by sepref
end
concrete_definition (in -) mtx_pointwise_unnop_fold_impl1 uses mtx_pointwise_unop_gen_impl.opr_fold_impl1.refine_raw
prepare_code_thms (in -) mtx_pointwise_unnop_fold_impl1_def
lemma op_hnr[sepref_fr_rules]: "(mtx_pointwise_unnop_fold_impl1 N M get_impl set_impl fi, RETURN ∘ PR_CONST (mtx_pointwise_unop f)) ∈ assn⇧d →⇩a assn"
unfolding PR_CONST_def
apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_unnop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]])
by (simp add: assn_range)
end
subsubsection ‹Binary Pointwise›
definition mtx_pointwise_binop :: "('a ⇒ 'a ⇒ 'a) ⇒ 'a mtx ⇒ 'a mtx ⇒ 'a mtx" where
"mtx_pointwise_binop f m n ≡ λ(i,j). f (m(i,j)) (n(i,j))"
context fixes f :: "'a ⇒ 'a ⇒ 'a" begin
sepref_register "PR_CONST (mtx_pointwise_binop f)" :: "'a i_mtx ⇒ 'a i_mtx ⇒ 'a i_mtx"
lemma [def_pat_rules]: "mtx_pointwise_binop$f ≡ UNPROTECT (mtx_pointwise_binop f)" by simp
end
locale mtx_pointwise_binop_loc =
fixes N :: nat and M :: nat
fixes f :: "'a::{zero} ⇒ 'a ⇒ 'a"
assumes pres_zero[simp]: "f 0 0 = 0"
begin
definition "opr_fold_impl m n ≡ fold (λi. fold (λj m. m( (i,j) := f (m(i,j)) (n(i,j)) )) [0..<M]) [0..<N] m"
lemma opr_fold_impl_eq:
assumes "mtx_nonzero m ⊆ {0..<N}×{0..<M}"
assumes "mtx_nonzero n ⊆ {0..<N}×{0..<M}"
shows "mtx_pointwise_binop f m n = opr_fold_impl m n"
apply (rule ext)
unfolding opr_fold_impl_def
apply (simp add: fold_fold_prod_conv)
apply (subst pointwise_fun_fold)
apply (auto simp: mtx_pointwise_binop_def distinct_product) [3]
apply clarsimp
subgoal for a b
apply (cases a b m rule: mtx_nonzero_cases; cases a b n rule: mtx_nonzero_cases)
using assms
apply (auto simp: mtx_pointwise_binop_def)
done
done
lemma opr_fold_impl_refine: "(uncurry opr_fold_impl, uncurry (mtx_pointwise_binop f)) ∈ [λ(m,n). mtx_nonzero m ⊆ {0..<N}×{0..<M} ∧ mtx_nonzero n ⊆ {0..<N}×{0..<M}]⇩f Id×⇩rId → Id"
apply (rule frefI)
using opr_fold_impl_eq
by auto
end
locale mtx_pointwise_binop_gen_impl = mtx_pointwise_binop_loc +
fixes assn :: "'a mtx ⇒ 'i ⇒ assn"
fixes A :: "'a ⇒ 'ai ⇒ assn"
fixes get_impl :: "'i ⇒ nat×nat ⇒ 'ai Heap"
fixes set_impl :: "'i ⇒ nat×nat ⇒ 'ai ⇒ 'i Heap"
fixes fi :: "'ai ⇒ 'ai ⇒ 'ai Heap"
assumes assn_range: "rdomp assn m ⟹ mtx_nonzero m ⊆ {0..<N}×{0..<M}"
assumes get_impl_hnr:
"(uncurry get_impl,uncurry (RETURN oo op_mtx_get)) ∈ assn⇧k *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k →⇩a A"
assumes set_impl_hnr:
"(uncurry2 set_impl,uncurry2 (RETURN ooo op_mtx_set)) ∈ assn⇧d *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k *⇩a A⇧k →⇩a assn"
assumes fi_hnr:
"(uncurry fi,uncurry (RETURN oo f)) ∈ A⇧k *⇩a A⇧k →⇩a A"
begin
lemma this_loc: "mtx_pointwise_binop_gen_impl N M f assn A get_impl set_impl fi"
by unfold_locales
context
notes [[sepref_register_adhoc f N M]]
notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"]
notes [sepref_import_param] = IdI[of N] IdI[of M]
notes [sepref_fr_rules] = get_impl_hnr set_impl_hnr fi_hnr
begin
sepref_thm opr_fold_impl1 is "uncurry (RETURN oo opr_fold_impl)" :: "assn⇧d*⇩aassn⇧k →⇩a assn"
unfolding opr_fold_impl_def[abs_def]
by sepref
end
concrete_definition (in -) mtx_pointwise_binop_fold_impl1
uses mtx_pointwise_binop_gen_impl.opr_fold_impl1.refine_raw is "(uncurry ?f,_)∈_"
prepare_code_thms (in -) mtx_pointwise_binop_fold_impl1_def
lemma op_hnr[sepref_fr_rules]: "(uncurry (mtx_pointwise_binop_fold_impl1 N M get_impl set_impl fi), uncurry (RETURN oo PR_CONST (mtx_pointwise_binop f))) ∈ assn⇧d *⇩a assn⇧k →⇩a assn"
unfolding PR_CONST_def
apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_binop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]])
apply (auto dest: assn_range)
done
end
subsubsection ‹Compare Pointwise›
definition mtx_pointwise_cmpop :: "('a ⇒ 'a ⇒ bool) ⇒ ('a ⇒ 'a ⇒ bool) ⇒ 'a mtx ⇒ 'a mtx ⇒ bool" where
"mtx_pointwise_cmpop f g m n ≡ (∀i j. f (m(i,j)) (n(i,j))) ∧ (∃i j. g (m(i,j)) (n(i,j)))"
context fixes f g :: "'a ⇒ 'a ⇒ bool" begin
sepref_register "PR_CONST (mtx_pointwise_cmpop f g)" :: "'a i_mtx ⇒ 'a i_mtx ⇒ bool"
lemma [def_pat_rules]: "mtx_pointwise_cmpop$f$g ≡ UNPROTECT (mtx_pointwise_cmpop f g)" by simp
end
lemma mtx_nonzeroD:
"⟦¬i<N; mtx_nonzero m ⊆ {0..<N}×{0..<M}⟧ ⟹ m(i,j) = 0"
"⟦¬j<M; mtx_nonzero m ⊆ {0..<N}×{0..<M}⟧ ⟹ m(i,j) = 0"
by (auto simp: mtx_nonzero_def)
locale mtx_pointwise_cmpop_loc =
fixes N :: nat and M :: nat
fixes f g :: "'a::{zero} ⇒ 'a ⇒ bool"
assumes pres_zero[simp]: "f 0 0 = True" "g 0 0 = False"
begin
definition "opr_fold_impl m n ≡ do {
s ← nfoldli (List.product [0..<N] [0..<M]) (λs. s≠2) (λ(i,j) s. do {
if f (m(i,j)) (n(i,j)) then
if s=0 then
if g (m(i,j)) (n(i,j)) then RETURN 1 else RETURN s
else
RETURN s
else RETURN 2
}) (0::nat);
RETURN (s=1)
}"
lemma opr_fold_impl_eq:
assumes "mtx_nonzero m ⊆ {0..<N}×{0..<M}"
assumes "mtx_nonzero n ⊆ {0..<N}×{0..<M}"
shows "opr_fold_impl m n ≤ RETURN (mtx_pointwise_cmpop f g m n)"
proof -
have "(∀i<N. ∀j<M. f (m (i, j)) (n (i, j))) ⟹ f (m (i, j)) (n (i, j))" for i j
apply (cases "i<N"; cases "j<M")
using assms by (auto simp: mtx_nonzeroD)
moreover have "g (m (i, j)) (n (i, j)) ⟹ (∃i<N. ∃j<M. g (m (i, j)) (n (i, j)))" for i j
apply (cases "i<N"; cases "j<M")
using assms by (auto simp: mtx_nonzeroD)
ultimately have EQ: "mtx_pointwise_cmpop f g m n
⟷ (∀i<N. ∀j<M. f (m(i,j)) (n(i,j))) ∧ (∃i<N. ∃j<M. g (m(i,j)) (n(i,j)))"
unfolding mtx_pointwise_cmpop_def by meson
have aux: "List.product [0..<N] [0..<M] = l1 @ (i, j) # l2 ⟹ i<N ∧ j<M" for l1 i j l2
proof -
assume "List.product [0..<N] [0..<M] = l1 @ (i, j) # l2"
hence "(i,j)∈set (List.product [0..<N] [0..<M])" by simp
thus ?thesis by simp
qed
show ?thesis
unfolding opr_fold_impl_def
apply (refine_vcg
nfoldli_rule[where I="λl1 _ s.
if s=2 then ∃i<N. ∃j<M. ¬f (m(i,j)) (n(i,j))
else (
(s=0 ∨ s=1) ∧
(∀(i,j)∈set l1. f (m(i,j)) (n(i,j))) ∧
(s=1 ⟷ (∃(i,j)∈set l1. g (m(i,j)) (n(i,j))))
)"]
)
apply (vc_solve dest: aux solve: asm_rl simp: EQ) [6]
apply (fastforce simp: EQ)
done
qed
lemma opr_fold_impl_refine:
"(uncurry opr_fold_impl, uncurry (RETURN oo mtx_pointwise_cmpop f g)) ∈ [λ(m,n). mtx_nonzero m ⊆ {0..<N}×{0..<M} ∧ mtx_nonzero n ⊆ {0..<N}×{0..<M}]⇩f Id×⇩rId → ⟨bool_rel⟩nres_rel"
apply (rule frefI)
using opr_fold_impl_eq
by (auto intro: nres_relI)
end
locale mtx_pointwise_cmpop_gen_impl = mtx_pointwise_cmpop_loc +
fixes assn :: "'a mtx ⇒ 'i ⇒ assn"
fixes A :: "'a ⇒ 'ai ⇒ assn"
fixes get_impl :: "'i ⇒ nat×nat ⇒ 'ai Heap"
fixes fi :: "'ai ⇒ 'ai ⇒ bool Heap"
fixes gi :: "'ai ⇒ 'ai ⇒ bool Heap"
assumes assn_range: "rdomp assn m ⟹ mtx_nonzero m ⊆ {0..<N}×{0..<M}"
assumes get_impl_hnr:
"(uncurry get_impl,uncurry (RETURN oo op_mtx_get)) ∈ assn⇧k *⇩a (prod_assn (nbn_assn N) (nbn_assn M))⇧k →⇩a A"
assumes fi_hnr:
"(uncurry fi,uncurry (RETURN oo f)) ∈ A⇧k *⇩a A⇧k →⇩a bool_assn"
assumes gi_hnr:
"(uncurry gi,uncurry (RETURN oo g)) ∈ A⇧k *⇩a A⇧k →⇩a bool_assn"
begin
lemma this_loc: "mtx_pointwise_cmpop_gen_impl N M f g assn A get_impl fi gi"
by unfold_locales
context
notes [[sepref_register_adhoc f g N M]]
notes [intf_of_assn] = intf_of_assnI[where R=assn and 'a="'a i_mtx"]
notes [sepref_import_param] = IdI[of N] IdI[of M]
notes [sepref_fr_rules] = get_impl_hnr fi_hnr gi_hnr
begin
sepref_thm opr_fold_impl1 is "uncurry opr_fold_impl" :: "assn⇧d*⇩aassn⇧k →⇩a bool_assn"
unfolding opr_fold_impl_def[abs_def] nfoldli_nfoldli_prod_conv[symmetric]
by sepref
end
concrete_definition (in -) mtx_pointwise_cmpop_fold_impl1
uses mtx_pointwise_cmpop_gen_impl.opr_fold_impl1.refine_raw is "(uncurry ?f,_)∈_"
prepare_code_thms (in -) mtx_pointwise_cmpop_fold_impl1_def
lemma op_hnr[sepref_fr_rules]: "(uncurry (mtx_pointwise_cmpop_fold_impl1 N M get_impl fi gi), uncurry (RETURN oo PR_CONST (mtx_pointwise_cmpop f g))) ∈ assn⇧d *⇩a assn⇧k →⇩a bool_assn"
unfolding PR_CONST_def
apply (rule hfref_weaken_pre'[OF _ mtx_pointwise_cmpop_fold_impl1.refine[OF this_loc,FCOMP opr_fold_impl_refine]])
apply (auto dest: assn_range)
done
end
end