跳转到内容
ForeverYoung
返回

低维表示:投影降维、MRL 与稀疏表示

做 embedding 时,大家很容易先盯着模型本身:换更大的 backbone,接更多数据,把 MTEB 或内部 eval 往上推一点。向量维度也跟着往上走。384768 维还没过时,但 20484096 维的文本和多模态 embedding 已经很常见。

这不是一条严格的 leaderboard 年表。MTEB 后来有英文、多语言和不同版本的划分,榜单也一直在变。但只看一批有代表性的强模型,可以看到输出维度逐步上移:

时间代表模型输出维度备注
2022sentence-transformers/all-mpnet-base-v2768早期常用的 sentence-transformers baseline
2023BAAI/bge-large-en-v1.51024BGE 系列发布时在 MTEB / C-MTEB 上很靠前
2024intfloat/e5-mistral-7b-instruct4096LLM backbone 开始成为 embedding 榜单里的主线
2024nvidia/NV-Embed-v24096model card 报告 2024-08-30 MTEB 56 任务 No.1
2025Qwen/Qwen3-Embedding-8B4096支持 32 到 4096 的可变输出维度,默认上限仍是 4096

离线实验里,这件事通常不太突出。多一点显存,多一点磁盘,把 batch 调小一点,实验还能继续。索引和 serving 的压力更直接。一个 4096float32 向量是 16 KB。十亿条就是 16 TB,还没算索引结构、副本、缓存和元数据。查询进来后,系统还要在很短的时间里从这些向量里找近邻。

问题不只是“能不能把向量做小”。更具体地说,在一个已经有质量要求的检索系统里,哪些表示能减少存储、带宽和计算?这些节省又会带来多少 recall 变化?

本文只比较三条常见路线:投影降维、Matryoshka Representation Learning(MRL),以及 Contrastive Sparse Representation(CSR)。前两类方法把 dense 向量变短。CSR 的表示空间可以很大,但每个样本只使用其中很少几个位置。

表示形式

先明确推理时系统拿到的表示,再讨论训练细节。

投影降维是最直接的形式。原始 dense embedding 是 x in R^d,接一个线性层或 MLP,得到 z in R^m,其中 m << d。检索系统不用关心前面的映射过程。它看到的仍然是 dense vector,只是维度更低。

MRL 也保留 dense vector 的接口。不同的是,它不是事后硬切一截,而是在训练时让前 32 维、前 64 维、前 128 维都能单独工作。推理时预算紧,就取短 prefix;预算宽,就取长一点。这个设计和现有 dense KNN / ANN 系统兼容。

CSR 走另一条路。它先把原始 embedding 映射到一个更大的 latent space,比如 h 维;再只保留 TopK 个非零值。存储时不存完整 dense 向量,只存 indicesvalues。如果 k = 16,每条 item 只存 16 个位置和 16 个数值。这个表示只有在检索系统也采用 sparse retrieval 时才会转化为实际收益,否则稀疏性只停留在复杂度公式里。

Feature Dim Full Rep. Dense 0.2 0.3 0.4 0.1 Projection Short Dense 0.7 0.3 learned short vector MRL Rep. Dense Prefix 0.7 0.3 drop drop use first m dims CSR Rep. Sparse / Ultra-Sparse 0 0.3 0 0.7 store indices + values Projection 和 MRL 得到短 dense 向量;CSR 保留大空间,但每条样本只激活少数位置。
几种表示的直观差别。投影把向量重新压成短 dense;MRL 使用前缀;CSR/CSRv2 仍在高维空间里表达,但只保留少数非零位置。

MRL 和 CSR 需要区分 active count 的含义。MRL 的 m=64 是“只用前 64 个 dense 维度”。CSR 的 k=64 是“在一个可能很大的 latent space 里,只激活 64 个位置”。这两个配置在成本上可能接近,但信息组织方式完全不同。

训练目标

投影降维没有固定配方。可以直接在低维向量上跑原来的检索 loss,比如 InfoNCE。也可以做蒸馏,让低维向量模仿高维 teacher 的相似度分布。还可以加 reconstruction loss,让低维向量尽量保留原始 embedding 的信息。不管怎么训,它都绕不开一个事实:信息最后都要挤进 m 维。

MRL 的约束更明确。原始 Matryoshka Representation Learning 在多个截断长度上同时加 loss。它等于告诉模型:前面的维度先别浪费,短向量也要可用。

如果用检索任务写,MRL 可以表示为:

这里 M 是一组截断长度,比如 {32, 64, 128, 256}。训练时每个长度都算一次检索损失。推理时选一个长度。好处是部署简单,仍然走 dense vector search。代价也很直观:模型要学会把信息排到前面。prefix 越短,能塞进去的东西越少。

CSR 的训练更像 sparse autoencoder 加一个任务约束。Beyond Matryoshka: Revisiting Sparse Coding for Adaptive Representation 先拿一个预训练 dense embedding,再训练一个稀疏模块,把它映射成 TopK latent。形式可以写成:

重建损失负责保留原始 embedding 里的信息。对比损失负责让 sparse latent 还能做检索或分类。CSR 的区别在于:它不要求所有样本都挤进同一组低维坐标。模型可以有一个大的 latent dictionary,每个样本只挑少数几个位置。

CSR training view reconstruction loss compare x_hat with dense x dense x base embedding encoder dense to latent TopK sparse z indices + values decoder x_hat contrastive / retrieval loss sparse z should separate positives from negatives Main path: x to z to x_hat. Top: reconstruction. Bottom: retrieval supervision.
CSR 的训练可以拆成两条约束:decoder 要从 sparse latent 还原原始 embedding;任务 loss 要让 sparse latent 仍然能做检索。

CSRv2 可以放在这个脉络里看。它不是换了一种表示,而是在修 CSR 的训练问题。尤其是 k=2k=4 这种 ultra-sparse 区域,原始 CSR 容易出现 dead neuron:很多 latent 维度长期不被选中,模型看起来有很大的 latent space,实际可用容量却小得多。

CSRv2 主要改训练。它用 k-annealing,从较大的 k 开始训,再逐步降到目标 sparsity,避免模型一开始就被很小的激活数限制。它也引入监督对比信号,让少数 active features 更直接地服务下游任务。在跨域任务上,论文还讨论了 full finetuning。更谨慎的结论是:CSRv2 并没有证明“稀疏一定更好”,而是说明 ultra-sparse 下的部分失败来自训练塌缩,不能只归因于稀疏表示本身。

评测方式

复杂度分析需要配合实测。低成本表示至少要回答两个问题:

这两个问题要分开看。质量可以用检索任务的 Recall、nDCG、MRR,也可以先看和 full dense top-k 的重叠率。性能要看存储、index build time、query latency、QPS、batch size、硬件和 kernel。只报维度或 k,信息是不完整的。

Dense 检索的基本形式是:

如果 query 是 B x d,corpus 是 N x d,一次 exact dense search 的主要成本就是 B x N x d 的乘加和对应内存读取。MRL 和投影降维仍然是这个形式,只是 d 换成了更小的 m

CSR 的理想路径不一样。query 和 corpus 都是 TopK sparse。相似度只需要在共同激活的维度上累加:

如果 sparse index 做得好,成本更接近 active features 和 posting list 的访问量,而不是完整 latent 维度。但工程代价也在这里。位置索引不是免费的,scatter / index_add 不是免费的。k 很小时,kernel launch 和不连续访存也可能吃掉一部分理论收益。

下面给出一个可复现实验脚本。它不能替代正式评测,只是把同一组向量上的质量和性能数字放到一起,避免质量指标和性能指标来自不同设置:

python scripts/benchmark_low_dim_sparse_retrieval.py \
  --device cuda \
  --num-items 1000000 \
  --num-queries 512 \
  --dim 2048 \
  --reduced-dims 32 64 128 256 \
  --sparse-topks 4 8 16 32 64 \
  --top-k 10 \
  --query-batch-size 64 \
  --warmup 10 \
  --repeats 50 \
  --output-json sparse_retrieval_benchmark.json

这个脚本有两种用法。没有真实 embedding 时,它会生成 synthetic paired query/corpus,适合检查 dense、projection、prefix 和 sparse retrieval 的机制与性能路径。这个结果不能当模型质量结论。要测真实质量,需要传入真实 embedding 或 CSR sparse latents:

python scripts/benchmark_low_dim_sparse_retrieval.py \
  --device cuda \
  --corpus-file corpus_embeddings.pt \
  --query-file query_embeddings.pt \
  --target-file target_ids.pt \
  --sparse-corpus-file csr_corpus_sparse.pt \
  --sparse-query-file csr_query_sparse.pt
代码:低维 dense / sparse retrieval benchmark
from __future__ import annotations

import argparse
import json
import math
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any

import torch
import torch.nn.functional as F


@dataclass
class MethodResult:
    method: str
    dim: int
    active_dim: int
    recall_at_k_vs_dense: float | None
    paired_recall_at_k: float | None
    index_build_ms: float | None
    latency_ms: float
    qps: float
    storage_mb: float
    note: str
    latency_speedup_vs_dense: float | None = None
    qps_vs_dense: float | None = None
    storage_vs_dense: float | None = None


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Benchmark dense, low-dimensional dense, prefix/MRL-proxy, and "
            "TopK sparse retrieval. Use real tensors when available; otherwise "
            "the script creates synthetic paired query/corpus embeddings."
        )
    )
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--num-items", type=int, default=100_000)
    parser.add_argument("--num-queries", type=int, default=512)
    parser.add_argument("--dim", type=int, default=2048)
    parser.add_argument("--reduced-dims", type=int, nargs="+", default=[32, 64, 128, 256])
    parser.add_argument("--sparse-topks", type=int, nargs="+", default=[4, 8, 16, 32, 64])
    parser.add_argument("--top-k", type=int, default=10)
    parser.add_argument("--query-batch-size", type=int, default=64)
    parser.add_argument("--warmup", type=int, default=10)
    parser.add_argument("--repeats", type=int, default=50)
    parser.add_argument("--noise", type=float, default=0.05)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--dtype", choices=["float32", "float16", "bfloat16"], default="float32")
    parser.add_argument("--corpus-file", type=Path, default=None)
    parser.add_argument("--query-file", type=Path, default=None)
    parser.add_argument("--target-file", type=Path, default=None)
    parser.add_argument("--sparse-corpus-file", type=Path, default=None)
    parser.add_argument("--sparse-query-file", type=Path, default=None)
    parser.add_argument("--output-json", type=Path, default=None)
    return parser.parse_args()


def dtype_from_name(name: str) -> torch.dtype:
    return {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }[name]


def load_tensor(path: Path) -> torch.Tensor:
    value = torch.load(path, map_location="cpu")
    if isinstance(value, dict):
        for key in ("embeddings", "tensor", "data"):
            if key in value:
                value = value[key]
                break
    if not isinstance(value, torch.Tensor):
        raise TypeError(f"{path} must contain a tensor or a dict with tensor-like embeddings")
    return value


def load_sparse(path: Path) -> dict[str, torch.Tensor | int]:
    value = torch.load(path, map_location="cpu")
    if not isinstance(value, dict):
        raise TypeError(f"{path} must contain a dict with indices, values, and dim")
    required = {"indices", "values", "dim"}
    missing = required - set(value)
    if missing:
        raise KeyError(f"{path} is missing sparse keys: {sorted(missing)}")
    return {
        "indices": value["indices"].long(),
        "values": value["values"],
        "dim": int(value["dim"]),
    }


def make_synthetic(
    num_items: int,
    num_queries: int,
    dim: int,
    noise: float,
    seed: int,
    dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    generator = torch.Generator(device="cpu").manual_seed(seed)
    corpus = torch.randn(num_items, dim, generator=generator, dtype=torch.float32)
    corpus = F.normalize(corpus, dim=1).to(dtype)
    target_ids = torch.randint(num_items, (num_queries,), generator=generator)
    query = corpus[target_ids].float()
    query = query + noise * torch.randn(query.shape, generator=generator)
    query = F.normalize(query, dim=1).to(dtype)
    return corpus, query, target_ids


def maybe_sync(device: torch.device) -> None:
    if device.type == "cuda":
        torch.cuda.synchronize(device)


def tensor_storage_mb(tensor: torch.Tensor) -> float:
    return tensor.numel() * tensor.element_size() / 1_000_000


def sparse_storage_mb(sparse: dict[str, torch.Tensor | int]) -> float:
    indices = sparse["indices"]
    values = sparse["values"]
    assert isinstance(indices, torch.Tensor)
    assert isinstance(values, torch.Tensor)
    return (indices.numel() * indices.element_size() + values.numel() * values.element_size()) / 1_000_000


@torch.no_grad()
def dense_topk(
    query: torch.Tensor,
    corpus: torch.Tensor,
    k: int,
    query_batch_size: int,
) -> torch.Tensor:
    all_indices = []
    corpus_t = corpus.t().contiguous()
    for start in range(0, query.shape[0], query_batch_size):
        q = query[start : start + query_batch_size]
        scores = q @ corpus_t
        all_indices.append(torch.topk(scores, k=min(k, corpus.shape[0]), dim=1).indices)
    return torch.cat(all_indices, dim=0)


def time_call(fn: Any, warmup: int, repeats: int, device: torch.device) -> tuple[Any, float]:
    result = None
    for _ in range(warmup):
        result = fn()
    maybe_sync(device)
    start = time.perf_counter()
    for _ in range(repeats):
        result = fn()
    maybe_sync(device)
    elapsed_ms = (time.perf_counter() - start) * 1000.0 / repeats
    return result, elapsed_ms


def recall_vs_reference(found: torch.Tensor, reference: torch.Tensor) -> float:
    hits = 0
    for row_found, row_ref in zip(found.cpu(), reference.cpu(), strict=True):
        hits += len(set(row_found.tolist()) & set(row_ref.tolist()))
    return hits / max(found.shape[0] * reference.shape[1], 1)


def paired_recall(found: torch.Tensor, target_ids: torch.Tensor | None) -> float | None:
    if target_ids is None:
        return None
    target_ids = target_ids.cpu()
    hits = 0
    for row, target in zip(found.cpu(), target_ids, strict=True):
        hits += int(int(target) in set(row.tolist()))
    return hits / max(found.shape[0], 1)


def projected(
    corpus: torch.Tensor,
    query: torch.Tensor,
    out_dim: int,
    seed: int,
    dtype: torch.dtype,
    device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
    generator = torch.Generator(device="cpu").manual_seed(seed + out_dim)
    scale = 1.0 / math.sqrt(out_dim)
    projection = torch.randn(corpus.shape[1], out_dim, generator=generator) * scale
    projection = projection.to(device=device, dtype=dtype)
    return (
        F.normalize(corpus @ projection, dim=1),
        F.normalize(query @ projection, dim=1),
    )


def prefix(corpus: torch.Tensor, query: torch.Tensor, out_dim: int) -> tuple[torch.Tensor, torch.Tensor]:
    return F.normalize(corpus[:, :out_dim], dim=1), F.normalize(query[:, :out_dim], dim=1)


def to_topk_sparse(x: torch.Tensor, k: int) -> dict[str, torch.Tensor | int]:
    values, indices = torch.topk(x.abs(), k=min(k, x.shape[1]), dim=1)
    signed_values = torch.gather(x, 1, indices)
    signed_values = F.normalize(signed_values, dim=1)
    return {"indices": indices.long(), "values": signed_values, "dim": x.shape[1]}


def build_postings(
    sparse: dict[str, torch.Tensor | int],
    num_items: int,
    dim: int,
    device: torch.device,
) -> dict[str, torch.Tensor]:
    indices = sparse["indices"]
    values = sparse["values"]
    assert isinstance(indices, torch.Tensor)
    assert isinstance(values, torch.Tensor)

    flat_dims = indices.reshape(-1).to(device=device)
    flat_values = values.reshape(-1).to(device=device)
    flat_items = (
        torch.arange(num_items, device=device)
        .repeat_interleave(indices.shape[1])
        .to(torch.long)
    )
    order = torch.argsort(flat_dims)
    flat_dims = flat_dims[order]
    flat_values = flat_values[order]
    flat_items = flat_items[order]
    counts = torch.bincount(flat_dims, minlength=dim)
    offsets = torch.zeros(dim + 1, dtype=torch.long, device=device)
    offsets[1:] = torch.cumsum(counts, dim=0)
    return {
        "items": flat_items,
        "values": flat_values,
        "offsets_cpu": offsets.cpu(),
    }


@torch.no_grad()
def sparse_topk_from_postings(
    query_sparse: dict[str, torch.Tensor | int],
    postings: dict[str, torch.Tensor],
    num_items: int,
    k: int,
) -> torch.Tensor:
    query_indices = query_sparse["indices"]
    query_values = query_sparse["values"]
    assert isinstance(query_indices, torch.Tensor)
    assert isinstance(query_values, torch.Tensor)

    device = query_values.device
    rows = []
    query_indices_cpu = query_indices.cpu()
    offsets_cpu = postings["offsets_cpu"]
    items = postings["items"]
    values = postings["values"]
    for row_indices, row_values in zip(query_indices_cpu, query_values, strict=True):
        scores = torch.zeros(num_items, dtype=query_values.dtype, device=device)
        for dim_id, query_value in zip(row_indices, row_values, strict=True):
            dim_int = int(dim_id)
            start = int(offsets_cpu[dim_int])
            end = int(offsets_cpu[dim_int + 1])
            if end > start:
                scores.index_add_(0, items[start:end], values[start:end] * query_value)
        rows.append(torch.topk(scores, k=min(k, num_items), dim=0).indices)
    return torch.stack(rows, dim=0)


def format_markdown(results: list[MethodResult]) -> str:
    headers = [
        "method",
        "dim",
        "active",
        "recall@k vs dense",
        "paired recall@k",
        "build ms",
        "latency ms",
        "latency x dense",
        "qps",
        "qps x dense",
        "storage MB",
        "storage x dense",
        "note",
    ]
    lines = ["| " + " | ".join(headers) + " |", "| " + " | ".join(["---"] * len(headers)) + " |"]
    for result in results:
        lines.append(
            "| "
            + " | ".join(
                [
                    result.method,
                    str(result.dim),
                    str(result.active_dim),
                    "n/a" if result.recall_at_k_vs_dense is None else f"{result.recall_at_k_vs_dense:.4f}",
                    "n/a" if result.paired_recall_at_k is None else f"{result.paired_recall_at_k:.4f}",
                    "n/a" if result.index_build_ms is None else f"{result.index_build_ms:.2f}",
                    f"{result.latency_ms:.2f}",
                    "n/a" if result.latency_speedup_vs_dense is None else f"{result.latency_speedup_vs_dense:.2f}x",
                    f"{result.qps:.1f}",
                    "n/a" if result.qps_vs_dense is None else f"{result.qps_vs_dense:.2f}x",
                    f"{result.storage_mb:.2f}",
                    "n/a" if result.storage_vs_dense is None else f"{result.storage_vs_dense:.4f}x",
                    result.note,
                ]
            )
            + " |"
        )
    return "\n".join(lines)


def add_relative_metrics(results: list[MethodResult]) -> None:
    baseline = next(result for result in results if result.method == "full_dense")
    for result in results:
        result.latency_speedup_vs_dense = baseline.latency_ms / result.latency_ms
        result.qps_vs_dense = result.qps / baseline.qps
        result.storage_vs_dense = result.storage_mb / baseline.storage_mb


def main() -> None:
    args = parse_args()
    device = torch.device(args.device)
    dtype = dtype_from_name(args.dtype)

    if args.corpus_file and args.query_file:
        corpus = load_tensor(args.corpus_file)
        query = load_tensor(args.query_file)
        target_ids = load_tensor(args.target_file).long() if args.target_file else None
    else:
        corpus, query, target_ids = make_synthetic(
            args.num_items,
            args.num_queries,
            args.dim,
            args.noise,
            args.seed,
            dtype,
        )

    corpus = F.normalize(corpus.to(device=device, dtype=dtype), dim=1)
    query = F.normalize(query.to(device=device, dtype=dtype), dim=1)
    if target_ids is not None:
        target_ids = target_ids.cpu()

    if corpus.shape[1] != query.shape[1]:
        raise ValueError(f"corpus dim {corpus.shape[1]} != query dim {query.shape[1]}")

    results: list[MethodResult] = []
    full_indices, full_latency_ms = time_call(
        lambda: dense_topk(query, corpus, args.top_k, args.query_batch_size),
        args.warmup,
        args.repeats,
        device,
    )
    results.append(
        MethodResult(
            method="full_dense",
            dim=corpus.shape[1],
            active_dim=corpus.shape[1],
            recall_at_k_vs_dense=1.0,
            paired_recall_at_k=paired_recall(full_indices, target_ids),
            index_build_ms=None,
            latency_ms=full_latency_ms,
            qps=query.shape[0] / (full_latency_ms / 1000.0),
            storage_mb=tensor_storage_mb(corpus),
            note="exact dense baseline",
        )
    )

    for reduced_dim in args.reduced_dims:
        if reduced_dim > corpus.shape[1]:
            continue
        p_corpus, p_query = projected(corpus, query, reduced_dim, args.seed, dtype, device)
        p_indices, p_latency_ms = time_call(
            lambda: dense_topk(p_query, p_corpus, args.top_k, args.query_batch_size),
            args.warmup,
            args.repeats,
            device,
        )
        results.append(
            MethodResult(
                method=f"projection_{reduced_dim}",
                dim=reduced_dim,
                active_dim=reduced_dim,
                recall_at_k_vs_dense=recall_vs_reference(p_indices, full_indices),
                paired_recall_at_k=paired_recall(p_indices, target_ids),
                index_build_ms=None,
                latency_ms=p_latency_ms,
                qps=query.shape[0] / (p_latency_ms / 1000.0),
                storage_mb=tensor_storage_mb(p_corpus),
                note="random projection; replace with trained projection for model quality",
            )
        )

        m_corpus, m_query = prefix(corpus, query, reduced_dim)
        m_indices, m_latency_ms = time_call(
            lambda: dense_topk(m_query, m_corpus, args.top_k, args.query_batch_size),
            args.warmup,
            args.repeats,
            device,
        )
        results.append(
            MethodResult(
                method=f"prefix_mrl_proxy_{reduced_dim}",
                dim=reduced_dim,
                active_dim=reduced_dim,
                recall_at_k_vs_dense=recall_vs_reference(m_indices, full_indices),
                paired_recall_at_k=paired_recall(m_indices, target_ids),
                index_build_ms=None,
                latency_ms=m_latency_ms,
                qps=query.shape[0] / (m_latency_ms / 1000.0),
                storage_mb=tensor_storage_mb(m_corpus),
                note="prefix mechanics only; real MRL requires trained nested embeddings",
            )
        )

    for sparse_topk in args.sparse_topks:
        if args.sparse_corpus_file and args.sparse_query_file:
            sparse_corpus = load_sparse(args.sparse_corpus_file)
            sparse_query = load_sparse(args.sparse_query_file)
        else:
            sparse_corpus = to_topk_sparse(corpus, sparse_topk)
            sparse_query = to_topk_sparse(query, sparse_topk)

        sparse_corpus = {
            "indices": sparse_corpus["indices"].to(device=device),
            "values": sparse_corpus["values"].to(device=device, dtype=dtype),
            "dim": int(sparse_corpus["dim"]),
        }
        sparse_query = {
            "indices": sparse_query["indices"].to(device=device),
            "values": sparse_query["values"].to(device=device, dtype=dtype),
            "dim": int(sparse_query["dim"]),
        }

        maybe_sync(device)
        start = time.perf_counter()
        postings = build_postings(sparse_corpus, corpus.shape[0], int(sparse_corpus["dim"]), device)
        maybe_sync(device)
        build_ms = (time.perf_counter() - start) * 1000.0
        sparse_indices, sparse_latency_ms = time_call(
            lambda: sparse_topk_from_postings(sparse_query, postings, corpus.shape[0], args.top_k),
            args.warmup,
            args.repeats,
            device,
        )
        results.append(
            MethodResult(
                method=f"sparse_topk_{sparse_topk}",
                dim=int(sparse_corpus["dim"]),
                active_dim=sparse_topk,
                recall_at_k_vs_dense=recall_vs_reference(sparse_indices, full_indices),
                paired_recall_at_k=paired_recall(sparse_indices, target_ids),
                index_build_ms=build_ms,
                latency_ms=sparse_latency_ms,
                qps=query.shape[0] / (sparse_latency_ms / 1000.0),
                storage_mb=sparse_storage_mb(sparse_corpus),
                note="TopK sparse retrieval; use real CSR latents for CSR model quality",
            )
        )

    add_relative_metrics(results)
    print(format_markdown(results))
    if args.output_json:
        args.output_json.write_text(
            json.dumps([asdict(result) for result in results], indent=2),
            encoding="utf-8",
        )


if __name__ == "__main__":
    main()

输出里有两类指标。recall@k vs dense 看当前方法的 top-k 和 full dense top-k 重合多少。paired recall@k 用 synthetic 或真实 target id 检查目标 item 是否进 top-k。性能部分记录 latency、QPS、storage MB,以及 sparse posting index 的 build time。

模型质量需要单独参考论文实验。下面这个表来自 CSRv2 论文里的 e5-Mistral-7B 对照实验:同一个 backbone、同一批训练配置,在六类 MTEB 任务上比较 MRL、CSR 和 CSRv2。表中只保留 average,用来观察不同 active dimension 下的整体趋势。

active dim / kMRL avgCSR avgCSRv2-linear avgCSRv2 avg备注
6461.8666.6867.5868.08同等 active count 下,CSR/CSRv2 的平均分更高
1651.9362.8364.2665.76sparse latent 在低 active count 下保持了较高平均分
440.8352.9458.6261.01ultra-sparse 下 CSRv2 的训练改动带来较大差异
233.8144.3353.3558.38CSRv2 论文重点讨论的极低激活区间

这张表不能直接替代你自己业务上的质量评测。它能说明的是:如果 sparse latent 是训练出来的,CSR/CSRv2 的模型质量不能从下面这个 synthetic benchmark 的 recall@k vs dense 推断。下面的数字主要用来检查机制和性能路径:dense baseline、随机投影、prefix 截断,以及 TopK sparse retrieval 在同一台机器上的延迟、QPS 和存储量。

相对列都以 full_dense 为基准。latency x dense 大于 1.0x 表示比 full dense 更快,小于 1.0x 表示更慢;storage x dense 越小,存储占用越少。完整脚本已在上面的折叠代码块里给出,其中 add_relative_metrics 负责计算这些相对指标。

methoddimactiverecall@k vs densepaired recall@kbuild mslatency mslatency x denseqpsqps x densestorage MBstorage x densenote
full_dense204820481.00001.0000n/a71.441.00x7167.01.00x8192.001.0000xexact dense baseline
projection_3232320.00140.0137n/a8.898.04x57577.78.03x128.000.0156xrandom projection; replace with trained projection for model quality
prefix_mrl_proxy_3232320.00210.0195n/a8.898.04x57621.48.04x128.000.0156xprefix mechanics only; real MRL requires trained nested embeddings
projection_6464640.01930.1914n/a10.796.62x47431.46.62x256.000.0313xrandom projection; replace with trained projection for model quality
prefix_mrl_proxy_6464640.01390.1367n/a10.826.60x47308.56.60x256.000.0313xprefix mechanics only; real MRL requires trained nested embeddings
projection_1281281280.06310.6211n/a13.785.18x37156.05.18x512.000.0625xrandom projection; replace with trained projection for model quality
prefix_mrl_proxy_1281281280.06780.6738n/a13.745.20x37264.35.20x512.000.0625xprefix mechanics only; real MRL requires trained nested embeddings
projection_2562562560.09940.9805n/a19.623.64x26092.63.64x1024.000.1250xrandom projection; replace with trained projection for model quality
prefix_mrl_proxy_2562562560.10000.9922n/a19.713.62x25971.53.62x1024.000.1250xprefix mechanics only; real MRL requires trained nested embeddings
sparse_topk_4204840.00020.000061.53128.400.56x3987.60.56x48.000.0059xTopK sparse retrieval; use real CSR latents for CSR model quality
sparse_topk_8204880.00140.01172.26143.580.50x3565.90.50x96.000.0117xTopK sparse retrieval; use real CSR latents for CSR model quality
sparse_topk_162048160.00430.04304.12203.110.35x2520.80.35x192.000.0234xTopK sparse retrieval; use real CSR latents for CSR model quality
sparse_topk_322048320.01500.15047.69361.480.20x1416.40.20x384.000.0469xTopK sparse retrieval; use real CSR latents for CSR model quality
sparse_topk_642048640.05230.523414.90676.080.11x757.30.11x768.000.0938xTopK sparse retrieval; use real CSR latents for CSR model quality

因此,这组数更适合作为工程 sanity check,而不是模型质量结论。随机投影和 prefix proxy 的 recall 说明,未训练的低维表示不能代表 projection 或 MRL 的真实性能。sparse TopK 的性能也只说明当前脚本里的 sparse retrieval 路径还没有跑出理论优势;真正的 CSR/CSRv2 评估需要换成训练好的 sparse latents,再同时比较质量和 latency。

方法选择

投影降维适合作为 baseline。实现简单,部署也简单。它对应的取舍是:用一个小 dense 向量承受信息瓶颈。如果任务本身冗余较大,或者 teacher embedding 的维度超过实际需要,投影可能已经够用。

MRL 适合一个模型服务多个预算档位。一个 embedding 可以按 32/64/128/256 不同长度取 prefix,下游接口仍然是 dense search。它的训练约束更强:模型需要把有用信息排到前面。维度很短时,prefix 的容量就是硬上限。

CSR 适合可以改检索路径、并且系统能够利用稀疏性的场景。它把“低成本”从低维换成低激活:latent space 可以很大,但每个样本只拿少数 active features 参与计算。CSRv2 虽然提升了训练稳定性,并且使得更低的 active count 成为可能,但它仍然极其依赖 sparse retrieval 的实现。如果 sparse index、kernel 和服务路径没有把稀疏性转化成真实的 latency 或 memory 收益,它就不是 dense embedding 的直接替代品。

更稳妥的实践顺序是:先测 full dense 上界,再测训练好的 projection 和 MRL prefix,最后测 CSR sparse latents。每一步都同时看质量和性能。低维表示不是目标本身。目标是在给定延迟、内存和成本预算下,拿到还能接受的检索质量。

参考文献

  1. Kusupati et al., Matryoshka Representation Learning, NeurIPS 2022.
  2. Wen et al., Beyond Matryoshka: Revisiting Sparse Coding for Adaptive Representation, 2025.
  3. CSRv2 authors, CSRv2: Unlocking Ultra-Sparse Embeddings, 2026.
  4. Muennighoff et al., MTEB: Massive Text Embedding Benchmark, 2022.
  5. Hugging Face model cards: all-mpnet-base-v2, bge-large-en-v1.5, e5-mistral-7b-instruct, NV-Embed-v2, Qwen3-Embedding-8B.

分享这篇文章:

下一篇
PyTorch 2.12 稀疏矩阵深度解析:COO、CSR、CSC、BSR 与 BSC