Theory HyperdualFunctionExtension

(*  Title:   HyperdualFunctionExtension.thy
    Authors: Jacques D. Fleuriot and Filip Smola, University of Edinburgh, 2021
*)

section ‹Hyperdual Extension of Functions›

theory HyperdualFunctionExtension
  imports Hyperdual TwiceFieldDifferentiable
begin

text‹The following is an important fact in the derivation of the hyperdual extension.›
lemma
    fixes x :: "('a :: comm_ring_1) hyperdual" and n :: nat
  assumes "Base x = 0"
    shows "x ^ (n + 3) = 0"
proof (induct n)
  case 0
  then show ?case
    using assms hyperdual_power[of x 3] by simp
next
  case (Suc n)
  then show ?case
    using assms power_Suc[of x "n + 3"] mult_zero_right add_Suc by simp
qed

text‹We define the extension of a function to the hyperdual numbers.›
primcorec hypext :: "(('a :: real_normed_field)  'a)  'a hyperdual  'a hyperdual" (*h* _› [80] 80)
  where
    "Base ((*h* f) x) = f (Base x)"
  | "Eps1 ((*h* f) x) = Eps1 x * deriv f (Base x)"
  | "Eps2 ((*h* f) x) = Eps2 x * deriv f (Base x)"
  | "Eps12 ((*h* f) x) = Eps12 x * deriv f (Base x) + Eps1 x * Eps2 x * deriv (deriv f) (Base x)"

text‹This has the expected behaviour when expressed in terms of the units.›
lemma hypext_Hyperdual_eq:
  "(*h* f) (Hyperdual a b c d) =
     Hyperdual (f a) (b * deriv f a) (c * deriv f a) (d * deriv f a + b * c * deriv (deriv f) a)"
  by (simp add: hypext.code)

lemma hypext_Hyperdual_eq_parts:
  "(*h* f) (Hyperdual a b c d) =
      f a *H ba + (b * deriv f a) *H e1 + (c * deriv f a) *H e2 +
         (d * deriv f a + b * c * deriv (deriv f) a) *H e12 "
  by (metis Hyperdual_eq hypext_Hyperdual_eq)

text‹
  The extension can be used to extract the function value, and first and second derivatives at x
  when applied to @{term "x *H re + e1 + e2 + 0 *H e12"}, which we denote by @{term "β x"}.
›
definition hyperdualx :: "('a :: real_normed_field)  'a hyperdual" ("β")
  where "β x = (Hyperdual x 1 1 0)"

lemma hyperdualx_sel [simp]:
  shows "Base (β x) = x"
    and "Eps1 (β x) = 1"
    and "Eps2 (β x) = 1"
    and "Eps12 (β x) = 0"
  by (simp_all add: hyperdualx_def)

lemma hypext_extract_eq:
  "(*h* f) (β x) = f x *H ba + deriv f x *H e1 + deriv f x *H e2 + deriv (deriv f) x *H e12"
  by (simp add: hypext_Hyperdual_eq_parts hyperdualx_def)

lemma Base_hypext:
  "Base ((*h* f) (β x)) = f x"
  by (simp add: hyperdualx_def)

lemma Eps1_hypext:
  "Eps1 ((*h* f) (β x)) = deriv f x"
  by (simp add: hyperdualx_def)

lemma Eps2_hypext:
  "Eps2 ((*h* f) (β x)) = deriv f x"
  by (simp add: hyperdualx_def)

lemma Eps12_hypext:
  "Eps12 ((*h* f) (β x)) = deriv (deriv f) x"
  by (simp add: hyperdualx_def)

subsubsection‹Convenience Interface›

text‹Define a datatype to hold the function value, and the first and second derivative values.›
datatype ('a :: real_normed_field) derivs = Derivs (Value: 'a) (First: 'a) (Second: 'a)

text‹
  Then we convert a hyperdual number to derivative values by extracting the base component, one of
  the first-order components, and the second-order component.
›
fun hyperdual_to_derivs :: "('a :: real_normed_field) hyperdual  'a derivs"
  where "hyperdual_to_derivs x = Derivs (Base x) (Eps1 x) (Eps12 x)"

text‹
  Finally we define way of converting any compatible function into one that yields the value and the
  derivatives.
›
fun autodiff :: "('a :: real_normed_field  'a)  'a  'a derivs"
  where "autodiff f = (λx. hyperdual_to_derivs ((*h* f) (β x)))"

lemma autodiff_sel:
  "Value (autodiff f x) = Base ((*h* f) (β x))"
  "First (autodiff f x) = Eps1 ((*h* f) (β x))"
  "Second (autodiff f x) = Eps12 ((*h* f) (β x))"
  by simp_all

text‹The result contains the expected values.›
lemma autodiff_extract_value:
  "Value (autodiff f x) = f x"
  by (simp del: hypext.simps add: Base_hypext)

lemma autodiff_extract_first:
  "First (autodiff f x) = deriv f x"
  by (simp del: hypext.simps add: Eps1_hypext)

lemma autodiff_extract_second:
  "Second (autodiff f x) = deriv (deriv f) x"
  by (simp del: hypext.simps add: Eps12_hypext)

text‹
  The derivative components of the result are actual derivatives if the function is sufficiently
  differentiable on that argument.
›
lemma autodiff_first_derivative:
  assumes "f field_differentiable (at x)"
  shows "(f has_field_derivative First (autodiff f x)) (at x)"
  by (simp add: autodiff_extract_first DERIV_deriv_iff_field_differentiable assms)

lemma autodiff_second_derivative:
  assumes "f twice_field_differentiable_at x"
  shows "((deriv f) has_field_derivative Second (autodiff f x)) (at x)"
  by (simp add: autodiff_extract_second DERIV_deriv_iff_field_differentiable assms deriv_field_differentiable_at)

subsubsection‹Composition›

text‹Composition of hyperdual extensions is the hyperdual extension of composition:›
lemma hypext_compose:
  assumes "f twice_field_differentiable_at (Base x)"
      and "g twice_field_differentiable_at (f (Base x))"
    shows "(*h* (λx. g (f x))) x = (*h* g) ((*h* f) x)"
proof (simp add: hyperdual_eq_iff, intro conjI disjI2)
  show goal1: "deriv (λx. g (f x)) (Base x) = deriv f (Base x) * deriv g (f (Base x))"
  proof -
    have "deriv (λx. g (f x)) (Base x) = deriv (g  f) (Base x)"
      by (simp add: comp_def)
    also have "... = deriv g (f (Base x)) * deriv f (Base x)"
      using assms by (simp add: deriv_chain once_field_differentiable_at)
    finally show ?thesis
      by (simp add: mult.commute deriv_chain)
  qed
  then show "deriv (λx. g (f x)) (Base x) = deriv f (Base x) * deriv g (f (Base x))" .

  have first_diff: "(λx. deriv g (f x)) field_differentiable at (Base x)"
    by (metis DERIV_chain2 assms deriv_field_differentiable_at field_differentiable_def once_field_differentiable_at)

  have "deriv (deriv g  f) (Base x) = deriv (deriv g) (f (Base x)) * deriv f (Base x)"
    using deriv_chain assms once_field_differentiable_at deriv_field_differentiable_at
    by blast
  then have deriv_deriv_comp: "deriv (λx. deriv g (f x)) (Base x) = deriv (deriv g) (f (Base x)) * deriv f (Base x)"
    by (simp add: comp_def)

  have "deriv (deriv (λx. g (f x))) (Base x) = deriv ((λx. deriv f x * deriv g (f x))) (Base x)"
    using assms eventually_deriv_compose'[of f "Base x" g]
    by (simp add: mult.commute deriv_cong_ev)
  also have "... = deriv f (Base x) * deriv (λx. deriv g (f x)) (Base x) + deriv (deriv f) (Base x) * deriv g (f (Base x))"
    using assms(1) first_diff by (simp add: deriv_field_differentiable_at)
  also have "... = deriv f (Base x) * deriv (deriv g) (f (Base x)) * deriv f (Base x) + deriv (deriv f) (Base x) * deriv g (f (Base x))"
    using deriv_deriv_comp by simp
  finally show "Eps12 x * deriv (λx. g (f x)) (Base x) + Eps1 x * Eps2 x * deriv (deriv (λx. g (f x))) (Base x) =
    (Eps12 x * deriv f (Base x) + Eps1 x * Eps2 x * deriv (deriv f) (Base x)) * deriv g (f (Base x)) +
    Eps1 x * deriv f (Base x) * (Eps2 x * deriv f (Base x)) * deriv (deriv g) (f (Base x))"
    by (simp add: goal1 field_simps)
qed

subsection‹Concrete Instances›

subsubsection‹Constant›

text‹Component embedding is an extension of the constant function.›
lemma hypext_const [simp]:
  "(*h* (λx. a)) x = of_comp a"
  by (simp add: of_comp_def hyperdual_eq_iff)

lemma "autodiff (λx. a) = (λx. Derivs a 0 0)"
  by simp

subsubsection‹Identity›

text‹Identity is an extension of the component identity.›
lemma hypext_ident:
  "(*h* (λx. x)) x = x"
  by (simp add: hyperdual_eq_iff)

subsubsection‹Component Scalar Multiplication›

text‹Component scaling is an extension of component constant multiplication:›
lemma hypext_scaleH:
  "(*h* (λx. k * x)) x = k *H x"
  by (simp add: hyperdual_eq_iff)

lemma hypext_fun_scaleH:
  assumes "f twice_field_differentiable_at (Base x)"
  shows "(*h* (λx. k * f x)) x = k *H (*h* f) x"
  using assms by (simp add: hypext_compose hypext_scaleH)

text‹Unary minus is just an instance of constant multiplication:›
lemma hypext_uminus:
  "(*h* uminus) x = - x"
  using hypext_scaleH[of "-1" x] by simp

subsubsection‹Real Scalar Multiplication›

text‹Real scaling is an extension of component real scaling:›
lemma hypext_scaleR:
  "(*h* (λx. k *R x)) x = k *R x"
  by (auto simp add: hyperdual_eq_iff)

lemma hypext_fun_scaleR:
  assumes "f twice_field_differentiable_at (Base x)"
  shows "(*h* (λx. k *R f x)) x = k *R (*h* f) x"
  using assms by (simp add: hypext_compose hypext_scaleR)

subsubsection‹Addition›

text‹Addition of hyperdual extensions is a hyperdual extension of addition of functions.›
lemma hypext_fun_add:
  assumes "f twice_field_differentiable_at (Base x)"
      and "g twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. f x + g x)) x = (*h* f) x + (*h* g) x"
proof (simp add: hyperdual_eq_iff distrib_left[symmetric], intro conjI disjI2)
  show goal1: "deriv (λx. f x + g x) (Base x) = deriv f (Base x) + deriv g (Base x)"
    by (simp add: assms once_field_differentiable_at distrib_left)
  then show "deriv (λx. f x + g x) (Base x) = deriv f (Base x) + deriv g (Base x)" .

  have "deriv (deriv (λx. f x + g x)) (Base x) = deriv (λw. deriv f w + deriv g w) (Base x)"
    by (simp add: assms deriv_cong_ev eventually_deriv_add)
  moreover have "Eps12 x * deriv f (Base x) + Eps12 x * deriv g (Base x) = Eps12 x * deriv (λx. f x + g x) (Base x)"
    by (metis distrib_left goal1)
  ultimately show "Eps12 x * deriv (λx. f x + g x) (Base x) +
    Eps1 x * Eps2 x * deriv (deriv (λx. f x + g x)) (Base x) =
    Eps12 x * deriv f (Base x) + Eps1 x * Eps2 x * deriv (deriv f) (Base x) +
    (Eps12 x * deriv g (Base x) + Eps1 x * Eps2 x * deriv (deriv g) (Base x))"
    using deriv_add[OF deriv_field_differentiable_at deriv_field_differentiable_at, OF assms]
    by (simp add: distrib_left add.left_commute)
qed

lemma hypext_cadd [simp]:
  "(*h* (λx. x + a)) x = x + of_comp a"
  by (auto simp add: hyperdual_eq_iff of_comp_def)

lemma hypext_fun_cadd:
  assumes "f twice_field_differentiable_at (Base x)"
  shows "(*h* (λx. f x + a)) x = (*h* f) x + of_comp a"
  using assms hypext_compose[of f x "λx. x + a"] by simp

subsubsection‹Component Linear Function›

text‹Hyperdual linear function is an extension of the component linear function:›
lemma hypext_linear:
  "(*h* (λx. k * x + a)) x = k *H x + of_comp a"
  using hypext_fun_add[of "(*) k" x "λx. a"]
  by (simp add: hypext_scaleH)

lemma hypext_fun_linear:
  assumes "f twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. k * f x + a)) x = k *H (*h* f) x + of_comp a"
  using assms hypext_compose[of f x "λx. k * x + a"] by (simp add: hypext_linear)

subsubsection‹Real Linear Function›

text‹We have the same for real scaling instead of component multiplication:›
lemma hypext_linearR:
  "(*h* (λx. k *R x + a)) x = k *R x + of_comp a"
  using hypext_fun_add[of "(*R) k" x "λx. a"]
  by (simp add: hypext_scaleR)

lemma hypext_fun_linearR:
  assumes "f twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. k *R f x + a)) x = k *R (*h* f) x + of_comp a"
  using assms hypext_compose[of f x "λx. k *R x + a"] by (simp add: hypext_linearR)

subsubsection‹Multiplication›

text‹Extension of multiplication is multiplication of the functions' extensions.›
lemma hypext_fun_mult:
  assumes "f twice_field_differentiable_at (Base x)"
      and "g twice_field_differentiable_at (Base x)"
    shows "(*h* (λz. f z * g z)) x = (*h* f) x * (*h* g) x"
proof (simp add: hyperdual_eq_iff distrib_left[symmetric], intro conjI)
  show "Eps1 x * deriv (λz. f z * g z) (Base x) =
        f (Base x) * (Eps1 x * deriv g (Base x)) + Eps1 x * deriv f (Base x) * g (Base x)"
   and "Eps2 x * deriv (λz. f z * g z) (Base x) =
        f (Base x) * (Eps2 x * deriv g (Base x)) + Eps2 x * deriv f (Base x) * g (Base x)"
    using assms by (simp_all add: once_field_differentiable_at distrib_left)

  have
    "deriv (deriv (λz. f z * g z)) (Base x) =
     f (Base x) * deriv (deriv g) (Base x) + 2 * deriv f (Base x) * deriv g (Base x) + deriv (deriv f) (Base x) * g (Base x)"
  proof -
    have "deriv (deriv (λz. f z * g z)) (Base x) = deriv (λz. f z * deriv g z + deriv f z * g z) (Base x)"
      using assms by (simp add: eventually_deriv_mult deriv_cong_ev)
    also have "... = (λz. f z * deriv (deriv g) z + deriv f z * deriv g z + deriv f z * deriv g z + deriv (deriv f) z * g z) (Base x)"
      by (simp add: assms deriv_field_differentiable_at field_differentiable_mult once_field_differentiable_at)
    finally show ?thesis
      by simp
  qed
  then show
    "Eps12 x * deriv (λz. f z * g z) (Base x) + Eps1 x * Eps2 x * deriv (deriv (λz. f z * g z)) (Base x) =
     2 * (Eps1 x * (Eps2 x * (deriv f (Base x) * deriv g (Base x)))) +
     f (Base x) * (Eps12 x * deriv g (Base x) + Eps1 x * Eps2 x * deriv (deriv g) (Base x)) +
     (Eps12 x * deriv f (Base x) + Eps1 x * Eps2 x * deriv (deriv f) (Base x)) * g (Base x)"
     using assms by (simp add: once_field_differentiable_at field_simps)
qed

subsubsection‹Sine and Cosine›

text‹The extended sin and cos at an arbitrary hyperdual.›

lemma hypext_sin_Hyperdual:
  "(*h* sin) (Hyperdual a b c d) = sin a *H ba + (b *cos a) *H e1 + (c * cos a) *H e2 + (d * cos a - b * c * sin a) *H e12 "
  by (simp add: hypext_Hyperdual_eq_parts)

lemma hypext_cos_Hyperdual:
  "(*h* cos) (Hyperdual a b c d) = cos a *H ba - (b * sin a) *H e1 - (c * sin a) *H e2 - (d * sin a + b * c * cos a) *H e12 "
proof -
  have "of_comp (- (d * sin a) - b * c * cos a) * e12 = - (of_comp (d * sin a + b * c * cos a) * e12)"
    by (metis add_uminus_conv_diff minus_add_distrib mult_minus_left of_comp_minus)
  then show ?thesis
    by (simp add: hypext_Hyperdual_eq_parts of_comp_minus scaleH_times)
qed

lemma Eps1_hypext_sin [simp]:
  "Eps1 ((*h* sin) x) = Eps1 x * cos (Base x)"
  by simp

lemma Eps2_hypext_sin [simp]:
  "Eps2 ((*h* sin) x) = Eps2 x * cos (Base x)"
  by simp

lemma Eps12_hypext_sin [simp]:
  "Eps12 ((*h* sin) x) = Eps12 x * cos (Base x) - Eps1 x * Eps2 x * sin (Base x)"
  by simp

lemma hypext_sin_e1 [simp]:
  "(*h* sin) (x * e1) = e1 * x"
  by (simp add: e1_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_sin_e2 [simp]:
  "(*h* sin) (x * e2) = e2 * x"
  by (simp add: e2_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_sin_e12 [simp]:
  "(*h* sin) (x * e12) = e12 * x"
  by (simp add: e12_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_cos_e1 [simp]:
  "(*h* cos) (x * e1) = 1"
  by (simp add: e1_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_cos_e2 [simp]:
  "(*h* cos) (x * e2) = 1"
  by (simp add: e2_def hyperdual_eq_iff one_hyperdual_def)

lemma hypext_cos_e12 [simp]:
  "(*h* cos) (x * e12) = 1"
  by (simp add: e12_def hyperdual_eq_iff one_hyperdual_def)

text‹The extended sin and cos at @{term "β x"}.›

lemma hypext_sin_extract:
  "(*h* sin) (β x) = sin x *H ba + cos x *H e1 + cos x *H e2 - sin x *H e12"
  by (simp add: hypext_sin_Hyperdual of_comp_minus scaleH_times hyperdualx_def)

lemma hypext_cos_extract:
  "(*h* cos) (β x) = cos x *H ba - sin x *H e1 - sin x *H e2 - cos x *H e12"
  by (simp add: hypext_cos_Hyperdual hyperdualx_def)

text‹Extracting the extended sin components at @{term "β x"}.›

lemma Base_hypext_sin_extract [simp]:
  "Base ((*h* sin) (β x)) = sin x"
  by (rule Base_hypext)

lemma Eps2_hypext_sin_extract [simp]:
  "Eps2 ((*h* sin) (β x)) = cos x"
  using Eps2_hypext[of sin] by simp

lemma Eps12_hypext_sin_extract [simp]:
  "Eps12 ((*h* sin) (β x)) = - sin x"
  using Eps12_hypext[of sin] by simp

text‹Extracting the extended cos components at @{term "β x"}.›

lemma Base_hypext_cos_extract [simp]:
  "Base ((*h* cos) (β x)) = cos x"
  by (rule Base_hypext)

lemma Eps2_hypext_cos_extract [simp]:
  "Eps2 ((*h* cos) (β x)) = - sin x"
  using Eps2_hypext[of cos] by simp

lemma Eps12_hypext_cos_extract [simp]:
  "Eps12 ((*h* cos) (β x)) = - cos x"
  using Eps12_hypext[of cos] by simp

text‹We get one of the key trigonometric properties for the extensions of sin and cos.›

lemma "((*h* sin) x)2 + ((*h* cos) x)2 = 1"
  by (simp add: hyperdual_eq_iff one_hyperdual_def power2_eq_square field_simps)

(* example *)
lemma "(*h* sin) x + (*h* cos) x = (*h* (λx. sin x + cos x)) x"
  by (simp add: hypext_fun_add)

subsubsection‹Exponential›

text‹The exponential function extension behaves as expected.›

lemma hypext_exp_Hyperdual:
  "(*h* exp) (Hyperdual a b c d) =
       exp a *H ba + (b * exp a) *H e1 + (c * exp a) *H e2 + (d * exp a + b * c * exp a) *H e12"
  by (simp add: hypext_Hyperdual_eq_parts)

lemma hypext_exp_extract:
  "(*h* exp) (β x) = exp x *H ba + exp x *H e1 + exp x *H e2 + exp x *H e12"
  by (simp add: hypext_extract_eq)

lemma hypext_exp_e1 [simp]:
  "(*h* exp) (x * e1) = 1 + e1 * x"
  by (simp add: e1_def hyperdual_eq_iff)

lemma hypext_exp_e2 [simp]:
  "(*h* exp) (x * e2) = 1 + e2 * x"
  by (simp add: e2_def hyperdual_eq_iff)

lemma hypext_exp_e12 [simp]:
  "(*h* exp) (x * e12) = 1 + e12 * x"
  by (simp add: e12_def hyperdual_eq_iff)

text‹Extracting the parts for the exponential function extension.›

lemma Eps1_hypext_exp_extract [simp]:
  "Eps1 ((*h* exp) (β x)) = exp x"
  using Eps1_hypext[of exp] by simp

lemma Eps2_hypext_exp_extract [simp]:
  "Eps2 ((*h* exp) (β x)) = exp x"
  using Eps2_hypext[of exp] by simp

lemma Eps12_hypext_exp_extract [simp]:
  "Eps12 ((*h* exp) (β x)) = exp x"
  using Eps12_hypext[of exp] by simp

subsubsection‹Square Root›
text‹Square root function extension.›

lemma hypext_sqrt_Hyperdual_Hyperdual:
  assumes "a > 0"
  shows "(*h* sqrt) (Hyperdual a b c d) =
         Hyperdual (sqrt a) (b * inverse (sqrt a) / 2) (c * inverse (sqrt a) / 2)
           (d * inverse (sqrt a) / 2 - b * c * inverse (sqrt a ^ 3) / 4)"
  by (simp add: assms hypext_Hyperdual_eq)

lemma hypext_sqrt_Hyperdual:
      "a > 0  (*h* sqrt) (Hyperdual a b c d) =
       sqrt a *H ba + (b * inverse (sqrt a) / 2) *H e1 + (c * inverse (sqrt a) / 2) *H e2 +
       (d * inverse (sqrt a) / 2 - b * c * inverse (sqrt a ^ 3) / 4) *H e12"
  by (auto simp add: hypext_Hyperdual_eq_parts)

lemma hypext_sqrt_extract:
  "x > 0  (*h* sqrt) (β x) = sqrt x *H ba + (inverse (sqrt x) / 2) *H e1 +
        (inverse (sqrt x) / 2) *H e2 - (inverse (sqrt x ^ 3) / 4) *H e12"
  by (simp add: hypext_sqrt_Hyperdual hyperdualx_def of_comp_minus scaleH_times)

text‹Extracting the parts for the square root extension.›

lemma Eps1_hypext_sqrt_extract [simp]:
  "x > 0  Eps1 ((*h* sqrt) (β x)) = inverse (sqrt x) / 2"
  using Eps1_hypext[of sqrt] by simp

lemma Eps2_hypext_sqrt_extract [simp]:
  "x > 0  Eps2 ((*h* sqrt) (β x)) = inverse (sqrt x) / 2"
  using Eps2_hypext[of sqrt] by simp

lemma Eps12_hypext_sqrt_extract [simp]:
  "x > 0  Eps12 ((*h* sqrt) (β x)) = - (inverse (sqrt x ^ 3) / 4)"
  using Eps12_hypext[of sqrt] by simp

(* example *)
lemma "Base x > 0  (*h* sin) x + (*h* sqrt) x = (*h* (λx. sin x + sqrt x)) x"
  by (simp add: hypext_fun_add)

subsubsection‹Natural Power›

lemma hypext_power:
  "(*h* (λx. x ^ n)) x = x ^ n"
  by (simp add: hyperdual_eq_iff hyperdual_power)

lemma hypext_fun_power:
  assumes "f twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. (f x) ^ n)) x = ((*h* f) x) ^ n"
  using assms hypext_compose[of f x "λx. x ^ n"] by (simp add: hypext_power)

lemma hypext_power_Hyperdual:
  "(*h* (λx. x ^ n)) (Hyperdual a b c d) =
        a ^ n *H ba + (of_nat n * b * a ^ (n - 1)) *H e1 + (of_nat n * c * a ^ (n - 1)) *H e2 +
        (d * (of_nat n * a ^ (n - 1)) + b * c * (of_nat n * of_nat (n - 1) * a ^ (n - 2))) *H e12"
  by (simp add: hypext_Hyperdual_eq_parts algebra_simps)

lemma hypext_power_Hyperdual_parts:
  "(*h* (λx. x ^ n)) (a *H ba + b *H e1 + c *H e2 + d *H e12) =
        a ^ n *H ba + (of_nat n * b * a ^ (n - 1)) *H e1 + (of_nat n * c * a ^ (n - 1)) *H e2 +
        (d * (of_nat n * a ^ (n - 1)) + b * c * (of_nat n * of_nat (n - 1) * a ^ (n - 2))) *H e12"
  by (simp add: Hyperdual_eq [symmetric] hypext_power_Hyperdual)

lemma hypext_power_extract:
  "(*h* (λx. x ^ n)) (β x) =
      x ^ n *H ba + (of_nat n * x ^ (n - 1)) *H e1 + (of_nat n * x ^ (n - 1)) *H e2 +
      (of_nat n * of_nat (n - 1) * x ^ (n - 2)) *H e12"
  by (simp add: hypext_extract_eq)

lemma Eps1_hypext_power [simp]:
  "Eps1 ((*h* (λx. x ^ n)) x) = of_nat n * Eps1 x * (Base x) ^ (n - 1)"
  by simp

lemma Eps2_hypext_power [simp]:
  "Eps2 ((*h* (λx. x ^ n)) x) = of_nat n * Eps2 x * (Base x) ^ (n - 1)"
  by simp

lemma Eps12_hypext_power [simp]:
  "Eps12 ((*h* (λx. x ^ n)) x) =
   Eps12 x * (of_nat n * Base x ^ (n - 1)) + Eps1 x * Eps2 x * (of_nat n * of_nat (n - 1) * Base x ^ (n - 2))"
  by simp

subsubsection‹Inverse›

lemma hypext_inverse:
  assumes "Base x  0"
  shows "(*h* inverse) x = inverse x"
  using assms by (simp add: hyperdual_eq_iff inverse_eq_divide)

lemma hypext_fun_inverse:
  assumes "f twice_field_differentiable_at (Base x)"
      and "f (Base x)  0"
    shows "(*h* (λx. inverse (f x))) x = inverse ((*h* f) x)"
  using assms hypext_compose[of f x inverse] by (simp add: hypext_inverse)

lemma hypext_inverse_Hyperdual:
  "a  0 
    (*h* inverse) (Hyperdual a b c d) =
    Hyperdual (inverse a) (- (b / a2)) (- (c / a2)) (2 * b * c / (a ^ 3) - d / a2)"
  by (simp add: hypext_Hyperdual_eq divide_inverse)

lemma hypext_inverse_Hyperdual_parts:
  "a  0 
    (*h* inverse) (a *H ba + b *H e1 + c *H e2 + d *H e12) =
    inverse a *H ba + - (b / a2) *H e1 + - (c / a2) *H e2 + (2 * b * c / a ^ 3 - d / a2) *H e12"
  by (metis Hyperdual_eq hypext_inverse_Hyperdual)

lemma inverse_Hyperdual_parts:
  "(a::'a::real_normed_field)  0 
    inverse (a *H ba + b *H e1 + c *H e2 + d *H e12) =
    inverse a *H ba + - (b / a2) *H e1 + - (c / a2) *H e2 + (2 * b * c / a ^ 3 - d / a2) *H e12"
  by (metis Hyperdual_eq hyperdual.sel(1) hypext_inverse hypext_inverse_Hyperdual_parts)

lemma hypext_inverse_extract:
  "x  0  (*h* inverse) (β x) = inverse x *H ba - (1 / x2) *H e1 - (1 / x2) *H e2 + (2 / x ^ 3) *H e12"
  by (simp add: hypext_extract_eq divide_inverse of_comp_minus scaleH_times)

lemma inverse_extract:
  "x  0  inverse (β x) = inverse x *H ba - (1 / x2) *H e1 - (1 / x2) *H e2 + (2 / x ^ 3) *H e12"
  by (metis hyperdual.sel(1) hyperdualx_def hypext_inverse hypext_inverse_extract)

lemma Eps1_hypext_inverse [simp]:
  "Base x  0  Eps1 ((*h* inverse) x) = - Eps1 x * (1 / (Base x)2)"
  by simp
lemma Eps1_inverse [simp]:
  "Base (x::'a::real_normed_field hyperdual)  0  Eps1 (inverse x) = - Eps1 x * (1 / (Base x)2)"
  by simp

lemma Eps2_hypext_inverse [simp]:
  "Base (x::'a::real_normed_field hyperdual)  0  Eps2 (inverse x) = - Eps2 x * (1 / (Base x)2)"
  by simp

lemma Eps12_hypext_inverse [simp]:
  "Base  (x::'a::real_normed_field hyperdual)  0
    Eps12 (inverse x) = Eps1 x * Eps2 x * (2/ (Base x ^ 3)) - Eps12 x / (Base x)2"
  by simp

subsubsection‹Division›

lemma hypext_fun_divide:
  assumes "f twice_field_differentiable_at (Base x)"
      and "g twice_field_differentiable_at (Base x)"
      and "g (Base x)  0"
    shows "(*h* (λx. f x / g x)) x = (*h* f) x / (*h* g) x"
proof -
  have "(λx. inverse (g x)) twice_field_differentiable_at Base x"
    by (simp add: assms(2) assms(3) twice_field_differentiable_at_compose)
  moreover have "(*h* f) x * (*h* (λx. inverse (g x))) x = (*h* f) x * inverse ((*h* g) x)"
    by (simp add: assms(2) assms(3) hypext_fun_inverse)
  ultimately have "(*h* (λx. f x * inverse (g x))) x = (*h* f) x * inverse ((*h* g) x)"
    by (simp add: assms(1) hypext_fun_mult)
  then show ?thesis
    by (simp add: divide_inverse hyp_divide_inverse)
qed

subsubsection‹Polynomial›

lemma hypext_polyn:
  fixes coef :: "nat  'a :: {real_normed_field}"
    and n :: nat
  shows "(*h* (λx. i<n. coef i * x^i)) x = (i<n. (coef i) *H (x^i))"
proof (induction n)
  case 0
  then show ?case
    by (simp add: zero_hyperdual_def)
next
  case hyp: (Suc n)

  have "(λx. i<Suc n. coef i * x ^ i) = (λx. (i<n. coef i * x ^ i) + coef n * x ^ n)"
   and "(λx. i<Suc n. coef i *H x ^ i) = (λx. (i<n. coef i *H x ^ i) + coef n *H x ^ n)"
    by (simp_all add: field_simps)

  then show ?case
  proof (simp)
    have "(λx. coef n * x ^ n) twice_field_differentiable_at Base x"
      using twice_field_differentiable_at_compose[of "λx. x ^ n" "Base x" "(*) (coef n)"]
      by simp
    then have "(*h* (λx. (i<n. coef i * x ^ i) + coef n * x ^ n)) x =
          (*h* (λx. (i<n. coef i * x ^ i))) x + (*h* (λx. coef n * x ^ n)) x"
      by (simp add: hypext_fun_add)
    moreover have "(*h* (λx. coef n * x ^ n)) x = coef n *H x ^ n"
      by (simp add: hypext_fun_scaleH hypext_power)
    ultimately have "(*h* (λx. (i<n. coef i * x ^ i) + coef n * x ^ n)) x = (i<n. coef i *H x ^ i) + coef n *H x ^ n"
      using hyp by simp
    then show "(*h* (λx. (i<n. coef i * x ^ i) + coef n * x ^ n)) x = (i<n. coef i *H x ^ i) + coef n *H x ^ n"
      by simp
  qed
qed

lemma hypext_fun_polyn:
    fixes coef :: "nat  'a :: {real_normed_field}"
      and n :: nat
  assumes "f twice_field_differentiable_at (Base x)"
    shows "(*h* (λx. i<n. coef i * (f x)^i)) x = (i<n. (coef i) *H (((*h* f) x)^i))"
  using assms hypext_compose[of f x "λx. (i<n. coef i * x^i)"] by (simp add: hypext_polyn)

end