# How to implementing your own solver

Ronny Bergmann

When you have used a few solvers from Manopt.jl for example like in the opening tutorial Get Started: Optimize! you might come to the idea of implementing a solver yourself.

After a short introduction of the algorithm we will implement, this tutorial first discusses the structural details, i.e. what a solver consists of and “works with”. Afterwards, we will show how to implement the algorithm. Finally, we will discuss how to make the algorithm both nice for the user as well as initialized in a way, that it can benefit from features already available in Manopt.jl.

Note

If you have implemented your own solver, we would be very happy to have that within Manopt.jl as well, so maybe consider opening a Pull Request

using Manopt, Manifolds, Random

## Our Guiding Example: A random walk Minimization

Since most serious algorithms should be implemented in Manopt.jl themselves directly, we will implement a solver that randomly walks on the manifold and keeps track of the lowest point visited. As for algorithms in Manopt.jl we aim to implement this generically for any manifold that is implemented using ManifoldsBase.jl.

The Random Walk Minimization

Given:

• a manifold $\mathcal M$
• a starting point $p=p^{(0)}$
• a cost function $f: \mathcal M \to\mathbb R$.
• a parameter $\sigma > 0$.
• a retraction $\operatorname{retr}_p(X)$ that maps $X\in T_p\mathcal M$ to the manifold.

We can run the following steps of the algorithm

1. set $k=0$
2. set our best point $q = p^{(0)}$
3. Repeat until a stopping criterion is fulfilled
1. Choose a random tangent vector $X^{(k)} \in T_{p^{(k)}}\mathcal M$ of length $\lVert X^{(k)} \rVert = \sigma$
2. “Walk” along this direction, i.e. $p^{(k+1)} = \operatorname{retr}_{p^{(k)}}(X^{(k)})$
3. If $f(p^{(k+1)}) < f(q)$ set q = p^{(k+1)}$as our new best visited point 4. Return$q$as the resulting best point we visited ## Preliminaries – Elements a Solver works on There are two main ingredients a solver needs: a problem to work on and the state of a solver, which “identifies” the solver and stores intermediate results. ### The “Task” – An AbstractManoptProblem A problem in Manopt.jl usually consists of a manifold (an AbstractManifold) and an AbstractManifoldObjective describing the function we have and its features. In our case the objective is (just) a ManifoldCostObjective that stores cost function f(M,p) = .... More generally, it might for example store a gradient function or the Hessian or any other information we have about our task. This is something independent of the solver itself, since it only identifies the problem we want to solve independent of how we want to solve it – or in other words, this type contains all information that is static and independent of the specific solver at hand. Usually the problems variable is called mp. ### The Solver – An AbstractManoptSolverState Everything that is needed by a solver during the iterations, all its parameters, interims values that are needed beyond just one iteration, is stored in a subtype of the AbstractManoptSolverState. This identifies the solver uniquely. In our case we want to store five things We can defined this as mutable struct RandomWalkState{ P, R<:AbstractRetractionMethod, S<:StoppingCriterion, } <: AbstractManoptSolverState p::P q::P σ::Float64 retraction_method::R stop::S end The stopping criterion is usually stored in the state’s stop field. If you have a reason to do otherwise, you have one more function to implement (see next section). For ease of use, we can provide a constructor, that for example chooses a good default for the retraction based on a given manifold. function RandomWalkState(M::AbstractManifold, p::P=rand(M); σ = 0.1, retraction_method::R=default_retraction_method(M), stopping_criterion::S=StopAfterIteration(200) ) where {P, R<:AbstractRetractionMethod, S<:StoppingCriterion} return RandomWalkState{P,R,S}(p, copy(M, p), σ, retraction_method, stopping_criterion) end Parametrising the state avoid that we have abstract typed fields. The keyword arguments for the retraction and stopping criterion are the ones usually used in Manopt.jl and provide an easy way to construct this state now. States usually have a shortened name as their variable, we will use rws for our state here. ## Implementing the Your solver There is basically only two methods we need to implement for our solver • initialize_solver!(mp, rws) which initialises the solver before the first iteration • step_solver!(mp, rws, i) which implements the ith iteration, where i is given to you as the third parameter • get_iterate(rws) which accesses the iterate from other places in the solver • get_solver_result(rws) returning the solvers final (best) point we reached. By default this would return the last iterate rws.p (or more precisely calls get_iterate), but since we randomly walk and remember our best point in q, this has to return rws.q. The first two functions are in-place functions, that is they modify our solver state rws. You implement these by multiple dispatch on the types after importing said functions from Manopt: import Manopt: initialize_solver!, step_solver!, get_iterate, get_solver_result The state above has two fields where we use the common names used in Manopt.jl, that is the StoppingCriterion is usually in stop and the iterate in p. If your choice is different, you need to reimplement • stop_solver!(mp, rws, i) to determine whether or not to stop after the ith iteration. • get_iterate(rws) to access the current iterate We recommend to follow the general scheme with the stop field. If you have specific criteria when to stop, consider implementing your own stoping criterion instead. ### Initialization & Iterate Access For our solver, there is not so much to initialize, just to be safe we should copy over the initial value in p we start with, to q. We do not have to care about remembering the iterate, that is done by Manopt.jl. For the iterate access we just have to pass p. function initialize_solver!(mp::AbstractManoptProblem, rws::RandomWalkState) copyto!(M, rws.q, rws.p) # Set p^{(0)} = q return rws end get_iterate(rws::RandomWalkState) = rws.p get_solver_result(rws::RandomWalkState) = rws.q and similarly we implement the step. Here we make use of the fact that the problem (and also the objective in fact) have access functions for their elements, the one we need is get_cost. function step_solver!(mp::AbstractManoptProblem, rws::RandomWalkState, i) M = get_manifold(mp) # for ease of use get the manifold from the problem X = rand(M; vector_at=p) # generate a direction X .*= rws.σ/norm(M, p, X) # Walk retract!(M, rws.p, rws.p, X, rws.retraction_method) # is the new point better? Then store it if get_cost(mp, rws.p) < get_cost(mp, rws.q) copyto!(M, rws.p, rws.q) end return rws end Performance wise we could improve the number of allocations by making X also a field of our rws but let’s keep it simple here. We could also store the cost of q in the state, but we will see how to easily also enable this solver to allow for caching. In practice, however, it is preferable to cache intermediate values like cost of q in the state when it can be easily achieved. This way we do not have to deal with overheads of an external cache. Now we can just run the solver already! We take the same example as for the other tutorials We first define our task, the Riemannian Center of Mass from the Get Started: Optimize! tutorial. Random.seed!(23) n = 100 σ = π / 8 M = Sphere(2) p = 1 / sqrt(2) * [1.0, 0.0, 1.0] data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n]; f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2) We can now generate the problem with its objective and the state mp = DefaultManoptProblem(M, ManifoldCostObjective(f)) s = RandomWalkState(M; σ = 0.2) solve!(mp, s) get_solver_result(s) 3-element Vector{Float64}: -0.2412674850987521 0.8608618657176527 -0.44800317943876844 The function solve! works also in place of s, but the last line illustrates how to access the result in general; we could also just look at s.p, but the function get_iterate is also used in several other places. We could for example easily set up a second solver to work from a specified starting point with a different σ like s2 = RandomWalkState(M, [1.0, 0.0, 0.0]; σ = 0.1) solve!(mp, s2) get_solver_result(s2) 3-element Vector{Float64}: 1.0 0.0 0.0 ## Ease of Use I: The high level interface(s) Manopt.jl offers a few additional features for solvers in their high level interfaces, for example debug= for debug, record= keywords for debug and recording within solver states or count= and cache keywords for the objective. We can introduce these here as well with just a few lines of code. There are usually two steps. We further need three internal function from Manopt.jl using Manopt: get_solver_return, indicates_convergence, status_summary ### A high level interface using the objective This could be considered as an interims step to the high-level interface: If we already have the objective – in our case a ManifoldCostObjective at hand, the high level interface consists of the steps 1. possibly decorate the objective 2. generate the problem 3. generate and possiblz generate the state 4. call the solver 5. determine the return value We illustrate the step with an in-place variant here. A variant that keeps the given start point unchanged would just add a copy(M, p) upfront. Manopt.jl provides both variants. function random_walk_algorithm!( M::AbstractManifold, mgo::ManifoldCostObjective, p; σ = 0.1, retraction_method::AbstractRetractionMethod=default_retraction_method(M, typeof(p)), stopping_criterion::StoppingCriterion=StopAfterIteration(200), kwargs..., ) dmgo = decorate_objective!(M, mgo; kwargs...) dmp = DefaultManoptProblem(M, dmgo) s = RandomWalkState(M, [1.0, 0.0, 0.0]; σ=0.1, retraction_method=retraction_method, stopping_criterion=stopping_criterion, ) ds = decorate_state!(s; kwargs...) solve!(dmp, ds) return get_solver_return(get_objective(dmp), ds) end random_walk_algorithm! (generic function with 1 method) ### The high level interface Starting from the last section, the usual call a user would prefer is just passing a manifold M the cost f and maybe a start point p. function random_walk_algorithm!(M::AbstractManifold, f, p=rand(M); kwargs...) mgo = ManifoldCostObjective(f) return random_walk_algorithm!(M, mgo, p; kwargs...) end random_walk_algorithm! (generic function with 3 methods) ## Ease of Use II: The State Summary For the case that you set return_state=true the solver should return a summary of the run. When a show method is provided, users can easily read such summary in a terminal. It should reflect its main parameters, if they are not too verbose and provide information about the reason it stopped and whether this indicates convergence. Here it would for example look like import Base: show function show(io::IO, rws::RandomWalkState) i = get_count(rws, :Iterations) Iter = (i > 0) ? "After$i iterations\n" : ""
Conv = indicates_convergence(rws.stop) ? "Yes" : "No"
s = """
# Solver state for Manopt.jls Tutorial Random Walk
$Iter ## Parameters * retraction method:$(rws.retraction_method)
* σ                : $(rws.σ) ## Stopping Criterion$(status_summary(rws.stop))
This indicates convergence: \$Conv"""
return print(io, s)
end
show (generic function with 671 methods)

Now the algorithm can be easily called and provides – if wanted – all features of a Manopt.jl algorithm. For example to see the summary, we could now just call

q = random_walk_algorithm!(M, f; return_state=true)
# Solver state for Manopt.jls Tutorial Random Walk
After 200 iterations

## Parameters
* retraction method: ExponentialRetraction()
* σ                : 0.1

## Stopping Criterion
Max Iteration 200:  reached
This indicates convergence: No

## Conclusion & Beyond

We saw in this tutorial how to implement a simple cost-based algorithm, to illustrate how optimization algorithms are covered in Manopt.jl.

One feature we did not cover is that most algorithms allow for inplace and allocation functions, as soon as they work on more than just the cost, e.g. gradients, proximal maps or Hessians. This is usually a keyword argument of the objective and hence also part of the high-level interfaces.