# How to implementing your own solver

Ronny Bergmann

When you have used a few solvers from `Manopt.jl`

for example like in the opening tutorial Get started: optimize! you might come to the idea of implementing a solver yourself.

After a short introduction of the algorithm we aim to implement, this tutorial first discusses the structural details, for example what a solver consists of and “works with”. Afterwards, we show how to implement the algorithm. Finally, we discuss how to make the algorithm both nice for the user as well as initialized in a way, that it can benefit from features already available in `Manopt.jl`

.

If you have implemented your own solver, we would be very happy to have that within `Manopt.jl`

as well, so maybe consider opening a Pull Request

`using Manopt, Manifolds, Random`

## Our guiding example: a random walk minimization

Since most serious algorithms should be implemented in `Manopt.jl`

themselves directly, we implement a solver that randomly walks on the manifold and keeps track of the lowest point visited. As for algorithms in `Manopt.jl`

we aim to implement this *generically* for any manifold that is implemented using ManifoldsBase.jl.

**The random walk minimization**

Given:

- a manifold $\mathcal M$
- a starting point $p=p^{(0)}$
- a cost function $f: \mathcal M → ℝ$.
- a parameter $\sigma > 0$.
- a retraction $\operatorname{retr}_p(X)$ that maps $X ∈ T_p\mathcal M$ to the manifold.

We can run the following steps of the algorithm

- set $k=0$
- set our best point $q = p^{(0)}$
- Repeat until a stopping criterion is fulfilled
- Choose a random tangent vector $X^{(k)} ∈ T_{p^{(k)}}\mathcal M$ of length $\lVert X^{(k)} \rVert = \sigma$
- “Walk” along this direction, that is $p^{(k+1)} = \operatorname{retr}_{p^{(k)}}(X^{(k)})$
- If $f(p^{(k+1)}) < f(q)$ set q = p^{(k+1)}$ as our new best visited point

- Return $q$ as the resulting best point we visited

## Preliminaries: elements a solver works on

There are two main ingredients a solver needs: a problem to work on and the state of a solver, which “identifies” the solver and stores intermediate results.

### Specifying the task: an `AbstractManoptProblem`

A problem in `Manopt.jl`

usually consists of a manifold (an `AbstractManifold`

) and an `AbstractManifoldObjective`

describing the function we have and its features. In our case the objective is (just) a `ManifoldCostObjective`

that stores cost function `f(M,p) -> R`

. More generally, it might for example store a gradient function or the Hessian or any other information we have about our task.

This is something independent of the solver itself, since it only identifies the problem we want to solve independent of how we want to solve it, or in other words, this type contains all information that is static and independent of the specific solver at hand.

Usually the problems variable is called `mp`

.

### Specifying a solver: an `AbstractManoptSolverState`

Everything that is needed by a solver during the iterations, all its parameters, interim values that are needed beyond just one iteration, is stored in a subtype of the `AbstractManoptSolverState`

. This identifies the solver uniquely.

In our case we want to store five things

- the current iterate
`p`

$=p^{(k)}$ - the best visited point $q$
- the variable $\sigma > 0$
- the retraction $\operatorname{retr}$ to use (cf. retractions and inverse retractions)
- a criterion, when to stop: a
`StoppingCriterion`

We can defined this as

```
mutable struct RandomWalkState{
P,
R<:AbstractRetractionMethod,
S<:StoppingCriterion,
} <: AbstractManoptSolverState
p::P
q::P
σ::Float64
retraction_method::R
stop::S
end
```

The stopping criterion is usually stored in the state’s `stop`

field. If you have a reason to do otherwise, you have one more function to implement (see next section). For ease of use, a constructor can be provided, that for example chooses a good default for the retraction based on a given manifold.

```
function RandomWalkState(M::AbstractManifold, p::P=rand(M);
σ = 0.1,
retraction_method::R=default_retraction_method(M, typeof(p)),
stopping_criterion::S=StopAfterIteration(200)
) where {P, R<:AbstractRetractionMethod, S<:StoppingCriterion}
return RandomWalkState{P,R,S}(p, copy(M, p), σ, retraction_method, stopping_criterion)
end
```

Parametrising the state avoid that we have abstract typed fields. The keyword arguments for the retraction and stopping criterion are the ones usually used in `Manopt.jl`

and provide an easy way to construct this state now.

States usually have a shortened name as their variable, we use `rws`

for our state here.

## Implementing your solver

There is basically only two methods we need to implement for our solver

`initialize_solver!(mp, rws)`

which initialises the solver before the first iteration`step_solver!(mp, rws, i)`

which implements the`i`

th iteration, where`i`

is given to you as the third parameter`get_iterate(rws)`

which accesses the iterate from other places in the solver`get_solver_result(rws)`

returning the solvers final (best) point we reached. By default this would return the last iterate`rws.p`

(or more precisely calls`get_iterate`

), but since we randomly walk and remember our best point in`q`

, this has to return`rws.q`

.

The first two functions are in-place functions, that is they modify our solver state `rws`

. You implement these by multiple dispatch on the types after importing said functions from Manopt:

`import Manopt: initialize_solver!, step_solver!, get_iterate, get_solver_result`

The state we defined before has two fields where we use the common names used in `Manopt.jl`

, that is the `StoppingCriterion`

is usually in `stop`

and the iterate in `p`

. If your choice is different, you need to reimplement

`stop_solver!(mp, rws, i)`

to determine whether or not to stop after the`i`

th iteration.`get_iterate(rws)`

to access the current iterate

We recommend to follow the general scheme with the `stop`

field. If you have specific criteria when to stop, consider implementing your own stopping criterion instead.

### Initialization and iterate access

For our solver, there is not so much to initialize, just to be safe we should copy over the initial value in `p`

we start with, to `q`

. We do not have to care about remembering the iterate, that is done by `Manopt.jl`

. For the iterate access we just have to pass `p`

.

```
function initialize_solver!(mp::AbstractManoptProblem, rws::RandomWalkState)
copyto!(M, rws.q, rws.p) # Set p^{(0)} = q
return rws
end
get_iterate(rws::RandomWalkState) = rws.p
get_solver_result(rws::RandomWalkState) = rws.q
```

and similarly we implement the step. Here we make use of the fact that the problem (and also the objective in fact) have access functions for their elements, the one we need is `get_cost`

.

```
function step_solver!(mp::AbstractManoptProblem, rws::RandomWalkState, i)
M = get_manifold(mp) # for ease of use get the manifold from the problem
X = rand(M; vector_at=p) # generate a direction
X .*= rws.σ/norm(M, p, X)
# Walk
retract!(M, rws.p, rws.p, X, rws.retraction_method)
# is the new point better? Then store it
if get_cost(mp, rws.p) < get_cost(mp, rws.q)
copyto!(M, rws.p, rws.q)
end
return rws
end
```

Performance wise we could improve the number of allocations by making `X`

also a field of our `rws`

but let’s keep it simple here. We could also store the cost of `q`

in the state, but we shall see how to easily also enable this solver to allow for caching. In practice, however, it is preferable to cache intermediate values like cost of `q`

in the state when it can be easily achieved. This way we do not have to deal with overheads of an external cache.

Now we can just run the solver already. We take the same example as for the other tutorials

We first define our task, the Riemannian Center of Mass from the Get started: optimize! tutorial.

```
Random.seed!(23)
n = 100
σ = π / 8
M = Sphere(2)
p = 1 / sqrt(2) * [1.0, 0.0, 1.0]
data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n];
f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2)
```

We can now generate the problem with its objective and the state

```
mp = DefaultManoptProblem(M, ManifoldCostObjective(f))
s = RandomWalkState(M; σ = 0.2)
solve!(mp, s)
get_solver_result(s)
```

```
3-element Vector{Float64}:
-0.2412674850987521
0.8608618657176527
-0.44800317943876844
```

The function `solve!`

works also in place of `s`

, but the last line illustrates how to access the result in general; we could also just look at `s.p`

, but the function `get_iterate`

is also used in several other places.

We could for example easily set up a second solver to work from a specified starting point with a different `σ`

like

```
s2 = RandomWalkState(M, [1.0, 0.0, 0.0]; σ = 0.1)
solve!(mp, s2)
get_solver_result(s2)
```

```
3-element Vector{Float64}:
1.0
0.0
0.0
```

## Ease of use I: a high level interface

`Manopt.jl`

offers a few additional features for solvers in their high level interfaces, for example `debug=`

for debug, `record=`

keywords for debug and recording within solver states or `count=`

and `cache`

keywords for the objective.

We can introduce these here as well with just a few lines of code. There are usually two steps. We further need three internal function from `Manopt.jl`

`using Manopt: get_solver_return, indicates_convergence, status_summary`

### A high level interface using the objective

This could be considered as an interim step to the high-level interface: if objective, a `ManifoldCostObjective`

is already initialized, the high level interface consists of the steps

- possibly decorate the objective
- generate the problem
- generate and possibly generate the state
- call the solver
- determine the return value

We illustrate the step with an in-place variant here. A variant that keeps the given start point unchanged would just add a `copy(M, p)`

upfront. `Manopt.jl`

provides both variants.

```
function random_walk_algorithm!(
M::AbstractManifold,
mgo::ManifoldCostObjective,
p;
σ = 0.1,
retraction_method::AbstractRetractionMethod=default_retraction_method(M, typeof(p)),
stopping_criterion::StoppingCriterion=StopAfterIteration(200),
kwargs...,
)
dmgo = decorate_objective!(M, mgo; kwargs...)
dmp = DefaultManoptProblem(M, dmgo)
s = RandomWalkState(M, [1.0, 0.0, 0.0];
σ=0.1,
retraction_method=retraction_method, stopping_criterion=stopping_criterion,
)
ds = decorate_state!(s; kwargs...)
solve!(dmp, ds)
return get_solver_return(get_objective(dmp), ds)
end
```

`random_walk_algorithm! (generic function with 1 method)`

### The high level interface

Starting from the last section, the usual call a user would prefer is just passing a manifold `M`

the cost `f`

and maybe a start point `p`

.

```
function random_walk_algorithm!(M::AbstractManifold, f, p=rand(M); kwargs...)
mgo = ManifoldCostObjective(f)
return random_walk_algorithm!(M, mgo, p; kwargs...)
end
```

`random_walk_algorithm! (generic function with 3 methods)`

## Ease of Use II: the state summary

For the case that you set `return_state=true`

the solver should return a summary of the run. When a `show`

method is provided, users can easily read such summary in a terminal. It should reflect its main parameters, if they are not too verbose and provide information about the reason it stopped and whether this indicates convergence.

Here it would for example look like

```
import Base: show
function show(io::IO, rws::RandomWalkState)
i = get_count(rws, :Iterations)
Iter = (i > 0) ? "After $i iterations\n" : ""
Conv = indicates_convergence(rws.stop) ? "Yes" : "No"
s = """
# Solver state for `Manopt.jl`s Tutorial Random Walk
$Iter
## Parameters
* retraction method: $(rws.retraction_method)
* σ : $(rws.σ)
## Stopping criterion
$(status_summary(rws.stop))
This indicates convergence: $Conv"""
return print(io, s)
end
```

Now the algorithm can be easily called and provides all features of a `Manopt.jl`

algorithm. For example to see the summary, we could now just call

`q = random_walk_algorithm!(M, f; return_state=true)`

```
# Solver state for `Manopt.jl`s Tutorial Random Walk
After 200 iterations
## Parameters
* retraction method: ExponentialRetraction()
* σ : 0.1
## Stopping criterion
Max Iteration 200: reached
This indicates convergence: No
```

## Conclusion & beyond

We saw in this tutorial how to implement a simple cost-based algorithm, to illustrate how optimization algorithms are covered in `Manopt.jl`

.

One feature we did not cover is that most algorithms allow for in-place and allocation functions, as soon as they work on more than just the cost, for example use gradients, proximal maps or Hessians. This is usually a keyword argument of the objective and hence also part of the high-level interfaces.

## Technical details

This tutorial is cached. It was last run on the following package versions.

```
using Pkg
Pkg.status()
```

```
Status `~/work/Manopt.jl/Manopt.jl/tutorials/Project.toml`
[6e4b80f9] BenchmarkTools v1.5.0
[5ae59095] Colors v0.12.11
[31c24e10] Distributions v0.25.111
[26cc04aa] FiniteDifferences v0.12.32
[7073ff75] IJulia v1.25.0
[8ac3fa9e] LRUCache v1.6.1
[af67fdf4] ManifoldDiff v0.3.11
[1cead3c2] Manifolds v0.10.0
[3362f125] ManifoldsBase v0.15.14
[0fc0a36d] Manopt v0.5.0 `~/work/Manopt.jl/Manopt.jl`
[91a5bcdd] Plots v1.40.8
[731186ca] RecursiveArrayTools v3.27.0
```

```
using Dates
now()
```

`2024-08-29T06:09:23.096`