Constraint Mean using Frank–Wolfe

This example illustrates the use of the Riemannian Frank-Wolfe algorithm by reimplementing the example given in [WeberSra2022], Section 4.2, that is the geometric mean (in Alg. 3, the GM variant).

using Manopt, Manifolds, LinearAlgebra, Plots, PlutoUI, Random

We first define some variables.

  • n is the dimension of our SPD matrices

  • N is the number of data matrices we will generate

begin
    Random.seed!(42)
    # dimension of the SPD matrices
    n = 50
    # number of data points
    N = 100
end;

We now generate the 100 data points and 100 random weights that sum up to 1.

begin
    data = [random_point(M) for _ in 1:N]
    weights = rand(N)
    weights ./= sum(weights)
end;
@doc raw"""
    harmonic_mean(pts,w=1/N.*ones(N))

for `N` points `pts` on the SPD matrices, this computes the weighted harmonic mean, i.e.

````math
    \biggl( \sum_{i=1}^N w_i p_i^{-1} \biggr)^{-1}
````
"""
function harmonic_mean(pts, w=1 / length(pts) .* ones(length(pts)))
    return inv(sum([wi * inv(pi) for (wi, pi) in zip(w, pts)]))
end
@doc raw"""
    arithmetic_mean(pts,w=1/N.*ones(N))

for `N` points `pts` on the SPD matrices, this computes the weighted arithmetic mean, i.e.

````math
    \sum_{i=1}^N w_i p_i
````
"""
function arithmetic_mean(pts, w=1 / length(pts) .* ones(length(pts)))
    return sum([wi * pi for (wi, pi) in zip(w, pts)])
end
function weighted_mean_cost(M, p)
    return sum([wi * distance(M, di, p)^2 for (wi, di) in zip(weights, data)])
end
weighted_mean_cost (generic function with 1 method)
function grad_weighted_mean(M, p)
    q = SPDPoint(p)
    return sum([wi * grad_distance(M, di, q) for (wi, di) in zip(weights, data)])
end
grad_weighted_mean (generic function with 1 method)
function grad_weighted_mean!(M, X, p)
    q = SPDPoint(p)
    zero_vector!(M, X, p)
    Y = zero_vector(M, p)
    for (wi, di) in zip(weights, data)
        grad_distance!(M, Y, di, q)
        X .+= wi .* Y
    end
    return X
end
grad_weighted_mean! (generic function with 1 method)
M = SymmetricPositiveDefinite(n)
SymmetricPositiveDefinite(50)
@doc raw"""
    FW_oracle!(M::SymmetricPositiveDefinite, q, L, U, p, X)

Given a lower bound `L` and an upper bound `U` (spd matrices),
a point `p` and a tangent vector (e.g. the gradient at p),
this oracle solves the subproblem related to the constraint problem

````math
    \operatorname{arg\,min}_{L\preceq q \preceq U} ⟨ X, \log_p q⟩
````
which has a closed form solution, cf. (38) in [^WeberSra2022] computed in place of `q`
"""
function FW_oracle!(M::SymmetricPositiveDefinite, q, L, U, p, X)
    (p_sqrt, p_sqrt_inv) = Manifolds.spd_sqrt_and_sqrt_inv(p)

    e2 = eigen(p_sqrt * X * p_sqrt)
    D = Diagonal(1.0 .* (e2.values .< 0))
    Q = e2.vectors

    Uprime = Q' * p_sqrt_inv * U * p_sqrt_inv * Q
    Lprime = Q' * p_sqrt_inv * L * p_sqrt_inv * Q
    P = cholesky(Hermitian(Uprime - Lprime))
    z = P.U' * D * P.U + Lprime
    copyto!(M, q, p_sqrt * Q * z * Q' * p_sqrt)
    return q
end
function FW_oracle!(M::SymmetricPositiveDefinite, q::SPDPoint, L, U, p, X)
    (p_sqrt, p_sqrt_inv) = Manifolds.spd_sqrt_and_sqrt_inv(p)

    e2 = eigen(p_sqrt * X * p_sqrt)
    D = Diagonal(1.0 .* (e2.values .< 0))
    Q = e2.vectors

    Uprime = Q' * p_sqrt_inv * U * p_sqrt_inv * Q
    Lprime = Q' * p_sqrt_inv * L * p_sqrt_inv * Q
    P = cholesky(Hermitian(Uprime - Lprime))
    z = P.U' * D * P.U + Lprime
    Q = p_sqrt * Q * z * Q' * p_sqrt
    !ismissing(q.p) && copyto!(q.p, Q)
    q.eigen .= eigen(Q)
    if !is_missing(q.sqrt) && !ismissing(q.sqrt_inv)
        copyto!.([q.sqrt, q.sqrt_inv], eigvals_sqrt_and_sqrt_inv(Q))
    else
        !ismissing(q.sqrt) && copyto!(q.sqrt, eigvals_sqrt(Q))
        !ismissing(q.sqrt_inv) && copyto!(q.sqrt_inv, eigvals_sqrt_inv(Q))
    end
    return q
end
FW_oracle! (generic function with 2 methods)
H = harmonic_mean(data, weights);
A = arithmetic_mean(data, weights);
special_oracle!(M, q, p, X) = FW_oracle!(M, q, H, A, p, X)
special_oracle! (generic function with 1 method)
statsM = @timed qT = mean(M, data, weights);
cT = weighted_mean_cost(M, qT)
1.8916283036527741
PlutoUI.with_terminal() do
    global statsF = @timed global oF = Frank_Wolfe_method(
        M,
        weighted_mean_cost,
        grad_weighted_mean!,
        data[1];
        subtask=special_oracle!,
        debug=[
            :Iteration,
            :Cost,
            (:Change, " | Change: %1.5e | "),
            DebugGradientNorm(; format=" | grad F |: %1.5e |"),
            "\n",
            :Stop,
            50,
        ],
        record=[:Iteration, :Iterate, :Cost],
        evaluation=MutatingEvaluation(),
        return_options=true,
    )
end
Initial F(x): 3.819152
# 50    F(x): 1.891638 | Change: 3.88088e-01 |  | grad F |: 3.32328e-03 |
# 100   F(x): 1.891629 | Change: 3.14309e-03 |  | grad F |: 1.56326e-03 |
# 150   F(x): 1.891629 | Change: 1.69170e-03 |  | grad F |: 1.01927e-03 |
# 200   F(x): 1.891628 | Change: 1.21721e-03 |  | grad F |: 7.88593e-04 |
The algorithm reached its maximal number of iterations (200).
qF = get_solver_result(oF);
cF = weighted_mean_cost(M, qF)
1.8916280467199975
q1 = copy(M, data[1]);
statsF20 = @timed Frank_Wolfe_method!(
    M,
    weighted_mean_cost,
    grad_weighted_mean!,
    q1;
    subtask=special_oracle!,
    evaluation=MutatingEvaluation(),
    stopping_criterion=StopAfterIteration(20),
);
c1 = weighted_mean_cost(M, q1)
1.8916721422547849
PlutoUI.with_terminal() do
    global oG = gradient_descent(
        M,
        weighted_mean_cost,
        grad_weighted_mean!,
        data[1];
        record=[:Iteration, :Iterate, :Cost],
        debug=[
            :Iteration,
            :Cost,
            (:Change, " | Change: %1.5e | "),
            DebugGradientNorm(; format=" | grad F |: %1.5e |"),
            "\n",
            :Stop,
            1,
        ],
        evaluation=MutatingEvaluation(),
        stopping_criterion=StopAfterIteration(200) | StopWhenGradientNormLess(1e-12),
        return_options=true,
    )
end
Initial F(x): 3.819152
# 1     F(x): 1.891695 | Change: 1.39241e+00 |  | grad F |: 1.44040e+00 |
# 2     F(x): 1.891627 | Change: 8.27573e-03 |  | grad F |: 8.27536e-03 |
# 3     F(x): 1.891627 | Change: 5.16064e-05 |  | grad F |: 5.16065e-05 |
# 4     F(x): 1.891627 | Change: 3.33490e-07 |  | grad F |: 3.33490e-07 |
# 5     F(x): 1.891627 | Change: 2.22400e-09 |  | grad F |: 2.22401e-09 |
# 6     F(x): 1.891627 | Change: 1.52379e-11 |  | grad F |: 1.52297e-11 |
# 7     F(x): 1.891627 | Change: 4.25243e-13 |  | grad F |: 4.37838e-13 |
The algorithm reached approximately critical point after 7 iterations; the gradient norm (4.378376265415871e-13) is less than 1.0e-12.
q2 = copy(M, data[1]);
statsG = @timed gradient_descent!(
    M,
    weighted_mean_cost,
    grad_weighted_mean!,
    q2;
    evaluation=MutatingEvaluation(),
    stopping_criterion=StopAfterIteration(200) | StopWhenGradientNormLess(1e-12),
);
cG = weighted_mean_cost(M, q2)
1.8916274280033853

We get the following results in the cost

MethodCostComputational time
mean (Manifolds.jl)1.89162830365277410.330450405 sec.
FrankWolfe1.891628046719997536.496048782 sec.
gradient_descent1.89162742800338531.103814741 sec.

And since we recorded the values in the first runs, that were not timed, we can also plot how the cost evolves over time. Note that gradient descent already finishes after 7 iterations, while we shortened Frank Wolfe to only the first 8 iterations.

begin
    fig = plot(
        [0, get_record(oF, :Iteration, :Iteration)[1:8]...],
        [weighted_mean_cost(M, data[1]), get_record(oF, :Iteration, :Cost)[1:8]...];
        label="Frank Wolfe",
    )
    plot!(
        fig,
        [0, get_record(oG, :Iteration, :Iteration)...],
        [weighted_mean_cost(M, data[1]), get_record(oG, :Iteration, :Cost)...];
        label="Gradient Descent",
    )
end

A challenge seems to be to find a good stopping criterion for Frank Wolfe on manifolds, since after iteration 10 the cost only changes in order 1e-4, and the iterates do as well, which is above the current threshold used for most other algorithms.

It also seems gradient descent outperforms Frank Wolfe in the speed it reaches the final cost when using ArmijoLinesearch, which was not used as an example in the paper [WeberSra2022].

Literature

WeberSra2022

M. Weber, S. Sra: Riemannian Optimization via Frank-Wolfe Methods, Math. Prog., 2022, to appear. doi: 10.1007/s10107-022-01840-5.