File ‹pratt.ML›
signature PRATT =
sig
type prime_thm_cache = (int * thm) list
type tac_config = {cache : prime_thm_cache, verbose : bool, code : bool}
datatype cert = Pratt_Node of int * int * cert list
exception INVALID_CERT of cert
val get_cert_number : cert -> int
val mk_cert : int -> cert option
val check_cert : cert -> bool
val replay_cert : prime_thm_cache -> cert -> Proof.context -> thm * prime_thm_cache
val replay_cert_code : cert -> Proof.context -> thm
val prove_prime : prime_thm_cache -> int -> Proof.context -> thm option * prime_thm_cache
val certT : typ
val termify_cert : cert -> term
val untermify_cert : term -> cert
val pretty_cert : cert -> Pretty.T
val read_cert : Input.source -> cert
val read_certs : Input.source -> cert list
val cert_cartouche : cert parser
val cert_cartouches : cert list parser
val tac_config_parser : tac_config parser
val tac : tac_config -> cert option -> Proof.context -> int -> tactic
val setup_valid_cert_code_conv : (Proof.context -> conv) -> Context.generic -> Context.generic
type cert_cache = (int * cert) list
val augment_certs: cert list -> cert_cache
val reduce_certs: cert list -> cert_cache
val check_certs_parser: (Toplevel.transition -> Toplevel.transition) parser
end
structure Pratt : PRATT =
struct
fun mod_exp _ 0 m = if m = 1 then 0 else 1
| mod_exp b e m =
case Integer.div_mod e 2 of
(e', 0) => mod_exp ((b * b) mod m) e' m
| (e', _) => (b * mod_exp ((b * b) mod m) e' m) mod m
local
fun calc_primes mode ps i n =
if n = 0 then ps
else if exists (fn p => i mod p = 0) ps then
let
val i = i + 1
and n = if mode then n else n - 1
in
calc_primes mode ps i n
end
else
let
val ps = ps @ [i]
and i = i + 1
and n = n - 1
in
calc_primes mode ps i n
end;
in
fun primes_up_to n =
if n < 2 then []
else calc_primes false [2] 3 (n - 2);
end;
val small_primes = primes_up_to 100
fun factorise n =
let
val init = (small_primes, 101, false)
fun get_divisor (p :: _, _, _) = p
| get_divisor ([], k, _) = k
fun next (_ :: ps, k, b) = (ps, k, b)
| next ([], k, b) = ([], k + (if b then 4 else 2), not b)
fun divide_out d n =
let
fun divide (n, acc) =
if n mod d = 0 then
divide (n div d, acc + 1)
else
(n, acc)
in
divide (n, 0)
end
fun factor st n acc =
let
val d = get_divisor st
in
if n <= 1 then
rev acc
else if d * d > n then
rev ((n, 1) :: acc)
else
case divide_out d n of
(n', k) => factor (next st) n' (if k = 0 then acc else (d, k) :: acc)
end
in
factor init n []
end
type prime_thm_cache = (int * thm) list
datatype cert = Pratt_Node of int * int * cert list
type cert_cache = (int * cert) list
exception INVALID_CERT of cert
fun get_cert_number (Pratt_Node (n, _, _)) = n
fun mk_cert n =
let
exception PRATT
fun cert n cache =
if AList.defined op= cache n then
cache
else
let
fun find p lb ub =
if ub < lb then NONE else if p lb then SOME lb else find p (lb+1) ub
val ps = map fst (factorise (n - 1))
fun suitable' a p = mod_exp a ((n - 1) div p) n <> 1
fun suitable a = mod_exp a (n - 1) n = 1 andalso forall (suitable' a) ps
val a =
case find suitable 1 n of
NONE => raise PRATT
| SOME a => a
val cache = fold cert ps cache
val proofs = map (the o AList.lookup op= cache) ps
in
(n, Pratt_Node (n, a, proofs)) :: cache
end
in
AList.lookup op= (cert n []) n
handle PRATT => NONE
end
fun prove_list_all ctxt property thms =
let
val thm =
Drule.infer_instantiate' ctxt [SOME (Thm.cterm_of ctxt property)]
@{thm list.pred_inject(1)}
fun prove acc [] = acc
| prove acc (thm :: thms) = prove (@{thm list_all_ConsI} OF [thm, acc]) thms
in
prove thm (rev thms)
end
fun check_prime_factors_subset 0 _ = false
| check_prime_factors_subset n [] = n = 1
| check_prime_factors_subset n (p :: ps) =
if n mod p = 0 then
check_prime_factors_subset (n div p) (p :: ps)
else
check_prime_factors_subset n ps
fun check_cert' (Pratt_Node (n, a, ts)) =
let
val ps = map get_cert_number ts
in
check_prime_factors_subset (n - 1) ps
andalso forall (fn p => mod_exp a ((n - 1) div p) n <> 1) ps
andalso mod_exp a (n - 1) n = 1
end
fun check_cert (Pratt_Node (n, a, ts)) =
check_cert' (Pratt_Node (n, a, ts)) andalso forall check_cert ts
fun replay_cert cache cert ctxt =
let
val mk_nat = HOLogic.mk_number @{typ "Nat.nat"}
val mk_eq_thm = Thm.cterm_of ctxt #> Thm.reflexive
fun replay (Pratt_Node (n, a, ts)) cache =
case AList.lookup op= cache n of
SOME thm => (thm, cache)
| NONE =>
let
val _ = if check_cert' cert then () else raise INVALID_CERT cert
val (prime_thms, cache) = fold_map replay ts cache
val (n', a') = apply2 mk_nat (n, a)
val prime_thm = prove_list_all ctxt @{term "prime :: nat ⇒ bool"} prime_thms
val thm =
(@{thm lehmers_theorem'} OF [prime_thm, mk_eq_thm a', mk_eq_thm n'])
fun mk_thm () =
Goal.prove ctxt [] []
(HOLogic.mk_Trueprop (@{term "prime :: nat ⇒ bool"} $ mk_nat n))
(fn {context = ctxt, ...} =>
HEADGOAL (resolve_tac ctxt [thm])
THEN ALLGOALS (TRY o REPEAT_ALL_NEW
(resolve_tac ctxt @{thms list_all_ConsI list.pred_inject(1)}))
THEN PARALLEL_ALLGOALS (Simplifier.simp_tac ctxt))
in
case try mk_thm () of
NONE => raise THM ("replay_cert", 0, [thm])
| SOME thm => (thm, (n, thm) :: cache)
end
in
replay cert cache
end
fun prove_prime cache n ctxt =
case mk_cert n of
NONE => (NONE, cache)
| SOME cert =>
case replay_cert cache cert ctxt of
(thm, cache) => (SOME thm, cache)
datatype token_kind = Nat of int | Comma | Open_Brace | Close_Brace | Space | EOF
datatype token = Token of token_kind * Position.T
fun pos_of (Token (_, pos)) = pos
fun is_space (Token (Space, _)) = true
| is_space _ = false
fun is_eof (Token (EOF, _)) = true
| is_eof _ = false
fun mk_eof pos = Token (EOF, pos)
fun token_kind_name (Nat _) = "natural number"
| token_kind_name Comma = "comma"
| token_kind_name Open_Brace = "opening curly brace"
| token_kind_name Close_Brace = "closing curly brace"
| token_kind_name Space = "whitespace"
| token_kind_name EOF = "end of input"
val stopper =
Scan.stopper (fn [] => mk_eof Position.none | toks => mk_eof (pos_of (List.last toks))) is_eof
local
fun space_symbol ((s, _): Symbol_Pos.T) = Symbol.is_blank s andalso s <> "\n"
val scan_space =
Scan.many1 space_symbol @@@ Scan.optional (Symbol_Pos.$$$ "\n") [] ||
Scan.many space_symbol @@@ Symbol_Pos.$$$ "\n"
fun token kind (ss: Symbol_Pos.T list) =
Token (kind, Position.range_position (Symbol_Pos.range ss))
val scan_token =
Symbol_Pos.scan_nat >> (fn ss =>
let
val kind = Nat (#1 (Library.read_int (map #1 ss)))
val pos = Position.range_position (Symbol_Pos.range ss)
in Token (kind, pos) end) ||
Symbol_Pos.$$$ "," >> token Comma ||
Symbol_Pos.$$$ "{" >> token Open_Brace ||
Symbol_Pos.$$$ "}" >> token Close_Brace ||
scan_space >> token Space
val scan_all_tokens =
Scan.repeat scan_token --|
Symbol_Pos.!!! (fn () => "Lexical error") (Scan.ahead (Scan.one Symbol_Pos.is_eof))
in
val tokenize =
#1 o Scan.error (Scan.finite Symbol_Pos.stopper scan_all_tokens) o Input.source_explode
end;
local
type 'a parser = token list -> 'a * token list
fun err_msg expected toks () =
let
fun found [] = "end of input"
| found (Token (kind, _) :: _) = token_kind_name kind
in
expected ^ " expected, but " ^ found toks ^ " was found"
end
fun !!! (scan: 'a parser) =
let
fun get_pos [] = " (end-of-input)"
| get_pos (tok :: _) = Position.here (pos_of tok)
fun err (toks, msg) () =
"Syntax error" ^ get_pos toks ^ (case msg of NONE => "" | SOME m => ": " ^ m ())
in Scan.!! err scan end;
fun one kind =
Scan.some (fn Token (kind', _) => if kind = kind' then SOME () else NONE)
|| Scan.fail_with (err_msg (token_kind_name kind))
val nat =
Scan.some (fn Token (Nat n, _) => SOME n | _ => NONE)
|| Scan.fail_with (err_msg "natural number")
val comma = one Comma
val open_brace = one Open_Brace
val close_brace = one Close_Brace
fun enum1 scan = scan ::: Scan.repeat (comma |-- scan)
fun enum scan = enum1 scan || Scan.succeed []
fun list scan = !!! open_brace |-- enum scan --| !!! close_brace
fun parse toks =
((open_brace |-- !!! (nat --| !!! comma -- !!! nat --| !!! comma -- list parse
--| !!! close_brace)) >> (fn ((a, b), c) => Pratt_Node (a, b, c))
|| !!! nat >> (fn a => Pratt_Node (a, 1, []))
|| Scan.fail_with (err_msg "opening curly brace or natural number")) toks
fun gen_read_cert parse input =
let val toks = filter_out is_space (tokenize input)
in #1 (Scan.error (Scan.finite stopper (parse --| !!! (Scan.ahead (one EOF)))) toks) end
in
val read_cert = gen_read_cert parse
val read_certs = gen_read_cert (Scan.bulk parse)
end
val cert_cartouche = Args.cartouche_input >> read_cert
val cert_cartouches = Args.cartouche_input >> read_certs
val certT = @{typ "Pratt_Certificate.pratt_tree"}
local
val mk_nat = HOLogic.mk_number @{typ nat}
val dest_nat = snd o HOLogic.dest_number
in
fun termify_cert (Pratt_Node (n, a, ts)) =
@{term Pratt_Node} $
HOLogic.mk_tuple [mk_nat n, mk_nat a, HOLogic.mk_list certT (map termify_cert ts)]
fun untermify_cert (@{term Pratt_Node} $ t) = (
case HOLogic.strip_tuple t of
[n, a, ts] => Pratt_Node (dest_nat n, dest_nat a, map untermify_cert (HOLogic.dest_list ts))
| _ => raise TERM ("untermify_cert", [@{term Pratt_Node} $ t]))
| untermify_cert t = raise TERM ("untermify_cert", [t])
end
structure Data = Generic_Data
(
type T = (Proof.context -> conv) option
val empty : T = NONE
fun merge (_, conv) = conv
)
fun setup_valid_cert_code_conv conv ctxt =
Data.put (SOME conv) ctxt
fun has_code_conv ctxt =
case Data.get (Context.Proof ctxt) of
SOME _ => true
| _ => false
fun valid_cert_code_conv ctxt =
case Data.get (Context.Proof ctxt) of
SOME conv => conv ctxt
| NONE => (fn ct => raise CTERM ("valid_cert_code_conv", [ct]))
fun replay_cert_code cert ctxt =
let
val goal =
Thm.cterm_of ctxt (HOLogic.mk_Trueprop (@{term valid_pratt_tree} $ termify_cert cert))
in
@{thm valid_pratt_tree_imp_prime'} OF [valid_cert_code_conv ctxt goal]
end
handle TERM _ => raise INVALID_CERT cert
| CTERM _ => raise INVALID_CERT cert
| THM _ => raise INVALID_CERT cert
local
val pretty_int = Pretty.str o string_of_int
in
fun pretty_cert (Pratt_Node (n, 1, [])) = pretty_int n
| pretty_cert (Pratt_Node (n, a, ts)) =
Pretty.list "{" "}"
[pretty_int n, pretty_int a, Pretty.enum "," "{" "}" (map pretty_cert ts)]
end
type tac_config = {cache : prime_thm_cache, verbose : bool, code : bool}
exception NO_CODE
local
fun cert_err config cert =
let
val _ =
if #verbose config then
Pretty.chunks [Pretty.str "Invalid Pratt certificate:",
Pretty.indent 2 (pretty_cert cert)]
|> Pretty.string_of
|> warning
else
()
in
no_tac
end
in
fun tac config cert ctxt i =
let
val cmd =
Pretty.block ([Pretty.str "pratt"] @
(if #code config then [Pretty.str " (", Pretty.keyword1 "code", Pretty.str ")"] else []))
fun print_cert cert =
[Pretty.keyword1 "by", Pretty.brk 1, Pretty.str "(", cmd, Pretty.str " ",
Pretty.blk (2, [Pretty.cartouche (pretty_cert cert)]), Pretty.str ")"]
|> Pretty.blk o pair 4
|> Pretty.string_of
|> Active.sendback_markup_command
|> prefix "To repeat this proof with a pre-computed certificate, use:\n"
|> Output.information
fun not_prime_err n =
let
val _ = if #verbose config then warning ("Not a prime number: " ^ Int.toString n) else ()
in
NONE
end
fun certify p =
case cert of
SOME cert => SOME cert
| NONE =>
let
val p' = p |> HOLogic.dest_Trueprop |> dest_comb |> snd |> HOLogic.dest_number |> snd
in
case mk_cert p' of
SOME cert =>
let val _ = if #verbose config then print_cert cert else () in SOME cert end
| NONE => not_prime_err p'
end
val replay =
if #code config then
if has_code_conv ctxt then
replay_cert_code
else
let
val _ =
if #verbose config then
warning
("Code for Pratt certificates was not set up yet. " ^
"Load the theory Pratt_Certificate_Code to do this.")
else
()
in
raise NO_CODE
end
else
fst oo replay_cert (#cache config)
in
Subgoal.FOCUS_PARAMS (fn {concl, ...} =>
case certify (Thm.term_of concl) of
NONE => no_tac
| SOME cert =>
HEADGOAL (resolve_tac ctxt [replay cert ctxt])
) ctxt i
end
handle INVALID_CERT cert => cert_err config cert
| NO_CODE => no_tac
end
val default_config = {verbose = true, code = false, cache = []}
local
val silent : (tac_config -> tac_config) parser =
Args.$$$ "silent" >>
(K (fn {code, cache, ...} => {verbose = false, code = code, cache = cache}))
val code : (tac_config -> tac_config) parser =
Args.$$$ "code" >>
(K (fn {verbose, cache, ...} => {verbose = verbose, code = true, cache = cache}))
val option = silent || code
val options =
Scan.optional (Args.parens (Parse.list option) >>
(fn fs => fold (fn f => fn g => f o g) fs I)) I
in
val tac_config_parser = options >> (fn f => f default_config)
end
local
fun prime_of (Pratt_Node (p, _, _)) = p
fun triv_cert_of n = Pratt_Node (prime_of n, 1, [])
fun thms_from_certs is_code ctxt certs =
if is_code then
map (fn c => replay_cert_code c ctxt) certs
else
fold_map (fn cert => fn cache => replay_cert cache cert ctxt) certs []
|> fst
fun notes_from_certs is_code binding certs lthy =
let
val thms = thms_from_certs is_code lthy certs
in
Local_Theory.note ((binding, []), thms) lthy |> snd
end
val sort_cert_cache = sort (fn ((k, _), (l, _)) => int_ord (k, l))
fun reduce_certs (certs: cert list): cert_cache =
let
fun go cache [] = cache
| go cache (Pratt_Node (p, a, certs) :: cs) =
if AList.lookup (op =) cache p = NONE then
go
((p, Pratt_Node (p, a, map triv_cert_of certs))
:: fold (fn x => fn y => go y [x]) certs cache)
cs
else
go cache cs
in go [] certs |> sort_cert_cache end
fun trans_from_certs (is_code, is_full, is_reduce) binding certs =
let
val _ =
if is_full then
map pretty_cert certs |> Pretty.chunks |> Pretty.writeln
else ()
val _ =
if is_reduce then
reduce_certs certs |> map snd
|> map pretty_cert |> Pretty.chunks |> Pretty.writeln
else ()
in
Toplevel.local_theory NONE NONE (notes_from_certs is_code binding certs)
end
val modes_parser =
let
fun mk_opts modes =
let
val is_code = member (op =) modes "code"
val is_full = member (op =) modes "full"
val is_reduce = member (op =) modes "reduce"
val _ =
if List.all (member (op =) ["code", "full", "reduce"]) modes then ()
else error "Unknown option!"
in (is_code, is_full, is_reduce) end
val parse = (\<^keyword>‹(› |-- Parse.!!! (Scan.repeat1 Parse.name --| \<^keyword>‹)›))
in Scan.optional parse [] >> mk_opts end
fun augment_cert cert_cache cert =
case cert of
Pratt_Node (2, 1, []) => Pratt_Node (2, 1, [])
| Pratt_Node (n, 1, []) => (
case AList.lookup (op =) cert_cache n of
SOME c => c
| NONE => error ("Missing certificate for " ^ string_of_int n))
| Pratt_Node (ps, a, certs) =>
Pratt_Node (ps, a, map (augment_cert cert_cache) certs)
fun augment_certs0 (cert_cache: cert_cache): cert_cache =
let
fun go done [] = rev done
| go done ((i, c) :: xs) = go ((i, augment_cert done c) :: done) xs
in go [] cert_cache end
in
val reduce_certs = reduce_certs
fun augment_certs (certs: cert list): cert_cache =
let
fun mk_pair n = (prime_of n, n)
val cache = map mk_pair certs |> sort_cert_cache
in augment_certs0 cache end
val check_certs_parser =
modes_parser -- Parse.binding -- cert_cartouches >>
(fn ((opts, binding), certs) =>
augment_certs certs |> map snd |> trans_from_certs opts binding)
end
end