Theory HyperdualFunctionExtension

(*  Title:   HyperdualFunctionExtension.thy
    Authors: Jacques D. Fleuriot and Filip Smola, University of Edinburgh, 2021

section ‹Hyperdual Extension of Functions›

theory HyperdualFunctionExtension
  imports Hyperdual TwiceFieldDifferentiable

text‹The following is an important fact in the derivation of the hyperdual extension.›
    fixes x :: "('a :: comm_ring_1) hyperdual" and n :: nat
  assumes "Base x = 0"
    shows "x ^ (n + 3) = 0"
proof (induct n)
  case 0
  then show ?case
    using assms hyperdual_power[of x 3] by simp
  case (Suc n)
  then show ?case
    using assms power_Suc[of x "n + 3"] mult_zero_right add_Suc by simp

text‹We define the extension of a function to the hyperdual numbers.›
primcorec hypext :: "(('a :: real_normed_field)  'a)  'a hyperdual  'a hyperdual" (*h* _› [80] 80)
    "Base ((*h* f) x) = f (Base x)"
  | "Eps1 ((*h* f) x) = Eps1 x * deriv f (Base x)"
  | "Eps2 ((*h* f) x) = Eps2 x * deriv f (Base x)"
  | "Eps12 ((*h* f) x) = Eps12 x * deriv f (Base x) + Eps1 x * Eps2 x * deriv (deriv f) (Base x)"

text‹This has the expected behaviour when expressed in terms of the units.›
lemma hypext_Hyperdual_eq:
  "(*h* f) (Hyperdual a b c d) =
     Hyperdual (f a) (b * deriv f a) (c * deriv f a) (d * deriv f a + b * c * deriv (deriv f) a)"
  by (simp add: hypext.code)

lemma hypext_Hyperdual_eq_parts:
  "(*h* f) (Hyperdual a b c d) =
      f a *H ba + (b * deriv f a) *H e1 + (c * deriv f a) *H e2 +
         (d * deriv f a + b * c * deriv (deriv f) a) *H e12 "
  by (metis Hyperdual_eq hypext_Hyperdual_eq)

  The extension can be used to extract the function value, and first and second derivatives at x
  when applied to @{term "x *H re + e1 + e2 + 0 *H e12"}, which we denote by @{term "β x"}.
definition hyperdualx :: "('a :: real_normed_field)  'a hyperdual" (β)
  where "β x = (Hyperdual x 1 1 0)"

lemma hyperdualx_sel [simp]:
  shows "Base (β x) = x"
    and "Eps1 (β x) = 1"
    and "Eps2 (β x) = 1"
    and "Eps12 (β x) = 0"
  by (simp_all add: hyperdualx_def)

lemma hypext_extract_eq:
  "(*h* f) (β x) = f x *H ba + deriv f x *H e1 + deriv f x *H e2 + deriv (deriv f) x *H e12"
  by (simp add: hypext_Hyperdual_eq_parts hyperdualx_def)

lemma Base_hypext:
  "Base ((*h* f) (β x)) = f x"
  by (simp add: hyperdualx_def)

lemma Eps1_hypext:
  "Eps1 ((*h* f) (β x)) = deriv f x"
  by (simp add: hyperdualx_def)

lemma Eps2_hypext:
  "Eps2 ((*h* f) (β x)) = deriv f x"
  by (simp add: hyperdualx_def)

lemma Eps12_hypext:
  "Eps12 ((*h* f) (β x)) = deriv (deriv f) x"
  by (simp add: hyperdualx_def)

subsubsection‹Convenience Interface›

text‹Define a datatype to hold the function value, and the first and second derivative values.›
datatype ('a :: real_normed_field) derivs = Derivs (Value: 'a) (First: 'a) (Second: 'a)

  Then we convert a hyperdual number to derivative values by extracting the base component, one of
  the first-order components, and the second-order component.
fun hyperdual_to_derivs :: "('a :: real_normed_field) hyperdual  'a derivs"
  where "hyperdual_to_derivs x = Derivs (Base x) (Eps1 x) (Eps12 x)"

  Finally we define way of converting any compatible function into one that yields the value and the
fun autodiff :: "('a :: real_normed_field  'a)  'a  'a derivs"
  where "autodiff f = (λx. hyperdual_to_derivs ((*h* f) (β x)))"

lemma autodiff_sel:
  "Value (autodiff f x) = Base ((*h* f) (β x))"
  "First (autodiff f x) = Eps1 ((*h* f) (β x))"
  "Second (autodiff f x) = Eps12 ((*h* f) (β x))"
  by simp_all

text‹The result contains the expected values.›
lemma autodiff_extract_value:
  "Value (autodiff f x) = f x"
  by (simp del: hypext.simps add: Base_hypext)

lemma autodiff_extract_first:
  "First (autodiff f x) = deriv f x"
  by (simp del: hypext.simps add: Eps1_hypext)

lemma autodiff_extract_second:
  "Second (autodiff f x) = deriv (deriv f) x"
  by (simp del: hypext.simps add: Eps12_hypext)

  The derivative components of the result are actual derivatives if the function is sufficiently
  differentiable on that argument.
lemma autodiff_first_derivative:
  assumes "f field_differentiable (at x)"
  shows "(f has_field_derivative First (autodiff f x)) (at x)"
  by (simp add: autodiff_extract_first DERIV_deriv_iff_field_differentiable assms)

lemma autodiff_second_derivative:
  assumes "f twice_field_differentiable_at x"
  shows "((deriv f) has_field_derivative Second (autodiff f x)) (at x)"
  by (simp add: autodiff_extract_second DERIV_deriv_iff_field_differentiable assms deriv_field_differentiable_at)


text‹Composition of hyperdual extensions is the hyperdual extension of composition:›
lemma hypext_compose:
  assumes "f twice_field_differentiable_at (Base x)"
      and "g twice_field_differentiable_at (f (Base x))"
    shows "(*h* (λx. g (f x))) x = (*h* g) ((*h* f) x)"
proof (simp add: hyperdual_eq_iff, intro conjI disjI2)
  show goal1: "deriv (λx. g (f x)) (Base x) = deriv f (Base x) * deriv g (f (Base x))"
  proof -
    have "deriv (λx. g (f x)) (Base x) = deriv (g  f) (Base x)"
      by (simp add: comp_def)
    also have "... = deriv g (f (Base x)) * deriv f (Base x)"
      using assms by (simp add: deriv_chain once_field_differentiable_at)
    finally show ?thesis
      by (simp add: mult.commute deriv_chain)
  then show "deriv (λx. g (f x)) (Base x) = deriv f (Base x) * deriv g (f (Base x))" .

  have first_diff: "(λx. deriv g (f x)) field_differentiable at (Base x)"
    by (metis DERIV_chain2 assms deriv_field_differentiable_at field_differentiable_def once_field_differentiable_at)

  have "deriv (deriv g  f) (Base x) = deriv (deriv g) (f (Base x)) * deriv f (Base x)"
    using deriv_chain assms once_field_differentiable_at deriv_field_differentiable_at
    by blast
  then have deriv_deriv_comp: "deriv (λx. deriv g (f x)) (Base x) = deriv (deriv g) (f (Base x)) * deriv f (Base x)"
    by (simp add: comp_def)

  have "deriv (deriv (λx. g (f x))) (Base x) = deriv ((λx. deriv f x * deriv g (f x))) (Base x)"
    using assms eventually_deriv_compose'[of f "Base x" g]
    by (simp add: mult.commute deriv_cong_ev)
  also have "... = deriv f (Base x) * deriv (λx. deriv g (f x)) (Base x) + deriv (deriv f) (Base x) * deriv g (f (Base x))"
    using assms(1) first_diff by (simp add: deriv_field_differentiable_at)
  also have "... = deriv f (Base x) * deriv (deriv g) (f (Base x)) * deriv f (Base x) + deriv (deriv f) (Base x) * deriv g (f (Base x))"
    using deriv_deriv_comp by simp
  finally show "Eps12 x * deriv (λx. g (f x)) (Base x) + Eps1 x * Eps2 x * deriv (deriv (λx. g (f x))) (Base x) =
    (Eps12 x * deriv f (Base x) + Eps1 x * Eps2 x * deriv (deriv f) (Base x)) * deriv g (f (Base x)) +
    Eps1 x * deriv f (Base x) * (Eps2 x * deriv f (Base x)) * deriv (deriv g) (f (Base x))"
    by (simp add: goal1 field_simps)

subsection‹Concrete Instances›


text‹Component embedding is an extension of the constant function.›
lemma hypext_const [simp]:
  "(*h* (λx. a)) x = of_comp a"
  by (simp add: of_comp_def hyperdual_eq_iff)

lemma "autodiff (λx. a) = (λx. Derivs a 0 0)"
  by simp


text‹Identity is an extension of the component identity.›
lemma hypext_ident:
  "(*h* (λx. x)) x = x"
  by (simp add: hyperdual_eq_iff)

subsubsection‹Component Scalar Multiplication›

text‹Component scaling is an extension of component constant multiplication:›
lemma hypext_scaleH:
  "(*h* (λx. k * x)) x = k *H x"
  by (simp add: hyperdual_eq_iff)

lemma hypext_fun_scaleH:
  assumes "f twice_field_differentiable_at (Base x)"
  shows "(*h* (λx. k * f x)) x = k *H (*h* f) x"
  using assms by (simp add: hypext_compose hypext_scaleH)

text‹Unary minus is just an instance of constant multiplication:›
lemma hypext_uminus:
  "(*h* uminus) x = - x"
  using hypext_scaleH[of "-1" x] by simp

subsubsection‹Real Scalar Multiplication›

text‹Real scaling is an extension of component real scaling:›
lemma hypext_scaleR:
  "(*h* (λx. k *R x)) x = k *R x"
  by (auto simp add: hyperdual_eq_iff)

lemma hypext_fun_scaleR:
  assumes "f twice_field_differentiable_at (Base x)"
  shows "(*h* (λx. k *R f x)) x = k *R (*h* f) x"
  using assms by (simp add: hypext_compose hypext_scaleR)


text‹Addition of hyperdual extensions is a hyperdual extension of addition of functions.›
lemma hypext_fun_add:
  assumes "f twice_field_differentiable_at (Base x)"
      and "g twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. f x + g x)) x = (*h* f) x + (*h* g) x"
proof (simp add: hyperdual_eq_iff distrib_left[symmetric], intro conjI disjI2)
  show goal1: "deriv (λx. f x + g x) (Base x) = deriv f (Base x) + deriv g (Base x)"
    by (simp add: assms once_field_differentiable_at distrib_left)
  then show "deriv (λx. f x + g x) (Base x) = deriv f (Base x) + deriv g (Base x)" .

  have "deriv (deriv (λx. f x + g x)) (Base x) = deriv (λw. deriv f w + deriv g w) (Base x)"
    by (simp add: assms deriv_cong_ev eventually_deriv_add)
  moreover have "Eps12 x * deriv f (Base x) + Eps12 x * deriv g (Base x) = Eps12 x * deriv (λx. f x + g x) (Base x)"
    by (metis distrib_left goal1)
  ultimately show "Eps12 x * deriv (λx. f x + g x) (Base x) +
    Eps1 x * Eps2 x * deriv (deriv (λx. f x + g x)) (Base x) =
    Eps12 x * deriv f (Base x) + Eps1 x * Eps2 x * deriv (deriv f) (Base x) +
    (Eps12 x * deriv g (Base x) + Eps1 x * Eps2 x * deriv (deriv g) (Base x))"
    using deriv_add[OF deriv_field_differentiable_at deriv_field_differentiable_at, OF assms]
    by (simp add: distrib_left add.left_commute)

lemma hypext_cadd [simp]:
  "(*h* (λx. x + a)) x = x + of_comp a"
  by (auto simp add: hyperdual_eq_iff of_comp_def)

lemma hypext_fun_cadd:
  assumes "f twice_field_differentiable_at (Base x)"
  shows "(*h* (λx. f x + a)) x = (*h* f) x + of_comp a"
  using assms hypext_compose[of f x "λx. x + a"] by simp

subsubsection‹Component Linear Function›

text‹Hyperdual linear function is an extension of the component linear function:›
lemma hypext_linear:
  "(*h* (λx. k * x + a)) x = k *H x + of_comp a"
  using hypext_fun_add[of "(*) k" x "λx. a"]
  by (simp add: hypext_scaleH)

lemma hypext_fun_linear:
  assumes "f twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. k * f x + a)) x = k *H (*h* f) x + of_comp a"
  using assms hypext_compose[of f x "λx. k * x + a"] by (simp add: hypext_linear)

subsubsection‹Real Linear Function›

text‹We have the same for real scaling instead of component multiplication:›
lemma hypext_linearR:
  "(*h* (λx. k *R x + a)) x = k *R x + of_comp a"
  using hypext_fun_add[of "(*R) k" x "λx. a"]
  by (simp add: hypext_scaleR)

lemma hypext_fun_linearR:
  assumes "f twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. k *R f x + a)) x = k *R (*h* f) x + of_comp a"
  using assms hypext_compose[of f x "λx. k *R x + a"] by (simp add: hypext_linearR)


text‹Extension of multiplication is multiplication of the functions' extensions.›
lemma hypext_fun_mult:
  assumes "f twice_field_differentiable_at (Base x)"
      and "g twice_field_differentiable_at (Base x)"
    shows "(*h* (λz. f z * g z)) x = (*h* f) x * (*h* g) x"
proof (simp add: hyperdual_eq_iff distrib_left[symmetric], intro conjI)
  show "Eps1 x * deriv (λz. f z * g z) (Base x) =
        f (Base x) * (Eps1 x * deriv g (Base x)) + Eps1 x * deriv f (Base x) * g (Base x)"
   and "Eps2 x * deriv (λz. f z * g z) (Base x) =
        f (Base x) * (Eps2 x * deriv g (Base x)) + Eps2 x * deriv f (Base x) * g (Base x)"
    using assms by (simp_all add: once_field_differentiable_at distrib_left)

    "deriv (deriv (λz. f z * g z)) (Base x) =
     f (Base x) * deriv (deriv g) (Base x) + 2 * deriv f (Base x) * deriv g (Base x) + deriv (deriv f) (Base x) * g (Base x)"
  proof -
    have "deriv (deriv (λz. f z * g z)) (Base x) = deriv (λz. f z * deriv g z + deriv f z * g z) (Base x)"
      using assms by (simp add: eventually_deriv_mult deriv_cong_ev)
    also have "... = (λz. f z * deriv (deriv g) z + deriv f z * deriv g z + deriv f z * deriv g z + deriv (deriv f) z * g z) (Base x)"
      by (simp add: assms deriv_field_differentiable_at field_differentiable_mult once_field_differentiable_at)
    finally show ?thesis
      by simp
  then show
    "Eps12 x * deriv (λz. f z * g z) (Base x) + Eps1 x * Eps2 x * deriv (deriv (λz. f z * g z)) (Base x) =
     2 * (Eps1 x * (Eps2 x * (deriv f (Base x) * deriv g (Base x)))) +
     f (Base x) * (Eps12 x * deriv g (Base x) + Eps1 x * Eps2 x * deriv (deriv g) (Base x)) +
     (Eps12 x * deriv f (Base x) + Eps1 x * Eps2 x * deriv (deriv f) (Base x)) * g (Base x)"
     using assms by (simp add: once_field_differentiable_at field_simps)

subsubsection‹Sine and Cosine›

text‹The extended sin and cos at an arbitrary hyperdual.›

lemma hypext_sin_Hyperdual:
  "(*h* sin) (Hyperdual a b c d) = sin a *H ba + (b *cos a) *H e1 + (c * cos a) *H e2 + (d * cos a - b * c * sin a) *H e12 "
  by (simp add: hypext_Hyperdual_eq_parts)

lemma hypext_cos_Hyperdual:
  "(*h* cos) (Hyperdual a b c d) = cos a *H ba - (b * sin a) *H e1 - (c * sin a) *H e2 - (d * sin a + b * c * cos a) *H e12 "
proof -
  have "of_comp (- (d * sin a) - b * c * cos a) * e12 = - (of_comp (d * sin a + b * c * cos a) * e12)"
    by (metis add_uminus_conv_diff minus_add_distrib mult_minus_left of_comp_minus)
  then show ?thesis
    by (simp add: hypext_Hyperdual_eq_parts of_comp_minus scaleH_times)

lemma Eps1_hypext_sin [simp]:
  "Eps1 ((*h* sin) x) = Eps1 x * cos (Base x)"
  by simp

lemma Eps2_hypext_sin [simp]:
  "Eps2 ((*h* sin) x) = Eps2 x * cos (Base x)"
  by simp

lemma Eps12_hypext_sin [simp]:
  "Eps12 ((*h* sin) x) = Eps12 x * cos (Base x) - Eps1 x * Eps2 x * sin (Base x)"
  by simp

lemma hypext_sin_e1 [simp]:
  "(*h* sin) (x * e1) = e1 * x"
  by (simp add: e1_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_sin_e2 [simp]:
  "(*h* sin) (x * e2) = e2 * x"
  by (simp add: e2_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_sin_e12 [simp]:
  "(*h* sin) (x * e12) = e12 * x"
  by (simp add: e12_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_cos_e1 [simp]:
  "(*h* cos) (x * e1) = 1"
  by (simp add: e1_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_cos_e2 [simp]:
  "(*h* cos) (x * e2) = 1"
  by (simp add: e2_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_cos_e12 [simp]:
  "(*h* cos) (x * e12) = 1"
  by (simp add: e12_def hyperdual_eq_iff one_hyperdual_def)

text‹The extended sin and cos at @{term "β x"}.›

lemma hypext_sin_extract:
  "(*h* sin) (β x) = sin x *H ba + cos x *H e1 + cos x *H e2 - sin x *H e12"
  by (simp add: hypext_sin_Hyperdual of_comp_minus scaleH_times hyperdualx_def)

lemma hypext_cos_extract:
  "(*h* cos) (β x) = cos x *H ba - sin x *H e1 - sin x *H e2 - cos x *H e12"
  by (simp add: hypext_cos_Hyperdual hyperdualx_def)

text‹Extracting the extended sin components at @{term "β x"}.›

lemma Base_hypext_sin_extract [simp]:
  "Base ((*h* sin) (β x)) = sin x"
  by (rule Base_hypext)

lemma Eps2_hypext_sin_extract [simp]:
  "Eps2 ((*h* sin) (β x)) = cos x"
  using Eps2_hypext[of sin] by simp

lemma Eps12_hypext_sin_extract [simp]:
  "Eps12 ((*h* sin) (β x)) = - sin x"
  using Eps12_hypext[of sin] by simp

text‹Extracting the extended cos components at @{term "β x"}.›

lemma Base_hypext_cos_extract [simp]:
  "Base ((*h* cos) (β x)) = cos x"
  by (rule Base_hypext)

lemma Eps2_hypext_cos_extract [simp]:
  "Eps2 ((*h* cos) (β x)) = - sin x"
  using Eps2_hypext[of cos] by simp

lemma Eps12_hypext_cos_extract [simp]:
  "Eps12 ((*h* cos) (β x)) = - cos x"
  using Eps12_hypext[of cos] by simp

text‹We get one of the key trigonometric properties for the extensions of sin and cos.›

lemma "((*h* sin) x)2 + ((*h* cos) x)2 = 1"
  by (simp add: hyperdual_eq_iff one_hyperdual_def power2_eq_square field_simps)

(* example *)
lemma "(*h* sin) x + (*h* cos) x = (*h* (λx. sin x + cos x)) x"
  by (simp add: hypext_fun_add)


text‹The exponential function extension behaves as expected.›

lemma hypext_exp_Hyperdual:
  "(*h* exp) (Hyperdual a b c d) =
       exp a *H ba + (b * exp a) *H e1 + (c * exp a) *H e2 + (d * exp a + b * c * exp a) *H e12"
  by (simp add: hypext_Hyperdual_eq_parts)

lemma hypext_exp_extract:
  "(*h* exp) (β x) = exp x *H ba + exp x *H e1 + exp x *H e2 + exp x *H e12"
  by (simp add: hypext_extract_eq)

lemma hypext_exp_e1 [simp]:
  "(*h* exp) (x * e1) = 1 + e1 * x"
  by (simp add: e1_def hyperdual_eq_iff)

lemma hypext_exp_e2 [simp]:
  "(*h* exp) (x * e2) = 1 + e2 * x"
  by (simp add: e2_def hyperdual_eq_iff)

lemma hypext_exp_e12 [simp]:
  "(*h* exp) (x * e12) = 1 + e12 * x"
  by (simp add: e12_def hyperdual_eq_iff)

text‹Extracting the parts for the exponential function extension.›

lemma Eps1_hypext_exp_extract [simp]:
  "Eps1 ((*h* exp) (β x)) = exp x"
  using Eps1_hypext[of exp] by simp

lemma Eps2_hypext_exp_extract [simp]:
  "Eps2 ((*h* exp) (β x)) = exp x"
  using Eps2_hypext[of exp] by simp

lemma Eps12_hypext_exp_extract [simp]:
  "Eps12 ((*h* exp) (β x)) = exp x"
  using Eps12_hypext[of exp] by simp

subsubsection‹Square Root›
text‹Square root function extension.›

lemma hypext_sqrt_Hyperdual_Hyperdual:
  assumes "a > 0"
  shows "(*h* sqrt) (Hyperdual a b c d) =
         Hyperdual (sqrt a) (b * inverse (sqrt a) / 2) (c * inverse (sqrt a) / 2)
           (d * inverse (sqrt a) / 2 - b * c * inverse (sqrt a ^ 3) / 4)"
  by (simp add: assms hypext_Hyperdual_eq)

lemma hypext_sqrt_Hyperdual:
      "a > 0  (*h* sqrt) (Hyperdual a b c d) =
       sqrt a *H ba + (b * inverse (sqrt a) / 2) *H e1 + (c * inverse (sqrt a) / 2) *H e2 +
       (d * inverse (sqrt a) / 2 - b * c * inverse (sqrt a ^ 3) / 4) *H e12"
  by (auto simp add: hypext_Hyperdual_eq_parts)

lemma hypext_sqrt_extract:
  "x > 0  (*h* sqrt) (β x) = sqrt x *H ba + (inverse (sqrt x) / 2) *H e1 +
        (inverse (sqrt x) / 2) *H e2 - (inverse (sqrt x ^ 3) / 4) *H e12"
  by (simp add: hypext_sqrt_Hyperdual hyperdualx_def of_comp_minus scaleH_times)

text‹Extracting the parts for the square root extension.›

lemma Eps1_hypext_sqrt_extract [simp]:
  "x > 0  Eps1 ((*h* sqrt) (β x)) = inverse (sqrt x) / 2"
  using Eps1_hypext[of sqrt] by simp

lemma Eps2_hypext_sqrt_extract [simp]:
  "x > 0  Eps2 ((*h* sqrt) (β x)) = inverse (sqrt x) / 2"
  using Eps2_hypext[of sqrt] by simp

lemma Eps12_hypext_sqrt_extract [simp]:
  "x > 0  Eps12 ((*h* sqrt) (β x)) = - (inverse (sqrt x ^ 3) / 4)"
  using Eps12_hypext[of sqrt] by simp

(* example *)
lemma "Base x > 0  (*h* sin) x + (*h* sqrt) x = (*h* (λx. sin x + sqrt x)) x"
  by (simp add: hypext_fun_add)

subsubsection‹Natural Power›

lemma hypext_power:
  "(*h* (λx. x ^ n)) x = x ^ n"
  by (simp add: hyperdual_eq_iff hyperdual_power)

lemma hypext_fun_power:
  assumes "f twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. (f x) ^ n)) x = ((*h* f) x) ^ n"
  using assms hypext_compose[of f x "λx. x ^ n"] by (simp add: hypext_power)

lemma hypext_power_Hyperdual:
  "(*h* (λx. x ^ n)) (Hyperdual a b c d) =
        a ^ n *H ba + (of_nat n * b * a ^ (n - 1)) *H e1 + (of_nat n * c * a ^ (n - 1)) *H e2 +
        (d * (of_nat n * a ^ (n - 1)) + b * c * (of_nat n * of_nat (n - 1) * a ^ (n - 2))) *H e12"
  by (simp add: hypext_Hyperdual_eq_parts algebra_simps)

lemma hypext_power_Hyperdual_parts:
  "(*h* (λx. x ^ n)) (a *H ba + b *H e1 + c *H e2 + d *H e12) =
        a ^ n *H ba + (of_nat n * b * a ^ (n - 1)) *H e1 + (of_nat n * c * a ^ (n - 1)) *H e2 +
        (d * (of_nat n * a ^ (n - 1)) + b * c * (of_nat n * of_nat (n - 1) * a ^ (n - 2))) *H e12"
  by (simp add: Hyperdual_eq [symmetric] hypext_power_Hyperdual)

lemma hypext_power_extract:
  "(*h* (λx. x ^ n)) (β x) =
      x ^ n *H ba + (of_nat n * x ^ (n - 1)) *H e1 + (of_nat n * x ^ (n - 1)) *H e2 +
      (of_nat n * of_nat (n - 1) * x ^ (n - 2)) *H e12"
  by (simp add: hypext_extract_eq)

lemma Eps1_hypext_power [simp]:
  "Eps1 ((*h* (λx. x ^ n)) x) = of_nat n * Eps1 x * (Base x) ^ (n - 1)"
  by simp

lemma Eps2_hypext_power [simp]:
  "Eps2 ((*h* (λx. x ^ n)) x) = of_nat n * Eps2 x * (Base x) ^ (n - 1)"
  by simp

lemma Eps12_hypext_power [simp]:
  "Eps12 ((*h* (λx. x ^ n)) x) =
   Eps12 x * (of_nat n * Base x ^ (n - 1)) + Eps1 x * Eps2 x * (of_nat n * of_nat (n - 1) * Base x ^ (n - 2))"
  by simp


lemma hypext_inverse:
  assumes "Base x  0"
  shows "(*h* inverse) x = inverse x"
  using assms by (simp add: hyperdual_eq_iff inverse_eq_divide)

lemma hypext_fun_inverse:
  assumes "f twice_field_differentiable_at (Base x)"
      and "f (Base x)  0"
    shows "(*h* (λx. inverse (f x))) x = inverse ((*h* f) x)"
  using assms hypext_compose[of f x inverse] by (simp add: hypext_inverse)

lemma hypext_inverse_Hyperdual:
  "a  0 
    (*h* inverse) (Hyperdual a b c d) =
    Hyperdual (inverse a) (- (b / a2)) (- (c / a2)) (2 * b * c / (a ^ 3) - d / a2)"
  by (simp add: hypext_Hyperdual_eq divide_inverse)

lemma hypext_inverse_Hyperdual_parts:
  "a  0 
    (*h* inverse) (a *H ba + b *H e1 + c *H e2 + d *H e12) =
    inverse a *H ba + - (b / a2) *H e1 + - (c / a2) *H e2 + (2 * b * c / a ^ 3 - d / a2) *H e12"
  by (metis Hyperdual_eq hypext_inverse_Hyperdual)

lemma inverse_Hyperdual_parts:
  "(a::'a::real_normed_field)  0 
    inverse (a *H ba + b *H e1 + c *H e2 + d *H e12) =
    inverse a *H ba + - (b / a2) *H e1 + - (c / a2) *H e2 + (2 * b * c / a ^ 3 - d / a2) *H e12"
  by (metis Hyperdual_eq hyperdual.sel(1) hypext_inverse hypext_inverse_Hyperdual_parts)

lemma hypext_inverse_extract:
  "x  0  (*h* inverse) (β x) = inverse x *H ba - (1 / x2) *H e1 - (1 / x2) *H e2 + (2 / x ^ 3) *H e12"
  by (simp add: hypext_extract_eq divide_inverse of_comp_minus scaleH_times)

lemma inverse_extract:
  "x  0  inverse (β x) = inverse x *H ba - (1 / x2) *H e1 - (1 / x2) *H e2 + (2 / x ^ 3) *H e12"
  by (metis hyperdual.sel(1) hyperdualx_def hypext_inverse hypext_inverse_extract)

lemma Eps1_hypext_inverse [simp]:
  "Base x  0  Eps1 ((*h* inverse) x) = - Eps1 x * (1 / (Base x)2)"
  by simp
lemma Eps1_inverse [simp]:
  "Base (x::'a::real_normed_field hyperdual)  0  Eps1 (inverse x) = - Eps1 x * (1 / (Base x)2)"
  by simp

lemma Eps2_hypext_inverse [simp]:
  "Base (x::'a::real_normed_field hyperdual)  0  Eps2 (inverse x) = - Eps2 x * (1 / (Base x)2)"
  by simp

lemma Eps12_hypext_inverse [simp]:
  "Base  (x::'a::real_normed_field hyperdual)  0
    Eps12 (inverse x) = Eps1 x * Eps2 x * (2/ (Base x ^ 3)) - Eps12 x / (Base x)2"
  by simp


lemma hypext_fun_divide:
  assumes "f twice_field_differentiable_at (Base x)"
      and "g twice_field_differentiable_at (Base x)"
      and "g (Base x)  0"
    shows "(*h* (λx. f x / g x)) x = (*h* f) x / (*h* g) x"
proof -
  have "(λx. inverse (g x)) twice_field_differentiable_at Base x"
    by (simp add: assms(2) assms(3) twice_field_differentiable_at_compose)
  moreover have "(*h* f) x * (*h* (λx. inverse (g x))) x = (*h* f) x * inverse ((*h* g) x)"
    by (simp add: assms(2) assms(3) hypext_fun_inverse)
  ultimately have "(*h* (λx. f x * inverse (g x))) x = (*h* f) x * inverse ((*h* g) x)"
    by (simp add: assms(1) hypext_fun_mult)
  then show ?thesis
    by (simp add: divide_inverse hyp_divide_inverse)


lemma hypext_polyn:
  fixes coef :: "nat  'a :: {real_normed_field}"
    and n :: nat
  shows "(*h* (λx. i<n. coef i * x^i)) x = (i<n. (coef i) *H (x^i))"
proof (induction n)
  case 0
  then show ?case
    by (simp add: zero_hyperdual_def)
  case hyp: (Suc n)

  have "(λx. i<Suc n. coef i * x ^ i) = (λx. (i<n. coef i * x ^ i) + coef n * x ^ n)"
   and "(λx. i<Suc n. coef i *H x ^ i) = (λx. (i<n. coef i *H x ^ i) + coef n *H x ^ n)"
    by (simp_all add: field_simps)

  then show ?case
  proof (simp)
    have "(λx. coef n * x ^ n) twice_field_differentiable_at Base x"
      using twice_field_differentiable_at_compose[of "λx. x ^ n" "Base x" "(*) (coef n)"]
      by simp
    then have "(*h* (λx. (i<n. coef i * x ^ i) + coef n * x ^ n)) x =
          (*h* (λx. (i<n. coef i * x ^ i))) x + (*h* (λx. coef n * x ^ n)) x"
      by (simp add: hypext_fun_add)
    moreover have "(*h* (λx. coef n * x ^ n)) x = coef n *H x ^ n"
      by (simp add: hypext_fun_scaleH hypext_power)
    ultimately have "(*h* (λx. (i<n. coef i * x ^ i) + coef n * x ^ n)) x = (i<n. coef i *H x ^ i) + coef n *H x ^ n"
      using hyp by simp
    then show "(*h* (λx. (i<n. coef i * x ^ i) + coef n * x ^ n)) x = (i<n. coef i *H x ^ i) + coef n *H x ^ n"
      by simp

lemma hypext_fun_polyn:
    fixes coef :: "nat  'a :: {real_normed_field}"
      and n :: nat
  assumes "f twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. i<n. coef i * (f x)^i)) x = (i<n. (coef i) *H (((*h* f) x)^i))"
  using assms hypext_compose[of f x "λx. (i<n. coef i * x^i)"] by (simp add: hypext_polyn)
