跳转到内容
ForeverYoung
返回

PyTorch 2.12 稀疏矩阵深度解析:COO、CSR、CSC、BSR 与 BSC

现在很多模型和系统里,矩阵的形状都大得有点吓人。embedding table、图邻接矩阵、推荐系统里的特征矩阵、attention mask,看起来动不动就是百万、千万级别的元素。但摊开一看,很多位置其实没东西,值就是 0。

稀疏矩阵的直觉很朴素:0 太多了,就别存 0。只存非零值,内存会小一些,后面的矩阵乘法也可能少做一些无用功。

麻烦也在这里。你把 0 扔掉以后,必须把另一件事记下来:这些非零值原来在哪。位置怎么编码,基本就决定了稀疏矩阵长什么样,也决定了后面 torch.sparse.mm 会走哪类计算路径。

PyTorch 2.12 里常见的稀疏格式包括 COO、CSR、CSC、BSR 和 BSC。它们都在做同一件事:保存非零值,以及这些非零值的位置。区别在于,有的直接保存坐标,有的压缩行指针,有的压缩列指针,有的把多个元素打包成块。

我想先把这件事拆开看:PyTorch 里的这些稀疏矩阵格式到底存了哪些 tensor?什么时候真的省内存?做 torch.sparse.mm 时,哪些格式能跑,哪些格式只是能构造出来?

文末放了一组 CPU/GPU 实验。它只说明这些机器、这些 PyTorch wheel、这些矩阵形状下发生了什么。稀疏性能太依赖硬件和 kernel,一张表不该被读成通用结论。

从“不存 0”开始

假设我们有一个矩阵 A,内容是:

它的形状是 5 x 6,非零元素个数是 nnz = 5。如果用 dense float32 存储,需要 5 x 6 x 4 = 120 bytes,不算 tensor metadata。

Dense 存储只关心每个位置的值。Sparse 存储要多想一步:除了非零值本身,也就是 values,还得存这些值在原矩阵里的位置,也就是某种形式的 indices。所以稀疏格式之间的差别,主要不是 API 名字不同,而是索引结构不同。

稀疏格式省掉了 0,但索引不是免费的。PyTorch 里索引通常是 int64,每个索引 8 bytes。矩阵很小,或者本来就不够稀疏时,最后反而可能更占空间。

同一个矩阵,可以被拆成坐标、行指针、列指针或块指针 Dense 345789 COO arrays row = [0, 1, 1, 3, 3, 4] col = [2, 0, 4, 1, 3, 5] values = [3, 4, 5, 7, 8, 9] COO 直接保存每个非零元素的坐标。 CSR/CSC/BSR/BSC 会进一步压缩坐标。
稀疏格式的差别,主要在坐标如何编码。

COO:最直接的坐标表

COO 是 coordinate list,基本就是一张坐标表。它很适合做入口格式,因为你只要知道“哪一行、哪一列、是什么值”,就能把稀疏矩阵拼出来。PyTorch 的 torch.sparse_coo_tensor(indices, values, size) 接收两部分核心数据:

import torch

indices = torch.tensor([
    [0, 1, 1, 3, 3, 4],  # row
    [2, 0, 4, 1, 3, 5],  # col
])
values = torch.tensor([3, 4, 5, 7, 8, 9], dtype=torch.float32)

a = torch.sparse_coo_tensor(indices, values, size=(5, 6))

COO 的好处就是直接。很多稀疏数据本来也不是一开始就是矩阵,而是按事件、边、样本一条条产生的。先把它们攒成 COO,通常最省心。

但 COO 的索引开销也最直观。二维矩阵里,每个非零元素要存两个坐标。如果 value 是 float32,index 是 int64,每个非零元素大约是 20 bytes:

这里 s_i 是 index bytes,s_v 是 value bytes。这还没有算 tensor metadata。矩阵如果不够稀疏,COO 的索引很快就会把省下来的空间吃回去。

PyTorch COO 还有一个容易踩到的细节:它可以是 uncoalesced。也就是同一个坐标可以出现多次,语义上这些重复值会相加。比如 (1, 4) 位置可以出现两次,coalesce() 之后才会合并成一个坐标。这个设计对增量构造很方便;真要计算时,通常还是先 coalesce(),否则同一位置的重复项会增加遍历负担。

unmerged = torch.sparse_coo_tensor(
    torch.tensor([[1, 1], [4, 4]]),
    torch.tensor([2.0, 3.0]),
    size=(5, 6),
)

merged = unmerged.coalesce()
# merged.indices() -> [[1], [4]]
# merged.values()  -> [5.]

COO 的直觉是:构造容易,表达宽松,索引冗余也明显。

CSR:把行坐标压成指针

CSR 是 compressed sparse row。到 CSR 这里,格式开始压缩重复坐标:它不再给每个非零元素都存一份 row,而是用一个行指针数组描述每一行的数据范围:

crow_indices = torch.tensor([0, 1, 3, 3, 5, 6])
col_indices = torch.tensor([2, 0, 4, 1, 3, 5])
values = torch.tensor([3, 4, 5, 7, 8, 9], dtype=torch.float32)

a = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(5, 6))

crow_indices 的长度是 n_rows + 1。第 i 行的非零元素,存在 crow_indices[i]crow_indices[i + 1] 之间,也就是一个左闭右开的 slice。比如第 1 行的范围是 [1, 3),所以它读 col_indices[1:3] = [0, 4]values[1:3] = [4, 5]

CSR:用 row pointer 切开 col_indices 和 values crow_indices 013356 col_indices 204135 values 345789 第 1 行范围: [crow[1], crow[2]) = [1, 3) 因此读取 col[1:3] 和 values[1:3] row 不再逐元素保存。
CSR 把重复的行坐标变成行指针。行遍历很自然,列坐标仍逐元素保存。

CSR 的存储开销大约是:

和 COO 相比,CSR 少了每个非零元素的 row index,但多了一个长度为 n_rows + 1 的指针数组。矩阵越宽、每行非零元素越多,CSR 相对 COO 越容易省索引。

PyTorch 文档也提到,CSR 的 sparse matrix multiplication 通常比 COO 更适合压缩行格式。不过这句话不能单独拿出来当结论。具体速度还得看稀疏模式、矩阵大小、右侧 dense 矩阵宽度、CPU/GPU backend 和 PyTorch build。

CSC:把列坐标压成指针

CSC 是 compressed sparse column。可以把它看成 CSR 的转置视角:压缩列,逐元素保存 row index。如果 CSR 是把“按行访问”变成连续 slice,CSC 做的就是把“按列访问”变成连续 slice。

ccol_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6])
row_indices = torch.tensor([1, 3, 0, 3, 1, 4])
values = torch.tensor([4, 7, 3, 8, 5, 9], dtype=torch.float32)

a = torch.sparse_csc_tensor(ccol_indices, row_indices, values, size=(5, 6))

ccol_indices[j]ccol_indices[j + 1] 描述第 j 列的非零元素。也就是说,CSR 对“按行找数据”友好,CSC 对“按列找数据”友好。

CSR:行是连续片段 CSC:列是连续片段 row slice column slice
CSR 和 CSC 的值数组都压缩坐标,但它们让不同方向的遍历变成连续片段。

A @ X 这类左侧稀疏、右侧 dense 的乘法里,CSR 的行切片更贴近输出行的计算方式。CSC 不是没用,它更适合列导向的访问、某些转置形式,或者特定 backend 的实现路径。本文的本机实验里,CSC 在 CPU 上可以跑 torch.sparse.mm,但比 COO/CSR 慢很多。这只能说明该环境和该 workload 下的结果,不代表 CSC 一般更慢。

BSR 与 BSC:坐标压缩到块级别

BSR 是 block sparse row,BSC 是 block sparse column。它们把元素坐标继续往上压一层,变成块坐标:不再把每个非零元素当成一个独立坐标,而是把矩阵切成固定大小的 dense block。只要一个 block 里有需要保留的元素,这个 block 的整个 dense payload 都会存下来。

比如使用 2 x 2 block。一个 block 里即使只有一个非零值,BSR 仍然要存 4 个 value。代价是可能把块内的 0 也存回来。好处是 block 坐标数量少了很多,块内计算也更像规则的 dense 小矩阵。

dense = torch.tensor([
    [1, 2, 0, 0],
    [3, 4, 0, 0],
    [0, 0, 5, 6],
    [0, 0, 7, 8],
], dtype=torch.float32)

bsr = dense.to_sparse_bsr(blocksize=(2, 2))
bsc = dense.to_sparse_bsc(blocksize=(2, 2))
Block sparse:索引变少,但块内 0 会被一起存下 storedskippedstored block pointers: 每个块行/块列的范围 block indices: 哪些块列/块行存在 values: [n_blocks, block_h, block_w]每个 block 是 dense 小矩阵 块结构真实存在时,索引更省。 块结构不真实时,块内 0 会浪费空间。
BSR/BSC 的关键问题是:非零是不是自然成块出现。

BSR/BSC 的存储近似是:

这里的 n_b 是非零块数量,n_p 是被压缩的块行数或块列数,b_h x b_w 是 block shape。对于 attention mask、图结构、有限元矩阵、某些分块特征交互,这种结构可能有意义。对于随机散落的非零元素,block sparse 往往会把很多 0 又存回来。

本机实验里的随机 5% 稀疏矩阵就是一个例子:16x16 BSR/BSC 的存储已经略高于 dense,因为几乎所有 block 都被至少一个非零元素“点亮”了。

torch.sparse.mm:格式存在,不等于计算路径一样

PyTorch 2.12 提供了这些 constructor,但 torch.sparse.mm 的支持不是“所有 sparse layout 都一样”。PyTorch 2.12 docstring 写明它支持 COO 和 CSR storage formats;在 gradient support 部分,COO @ DenseCSR @ Dense 都支持对两个输入做 backward,而 CSC/BSR/BSC @ Dense 不支持 backward。

这点对训练代码很重要。你可以构造 CSR/CSC/BSR/BSC tensor,不代表每个 layout 都有你想要的 forward kernel、backward kernel,或者支持你的 device。实际写代码时,应该把 layout 当成 API contract,而不是只看 tensor 能不能创建。

一个最小的检查方式是直接跑:

def try_sparse_mm(a_sparse, x_dense):
    try:
        y = torch.sparse.mm(a_sparse, x_dense)
        return y, None
    except Exception as exc:
        return None, f"{type(exc).__name__}: {exc}"

本文的 benchmark 脚本也是这么做的。能跑就计时,不能跑就把 unsupported 记录下来。

CPU/GPU 实验:存储省多少,乘法快多少

数据结构讲完以后,再看一组实际数字。表里把 CPU 和 GPU 的结果放在一起。两个实验都使用 2048 x 2048 矩阵、2048 x 64 的右侧 dense 矩阵、float3216 x 16 block size。CPU 数字来自本机 arm64 PyTorch 2.12 wheel;GPU 数字来自 H200/CUDA 的同 workload 运行。

CPU 运行命令:

uv run --python 3.12 --with 'torch==2.12.*' scripts/benchmark_pytorch_sparse.py --device cpu

GPU 运行命令:

python -m haptic_foundation.scripts.benchmark_pytorch_sparse_gpu --device cuda --rows 2048 --cols 2048 --rhs-cols 64 --block-size 16
代码:GPU 稀疏矩阵 benchmark
from __future__ import annotations

import argparse
import platform
import time
from dataclasses import dataclass
from typing import Callable

import torch


@dataclass
class Case:
    name: str
    dense: torch.Tensor
    rhs: torch.Tensor


@dataclass
class Result:
    case: str
    layout: str
    nnz: int
    storage_bytes: int | None
    storage_ratio: float | None
    time_ms: float | None
    time_ratio: float | None
    note: str


def tensor_bytes(tensor: torch.Tensor) -> int:
    return tensor.untyped_storage().nbytes()


def sparse_storage_bytes(tensor: torch.Tensor) -> int:
    layout = tensor.layout
    if layout == torch.sparse_coo:
        return (
            tensor.indices().untyped_storage().nbytes()
            + tensor.values().untyped_storage().nbytes()
        )
    if layout in {torch.sparse_csr, torch.sparse_bsr}:
        return (
            tensor.crow_indices().untyped_storage().nbytes()
            + tensor.col_indices().untyped_storage().nbytes()
            + tensor.values().untyped_storage().nbytes()
        )
    if layout in {torch.sparse_csc, torch.sparse_bsc}:
        return (
            tensor.ccol_indices().untyped_storage().nbytes()
            + tensor.row_indices().untyped_storage().nbytes()
            + tensor.values().untyped_storage().nbytes()
        )
    raise ValueError(f"unsupported layout for storage accounting: {layout}")


def cuda_median_ms(
    fn: Callable[[], torch.Tensor],
    *,
    warmup: int,
    repeats: int,
) -> float:
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()

    times: list[float] = []
    for _ in range(repeats):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        fn()
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))
    return sorted(times)[len(times) // 2]


def cpu_median_ms(fn: Callable[[], torch.Tensor], *, min_run_time: float) -> float:
    for _ in range(3):
        fn()
    start = time.perf_counter()
    iters = 0
    samples: list[float] = []
    while time.perf_counter() - start < min_run_time or iters < 5:
        t0 = time.perf_counter()
        fn()
        samples.append((time.perf_counter() - t0) * 1000.0)
        iters += 1
    return sorted(samples)[len(samples) // 2]


def median_ms(
    fn: Callable[[], torch.Tensor],
    *,
    device: torch.device,
    warmup: int,
    repeats: int,
    min_run_time: float,
) -> float:
    if device.type == "cuda":
        return cuda_median_ms(fn, warmup=warmup, repeats=repeats)
    return cpu_median_ms(fn, min_run_time=min_run_time)


def make_unstructured_case(
    name: str,
    rows: int,
    cols: int,
    rhs_cols: int,
    density: float,
    *,
    seed: int,
    dtype: torch.dtype,
    device: torch.device,
) -> Case:
    generator = torch.Generator(device=device).manual_seed(seed)
    nnz = max(1, int(rows * cols * density))
    flat = torch.randperm(rows * cols, generator=generator, device=device)[:nnz]
    row = torch.div(flat, cols, rounding_mode="floor")
    col = flat.remainder(cols)
    values = torch.randn(nnz, generator=generator, dtype=dtype, device=device)
    dense = torch.zeros((rows, cols), dtype=dtype, device=device)
    dense[row, col] = values
    rhs = torch.randn((cols, rhs_cols), generator=generator, dtype=dtype, device=device)
    return Case(name=name, dense=dense, rhs=rhs)


def make_block_case(
    name: str,
    rows: int,
    cols: int,
    rhs_cols: int,
    block_size: tuple[int, int],
    block_density: float,
    *,
    seed: int,
    dtype: torch.dtype,
    device: torch.device,
) -> Case:
    brow, bcol = block_size
    if rows % brow or cols % bcol:
        raise ValueError("rows and cols must be divisible by block size")
    generator = torch.Generator(device=device).manual_seed(seed)
    block_rows = rows // brow
    block_cols = cols // bcol
    n_blocks = max(1, int(block_rows * block_cols * block_density))
    chosen = torch.randperm(block_rows * block_cols, generator=generator, device=device)[:n_blocks]
    dense = torch.zeros((rows, cols), dtype=dtype, device=device)
    for block in chosen.tolist():
        br = block // block_cols
        bc = block % block_cols
        dense[br * brow : (br + 1) * brow, bc * bcol : (bc + 1) * bcol] = torch.randn(
            (brow, bcol), generator=generator, dtype=dtype, device=device
        )
    rhs = torch.randn((cols, rhs_cols), generator=generator, dtype=dtype, device=device)
    return Case(name=name, dense=dense, rhs=rhs)


def convert_layouts(dense: torch.Tensor, block_size: tuple[int, int]) -> dict[str, torch.Tensor]:
    coo = dense.to_sparse_coo().coalesce()
    return {
        "dense": dense,
        "coo": coo,
        "csr": dense.to_sparse_csr(),
        "csc": dense.to_sparse_csc(),
        "bsr": dense.to_sparse_bsr(blocksize=block_size),
        "bsc": dense.to_sparse_bsc(blocksize=block_size),
    }


def run_case(
    case: Case,
    block_size: tuple[int, int],
    *,
    device: torch.device,
    warmup: int,
    repeats: int,
    min_run_time: float,
) -> list[Result]:
    layouts = convert_layouts(case.dense, block_size)
    dense_bytes = tensor_bytes(case.dense)
    dense_time_ms = median_ms(
        lambda: torch.mm(case.dense, case.rhs),
        device=device,
        warmup=warmup,
        repeats=repeats,
        min_run_time=min_run_time,
    )
    results = [
        Result(
            case=case.name,
            layout="dense",
            nnz=int(torch.count_nonzero(case.dense).item()),
            storage_bytes=dense_bytes,
            storage_ratio=1.0,
            time_ms=dense_time_ms,
            time_ratio=1.0,
            note="baseline",
        )
    ]
    for layout_name in ["coo", "csr", "csc", "bsr", "bsc"]:
        sparse = layouts[layout_name]
        storage_bytes = sparse_storage_bytes(sparse)
        try:
            torch.sparse.mm(sparse, case.rhs)
            if device.type == "cuda":
                torch.cuda.synchronize()
            time_ms = median_ms(
                lambda sparse=sparse: torch.sparse.mm(sparse, case.rhs),
                device=device,
                warmup=warmup,
                repeats=repeats,
                min_run_time=min_run_time,
            )
            note = "ok"
        except Exception as exc:
            time_ms = None
            note = f"unsupported: {type(exc).__name__}: {str(exc).splitlines()[0]}"
        results.append(
            Result(
                case=case.name,
                layout=layout_name,
                nnz=int(sparse._nnz()),
                storage_bytes=storage_bytes,
                storage_ratio=storage_bytes / dense_bytes,
                time_ms=time_ms,
                time_ratio=None if time_ms is None else time_ms / dense_time_ms,
                note=note,
            )
        )
    return results


def fmt_bytes(value: int | None) -> str:
    if value is None:
        return "-"
    units = ["B", "KiB", "MiB", "GiB"]
    size = float(value)
    for unit in units:
        if size < 1024 or unit == units[-1]:
            return f"{size:.1f} {unit}" if unit != "B" else f"{int(size)} B"
        size /= 1024
    return f"{value} B"


def fmt_ratio(value: float | None) -> str:
    return "-" if value is None else f"{value:.3f}x"


def fmt_ms(value: float | None) -> str:
    return "-" if value is None else f"{value:.3f}"


def print_markdown(results: list[Result]) -> None:
    print(
        "| case | layout | nnz / blocks | storage | storage vs dense | "
        "median ms | time vs dense | note |"
    )
    print("| --- | ---: | ---: | ---: | ---: | ---: | ---: | --- |")
    for row in results:
        print(
            f"| {row.case} | {row.layout} | {row.nnz} | {fmt_bytes(row.storage_bytes)} | "
            f"{fmt_ratio(row.storage_ratio)} | {fmt_ms(row.time_ms)} | "
            f"{fmt_ratio(row.time_ratio)} | {row.note} |"
        )


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
    parser.add_argument("--rows", type=int, default=2048)
    parser.add_argument("--cols", type=int, default=2048)
    parser.add_argument("--rhs-cols", type=int, default=64)
    parser.add_argument("--block-size", type=int, default=16)
    parser.add_argument("--warmup", type=int, default=10)
    parser.add_argument("--repeats", type=int, default=50)
    parser.add_argument("--min-run-time", type=float, default=0.2)
    parser.add_argument("--require-torch-prefix", default="2.12.")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    if args.require_torch_prefix and not torch.__version__.startswith(args.require_torch_prefix):
        raise SystemExit(
            f"Expected torch {args.require_torch_prefix}*, got torch {torch.__version__}"
        )
    if args.device == "cuda" and not torch.cuda.is_available():
        raise SystemExit("CUDA device requested, but torch.cuda.is_available() is false")

    device = torch.device(args.device)
    dtype = torch.float32
    block_size = (args.block_size, args.block_size)
    cases = [
        make_unstructured_case("random 0.1%", args.rows, args.cols, args.rhs_cols, 0.001, seed=11, dtype=dtype, device=device),
        make_unstructured_case("random 1%", args.rows, args.cols, args.rhs_cols, 0.01, seed=12, dtype=dtype, device=device),
        make_unstructured_case("random 5%", args.rows, args.cols, args.rhs_cols, 0.05, seed=13, dtype=dtype, device=device),
        make_block_case("16x16 blocks 1%", args.rows, args.cols, args.rhs_cols, block_size, 0.01, seed=21, dtype=dtype, device=device),
    ]

    print(f"torch: {torch.__version__}")
    print(f"python: {platform.python_version()} ({platform.machine()})")
    print(f"device: {device}")
    if device.type == "cuda":
        props = torch.cuda.get_device_properties(device)
        print(f"cuda: {torch.version.cuda}")
        print(f"gpu: {props.name}")
        print(f"compute capability: {props.major}.{props.minor}")
        print(f"total memory: {fmt_bytes(props.total_memory)}")
    else:
        print(f"threads: {torch.get_num_threads()}")
    print(f"shape: {args.rows}x{args.cols}, rhs: {args.cols}x{args.rhs_cols}, dtype: float32")
    print(f"block size: {block_size[0]}x{block_size[1]}")
    print()

    all_results: list[Result] = []
    for case in cases:
        all_results.extend(
            run_case(
                case,
                block_size,
                device=device,
                warmup=args.warmup,
                repeats=args.repeats,
                min_run_time=args.min_run_time,
            )
        )
    print_markdown(all_results)


if __name__ == "__main__":
    main()

运行环境和统计方式:

结果用 dense 作为 1.0。time vs dense (CPU) = 0.102x 表示 CPU sparse 路径时间约为 CPU dense 的 10.2%;time vs dense (GPU) = 1.748x 表示 GPU sparse 路径比同一 GPU 上的 dense 慢约 1.75 倍。

caselayoutnnz / blocksstoragestorage vs densetime vs dense (CPU)time vs dense (GPU)note
random 0.1%dense419416.0 MiB1.000x1.000x1.000xbaseline
random 0.1%coo419481.9 KiB0.005x0.102x1.748xok
random 0.1%csr419497.9 KiB0.006x0.137x0.971xok
random 0.1%csc419497.9 KiB0.006x0.306x3.925xok
random 0.1%bsr36843.7 MiB0.228x-2.290xCPU unsupported
random 0.1%bsc36843.7 MiB0.228x--CPU/GPU unsupported
random 1%dense4194316.0 MiB1.000x1.000x1.000xbaseline
random 1%coo41943819.2 KiB0.050x0.792x1.486xok
random 1%csr41943835.2 KiB0.051x0.920x0.929xok
random 1%csc41943835.2 KiB0.051x4.169x4.115xok
random 1%bsr1513515.0 MiB0.938x-2.555xCPU unsupported
random 1%bsc1513515.0 MiB0.938x--CPU/GPU unsupported
random 5%dense20971516.0 MiB1.000x1.000x1.000xbaseline
random 5%coo2097154.0 MiB0.250x3.571x2.015xok
random 5%csr2097154.0 MiB0.251x4.149x1.007xok
random 5%csc2097154.0 MiB0.251x17.511x4.172xok
random 5%bsr1638416.3 MiB1.016x-2.674xCPU unsupported
random 5%bsc1638416.3 MiB1.016x--CPU/GPU unsupported
16x16 blocks 1%dense4172816.0 MiB1.000x1.000x1.000xbaseline
16x16 blocks 1%coo41728815.0 KiB0.050x0.833x1.481xok
16x16 blocks 1%csr41728831.0 KiB0.051x0.796x0.971xok
16x16 blocks 1%csc41728831.0 KiB0.051x2.753x4.004xok
16x16 blocks 1%bsr163166.6 KiB0.010x-2.226xCPU unsupported
16x16 blocks 1%bsc163166.6 KiB0.010x--CPU/GPU unsupported

这张表最容易被误读,所以先把边界说清楚:这些数字只对应这组 shape、这组稀疏模式和这两套运行环境。

先看存储。随机 0.1% 的 COO 只用了 dense 的 0.5% 左右。这里 dense 是 16 MiB,COO 是 81.9 KiB。如果问题只是“这个矩阵本身怎么放得下”,稀疏格式的收益很直接。

但存储省,不等于乘法一定快。CPU 上,随机 0.1% 的 COO/CSR 比 dense 快;到随机 5% 时,COO/CSR/CSC 都慢了。GPU 上,这组 H200 结果里 CSR 基本贴近 dense,COO 和 CSC 都更慢。原因不复杂:dense GEMM 是高度优化的连续计算;稀疏乘法要读索引、做间接寻址,还可能写出 dense 输出。非零元素多到一定程度后,索引和不规则访存会把省掉的乘法抵消掉。

Block sparse 更挑数据。随机 1% 时,16x16 BSR/BSC 已经接近 dense storage;随机 5% 时甚至略高于 dense。原因就是那个老问题:一个 block 里只要有一个非零元素,整个 16x16 payload 都要保存。换成 16x16 blocks 1% 这个人工 block case,BSR/BSC 只用了 dense 的 1% 左右。块结构是真的,block sparse 才像它应该有的样子。

还有一个很现实的点:layout 支持本身就是性能结论的一部分。本机 CPU wheel 没有跑通 BSR/BSC 的 torch.sparse.mm 路径;H200 上 BSR 可以跑,但这组 workload 里仍然慢于 dense;BSC 在 CUDA 上也没有这条 Strided + SparseBsc @ Strided 路径。这不是说 BSR/BSC 数据结构没意义,而是当前 backend 和 kernel 覆盖决定了你能不能用、用起来快不快。

怎么选格式

如果你手里是一批坐标和值,想先装进 sparse tensor,COO 通常最方便。它对构造友好,也能表达重复坐标。构造完成后如果要计算,先考虑 coalesce()

如果主要操作是 A @ X,其中 A 是左侧稀疏矩阵,CSR 是很自然的候选。它按行压缩,输出的每一行可以从一个连续 slice 里读出这一行的非零项。它不是永远最快,但它的数据结构和这类访问模式匹配。

如果你的访问模式天然按列组织,或者你经常处理转置视角,CSC 值得考虑。不要只因为 CSC 是 CSR 的“镜像”就假设两者速度相同。实际 kernel 往往不是对称的。

如果非零元素确实成块出现,再考虑 BSR/BSC。block sparse 的收益来自“减少块坐标”和“块内 dense 计算更规则”。如果你的非零只是随机散点,block sparse 会把块内的 0 一起存回来,存储可能很快接近 dense。

如果这是生产代码,最后还是要测。至少固定 shape、density、dtype、index dtype、right-hand side 宽度、device、PyTorch version 和 build 信息。稀疏矩阵的选择不是一个抽象格式题,而是数据分布和 kernel 支持共同决定的工程题。

参考


分享这篇文章:

上一篇
低维表示:投影降维、MRL 与稀疏表示
下一篇
当思考(CoT)遇见embedding