How to implementing your own solver
Ronny Bergmann
When you have used a few solvers from Manopt.jl
for example like in the opening tutorial Get started: optimize! you might come to the idea of implementing a solver yourself.
After a short introduction of the algorithm we aim to implement, this tutorial first discusses the structural details, for example what a solver consists of and “works with”. Afterwards, we show how to implement the algorithm. Finally, we discuss how to make the algorithm both nice for the user as well as initialized in a way, that it can benefit from features already available in Manopt.jl
.
If you have implemented your own solver, we would be very happy to have that within Manopt.jl
as well, so maybe consider opening a Pull Request
using Manopt, Manifolds, Random
Our guiding example: a random walk minimization
Since most serious algorithms should be implemented in Manopt.jl
themselves directly, we implement a solver that randomly walks on the manifold and keeps track of the lowest point visited. As for algorithms in Manopt.jl
we aim to implement this generically for any manifold that is implemented using ManifoldsBase.jl.
The random walk minimization
Given:
- a manifold $\mathcal M$
- a starting point $p=p^{(0)}$
- a cost function $f: \mathcal M → ℝ$.
- a parameter $\sigma > 0$.
- a retraction $\operatorname{retr}_p(X)$ that maps $X ∈ T_p\mathcal M$ to the manifold.
We can run the following steps of the algorithm
- set $k=0$
- set our best point $q = p^{(0)}$
- Repeat until a stopping criterion is fulfilled
- Choose a random tangent vector $X^{(k)} ∈ T_{p^{(k)}}\mathcal M$ of length $\lVert X^{(k)} \rVert = \sigma$
- “Walk” along this direction, that is $p^{(k+1)} = \operatorname{retr}_{p^{(k)}}(X^{(k)})$
- If $f(p^{(k+1)}) < f(q)$ set q = p^{(k+1)}$ as our new best visited point
- Return $q$ as the resulting best point we visited
Preliminaries: elements a solver works on
There are two main ingredients a solver needs: a problem to work on and the state of a solver, which “identifies” the solver and stores intermediate results.
Specifying the task: an AbstractManoptProblem
A problem in Manopt.jl
usually consists of a manifold (an AbstractManifold
) and an AbstractManifoldObjective
describing the function we have and its features. In our case the objective is (just) a ManifoldCostObjective
that stores cost function f(M,p) -> R
. More generally, it might for example store a gradient function or the Hessian or any other information we have about our task.
This is something independent of the solver itself, since it only identifies the problem we want to solve independent of how we want to solve it, or in other words, this type contains all information that is static and independent of the specific solver at hand.
Usually the problems variable is called mp
.
Specifying a solver: an AbstractManoptSolverState
Everything that is needed by a solver during the iterations, all its parameters, interim values that are needed beyond just one iteration, is stored in a subtype of the AbstractManoptSolverState
. This identifies the solver uniquely.
In our case we want to store five things
- the current iterate
p
$=p^{(k)}$ - the best visited point $q$
- the variable $\sigma > 0$
- the retraction $\operatorname{retr}$ to use (cf. retractions and inverse retractions)
- a criterion, when to stop: a
StoppingCriterion
We can defined this as
mutable struct RandomWalkState{
P,
R<:AbstractRetractionMethod,
S<:StoppingCriterion,
} <: AbstractManoptSolverState
p::P
q::P
σ::Float64
retraction_method::R
stop::S
end
The stopping criterion is usually stored in the state’s stop
field. If you have a reason to do otherwise, you have one more function to implement (see next section). For ease of use, a constructor can be provided, that for example chooses a good default for the retraction based on a given manifold.
function RandomWalkState(M::AbstractManifold, p::P=rand(M);
σ = 0.1,
retraction_method::R=default_retraction_method(M, typeof(p)),
stopping_criterion::S=StopAfterIteration(200)
) where {P, R<:AbstractRetractionMethod, S<:StoppingCriterion}
return RandomWalkState{P,R,S}(p, copy(M, p), σ, retraction_method, stopping_criterion)
end
Parametrising the state avoid that we have abstract typed fields. The keyword arguments for the retraction and stopping criterion are the ones usually used in Manopt.jl
and provide an easy way to construct this state now.
States usually have a shortened name as their variable, we use rws
for our state here.
Implementing your solver
There is basically only two methods we need to implement for our solver
initialize_solver!(mp, rws)
which initialises the solver before the first iterationstep_solver!(mp, rws, i)
which implements thei
th iteration, wherei
is given to you as the third parameterget_iterate(rws)
which accesses the iterate from other places in the solverget_solver_result(rws)
returning the solvers final (best) point we reached. By default this would return the last iteraterws.p
(or more precisely callsget_iterate
), but since we randomly walk and remember our best point inq
, this has to returnrws.q
.
The first two functions are in-place functions, that is they modify our solver state rws
. You implement these by multiple dispatch on the types after importing said functions from Manopt:
import Manopt: initialize_solver!, step_solver!, get_iterate, get_solver_result
The state we defined before has two fields where we use the common names used in Manopt.jl
, that is the StoppingCriterion
is usually in stop
and the iterate in p
. If your choice is different, you need to reimplement
stop_solver!(mp, rws, i)
to determine whether or not to stop after thei
th iteration.get_iterate(rws)
to access the current iterate
We recommend to follow the general scheme with the stop
field. If you have specific criteria when to stop, consider implementing your own stopping criterion instead.
Initialization and iterate access
For our solver, there is not so much to initialize, just to be safe we should copy over the initial value in p
we start with, to q
. We do not have to care about remembering the iterate, that is done by Manopt.jl
. For the iterate access we just have to pass p
.
function initialize_solver!(mp::AbstractManoptProblem, rws::RandomWalkState)
copyto!(M, rws.q, rws.p) # Set p^{(0)} = q
return rws
end
get_iterate(rws::RandomWalkState) = rws.p
get_solver_result(rws::RandomWalkState) = rws.q
and similarly we implement the step. Here we make use of the fact that the problem (and also the objective in fact) have access functions for their elements, the one we need is get_cost
.
function step_solver!(mp::AbstractManoptProblem, rws::RandomWalkState, i)
M = get_manifold(mp) # for ease of use get the manifold from the problem
X = rand(M; vector_at=p) # generate a direction
X .*= rws.σ/norm(M, p, X)
# Walk
retract!(M, rws.p, rws.p, X, rws.retraction_method)
# is the new point better? Then store it
if get_cost(mp, rws.p) < get_cost(mp, rws.q)
copyto!(M, rws.p, rws.q)
end
return rws
end
Performance wise we could improve the number of allocations by making X
also a field of our rws
but let’s keep it simple here. We could also store the cost of q
in the state, but we shall see how to easily also enable this solver to allow for caching. In practice, however, it is preferable to cache intermediate values like cost of q
in the state when it can be easily achieved. This way we do not have to deal with overheads of an external cache.
Now we can just run the solver already. We take the same example as for the other tutorials
We first define our task, the Riemannian Center of Mass from the Get started: optimize! tutorial.
Random.seed!(23)
n = 100
σ = π / 8
M = Sphere(2)
p = 1 / sqrt(2) * [1.0, 0.0, 1.0]
data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n];
f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2)
We can now generate the problem with its objective and the state
mp = DefaultManoptProblem(M, ManifoldCostObjective(f))
s = RandomWalkState(M; σ = 0.2)
solve!(mp, s)
get_solver_result(s)
3-element Vector{Float64}:
-0.2412674850987521
0.8608618657176527
-0.44800317943876844
The function solve!
works also in place of s
, but the last line illustrates how to access the result in general; we could also just look at s.p
, but the function get_iterate
is also used in several other places.
We could for example easily set up a second solver to work from a specified starting point with a different σ
like
s2 = RandomWalkState(M, [1.0, 0.0, 0.0]; σ = 0.1)
solve!(mp, s2)
get_solver_result(s2)
3-element Vector{Float64}:
1.0
0.0
0.0
Ease of use I: a high level interface
Manopt.jl
offers a few additional features for solvers in their high level interfaces, for example debug=
for debug, record=
keywords for debug and recording within solver states or count=
and cache
keywords for the objective.
We can introduce these here as well with just a few lines of code. There are usually two steps. We further need three internal function from Manopt.jl
using Manopt: get_solver_return, indicates_convergence, status_summary
A high level interface using the objective
This could be considered as an interim step to the high-level interface: if objective, a ManifoldCostObjective
is already initialized, the high level interface consists of the steps
- possibly decorate the objective
- generate the problem
- generate and possibly generate the state
- call the solver
- determine the return value
We illustrate the step with an in-place variant here. A variant that keeps the given start point unchanged would just add a copy(M, p)
upfront. Manopt.jl
provides both variants.
function random_walk_algorithm!(
M::AbstractManifold,
mgo::ManifoldCostObjective,
p;
σ = 0.1,
retraction_method::AbstractRetractionMethod=default_retraction_method(M, typeof(p)),
stopping_criterion::StoppingCriterion=StopAfterIteration(200),
kwargs...,
)
dmgo = decorate_objective!(M, mgo; kwargs...)
dmp = DefaultManoptProblem(M, dmgo)
s = RandomWalkState(M, [1.0, 0.0, 0.0];
σ=0.1,
retraction_method=retraction_method, stopping_criterion=stopping_criterion,
)
ds = decorate_state!(s; kwargs...)
solve!(dmp, ds)
return get_solver_return(get_objective(dmp), ds)
end
random_walk_algorithm! (generic function with 1 method)
The high level interface
Starting from the last section, the usual call a user would prefer is just passing a manifold M
the cost f
and maybe a start point p
.
function random_walk_algorithm!(M::AbstractManifold, f, p=rand(M); kwargs...)
mgo = ManifoldCostObjective(f)
return random_walk_algorithm!(M, mgo, p; kwargs...)
end
random_walk_algorithm! (generic function with 3 methods)
Ease of Use II: the state summary
For the case that you set return_state=true
the solver should return a summary of the run. When a show
method is provided, users can easily read such summary in a terminal. It should reflect its main parameters, if they are not too verbose and provide information about the reason it stopped and whether this indicates convergence.
Here it would for example look like
import Base: show
function show(io::IO, rws::RandomWalkState)
i = get_count(rws, :Iterations)
Iter = (i > 0) ? "After $i iterations\n" : ""
Conv = indicates_convergence(rws.stop) ? "Yes" : "No"
s = """
# Solver state for `Manopt.jl`s Tutorial Random Walk
$Iter
## Parameters
* retraction method: $(rws.retraction_method)
* σ : $(rws.σ)
## Stopping criterion
$(status_summary(rws.stop))
This indicates convergence: $Conv"""
return print(io, s)
end
Now the algorithm can be easily called and provides all features of a Manopt.jl
algorithm. For example to see the summary, we could now just call
q = random_walk_algorithm!(M, f; return_state=true)
# Solver state for `Manopt.jl`s Tutorial Random Walk
After 200 iterations
## Parameters
* retraction method: ExponentialRetraction()
* σ : 0.1
## Stopping criterion
Max Iteration 200: reached
This indicates convergence: No
Conclusion & beyond
We saw in this tutorial how to implement a simple cost-based algorithm, to illustrate how optimization algorithms are covered in Manopt.jl
.
One feature we did not cover is that most algorithms allow for in-place and allocation functions, as soon as they work on more than just the cost, for example use gradients, proximal maps or Hessians. This is usually a keyword argument of the objective and hence also part of the high-level interfaces.
Technical details
This tutorial is cached. It was last run on the following package versions.
using Pkg
Pkg.status()
Status `~/work/Manopt.jl/Manopt.jl/tutorials/Project.toml`
[6e4b80f9] BenchmarkTools v1.5.0
⌃ [5ae59095] Colors v0.12.11
[31c24e10] Distributions v0.25.115
[26cc04aa] FiniteDifferences v0.12.32
[7073ff75] IJulia v1.26.0
[8ac3fa9e] LRUCache v1.6.1
⌅ [af67fdf4] ManifoldDiff v0.3.13
⌃ [1cead3c2] Manifolds v0.10.7
[3362f125] ManifoldsBase v0.15.23
[0fc0a36d] Manopt v0.5.5 `~/work/Manopt.jl/Manopt.jl`
[91a5bcdd] Plots v1.40.9
[731186ca] RecursiveArrayTools v3.27.4
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`
using Dates
now()
2024-12-25T14:13:01.400