24  Symbolic derivatives

LogLevel(1001)

This section uses this add-on package:

using TermInterface

The ability to breakdown an expression into operations and their arguments is necessary when trying to apply the differentiation rules. Such rules are applied from the outside in. Identifying the proper “outside” function is usually most of the battle when finding derivatives.

In the following example, we provide a sketch of a framework to differentiate expressions by a chosen symbol to illustrate how the outer function drives the task of differentiation.

The Symbolics package provides native symbolic manipulation abilities for Julia, similar to SymPy, though without the dependence on Python. The TermInterface package, used by Symbolics, provides a generic interface for expression manipulation for this package that also is implemented for Julia’s expressions and symbols.

An expression is an unevaluated portion of code that for our purposes below contains other expressions, symbols, and numeric literals. They are held in the Expr type. A symbol, such as :x, is distinct from a string (e.g. "x") and is useful to the programmer to distinguish between the contents a variable points to from the name of the variable. Symbols are fundamental to metaprogramming in Julia. An expression is a specification of some set of statements to execute. A numeric literal is just a number.

The three main functions from TermInterface we leverage are isexpr, operation, and arguments. The operation function returns the “outside” function of an expression. For example:

operation(:(sin(x)))
:sin

We see the sin function, referred to by a symbol (:sin). The :(...) above quotes the argument, and does not evaluate it, hence x need not be defined above. (The : notation is used to create both symbols and expressions.)

The arguments are the terms that the outside function is called on. For our purposes there may be \(1\) (unary), \(2\) (binary), or more than \(2\) (nary) arguments. (We ignore zero-argument functions.) For example:

arguments(:(-x)), arguments(:(pi^2)), arguments(:(1 + x + x^2))
(Any[:x], Any[:pi, 2], Any[1, :x, :(x ^ 2)])

(The last one may be surprising, but all three arguments are passed to the + function.)

Here we define a function to decide the arity of an expression based on the number of arguments it is called with:

function arity(ex)
        n = length(arguments(ex))
        n == 1 ? Val(:unary) :
        n == 2 ? Val(:binary) : Val(:nary)
end
arity (generic function with 1 method)

Differentiation must distinguish between expressions, variables, and numbers. Mathematically expressions have an “outer” function, whereas variables and numbers can be directly differentiated. The isexpr function in TermInterface returns true when passed an expression, and false when passed a symbol or numeric literal. The latter two may be distinguished by isa(..., Symbol).

Here we create a function, D, that when it encounters an expression it dispatches to a specific method of D based on the outer operation and arity, otherwise if it encounters a symbol or a numeric literal it does the differentiation:

function D(ex, var=:x)
    if isexpr(ex)
        op, args = operation(ex), arguments(ex)
        D(Val(op), arity(ex), args, var)
    elseif isa(ex, Symbol) && ex == :x
        1
    else
        0
    end
end
D (generic function with 2 methods)

Now to develop methods for D for different “outside” functions and arities.

Addition can be unary (:(+x) is a valid quoting, even if it might simplify to the symbol :x when evaluated), binary, or nary. Here we implement the sum rule:

D(::Val{:+}, ::Val{:unary}, args, var) = D(first(args), var)

function D(::Val{:+}, ::Val{:binary}, args, var)
    a′, b′ = D.(args, var)
    :($a′ + $b′)
end

function D(::Val{:+}, ::Val{:nary}, args, var)
    a′s = D.(args, var)
    :(+($a′s...))
end
D (generic function with 5 methods)

The args are always held in a container, so the unary method must pull out the first one. The binary case should read as: apply D to each of the two arguments, and then create a quoted expression containing the sum of the results. The dollar signs interpolate into the quoting. (The “primes” are unicode notation achieved through \prime[tab] and not operations.) The nary case does something similar, only using splatting to produce the sum.

Subtraction must also be implemented in a similar manner, but not for the nary case:

function D(::Val{:-}, ::Val{:unary}, args, var)
    a′ = D(first(args), var)
    :(-$a′)
end
function D(::Val{:-}, ::Val{:binary}, args, var)
    a′, b′ = D.(args, var)
    :($a′ - $b′)
end
D (generic function with 7 methods)

The product rule is similar to addition, in that \(3\) cases are considered:

D(op::Val{:*}, ::Val{:unary}, args, var) = D(first(args), var)

function D(::Val{:*}, ::Val{:binary}, args, var)
    a, b = args
    a′, b′ = D.(args, var)
    :($a′ * $b + $a * $b′)
end

function D(op::Val{:*}, ::Val{:nary}, args, var)
    a, bs... = args
    b = :(*($(bs...)))
    a′ = D(a, var)
    b′ = D(b, var)
    :($a′ * $b + $a * $b′)
end
D (generic function with 10 methods)

The nary case above just peels off the first factor and then uses the binary product rule.

Division is only a binary operation, so here we have the quotient rule:

function D(::Val{:/}, ::Val{:binary}, args, var)
    u,v = args
    u′, v′ = D(u, var), D(v, var)
    :( ($u′*$v - $u*$v′)/$v^2 )
end
D (generic function with 11 methods)

Powers are handled a bit differently. The power rule would require checking if the exponent does not contain the variable of differentiation, exponential derivatives would require checking the base does not contain the variable of differentation. Trying to implement both would be tedious, so we use the fact that \(x = \exp(\log(x))\) (for x in the domain of log, more care is necessary if x is negative) to differentiate:

function D(::Val{:^}, ::Val{:binary}, args, var)
    a, b = args
    D(:(exp($b*log($a))), var)  # a > 0 assumed here
end
D (generic function with 12 methods)

That leaves the task of defining a rule to differentiate both exp and log. We do so with unary definitions. In the following we also implement sin and cos rules:

function D(::Val{:exp}, ::Val{:unary}, args, var)
    a = first(args)
    a′ = D(a, var)
    :(exp($a) * $a′)
end

function D(::Val{:log}, ::Val{:unary}, args, var)
    a = first(args)
    a′ = D(a, var)
    :(1/$a * $a′)
end

function D(::Val{:sin}, ::Val{:unary}, args, var)
    a = first(args)
    a′ = D(a, var)
    :(cos($a) * $a′)
end

function D(::Val{:cos}, ::Val{:unary}, args, var)
    a = first(args)
    a′ = D(a, var)
    :(-sin($a) * $a′)
end
D (generic function with 16 methods)

The pattern is similar for each. The $a′ factor is needed due to the chain rule. The above illustrates the simple pattern necessary to add a derivative rule for a function.

Note

Several automatic differentiation packages use a set of rules defined following an interface spelled out in the package ChainRules.jl. Leveraging multi-dimensional derivatives, the chain rule is the only rule needed of the sum, product, quotient and chain rules.

More functions could be included, but for this example the above will suffice, as now the system is ready to be put to work.

ex₁ = :(x + 2/x)
D(ex₁, :x)
:(1 + (0 * x - 2 * 1) / x ^ 2)

The output does not simplify, so some work is needed to identify 1 - 2/x^2 as the answer.

ex₂ = :( (x + sin(x))/sin(x))
D(ex₂, :x)
:(((1 + cos(x) * 1) * sin(x) - (x + sin(x)) * (cos(x) * 1)) / sin(x) ^ 2)

Again, simplification is not performed.

Finally, we have a second derivative taken below:

ex₃ = :(sin(x) - x - x^3/6)
D(D(ex₃, :x), :x)
:((((-(sin(x)) * 1) * 1 + cos(x) * 0) - 0) - (((((exp(3 * log(x)) * (0 * log(x) + 3 * ((1 / x) * 1))) * (0 * log(x) + 3 * ((1 / x) * 1)) + exp(3 * log(x)) * ((0 * log(x) + 0 * ((1 / x) * 1)) + (0 * ((1 / x) * 1) + 3 * (((0 * x - 1 * 1) / x ^ 2) * 1 + (1 / x) * 0)))) * 6 + (exp(3 * log(x)) * (0 * log(x) + 3 * ((1 / x) * 1))) * 0) - ((exp(3 * log(x)) * (0 * log(x) + 3 * ((1 / x) * 1))) * 0 + x ^ 3 * 0)) * 6 ^ 2 - ((exp(3 * log(x)) * (0 * log(x) + 3 * ((1 / x) * 1))) * 6 - x ^ 3 * 0) * (exp(2 * log(6)) * (0 * log(6) + 2 * ((1 / 6) * 0)))) / (6 ^ 2) ^ 2)

The length of the expression should lead to further appreciation for simplification steps taken when doing such a computation by hand.