From 47136d078ba37491faf3310b72b0069099480fc8 Mon Sep 17 00:00:00 2001
From: "Vladimir V. Kisil" <kisilv@maths.leeds.ac.uk>
Date: Wed, 31 Jul 2013 15:14:32 +0100
Subject: [PATCH 5/9] Add specific-tailored expand to GiNaC functions. This
 patch adds expand() methods which can be tailored to particular function. For
 example, exp(a*b).expand() -> exp(a)*exp(b) and log(p*q).expand() -> log(p) +
 log(q).

Signed-off-by: Vladimir V. Kisil <kisilv@maths.leeds.ac.uk>
---
 check/exam_inifcns.cpp  | 58 +++++++++++++++++++++++++++++++++++++++-
 ginac/function.cppy     | 42 +++++++++++++++++++++--------
 ginac/function.hppy     |  5 ++++
 ginac/function.py       |  2 +-
 ginac/inifcns.cpp       | 21 +++++++++++++++
 ginac/inifcns_trans.cpp | 70 +++++++++++++++++++++++++++++++++++++++++++++++++
 6 files changed, 185 insertions(+), 13 deletions(-)

diff --git a/check/exam_inifcns.cpp b/check/exam_inifcns.cpp
index 994aa34..424aa26 100644
--- a/check/exam_inifcns.cpp
+++ b/check/exam_inifcns.cpp
@@ -222,6 +222,7 @@ static unsigned inifcns_consist_zeta()
 static unsigned inifcns_consist_abs()
 {
 	unsigned result = 0;
+	symbol z("z");
 	realsymbol a("a"), b("b"), x("x"), y("y");
 	possymbol p("p");
 
@@ -242,6 +243,59 @@ static unsigned inifcns_consist_abs()
 	if (!abs(pow(x+I*y,a+I*b)).eval().is_equal(abs(pow(x+I*y,a+I*b))))
 		++result;
 
+	// check expansion of abs
+	if (!abs(-7*z*a*p).expand().is_equal(7*abs(z)*abs(a)*p))
+		++result;
+
+	return result;
+}
+
+static unsigned inifcns_consist_exp()
+{
+	unsigned result = 0;
+	symbol a("a"), b("b");
+
+	if (!exp(a+b).expand().is_equal(exp(a)*exp(b)))
+		++result;
+
+	// shall not be expanded since the arg is not add
+	if (!exp(pow(a+b,2)).expand().is_equal(exp(pow(a+b,2))))
+		++result;
+
+	// expand now
+	if (!exp(pow(a+b,2)).expand(expand_options::expand_function_args)
+		.is_equal(exp(a*a)*exp(b*b)*exp(2*a*b)))
+		++result;
+
+	return result;
+}
+
+static unsigned inifcns_consist_log()
+{
+	unsigned result = 0;
+	symbol z("a"), w("b");
+	realsymbol a("a"), b("b");
+	possymbol p("p"), q("q");
+
+	// do not expand
+	if (!log(z*w).expand().is_equal(log(z*w)))
+		++result;
+
+	// do not expand
+	if (!log(a*b).expand().is_equal(log(a*b)))
+		++result;
+
+	// shall expand
+	if (!log(p*q).expand().is_equal(log(p) +log(q)))
+		++result;
+
+	// a bit more complicated
+	ex e1=log(-7*p*pow(q,3)*a*pow(b,2)*z*w).expand();
+	ex e2=log(7)+log(p)+log(pow(q,3))+log(-z*a*w*pow(b,2));
+	if (!e1.is_equal(e2))
+		++result;
+
+	return result;
 }
 
 static unsigned inifcns_consist_various()
@@ -253,7 +307,7 @@ static unsigned inifcns_consist_various()
 		clog << "ERROR: binomial(n,0) != 1" << endl;		
 		++result;
 	}
-	
+
 	return result;
 }
 
@@ -268,6 +322,8 @@ unsigned exam_inifcns()
 	result += inifcns_consist_psi();  cout << '.' << flush;
 	result += inifcns_consist_zeta();  cout << '.' << flush;
 	result += inifcns_consist_abs();  cout << '.' << flush;
+	result += inifcns_consist_exp();  cout << '.' << flush;
+	result += inifcns_consist_log();  cout << '.' << flush;
 	result += inifcns_consist_various();  cout << '.' << flush;
 	
 	return result;
diff --git a/ginac/function.cppy b/ginac/function.cppy
index fb83cdc..1b76184 100644
--- a/ginac/function.cppy
+++ b/ginac/function.cppy
@@ -78,8 +78,8 @@ void function_options::initialize()
 {
 	set_name("unnamed_function", "\\\\mbox{unnamed}");
 	nparams = 0;
-	eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = derivative_f
-		= power_f = series_f = 0;
+	eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
+		= derivative_f = power_f = series_f = 0;
 	evalf_params_first = true;
 	use_return_type = false;
 	eval_use_exvector_args = false;
@@ -87,6 +87,7 @@ void function_options::initialize()
 	conjugate_use_exvector_args = false;
 	real_part_use_exvector_args = false;
 	imag_part_use_exvector_args = false;
+	expand_use_exvector_args = false;
 	derivative_use_exvector_args = false;
 	power_use_exvector_args = false;
 	series_use_exvector_args = false;
@@ -360,15 +361,6 @@ next_context:
 	}
 }
 
-ex function::expand(unsigned options) const
-{
-	// Only expand arguments when asked to do so
-	if (options & expand_options::expand_function_args)
-		return inherited::expand(options);
-	else
-		return (options == 0) ? setflag(status_flags::expanded) : *this;
-}
-
 ex function::eval(int level) const
 {
 	if (level>1) {
@@ -756,6 +748,34 @@ ex function::power(const ex & power_param) const // power of function
 	throw(std::logic_error("function::power(): no power function defined"));
 }
 
+ex function::expand(unsigned options) const
+{
+	GINAC_ASSERT(serial<registered_functions().size());
+	const function_options &opt = registered_functions()[serial];
+	
+	// No expand defined? Then return the same function with expanded arguments (if required)
+	if (opt.expand_f == NULL) {
+		// Only expand arguments when asked to do so
+		if (options & expand_options::expand_function_args)
+			return inherited::expand(options);
+		else
+			return (options == 0) ? setflag(status_flags::expanded) : *this;
+	}
+
+	current_serial = serial;
+	if (opt.expand_use_exvector_args)
+		return ((expand_funcp_exvector)(opt.expand_f))(seq,  options);
+	switch (opt.nparams) {
+		// the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+		case @N@:
+			return ((expand_funcp_@N@)(opt.expand_f))(@seq('seq[%(n)d]', N, 0)@, options);
+---
+		// end of generated lines
+	}
+	throw(std::logic_error("function::expand(): no expand of function defined"));
+}
+
 std::vector<function_options> & function::registered_functions()
 {
 	static std::vector<function_options> rf = std::vector<function_options>();
diff --git a/ginac/function.hppy b/ginac/function.hppy
index e7ec83e..fa17813 100644
--- a/ginac/function.hppy
+++ b/ginac/function.hppy
@@ -57,6 +57,7 @@ typedef ex (* evalf_funcp)();
 typedef ex (* conjugate_funcp)();
 typedef ex (* real_part_funcp)();
 typedef ex (* imag_part_funcp)();
+typedef ex (* expand_funcp)();
 typedef ex (* derivative_funcp)();
 typedef ex (* power_funcp)();
 typedef ex (* series_funcp)();
@@ -69,6 +70,7 @@ typedef ex (* evalf_funcp_@N@)( @args@ );
 typedef ex (* conjugate_funcp_@N@)( @args@ );
 typedef ex (* real_part_funcp_@N@)( @args@ );
 typedef ex (* imag_part_funcp_@N@)( @args@ );
+typedef ex (* expand_funcp_@N@)( @args@, unsigned );
 typedef ex (* derivative_funcp_@N@)( @args@, unsigned );
 typedef ex (* power_funcp_@N@)( @args@, const ex & );
 typedef ex (* series_funcp_@N@)( @args@, const relational &, int, unsigned );
@@ -81,6 +83,7 @@ typedef void (* print_funcp_@N@)( @args@, const print_context & );
 +++ for fp in "eval evalf conjugate real_part imag_part".split():
 typedef ex (* @fp@_funcp_exvector)(const exvector &);
 ---
+typedef ex (* expand_funcp_exvector)(const exvector &, unsigned);
 typedef ex (* derivative_funcp_exvector)(const exvector &, unsigned);
 typedef ex (* power_funcp_exvector)(const exvector &, const ex &);
 typedef ex (* series_funcp_exvector)(const exvector &, const relational &, int, unsigned);
@@ -151,6 +154,7 @@ protected:
 	conjugate_funcp conjugate_f;
 	real_part_funcp real_part_f;
 	imag_part_funcp imag_part_f;
+	expand_funcp expand_f;
 	derivative_funcp derivative_f;
 	power_funcp power_f;
 	series_funcp series_f;
@@ -172,6 +176,7 @@ protected:
 	bool conjugate_use_exvector_args;
 	bool real_part_use_exvector_args;
 	bool imag_part_use_exvector_args;
+	bool expand_use_exvector_args;
 	bool derivative_use_exvector_args;
 	bool power_use_exvector_args;
 	bool series_use_exvector_args;
diff --git a/ginac/function.py b/ginac/function.py
index 6ceb254..0ecb918 100755
--- a/ginac/function.py
+++ b/ginac/function.py
@@ -2,7 +2,7 @@
 # encoding: utf-8
 
 maxargs = 14
-methods = "eval evalf conjugate real_part imag_part derivative power series print".split()
+methods = "eval evalf conjugate real_part imag_part expand derivative power series print".split()
 
 import sys, os, optparse
 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'scripts'))
diff --git a/ginac/inifcns.cpp b/ginac/inifcns.cpp
index 08f9b02..6add43d 100644
--- a/ginac/inifcns.cpp
+++ b/ginac/inifcns.cpp
@@ -206,6 +206,26 @@ static ex abs_eval(const ex & arg)
 	return abs(arg).hold();
 }
 
+static ex abs_expand(const ex & arg, unsigned options)
+{
+	if (is_exactly_a<mul>(arg)) {
+		exvector prodseq;
+		prodseq.reserve(arg.nops());
+		for (const_iterator i = arg.begin(); i != arg.end(); ++i) {
+			if (options & expand_options::expand_function_args)
+				prodseq.push_back(abs(i->expand(options)));
+			else
+				prodseq.push_back(abs(*i));
+		}
+		return (new mul(prodseq))->setflag(status_flags::dynallocated | status_flags::expanded);
+	}
+
+	if (options & expand_options::expand_function_args)
+		return abs(arg.expand(options)).hold();
+	else
+		return abs(arg).hold();
+}
+
 static void abs_print_latex(const ex & arg, const print_context & c)
 {
 	c.s << "{|"; arg.print(c); c.s << "|}";
@@ -242,6 +262,7 @@ static ex abs_power(const ex & arg, const ex & exp)
 
 REGISTER_FUNCTION(abs, eval_func(abs_eval).
                        evalf_func(abs_evalf).
+                       expand_func(abs_expand).
                        print_func<print_latex>(abs_print_latex).
                        print_func<print_csrc_float>(abs_print_csrc_float).
                        print_func<print_csrc_double>(abs_print_csrc_float).
diff --git a/ginac/inifcns_trans.cpp b/ginac/inifcns_trans.cpp
index 10a3675..0d4904a 100644
--- a/ginac/inifcns_trans.cpp
+++ b/ginac/inifcns_trans.cpp
@@ -24,6 +24,8 @@
 #include "inifcns.h"
 #include "ex.h"
 #include "constant.h"
+#include "add.h"
+#include "mul.h"
 #include "numeric.h"
 #include "power.h"
 #include "operators.h"
@@ -81,6 +83,26 @@ static ex exp_eval(const ex & x)
 	return exp(x).hold();
 }
 
+static ex exp_expand(const ex & arg, unsigned options)
+{
+	ex exp_arg;
+	if (options & expand_options::expand_function_args)
+		exp_arg = arg.expand(options);
+	else
+		exp_arg=arg;
+
+	if (is_exactly_a<add>(exp_arg)) {
+		exvector prodseq;
+		prodseq.reserve(exp_arg.nops());
+		for (const_iterator i = exp_arg.begin(); i != exp_arg.end(); ++i)
+			prodseq.push_back(exp(*i));
+
+		return (new mul(prodseq))->setflag(status_flags::dynallocated | status_flags::expanded);
+	}
+
+	return exp(exp_arg).hold();
+}
+
 static ex exp_deriv(const ex & x, unsigned deriv_param)
 {
 	GINAC_ASSERT(deriv_param==0);
@@ -107,6 +129,7 @@ static ex exp_conjugate(const ex & x)
 
 REGISTER_FUNCTION(exp, eval_func(exp_eval).
                        evalf_func(exp_evalf).
+                       expand_func(exp_expand).
                        derivative_func(exp_deriv).
                        real_part_func(exp_real_part).
                        imag_part_func(exp_imag_part).
@@ -265,6 +288,52 @@ static ex log_imag_part(const ex & x)
 	return atan2(GiNaC::imag_part(x), GiNaC::real_part(x));
 }
 
+static ex log_expand(const ex & arg, unsigned options)
+{
+	if (is_exactly_a<mul>(arg) && !arg.info(info_flags::indefinite)) {
+		exvector sumseq;
+		exvector prodseq;
+		sumseq.reserve(arg.nops());
+		prodseq.reserve(arg.nops());
+		bool possign=true;
+
+		// searching for positive/negative factors
+		for (const_iterator i = arg.begin(); i != arg.end(); ++i) {
+			ex e;
+			if (options & expand_options::expand_function_args)
+				e=i->expand(options);
+			else
+				e=*i;
+			if (e.info(info_flags::positive))
+				sumseq.push_back(log(e));
+			else if (e.info(info_flags::negative)) {
+				sumseq.push_back(log(-e));
+				possign = !possign;
+			} else
+				prodseq.push_back(e);
+		}
+
+		if (sumseq.size() > 0) {
+			ex newarg;
+			if (options & expand_options::expand_function_args)
+				newarg=((possign?_ex1:_ex_1)*mul(prodseq)).expand(options);
+			else {
+				newarg=(possign?_ex1:_ex_1)*mul(prodseq);
+				ex_to<basic>(newarg).setflag(status_flags::purely_indefinite);
+			}
+			return add(sumseq)+log(newarg);
+		} else {
+			if (!(options & expand_options::expand_function_args))
+				ex_to<basic>(arg).setflag(status_flags::purely_indefinite);
+		}
+	}
+
+	if (options & expand_options::expand_function_args)
+		return log(arg.expand(options)).hold();
+	else
+		return log(arg).hold();
+}
+
 static ex log_conjugate(const ex & x)
 {
 	// conjugate(log(x))==log(conjugate(x)) unless on the branch cut which
@@ -281,6 +350,7 @@ static ex log_conjugate(const ex & x)
 
 REGISTER_FUNCTION(log, eval_func(log_eval).
                        evalf_func(log_evalf).
+                       expand_func(log_expand).
                        derivative_func(log_deriv).
                        series_func(log_series).
                        real_part_func(log_real_part).
-- 
1.8.4.rc3

