How to implementing your own solver

Ronny Bergmann

When you have used a few solvers from Manopt.jl for example like in the opening tutorial Get started: optimize! you might come to the idea of implementing a solver yourself.

After a short introduction of the algorithm we aim to implement, this tutorial first discusses the structural details, for example what a solver consists of and “works with”. Afterwards, we show how to implement the algorithm. Finally, we discuss how to make the algorithm both nice for the user as well as initialized in a way, that it can benefit from features already available in Manopt.jl.

Note

If you have implemented your own solver, we would be very happy to have that within Manopt.jl as well, so maybe consider opening a Pull Request

using Manopt, Manifolds, Random

Our guiding example: a random walk minimization

Since most serious algorithms should be implemented in Manopt.jl themselves directly, we implement a solver that randomly walks on the manifold and keeps track of the lowest point visited. As for algorithms in Manopt.jl we aim to implement this generically for any manifold that is implemented using ManifoldsBase.jl.

The random walk minimization

Given:

  • a manifold $\mathcal M$
  • a starting point $p=p^{(0)}$
  • a cost function $f: \mathcal M → ℝ$.
  • a parameter $\sigma > 0$.
  • a retraction $\operatorname{retr}_p(X)$ that maps $X ∈ T_p\mathcal M$ to the manifold.

We can run the following steps of the algorithm

  1. set $k=0$
  2. set our best point $q = p^{(0)}$
  3. Repeat until a stopping criterion is fulfilled
    1. Choose a random tangent vector $X^{(k)} ∈ T_{p^{(k)}}\mathcal M$ of length $\lVert X^{(k)} \rVert = \sigma$
    2. “Walk” along this direction, that is $p^{(k+1)} = \operatorname{retr}_{p^{(k)}}(X^{(k)})$
    3. If $f(p^{(k+1)}) < f(q)$ set q = p^{(k+1)}$ as our new best visited point
  4. Return $q$ as the resulting best point we visited

Preliminaries: elements a solver works on

There are two main ingredients a solver needs: a problem to work on and the state of a solver, which “identifies” the solver and stores intermediate results.

The “task”: an AbstractManoptProblem

A problem in Manopt.jl usually consists of a manifold (an AbstractManifold) and an AbstractManifoldObjective describing the function we have and its features. In our case the objective is (just) a ManifoldCostObjective that stores cost function f(M,p) -> R. More generally, it might for example store a gradient function or the Hessian or any other information we have about our task.

This is something independent of the solver itself, since it only identifies the problem we want to solve independent of how we want to solve it, or in other words, this type contains all information that is static and independent of the specific solver at hand.

Usually the problems variable is called mp.

The solver: an AbstractManoptSolverState

Everything that is needed by a solver during the iterations, all its parameters, interim values that are needed beyond just one iteration, is stored in a subtype of the AbstractManoptSolverState. This identifies the solver uniquely.

In our case we want to store five things

We can defined this as

mutable struct RandomWalkState{
    P,
    R<:AbstractRetractionMethod,
    S<:StoppingCriterion,
} <: AbstractManoptSolverState
  p::P
  q::P
  σ::Float64
  retraction_method::R
  stop::S
end

The stopping criterion is usually stored in the state’s stop field. If you have a reason to do otherwise, you have one more function to implement (see next section). For ease of use, a constructor can be provided, that for example chooses a good default for the retraction based on a given manifold.

function RandomWalkState(M::AbstractManifold, p::P=rand(M);
    σ = 0.1,
    retraction_method::R=default_retraction_method(M),
    stopping_criterion::S=StopAfterIteration(200)
) where {P, R<:AbstractRetractionMethod, S<:StoppingCriterion}
    return RandomWalkState{P,R,S}(p, copy(M, p), σ, retraction_method, stopping_criterion)
end

Parametrising the state avoid that we have abstract typed fields. The keyword arguments for the retraction and stopping criterion are the ones usually used in Manopt.jl and provide an easy way to construct this state now.

States usually have a shortened name as their variable, we use rws for our state here.

Implementing your solver

There is basically only two methods we need to implement for our solver

  • initialize_solver!(mp, rws) which initialises the solver before the first iteration
  • step_solver!(mp, rws, i) which implements the ith iteration, where i is given to you as the third parameter
  • get_iterate(rws) which accesses the iterate from other places in the solver
  • get_solver_result(rws) returning the solvers final (best) point we reached. By default this would return the last iterate rws.p (or more precisely calls get_iterate), but since we randomly walk and remember our best point in q, this has to return rws.q.

The first two functions are in-place functions, that is they modify our solver state rws. You implement these by multiple dispatch on the types after importing said functions from Manopt:

import Manopt: initialize_solver!, step_solver!, get_iterate, get_solver_result

The state above has two fields where we use the common names used in Manopt.jl, that is the StoppingCriterion is usually in stop and the iterate in p. If your choice is different, you need to reimplement

  • stop_solver!(mp, rws, i) to determine whether or not to stop after the ith iteration.
  • get_iterate(rws) to access the current iterate

We recommend to follow the general scheme with the stop field. If you have specific criteria when to stop, consider implementing your own stopping criterion instead.

Initialization and iterate access

For our solver, there is not so much to initialize, just to be safe we should copy over the initial value in p we start with, to q. We do not have to care about remembering the iterate, that is done by Manopt.jl. For the iterate access we just have to pass p.

function initialize_solver!(mp::AbstractManoptProblem, rws::RandomWalkState)
    copyto!(M, rws.q, rws.p) # Set p^{(0)} = q
    return rws
end
get_iterate(rws::RandomWalkState) = rws.p
get_solver_result(rws::RandomWalkState) = rws.q

and similarly we implement the step. Here we make use of the fact that the problem (and also the objective in fact) have access functions for their elements, the one we need is get_cost.

function step_solver!(mp::AbstractManoptProblem, rws::RandomWalkState, i)
    M = get_manifold(mp) # for ease of use get the manifold from the problem
    X = rand(M; vector_at=p)     # generate a direction
    X .*= rws.σ/norm(M, p, X)
    # Walk
    retract!(M, rws.p, rws.p, X, rws.retraction_method)
    # is the new point better? Then store it
    if get_cost(mp, rws.p) < get_cost(mp, rws.q)
        copyto!(M, rws.p, rws.q)
    end
    return rws
end

Performance wise we could improve the number of allocations by making X also a field of our rws but let’s keep it simple here. We could also store the cost of q in the state, but we shall see how to easily also enable this solver to allow for caching. In practice, however, it is preferable to cache intermediate values like cost of q in the state when it can be easily achieved. This way we do not have to deal with overheads of an external cache.

Now we can just run the solver already. We take the same example as for the other tutorials

We first define our task, the Riemannian Center of Mass from the Get started: optimize! tutorial.

Random.seed!(23)
n = 100
σ = π / 8
M = Sphere(2)
p = 1 / sqrt(2) * [1.0, 0.0, 1.0]
data = [exp(M, p,  σ * rand(M; vector_at=p)) for i in 1:n];
f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2)

We can now generate the problem with its objective and the state

mp = DefaultManoptProblem(M, ManifoldCostObjective(f))
s = RandomWalkState(M; σ = 0.2)

solve!(mp, s)
get_solver_result(s)
3-element Vector{Float64}:
 -0.2412674850987521
  0.8608618657176527
 -0.44800317943876844

The function solve! works also in place of s, but the last line illustrates how to access the result in general; we could also just look at s.p, but the function get_iterate is also used in several other places.

We could for example easily set up a second solver to work from a specified starting point with a different σ like

s2 = RandomWalkState(M, [1.0, 0.0, 0.0];  σ = 0.1)
solve!(mp, s2)
get_solver_result(s2)
3-element Vector{Float64}:
 1.0
 0.0
 0.0

Ease of use I: a high level interface

Manopt.jl offers a few additional features for solvers in their high level interfaces, for example debug= for debug, record= keywords for debug and recording within solver states or count= and cache keywords for the objective.

We can introduce these here as well with just a few lines of code. There are usually two steps. We further need three internal function from Manopt.jl

using Manopt: get_solver_return, indicates_convergence, status_summary

A high level interface using the objective

This could be considered as an interim step to the high-level interface: If objective, a ManifoldCostObjective is already initialized, the high level interface consists of the steps

  1. possibly decorate the objective
  2. generate the problem
  3. generate and possibly generate the state
  4. call the solver
  5. determine the return value

We illustrate the step with an in-place variant here. A variant that keeps the given start point unchanged would just add a copy(M, p) upfront. Manopt.jl provides both variants.

function random_walk_algorithm!(
    M::AbstractManifold,
    mgo::ManifoldCostObjective,
    p;
    σ = 0.1,
    retraction_method::AbstractRetractionMethod=default_retraction_method(M, typeof(p)),
    stopping_criterion::StoppingCriterion=StopAfterIteration(200),
    kwargs...,
)
    dmgo = decorate_objective!(M, mgo; kwargs...)
    dmp = DefaultManoptProblem(M, dmgo)
    s = RandomWalkState(M, [1.0, 0.0, 0.0];
        σ=0.1,
        retraction_method=retraction_method, stopping_criterion=stopping_criterion,
    )
    ds = decorate_state!(s; kwargs...)
    solve!(dmp, ds)
    return get_solver_return(get_objective(dmp), ds)
end
random_walk_algorithm! (generic function with 1 method)

The high level interface

Starting from the last section, the usual call a user would prefer is just passing a manifold M the cost f and maybe a start point p.

function random_walk_algorithm!(M::AbstractManifold, f, p=rand(M); kwargs...)
    mgo = ManifoldCostObjective(f)
    return random_walk_algorithm!(M, mgo, p; kwargs...)
end
random_walk_algorithm! (generic function with 3 methods)

Ease of Use II: the State Summary

For the case that you set return_state=true the solver should return a summary of the run. When a show method is provided, users can easily read such summary in a terminal. It should reflect its main parameters, if they are not too verbose and provide information about the reason it stopped and whether this indicates convergence.

Here it would for example look like

import Base: show
function show(io::IO, rws::RandomWalkState)
    i = get_count(rws, :Iterations)
    Iter = (i > 0) ? "After $i iterations\n" : ""
    Conv = indicates_convergence(rws.stop) ? "Yes" : "No"
    s = """
    # Solver state for `Manopt.jl`s Tutorial Random Walk
    $Iter
    ## Parameters
    * retraction method: $(rws.retraction_method)
    * σ                : $(rws.σ)

    ## Stopping criterion

    $(status_summary(rws.stop))
    This indicates convergence: $Conv"""
    return print(io, s)
end

Now the algorithm can be easily called and provides all features of a Manopt.jl algorithm. For example to see the summary, we could now just call

q = random_walk_algorithm!(M, f; return_state=true)
# Solver state for `Manopt.jl`s Tutorial Random Walk
After 200 iterations

## Parameters
* retraction method: ExponentialRetraction()
* σ                : 0.1

## Stopping criterion

Max Iteration 200:  reached
This indicates convergence: No

Conclusion & Beyond

We saw in this tutorial how to implement a simple cost-based algorithm, to illustrate how optimization algorithms are covered in Manopt.jl.

One feature we did not cover is that most algorithms allow for in-place and allocation functions, as soon as they work on more than just the cost, for example use gradients, proximal maps or Hessians. This is usually a keyword argument of the objective and hence also part of the high-level interfaces.