From 900831f0aefb53e414d70dfcd858be16e75281fa Mon Sep 17 00:00:00 2001
From: "Vladimir V. Kisil" <kisilv@maths.leeds.ac.uk>
Date: Mon, 4 Nov 2013 11:05:02 +0000
Subject: [PATCH 1/4] Implicit derivation of functions.

Some function cannot be cleanly differentiated through the chain rule.
For example, it is natural to define derivative of the absolute value as

(abs(f))'=(f'*f.conjugate()+f*f'.conjugate())/2/abs(f)

This patch adds a possibility to define derivatives of functions in this way.
In particular the derivative of abs() is defined.


Signed-off-by: Vladimir V. Kisil <kisilv@maths.leeds.ac.uk>
---
 ginac/function.cppy | 56 +++++++++++++++++++++++++++++++++++++++--------------
 ginac/function.hppy |  8 +++++++-
 ginac/function.py   |  2 +-
 ginac/inifcns.cpp   |  7 +++++++
 4 files changed, 57 insertions(+), 16 deletions(-)

diff --git a/ginac/function.cppy b/ginac/function.cppy
index b8259f9..c9d0cae 100644
--- a/ginac/function.cppy
+++ b/ginac/function.cppy
@@ -79,7 +79,7 @@ void function_options::initialize()
 	set_name("unnamed_function", "\\\\mbox{unnamed}");
 	nparams = 0;
 	eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
-		= derivative_f = power_f = series_f = 0;
+		= derivative_f = impl_derivative_f = power_f = series_f = 0;
 	info_f = 0;
 	evalf_params_first = true;
 	use_return_type = false;
@@ -90,6 +90,7 @@ void function_options::initialize()
 	imag_part_use_exvector_args = false;
 	expand_use_exvector_args = false;
 	derivative_use_exvector_args = false;
+	impl_derivative_use_exvector_args = false;
 	power_use_exvector_args = false;
 	series_use_exvector_args = false;
 	print_use_exvector_args = false;
@@ -634,21 +635,26 @@ ex function::derivative(const symbol & s) const
 		// Order Term function only differentiates the argument
 		return Order(seq[0].diff(s));
 	} else {
-		// Chain rule
-		ex arg_diff;
-		size_t num = seq.size();
-		for (size_t i=0; i<num; i++) {
-			arg_diff = seq[i].diff(s);
-			// We apply the chain rule only when it makes sense.  This is not
-			// just for performance reasons but also to allow functions to
-			// throw when differentiated with respect to one of its arguments
-			// without running into trouble with our automatic full
-			// differentiation:
-			if (!arg_diff.is_zero())
-				result += pderivative(i)*arg_diff;
+		try {
+			// Implicit derivation
+			result = impl_derivative(s);
+		} catch (...) {
+			// Chain rule
+			ex arg_diff;
+			size_t num = seq.size();
+			for (size_t i=0; i<num; i++) {
+				arg_diff = seq[i].diff(s);
+				// We apply the chain rule only when it makes sense.  This is not
+				// just for performance reasons but also to allow functions to
+				// throw when differentiated with respect to one of its arguments
+				// without running into trouble with our automatic full
+				// differentiation:
+				if (!arg_diff.is_zero())
+					result += pderivative(i)*arg_diff;
+			}
 		}
+		return result;
 	}
-	return result;
 }
 
 int function::compare_same_type(const basic & other) const
@@ -752,6 +758,28 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
 	throw(std::logic_error("function::pderivative(): no diff function defined"));
 }
 
+ex function::impl_derivative(const symbol & s) const // implicit differentiation
+{
+	GINAC_ASSERT(serial<registered_functions().size());
+	const function_options &opt = registered_functions()[serial];
+
+	// No implicit derivative defined? Then this function shall not be called!
+	if (opt.impl_derivative_f == NULL)
+		throw(std::logic_error("function::impl_derivative(): implicit derivation is called, but no such function defined"));
+
+	current_serial = serial;
+	if (opt.impl_derivative_use_exvector_args)
+		return ((impl_derivative_funcp_exvector)(opt.impl_derivative_f))(seq, s);
+	switch (opt.nparams) {
+		// the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+		case @N@:
+			return ((impl_derivative_funcp_@N@)(opt.impl_derivative_f))(@seq('seq[%(n)d]', N, 0)@, s);
+---
+		// end of generated lines
+	}
+}
+
 ex function::power(const ex & power_param) const // power of function
 {
 	GINAC_ASSERT(serial<registered_functions().size());
diff --git a/ginac/function.hppy b/ginac/function.hppy
index 3d953b9..153a4f2 100644
--- a/ginac/function.hppy
+++ b/ginac/function.hppy
@@ -59,6 +59,7 @@ typedef ex (* real_part_funcp)();
 typedef ex (* imag_part_funcp)();
 typedef ex (* expand_funcp)();
 typedef ex (* derivative_funcp)();
+typedef ex (* impl_derivative_funcp)();
 typedef ex (* power_funcp)();
 typedef ex (* series_funcp)();
 typedef void (* print_funcp)();
@@ -73,6 +74,7 @@ typedef ex (* real_part_funcp_@N@)( @args@ );
 typedef ex (* imag_part_funcp_@N@)( @args@ );
 typedef ex (* expand_funcp_@N@)( @args@, unsigned );
 typedef ex (* derivative_funcp_@N@)( @args@, unsigned );
+typedef ex (* impl_derivative_funcp_@N@)( @args@, const symbol & );
 typedef ex (* power_funcp_@N@)( @args@, const ex & );
 typedef ex (* series_funcp_@N@)( @args@, const relational &, int, unsigned );
 typedef void (* print_funcp_@N@)( @args@, const print_context & );
@@ -87,6 +89,7 @@ typedef ex (* @fp@_funcp_exvector)(const exvector &);
 ---
 typedef ex (* expand_funcp_exvector)(const exvector &, unsigned);
 typedef ex (* derivative_funcp_exvector)(const exvector &, unsigned);
+typedef ex (* impl_derivative_funcp_exvector)(const exvector &, const symbol &);
 typedef ex (* power_funcp_exvector)(const exvector &, const ex &);
 typedef ex (* series_funcp_exvector)(const exvector &, const relational &, int, unsigned);
 typedef void (* print_funcp_exvector)(const exvector &, const print_context &);
@@ -142,7 +145,7 @@ public:
 	unsigned get_nparams() const { return nparams; }
 
 protected:
-	bool has_derivative() const { return derivative_f != NULL; }
+	bool has_derivative() const { return (derivative_f != NULL) || (impl_derivative_f != NULL); }
 	bool has_power() const { return power_f != NULL; }
 	void test_and_set_nparams(unsigned n);
 	void set_print_func(unsigned id, print_funcp f);
@@ -159,6 +162,7 @@ protected:
 	imag_part_funcp imag_part_f;
 	expand_funcp expand_f;
 	derivative_funcp derivative_f;
+	impl_derivative_funcp impl_derivative_f;
 	power_funcp power_f;
 	series_funcp series_f;
 	std::vector<print_funcp> print_dispatch_table;
@@ -182,6 +186,7 @@ protected:
 	bool imag_part_use_exvector_args;
 	bool expand_use_exvector_args;
 	bool derivative_use_exvector_args;
+	bool impl_derivative_use_exvector_args;
 	bool power_use_exvector_args;
 	bool series_use_exvector_args;
 	bool print_use_exvector_args;
@@ -251,6 +256,7 @@ protected:
 	// non-virtual functions in this class
 protected:
 	ex pderivative(unsigned diff_param) const; // partial differentiation
+	ex impl_derivative(const symbol & s) const; // partial differentiation
 	static std::vector<function_options> & registered_functions();
 	bool lookup_remember_table(ex & result) const;
 	void store_remember_table(ex const & result) const;
diff --git a/ginac/function.py b/ginac/function.py
index 3f5e54e..26d378a 100755
--- a/ginac/function.py
+++ b/ginac/function.py
@@ -2,7 +2,7 @@
 # encoding: utf-8
 
 maxargs = 14
-methods = "eval evalf conjugate real_part imag_part expand derivative power series info print".split()
+methods = "eval evalf conjugate real_part imag_part expand derivative impl_derivative power series info print".split()
 
 import sys, os, optparse
 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'scripts'))
diff --git a/ginac/inifcns.cpp b/ginac/inifcns.cpp
index 02c909f..7692b34 100644
--- a/ginac/inifcns.cpp
+++ b/ginac/inifcns.cpp
@@ -275,6 +275,12 @@ static ex abs_expand(const ex & arg, unsigned options)
 		return abs(arg).hold();
 }
 
+static ex abs_impl_derivative(const ex & arg, const symbol & s)
+{
+	ex diff_arg = arg.diff(s);
+	return (diff_arg*arg.conjugate()+arg*diff_arg.conjugate())/2/abs(arg);
+}
+
 static void abs_print_latex(const ex & arg, const print_context & c)
 {
 	c.s << "{|"; arg.print(c); c.s << "|}";
@@ -339,6 +345,7 @@ bool abs_info(const ex & arg, unsigned inf)
 REGISTER_FUNCTION(abs, eval_func(abs_eval).
                        evalf_func(abs_evalf).
                        expand_func(abs_expand).
+                       impl_derivative_func(abs_impl_derivative).
                        info_func(abs_info).
                        print_func<print_latex>(abs_print_latex).
                        print_func<print_csrc_float>(abs_print_csrc_float).
-- 
2.1.1

