using TermInterface
24 Symbolic derivatives
This section uses the TermInterface
add-on package.
The ability to breakdown an expression into operations and their arguments is necessary when trying to apply the differentiation rules. Such rules are applied from the outside in. Identifying the proper “outside” function is usually most of the battle when finding derivatives.
In the following example, we provide a sketch of a framework to differentiate expressions by a chosen symbol to illustrate how the outer function drives the task of differentiation.
The Symbolics
package provides native symbolic manipulation abilities for Julia
, similar to SymPy
, though without the dependence on Python
. The TermInterface
package, used by Symbolics
, provides a generic interface for expression manipulation for this package that also is implemented for Julia
’s expressions and symbols.
An expression is an unevaluated portion of code that for our purposes below contains other expressions, symbols, and numeric literals. They are held in the Expr
type. A symbol, such as :x
, is distinct from a string (e.g. "x"
) and is useful to the programmer to distinguish between the contents a variable points to from the name of the variable. Symbols are fundamental to metaprogramming in Julia
. An expression is a specification of some set of statements to execute. A numeric literal is just a number.
The three main functions from TermInterface
we leverage are isexpr
, operation
, and arguments
. The operation
function returns the “outside” function of an expression. For example:
operation(:(sin(x)))
:sin
We see the sin
function, referred to by a symbol (:sin
). The :(...)
above quotes the argument, and does not evaluate it, hence x
need not be defined above. (The :
notation is used to create both symbols and expressions.)
The arguments are the terms that the outside function is called on. For our purposes there may be \(1\) (unary), \(2\) (binary), or more than \(2\) (nary) arguments. (We ignore zero-argument functions.) For example:
arguments(:(-x)), arguments(:(pi^2)), arguments(:(1 + x + x^2))
(Any[:x], Any[:pi, 2], Any[1, :x, :(x ^ 2)])
(The last one may be surprising, but all three arguments are passed to the +
function.)
TermInterface
has an arity
function defined by length(arguments(ex))
that will be used for dispatch below.
Differentiation must distinguish between expressions, variables, and numbers. Mathematically expressions have an “outer” function, whereas variables and numbers can be directly differentiated. The isexpr
function in TermInterface
returns true
when passed an expression, and false
when passed a symbol or numeric literal. The latter two may be distinguished by isa(..., Symbol)
.
Here we create a function, D
, that when it encounters an expression it dispatches to a specific method of D
based on the outer operation and arity, otherwise if it encounters a symbol or a numeric literal it does the differentiation:
function D(ex, var=:x)
if isexpr(ex)
= operation(ex), arguments(ex)
op, args D(Val(op), Val(arity(ex)), args, var)
elseif isa(ex, Symbol) && ex == var
1
else
0
end
end
D (generic function with 2 methods)
(The use of Val
is an idiom of Julia
allowing dispatch on certain values such as function names and numbers.)
Now to develop methods for D
for different “outside” functions and arities.
Addition can be unary (:(+x)
is a valid quoting, even if it might simplify to the symbol :x
when evaluated), binary, or nary. Here we implement the sum rule:
D(::Val{:+}, ::Val{1}, args, var) = D(first(args), var)
function D(::Val{:+}, ::Val{2}, args, var)
= D.(args, var)
a′, b′ :($a′ + $b′)
end
function D(::Val{:+}, ::Any, args, var)
= D.(args, var)
a′s :(+($a′s...))
end
D (generic function with 5 methods)
The args
are always held in a container, so the unary method must pull out the first one. The binary case should read as: apply D
to each of the two arguments, and then create a quoted expression containing the sum of the results. The dollar signs interpolate into the quoting. (The “primes” are unicode notation achieved through \prime[tab]
and not operations.) The nary method (which catches any arity besides 1
and 2
) does something similar, only using splatting to produce the sum.
Subtraction must also be implemented in a similar manner, but not for the nary case, as subtraction is not associative:
function D(::Val{:-}, ::Val{1}, args, var)
= D(first(args), var)
a′ :(-$a′)
end
function D(::Val{:-}, ::Val{2}, args, var)
= D.(args, var)
a′, b′ :($a′ - $b′)
end
D (generic function with 7 methods)
The product rule is similar to addition, in that \(3\) cases are considered:
D(op::Val{:*}, ::Val{1}, args, var) = D(first(args), var)
function D(::Val{:*}, ::Val{2}, args, var)
= args
a, b = D.(args, var)
a′, b′ :($a′ * $b + $a * $b′)
end
function D(op::Val{:*}, ::Any, args, var)
... = args
a, bs= :(*($(bs...)))
b = D(a, var)
a′ = D(b, var)
b′ :($a′ * $b + $a * $b′)
end
D (generic function with 10 methods)
The nary case above just peels off the first factor and then uses the binary product rule.
Division is only a binary operation, so here we have the quotient rule:
function D(::Val{:/}, ::Val{2}, args, var)
= args
u,v = D(u, var), D(v, var)
u′, v′ :( ($u′*$v - $u*$v′)/$v^2 )
end
D (generic function with 11 methods)
Powers are handled a bit differently. The power rule would require checking if the exponent does not contain the variable of differentiation, exponential derivatives would require checking the base does not contain the variable of differentation. Trying to implement both would be tedious, so we use the fact that \(x = \exp(\log(x))\) (for x
in the domain of log
, more care is necessary if x
is negative) to differentiate:
function D(::Val{:^}, ::Val{2}, args, var)
= args
a, b D(:(exp($b*log($a))), var) # a > 0 assumed here
end
D (generic function with 12 methods)
That leaves the task of defining a rule to differentiate both exp
and log
. We do so with unary definitions. In the following we also implement sin
and cos
rules:
function D(::Val{:exp}, ::Val{1}, args, var)
= only(args)
a = D(a, var)
a′ :(exp($a) * $a′)
end
function D(::Val{:log}, ::Val{1}, args, var)
= only(args)
a = D(a, var)
a′ :(1/$a * $a′)
end
function D(::Val{:sin}, ::Val{1}, args, var)
= only(args)
a = D(a, var)
a′ :(cos($a) * $a′)
end
function D(::Val{:cos}, ::Val{1}, args, var)
= only(args)
a = D(a, var)
a′ :(-sin($a) * $a′)
end
D (generic function with 16 methods)
The pattern is similar for each. The $a′
factor is needed due to the chain rule. The above illustrates the simple pattern necessary to add a derivative rule for a function.
Several automatic differentiation packages use a set of rules defined following an interface spelled out in the package ChainRules.jl
. Leveraging multi-dimensional derivatives, the chain rule is the only rule needed of the sum, product, quotient and chain rules.
More functions could be included, but for this example the above will suffice, as now the system is ready to be put to work.
= :(x + 2/x)
ex D(ex, :x)
:(1 + (0 * x - 2 * 1) / x ^ 2)
The output does not simplify, so some work is needed to identify 1 - 2/x^2
as the answer.
= :( (x + sin(x))/sin(x))
ex D(ex, :x)
:(((1 + cos(x) * 1) * sin(x) - (x + sin(x)) * (cos(x) * 1)) / sin(x) ^ 2)
Again, simplification is not performed.
Finally, we have a second derivative taken below:
= :(sin(x) - x - x^3/6)
ex D(D(ex, :x), :x)
:((((-(sin(x)) * 1) * 1 + cos(x) * 0) - 0) - (((((exp(3 * log(x)) * (0 * log(x) + 3 * ((1 / x) * 1))) * (0 * log(x) + 3 * ((1 / x) * 1)) + exp(3 * log(x)) * ((0 * log(x) + 0 * ((1 / x) * 1)) + (0 * ((1 / x) * 1) + 3 * (((0 * x - 1 * 1) / x ^ 2) * 1 + (1 / x) * 0)))) * 6 + (exp(3 * log(x)) * (0 * log(x) + 3 * ((1 / x) * 1))) * 0) - ((exp(3 * log(x)) * (0 * log(x) + 3 * ((1 / x) * 1))) * 0 + x ^ 3 * 0)) * 6 ^ 2 - ((exp(3 * log(x)) * (0 * log(x) + 3 * ((1 / x) * 1))) * 6 - x ^ 3 * 0) * (exp(2 * log(6)) * (0 * log(6) + 2 * ((1 / 6) * 0)))) / (6 ^ 2) ^ 2)
The length of the expression should lead to further appreciation for simplification steps taken when doing such a computation by hand.