Constraint Mean using Frank–Wolfe

This example illustrates the use of the Riemannian Frank-Wolfe algorithm by reimplementing the example given in [WeberSra2022], Section 4.2, that is the geometric mean (in Alg. 3, the GM variant).

using Manopt, Manifolds, LinearAlgebra, Plots, PlutoUI, Random

We first define some variables.

  • n is the dimension of our SPD matrices

  • N is the number of data matrices we will generate

begin
    Random.seed!(42)
    # dimension of the SPD matrices
    n = 50
    # number of data points
    N = 100
end;

We now generate the 100 data points and 100 random weights that sum up to 1.

begin
    data = [random_point(M) for _ in 1:N]
    weights = rand(N)
    weights ./= sum(weights)
end;
@doc raw"""
    harmonic_mean(pts,w=1/N.*ones(N))

for `N` points `pts` on the SPD matrices, this computes the weighted harmonic mean, i.e.

````math
    \biggl( \sum_{i=1}^N w_i p_i^{-1} \biggr)^{-1}
````
"""
function harmonic_mean(pts, w=1 / length(pts) .* ones(length(pts)))
    return inv(sum([wi * inv(pi) for (wi, pi) in zip(w, pts)]))
end
@doc raw"""
    arithmetic_mean(pts,w=1/N.*ones(N))

for `N` points `pts` on the SPD matrices, this computes the weighted arithmetic mean, i.e.

````math
    \sum_{i=1}^N w_i p_i
````
"""
function arithmetic_mean(pts, w=1 / length(pts) .* ones(length(pts)))
    return sum([wi * pi for (wi, pi) in zip(w, pts)])
end
function weighted_mean_cost(M, p)
    return sum([wi * distance(M, di, p)^2 for (wi, di) in zip(weights, data)])
end
weighted_mean_cost (generic function with 1 method)
function grad_weighted_mean(M, p)
    q = SPDPoint(p)
    return sum([wi * grad_distance(M, di, q) for (wi, di) in zip(weights, data)])
end
grad_weighted_mean (generic function with 1 method)
function grad_weighted_mean!(M, X, p)
    q = SPDPoint(p)
    zero_vector!(M, X, p)
    Y = zero_vector(M, p)
    for (wi, di) in zip(weights, data)
        grad_distance!(M, Y, di, q)
        X .+= wi .* Y
    end
    return X
end
grad_weighted_mean! (generic function with 1 method)
M = SymmetricPositiveDefinite(n)
SymmetricPositiveDefinite(50)
@doc raw"""
    FW_oracle!(M::SymmetricPositiveDefinite, q, L, U, p, X)

Given a lower bound `L` and an upper bound `U` (spd matrices),
a point `p` and a tangent vector (e.g. the gradient at p),
this oracle solves the subproblem related to the constraint problem

````math
    \operatorname{arg\,min}_{L\preceq q \preceq U} ⟨ X, \log_p q⟩
````
which has a closed form solution, cf. (38) in [^WeberSra2022] computed in place of `q`
"""
function FW_oracle!(M::SymmetricPositiveDefinite, q, L, U, p, X)
    (p_sqrt, p_sqrt_inv) = Manifolds.spd_sqrt_and_sqrt_inv(p)

    e2 = eigen(p_sqrt * X * p_sqrt)
    D = Diagonal(1.0 .* (e2.values .< 0))
    Q = e2.vectors

    Uprime = Q' * p_sqrt_inv * U * p_sqrt_inv * Q
    Lprime = Q' * p_sqrt_inv * L * p_sqrt_inv * Q
    P = cholesky(Hermitian(Uprime - Lprime))
    z = P.U' * D * P.U + Lprime
    copyto!(M, q, p_sqrt * Q * z * Q' * p_sqrt)
    return q
end
function FW_oracle!(M::SymmetricPositiveDefinite, q::SPDPoint, L, U, p, X)
    (p_sqrt, p_sqrt_inv) = Manifolds.spd_sqrt_and_sqrt_inv(p)

    e2 = eigen(p_sqrt * X * p_sqrt)
    D = Diagonal(1.0 .* (e2.values .< 0))
    Q = e2.vectors

    Uprime = Q' * p_sqrt_inv * U * p_sqrt_inv * Q
    Lprime = Q' * p_sqrt_inv * L * p_sqrt_inv * Q
    P = cholesky(Hermitian(Uprime - Lprime))
    z = P.U' * D * P.U + Lprime
    Q = p_sqrt * Q * z * Q' * p_sqrt
    !ismissing(q.p) && copyto!(q.p, Q)
    q.eigen .= eigen(Q)
    if !is_missing(q.sqrt) && !ismissing(q.sqrt_inv)
        copyto!.([q.sqrt, q.sqrt_inv], eigvals_sqrt_and_sqrt_inv(Q))
    else
        !ismissing(q.sqrt) && copyto!(q.sqrt, eigvals_sqrt(Q))
        !ismissing(q.sqrt_inv) && copyto!(q.sqrt_inv, eigvals_sqrt_inv(Q))
    end
    return q
end
FW_oracle! (generic function with 2 methods)
H = harmonic_mean(data, weights);
A = arithmetic_mean(data, weights);
special_oracle!(M, q, p, X) = FW_oracle!(M, q, H, A, p, X)
special_oracle! (generic function with 1 method)
statsM = @timed qT = mean(M, data, weights);
cT = weighted_mean_cost(M, qT)
1.9416538882004648
PlutoUI.with_terminal() do
    global statsF = @timed global oF = Frank_Wolfe_method(
        M,
        weighted_mean_cost,
        grad_weighted_mean!,
        data[1];
        subtask=special_oracle!,
        debug=[
            :Iteration,
            :Cost,
            (:Change, " | Change: %1.5e | "),
            DebugGradientNorm(; format=" | grad F |: %1.5e |"),
            "\n",
            :Stop,
            50,
        ],
        record=[:Iteration, :Iterate, :Cost],
        evaluation=MutatingEvaluation(),
        return_options=true,
    )
end
Initial F(x): 4.023987
# 50    F(x): 1.941664 | Change: 4.04687e-01 |  | grad F |: 3.20442e-03 |
# 100   F(x): 1.941655 | Change: 3.25453e-03 |  | grad F |: 1.55214e-03 |
# 150   F(x): 1.941654 | Change: 1.78362e-03 |  | grad F |: 1.08175e-03 |
# 200   F(x): 1.941654 | Change: 1.22176e-03 |  | grad F |: 7.98141e-04 |
# 250   F(x): 1.941653 | Change: 9.57239e-04 |  | grad F |: 6.37809e-04 |
# 300   F(x): 1.941653 | Change: 7.82434e-04 |  | grad F |: 5.36181e-04 |
# 350   F(x): 1.941653 | Change: 6.62996e-04 |  | grad F |: 4.44599e-04 |
# 400   F(x): 1.941653 | Change: 5.49384e-04 |  | grad F |: 4.01004e-04 |
# 450   F(x): 1.941653 | Change: 5.06441e-04 |  | grad F |: 3.57819e-04 |
# 500   F(x): 1.941653 | Change: 4.56883e-04 |  | grad F |: 3.15573e-04 |
The algorithm reached its maximal number of iterations (500).
qF = get_solver_result(oF);
cF = weighted_mean_cost(M, qF)
1.9416530862221668
q1 = copy(M, data[1]);
statsF20 = @timed Frank_Wolfe_method!(
    M,
    weighted_mean_cost,
    grad_weighted_mean!,
    q1;
    subtask=special_oracle!,
    evaluation=MutatingEvaluation(),
    stopping_criterion=StopAfterIteration(20),
);
c1 = weighted_mean_cost(M, q1)
1.9417001445586257
PlutoUI.with_terminal() do
    global oG = gradient_descent(
        M,
        weighted_mean_cost,
        grad_weighted_mean!,
        data[1];
        record=[:Iteration, :Iterate, :Cost],
        debug=[
            :Iteration,
            :Cost,
            (:Change, " | Change: %1.5e | "),
            DebugGradientNorm(; format=" | grad F |: %1.5e |"),
            "\n",
            :Stop,
            1,
        ],
        evaluation=MutatingEvaluation(),
        stopping_criterion=StopAfterIteration(200) | StopWhenGradientNormLess(1e-12),
        return_options=true,
    )
end
Initial F(x): 4.023987
# 1     F(x): 1.941731 | Change: 1.44730e+00 |  | grad F |: 1.36082e+00 |
# 2     F(x): 1.941653 | Change: 8.85765e-03 |  | grad F |: 8.85831e-03 |
# 3     F(x): 1.941653 | Change: 5.79311e-05 |  | grad F |: 5.79311e-05 |
# 4     F(x): 1.941653 | Change: 3.92724e-07 |  | grad F |: 3.92724e-07 |
# 5     F(x): 1.941653 | Change: 2.73998e-09 |  | grad F |: 2.73997e-09 |
# 6     F(x): 1.941653 | Change: 1.96115e-11 |  | grad F |: 1.95947e-11 |
# 7     F(x): 1.941653 | Change: 3.76508e-13 |  | grad F |: 3.60086e-13 |
The algorithm reached approximately critical point after 7 iterations; the gradient norm (3.6008637846307747e-13) is less than 1.0e-12.
q2 = copy(M, data[1]);
statsG = @timed gradient_descent!(
    M,
    weighted_mean_cost,
    grad_weighted_mean!,
    q2;
    evaluation=MutatingEvaluation(),
    stopping_criterion=StopAfterIteration(200) | StopWhenGradientNormLess(1e-12),
);
cG = weighted_mean_cost(M, q2)
1.9416529805815455

We get the following results in the cost

MethodCostComputational time
mean (Manifolds.jl)1.94165388820046480.25751222 sec.
FrankWolfe1.941653086222166853.206128953 sec.
gradient_descent1.94165298058154550.887520014 sec.

And since we recorded the values in the first runs, that were not timed, we can also plot how the cost evolves over time. Note that gradient descent already finishes after 7 iterations, while we shortened Frank Wolfe to only the first 8 iterations.

begin
    fig = plot(
        [0, get_record(oF, :Iteration, :Iteration)[1:8]...],
        [weighted_mean_cost(M, data[1]), get_record(oF, :Iteration, :Cost)[1:8]...];
        label="Frank Wolfe",
    )
    plot!(
        fig,
        [0, get_record(oG, :Iteration, :Iteration)...],
        [weighted_mean_cost(M, data[1]), get_record(oG, :Iteration, :Cost)...];
        label="Gradient Descent",
    )
end

A challenge seems to be to find a good stopping criterion for Frank Wolfe on manifolds, since after iteration 10 the cost only changes in order 1e-4, and the iterates do as well, which is above the current threshold used for most other algorithms.

It also seems gradient descent outperforms Frank Wolfe in the speed it reaches the final cost when using ArmijoLinesearch, which was not used as an example in the paper [WeberSra2022].

Literature

WeberSra2022

M. Weber, S. Sra: Riemannian Optimization via Frank-Wolfe Methods, Math. Prog., 2022, to appear. doi: 10.1007/s10107-022-01840-5.