How to run stochastic gradient descent

Ronny Bergmann

This tutorial illustrates how to use the stochastic_gradient_descent solver and different DirectionUpdateRules to introduce the average or momentum variant, see Stochastic Gradient Descent.

Computationally, we look at a very simple but large scale problem, the Riemannian Center of Mass or FrΓ©chet mean: for given points $p_i ∈\mathcal M$, $i=1,…,N$ this optimization problem reads

\[\operatorname*{arg\,min}_{x∈\mathcal M} \frac{1}{2}\sum_{i=1}^{N} \operatorname{d}^2_{\mathcal M}(x,p_i),\]

which of course can be (and is) solved by a gradient descent, see the introductory tutorial or Statistics in Manifolds.jl. If $N$ is very large, evaluating the complete gradient might be quite expensive. A remedy is to evaluate only one of the terms at a time and choose a random order for these.

We first initialize the packages

using Manifolds, Manopt, Random, BenchmarkTools, ManifoldDiff
using ManifoldDiff: grad_distance
Random.seed!(42);

We next generate a (little) large(r) data set

n = 5000
Οƒ = Ο€ / 12
M = Sphere(2)
p = 1 / sqrt(2) * [1.0, 0.0, 1.0]
data = [exp(M, p,  Οƒ * rand(M; vector_at=p)) for i in 1:n];

Note that due to the construction of the points as zero mean tangent vectors, the mean should be very close to our initial point p.

In order to use the stochastic gradient, we now need a function that returns the vector of gradients. There are two ways to define it in Manopt.jl: either as a single function that returns a vector, or as a vector of functions.

The first variant is of course easier to define, but the second is more efficient when only evaluating one of the gradients.

For the mean, the gradient is

\[\operatorname{grad}f(p) = \sum_{i=1}^N \operatorname{grad}f_i(x) \quad \text{where} \operatorname{grad}f_i(x) = -\log_x p_i\]

which we define in Manopt.jl in two different ways: either as one function returning all gradients as a vector (see gradF), or, maybe more fitting for a large scale problem, as a vector of small gradient functions (see gradf)

F(M, p) = 1 / (2 * n) * sum(map(q -> distance(M, p, q)^2, data))
gradF(M, p) = [grad_distance(M, p, q) for q in data]
gradf = [(M, p) -> grad_distance(M, q, p) for q in data];
p0 = 1 / sqrt(3) * [1.0, 1.0, 1.0]
3-element Vector{Float64}:
 0.5773502691896258
 0.5773502691896258
 0.5773502691896258

The calls are only slightly different, but notice that accessing the second gradient element requires evaluating all logs in the first function, while we only call one of the functions in the second array of functions. So while you can use both gradF and gradf in the following call, the second one is (much) faster:

p_opt1 = stochastic_gradient_descent(M, gradF, p)
3-element Vector{Float64}:
  0.6940527079187876
 -0.37439006629268595
 -0.614917000002404
@benchmark stochastic_gradient_descent($M, $gradF, $p0)
BenchmarkTools.Trial: 1 sample with 1 evaluation per sample.
 Single result which took 5.998 s (11.85% GC) to evaluate,
 with a memory estimate of 7.84 GiB, over 200470602 allocations.
p_opt2 = stochastic_gradient_descent(M, gradf, p0)
3-element Vector{Float64}:
 0.6828818855405706
 0.1754529371758115
 0.7091463863243864
@benchmark stochastic_gradient_descent($M, $gradf, $p0)
BenchmarkTools.Trial: 2211 samples with 1 evaluation per sample.
 Range (min … max):  934.774 ΞΌs … 10.931 ms  β”Š GC (min … max): 0.00% … 66.78%
 Time  (median):       1.856 ms              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):     2.262 ms Β±  1.471 ms  β”Š GC (mean Β± Οƒ):  8.90% Β± 12.65%

  ▅▇▆▅▄▄▁          β–ˆ                                            
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–†β–‡β–…β–…β–„β–…β–ˆβ–†β–‚β–‚β–‚β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–‚β–‚β–‚β–β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–β–‚β–‚ β–ƒ
  935 ΞΌs          Histogram: frequency by time         9.59 ms <

 Memory estimate: 955.47 KiB, allocs estimate: 21665.

This result is reasonably close. But we can improve it by using a DirectionUpdateRule, namely:

On the one hand MomentumGradient, which requires both the manifold and the initial value, to keep track of the iterate and parallel transport the last direction to the current iterate. The necessary vector_transport_method keyword is set to a suitable default on every manifold, see default_vector_transport_method. We get β€œβ€œβ€

p_opt3 = stochastic_gradient_descent(
    M, gradf, p0; direction=MomentumGradient(; direction=StochasticGradient())
)
3-element Vector{Float64}:
  0.8018714772674244
 -0.1618929268674662
  0.5751459068577677
MG = MomentumGradient(; direction=StochasticGradient());
@benchmark stochastic_gradient_descent($M, $gradf, p=$p0; direction=$MG)
BenchmarkTools.Trial: 858 samples with 1 evaluation per sample.
 Range (min … max):  4.840 ms … 13.422 ms  β”Š GC (min … max):  0.00% … 55.23%
 Time  (median):     5.035 ms              β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   5.836 ms Β±  2.214 ms  β”Š GC (mean Β± Οƒ):  13.20% Β± 18.08%

  β–‡β–ˆβ–…β–ƒ                                                        
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–…β–…β–β–β–β–β–„β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–…β–‡β–ˆβ–‡β–‡β–‡β–…β–‡β–ˆβ–‡β–ˆ β–‡
  4.84 ms      Histogram: log(frequency) by time     12.7 ms <

 Memory estimate: 8.56 MiB, allocs estimate: 221660.

And on the other hand the AverageGradient computes an average of the last n gradients. This is done by

p_opt4 = stochastic_gradient_descent(
    M, gradf, p0; direction=AverageGradient(; n=10, direction=StochasticGradient()), debug=[],
)
3-element Vector{Float64}:
 -0.9984748902117264
 -0.03991462423400982
 -0.03814074447279103
AG = AverageGradient(; n=10, direction=StochasticGradient(M));
@benchmark stochastic_gradient_descent($M, $gradf, p=$p0; direction=$AG, debug=[])
BenchmarkTools.Trial: 3 samples with 1 evaluation per sample.
 Range (min … max):  1.676 s …   1.696 s  β”Š GC (min … max): 0.70% … 0.89%
 Time  (median):     1.680 s              β”Š GC (median):    0.90%
 Time  (mean Β± Οƒ):   1.684 s Β± 10.420 ms  β”Š GC (mean Β± Οƒ):  0.84% Β± 0.12%

  β–ˆ          β–ˆ                                            β–ˆ  
  β–ˆβ–β–β–β–β–β–β–β–β–β–β–ˆβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ˆ ▁
  1.68 s         Histogram: frequency by time         1.7 s <

 Memory estimate: 154.59 MiB, allocs estimate: 2691686.

Note that the default StoppingCriterion is a fixed number of iterations which helps the comparison here.

For both update rules we have to internally specify that we are still in the stochastic setting, since both rules can also be used with the IdentityUpdateRule within gradient_descent.

For this not-that-large-scale example we can of course also use a gradient descent with ArmijoLinesearch,

fullGradF(M, p) = 1/n*sum(grad_distance(M, q, p) for q in data)
p_opt5 = gradient_descent(M, F, fullGradF, p0; stepsize=ArmijoLinesearch())
3-element Vector{Float64}:
  0.7050420976839262
 -0.006374163322665686
  0.7091368066426853

but in general it is expected to be a bit slow.

AL = ArmijoLinesearch();
@benchmark gradient_descent($M, $F, $fullGradF, $p0; stepsize=$AL)
BenchmarkTools.Trial: 48 samples with 1 evaluation per sample.
 Range (min … max):   98.627 ms … 108.740 ms  β”Š GC (min … max):  6.18% … 13.10%
 Time  (median):     106.552 ms               β”Š GC (median):    12.70%
 Time  (mean Β± Οƒ):   105.520 ms Β±   2.787 ms  β”Š GC (mean Β± Οƒ):  11.72% Β±  2.39%

         ▁                                     β–β–β–ˆβ–ƒβ–ˆβ–†    ▁       
  β–„β–β–β–β–β–β–„β–ˆβ–„β–β–„β–‡β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–„β–β–„β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–β–β–β–ˆβ–β–‡β–β–β–„ ▁
  98.6 ms          Histogram: frequency by time          109 ms <

 Memory estimate: 138.29 MiB, allocs estimate: 3395753.

Technical details

This tutorial is cached. It was last run on the following package versions.

Status `~/work/Manopt.jl/Manopt.jl/tutorials/Project.toml`
  [47edcb42] ADTypes v1.22.0
  [6e4b80f9] BenchmarkTools v1.8.0
  [5ae59095] Colors v0.13.1
  [31c24e10] Distributions v0.25.125
  [26cc04aa] FiniteDifferences v0.12.33
  [8ac3fa9e] LRUCache v1.6.2
  [af67fdf4] ManifoldDiff v0.4.5
  [1cead3c2] Manifolds v0.11.23
  [3362f125] ManifoldsBase v2.3.5
  [0fc0a36d] Manopt v0.5.36 `.`
  [91a5bcdd] Plots v1.41.6
  [731186ca] RecursiveArrayTools v4.3.0
  [37e2e46d] LinearAlgebra v1.12.0
  [9a3f8284] Random v1.11.0

This tutorial was last rendered May 4, 2026, 07:22:52.