From 37e2b34e0ebe01217124cc5d6d7c10134cd1edd7 Mon Sep 17 00:00:00 2001
From: "Vladimir V. Kisil" <kisilv@maths.leeds.ac.uk>
Date: Mon, 2 Feb 2015 08:30:34 +0000
Subject: [PATCH 2/2] Explicit derivation of functions.

Some function cannot be cleanly differentiated through the chain rule.
For example, it is natural to define derivative of the absolute value as

(abs(f))'=(f'*f.conjugate()+f*f'.conjugate())/2/abs(f)

This patch adds a possibility to define derivatives of functions in this way.
In particular the derivative of abs(), Order(), real_part(), imag_part() and
conjugate() are defined.

For example, conjugate of a derivative with respect of a real symbol
If x is real then U.diff(x)-I*V.diff(x) represents both
conjugate(U+I*V).diff(x) and conjugate((U+I*V).diff(x))
Thus in this patch we use the rule

conjugate(f)'=conjugate(f')

for a derivative with respect to the real symbol.

Signed-off-by: Vladimir V. Kisil <kisilv@maths.leeds.ac.uk>
---
 check/exam_inifcns.cpp  | 66 +++++++++++++++++++++++++++++++++++++++++++++++++
 doc/tutorial/ginac.texi | 26 ++++++++++++++++---
 ginac/function.cppy     | 33 +++++++++++++++++++++----
 ginac/function.hppy     |  6 +++++
 ginac/function.py       |  2 +-
 ginac/inifcns.cpp       | 57 +++++++++++++++++++++++++++++++++++++++++-
 6 files changed, 180 insertions(+), 10 deletions(-)

diff --git a/check/exam_inifcns.cpp b/check/exam_inifcns.cpp
index 19ad1b9..a0acb2d 100644
--- a/check/exam_inifcns.cpp
+++ b/check/exam_inifcns.cpp
@@ -343,6 +343,71 @@ static unsigned inifcns_consist_various()
 	return result;
 }
 
+/* Several tests for derivetives */
+static unsigned inifcns_consist_derivatives()
+{
+	unsigned result = 0;
+	symbol z, w;
+	realsymbol x;
+	ex e, e1;
+
+	e=pow(x,z).conjugate().diff(x);
+	e1=pow(x,z).conjugate()*z.conjugate()/x;
+	if (! (e-e1).normal().is_zero() ) {
+		clog << "ERROR: pow(x,z).conjugate().diff(x) " << e << " != " << e1 << endl;
+		++result;
+	}
+
+	e=pow(w,z).conjugate().diff(w);
+	e1=pow(w,z).conjugate()*z.conjugate()/w;
+	if ( (e-e1).normal().is_zero() ) {
+		clog << "ERROR: pow(w,z).conjugate().diff(w) " << e << " = " << e1 << endl;
+		++result;
+	}
+
+	e=atanh(x).imag_part().diff(x);
+	if (! e.is_zero() ) {
+		clog << "ERROR: atanh(x).imag_part().diff(x) " << e << " != 0" << endl;
+		++result;
+	}
+
+	e=atanh(w).imag_part().diff(w);
+	if ( e.is_zero() ) {
+		clog << "ERROR: atanh(w).imag_part().diff(w) " << e << " = 0" << endl;
+		++result;
+	}
+
+	e=atanh(x).real_part().diff(x);
+	e1=pow(1-x*x,-1);
+	if (! (e-e1).normal().is_zero() ) {
+		clog << "ERROR: atanh(x).real_part().diff(x) " << e << " != " << e1 << endl;
+		++result;
+	}
+
+	e=atanh(w).real_part().diff(w);
+	e1=pow(1-w*w,-1);
+	if ( (e-e1).normal().is_zero() ) {
+		clog << "ERROR: atanh(w).real_part().diff(w) " << e << " = " << e1 << endl;
+		++result;
+	}
+
+	e=abs(log(z)).diff(z);
+	e1=(conjugate(log(z))/z+log(z)/conjugate(z))/abs(log(z))/2;
+	if (! (e-e1).normal().is_zero() ) {
+		clog << "ERROR: abs(log(z)).diff(z) " << e << " != " << e1 << endl;
+		++result;
+	}
+
+	e=Order(pow(x,4)).diff(x);
+	e1=Order(pow(x,3));
+	if (! (e-e1).normal().is_zero() ) {
+		clog << "ERROR: Order(pow(x,4)).diff(x) " << e << " != " << e1 << endl;
+		++result;
+	}
+
+	return result;
+}
+
 unsigned exam_inifcns()
 {
 	unsigned result = 0;
@@ -357,6 +422,7 @@ unsigned exam_inifcns()
 	result += inifcns_consist_exp();  cout << '.' << flush;
 	result += inifcns_consist_log();  cout << '.' << flush;
 	result += inifcns_consist_various();  cout << '.' << flush;
+	result += inifcns_consist_derivatives();  cout << '.' << flush;
 	
 	return result;
 }
diff --git a/doc/tutorial/ginac.texi b/doc/tutorial/ginac.texi
index 21e31b2..3ac5398 100644
--- a/doc/tutorial/ginac.texi
+++ b/doc/tutorial/ginac.texi
@@ -7103,6 +7103,25 @@ specifies which parameter to differentiate in a partial derivative in
 case the function has more than one parameter, and its main application
 is for correct handling of the chain rule.
 
+Derivatives of some functions, for example @code{abs()} and
+@code{Order()}, could not be evaluated through the chain rule. In such
+cases the full derivative may be specified as shown for @code{Order()}:
+
+@example
+static ex Order_expl_derivative(const ex & arg, const symbol & s)
+@{
+	return Order(arg.diff(s));
+@}
+@end example
+
+That is, we need to supply a procedure, which returns the expression of
+derivative with respect to the variable @code{s} for the argument
+@code{arg}. This procedure need to be registered with the function
+through the option @code{expl_derivative_func} (see the next
+Subsection). In contrast, a partial derivative, e.g. as was defined for
+@code{cos()} above, needs to be registered through the option
+@code{derivative_func}. 
+
 An implementation of the series expansion is not needed for @code{cos()} as
 it doesn't have any poles and GiNaC can do Taylor expansion by itself (as
 long as it knows what the derivative of @code{cos()} is). @code{tan()}, on
@@ -7138,14 +7157,15 @@ functions without any special options.
 eval_func(<C++ function>)
 evalf_func(<C++ function>)
 derivative_func(<C++ function>)
+expl_derivative_func(<C++ function>)
 series_func(<C++ function>)
 conjugate_func(<C++ function>)
 @end example
 
 These specify the C++ functions that implement symbolic evaluation,
-numeric evaluation, partial derivatives, and series expansion, respectively.
-They correspond to the GiNaC methods @code{eval()}, @code{evalf()},
-@code{diff()} and @code{series()}.
+numeric evaluation, partial derivatives, explicit derivative, and series
+expansion, respectively.  They correspond to the GiNaC methods
+@code{eval()}, @code{evalf()}, @code{diff()} and @code{series()}.
 
 The @code{eval_func()} function needs to use @code{.hold()} if no further
 automatic evaluation is desired or possible.
diff --git a/ginac/function.cppy b/ginac/function.cppy
index d8a261f..dba9f4e 100644
--- a/ginac/function.cppy
+++ b/ginac/function.cppy
@@ -79,7 +79,7 @@ void function_options::initialize()
 	set_name("unnamed_function", "\\\\mbox{unnamed}");
 	nparams = 0;
 	eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
-		= derivative_f = power_f = series_f = 0;
+		= derivative_f = expl_derivative_f = power_f = series_f = 0;
 	info_f = 0;
 	evalf_params_first = true;
 	use_return_type = false;
@@ -90,6 +90,7 @@ void function_options::initialize()
 	imag_part_use_exvector_args = false;
 	expand_use_exvector_args = false;
 	derivative_use_exvector_args = false;
+	expl_derivative_use_exvector_args = false;
 	power_use_exvector_args = false;
 	series_use_exvector_args = false;
 	print_use_exvector_args = false;
@@ -630,10 +631,10 @@ ex function::derivative(const symbol & s) const
 {
 	ex result;
 
-	if (serial == Order_SERIAL::serial) {
-		// Order Term function only differentiates the argument
-		return Order(seq[0].diff(s));
-	} else {
+	try {
+		// Explicit derivation
+		result = expl_derivative(s);
+	} catch (...) {
 		// Chain rule
 		ex arg_diff;
 		size_t num = seq.size();
@@ -752,6 +753,28 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
 	throw(std::logic_error("function::pderivative(): no diff function defined"));
 }
 
+ex function::expl_derivative(const symbol & s) const // explicit differentiation
+{
+	GINAC_ASSERT(serial<registered_functions().size());
+	const function_options &opt = registered_functions()[serial];
+
+	// No explicit derivative defined? Then this function shall not be called!
+	if (opt.expl_derivative_f == NULL)
+		throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
+
+	current_serial = serial;
+	if (opt.expl_derivative_use_exvector_args)
+		return ((expl_derivative_funcp_exvector)(opt.expl_derivative_f))(seq, s);
+	switch (opt.nparams) {
+		// the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+		case @N@:
+			return ((expl_derivative_funcp_@N@)(opt.expl_derivative_f))(@seq('seq[%(n)d]', N, 0)@, s);
+---
+		// end of generated lines
+	}
+}
+
 ex function::power(const ex & power_param) const // power of function
 {
 	GINAC_ASSERT(serial<registered_functions().size());
diff --git a/ginac/function.hppy b/ginac/function.hppy
index 6259d7a..971786d 100644
--- a/ginac/function.hppy
+++ b/ginac/function.hppy
@@ -59,6 +59,7 @@ typedef ex (* real_part_funcp)();
 typedef ex (* imag_part_funcp)();
 typedef ex (* expand_funcp)();
 typedef ex (* derivative_funcp)();
+typedef ex (* expl_derivative_funcp)();
 typedef ex (* power_funcp)();
 typedef ex (* series_funcp)();
 typedef void (* print_funcp)();
@@ -73,6 +74,7 @@ typedef ex (* real_part_funcp_@N@)( @args@ );
 typedef ex (* imag_part_funcp_@N@)( @args@ );
 typedef ex (* expand_funcp_@N@)( @args@, unsigned );
 typedef ex (* derivative_funcp_@N@)( @args@, unsigned );
+typedef ex (* expl_derivative_funcp_@N@)( @args@, const symbol & );
 typedef ex (* power_funcp_@N@)( @args@, const ex & );
 typedef ex (* series_funcp_@N@)( @args@, const relational &, int, unsigned );
 typedef void (* print_funcp_@N@)( @args@, const print_context & );
@@ -87,6 +89,7 @@ typedef ex (* @fp@_funcp_exvector)(const exvector &);
 ---
 typedef ex (* expand_funcp_exvector)(const exvector &, unsigned);
 typedef ex (* derivative_funcp_exvector)(const exvector &, unsigned);
+typedef ex (* expl_derivative_funcp_exvector)(const exvector &, const symbol &);
 typedef ex (* power_funcp_exvector)(const exvector &, const ex &);
 typedef ex (* series_funcp_exvector)(const exvector &, const relational &, int, unsigned);
 typedef void (* print_funcp_exvector)(const exvector &, const print_context &);
@@ -159,6 +162,7 @@ protected:
 	imag_part_funcp imag_part_f;
 	expand_funcp expand_f;
 	derivative_funcp derivative_f;
+	expl_derivative_funcp expl_derivative_f;
 	power_funcp power_f;
 	series_funcp series_f;
 	std::vector<print_funcp> print_dispatch_table;
@@ -182,6 +186,7 @@ protected:
 	bool imag_part_use_exvector_args;
 	bool expand_use_exvector_args;
 	bool derivative_use_exvector_args;
+	bool expl_derivative_use_exvector_args;
 	bool power_use_exvector_args;
 	bool series_use_exvector_args;
 	bool print_use_exvector_args;
@@ -251,6 +256,7 @@ protected:
 	// non-virtual functions in this class
 protected:
 	ex pderivative(unsigned diff_param) const; // partial differentiation
+	ex expl_derivative(const symbol & s) const; // partial differentiation
 	static std::vector<function_options> & registered_functions();
 	bool lookup_remember_table(ex & result) const;
 	void store_remember_table(ex const & result) const;
diff --git a/ginac/function.py b/ginac/function.py
index 3f5e54e..465976b 100755
--- a/ginac/function.py
+++ b/ginac/function.py
@@ -2,7 +2,7 @@
 # encoding: utf-8
 
 maxargs = 14
-methods = "eval evalf conjugate real_part imag_part expand derivative power series info print".split()
+methods = "eval evalf conjugate real_part imag_part expand derivative expl_derivative power series info print".split()
 
 import sys, os, optparse
 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'scripts'))
diff --git a/ginac/inifcns.cpp b/ginac/inifcns.cpp
index 84cd285..a3ec4e0 100644
--- a/ginac/inifcns.cpp
+++ b/ginac/inifcns.cpp
@@ -24,6 +24,7 @@
 #include "ex.h"
 #include "constant.h"
 #include "lst.h"
+#include "fderivative.h"
 #include "matrix.h"
 #include "mul.h"
 #include "power.h"
@@ -66,6 +67,20 @@ static ex conjugate_conjugate(const ex & arg)
 	return arg;
 }
 
+// If x is real then U.diff(x)-I*V.diff(x) represents both conjugate(U+I*V).diff(x) 
+// and conjugate((U+I*V).diff(x))
+static ex conjugate_expl_derivative(const ex & arg, const symbol & s)
+{
+	if (s.info(info_flags::real))
+		return conjugate(arg.diff(s));
+	else {
+		unsigned this_serial = function::current_serial;
+		exvector vec_arg;
+		vec_arg.push_back(arg);
+		return fderivative(this_serial,0,vec_arg).hold()*arg.diff(s);
+	}
+}
+
 static ex conjugate_real_part(const ex & arg)
 {
 	return arg.real_part();
@@ -115,6 +130,7 @@ static bool conjugate_info(const ex & arg, unsigned inf)
 
 REGISTER_FUNCTION(conjugate_function, eval_func(conjugate_eval).
                                       evalf_func(conjugate_evalf).
+                                      expl_derivative_func(conjugate_expl_derivative).
                                       info_func(conjugate_info).
                                       print_func<print_latex>(conjugate_print_latex).
                                       conjugate_func(conjugate_conjugate).
@@ -159,8 +175,22 @@ static ex real_part_imag_part(const ex & arg)
 	return 0;
 }
 
+// If x is real then Re(e).diff(x) is equal to Re(e.diff(x)) 
+static ex real_part_expl_derivative(const ex & arg, const symbol & s)
+{
+	if (s.info(info_flags::real))
+		return real_part_function(arg.diff(s));
+	else {
+		unsigned this_serial = function::current_serial;
+		exvector vec_arg;
+		vec_arg.push_back(arg);
+		return fderivative(this_serial,0,vec_arg).hold()*arg.diff(s);
+	}
+}
+
 REGISTER_FUNCTION(real_part_function, eval_func(real_part_eval).
                                       evalf_func(real_part_evalf).
+                                      expl_derivative_func(real_part_expl_derivative).
                                       print_func<print_latex>(real_part_print_latex).
                                       conjugate_func(real_part_conjugate).
                                       real_part_func(real_part_real_part).
@@ -204,8 +234,22 @@ static ex imag_part_imag_part(const ex & arg)
 	return 0;
 }
 
+// If x is real then Im(e).diff(x) is equal to Im(e.diff(x)) 
+static ex imag_part_expl_derivative(const ex & arg, const symbol & s)
+{
+	if (s.info(info_flags::real))
+		return imag_part_function(arg.diff(s));
+	else {
+		unsigned this_serial = function::current_serial;
+		exvector vec_arg;
+		vec_arg.push_back(arg);
+		return fderivative(this_serial,0,vec_arg).hold()*arg.diff(s);
+	}
+}
+
 REGISTER_FUNCTION(imag_part_function, eval_func(imag_part_eval).
                                       evalf_func(imag_part_evalf).
+                                      expl_derivative_func(imag_part_expl_derivative).
                                       print_func<print_latex>(imag_part_print_latex).
                                       conjugate_func(imag_part_conjugate).
                                       real_part_func(imag_part_real_part).
@@ -275,6 +319,12 @@ static ex abs_expand(const ex & arg, unsigned options)
 		return abs(arg).hold();
 }
 
+static ex abs_expl_derivative(const ex & arg, const symbol & s)
+{
+	ex diff_arg = arg.diff(s);
+	return (diff_arg*arg.conjugate()+arg*diff_arg.conjugate())/2/abs(arg);
+}
+
 static void abs_print_latex(const ex & arg, const print_context & c)
 {
 	c.s << "{|"; arg.print(c); c.s << "|}";
@@ -341,6 +391,7 @@ bool abs_info(const ex & arg, unsigned inf)
 REGISTER_FUNCTION(abs, eval_func(abs_eval).
                        evalf_func(abs_evalf).
                        expand_func(abs_expand).
+                       expl_derivative_func(abs_expl_derivative).
                        info_func(abs_info).
                        print_func<print_latex>(abs_print_latex).
                        print_func<print_csrc_float>(abs_print_csrc_float).
@@ -977,11 +1028,15 @@ static ex Order_imag_part(const ex & x)
 	return Order(x).hold();
 }
 
-// Differentiation is handled in function::derivative because of its special requirements
+static ex Order_expl_derivative(const ex & arg, const symbol & s)
+{
+	return Order(arg.diff(s));
+}
 
 REGISTER_FUNCTION(Order, eval_func(Order_eval).
                          series_func(Order_series).
                          latex_name("\\mathcal{O}").
+                         expl_derivative_func(Order_expl_derivative).
                          conjugate_func(Order_conjugate).
                          real_part_func(Order_real_part).
                          imag_part_func(Order_imag_part));
-- 
2.1.4

