Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

## 0.9.7.dev0

This new release adds support for sparse cost matrices in the exact EMD solver. Users can now pass sparse cost matrices (e.g., k-NN graphs, sparse graphs) and receive sparse transport plans, significantly reducing memory footprint for large-scale problems. The implementation is backend-agnostic, automatically handling scipy.sparse for NumPy and torch.sparse for PyTorch, and preserves full gradient computation capabilities for automatic differentiation in PyTorch. This enables efficient solving of OT problems on graphs with millions of nodes where only a sparse subset of edges have finite costs.
This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.

#### New features
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j] (PR #785)
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)

#### Closed issues
- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785)
- Fix NumPy 2.x compatibility in Brenier potential bounds (PR #788)
- Fix MSVC Windows build by removing __restrict__ keyword (PR #788)
- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785)
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
- Add test for build from source (PR #772, Issue #764)
- Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783)
Expand Down
2 changes: 2 additions & 0 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .lp import (
emd,
emd2,
emd2_lazy,
emd_1d,
emd2_1d,
wasserstein_1d,
Expand Down Expand Up @@ -82,6 +83,7 @@
__all__ = [
"emd",
"emd2",
"emd2_lazy",
"emd_1d",
"sinkhorn",
"sinkhorn2",
Expand Down
25 changes: 16 additions & 9 deletions ot/gromov/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def GW_distance_estimation(

for i in range(nb_samples_p):
if nx.issparse(T):
T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
T_indexi = nx.reshape(nx.todense(T[[index_i[i]], :]), (-1,))
T_indexj = nx.reshape(nx.todense(T[[index_j[i]], :]), (-1,))
else:
T_indexi = T[index_i[i], :]
T_indexj = T[index_j[i], :]
Expand Down Expand Up @@ -243,16 +243,18 @@ def pointwise_gromov_wasserstein(
index = np.zeros(2, dtype=int)

# Initialize with default marginal
index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item())
index[1] = int(generator.choice(len_q, size=1, p=nx.to_numpy(q)).item())
T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))

best_gw_dist_estimated = np.inf
for cpt in range(max_iter):
index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
index[1] = generator.choice(
len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item())
T_index0 = nx.reshape(nx.todense(T[[index[0]], :]), (-1,))
index[1] = int(
generator.choice(
len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
).item()
)

if alpha == 1:
Expand Down Expand Up @@ -404,10 +406,15 @@ def sampled_gromov_wasserstein(
)
Lik = 0
for i, index0_i in enumerate(index0):
T_row = (
nx.reshape(nx.todense(T[[index0_i], :]), (-1,))
if nx.issparse(T)
else T[index0_i, :]
)
index1 = generator.choice(
len_q,
size=nb_samples_grad_q,
p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
p=nx.to_numpy(T_row / nx.sum(T_row)),
replace=False,
)
# If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
Expand Down
16 changes: 16 additions & 0 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,21 @@ int EMD_wrap_sparse(
uint64_t maxIter // Maximum iterations for solver
);

int EMD_wrap_lazy(
int n1, // Number of source points
int n2, // Number of target points
double *X, // Source weights (n1)
double *Y, // Target weights (n2)
double *coords_a, // Source coordinates (n1 x dim)
double *coords_b, // Target coordinates (n2 x dim)
int dim, // Dimension of coordinates
int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock
double *G, // Output: transport plan (n1 x n2)
double *alpha, // Output: dual variables for sources (n1)
double *beta, // Output: dual variables for targets (n2)
double *cost, // Output: total transportation cost
uint64_t maxIter // Maximum iterations for solver
);


#endif
106 changes: 105 additions & 1 deletion ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,4 +370,108 @@ int EMD_wrap_sparse(
}
}
return ret;
}
}

int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b,
int dim, int metric, double *G, double *alpha, double *beta,
double *cost, uint64_t maxIter) {
using namespace lemon;
typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);

// Filter source nodes with non-zero weights
std::vector<int> idx_a;
std::vector<double> weights_a_filtered;
std::vector<double> coords_a_filtered;

// Reserve space to avoid reallocations
idx_a.reserve(n1);
weights_a_filtered.reserve(n1);
coords_a_filtered.reserve(n1 * dim);

for (int i = 0; i < n1; i++) {
if (X[i] > 0) {
idx_a.push_back(i);
weights_a_filtered.push_back(X[i]);
for (int d = 0; d < dim; d++) {
coords_a_filtered.push_back(coords_a[i * dim + d]);
}
}
}
int n = idx_a.size();

// Filter target nodes with non-zero weights
std::vector<int> idx_b;
std::vector<double> weights_b_filtered;
std::vector<double> coords_b_filtered;

// Reserve space to avoid reallocations
idx_b.reserve(n2);
weights_b_filtered.reserve(n2);
coords_b_filtered.reserve(n2 * dim);

for (int j = 0; j < n2; j++) {
if (Y[j] > 0) {
idx_b.push_back(j);
weights_b_filtered.push_back(-Y[j]); // Demand is negative supply
for (int d = 0; d < dim; d++) {
coords_b_filtered.push_back(coords_b[j * dim + d]);
}
}
}
int m = idx_b.size();

if (n == 0 || m == 0) {
*cost = 0.0;
return 0;
}

// Create full bipartite graph
Digraph di(n, m);

NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter
);

// Set supplies
net.supplyMap(&weights_a_filtered[0], n, &weights_b_filtered[0], m);

// Enable lazy cost computation - costs will be computed on-the-fly
net.setLazyCost(&coords_a_filtered[0], &coords_b_filtered[0], dim, metric, n, m);

// Run solver
int ret = net.run();

if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
*cost = 0;

// Initialize output arrays
for (int i = 0; i < n1 * n2; i++) G[i] = 0.0;
for (int i = 0; i < n1; i++) alpha[i] = 0.0;
for (int i = 0; i < n2; i++) beta[i] = 0.0;

// Extract solution
Arc a;
di.first(a);
for (; a != INVALID; di.next(a)) {
int i = di.source(a);
int j = di.target(a) - n;

int orig_i = idx_a[i];
int orig_j = idx_b[j];

double flow = net.flow(a);
G[orig_i * n2 + orig_j] = flow;

alpha[orig_i] = -net.potential(i);
beta[orig_j] = net.potential(j + n);

if (flow > 0) {
double c = net.computeLazyCost(i, j);
*cost += flow * c;
}
}
}

return ret;
}
3 changes: 2 additions & 1 deletion ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# License: MIT License

from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
from ._network_simplex import emd, emd2
from ._network_simplex import emd, emd2, emd2_lazy
from ._barycenter_solvers import (
barycenter,
free_support_barycenter,
Expand All @@ -35,6 +35,7 @@
__all__ = [
"emd",
"emd2",
"emd2_lazy",
"barycenter",
"free_support_barycenter",
"cvx",
Expand Down
Loading
Loading