Stochastic Gradient Descent

This tutorial illustrates how to use the stochastic_gradient_descent solver and different DirectionUpdateRules in order to introduce the average or momentum variant, see Stochastic Gradient Descent.

Computationally, we look at a very simple but large scale problem, the Riemannian Center of Mass or Fréchet mean: for given points $p_i ∈\mathcal M$, $i=1,…,N$ this optimization problem reads

$$\operatorname*{arg\,min}_{x∈\mathcal M} \frac{1}{2}\sum_{i=1}^{N} \operatorname{d}^2_{\mathcal M}(x,p_i),$$

which of course can be (and is) solved by a gradient descent, see the introductionary tutorial or Statistics in Manifolds.jl. If $N$ is very large, evaluating the complete gradient might be quite expensive. A remedy is to evaluate only one of the terms at a time and choose a random order for these.

We first initialize the packages

Setup

If you open this notebook in Pluto locally it switches between two modes. If the tutorial is within the Manopt.jl repository, this notebook tries to use the local package in development mode. Otherwise, the file uses the Pluto pacakge management version. In this case, the includsion of images might be broken. unless you create a subfolder optimize and activate asy-rendering.

Since the loading is a little complicated, we show, which versions of packages were installed in the following.

with_terminal() do
    Pkg.status()
end
�[32m�[1mStatus�[22m�[39m `/private/var/folders/_v/wg192lpd3mb1lp55zz7drpcw0000gn/T/jl_KZHwyN/Project.toml`
 �[90m [6e4b80f9] �[39mBenchmarkTools v1.3.2
 �[90m [5ae59095] �[39mColors v0.12.10
 �[90m [1cead3c2] �[39mManifolds v0.8.42
 �[90m [0fc0a36d] �[39mManopt v0.4.0 `~/Repositories/Julia/Manopt.jl`
 �[90m [7f904dfe] �[39mPlutoUI v0.7.49
 �[90m [44cfe95a] �[39mPkg v1.8.0
 �[90m [9a3f8284] �[39mRandom

and we define some colors from Paul Tol

begin
    black = RGBA{Float64}(colorant"#000000")
    TolVibrantOrange = RGBA{Float64}(colorant"#EE7733") # Start
    TolVibrantBlue = RGBA{Float64}(colorant"#0077BB") # a path
    TolVibrantTeal = RGBA{Float64}(colorant"#009988") # points
end;

We next generate a (little) large(r) data set

begin
    n = 5000
    σ = π / 12
    M = Sphere(2)
    x = 1 / sqrt(2) * [1.0, 0.0, 1.0]
    Random.seed!(42)
    data = [exp(M, x,  σ * rand(M; vector_at=x)) for i in 1:n]
    localpath = join(splitpath(@__FILE__)[1:(end - 1)], "/") # files folder
    image_prefix = localpath * "/stochastic_gradient_descent"
    _in_package && @info image_prefix
    render_asy = false # on CI or when you do not have asymptote, this should be false
end
false
render_asy && asymptote_export_S2_signals(
    image_prefix * "/center_and_large_data.asy";
    points=[[x], data],
    colors=Dict(:points => [TolVibrantBlue, TolVibrantTeal]),
    dot_sizes=[2.5, 1.0],
    camera_position=(1.0, 0.5, 0.5),
);
render_asy && render_asymptote(image_prefix * "/center_and_large_data.asy"; render=2);
PlutoUI.LocalResource(image_prefix * "/center_and_large_data.png")

Note that due to the construction of the points as zero mean tangent vectors, the mean should be very close to our initial point x.

In order to use the stochastic gradient, we now need a function that returns the vector of gradients. There are two ways to define it in Manopt.jl: either as a single function that returns a vector, or as a vector of functions.

The first variant is of course easier to define, but the second is more efficient when only evaluating one of the gradients.

For the mean, the gradient is

$$ gradF(x) = \sum_{i=1}^N \operatorname{grad}f_i(x) \quad \text{where} \operatorname{grad}f_i(x) = -\log_x p_i$$

which we define in Manopt.jl in two different ways: either as one function returning all gradients as a vector (see gradF), or – maybe more fitting for a large scale problem, as a vector of small gradient functions (see gradf)

F(M, x) = 1 / (2 * n) * sum(map(p -> distance(M, x, p)^2, data))
F (generic function with 1 method)
gradF(M, x) = [grad_distance(M, p, x) for p in data]
gradF (generic function with 1 method)
gradf = [(M, x) -> grad_distance(M, p, x) for p in data];

The calls are only slightly different, but notice that accessing the 2nd gradient element requires evaluating all logs in the first function. So while you can use both gradF and gradf in the following call, the second one is (much) faster:

x_opt1 = stochastic_gradient_descent(M, gradF, x)
3-element Vector{Float64}:
 0.7071067811865475
 0.0
 0.7071067811865475
@benchmark stochastic_gradient_descent($M, $gradF, $x)
BenchmarkTools.Trial: 8417 samples with 1 evaluation.
 Range (min … max):  482.250 μs …   5.693 ms  ┊ GC (min … max): 0.00% … 86.30%
 Time  (median):     530.500 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   592.126 μs ± 451.419 μs  ┊ GC (mean ± σ):  7.69% ±  8.91%

  █▆▃                                                           ▁
  ██████▇▅▃▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▄▁▁▅ █
  482 μs        Histogram: log(frequency) by time       4.65 ms <

 Memory estimate: 861.50 KiB, allocs estimate: 10031.
x_opt2 = stochastic_gradient_descent(M, gradf, x)
3-element Vector{Float64}:
 0.7071067811865475
 0.0
 0.7071067811865475
@benchmark stochastic_gradient_descent($M, $gradf, $x)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  1.949 μs …  1.558 ms  ┊ GC (min … max):  0.00% … 99.36%
 Time  (median):     5.349 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   9.043 μs ± 55.784 μs  ┊ GC (mean ± σ):  40.39% ±  6.77%

                     ▂▇█▇▅▂
  ▃▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▄▆██████▇▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂ ▃
  1.95 μs        Histogram: frequency by time        10.9 μs <

 Memory estimate: 41.06 KiB, allocs estimate: 26.
x_opt2
3-element Vector{Float64}:
 0.7071067811865475
 0.0
 0.7071067811865475

This result is reasonably close. But we can improve it by using a DirectionUpdateRule, namely:

On the one hand MomentumGradient, which requires both the manifold and the initial value, in order to keep track of the iterate and parallel transport the last direction to the current iterate. You can also set a vector_transport_method, if ParallelTransport() is not available on your manifold. Here, we simply do

x_opt3 = stochastic_gradient_descent(
    M, gradf, x; direction=MomentumGradient(M, x; direction=StochasticGradient(M; X=zero_vector(M, x)))
)
3-element Vector{Float64}:
 0.7071067811865475
 0.0
 0.7071067811865475
MG = MomentumGradient(M, x; direction=StochasticGradient(M; X=zero_vector(M, x)));
@benchmark stochastic_gradient_descent($M, $gradf, $x; direction=$MG)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  1.815 μs … 556.102 μs  ┊ GC (min … max):  0.00% … 98.68%
 Time  (median):     5.111 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   7.014 μs ±  30.657 μs  ┊ GC (mean ± σ):  29.35% ±  6.66%

                     ▁▆█▇▄▁
  ▃▇▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▄██████▆▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▂ ▃
  1.81 μs         Histogram: frequency by time        10.7 μs <

 Memory estimate: 40.97 KiB, allocs estimate: 24.

And on the other hand the AverageGradient computes an average of the last n gradients, i.e.

x_opt4 = stochastic_gradient_descent(
    M, gradf, x; direction=AverageGradient(M, x; n=10, direction=StochasticGradient(M; X=zero_vector(M, x)))
);
AG = AverageGradient(M, x; n=10, direction=StochasticGradient(M; X=zero_vector(M, x)));
@benchmark stochastic_gradient_descent($M, $gradf, $x; direction=$AG)
BenchmarkTools.Trial: 10000 samples with 9 evaluations.
 Range (min … max):  1.843 μs … 811.903 μs  ┊ GC (min … max):  0.00% … 98.94%
 Time  (median):     5.139 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   7.863 μs ±  40.574 μs  ┊ GC (mean ± σ):  34.27% ±  6.68%

                     ▁▅█▇▅▁
  ▂▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▅██████▇▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂ ▃
  1.84 μs         Histogram: frequency by time        10.6 μs <

 Memory estimate: 40.97 KiB, allocs estimate: 24.

Note that the default StoppingCriterion is a fixed number of iterations which helps the comparison here.

For both update rules we have to internally specify that we are still in the stochastic setting (since both rules can also be used with the IdentityUpdateRule within gradient_descent.

For this not-that-large-scale example we can of course also use a gradient descent with ArmijoLinesearch, but it will be a little slower usually

fullGradF(M, x) = sum(grad_distance(M, p, x) for p in data)
fullGradF (generic function with 1 method)
x_opt5 = gradient_descent(M, F, fullGradF, x; stepsize=ArmijoLinesearch(M))
3-element Vector{Float64}:
  0.7071067811865475
 -8.428836557897668e-17
  0.7071067811865475
AL = ArmijoLinesearch(M);
@benchmark gradient_descent($M, $F, $fullGradF, $x; stepsize=$AL)
BenchmarkTools.Trial: 11 samples with 1 evaluation.
 Range (min … max):  473.913 ms … 497.582 ms  ┊ GC (min … max): 8.47% … 10.42%
 Time  (median):     480.760 ms               ┊ GC (median):    8.85%
 Time  (mean ± σ):   482.758 ms ±   7.100 ms  ┊ GC (mean ± σ):  9.13% ±  0.99%

  ▁           ▁█▁  █  ▁     ▁                          ▁      ▁
  █▁▁▁▁▁▁▁▁▁▁▁███▁▁█▁▁█▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁█ ▁
  474 ms           Histogram: frequency by time          498 ms <

 Memory estimate: 703.23 MiB, allocs estimate: 9023227.