# Theory Hybrid_Multiv_Matrix

(* This file includes the algorithm to construct the multivariate matrix equation.

As a naming convention, we try to maintain some consistency with prior work
for the univariate case.  When we write a function with a univariate analog in our
multivariate setting, we use _M to indicate that the function is now multivariate.
*)

theory Hybrid_Multiv_Matrix
imports
(* This entry is useful for the mpoly, mpoly poly connection *)
"Factor_Algebraic_Polynomial.Poly_Connection"
Multiv_Pseudo_Remainder_Sequence
"BenOr_Kozen_Reif.More_Matrix"
"HOL-Library.Mapping"
"BenOr_Kozen_Reif.Renegar_Algorithm"

begin

section "Find CSAS to qs at zeros of p"

subsection "Towards Tarski Queries"
(* Should only be called with a degree list that is as long as sturm_seq *)
fun sminus:: "nat list ⇒ rat list ⇒ int" where
"sminus degree_list sturm_seq = changes (map (λi. (-1)^(nth degree_list i)*(nth sturm_seq i)) [0..< length degree_list]) "

definition changes_R_smods_multiv:: "rat list ⇒ nat list ⇒ int"
where "changes_R_smods_multiv signs_list degree_list ≡ (sminus degree_list signs_list) - (changes signs_list)"

definition changes_R_smods_multiv_val:: "real mpoly Polynomial.poly list ⇒ real list ⇒ int" where
"changes_R_smods_multiv_val sturm_seq val ≡ (let (eval_ss::real Polynomial.poly list) = (eval_mpoly_poly_list val sturm_seq) in (changes_poly_neg_inf eval_ss - changes_poly_pos_inf eval_ss))"

subsection "Building the Matrix Equation"

type_synonym rmpoly = "real mpoly Polynomial.poly"
type_synonym assumps = "(real mpoly × rat) list"
type_synonym matrix_equation = "(rat mat × ((nat list * nat list) list × rat list list))"

definition base_case_info_M:: "(assumps × matrix_equation) list"
where "base_case_info_M = [([], base_case_info_R)]"

definition base_case_info_M_assumps:: "assumps ⇒ (assumps × matrix_equation) list"
where "base_case_info_M_assumps init_assumps = [(init_assumps, base_case_info_R)]"

fun combine_systems_single_M:: "rmpoly ⇒ rmpoly list ⇒ (assumps × matrix_equation) ⇒ rmpoly list ⇒ (assumps × matrix_equation) ⇒ (assumps × matrix_equation)"
where "combine_systems_single_M p q1 (a1, m1) q2 (a2,m2) =
(append a1 a2, snd (combine_systems_R p (q1, m1) (q2, m2)))"

fun combine_systems_M:: "rmpoly ⇒ rmpoly list ⇒ (assumps × matrix_equation) list ⇒ rmpoly list ⇒
(assumps × matrix_equation) list => rmpoly list × ((assumps × matrix_equation) list)"
where "combine_systems_M p q1 list1 q2 list2 =
(append q1 q2, concat (map (λl1. (map (λl2. combine_systems_single_M p q1 l1 q2 l2) list2)) list1))"

(* returns list of (assumps \times sturm sequence)*)
definition construct_NofI_R_spmods:: "rmpoly ⇒ assumps ⇒ rmpoly list ⇒ rmpoly list ⇒ (assumps × (rmpoly list)) list"
where "construct_NofI_R_spmods p assumps I1 I2 = (
let new_p = sum_list (map (λx. x^2) (p # I1)) in
spmods_multiv new_p ((pderiv new_p)*(prod_list I2))) assumps"

fun construct_NofI_single_M:: "(assumps × (rmpoly list)) ⇒
(assumps × rat)"
where "construct_NofI_single_M (input_assumps, ss)  =
sa_list = map (λlc. lookup_assump lc input_assumps) lcs;
degrees_list = degrees ss in
(input_assumps, rat_of_int (changes_R_smods_multiv sa_list degrees_list)))"

fun construct_NofI_M:: "rmpoly ⇒ assumps ⇒ rmpoly list ⇒ rmpoly list => (assumps × rat) list"
where "construct_NofI_M p assumps I1 I2 =
(let ss_list = construct_NofI_R_spmods p assumps I1 I2 in
map construct_NofI_single_M ss_list)"

fun pull_out_pairs:: "rmpoly list ⇒ (nat list * nat list) list ⇒ (rmpoly list × rmpoly list) list"
where "pull_out_pairs qs Is =
map (λ(I1, I2). ((retrieve_polys qs I1), (retrieve_polys qs I2))) Is"

fun construct_rhs_vector_rec_M:: "rmpoly ⇒ assumps ⇒ (rmpoly list × rmpoly list) list ⇒ (assumps × rat list) list"
where "construct_rhs_vector_rec_M p assumps [] = [(assumps, [])]"
| "construct_rhs_vector_rec_M p assumps ((qs1, qs2)#[]) =
(let TQ_list = construct_NofI_M p assumps qs1 qs2 in
map (λ(new_assumps, tq). (new_assumps, [tq])) TQ_list)"
| "construct_rhs_vector_rec_M p assumps ((qs1, qs2)#T) =
concat (let TQ_list = construct_NofI_M p assumps qs1 qs2 in
(map (λ(new_assumps, tq). (let rec = construct_rhs_vector_rec_M p new_assumps T in
map (λr. (fst r,  tq#snd r)) rec)) TQ_list))"

definition construct_rhs_vector_M:: "rmpoly ⇒ assumps ⇒ rmpoly list ⇒ (nat list * nat list) list ⇒ (assumps × rat vec) list"
where "construct_rhs_vector_M p assumps qs Is =
map (λres. (fst res, vec_of_list (snd res))) (construct_rhs_vector_rec_M p assumps (pull_out_pairs qs Is))"

definition solve_for_lhs_single_M:: "rmpoly ⇒ rmpoly list ⇒ (nat list * nat list) list ⇒ rat mat ⇒ rat vec ⇒ rat vec"
where "solve_for_lhs_single_M p qs subsets matr rhs_vector =
mult_mat_vec (matr_option (dim_row matr) (mat_inverse_var matr)) rhs_vector"

definition solve_for_lhs_M:: "rmpoly ⇒ assumps ⇒ rmpoly list ⇒ (nat list * nat list) list ⇒ rat mat ⇒ (assumps × rat vec) list"
where "solve_for_lhs_M p assumps qs subsets matr =
map (λrhs. (fst rhs, solve_for_lhs_single_M p qs subsets matr (snd rhs))) (construct_rhs_vector_M p assumps qs subsets)"

subsection "Reduction"
fun reduce_system_single_M:: "rmpoly ⇒ rmpoly list ⇒ (assumps × matrix_equation) ⇒ (assumps × matrix_equation) list"
where "reduce_system_single_M p qs (assumps, (m,subs,signs)) =
map (λlhs. (fst lhs, reduction_step_R m signs subs (snd lhs))) (solve_for_lhs_M p assumps qs subs m)"

fun reduce_system_M:: "rmpoly ⇒ rmpoly list ⇒ (assumps × matrix_equation) list ⇒ (assumps × matrix_equation) list"
where "reduce_system_M p qs input_list = concat (map (reduce_system_single_M p qs) input_list)"

subsection "Top-level Function"
fun calculate_data_M:: "rmpoly ⇒ rmpoly list ⇒ (assumps × matrix_equation) list"
where
"calculate_data_M p qs =
( let len = length qs in
if len = 0 then map (λ(assumps,(a,(b,c))). (assumps, (a,b,map (drop 1) c))) (reduce_system_M p [1] base_case_info_M)
else if len ≤ 1 then reduce_system_M p qs base_case_info_M
else
(let q1 = take (len div 2) qs; left = calculate_data_M p q1;
q2 = drop (len div 2) qs; right = calculate_data_M p q2;
comb = combine_systems_M p q1 left q2 right in
reduce_system_M p (fst comb) (snd comb)
)
)"

(* Very similar to calculate_data_M, but takes assumptions as an input *)
(* The top-level function we use to construct the multivariate matrix equation *)
fun calculate_data_assumps_M:: "rmpoly ⇒ rmpoly list ⇒ assumps ⇒ (assumps × matrix_equation) list"
where
"calculate_data_assumps_M p qs init_assumps =
( let len = length qs in
if len = 0 then map (λ(assumps,(a,(b,c))). (assumps, (a,b,map (drop 1) c))) (reduce_system_M p [1] (base_case_info_M_assumps init_assumps))
else if len ≤ 1 then reduce_system_M p qs (base_case_info_M_assumps init_assumps)
else
(let q1 = take (len div 2) qs; left = calculate_data_assumps_M p q1 init_assumps;
q2 = drop (len div 2) qs; right = calculate_data_assumps_M p q2 init_assumps;
comb = combine_systems_M p q1 left q2 right in
reduce_system_M p (fst comb) (snd comb)
)
)"

(* export_code vars calculate_data_assumps_M
in SML module_name export *)

end