How to Record Data During the Iterations
The recording and debugging features make it possible to record nearly any data during the iterations. This tutorial illustrates how to:
record one value during the iterations;
record multiple values during the iterations and access them afterwards;
define an own
RecordAction
to perform individual recordings.
Several predefined recordings exist, for example RecordCost()
or RecordGradient()
, depending on the solver used. For fields of the Options
this can be directly done using the [RecordEntry(:field)
]. For other recordings, for example more advanced computations before storing a value, an own RecordAction
can be defined.
We illustrate these using the gradient descent from the mean computation tutorial.
using Manopt, Manifolds, Random
begin
Random.seed!(42)
m = 30
M = Sphere(m)
n = 800
σ = π / 8
x = zeros(Float64, m + 1)
x[2] = 1.0
data = [exp(M, x, random_tangent(M, x, Val(:Gaussian), σ)) for i in 1:n]
end
800-element Vector{Vector{Float64}}: [-0.054658825167894595, -0.5592077846510423, -0.04738273828111257, -0.04682080720921302, 0.12279468849667038, 0.07171438895366239, -0.12930045409417057, -0.22102081626380404, -0.31805333254577767, 0.0065859500152017645 … -0.21999168261518043, 0.19570142227077295, 0.340909965798364, -0.0310802190082894, -0.04674431076254687, -0.006088297671169996, 0.01576037011323387, -0.14523596850249543, 0.14526158060820338, 0.1972125856685378] [-0.08192376929745251, -0.5097715132187672, -0.008339904915541008, 0.07289741328038679, 0.11422036270613802, -0.11546739299835752, 0.22969969326284728, 0.14904671708359585, -0.11124820565850368, -0.11790721606521787 … -0.1642124963047035, -0.24505758444677161, -0.07570080850379846, -0.07426218324072494, -0.026520181327346348, 0.1155534120525021, -0.029295576236512112, -0.09012096853677579, -0.23470556634911582, -0.026214242996704024] [-0.22951484264859257, -0.6083825348640186, 0.14273766477054015, -0.11947823367023377, 0.05984293499234536, 0.058820835498203126, 0.07577331705863266, 0.1632847202946857, 0.20244385489915745, 0.04389826920203656 … 0.3222365119325929, 0.009728730325524067, -0.12094785371632395, -0.36322323926212824, -0.0689253407939657, 0.23356953371702974, 0.23489531397909744, 0.078303336494718, -0.14272984135578806, 0.07844539956202407] [-0.0012588500237817606, -0.29958740415089763, 0.036738459489123514, 0.20567651907595125, -0.1131046432541904, -0.06032435985370224, 0.3366633723165895, -0.1694687746143405, -0.001987171245125281, 0.04933779858684409 … -0.2399584473006256, 0.19889267065775063, 0.22468755918787048, 0.1780090580180643, 0.023703860700539356, -0.10212737517121755, 0.03807004103115319, -0.20569120952458983, -0.03257704254233959, 0.06925473452536687] [-0.035534309946938375, -0.06645560787329002, 0.14823972268208874, -0.23913346587232426, 0.038347027875883496, 0.10453333143286662, 0.050933995140290705, -0.12319549375687473, 0.12956684644537844, -0.23540367869989412 … -0.41471772859912864, -0.1418984610380257, 0.0038321446836859334, 0.23655566917750157, -0.17500681300994742, -0.039189751036839374, -0.08687860620942896, -0.11509948162959047, 0.11378233994840942, 0.38739450723013735] [-0.3122539912469438, -0.3101935557860296, 0.1733113629107006, 0.08968593616209351, -0.1836344261367962, -0.06480023695256802, 0.18165070013886545, 0.19618275767992124, -0.07956460275570058, 0.0325997354656551 … 0.2845492418767769, 0.17406455870721682, -0.053101230371568706, -0.1382082812981627, 0.005830071475508364, 0.16739264037923055, 0.034365814374995335, 0.09107702398753297, -0.1877250428700409, 0.05116494897806923] [-0.04159442361185588, -0.7768029783272633, 0.06303616666722486, 0.08070518925253539, -0.07396265237309446, -0.06008109299719321, 0.07977141629715745, 0.019511027129056415, 0.08629917589924847, -0.11156298867318722 … 0.0792587504128044, -0.016444383900170008, -0.181746064577005, -0.01888129512990984, -0.13523922089388968, 0.11358102175659832, 0.07929049608459493, 0.1689565359083833, 0.07673657951723721, -0.1128480905648813] ⋮ [-0.19830349374441875, -0.6086693423968884, 0.08552341811170468, 0.35781519334042255, 0.15790663648524367, 0.02712571268324985, 0.09855601327331667, -0.05840653973421127, -0.09546429767790429, -0.13414717696055448 … -0.0430935804718714, 0.2678584478951765, 0.08780994289014614, 0.01613469379498457, 0.0516187906322884, -0.07383067566731401, -0.1481272738354552, -0.010532317187265649, 0.06555344745952187, -0.1506167863762911] [-0.043475241251977756, -0.6327981074196991, -0.22111668003519114, 0.02822074679404562, -0.08550248815229336, 0.12821801740178354, 0.1779499563280025, -0.10247384887512372, 0.03964324641001163, -0.05825803381126274 … 0.12538932070835737, 0.0962820226976477, 0.31652954739473566, -0.14915034201394842, -0.13767278678177727, -0.004153096613530295, 0.09277957650773745, 0.059172645540316274, -0.12230262590034516, -0.19655728521529928] [-0.10173946348675116, -0.6475660153977272, 0.1260284619729566, -0.11933160462857616, -0.04774310633937567, 0.09093928358804217, 0.041662676324043114, -0.1264739543938265, 0.09605293126911392, -0.16790474428001648 … -0.04056684573478108, 0.09351665120940456, 0.15259195558799882, 0.0009949298312580497, 0.09461980828206303, 0.3067004514287283, 0.16129258773733715, -0.18893664085007542, -0.1806865244492513, 0.029319680436405825] [-0.25178095432005304, -0.39147463259941434, -0.24359579328578632, 0.3017930975766573, 0.21658893985206487, 0.12304585275893234, 0.2828113308645171, 0.029187615341955332, 0.03616243507191925, 0.02937558890997916 … -0.08071746662465405, -0.21761019282586586, 0.2094468492117083, 0.04303327342535272, -0.04050554246085358, 0.179355961490792, -0.08454569418519974, 0.054594159703393215, 0.12471741052450101, -0.24314124407858334] [0.28156471341150974, -0.6708572780452595, -0.1410302363738465, -0.08322589397277698, -0.022772599832907418, -0.04447265789199677, -0.016448068022011157, -0.07490911512503738, 0.2778432295769144, -0.10191899088372378 … -0.057272155080983836, 0.12817478092201395, 0.04623814480781884, -0.12184190164369117, 0.1987855635987229, -0.14533603246124993, -0.16334072868597016, -0.052369977381939437, 0.014904286931394959, -0.2440882678882144] [0.12108727495744157, -0.714787344982596, 0.01632521838262752, 0.04437570556908449, -0.041199280304144284, 0.052984488452616, 0.03796520200156107, 0.2791785910964288, 0.11530429924056099, 0.12178223160398421 … -0.07621847481721669, 0.18353870423743013, -0.19066653731436745, -0.09423224997242206, 0.14596847781388494, -0.09747986927777111, 0.16041150122587072, -0.02296513951256738, 0.06786878373578588, 0.15296635978447756]
F(M, y) = sum(1 / (2 * n) * distance.(Ref(M), Ref(y), data) .^ 2)
F (generic function with 1 method)
gradF(M, y) = sum(1 / n * grad_distance.(Ref(M), data, Ref(y)))
gradF (generic function with 1 method)
Plain Examples
For the high level interfaces of the solvers, like gradient_descent
we have to set return_options
to true
to obtain the whole options structure and not only the resulting minimizer.
Then we can easily use the record=
option to add recorded values. This keyword accepts RecordAction
s as well as several symbols as shortcuts, for example :Cost
to record the cost, or if your options have a field f
, :f
would record that entry.
R = gradient_descent(M, F, gradF, data[1]; record=:Cost, return_options=true)
RecordOptions{GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}, NamedTuple{(:Iteration,), Tuple{RecordCost}}}(GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}([0.003348563697666398, -0.9989177042237649, 0.012603956176158468, 3.724185091494971e-5, 0.003476887105914318, 0.007393313151087908, 0.0015131380045400118, -0.020957966249101206, 0.014862723464370162, -0.007213435035962335 … -0.003341406163845814, 0.004841936440829273, -0.010592698188325093, -0.01301712666275464, 0.0033263344419417, -0.004530004343370803, 0.0030077997761761826, -0.015610321221756972, -0.0016794409150365508, -0.0025375708198013104], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-8, "The algorithm reached approximately critical point after 20 iterations; the gradient norm (7.386440725063072e-9) is less than 1.0e-8.\n")), "The algorithm reached approximately critical point after 20 iterations; the gradient norm (7.386440725063072e-9) is less than 1.0e-8.\n"), ConstantStepsize(1.0), [-6.129711326192944e-10, 5.1812992794094857e-11, -5.405597899451432e-10, -2.1340247270847362e-10, 9.874240397637337e-10, 5.979410710462525e-10, -1.042001310031708e-9, -1.7296361451923058e-9, -2.7893922138993825e-9, 8.807665458510777e-11 … -1.7041825315966585e-9, 1.6257109125691546e-9, 3.232189115420954e-9, -2.7365968945948765e-10, -3.5264198551825283e-10, 7.746778828772623e-11, 1.4421730910180175e-10, -1.2080581403121644e-9, 1.0696541392159111e-9, 1.9981575853015958e-9], ExponentialRetraction()), (Iteration = RecordCost([0.5808287253777765, 0.5395268557323746, 0.5333529073733115, 0.5324514620174543, 0.5323201743667151, 0.5323010518577256, 0.5322982658416161, 0.532297859847447, 0.5322978006725337, 0.5322977920461375, 0.5322977907883957, 0.5322977906049865, 0.5322977905782369, 0.532297790574335, 0.5322977905737657, 0.5322977905736823, 0.5322977905736703, 0.5322977905736688, 0.5322977905736683, 0.5322977905736684]),))
From the returned options, we see that the Options
are encapsulated (decorated) with RecordOptions
.
You can attach different recorders to some operations (:Start
. :Stop
, :Iteration
at time of writing), where :Iteration
is the default, so the following is the same as get_record(R, :Iteation)
. We get
get_record(R)
20-element Vector{Float64}: 0.5808287253777765 0.5395268557323746 0.5333529073733115 0.5324514620174543 0.5323201743667151 0.5323010518577256 0.5322982658416161 ⋮ 0.5322977905737657 0.5322977905736823 0.5322977905736703 0.5322977905736688 0.5322977905736683 0.5322977905736684
To record more than one value, you can pass an array of a mix of symbols and RecordAction
which formally introduces RecordGroup
. Such a group records a tuple of values in every iteration.
R2 = gradient_descent(M, F, gradF, data[1]; record=[:Iteration, :Cost], return_options=true)
RecordOptions{GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}, NamedTuple{(:Iteration,), Tuple{RecordGroup}}}(GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}([0.003348563697666398, -0.9989177042237649, 0.012603956176158468, 3.724185091494971e-5, 0.003476887105914318, 0.007393313151087908, 0.0015131380045400118, -0.020957966249101206, 0.014862723464370162, -0.007213435035962335 … -0.003341406163845814, 0.004841936440829273, -0.010592698188325093, -0.01301712666275464, 0.0033263344419417, -0.004530004343370803, 0.0030077997761761826, -0.015610321221756972, -0.0016794409150365508, -0.0025375708198013104], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-8, "The algorithm reached approximately critical point after 20 iterations; the gradient norm (7.386440725063072e-9) is less than 1.0e-8.\n")), "The algorithm reached approximately critical point after 20 iterations; the gradient norm (7.386440725063072e-9) is less than 1.0e-8.\n"), ConstantStepsize(1.0), [-6.129711326192944e-10, 5.1812992794094857e-11, -5.405597899451432e-10, -2.1340247270847362e-10, 9.874240397637337e-10, 5.979410710462525e-10, -1.042001310031708e-9, -1.7296361451923058e-9, -2.7893922138993825e-9, 8.807665458510777e-11 … -1.7041825315966585e-9, 1.6257109125691546e-9, 3.232189115420954e-9, -2.7365968945948765e-10, -3.5264198551825283e-10, 7.746778828772623e-11, 1.4421730910180175e-10, -1.2080581403121644e-9, 1.0696541392159111e-9, 1.9981575853015958e-9], ExponentialRetraction()), (Iteration = RecordGroup(Manopt.RecordAction[RecordIteration([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]), RecordCost([0.5808287253777765, 0.5395268557323746, 0.5333529073733115, 0.5324514620174543, 0.5323201743667151, 0.5323010518577256, 0.5322982658416161, 0.532297859847447, 0.5322978006725337, 0.5322977920461375, 0.5322977907883957, 0.5322977906049865, 0.5322977905782369, 0.532297790574335, 0.5322977905737657, 0.5322977905736823, 0.5322977905736703, 0.5322977905736688, 0.5322977905736683, 0.5322977905736684])], Dict(:Iteration => 1, :Cost => 2)),))
Here, the symbol :Cost
is mapped to using the RecordCost
action. The same holds for :Iteration
and :Iterate
and any member field of the current Options
. To access these you can first extract the group of records (that is where the :Iteration
s are recorded – note the plural) and then access the :Cost
get_record_action(R2, :Iteration)
RecordGroup(Manopt.RecordAction[RecordIteration([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]), RecordCost([0.5808287253777765, 0.5395268557323746, 0.5333529073733115, 0.5324514620174543, 0.5323201743667151, 0.5323010518577256, 0.5322982658416161, 0.532297859847447, 0.5322978006725337, 0.5322977920461375, 0.5322977907883957, 0.5322977906049865, 0.5322977905782369, 0.532297790574335, 0.5322977905737657, 0.5322977905736823, 0.5322977905736703, 0.5322977905736688, 0.5322977905736683, 0.5322977905736684])], Dict(:Iteration => 1, :Cost => 2))
:Iteration
is the default here, i.e. something recorded through the iterations – and we can access the recorded data the same way as we specify them in the record=
keyword, that is, using the indexing operation.
get_record_action(R2)[:Cost]
20-element Vector{Float64}: 0.5808287253777765 0.5395268557323746 0.5333529073733115 0.5324514620174543 0.5323201743667151 0.5323010518577256 0.5322982658416161 ⋮ 0.5322977905737657 0.5322977905736823 0.5322977905736703 0.5322977905736688 0.5322977905736683 0.5322977905736684
This can be also done by using a the high level interface get_record
.
get_record(R2, :Iteration, :Cost)
20-element Vector{Float64}: 0.5808287253777765 0.5395268557323746 0.5333529073733115 0.5324514620174543 0.5323201743667151 0.5323010518577256 0.5322982658416161 ⋮ 0.5322977905737657 0.5322977905736823 0.5322977905736703 0.5322977905736688 0.5322977905736683 0.5322977905736684
Note that the first symbol again refers to the point where we record (not to the thing we record). We can also pass a tuple as second argument to have our own order (not that now the second :Iteration
refers to the recorded iterations).
get_record(R2, :Iteration, (:Iteration, :Cost))
20-element Vector{Tuple{Int64, Float64}}: (1, 0.5808287253777765) (2, 0.5395268557323746) (3, 0.5333529073733115) (4, 0.5324514620174543) (5, 0.5323201743667151) (6, 0.5323010518577256) (7, 0.5322982658416161) ⋮ (15, 0.5322977905737657) (16, 0.5322977905736823) (17, 0.5322977905736703) (18, 0.5322977905736688) (19, 0.5322977905736683) (20, 0.5322977905736684)
A more Complex Example
To illustrate a complicated example let's record:
the iteration number, cost and gradient field, but only every sixth iteration;
the iteration at which we stop.
We first generate the problem and the options, to also illustrate the low-level works when not using gradient_descent
.
p = GradientProblem(M, F, gradF)
GradientProblem{AllocatingEvaluation, Sphere{30, ℝ}, typeof(F), typeof(gradF)}(Sphere(30, ℝ), Main.var"workspace#3".F, Main.var"workspace#3".gradF)
o = GradientDescentOptions(
M,
copy(data[1]);
stopping_criterion=StopAfterIteration(200) | StopWhenGradientNormLess(10.0^-9),
)
GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}([-0.054658825167894595, -0.5592077846510423, -0.04738273828111257, -0.04682080720921302, 0.12279468849667038, 0.07171438895366239, -0.12930045409417057, -0.22102081626380404, -0.31805333254577767, 0.0065859500152017645 … -0.21999168261518043, 0.19570142227077295, 0.340909965798364, -0.0310802190082894, -0.04674431076254687, -0.006088297671169996, 0.01576037011323387, -0.14523596850249543, 0.14526158060820338, 0.1972125856685378], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-9, "")), #undef), ConstantStepsize(1.0), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ExponentialRetraction())
We first build a RecordGroup
to group the three entries we want to record per iteration. We then put this into a RecordEvery
to only record this every 6th iteration
RecordEvery(RecordGroup(Manopt.RecordAction[RecordIteration(Int64[]), RecordCost(Float64[]), RecordEntry{Vector{Float64}}(Vector{Float64}[], :gradient)], Dict(:Iteration => 1, :Gradient => 3, :Cost => 2)), 6, true)
and a small option to record iterations
RecordIteration(Int64[])
We now combine both into the RecordOptions
decorator. It acts completely the same as an Option
but records something in every iteration additionally. This is stored in a dictionary of RecordActions
, where :Iteration
is the action (here the only every 6th iteration group) and the sI
which is executed at stop.
Note that the keyword record=
(in the high level interface gradient_descent
only would fill the :Iteration
symbol).
r = RecordOptions(o, Dict(:Iteration => rI, :Stop => sI))
RecordOptions{GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}, NamedTuple{(:Iteration, :Stop), Tuple{RecordEvery, RecordIteration}}}(GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}([-0.054658825167894595, -0.5592077846510423, -0.04738273828111257, -0.04682080720921302, 0.12279468849667038, 0.07171438895366239, -0.12930045409417057, -0.22102081626380404, -0.31805333254577767, 0.0065859500152017645 … -0.21999168261518043, 0.19570142227077295, 0.340909965798364, -0.0310802190082894, -0.04674431076254687, -0.006088297671169996, 0.01576037011323387, -0.14523596850249543, 0.14526158060820338, 0.1972125856685378], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-9, "")), #undef), ConstantStepsize(1.0), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ExponentialRetraction()), (Iteration = RecordEvery(RecordGroup(Manopt.RecordAction[RecordIteration(Int64[]), RecordCost(Float64[]), RecordEntry{Vector{Float64}}(Vector{Float64}[], :gradient)], Dict(:Iteration => 1, :Gradient => 3, :Cost => 2)), 6, true), Stop = RecordIteration(Int64[])))
We now call the solver
res = solve(p, r)
RecordOptions{GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}, NamedTuple{(:Iteration, :Stop), Tuple{RecordEvery, RecordIteration}}}(GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}([0.003348564059848561, -0.9989177042539492, 0.01260395649346607, 3.724196744474232e-5, 0.003476886534940969, 0.007393312802598416, 0.0015131386077148752, -0.020957965245003006, 0.014862725086004933, -0.007213435086688198 … -0.0033414051806981157, 0.004841935496011182, -0.010592700080354565, -0.013017126499757144, 0.0033263346447266217, -0.00453000439300303, 0.0030077996912374564, -0.015610320516095559, -0.0016794415289985518, -0.0025375719951008105], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-9, "The algorithm reached approximately critical point after 23 iterations; the gradient norm (4.1219635493336396e-10) is less than 1.0e-9.\n")), "The algorithm reached approximately critical point after 23 iterations; the gradient norm (4.1219635493336396e-10) is less than 1.0e-9.\n"), ConstantStepsize(1.0), [-3.504430570061036e-11, 2.8771833539999853e-12, -3.0484566932746844e-11, -1.0315284488127824e-11, 5.397456908365827e-11, 3.3226527039183044e-11, -5.7099667584491285e-11, -9.534786995319885e-11, -1.5422092373767668e-10, 4.771789625761247e-12 … -9.27095560558774e-11, 8.982541826237783e-11, 1.8128432907387956e-10, -1.591651685667675e-11, -1.9069036090836947e-11, 5.1840647660034156e-12, 8.188570318774097e-12, -6.746131014177662e-11, 5.758340778218383e-11, 1.1318272976289755e-10], ExponentialRetraction()), (Iteration = RecordEvery(RecordGroup(Manopt.RecordAction[RecordIteration([6, 12, 18]), RecordCost([0.5323010518577256, 0.5322977906049865, 0.5322977905736688]), RecordEntry{Vector{Float64}}([[-0.00038158718978169347, 8.243775455566567e-5, -0.0003580924733205341, -0.00025905876141471746, 0.0007619214859574283, 0.0004277150396527791, -0.0008076042524146088, -0.001301005673661613, -0.0020670967148862514, 6.95591037894901e-5 … -0.0013504353686798949, 0.001209447330495814, 0.002242981930382095, -0.00016380134227963583, -0.0002994016292140086, -7.142360524208869e-6, 9.124917577360244e-5, -0.0008577966654541917, 0.0008956067961453931, 0.0013124151032857276], [-1.2578098855403912e-6, 1.161830737771279e-7, -1.1492279085903175e-6, -6.524413245202699e-7, 2.287734329176561e-6, 1.3267058701910134e-6, -2.4124385063506256e-6, -3.939382844498078e-6, -6.298004499544115e-6, 2.0690731768901625e-7 … -4.00515241077542e-6, 3.678794834016273e-6, 7.027441392112475e-6, -5.445300753661287e-7, -8.555011713747636e-7, 6.73538896253927e-8, 2.9994829264689995e-7, -2.660651049511443e-6, 2.5913354877695258e-6, 4.218349473223207e-6], [-4.128322534329267e-9, 3.56089609426109e-10, -3.6735693468860467e-9, -1.5970797744593824e-9, 6.853233817318992e-9, 4.105108631767985e-9, -7.225021168590616e-9, -1.1944745389839972e-8, -1.922351771169173e-8, 6.141461669458611e-10 … -1.1867770362129055e-8, 1.1209332637768764e-8, 2.2064372827727483e-8, -1.8248808981405525e-9, -2.4702197758000657e-9, 4.5578675702357893e-10, 9.752551396937249e-10, -8.271971119230538e-9, 7.50339001924559e-9, 1.3546235655870383e-8]], :gradient)], Dict(:Iteration => 1, :Gradient => 3, :Cost => 2)), 6, true), Stop = RecordIteration([23])))
And we can check the recorded value at :Stop
to see how many iterations were performed
get_record(res, :Stop)
1-element Vector{Int64}: 23
and the other values during the iterations are
get_record(res, :Iteration, (:Iteration, :Cost))
3-element Vector{Tuple{Int64, Float64}}: (6, 0.5323010518577256) (12, 0.5322977906049865) (18, 0.5322977905736688)
Writing an own RecordAction
s
Let's investigate where we want to count the number of function evaluations, again just to illustrate, since for the gradient this is just one evaluation per iteration. We first define a cost, that counts its own calls.
begin
mutable struct MyCost{T}
data::T
count::Int
end
MyCost(data::T) where {T} = MyCost{T}(data, 0)
function (c::MyCost)(M, x)
c.count += 1
return sum(1 / (2 * length(c.data)) * distance.(Ref(M), Ref(x), c.data) .^ 2)
end
end
and we define the following RecordAction, which is a functor, i.e. a struct that is also a function. The function we have to implement is similar to a single solver step in signature, since it might get called every iteration:
begin
mutable struct RecordCount <: RecordAction
recorded_values::Vector{Int}
RecordCount() = new(Vector{Int}())
end
function (r::RecordCount)(p::Problem, ::Options, i)
if i > 0
push!(r.recorded_values, p.cost.count)
elseif i < 0 # reset if negative
r.recorded_values = Vector{Int}()
end
end
end
Now we can initialize the new cost and call the gradient descent. Note that this illustrates also the last use case – you can pass symbol-action pairs into the record=
array.
F2 = MyCost(data)
MyCost{Vector{Vector{Float64}}}([[-0.054658825167894595, -0.5592077846510423, -0.04738273828111257, -0.04682080720921302, 0.12279468849667038, 0.07171438895366239, -0.12930045409417057, -0.22102081626380404, -0.31805333254577767, 0.0065859500152017645 … -0.21999168261518043, 0.19570142227077295, 0.340909965798364, -0.0310802190082894, -0.04674431076254687, -0.006088297671169996, 0.01576037011323387, -0.14523596850249543, 0.14526158060820338, 0.1972125856685378], [-0.08192376929745251, -0.5097715132187672, -0.008339904915541008, 0.07289741328038679, 0.11422036270613802, -0.11546739299835752, 0.22969969326284728, 0.14904671708359585, -0.11124820565850368, -0.11790721606521787 … -0.1642124963047035, -0.24505758444677161, -0.07570080850379846, -0.07426218324072494, -0.026520181327346348, 0.1155534120525021, -0.029295576236512112, -0.09012096853677579, -0.23470556634911582, -0.026214242996704024], [-0.22951484264859257, -0.6083825348640186, 0.14273766477054015, -0.11947823367023377, 0.05984293499234536, 0.058820835498203126, 0.07577331705863266, 0.1632847202946857, 0.20244385489915745, 0.04389826920203656 … 0.3222365119325929, 0.009728730325524067, -0.12094785371632395, -0.36322323926212824, -0.0689253407939657, 0.23356953371702974, 0.23489531397909744, 0.078303336494718, -0.14272984135578806, 0.07844539956202407], [-0.0012588500237817606, -0.29958740415089763, 0.036738459489123514, 0.20567651907595125, -0.1131046432541904, -0.06032435985370224, 0.3366633723165895, -0.1694687746143405, -0.001987171245125281, 0.04933779858684409 … -0.2399584473006256, 0.19889267065775063, 0.22468755918787048, 0.1780090580180643, 0.023703860700539356, -0.10212737517121755, 0.03807004103115319, -0.20569120952458983, -0.03257704254233959, 0.06925473452536687], [-0.035534309946938375, -0.06645560787329002, 0.14823972268208874, -0.23913346587232426, 0.038347027875883496, 0.10453333143286662, 0.050933995140290705, -0.12319549375687473, 0.12956684644537844, -0.23540367869989412 … -0.41471772859912864, -0.1418984610380257, 0.0038321446836859334, 0.23655566917750157, -0.17500681300994742, -0.039189751036839374, -0.08687860620942896, -0.11509948162959047, 0.11378233994840942, 0.38739450723013735], [-0.3122539912469438, -0.3101935557860296, 0.1733113629107006, 0.08968593616209351, -0.1836344261367962, -0.06480023695256802, 0.18165070013886545, 0.19618275767992124, -0.07956460275570058, 0.0325997354656551 … 0.2845492418767769, 0.17406455870721682, -0.053101230371568706, -0.1382082812981627, 0.005830071475508364, 0.16739264037923055, 0.034365814374995335, 0.09107702398753297, -0.1877250428700409, 0.05116494897806923], [-0.04159442361185588, -0.7768029783272633, 0.06303616666722486, 0.08070518925253539, -0.07396265237309446, -0.06008109299719321, 0.07977141629715745, 0.019511027129056415, 0.08629917589924847, -0.11156298867318722 … 0.0792587504128044, -0.016444383900170008, -0.181746064577005, -0.01888129512990984, -0.13523922089388968, 0.11358102175659832, 0.07929049608459493, 0.1689565359083833, 0.07673657951723721, -0.1128480905648813], [-0.21221814304651335, -0.5031823821503253, 0.010326342133992458, -0.12438192100961257, 0.04004758695231872, 0.2280527500843805, -0.2096243232022162, -0.16564828762420294, -0.28325749481138984, 0.17033534605245823 … -0.13599096505924074, 0.28437770540525625, 0.08424426798544583, -0.1266207606984139, 0.04917635557603396, -0.00012608938533809706, -0.04283220254770056, -0.08771365647566572, 0.14750169103093985, 0.11601120086036351], [0.10683290707435533, -0.17680836277740178, 0.237674583018994, 0.12011180867097296, -0.029404774462600147, 0.11522028383799932, -0.33181744809745184, -0.17859266746938368, 0.04352373642537758, 0.25303828026679875 … 0.08879861736692071, -0.0044125069878017285, 0.19786810509925892, -0.13971046827270436, 0.09482328498485092, 0.05108149065160892, -0.1457834350695163, 0.31674797726604376, 0.1042267316918273, 0.21573150015891307], [-0.024895624707466164, -0.7473912016432697, -0.1392537238944721, -0.14948896791465557, -0.09765393283580377, 0.04413059403279867, -0.13865379004720355, -0.071032040283992, 0.15604054722246585, -0.10744260463413555 … -0.14748067081342833, -0.14743635071251024, 0.0643591937981352, 0.16138827697852615, -0.12656652133603935, -0.06463635704869083, 0.14329582429103488, -0.01113113793821713, 0.29295387893749997, 0.06774523575259782] … [0.011874845316569958, -0.6910596618389591, 0.2127574143947781, -0.014042545524367428, -0.07883613103495009, -0.002190096669624676, -0.03383643046422047, 0.2925813113264833, -0.04718187201980004, 0.03949680289730033 … 0.08677365866032935, 0.040468251005154374, -0.2477981384858724, -0.2863151460287713, -0.07211767532456784, -0.15072898498180462, 0.01785592362182673, -0.09795357710255247, -0.14755229203084913, 0.1305005778855435], [0.013457629515450426, -0.3750353654626534, 0.12349883726772073, 0.3521803555005319, 0.2475921439420274, 0.006088649842999206, 0.31203183112392907, -0.036869203979483754, -0.07475746464056504, -0.029297797064479717 … 0.16867368684091563, -0.09450564983271922, -0.0587273302122711, -0.1326667940553803, -0.25530237980444614, 0.37556905374043376, 0.04922612067677609, 0.2605362549983866, -0.21871556587505667, -0.22915883767386164], [0.03295085436260177, -0.971861604433394, 0.034748713521512035, -0.0494065013245799, -0.01767479281403355, 0.0465459739459587, 0.007470494722096038, 0.003227960072276129, 0.0058328596338402365, -0.037591237446692356 … 0.03205152122876297, 0.11331109854742015, 0.03044900529526686, 0.017971704993311105, -0.009329252062960229, -0.02939354719650879, 0.022088835776251863, -0.02546111553658854, -0.0026257225461427582, 0.005702111697172774], [0.06968243992532257, -0.7119502191435176, -0.18136614593117445, -0.1695926215673451, 0.01725015359973796, -0.00694164951158388, -0.34621134287344574, 0.024709256792651912, -0.1632255805999673, -0.2158226433583082 … -0.14153772108081458, -0.11256850346909901, 0.045109821764180706, -0.1162754336222613, -0.13221711766357983, 0.005365354776191061, 0.012750671705879105, -0.018208207549835407, 0.12458753932455452, -0.31843587960340897], [-0.19830349374441875, -0.6086693423968884, 0.08552341811170468, 0.35781519334042255, 0.15790663648524367, 0.02712571268324985, 0.09855601327331667, -0.05840653973421127, -0.09546429767790429, -0.13414717696055448 … -0.0430935804718714, 0.2678584478951765, 0.08780994289014614, 0.01613469379498457, 0.0516187906322884, -0.07383067566731401, -0.1481272738354552, -0.010532317187265649, 0.06555344745952187, -0.1506167863762911], [-0.043475241251977756, -0.6327981074196991, -0.22111668003519114, 0.02822074679404562, -0.08550248815229336, 0.12821801740178354, 0.1779499563280025, -0.10247384887512372, 0.03964324641001163, -0.05825803381126274 … 0.12538932070835737, 0.0962820226976477, 0.31652954739473566, -0.14915034201394842, -0.13767278678177727, -0.004153096613530295, 0.09277957650773745, 0.059172645540316274, -0.12230262590034516, -0.19655728521529928], [-0.10173946348675116, -0.6475660153977272, 0.1260284619729566, -0.11933160462857616, -0.04774310633937567, 0.09093928358804217, 0.041662676324043114, -0.1264739543938265, 0.09605293126911392, -0.16790474428001648 … -0.04056684573478108, 0.09351665120940456, 0.15259195558799882, 0.0009949298312580497, 0.09461980828206303, 0.3067004514287283, 0.16129258773733715, -0.18893664085007542, -0.1806865244492513, 0.029319680436405825], [-0.25178095432005304, -0.39147463259941434, -0.24359579328578632, 0.3017930975766573, 0.21658893985206487, 0.12304585275893234, 0.2828113308645171, 0.029187615341955332, 0.03616243507191925, 0.02937558890997916 … -0.08071746662465405, -0.21761019282586586, 0.2094468492117083, 0.04303327342535272, -0.04050554246085358, 0.179355961490792, -0.08454569418519974, 0.054594159703393215, 0.12471741052450101, -0.24314124407858334], [0.28156471341150974, -0.6708572780452595, -0.1410302363738465, -0.08322589397277698, -0.022772599832907418, -0.04447265789199677, -0.016448068022011157, -0.07490911512503738, 0.2778432295769144, -0.10191899088372378 … -0.057272155080983836, 0.12817478092201395, 0.04623814480781884, -0.12184190164369117, 0.1987855635987229, -0.14533603246124993, -0.16334072868597016, -0.052369977381939437, 0.014904286931394959, -0.2440882678882144], [0.12108727495744157, -0.714787344982596, 0.01632521838262752, 0.04437570556908449, -0.041199280304144284, 0.052984488452616, 0.03796520200156107, 0.2791785910964288, 0.11530429924056099, 0.12178223160398421 … -0.07621847481721669, 0.18353870423743013, -0.19066653731436745, -0.09423224997242206, 0.14596847781388494, -0.09747986927777111, 0.16041150122587072, -0.02296513951256738, 0.06786878373578588, 0.15296635978447756]], 0)
R3 = gradient_descent(
M,
F2,
gradF,
data[1];
record=[:Iteration, :Count => RecordCount(), :Cost],
return_options=true,
)
RecordOptions{GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}, NamedTuple{(:Iteration,), Tuple{RecordGroup}}}(GradientDescentOptions{Vector{Float64}, Vector{Float64}, StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}, ConstantStepsize, ExponentialRetraction}([0.003348563697666398, -0.9989177042237649, 0.012603956176158468, 3.724185091494971e-5, 0.003476887105914318, 0.007393313151087908, 0.0015131380045400118, -0.020957966249101206, 0.014862723464370162, -0.007213435035962335 … -0.003341406163845814, 0.004841936440829273, -0.010592698188325093, -0.01301712666275464, 0.0033263344419417, -0.004530004343370803, 0.0030077997761761826, -0.015610321221756972, -0.0016794409150365508, -0.0025375708198013104], IdentityUpdateRule(), StopWhenAny{Tuple{StopAfterIteration, StopWhenGradientNormLess}}((StopAfterIteration(200, ""), StopWhenGradientNormLess(1.0e-8, "The algorithm reached approximately critical point after 20 iterations; the gradient norm (7.386440725063072e-9) is less than 1.0e-8.\n")), "The algorithm reached approximately critical point after 20 iterations; the gradient norm (7.386440725063072e-9) is less than 1.0e-8.\n"), ConstantStepsize(1.0), [-6.129711326192944e-10, 5.1812992794094857e-11, -5.405597899451432e-10, -2.1340247270847362e-10, 9.874240397637337e-10, 5.979410710462525e-10, -1.042001310031708e-9, -1.7296361451923058e-9, -2.7893922138993825e-9, 8.807665458510777e-11 … -1.7041825315966585e-9, 1.6257109125691546e-9, 3.232189115420954e-9, -2.7365968945948765e-10, -3.5264198551825283e-10, 7.746778828772623e-11, 1.4421730910180175e-10, -1.2080581403121644e-9, 1.0696541392159111e-9, 1.9981575853015958e-9], ExponentialRetraction()), (Iteration = RecordGroup(Manopt.RecordAction[RecordIteration([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]), RecordCount([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), RecordCost([0.5808287253777765, 0.5395268557323746, 0.5333529073733115, 0.5324514620174543, 0.5323201743667151, 0.5323010518577256, 0.5322982658416161, 0.532297859847447, 0.5322978006725337, 0.5322977920461375, 0.5322977907883957, 0.5322977906049865, 0.5322977905782369, 0.532297790574335, 0.5322977905737657, 0.5322977905736823, 0.5322977905736703, 0.5322977905736688, 0.5322977905736683, 0.5322977905736684])], Dict(:Iteration => 1, :Count => 2, :Cost => 3)),))
For :Cost
we already learned how to access them, the :Count =>
introduces the following action to obtain the :Count
. We can again access the whole sets of records
get_record(R3)
20-element Vector{Tuple{Int64, Int64, Float64}}: (1, 0, 0.5808287253777765) (2, 1, 0.5395268557323746) (3, 2, 0.5333529073733115) (4, 3, 0.5324514620174543) (5, 4, 0.5323201743667151) (6, 5, 0.5323010518577256) (7, 6, 0.5322982658416161) ⋮ (15, 14, 0.5322977905737657) (16, 15, 0.5322977905736823) (17, 16, 0.5322977905736703) (18, 17, 0.5322977905736688) (19, 18, 0.5322977905736683) (20, 19, 0.5322977905736684)
this is equivalent to calling R[:Iteration]
. Note that since we introduced :Count
we can also access a single recorded value using
R3[:Iteration, :Count]
20-element Vector{Int64}: 0 1 2 3 4 5 6 ⋮ 14 15 16 17 18 19
and we see that the cost function is called once per iteration.