Stochastic Gradient Descent
This tutorial illustrates how to use the stochastic_gradient_descent
solver and different DirectionUpdateRule
s in order to introduce the average or momentum variant, see Stochastic Gradient Descent.
Computationally we look at a very simple but large scale problem, the Riemannian Center of Mass or Fréchet mean: For given points $p_i ∈\mathcal M$, $i=1,…,N$ this optimization problem reads
$$\operatorname*{arg\,min}_{x∈\mathcal M} \frac{1}{2}\sum_{i=1}^{N} \operatorname{d}^2_{\mathcal M}(x,p_i),$$
which of course can be (and is) solved by a gradient descent, see the introductionary tutorial or the Statistics in Manifolds.jl. If $N$ is very large it might be quite expensive to evaluate the complete gradient. A remedy is, to evaluate only one of the terms at a time and choose a random order for these.
We first initialize the packages
using BenchmarkTools, Colors, Manopt, Manifolds, PlutoUI, Random
and we define some colors from Paul Tol
begin
black = RGBA{Float64}(colorant"#000000")
TolVibrantOrange = RGBA{Float64}(colorant"#EE7733") # Start
TolVibrantBlue = RGBA{Float64}(colorant"#0077BB") # a path
TolVibrantTeal = RGBA{Float64}(colorant"#009988") # points
end;
We next generate a (little) large(r) data set
begin
n = 5000
σ = π / 12
M = Sphere(2)
x = 1 / sqrt(2) * [1.0, 0.0, 1.0]
Random.seed!(42)
data = [exp(M, x, random_tangent(M, x, Val(:Gaussian), σ)) for i in 1:n]
localpath = join(splitpath(@__FILE__)[1:(end - 1)], "/") # files folder
image_prefix = localpath * "/stochastic_gradient_descent"
@info image_prefix
render_asy = false # on CI or when you do not have asymptote, this should be false
end
false
render_asy && asymptote_export_S2_signals(
image_prefix * "/center_and_large_data.asy";
points=[[x], data],
colors=Dict(:points => [TolVibrantBlue, TolVibrantTeal]),
dot_sizes=[2.5, 1.0],
camera_position=(1.0, 0.5, 0.5),
)
false
render_asy && render_asymptote(image_prefix * "/center_and_large_data.asy"; render=2);
PlutoUI.LocalResource(image_prefix * "/center_and_large_data.png")
Note that due to the construction of the points as zero mean tangent vectors, the mean should be very close to our initial point x
.
In order to use the stochastic gradient, we now need a function that returns the vector of gradients. There are two ways to define it in Manopt.jl
: as one function, that returns a vector or a vector of funtions.
The first variant is of course easier to define, but the second is more efficient when only evaluating one of the gradients.
For the mean we have as a gradient
$$ gradF(x) = \sum_{i=1}^N \operatorname{grad}f_i(x) \quad \text{where} \operatorname{grad}f_i(x) = -\log_x p_i$$
Which we define in Manopt.jl
in two different ways: Either as one function returning all gradients as a vector (see gradF
) or – maybe more fitting for a large scale problem, as a vector of small gradient functions (see gradf
)
F(M, x) = 1 / (2 * n) * sum(map(p -> distance(M, x, p)^2, data))
F (generic function with 1 method)
gradF(M, x) = [grad_distance(M, p, x) for p in data]
gradF (generic function with 1 method)
gradf = [(M, x) -> grad_distance(M, p, x) for p in data];
The calls are only slightly different, but notice that accessing the 2nd gradient element requires evaluating all logs in the first function. So while you can use both gradF
and gradf
in the following call, the second one is (much) faster:
x_opt1 = stochastic_gradient_descent(M, gradF, x)
3-element Vector{Float64}:
0.7071067811865475
0.0
0.7071067811865475
@benchmark stochastic_gradient_descent($M, $gradF, $x)
BenchmarkTools.Trial: 2988 samples with 1 evaluation.
Range (min … max): 1.138 ms … 28.163 ms ┊ GC (min … max): 0.00% … 93.56%
Time (median): 1.340 ms ┊ GC (median): 0.00%
Time (mean ± σ): 1.670 ms ± 2.271 ms ┊ GC (mean ± σ): 11.76% ± 8.19%
▁▄█▇▇▇▄▅▁▂ ▂▂▂▁▁
███████████▆▅▄▃▄▃▃▃▂▃▂▂▂▂▃▅██████▇▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂ ▄
1.14 ms Histogram: frequency by time 2.4 ms <
Memory estimate: 861.47 KiB, allocs estimate: 10030.
x_opt2 = stochastic_gradient_descent(M, gradf, x)
3-element Vector{Float64}:
0.7071067811865475
0.0
0.7071067811865475
@benchmark stochastic_gradient_descent($M, $gradf, $x)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
Range (min … max): 6.480 μs … 1.680 ms ┊ GC (min … max): 0.00% … 98.02%
Time (median): 8.720 μs ┊ GC (median): 0.00%
Time (mean ± σ): 13.061 μs ± 66.942 μs ┊ GC (mean ± σ): 21.53% ± 4.18%
▁▄▆▇██▆▃▂▂▃▄▄▃▂▁ ▁ ▂
█████████████████████▇▇▆▇▅▄▅▅▅▆▄▅▄▅▅▄▃▅▄▅▄▂▄▆▇▇████▇▆▇▇▆▇▇▆ █
6.48 μs Histogram: log(frequency) by time 33.1 μs <
Memory estimate: 41.00 KiB, allocs estimate: 23.
x_opt2
3-element Vector{Float64}:
0.7071067811865475
0.0
0.7071067811865475
This result is reasonably close. But we can improve it by using a DirectionUpdateRule
, namely:
On the one hand MomentumGradient
, which requires both the manifold and the initial value, in order to keep track of the iterate and parallel transport the last direction to the current iterate. You can also set a vector_transport_method
, if ParallelTransport()
is not available on your manifold. Here we simply do
x_opt3 = stochastic_gradient_descent(
M, gradf, x; direction=MomentumGradient(M, x, StochasticGradient(zero_vector(M, x)))
)
3-element Vector{Float64}:
0.7071067811865475
0.0
0.7071067811865475
MG = MomentumGradient(M, x, StochasticGradient(zero_vector(M, x)));
@benchmark stochastic_gradient_descent($M, $gradf, $x; direction=$MG)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
Range (min … max): 6.800 μs … 2.111 ms ┊ GC (min … max): 0.00% … 98.30%
Time (median): 8.700 μs ┊ GC (median): 0.00%
Time (mean ± σ): 13.385 μs ± 70.096 μs ┊ GC (mean ± σ): 21.86% ± 4.18%
▁▄▇██▆▃▂▂▃▄▄▃▂▁ ▁ ▁ ▂
███████████████████▇▆▇▆▅▅▅▅▄▃▄▅▄▄▄▆▃▄▄▄▄▆▇████▇▇▇▇▇▇▇▇▅▅▆▅▆ █
6.8 μs Histogram: log(frequency) by time 33.8 μs <
Memory estimate: 40.92 KiB, allocs estimate: 22.
And on the other hand the AverageGradient
computes an average of the last n
gradients, i.e.
x_opt4 = stochastic_gradient_descent(
M, gradf, x; direction=AverageGradient(M, x, 10, StochasticGradient(zero_vector(M, x)))
);
AG = AverageGradient(M, x, 10, StochasticGradient(zero_vector(M, x)));
@benchmark stochastic_gradient_descent($M, $gradf, $x; direction=$AG)
BenchmarkTools.Trial: 10000 samples with 4 evaluations.
Range (min … max): 6.850 μs … 2.402 ms ┊ GC (min … max): 0.00% … 98.00%
Time (median): 8.775 μs ┊ GC (median): 0.00%
Time (mean ± σ): 13.460 μs ± 76.977 μs ┊ GC (mean ± σ): 21.07% ± 3.68%
▄▇██▅▃▂▂▂▂▃▄▃▂▁ ▂
█████████████████▇▇████▇▇▆▆▅▆▅▆▅▅▆▄▅▅▆▅▆██████▆▆▅▆▆▇▅▅▆▅▅▅▅ █
6.85 μs Histogram: log(frequency) by time 34.4 μs <
Memory estimate: 40.92 KiB, allocs estimate: 22.
Note that the default StoppingCriterion
is a fixed number of iterations which helps the comparison here.
For both update rules we have to internally specify that we are still in the Stochastic setting (since both rules can also be used with the IdentityUpdateRule
within gradient_descent
,
For this not-that-large-scale example we can of course also use a gradient descent with ArmijoLinesearch
, but it will be a little slower usually
fullGradF(M, x) = sum(grad_distance(M, p, x) for p in data)
fullGradF (generic function with 1 method)
x_opt5 = gradient_descent(M, F, fullGradF, x; stepsize=ArmijoLinesearch())
3-element Vector{Float64}:
0.7071067811865475
-3.9257729890729064e-16
0.7071067811865475
AL = ArmijoLinesearch();
@benchmark gradient_descent($M, $F, $fullGradF, $x; stepsize=$AL)
BenchmarkTools.Trial: 4 samples with 1 evaluation.
Range (min … max): 1.396 s … 1.461 s ┊ GC (min … max): 11.71% … 13.27%
Time (median): 1.450 s ┊ GC (median): 11.55%
Time (mean ± σ): 1.439 s ± 28.979 ms ┊ GC (mean ± σ): 12.03% ± 0.84%
█ █ █ █
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁█▁▁▁▁▁▁▁█ ▁
1.4 s Histogram: frequency by time 1.46 s <
Memory estimate: 711.54 MiB, allocs estimate: 9034670.