Theory Stream_Fusion_List

(* Title: Stream_Fusion_List.thy
 Authors: Alexandra Maximova, ETH Zurich
          Andreas Lochbihler, ETH Zurich
*)

section ‹Stream fusion for finite lists›

theory Stream_Fusion_List
imports Stream_Fusion
begin

lemma map_option_mono [partial_function_mono]: (* To be moved to HOL *)
  "mono_option f  mono_option (λx. map_option g (f x))"
apply (rule monotoneI)
apply (drule (1) monotoneD)
apply (auto simp add: flat_ord_def split: option.split)
done

subsection ‹The type of generators for finite lists›

datatype ('a, 's) step = Done | is_Skip: Skip 's | is_Yield: Yield 'a 's

type_synonym ('a, 's) raw_generator = "'s  ('a,'s) step"

text ‹
  Raw generators may not end in @{const Done}, but may lead to infinitely many @{const Yield}s 
  in a row. Such generators cannot be converted to finite lists, because it corresponds to an
  infinite list. Therefore, we introduce the type of generators that always end in @{const Done}
  after finitely many steps.
›

inductive_set terminates_on :: "('a, 's) raw_generator  's set"
  for g :: "('a, 's) raw_generator"
where
  stop: "g s = Done  s  terminates_on g"
| pause: " g s = Skip s'; s'  terminates_on g   s  terminates_on g"
| unfold: " g s = Yield a s'; s'  terminates_on g   s  terminates_on g"

definition terminates :: "('a, 's) raw_generator  bool"
where "terminates g  (terminates_on g = UNIV)"

lemma terminatesI [intro?]:
  "(s. s  terminates_on g)  terminates g"
by (auto simp add: terminates_def)

lemma terminatesD:
  "terminates g  s  terminates_on g"
by (auto simp add: terminates_def)

lemma terminates_on_stop:
  "terminates_on (λ_. Done) = UNIV"
by (auto intro: terminates_on.stop)

lemma wf_terminates:
  assumes "wf R"
  and skip: "s s'. g s = Skip s'  (s',s)  R"
  and yield: "s s' a. g s = Yield a s'  (s',s)  R"
  shows "terminates g"
proof (rule terminatesI)
  fix s
  from wf R show "s  terminates_on g"
  proof (induction rule: wf_induct [rule_format, consumes 1, case_names wf])
    case (wf s)
    show ?case
    proof (cases "g s")
      case (Skip s')
      hence "(s', s)  R" by (rule skip)
      hence "s'  terminates_on g" by (rule wf.IH)
      with Skip show ?thesis by (rule terminates_on.pause)
    next
      case (Yield a s')
      hence "(s', s)  R" by (rule yield)
      hence "s'  terminates_on g" by (rule wf.IH)
      with Yield show ?thesis by (rule terminates_on.unfold)
    qed (rule terminates_on.stop)
  qed
qed

context fixes g :: "('a, 's) raw_generator" begin

partial_function (option) terminates_within :: "'s  nat option" where
  "terminates_within s = (case g s of
     Done  Some 0
  | Skip s'  map_option (λn. n + 1) (terminates_within s')
  | Yield a s'  map_option (λn. n + 1) (terminates_within s'))"

lemma terminates_on_conv_dom_terminates_within:
  "terminates_on g = dom terminates_within"
proof (rule set_eqI iffI)+
  fix s
  assume "s  terminates_on g"
  hence "n. terminates_within s = Some n"
    by induction (subst terminates_within.simps, simp add: split_beta)+
  then show "s  dom terminates_within" by blast
next
  fix s
  assume "s  dom terminates_within"
  then obtain n where "terminates_within s = Some n" by blast
  then show "s  terminates_on g"
  proof (induction rule: terminates_within.raw_induct[rotated 1, consumes 1])
    case (1 terminates_within s s')
    show ?case
    proof(cases "g s")
      case Done
      thus ?thesis by (simp add: terminates_on.stop)
    next
      case (Skip s')
      hence "s'  terminates_on g" using 1 by(auto)
      thus ?thesis using g s = Skip s' by (simp add: terminates_on.pause)
    next
      case (Yield a s')
      hence "s'  terminates_on g" using 1 by(auto)
      thus ?thesis using g s = Yield a s' by (auto intro: terminates_on.unfold)
    qed
  qed
qed

end

lemma terminates_wfE:
  assumes "terminates g"
  obtains R 
  where "wf R"
    "s s'. (g s = Skip s')  (s',s)  R"
    "s a s'. (g s = Yield a s')  (s',s)  R"
proof -
  let ?R = "measure (λs. the (terminates_within g s)) :: ('a × 'a) set"
  have "wf ?R" by simp
  moreover {
    fix s s'
    assume "g s = Skip s'"
    moreover from assms have "s'  terminates_on g" by (rule terminatesD)
    then obtain n where "terminates_within g s' = Some n"
      unfolding terminates_on_conv_dom_terminates_within by (auto)
    ultimately have "the (terminates_within g s') < the (terminates_within g s)"
      by (simp add: terminates_within.simps)
    hence "(s',s)  ?R" by (auto)
  } moreover {
    fix s s' a
    assume 2: "g s = Yield a s'"
    moreover from assms have "s'  terminates_on g" by (rule terminatesD)
    then obtain n where "terminates_within g s' = Some n"
      unfolding terminates_on_conv_dom_terminates_within by (auto)
    ultimately have "(s',s)  ?R"
      by simp (subst terminates_within.simps, simp add: split_beta)
  } ultimately 
  show thesis by (rule that)
qed

typedef ('a,'s) generator = "{g :: ('a,'s) raw_generator. terminates g}"
  morphisms generator Generator
proof
  show "(λ_. Done)  ?generator"
    by (simp add: terminates_on_stop terminates_def)
qed

setup_lifting type_definition_generator

subsection ‹Conversion to @{typ "'a list"}

context fixes g :: "('a, 's) generator" begin

function unstream :: "'s  'a list"
where
  "unstream s = (case generator g s of
     Done  []
   | Skip s'  unstream s'
   | Yield x s'  x # unstream s')"
by pat_completeness auto
termination
proof -
  have "terminates (generator g)" using generator[of g] by simp
  thus ?thesis by(rule terminates_wfE)(erule "termination")
qed

lemma unstream_simps [simp]:
  "generator g s = Done  unstream s = []"
  "generator g s = Skip s'  unstream s = unstream s'"
  "generator g s = Yield x s'  unstream s = x # unstream s'"
by(simp_all)

declare unstream.simps[simp del]

function force :: "'s  ('a × 's) option"
where
  "force s = (case generator g s of Done  None 
     | Skip s'  force s'
     | Yield x s'  Some (x, s'))"
by pat_completeness auto
termination
proof -
  have "terminates (generator g)" using generator[of g] by simp
  thus ?thesis by(rule terminates_wfE)(rule "termination")
qed

lemma force_simps [simp]:
  "generator g s = Done  force s = None"
  "generator g s = Skip s'  force s = force s'"
  "generator g s = Yield x s'  force s = Some (x, s')"
by(simp_all)

declare force.simps[simp del]

lemma unstream_force_None [simp]: "force s = None  unstream s = []"
proof(induction s rule: force.induct)
  case (1 s)
  thus ?case by(cases "generator g s") simp_all
qed

lemma unstream_force_Some [simp]: "force s = Some (x, s')  unstream s = x # unstream s'"
proof(induction s rule: force.induct)
  case (1 s)
  thus ?case by(cases "generator g s") simp_all
qed

end

setup Context.theory_map (Stream_Fusion.add_unstream @{const_name unstream})

subsection ‹Producers›

subsubsection ‹Conversion to streams›

fun stream_raw :: "'a list  ('a, 'a list) step"
where
  "stream_raw [] = Done"
| "stream_raw (x # xs) = Yield x xs"

lemma terminates_stream_raw: "terminates stream_raw"
proof (rule terminatesI)
  fix s :: "'a list"
  show "s  terminates_on stream_raw"
    by(induction s)(auto intro: terminates_on.intros)
qed

lift_definition stream :: "('a, 'a list) generator" is "stream_raw" by(rule terminates_stream_raw)

lemma unstream_stream: "unstream stream xs = xs"
by(induction xs)(auto simp add: stream.rep_eq)

subsubsection @{const replicate}

fun replicate_raw :: "'a  ('a, nat) raw_generator"
where
  "replicate_raw a 0 = Done"
| "replicate_raw a (Suc n) = Yield a n"
 
lemma terminates_replicate_raw: "terminates (replicate_raw a)"
proof (rule terminatesI)
  fix s :: "nat"
  show "s  terminates_on (replicate_raw a)"
    by(induction s)(auto intro: terminates_on.intros)
qed

lift_definition replicate_prod :: "'a  ('a, nat) generator" is "replicate_raw"
by(rule terminates_replicate_raw)

lemma unstream_replicate_prod [stream_fusion]: "unstream (replicate_prod x) n = replicate n x"
by(induction n)(simp_all add: replicate_prod.rep_eq)

subsubsection @{const upt}

definition upt_raw :: "nat  (nat, nat) raw_generator"
where "upt_raw n m = (if m  n then Done else Yield m (Suc m))"

lemma terminates_upt_raw: "terminates (upt_raw n)"
proof (rule terminatesI)
  fix s :: nat
  show "s  terminates_on (upt_raw n)"
    by(induction "n-s" arbitrary: s rule: nat.induct)(auto 4 3 simp add: upt_raw_def intro: terminates_on.intros)
qed

lift_definition upt_prod :: "nat  (nat, nat) generator" is "upt_raw" by(rule terminates_upt_raw)

lemma unstream_upt_prod [stream_fusion]: "unstream (upt_prod n) m = upt m n"
by(induction "n-m" arbitrary: n m)(simp_all add: upt_prod.rep_eq upt_conv_Cons upt_raw_def unstream.simps)


subsubsection @{const upto}

definition upto_raw :: "int  (int, int) raw_generator"
where "upto_raw n m = (if m  n then Yield m (m + 1) else Done)"

lemma terminates_upto_raw: "terminates (upto_raw n)"
proof (rule terminatesI)
  fix s :: int
  show "s  terminates_on (upto_raw n)"
    by(induction "nat(n-s+1)" arbitrary: s)(auto 4 3 simp add: upto_raw_def intro: terminates_on.intros)
qed

lift_definition upto_prod :: "int  (int, int) generator" is "upto_raw" by (rule terminates_upto_raw)

lemma unstream_upto_prod [stream_fusion]: "unstream (upto_prod n) m = upto m n"
by(induction "nat (n - m + 1)" arbitrary: m)(simp_all add: upto_prod.rep_eq upto.simps upto_raw_def)

subsubsection @{term "[]"}

lift_definition Nil_prod :: "('a, unit) generator" is "λ_. Done"
by(auto simp add: terminates_def intro: terminates_on.intros)

lemma generator_Nil_prod: "generator Nil_prod = (λ_. Done)"
by(fact Nil_prod.rep_eq)

lemma unstream_Nil_prod [stream_fusion]: "unstream Nil_prod () = []"
by(simp add: generator_Nil_prod)

subsection ‹Consumers›

subsubsection @{const nth}

context fixes g :: "('a, 's) generator" begin

definition nth_cons :: "'s  nat  'a" 
where [stream_fusion]: "nth_cons s n = unstream g s ! n"

lemma nth_cons_code [code]:
  "nth_cons s n =
  (case generator g s of Done => undefined n
    | Skip s' => nth_cons s' n
    | Yield x s' => (case n of 0 => x | Suc n' => nth_cons s' n'))"
by(cases "generator g s")(simp_all add: nth_cons_def nth_def split: nat.split)

end

subsubsection @{term length}

context fixes g :: "('a, 's) generator" begin

definition length_cons :: "'s  nat"
where "length_cons s = length (unstream g s)"

lemma length_cons_code [code]:
  "length_cons s =
    (case generator g s of
      Done  0
    | Skip s'  length_cons s'
    | Yield a s'  1 + length_cons s')"
by(cases "generator g s")(simp_all add: length_cons_def)

definition gen_length_cons :: "nat  's  nat"
where "gen_length_cons n s = n + length (unstream g s)"

lemma gen_length_cons_code [code]:
  "gen_length_cons n s = (case generator g s of
     Done  n | Skip s'  gen_length_cons n s' | Yield a s'  gen_length_cons (Suc n) s')"
by(simp add: gen_length_cons_def split: step.split)

lemma unstream_gen_length [stream_fusion]: "gen_length_cons 0 s = length (unstream g s)"
by(simp add: gen_length_cons_def)

lemma unstream_gen_length2 [stream_fusion]: "gen_length_cons n s = List.gen_length n (unstream g s)"
by(simp add: List.gen_length_def gen_length_cons_def)

end

subsubsection @{const foldr}

context 
  fixes g :: "('a, 's) generator"
  and f :: "'a  'b  'b"
  and z :: "'b"
begin

definition foldr_cons :: "'s  'b"
where [stream_fusion]: "foldr_cons s = foldr f (unstream g s) z"

lemma foldr_cons_code [code]:
  "foldr_cons s =
    (case generator g s of
      Done  z
    | Skip s'  foldr_cons s'
    | Yield a s'  f a (foldr_cons s'))"
by(cases "generator g s")(simp_all add: foldr_cons_def)

end

subsubsection @{const foldl}

context
  fixes g :: "('b, 's) generator"
  and f :: "'a  'b  'a"
begin

definition foldl_cons :: "'a  's  'a"
where [stream_fusion]: "foldl_cons z s = foldl f z (unstream g s)"

lemma foldl_cons_code [code]:
  "foldl_cons z s =
    (case generator g s of
      Done  z
    | Skip s'  foldl_cons z s'
    | Yield a s'  foldl_cons (f z a) s')"
by (cases "generator g s")(simp_all add: foldl_cons_def)

end

subsubsection @{const fold}

context
  fixes g :: "('a, 's) generator"
  and f :: "'a  'b  'b"
begin

definition fold_cons :: "'b  's  'b"
where [stream_fusion]: "fold_cons z s = fold f (unstream g s) z"

lemma fold_cons_code [code]:
  "fold_cons z s =
    (case generator g s of
      Done  z
    | Skip s'  fold_cons z s'
    | Yield a s'  fold_cons (f a z) s')"
by (cases "generator g s")(simp_all add: fold_cons_def)

end

subsubsection @{const List.null}

definition null_cons :: "('a, 's) generator  's  bool"
where [stream_fusion]: "null_cons g s = List.null (unstream g s)"

lemma null_cons_code [code]:
  "null_cons g s = (case generator g s of Done  True | Skip s'  null_cons g s' | Yield _ _  False)"
by(cases "generator g s")(simp_all add: null_cons_def null_def)

subsubsection @{const hd}

context fixes g :: "('a, 's) generator" begin

definition hd_cons :: "'s  'a"
where [stream_fusion]: "hd_cons s = hd (unstream g s)"

lemma hd_cons_code [code]:
  "hd_cons s =
    (case generator g s of
      Done  undefined
    | Skip s'  hd_cons s'
    | Yield a s'  a)"
by (cases "generator g s")(simp_all add: hd_cons_def hd_def)

end

subsubsection @{const last}

context fixes g :: "('a, 's) generator" begin

definition last_cons :: "'a option  's  'a"
where "last_cons x s = (if unstream g s = [] then the x else last (unstream g s))"

lemma last_cons_code [code]:
  "last_cons x s =
  (case generator g s of Done  the x
             | Skip s'  last_cons x s'
             | Yield a s'  last_cons (Some a) s')"
by (cases "generator g s")(simp_all add: last_cons_def)

lemma unstream_last_cons [stream_fusion]: "last_cons None s = last (unstream g s)"
by (simp add: last_cons_def last_def option.the_def)

end

subsubsection @{const sum_list}

context fixes g :: "('a :: monoid_add, 's) generator" begin

definition sum_list_cons :: "'s  'a"
where [stream_fusion]: "sum_list_cons s = sum_list (unstream g s)"

lemma sum_list_cons_code [code]:
  "sum_list_cons s =
    (case generator g s of
      Done  0
    | Skip s'  sum_list_cons s'
    | Yield a s'  a + sum_list_cons s')"
by (cases "generator g s")(simp_all add: sum_list_cons_def)

end

subsubsection @{const list_all2}

context
  fixes g :: "('a, 's1) generator"
  and h :: "('b, 's2) generator"
  and P :: "'a  'b  bool"
begin

definition list_all2_cons :: "'s1  's2  bool"
where [stream_fusion]: "list_all2_cons sg sh = list_all2 P (unstream g sg) (unstream h sh)"

definition list_all2_cons1 :: "'a  's1  's2  bool"
where "list_all2_cons1 x sg' sh = list_all2 P (x # unstream g sg') (unstream h sh)"

lemma list_all2_cons_code [code]:
  "list_all2_cons sg sh = 
  (case generator g sg of
     Done  null_cons h sh
   | Skip sg'  list_all2_cons sg' sh
   | Yield a sg'  list_all2_cons1 a sg' sh)"
by(simp split: step.split add: list_all2_cons_def null_cons_def List.null_def list_all2_cons1_def)

lemma list_all2_cons1_code [code]:
  "list_all2_cons1 x sg' sh = 
  (case generator h sh of
     Done  False
   | Skip sh'  list_all2_cons1 x sg' sh'
   | Yield y sh'  P x y  list_all2_cons sg' sh')"
by(simp split: step.split add: list_all2_cons_def null_cons_def List.null_def list_all2_cons1_def)

end

subsubsection @{const list_all}

context
  fixes g :: "('a, 's) generator"
  and P :: "'a  bool"
begin

definition list_all_cons :: "'s  bool"
where [stream_fusion]: "list_all_cons s = list_all P (unstream g s)"

lemma list_all_cons_code [code]:
  "list_all_cons s 
  (case generator g s of
    Done  True | Skip s'  list_all_cons s' | Yield x s'  P x  list_all_cons s')"
by(simp add: list_all_cons_def split: step.split)

end

subsubsection @{const ord.lexordp}

context ord begin

definition lexord_fusion :: "('a, 's1) generator  ('a, 's2) generator  's1  's2  bool"
where [code del]: "lexord_fusion g1 g2 s1 s2 = ord_class.lexordp (unstream g1 s1) (unstream g2 s2)"

definition lexord_eq_fusion :: "('a, 's1) generator  ('a, 's2) generator  's1  's2  bool"
where [code del]: "lexord_eq_fusion g1 g2 s1 s2 = lexordp_eq (unstream g1 s1) (unstream g2 s2)"

lemma lexord_fusion_code:
  "lexord_fusion g1 g2 s1 s2 
  (case generator g1 s1 of
     Done  ¬ null_cons g2 s2
   | Skip s1'  lexord_fusion g1 g2 s1' s2
   | Yield x s1'  
     (case force g2 s2 of
        None  False
      | Some (y, s2')  x < y  ¬ y < x  lexord_fusion g1 g2 s1' s2'))"
unfolding lexord_fusion_def
by(cases "generator g1 s1" "force g2 s2" rule: step.exhaust[case_product option.exhaust])(auto simp add: null_cons_def null_def)

lemma lexord_eq_fusion_code:
  "lexord_eq_fusion g1 g2 s1 s2 
  (case generator g1 s1 of
     Done  True
   | Skip s1'  lexord_eq_fusion g1 g2 s1' s2
   | Yield x s1' 
     (case force g2 s2 of
        None  False
      | Some (y, s2')  x < y  ¬ y < x  lexord_eq_fusion g1 g2 s1' s2'))"
unfolding lexord_eq_fusion_def
by(cases "generator g1 s1" "force g2 s2" rule: step.exhaust[case_product option.exhaust]) auto

end

lemmas [code] =
  lexord_fusion_code ord.lexord_fusion_code
  lexord_eq_fusion_code ord.lexord_eq_fusion_code

lemmas [stream_fusion] =
  lexord_fusion_def ord.lexord_fusion_def
  lexord_eq_fusion_def ord.lexord_eq_fusion_def

subsection ‹Transformers›

subsubsection @{const map}

definition map_raw :: "('a  'b)  ('a, 's) raw_generator  ('b, 's) raw_generator"
where
  "map_raw f g s = (case g s of
     Done  Done
   | Skip s'  Skip s'
   | Yield a s'  Yield (f a) s')"

lemma terminates_map_raw: 
  assumes "terminates g"
  shows "terminates (map_raw f g)"
proof (rule terminatesI)
  fix s
  from assms
  have "s  terminates_on g" by (simp add: terminates_def)
  then show "s  terminates_on (map_raw f g)"
    by (induction s)(auto intro: terminates_on.intros simp add: map_raw_def)
qed

lift_definition map_trans :: "('a  'b)  ('a, 's) generator  ('b, 's) generator" is "map_raw"
by (rule terminates_map_raw)

lemma unstream_map_trans [stream_fusion]: "unstream (map_trans f g) s = map f (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  show ?case using "1.IH" by (cases "generator g s")(simp_all add: map_trans.rep_eq map_raw_def)
qed

subsubsection @{const drop}

fun drop_raw :: "('a, 's) raw_generator  ('a, (nat × 's)) raw_generator"
where
  "drop_raw g (n, s) = (case g s of
     Done  Done | Skip s'  Skip (n, s')
   | Yield a s'  (case n of 0  Yield a (0, s') | Suc n  Skip (n, s')))"

lemma terminates_drop_raw:
  assumes "terminates g"
  shows "terminates (drop_raw g)"
proof (rule terminatesI)
  fix st :: "nat × 'a"
  obtain n s where "st = (n, s)" by(cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  thus "st  terminates_on (drop_raw g)" unfolding st = (n, s)
    apply(induction arbitrary: n)
    apply(case_tac [!] n)
    apply(auto intro: terminates_on.intros)
    done
qed

lift_definition drop_trans :: "('a, 's) generator  ('a, nat × 's) generator" is "drop_raw"
by (rule terminates_drop_raw)

lemma unstream_drop_trans [stream_fusion]: "unstream (drop_trans g) (n, s) = drop n (unstream g s)"
proof (induction s arbitrary: n taking: g rule: unstream.induct)
  case (1 s)
  show ?case using "1.IH"(1)[of _ n] "1.IH"(2)[of _ _ n] "1.IH"(2)[of _ _ "n - 1"]
    by(cases "generator g s" "n" rule: step.exhaust[case_product nat.exhaust])
      (simp_all add: drop_trans.rep_eq)
qed

subsubsection @{const dropWhile}

fun dropWhile_raw :: "('a  bool)  ('a, 's) raw_generator  ('a, bool × 's) raw_generator"
  ― ‹Boolean flag indicates whether we are still in dropping phase›
where
  "dropWhile_raw P g (True, s) = (case g s of
     Done  Done | Skip s'  Skip (True, s')
   | Yield a s'  (if P a then Skip (True, s') else Yield a (False, s')))"
| "dropWhile_raw P g (False, s) = (case g s of
     Done  Done | Skip s'  Skip (False, s') | Yield a s'  Yield a (False, s'))"

lemma terminates_dropWhile_raw:
  assumes "terminates g"
  shows "terminates (dropWhile_raw P g)"
proof (rule terminatesI)
  fix st :: "bool × 'a"
  obtain b s where "st = (b, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (dropWhile_raw P g)" unfolding st = (b, s)
  proof (induction s arbitrary: b)
    case (stop s b)
    then show ?case by (cases b)(simp_all add: terminates_on.stop)
  next
    case (pause s s' b)
    then show ?case by (cases b)(simp_all add: terminates_on.pause)
  next
    case (unfold s a s' b)
    then show ?case
      by(cases b)(cases "P a", auto intro: terminates_on.pause terminates_on.unfold)
   qed
qed

lift_definition dropWhile_trans :: "('a  bool)  ('a, 's) generator  ('a, bool × 's) generator"
is "dropWhile_raw" by (rule terminates_dropWhile_raw)

lemma unstream_dropWhile_trans_False:
  "unstream (dropWhile_trans P g) (False, s) = unstream g s"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by (cases "generator g s")(simp_all add: dropWhile_trans.rep_eq)
qed

lemma unstream_dropWhile_trans [stream_fusion]:
  "unstream (dropWhile_trans P g) (True, s) = dropWhile P (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case
  proof(cases "generator g s")
    case (Yield a s')
    then show ?thesis using "1.IH"(2) unstream_dropWhile_trans_False
      by (cases "P a")(simp_all add: dropWhile_trans.rep_eq)
  qed(simp_all add: dropWhile_trans.rep_eq)
qed

subsubsection @{const take}

fun take_raw :: "('a, 's) raw_generator  ('a, (nat × 's)) raw_generator"
where
  "take_raw g (0, s) = Done"
| "take_raw g (Suc n, s) = (case g s of 
     Done  Done | Skip s'  Skip (Suc n, s') | Yield a s'  Yield a (n, s'))"

lemma terminates_take_raw:
  assumes "terminates g"
  shows "terminates (take_raw g)"
proof (rule terminatesI)
  fix st :: "nat × 'a"
  obtain n s where "st = (n, s)" by(cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  thus "st  terminates_on (take_raw g)" unfolding st = (n, s)
    apply(induction s arbitrary: n)
    apply(case_tac [!] n)
    apply(auto intro: terminates_on.intros)
    done
qed

lift_definition take_trans :: "('a, 's) generator  ('a, nat × 's) generator" is "take_raw"
by (rule terminates_take_raw)

lemma unstream_take_trans [stream_fusion]: "unstream (take_trans g) (n, s) = take n (unstream g s)" 
proof (induction s arbitrary: n taking: g rule: unstream.induct)
  case (1 s)
  show ?case using "1.IH"(1)[of _ n] "1.IH"(2)
    by(cases "generator g s" n rule: step.exhaust[case_product nat.exhaust])
      (simp_all add: take_trans.rep_eq)
qed

subsubsection @{const takeWhile}

definition takeWhile_raw :: "('a  bool)  ('a, 's) raw_generator  ('a, 's) raw_generator"
where
  "takeWhile_raw P g s = (case g s of
     Done  Done | Skip s'  Skip s' | Yield a s'  if P a then Yield a s' else Done)"

lemma terminates_takeWhile_raw: 
  assumes "terminates g"
  shows "terminates (takeWhile_raw P g)"
proof (rule terminatesI)
  fix s
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  thus "s  terminates_on (takeWhile_raw P g)"
  proof (induction s rule: terminates_on.induct)
    case (unfold s a s')
    then show ?case by(cases "P a")(auto simp add: takeWhile_raw_def intro: terminates_on.intros)
  qed(auto intro: terminates_on.intros simp add: takeWhile_raw_def)
qed

lift_definition takeWhile_trans :: "('a  bool)  ('a, 's) generator  ('a, 's) generator"
is "takeWhile_raw" by (rule terminates_takeWhile_raw)

lemma unstream_takeWhile_trans [stream_fusion]:
  "unstream (takeWhile_trans P g) s = takeWhile P (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by(cases "generator g s")(simp_all add: takeWhile_trans.rep_eq takeWhile_raw_def)
qed

subsubsection@{const append}

fun append_raw :: "('a, 'sg) raw_generator  ('a, 'sh) raw_generator  'sh  ('a, 'sg + 'sh) raw_generator"
where
  "append_raw g h sh_start (Inl sg) = (case g sg of
     Done  Skip (Inr sh_start) | Skip sg'  Skip (Inl sg') | Yield a sg'  Yield a (Inl sg'))"
| "append_raw g h sh_start (Inr sh) = (case h sh of
     Done  Done | Skip sh'  Skip (Inr sh') | Yield a sh'  Yield a (Inr sh'))"

lemma terminates_on_append_raw_Inr: 
  assumes "terminates h"
  shows "Inr sh  terminates_on (append_raw g h sh_start)"
proof -
  from assms have "sh  terminates_on h" by (simp add: terminates_def)
  thus ?thesis by(induction sh)(auto intro: terminates_on.intros)
qed

lemma terminates_append_raw:
  assumes "terminates g" "terminates h"
  shows "terminates (append_raw g h sh_start)"
proof (rule terminatesI)
  fix s
  show "s  terminates_on (append_raw g h sh_start)"
  proof (cases s)
    case (Inl sg)
    from terminates g have "sg  terminates_on g" by (simp add: terminates_def)
    thus "s  terminates_on (append_raw g h sh_start)" unfolding Inl
      by induction(auto intro: terminates_on.intros terminates_on_append_raw_Inr[OF terminates h])
  qed(simp add: terminates_on_append_raw_Inr[OF terminates h])
qed

lift_definition append_trans :: "('a, 'sg) generator  ('a, 'sh) generator  'sh  ('a, 'sg + 'sh) generator"
is "append_raw" by (rule terminates_append_raw)

lemma unstream_append_trans_Inr: "unstream (append_trans g h sh) (Inr sh') = unstream h sh'"
proof (induction sh' taking: h rule: unstream.induct)
  case (1 sh')
  then show ?case by (cases "generator h sh'")(simp_all add: append_trans.rep_eq)
qed

lemma unstream_append_trans [stream_fusion]:
  "unstream (append_trans g h sh) (Inl sg) = append (unstream g sg) (unstream h sh)"
proof(induction sg taking: g rule: unstream.induct)
  case (1 sg)
  then show ?case using unstream_append_trans_Inr 
    by (cases "generator g sg")(simp_all add: append_trans.rep_eq)
qed

subsubsection@{const filter}

definition filter_raw :: "('a  bool)  ('a, 's) raw_generator  ('a, 's) raw_generator"
where 
  "filter_raw P g s = (case g s of
     Done  Done | Skip s'  Skip s' | Yield a s'  if P a then Yield a s' else Skip s')"

lemma terminates_filter_raw:
  assumes "terminates g"
  shows "terminates (filter_raw P g)"
proof (rule terminatesI)
  fix s
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  thus "s  terminates_on (filter_raw P g)"
  proof(induction s)
    case (unfold s a s')
    thus ?case
      by(cases "P a")(auto intro: terminates_on.intros simp add: filter_raw_def)
  qed(auto intro: terminates_on.intros simp add: filter_raw_def)
qed

lift_definition filter_trans :: "('a  bool)  ('a,'s) generator  ('a,'s) generator"
is "filter_raw" by (rule terminates_filter_raw)

lemma unstream_filter_trans [stream_fusion]: "unstream (filter_trans P g) s = filter P (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by(cases "generator g s")(simp_all add: filter_trans.rep_eq filter_raw_def)
qed

subsubsection@{const zip}

fun zip_raw :: "('a, 'sg) raw_generator  ('b, 'sh) raw_generator  ('a × 'b, 'sg × 'sh × 'a option) raw_generator"
  ― ‹We search first the left list for the next element and cache it in the @{typ "'a option"}
        part of the state once we found one›
where
  "zip_raw g h (sg, sh, None) = (case g sg of
      Done  Done | Skip sg'  Skip (sg', sh, None) | Yield a sg'  Skip (sg', sh, Some a))"
| "zip_raw g h (sg, sh, Some a) = (case h sh of
      Done  Done | Skip sh'  Skip (sg, sh', Some a) | Yield b sh'  Yield (a, b) (sg, sh', None))"

lemma terminates_zip_raw: 
  assumes "terminates g" "terminates h"
  shows "terminates (zip_raw g h)"
proof (rule terminatesI)
  fix s :: "'a × 'c × 'b option"
  obtain sg sh m where "s = (sg, sh, m)" by(cases s)
  show "s  terminates_on (zip_raw g h)" 
  proof(cases m)
    case None
    from terminates g have "sg  terminates_on g" by (simp add: terminates_def)
    then show ?thesis unfolding s = (sg, sh, m) None
    proof (induction sg arbitrary: sh)
      case (unfold sg a sg')
      from terminates h have "sh  terminates_on h" by (simp add: terminates_def)
      hence "(sg', sh, Some a)  terminates_on (zip_raw g h)"
        by induction(auto intro: terminates_on.intros unfold.IH)
      thus ?case using unfold.hyps by(auto intro: terminates_on.pause)
    qed(simp_all add: terminates_on.stop terminates_on.pause)
  next
    case (Some a')
    from terminates h have "sh  terminates_on h" by (simp add: terminates_def)
    thus ?thesis unfolding s = (sg, sh, m) Some
    proof (induction sh arbitrary: sg a')
      case (unfold sh b sh')
      from terminates g have "sg  terminates_on g" by (simp add: terminates_def)
      hence "(sg, sh', None)  terminates_on (zip_raw g h)"
        by induction(auto intro: terminates_on.intros unfold.IH)
      thus ?case using unfold.hyps by(auto intro: terminates_on.unfold)
    qed(simp_all add: terminates_on.stop terminates_on.pause)
  qed
qed

lift_definition zip_trans :: "('a, 'sg) generator  ('b, 'sh) generator  ('a × 'b,'sg × 'sh × 'a option) generator"
is "zip_raw" by (rule terminates_zip_raw)

lemma unstream_zip_trans [stream_fusion]:
  "unstream (zip_trans g h) (sg, sh, None) = zip (unstream g sg) (unstream h sh)"        
proof (induction sg arbitrary: sh taking: g rule: unstream.induct)
  case (1 sg)
  then show ?case
  proof (cases "generator g sg")
    case (Yield a sg')
    note IH = "1.IH"(2)[OF Yield]
    have "unstream (zip_trans g h) (sg', sh, Some a) = zip (a # (unstream g sg')) (unstream h sh)"
    proof(induction sh taking: h rule: unstream.induct)
      case (1 sh)
      then show ?case using IH by(cases "generator h sh")(simp_all add: zip_trans.rep_eq)
    qed
    then show ?thesis using Yield by (simp add: zip_trans.rep_eq)
  qed(simp_all add: zip_trans.rep_eq)
qed

subsubsection @{const tl}

fun tl_raw :: "('a, 'sg) raw_generator  ('a, bool × 'sg) raw_generator"
  ― ‹The Boolean flag stores whether we have already skipped the first element›
where
  "tl_raw g (False, sg) = (case g sg of
      Done  Done | Skip sg'  Skip (False, sg') | Yield a sg'  Skip (True,sg'))"
| "tl_raw g (True, sg) = (case g sg of
      Done  Done | Skip sg'  Skip (True, sg') | Yield a sg'  Yield a (True, sg'))"

lemma terminates_tl_raw: 
  assumes "terminates g"
  shows "terminates (tl_raw g)"
proof (rule terminatesI)
  fix s :: "bool × 'a"
  obtain b sg where "s = (b, sg)" by(cases s)
  { fix sg
    from assms have "sg  terminates_on g" by(simp add: terminates_def)
    hence "(True, sg)  terminates_on (tl_raw g)"
      by(induction sg)(auto intro: terminates_on.intros) }
  moreover from assms have "sg  terminates_on g" by(simp add: terminates_def)
  hence "(False, sg)  terminates_on (tl_raw g)"
    by(induction sg)(auto intro: terminates_on.intros calculation)
  ultimately show "s  terminates_on (tl_raw g)" using s = (b, sg)
    by(cases b) simp_all
qed

lift_definition tl_trans :: "('a, 'sg) generator  ('a, bool × 'sg) generator"
is "tl_raw" by(rule terminates_tl_raw)

lemma unstream_tl_trans_True: "unstream (tl_trans g) (True, s) = unstream g s"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  show ?case using "1.IH" by (cases "generator g s")(simp_all add: tl_trans.rep_eq)
qed

lemma unstream_tl_trans [stream_fusion]: "unstream (tl_trans g) (False, s) = tl (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case using unstream_tl_trans_True
    by (cases "generator g s")(simp_all add: tl_trans.rep_eq)
qed

subsubsection @{const butlast}

fun butlast_raw :: "('a, 's) raw_generator  ('a, 'a option × 's) raw_generator"
  ― ‹The @{typ "'a option"} caches the previous element we have seen›
where
  "butlast_raw g (None,s) = (case g s of
     Done  Done | Skip s'  Skip (None, s') | Yield a s'  Skip (Some a, s'))"
| "butlast_raw g (Some b, s) = (case g s of
     Done  Done | Skip s'  Skip (Some b, s') | Yield a s'  Yield b (Some a, s'))"

lemma terminates_butlast_raw:
  assumes "terminates g"
  shows "terminates (butlast_raw g)"
proof (rule terminatesI)
  fix st :: "'b option × 'a"
  obtain ma s where "st = (ma,s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (butlast_raw g)" unfolding st = (ma, s)
    apply(induction s arbitrary: ma)
    apply(case_tac [!] ma)
    apply(auto intro: terminates_on.intros)
    done
qed

lift_definition butlast_trans :: "('a,'s) generator  ('a, 'a option × 's) generator"
is "butlast_raw" by (rule terminates_butlast_raw)

lemma unstream_butlast_trans_Some:
  "unstream (butlast_trans g) (Some b,s) = butlast (b # (unstream g s))"
proof (induction s arbitrary: b taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by (cases "generator g s")(simp_all add: butlast_trans.rep_eq)
qed

lemma unstream_butlast_trans [stream_fusion]:
  "unstream (butlast_trans g) (None, s) = butlast (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case using 1 unstream_butlast_trans_Some[of g]
    by (cases "generator g s")(simp_all add: butlast_trans.rep_eq)
qed

subsubsection @{const concat}

text ‹
  We only do the easy version here where
  the generator has type @{typ "('a list,'s) generator"}, not @{typ "(('a, 'si) generator, 's) generator"}

fun concat_raw :: "('a list, 's) raw_generator  ('a, 'a list × 's) raw_generator"
where
  "concat_raw g ([], s) = (case g s of
     Done  Done | Skip s'  Skip ([], s') | Yield xs s'  Skip (xs, s'))"
| "concat_raw g (x # xs, s) = Yield x (xs, s)"

lemma terminates_concat_raw: 
  assumes "terminates g"
  shows "terminates (concat_raw g)"
proof (rule terminatesI)
  fix st :: "'b list × 'a"
  obtain xs s where "st = (xs, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (concat_raw g)" unfolding st = (xs, s)
  proof (induction s arbitrary: xs)
    case (stop s xs)
    then show ?case by (induction xs)(auto intro: terminates_on.stop terminates_on.unfold)
  next
    case (pause s s' xs)
    then show ?case by (induction xs)(auto intro: terminates_on.pause terminates_on.unfold)
  next
    case (unfold s a s' xs)
    then show ?case by (induction xs)(auto intro: terminates_on.pause terminates_on.unfold)
  qed
qed

lift_definition concat_trans :: "('a list, 's) generator  ('a, 'a list × 's) generator"
is "concat_raw" by (rule terminates_concat_raw)

lemma unstream_concat_trans_gen: "unstream (concat_trans g) (xs, s) = xs @ (concat (unstream g s))"
proof (induction s arbitrary: xs taking: g rule: unstream.induct)
  case (1 s)
  then show "unstream (concat_trans g) (xs, s) = xs @ (concat (unstream g s))"
  proof (cases "generator g s")
    case Done
    then show ?thesis by (induction xs)(simp_all add: concat_trans.rep_eq)
  next
    case (Skip s')
    then show ?thesis using "1.IH"(1)[of s' Nil]
      by (induction xs)(simp_all add: concat_trans.rep_eq)
  next
    case (Yield a s')
    then show ?thesis using "1.IH"(2)[of a s' a]
      by (induction xs)(simp_all add: concat_trans.rep_eq)
  qed
qed

lemma unstream_concat_trans [stream_fusion]:
  "unstream (concat_trans g) ([], s) = concat (unstream g s)"
by(simp only: unstream_concat_trans_gen append_Nil)

subsubsection @{const splice}

datatype ('a, 'b) splice_state = Left 'a 'b | Right 'a 'b | Left_only 'a | Right_only 'b

fun splice_raw :: "('a, 'sg) raw_generator  ('a, 'sh) raw_generator  ('a, ('sg, 'sh) splice_state) raw_generator"
where
  "splice_raw g h (Left_only sg) = (case g sg of
     Done  Done | Skip sg'  Skip (Left_only sg') | Yield a sg'  Yield a (Left_only sg'))"
| "splice_raw g h (Left sg sh) = (case g sg of
     Done  Skip (Right_only sh) | Skip sg'  Skip (Left sg' sh) | Yield a sg'  Yield a (Right sg' sh))"
| "splice_raw g h (Right_only sh) = (case h sh of
     Done  Done | Skip sh'  Skip (Right_only sh') | Yield a sh'  Yield a (Right_only sh'))"
| "splice_raw g h (Right sg sh) = (case h sh of
     Done  Skip (Left_only sg) | Skip sh'  Skip (Right sg sh') | Yield a sh'  Yield a (Left sg sh'))"

lemma terminates_splice_raw: 
  assumes g: "terminates g" and h: "terminates h"
  shows "terminates (splice_raw g h)"
proof (rule terminatesI)
  fix s
  { fix sg
    from g have "sg  terminates_on g" by (simp add: terminates_def)
    hence "Left_only sg  terminates_on (splice_raw g h)"
      by induction(auto intro: terminates_on.intros)
  } moreover {
    fix sh
    from h have "sh  terminates_on h" by (simp add: terminates_def)
    hence "Right_only sh  terminates_on (splice_raw g h)"
      by induction(auto intro: terminates_on.intros)
  } moreover {
    fix sg sh
    from g have "sg  terminates_on g" by (simp add: terminates_def)
    hence "Left sg sh  terminates_on (splice_raw g h)"
    proof (induction sg arbitrary: sh)
      case (unfold sg a sg')
      from h have "sh  terminates_on h" by (simp add: terminates_def)
      hence "Right sg' sh  terminates_on (splice_raw g h)"
        by induction(auto intro: terminates_on.intros unfold.IH calculation)
      thus ?case using unfold.hyps by (auto intro: terminates_on.unfold)
    qed(auto intro: terminates_on.intros calculation)
  } moreover {
    fix sg sh
    from h have "sh  terminates_on h" by (simp add: terminates_def)
    hence "Right sg sh  terminates_on (splice_raw g h)"
      by(induction sh arbitrary: sg)(auto intro: terminates_on.intros calculation) }
  ultimately show "s  terminates_on (splice_raw g h)" by(cases s)(simp_all)
qed

lift_definition splice_trans :: "('a, 'sg) generator  ('a, 'sh) generator  ('a, ('sg, 'sh) splice_state) generator"
is "splice_raw" by (rule terminates_splice_raw)

lemma unstream_splice_trans_Right_only: "unstream (splice_trans g h) (Right_only sh) = unstream h sh" 
proof (induction sh taking: h rule: unstream.induct)
  case (1 sh)
  then show ?case by (cases "generator h sh")(simp_all add: splice_trans.rep_eq)
qed

lemma unstream_splice_trans_Left_only: "unstream (splice_trans g h) (Left_only sg) = unstream g sg"
proof (induction sg taking: g rule: unstream.induct)
  case (1 sg)
  then show ?case by (cases "generator g sg")(simp_all add: splice_trans.rep_eq)
qed

lemma unstream_splice_trans [stream_fusion]:
  "unstream (splice_trans g h) (Left sg sh) = splice (unstream g sg) (unstream h sh)"
proof (induction sg arbitrary: sh taking: g rule: unstream.induct)
  case (1 sg sh)
  then show ?case
  proof (cases "generator g sg")
    case Done
    with unstream_splice_trans_Right_only[of g h]
    show ?thesis by (simp add: splice_trans.rep_eq)
  next
    case (Skip sg')
    then show ?thesis using "1.IH"(1) by (simp add: splice_trans.rep_eq)
  next
    case (Yield a sg')
    note IH = "1.IH"(2)[OF Yield]

    have "a # (unstream (splice_trans g h) (Right sg' sh)) = splice (unstream g sg) (unstream h sh)"
    proof (induction sh taking: h rule: unstream.induct)
      case (1 sh)
      show ?case
      proof (cases "generator h sh")
        case Done
        with unstream_splice_trans_Left_only[of g h sg']
        show ?thesis using Yield by (simp add: splice_trans.rep_eq)
      next
        case (Skip sh')
        then show ?thesis using Yield "1.IH"(1) "1.prems" by(simp add: splice_trans.rep_eq)
      next
        case (Yield b sh')
        then show ?thesis using IH generator g sg = Yield a sg'
          by (simp add: splice_trans.rep_eq)
      qed
    qed
    then show ?thesis using Yield by (simp add: splice_trans.rep_eq)
  qed
qed


subsubsection @{const list_update}

fun list_update_raw :: "('a,'s) raw_generator  'a  ('a, nat × 's) raw_generator"
where
  "list_update_raw g b (n, s) = (case g s of
     Done  Done | Skip s'  Skip (n, s') 
   | Yield a s'  if n = 0 then Yield a (0,s')
                   else if n = 1 then Yield b (0, s')
                   else Yield a (n - 1, s'))"

lemma terminates_list_update_raw:
  assumes "terminates g"
  shows "terminates (list_update_raw g b)"
proof (rule terminatesI)
  fix st :: "nat × 'a"
  obtain n s where "st = (n, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (list_update_raw g b)" unfolding st = (n, s)
  proof (induction s arbitrary: n)
    case (unfold s a s' n)
    then show "(n, s)  terminates_on (list_update_raw g b)"
      by(cases "n = 0  n = 1")(auto intro: terminates_on.unfold)
  qed(simp_all add: terminates_on.stop terminates_on.pause)
qed

lift_definition list_update_trans :: "('a,'s) generator  'a  ('a, nat × 's)  generator"
is "list_update_raw" by (rule terminates_list_update_raw)

lemma unstream_lift_update_trans_None: "unstream (list_update_trans g b) (0, s) = unstream g s"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by (cases "generator g s")(simp_all add: list_update_trans.rep_eq)
qed

lemma unstream_list_update_trans [stream_fusion]:
  "unstream (list_update_trans g b) (Suc n, s) = list_update (unstream g s) n b"
proof(induction s arbitrary: n taking: g rule: unstream.induct)
  case (1 s)
  then show ?case
  proof (cases "generator g s")
    case Done
    then show ?thesis by (simp add: list_update_trans.rep_eq)
  next
    case (Skip s')
    then show ?thesis using "1.IH"(1) by (simp add: list_update_trans.rep_eq)
  next
    case (Yield a s')
    then show ?thesis using unstream_lift_update_trans_None[of g b s'] "1.IH"(2) 
      by (cases n)(simp_all add: list_update_trans.rep_eq)
  qed
qed

subsubsection @{const removeAll}

definition removeAll_raw :: "'a  ('a, 's) raw_generator  ('a, 's) raw_generator"
where
 "removeAll_raw b g s = (case g s of
    Done  Done | Skip s'  Skip s' | Yield a s'  if a = b then Skip s' else Yield a s')"

lemma terminates_removeAll_raw:
  assumes "terminates g"
  shows "terminates (removeAll_raw b g)"
proof (rule terminatesI)
  fix s
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "s  terminates_on (removeAll_raw b g)"
  proof(induction s)
    case (unfold s a s')
    then show ?case
      by(cases "a = b")(auto intro: terminates_on.intros simp add: removeAll_raw_def)
  qed(auto intro: terminates_on.intros simp add: removeAll_raw_def)
qed

lift_definition removeAll_trans :: "'a  ('a, 's) generator  ('a, 's) generator"
is "removeAll_raw" by (rule terminates_removeAll_raw)

lemma unstream_removeAll_trans [stream_fusion]:
  "unstream (removeAll_trans b g) s = removeAll b (unstream g s)"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case
  proof(cases "generator g s")
    case (Yield a s')
    then show ?thesis using "1.IH"(2)
      by(cases "a = b")(simp_all add: removeAll_trans.rep_eq removeAll_raw_def)
  qed(auto simp add: removeAll_trans.rep_eq removeAll_raw_def)
qed

subsubsection @{const remove1}

fun remove1_raw :: "'a  ('a, 's) raw_generator  ('a, bool × 's) raw_generator"
where
  "remove1_raw x g (b, s) = (case g s of
     Done  Done | Skip s'  Skip (b, s') 
   | Yield y s'  if b  x = y then Skip (False, s') else Yield y (b, s'))"

lemma terminates_remove1_raw: 
  assumes "terminates g"
  shows "terminates (remove1_raw b g)"
proof (rule terminatesI)
  fix st :: "bool × 'a"
  obtain c s where "st = (c, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (remove1_raw b g)" unfolding st = (c, s)
  proof (induction s arbitrary: c)
    case (stop s)
    then show ?case by (cases c)(simp_all add: terminates_on.stop)
  next
    case (pause s s')
    then show ?case by (cases c)(simp_all add: terminates_on.pause)
  next
    case (unfold s a s')
    then show ?case
      by(cases c)(cases "a = b", auto intro: terminates_on.intros)
   qed
qed

lift_definition remove1_trans :: "'a  ('a, 's) generator  ('a, bool × 's) generator "
is "remove1_raw" by (rule terminates_remove1_raw)

lemma unstream_remove1_trans_False: "unstream (remove1_trans b g) (False, s) = unstream g s"
proof (induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by (cases "generator g s")(simp_all add: remove1_trans.rep_eq)
qed

lemma unstream_remove1_trans [stream_fusion]:
  "unstream (remove1_trans b g) (True, s) = remove1 b (unstream g s)"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case
  proof (cases "generator g s")
    case (Yield a s')
    then show ?thesis
      using Yield "1.IH"(2) unstream_remove1_trans_False[of b g]
      by (cases "a = b")(simp_all add: remove1_trans.rep_eq)
  qed(simp_all add: remove1_trans.rep_eq)
qed

subsubsection @{term "(#)"}

fun Cons_raw :: "'a  ('a, 's) raw_generator  ('a, bool × 's) raw_generator"
where
  "Cons_raw x g (b, s) = (if b then Yield x (False, s) else case g s of
    Done  Done | Skip s'  Skip (False, s') | Yield y s'  Yield y (False, s'))"

lemma terminates_Cons_raw: 
  assumes "terminates g"
  shows "terminates (Cons_raw x g)"
proof (rule terminatesI)
  fix st :: "bool × 'a"
  obtain b s where "st = (b, s)" by (cases st)
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  hence "(False, s)  terminates_on (Cons_raw x g)"
    by(induction s arbitrary: b)(auto intro: terminates_on.intros)
  then show "st  terminates_on (Cons_raw x g)" unfolding st = (b, s)
    by(cases b)(auto intro: terminates_on.intros)
qed

lift_definition Cons_trans :: "'a  ('a, 's) generator  ('a, bool × 's) generator"
is Cons_raw by(rule terminates_Cons_raw)

lemma unstream_Cons_trans_False: "unstream (Cons_trans x g) (False, s) = unstream g s"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  then show ?case by(cases "generator g s")(auto simp add: Cons_trans.rep_eq)
qed

text ‹
  We do not declare @{const Cons_trans} as a transformer.
  Otherwise, literal lists would be transformed into streams which adds a significant overhead
  to the stream state.
›
lemma unstream_Cons_trans: "unstream (Cons_trans x g) (True, s) = x # unstream g s"
using unstream_Cons_trans_False[of x g s] by(simp add: Cons_trans.rep_eq)

subsubsection @{const List.maps}

text ‹Stream version based on Coutts cite"Coutts2010PhD".›

text ‹
  We restrict the function for generating the inner lists to terminating
  generators because the code generator does not directly supported nesting abstract
  datatypes in other types.
›

fun maps_raw
  :: "('a  ('b, 'sg) generator × 'sg)  ('a, 's) raw_generator
   ('b, 's × (('b, 'sg) generator × 'sg) option) raw_generator"
where
  "maps_raw f g (s, None) = (case g s of
    Done  Done | Skip s'  Skip (s', None) | Yield x s'  Skip (s', Some (f x)))"
| "maps_raw f g (s, Some (g'', s'')) = (case generator g'' s'' of
    Done  Skip (s, None) | Skip s'  Skip (s, Some (g'', s')) | Yield x s'  Yield x (s, Some (g'', s')))"

lemma terminates_on_maps_raw_Some: 
  assumes "(s, None)  terminates_on (maps_raw f g)"
  shows "(s, Some (g'', s''))  terminates_on (maps_raw f g)"
proof -
  from generator[of g''] have "s''  terminates_on (generator g'')" by (simp add: terminates_def)
  thus ?thesis by(induction)(auto intro: terminates_on.intros assms)
qed

lemma terminates_maps_raw: 
  assumes "terminates g"
  shows "terminates (maps_raw f g)"
proof
  fix st :: "'a × (('c, 'd) generator × 'd) option"
  obtain s mgs where "st = (s, mgs)" by(cases st) 
  from assms have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on (maps_raw f g)" unfolding st = (s, mgs)
    apply(induction arbitrary: mgs)
    apply(case_tac [!] mgs)
    apply(auto intro: terminates_on.intros intro!: terminates_on_maps_raw_Some)
    done
qed

lift_definition maps_trans :: "('a  ('b, 'sg) generator × 'sg)  ('a, 's) generator
   ('b, 's × (('b, 'sg) generator × 'sg) option) generator"
is "maps_raw" by(rule terminates_maps_raw)

lemma unstream_maps_trans_Some:
  "unstream (maps_trans f g) (s, Some (g'', s'')) = unstream g'' s'' @ unstream (maps_trans f g) (s, None)"
proof(induction s'' taking: g'' rule: unstream.induct)
  case (1 s'')
  then show ?case by(cases "generator g'' s''")(simp_all add: maps_trans.rep_eq)
qed

lemma unstream_maps_trans:
  "unstream (maps_trans f g) (s, None) = List.maps (case_prod unstream  f) (unstream g s)"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  thus ?case
  proof(cases "generator g s")
    case (Yield x s')
    with "1.IH"(2)[OF this] show ?thesis
      using unstream_maps_trans_Some[of f g _ "fst (f x)" "snd (f x)"]
      by(simp add: maps_trans.rep_eq maps_simps split_def)
  qed(simp_all add: maps_trans.rep_eq maps_simps)
qed

text ‹
  The rule @{thm [source] unstream_map_trans} is too complicated for fusion because of @{term split},
  which does not arise naturally from stream fusion rules. Moreover, according to Farmer et al.
  cite"FarmerHoenerGill2014PEPM", this fusion is too general for further optimisations because the
  generators of the inner list are generated by the outer generator and therefore compilers may
  think that is was not known statically. 

  Instead, they propose a weaker version using flatten› below.
  (More precisely, Coutts already mentions this approach in his PhD thesis cite"Coutts2010PhD",
  but dismisses it because it requires a stronger rewriting engine than GHC has. But Isabelle's
  simplifier language is sufficiently powerful.
›

fun fix_step :: "'a  ('b, 's) step  ('b, 'a × 's) step"
where
  "fix_step a Done = Done"
| "fix_step a (Skip s) = Skip (a, s)"
| "fix_step a (Yield x s) = Yield x (a, s)"

fun fix_gen_raw :: "('a  ('b, 's) raw_generator)  ('b, 'a × 's) raw_generator"
where "fix_gen_raw g (a, s) = fix_step a (g a s)"

lemma terminates_fix_gen_raw:
  assumes "x. terminates (g x)"
  shows "terminates (fix_gen_raw g)"
proof
  fix st :: "'a × 'b"
  obtain a s where "st = (a, s)" by(cases st)
  from assms[of a] have "s  terminates_on (g a)" by (simp add: terminates_def)
  then show "st  terminates_on (fix_gen_raw g)" unfolding st = (a, s)
    by(induction)(auto intro: terminates_on.intros)
qed

lift_definition fix_gen :: "('a  ('b, 's) generator)  ('b, 'a × 's) generator"
is "fix_gen_raw" by(rule terminates_fix_gen_raw)

lemma unstream_fix_gen: "unstream (fix_gen g) (a, s) = unstream (g a) s"
proof(induction s taking: "g a" rule: unstream.induct)
  case (1 s)
  thus ?case by(cases "generator (g a) s")(simp_all add: fix_gen.rep_eq)
qed

context 
  fixes f :: "('a  's')"
  and g'' :: "('b, 's') raw_generator"
  and g :: "('a, 's) raw_generator"
begin

fun flatten_raw :: "('b, 's × 's' option) raw_generator"
where
  "flatten_raw (s, None) = (case g s of
     Done  Done | Skip s'  Skip (s', None) | Yield x s'  Skip (s', Some (f x)))"
| "flatten_raw (s, Some s'') = (case g'' s'' of
     Done  Skip (s, None) | Skip s'  Skip (s, Some s') | Yield x s'  Yield x (s, Some s'))"

lemma terminates_flatten_raw: 
  assumes "terminates g''" "terminates g"
  shows "terminates flatten_raw"
proof
  fix st :: "'s × 's' option"
  obtain s ms where "st = (s, ms)" by(cases st)
  { fix s s''
    assume s: "(s, None)  terminates_on flatten_raw"
    from terminates g'' have "s''  terminates_on g''" by (simp add: terminates_def)
    hence "(s, Some s'')  terminates_on flatten_raw"
      by(induction)(auto intro: terminates_on.intros s) }
  note Some = this
  from terminates g have "s  terminates_on g" by (simp add: terminates_def)
  then show "st  terminates_on flatten_raw" unfolding st = (s, ms)
    apply(induction arbitrary: ms)
    apply(case_tac [!] ms)
    apply(auto intro: terminates_on.intros intro!: Some)
    done
qed

end

lift_definition flatten :: "('a  's')  ('b, 's') generator  ('a, 's) generator  ('b, 's × 's' option) generator"
is "flatten_raw" by(fact terminates_flatten_raw)

lemma unstream_flatten_Some:
  "unstream (flatten f g'' g) (s, Some s') = unstream g'' s' @ unstream (flatten f g'' g) (s, None)"
proof(induction s' taking: g'' rule: unstream.induct)
  case (1 s')
  thus ?case by(cases "generator g'' s'")(simp_all add: flatten.rep_eq)
qed

text ‹HO rewrite equations can express the variable capture in the generator unlike GHC rules›

lemma unstream_flatten_fix_gen [stream_fusion]:
  "unstream (flatten (λs. (s, f s)) (fix_gen g'') g) (s, None) =
   List.maps (λs'. unstream (g'' s') (f s')) (unstream g s)"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  thus ?case
  proof(cases "generator g s")
    case (Yield x s')
    with "1.IH"(2)[OF this] unstream_flatten_Some[of "λs. (s, f s)" "fix_gen g''" g]
    show ?thesis
      by(subst (1 3) unstream.simps)(simp add: flatten.rep_eq maps_simps unstream_fix_gen)
  qed(simp_all add: flatten.rep_eq maps_simps)
qed

text ‹
  Separate fusion rule when the inner generator does not depend on the elements of the outer stream.
›
lemma unstream_flatten [stream_fusion]:
  "unstream (flatten f g'' g) (s, None) = List.maps (λs'. unstream g'' (f s')) (unstream g s)"
proof(induction s taking: g rule: unstream.induct)
  case (1 s)
  thus ?case 
  proof(cases "generator g s")
    case (Yield x s')
    with "1.IH"(2)[OF this] show ?thesis
      using unstream_flatten_Some[of f g'' g s' "f x"]
      by(simp add: flatten.rep_eq maps_simps o_def)
  qed(simp_all add: maps_simps flatten.rep_eq)
qed

end