Advanced Recording Example

The recording and debug possiblities make it possible to record nearly any data during the iterations. This tutorial illustrates how to

  • record one value during the iterations
  • record multiple values during the iterations and access them afterwards
  • define an own RecordAction to perform individual recordings.

Several predefined recordings exist, for example RecordCost or RecordGradient, depending on the solver used. For fields of the Options this can be directly done using the RecordEntry. For others, an own RecordAction can be defined.

We illustrate these using the gradient descent used in the introductionary Get Started: Optimize! tutorial example of computing the Riemannian Center of mass and refer to that tutorial for the mathematical details.

using Manopt, Manifolds, Random
Random.seed!(42)
m = 30
M = Sphere(m)
n = 800
σ = π / 8
x = zeros(Float64, m + 1)
x[2] = 1.0
data = [exp(M, x, random_tangent(M, x, Val(:Gaussian), σ)) for i in 1:n];
F(M, y) = sum(1 / (2 * n) * distance.(Ref(M), Ref(y), data) .^ 2)
gradF(M, y) = sum(1 / n * grad_distance.(Ref(M), data, Ref(y)))

Plain examples

For the high level interfaces of the solvers, like gradient_descent we have to set return_options to true to obtain the whole options structure and not only the resulting resulting minimizer.

R = gradient_descent(M, F, gradF, data[1]; record=:Cost, return_options=true)
RecordOptions{GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction},NamedTuple{(:Iteration,),Tuple{RecordCost}}}(GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction}([-0.013171687776920208, -0.9988254858704124, 0.010022911212629302, 0.0013344707055430774, -0.0002677083094212463, 0.0072844866211142075, 0.0015113333205996316, -0.011281003233671358, -0.006712279079023681, 0.007454047607253493  …  -0.008560555645040935, -0.0009246495682439206, -0.002882588217415678, 0.005024753641084912, 0.0014807368499568049, -0.00038585441415452137, -0.000792496232822748, 0.0019746551544632665, -0.004531172402756976, -0.0026734658178187744], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-8, "The algorithm reached approximately critical point after 20 iterations; the gradient norm (3.856776202540161e-9) is less than 1.0e-8.\n")), "The algorithm reached approximately critical point after 20 iterations; the gradient norm (3.856776202540161e-9) is less than 1.0e-8.\n"), ConstantStepsize(1.0), [-3.933077346880004e-10, -4.3171833343208076e-11, 6.56732853862585e-11, -1.5759354289753523e-10, 8.464979115222865e-10, -7.668974285907765e-10, -2.0538018128627572e-10, 1.900122719621566e-10, -1.4441114040805898e-9, 5.516133830632021e-10  …  3.412740687581419e-10, 7.649628380821533e-11, -3.992971282394328e-10, -2.58712254612816e-9, 2.6960121234011897e-10, -6.873708650878552e-10, -2.3091756615580342e-11, 2.5162697842246243e-10, -2.967947426452816e-11, 2.003055126278963e-10], ExponentialRetraction()), (Iteration = RecordCost([0.5467764136711742, 0.536496662070443, 0.5349809370338188, 0.5347588107780821, 0.5347262771824008, 0.5347215112273678, 0.5347208128443974, 0.5347207104761789, 0.5347206954668047, 0.5347206932654798, 0.5347206929425345, 0.534720692895144, 0.5347206928881875, 0.5347206928871665, 0.5347206928870166, 0.5347206928869943, 0.5347206928869911, 0.5347206928869905, 0.5347206928869904, 0.5347206928869904]),))

You can attach different recorders to some operations (:Start. :Stop, :Iteration at time of writing), where :Iteration is the default, so the following is the same as get_record(R, :Iteration). We get

get_record(R)
20-element Array{Float64,1}:
 0.5467764136711742
 0.536496662070443
 0.5349809370338188
 0.5347588107780821
 0.5347262771824008
 0.5347215112273678
 0.5347208128443974
 0.5347207104761789
 0.5347206954668047
 0.5347206932654798
 0.5347206929425345
 0.534720692895144
 0.5347206928881875
 0.5347206928871665
 0.5347206928870166
 0.5347206928869943
 0.5347206928869911
 0.5347206928869905
 0.5347206928869904
 0.5347206928869904

To record more than one value, you can pass a array of a mix of symbols and RecordAction which gets mapped to a RecordGroup

R = gradient_descent(M, F, gradF, data[1]; record=[:Iteration, :Cost], return_options=true)
RecordOptions{GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction},NamedTuple{(:Iteration,),Tuple{RecordGroup}}}(GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction}([-0.013171687776920208, -0.9988254858704124, 0.010022911212629302, 0.0013344707055430774, -0.0002677083094212463, 0.0072844866211142075, 0.0015113333205996316, -0.011281003233671358, -0.006712279079023681, 0.007454047607253493  …  -0.008560555645040935, -0.0009246495682439206, -0.002882588217415678, 0.005024753641084912, 0.0014807368499568049, -0.00038585441415452137, -0.000792496232822748, 0.0019746551544632665, -0.004531172402756976, -0.0026734658178187744], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-8, "The algorithm reached approximately critical point after 20 iterations; the gradient norm (3.856776202540161e-9) is less than 1.0e-8.\n")), "The algorithm reached approximately critical point after 20 iterations; the gradient norm (3.856776202540161e-9) is less than 1.0e-8.\n"), ConstantStepsize(1.0), [-3.933077346880004e-10, -4.3171833343208076e-11, 6.56732853862585e-11, -1.5759354289753523e-10, 8.464979115222865e-10, -7.668974285907765e-10, -2.0538018128627572e-10, 1.900122719621566e-10, -1.4441114040805898e-9, 5.516133830632021e-10  …  3.412740687581419e-10, 7.649628380821533e-11, -3.992971282394328e-10, -2.58712254612816e-9, 2.6960121234011897e-10, -6.873708650878552e-10, -2.3091756615580342e-11, 2.5162697842246243e-10, -2.967947426452816e-11, 2.003055126278963e-10], ExponentialRetraction()), (Iteration = RecordGroup(RecordAction[RecordIteration([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]), RecordCost([0.5467764136711742, 0.536496662070443, 0.5349809370338188, 0.5347588107780821, 0.5347262771824008, 0.5347215112273678, 0.5347208128443974, 0.5347207104761789, 0.5347206954668047, 0.5347206932654798, 0.5347206929425345, 0.534720692895144, 0.5347206928881875, 0.5347206928871665, 0.5347206928870166, 0.5347206928869943, 0.5347206928869911, 0.5347206928869905, 0.5347206928869904, 0.5347206928869904])], Dict(:Iteration => 1,:Cost => 2)),))

Here, the Symbol :Cost is mapped to using the RecordCost action. The same holds for :Iteration and :Iterate and any member field of the current Options. To access these you can first extract the group of records (of the :Iteration action) and then access the :Cost

ra = get_record_action(R)[:Cost]
20-element Array{Float64,1}:
 0.5467764136711742
 0.536496662070443
 0.5349809370338188
 0.5347588107780821
 0.5347262771824008
 0.5347215112273678
 0.5347208128443974
 0.5347207104761789
 0.5347206954668047
 0.5347206932654798
 0.5347206929425345
 0.534720692895144
 0.5347206928881875
 0.5347206928871665
 0.5347206928870166
 0.5347206928869943
 0.5347206928869911
 0.5347206928869905
 0.5347206928869904
 0.5347206928869904

Or similarly

get_record(R, :Iteration, :Cost)
20-element Array{Float64,1}:
 0.5467764136711742
 0.536496662070443
 0.5349809370338188
 0.5347588107780821
 0.5347262771824008
 0.5347215112273678
 0.5347208128443974
 0.5347207104761789
 0.5347206954668047
 0.5347206932654798
 0.5347206929425345
 0.534720692895144
 0.5347206928881875
 0.5347206928871665
 0.5347206928870166
 0.5347206928869943
 0.5347206928869911
 0.5347206928869905
 0.5347206928869904
 0.5347206928869904

Note that the first symbol again refers to the point where we record (not to the thing we record). We can also pass a Tuple as second argument to have our own order (not that now the second :Iteration refers to the recorded iteratons)

get_record(R, :Iteration, (:Cost, :Iteration))
20-element Array{Tuple{Float64,Int64},1}:
 (0.5467764136711742, 1)
 (0.536496662070443, 2)
 (0.5349809370338188, 3)
 (0.5347588107780821, 4)
 (0.5347262771824008, 5)
 (0.5347215112273678, 6)
 (0.5347208128443974, 7)
 (0.5347207104761789, 8)
 (0.5347206954668047, 9)
 (0.5347206932654798, 10)
 (0.5347206929425345, 11)
 (0.534720692895144, 12)
 (0.5347206928881875, 13)
 (0.5347206928871665, 14)
 (0.5347206928870166, 15)
 (0.5347206928869943, 16)
 (0.5347206928869911, 17)
 (0.5347206928869905, 18)
 (0.5347206928869904, 19)
 (0.5347206928869904, 20)

A more complex example

To illustrate a complicated example let's record

  • the iteration number, cost and gradient field, but only every sixth iteration
  • the iteration at which we stop

We first generate the problem and the options

p = GradientProblem(M, F, gradF)
o = GradientDescentOptions(
    M,
    copy(data[1]);
    stopping_criterion=StopAfterIteration(200) | StopWhenGradientNormLess(10.0^-9),
)
GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction}([-0.03955888268160963, -0.8779652649012455, 0.0019319836442624376, -0.021306984435043162, 0.12648704877303707, -0.0814547413419126, -0.033339260945549724, 0.011108925098935594, -0.1879661160602998, 0.07138111553205544  …  0.035720759576640955, -0.015385100827104183, -0.05025900025037785, -0.30859191198481517, 0.04054081850698768, -0.10117327773801128, -0.02649469510827848, 0.026253469357567893, -0.000541630267770923, 0.0400314274185503], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-9, "")), #undef), ConstantStepsize(1.0), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ExponentialRetraction())

and now decorate these with RecordOptions

rI = RecordEvery(
    RecordGroup([
        :Iteration => RecordIteration(),
        :Cost => RecordCost(),
        :Gradient => RecordEntry(similar(data[1]), :gradient),
    ]),
    6,
)
sI = RecordIteration()
r = RecordOptions(o, Dict(:Iteration => rI, :Stop => sI))
r2 = solve(p, r)
RecordOptions{GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction},NamedTuple{(:Iteration, :Stop),Tuple{RecordEvery,RecordIteration}}}(GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction}([-0.013171687562367885, -0.9988254858473309, 0.010022911173862174, 0.0013344707882379395, -0.0002677087531867079, 0.007284487030165048, 0.0015113334265440613, -0.011281003334715156, -0.006712278313387459, 0.007454047312965663  …  -0.008560555825942161, -0.0009246496158967537, -0.0028825880049497443, 0.005024755015904156, 0.0014807367083181707, -0.00038585405334631124, -0.0007924962268515648, 0.0019746550194331483, -0.004531172384967613, -0.002673465918907319], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-9, "The algorithm reached approximately critical point after 22 iterations; the gradient norm (5.673508914769152e-10) is less than 1.0e-9.\n")), "The algorithm reached approximately critical point after 22 iterations; the gradient norm (5.673508914769152e-10) is less than 1.0e-9.\n"), ConstantStepsize(1.0), [-6.039159265985167e-11, -6.425622468326673e-12, 1.1355642229489427e-11, -2.277929421153936e-11, 1.2218354831058698e-10, -1.1372630755960428e-10, -2.8890151759059478e-11, 2.8040223720954098e-11, -2.12150862307324e-10, 8.18362990356392e-11  …  5.01215479032811e-11, 1.431754205211827e-11, -5.899276190159466e-11, -3.81440174964439e-10, 3.904555135160817e-11, -9.941452641572951e-11, -6.763190268238573e-13, 3.766488266025367e-11, -5.250628165760773e-12, 2.720983600110541e-11], ExponentialRetraction()), (Iteration = RecordEvery(RecordGroup(RecordAction[RecordIteration([6, 12, 18]), RecordCost([0.5347215112273678, 0.534720692895144, 0.5347206928869905]), RecordEntry{Array{Float64,1}}([[-0.00017957831321735007, -1.5369053303947344e-5, -1.5991196719197115e-5, -0.00011959238883483006, 0.0006571819032620557, -0.00048628549082952734, -0.00017756397389585285, 0.00012025863611221717, -0.000986519270927751, 0.0003550879627678967  …  0.00023469206857972487, -4.7254727233998e-5, -0.00026201382297258506, -0.001714290190688771, 0.00020493319322553813, -0.0005259803510424072, -0.00010926439890075726, 0.0001438395281741208, 9.27241057713978e-6, 0.0002088961695966906], [-6.893087484227765e-7, -8.768942524391848e-8, 3.543936780977608e-8, -3.6104561563268236e-7, 1.959731277980112e-6, -1.5856713767760247e-6, -5.126309279540581e-7, 3.9571557224334455e-7, -3.1101487401679576e-6, 1.1456996152070113e-6  …  7.368801164374264e-7, -1.5206183151951318e-8, -8.405001014450092e-7, -5.481754520715056e-6, 6.161943729087686e-7, -1.578222204929279e-6, -2.1470153043176327e-7, 4.951592679875664e-7, -1.1018928325551638e-8, 5.614941172160529e-7], [-2.5555321544893164e-9, -2.900042552732567e-10, 3.6615709818549683e-10, -1.0902745839393851e-9, 5.867557496557995e-9, -5.171426792002776e-9, -1.4563943645501789e-9, 1.2860835777822556e-9, -9.833099704083638e-9, 3.7202802075276707e-9  …  2.324737582470903e-9, 3.7827086753726424e-10, -2.703431187082734e-9, -1.754936396664063e-8, 1.862537259730672e-9, -4.754830305636788e-9, -2.849126727187852e-10, 1.679187907536843e-9, -1.6056917964300463e-10, 1.4667528849474087e-9]], :gradient)], Dict(:Iteration => 1,:Gradient => 3,:Cost => 2)), 6, true), Stop = RecordIteration([22])))

and we see

get_record(r2, :Stop)
1-element Array{Int64,1}:
 22

as well as

get_record(r2, :Iteration, (:Iteration, :Cost))
3-element Array{Tuple{Int64,Float64},1}:
 (6, 0.5347215112273678)
 (12, 0.534720692895144)
 (18, 0.5347206928869905)

Here it is interesting to see, that a meta-record like RecordEvery just passes the tuple further on, so we can again also do

get_record_action(r2, :Iteration)[:Gradient]
3-element Array{Array{Float64,1},1}:
 [-0.00017957831321735007, -1.5369053303947344e-5, -1.5991196719197115e-5, -0.00011959238883483006, 0.0006571819032620557, -0.00048628549082952734, -0.00017756397389585285, 0.00012025863611221717, -0.000986519270927751, 0.0003550879627678967  …  0.00023469206857972487, -4.7254727233998e-5, -0.00026201382297258506, -0.001714290190688771, 0.00020493319322553813, -0.0005259803510424072, -0.00010926439890075726, 0.0001438395281741208, 9.27241057713978e-6, 0.0002088961695966906]
 [-6.893087484227765e-7, -8.768942524391848e-8, 3.543936780977608e-8, -3.6104561563268236e-7, 1.959731277980112e-6, -1.5856713767760247e-6, -5.126309279540581e-7, 3.9571557224334455e-7, -3.1101487401679576e-6, 1.1456996152070113e-6  …  7.368801164374264e-7, -1.5206183151951318e-8, -8.405001014450092e-7, -5.481754520715056e-6, 6.161943729087686e-7, -1.578222204929279e-6, -2.1470153043176327e-7, 4.951592679875664e-7, -1.1018928325551638e-8, 5.614941172160529e-7]
 [-2.5555321544893164e-9, -2.900042552732567e-10, 3.6615709818549683e-10, -1.0902745839393851e-9, 5.867557496557995e-9, -5.171426792002776e-9, -1.4563943645501789e-9, 1.2860835777822556e-9, -9.833099704083638e-9, 3.7202802075276707e-9  …  2.324737582470903e-9, 3.7827086753726424e-10, -2.703431187082734e-9, -1.754936396664063e-8, 1.862537259730672e-9, -4.754830305636788e-9, -2.849126727187852e-10, 1.679187907536843e-9, -1.6056917964300463e-10, 1.4667528849474087e-9]

Writing an own RecordActions

Let's investigate where we want to count the number of function evaluations, again just to illustrate, since for the gradient this is just one evaluation per iteration. We first define a cost, that counts it's own calls.

mutable struct MyCost{T}
    data::T
    count::Int
end
MyCost(data::T) where {T} = MyCost{T}(data, 0)
function (c::MyCost)(M, x)
    c.count += 1
    return sum(1 / (2 * length(c.data)) * distance.(Ref(M), Ref(x), c.data) .^ 2)
end

and we define the following RecordAction,

mutable struct RecordCount <: RecordAction
    recorded_values::Vector{Int}
    RecordCount() = new(Vector{Int}())
end
function (r::RecordCount)(p::Problem, ::Options, i)
    if i > 0
        push!(r.recorded_values, p.cost.count)
    elseif i < 0 # reset if negative
        r.recorded_values = Vector{Int}()
    end
end

And now we can initialize the new cost and call the gradient descent. Note that this illustrates also the last use case – you can pass symbol-Action pairs into the record=array.

F2 = MyCost(data)
R = gradient_descent(
    M,
    F2,
    gradF,
    data[1];
    record=[:Iteration, :Count => RecordCount(), :Cost],
    return_options=true,
)
RecordOptions{GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction},NamedTuple{(:Iteration,),Tuple{RecordGroup}}}(GradientDescentOptions{Array{Float64,1},Array{Float64,1},StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}},ConstantStepsize,ExponentialRetraction}([-0.013171687776920208, -0.9988254858704124, 0.010022911212629302, 0.0013344707055430774, -0.0002677083094212463, 0.0072844866211142075, 0.0015113333205996316, -0.011281003233671358, -0.006712279079023681, 0.007454047607253493  …  -0.008560555645040935, -0.0009246495682439206, -0.002882588217415678, 0.005024753641084912, 0.0014807368499568049, -0.00038585441415452137, -0.000792496232822748, 0.0019746551544632665, -0.004531172402756976, -0.0026734658178187744], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration,StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-8, "The algorithm reached approximately critical point after 20 iterations; the gradient norm (3.856776202540161e-9) is less than 1.0e-8.\n")), "The algorithm reached approximately critical point after 20 iterations; the gradient norm (3.856776202540161e-9) is less than 1.0e-8.\n"), ConstantStepsize(1.0), [-3.933077346880004e-10, -4.3171833343208076e-11, 6.56732853862585e-11, -1.5759354289753523e-10, 8.464979115222865e-10, -7.668974285907765e-10, -2.0538018128627572e-10, 1.900122719621566e-10, -1.4441114040805898e-9, 5.516133830632021e-10  …  3.412740687581419e-10, 7.649628380821533e-11, -3.992971282394328e-10, -2.58712254612816e-9, 2.6960121234011897e-10, -6.873708650878552e-10, -2.3091756615580342e-11, 2.5162697842246243e-10, -2.967947426452816e-11, 2.003055126278963e-10], ExponentialRetraction()), (Iteration = RecordGroup(RecordAction[RecordIteration([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]), Main.ex-HowToRecord.RecordCount([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), RecordCost([0.5467764136711742, 0.536496662070443, 0.5349809370338188, 0.5347588107780821, 0.5347262771824008, 0.5347215112273678, 0.5347208128443974, 0.5347207104761789, 0.5347206954668047, 0.5347206932654798, 0.5347206929425345, 0.534720692895144, 0.5347206928881875, 0.5347206928871665, 0.5347206928870166, 0.5347206928869943, 0.5347206928869911, 0.5347206928869905, 0.5347206928869904, 0.5347206928869904])], Dict(:Iteration => 1,:Count => 2,:Cost => 3)),))

We can again access the whole sets of records

get_record(R)
20-element Array{Tuple{Int64,Int64,Float64},1}:
 (1, 0, 0.5467764136711742)
 (2, 1, 0.536496662070443)
 (3, 2, 0.5349809370338188)
 (4, 3, 0.5347588107780821)
 (5, 4, 0.5347262771824008)
 (6, 5, 0.5347215112273678)
 (7, 6, 0.5347208128443974)
 (8, 7, 0.5347207104761789)
 (9, 8, 0.5347206954668047)
 (10, 9, 0.5347206932654798)
 (11, 10, 0.5347206929425345)
 (12, 11, 0.534720692895144)
 (13, 12, 0.5347206928881875)
 (14, 13, 0.5347206928871665)
 (15, 14, 0.5347206928870166)
 (16, 15, 0.5347206928869943)
 (17, 16, 0.5347206928869911)
 (18, 17, 0.5347206928869905)
 (19, 18, 0.5347206928869904)
 (20, 19, 0.5347206928869904)

this is equivalent to calling R[:Iteration. Note that since we introduced :Count we can also access a single recorded value using

R[:Iteration, :Count]
20-element Array{Int64,1}:
  0
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19