0

0

高效计算Python中的稀疏成对距离

聖光之護

聖光之護

发布时间:2025-09-30 15:28:41

|

305人浏览过

|

来源于php中文网

原创

高效计算python中的稀疏成对距离

本文旨在解决在Python中高效计算两组向量之间稀疏成对距离的问题。针对传统NumPy方法在处理大量向量时因计算冗余而导致的性能瓶颈,本文提出了一种结合Numba即时编译和SciPy稀疏矩阵(特别是CSR格式)的优化方案。通过在Numba加速的循环中仅计算所需的距离并构建稀扑矩阵,该方法显著提升了计算效率和内存利用率,特别适用于距离矩阵高度稀疏的场景。

问题背景与传统方法分析

在许多数据处理和机器学习任务中,我们可能需要计算两组向量集 A 和 B 之间的所有成对距离。然而,在某些特定场景下,我们仅对其中一小部分成对距离感兴趣,例如,当一个掩码矩阵 M 指定了需要保留的距离对时。

传统的NumPy方法通常涉及计算所有可能的成对距离,然后通过掩码矩阵进行筛选。以下是一个示例:

import numpy as np

A = np.array([[1, 2], [2, 3], [3, 4]])                              # (3, 2)
B = np.array([[4, 5], [5, 6], [6, 7], [7, 8], [8, 9]])              # (5, 2)
M = np.array([[0, 0, 0, 1, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 1]])   # (3, 5)

# 计算所有向量对的差值
diff = A[:, None] - B[None, :]                                      # (3, 5, 2)
# 计算所有成对距离(L2范数)
distances = np.linalg.norm(diff, ord=2, axis=2)                     # (3, 5)
# 应用掩码,保留所需距离
masked_distances = distances * M                                    # (3, 5)

print("计算的距离矩阵:\n", distances)
print("掩码后的距离矩阵:\n", masked_distances)

这种方法虽然简洁,但当 A 和 B 的行数非常大时(例如数千行),diff 和 distances 矩阵会变得非常庞大,导致计算大量不必要的距离,从而消耗大量的计算资源和内存。即使通过 np.vectorize 尝试创建条件函数,也可能因为Python循环的开销而导致性能不佳,甚至更慢。

优化方案:Numba加速与CSR稀疏矩阵

为了解决上述性能瓶颈,我们引入一种结合 Numba 即时编译和 SciPy 稀疏矩阵(特别是 Compressed Sparse Row, CSR 格式)的优化方案。该方案的核心思想是:

立即学习Python免费学习笔记(深入)”;

  1. 避免冗余计算:仅计算掩码矩阵 M 中指定为 True 的那些成对距离。
  2. 高效存储:使用 CSR 稀疏矩阵来存储结果,只存储非零距离值,显著减少内存占用
  3. 性能提升:利用 Numba 对核心计算逻辑进行 JIT 编译,将 Python 循环的性能提升至接近 C 语言的水平。

1. 自定义欧几里得距离函数

首先,我们定义一个 Numba 加速的欧几里得距离函数。在 Numba 环境下,自定义的循环计算通常比调用 np.linalg.norm 更快。

import numba as nb
import numpy as np
import scipy.sparse
import math

@nb.njit()
def euclidean_distance(vec_a, vec_b):
    """
    计算两个向量之间的欧几里得距离。
    使用 Numba 加速,避免 np.linalg.norm 的开销。
    """
    acc = 0.0
    for i in range(vec_a.shape[0]):
        acc += (vec_a[i] - vec_b[i]) ** 2
    return math.sqrt(acc)

这里,@nb.njit() 装饰器指示 Numba 在函数首次调用时将其编译为优化的机器码。

问小白
问小白

免费使用DeepSeek满血版

下载

2. 核心距离计算与稀疏数据填充函数

接下来,我们创建 masked_distance_inner 函数。这是一个 Numba 加速的核心函数,负责遍历掩码矩阵,只计算所需的距离,并将结果填充到 CSR 矩阵所需的 data、indicies 和 indptr 数组中。

@nb.njit()
def masked_distance_inner(data, indicies, indptr, matrix_a, matrix_b, mask):
    """
    Numba 加速的核心函数,根据掩码计算距离并填充 CSR 矩阵的内部数组。
    参数:
        data (np.ndarray): 存储非零距离值的数组。
        indicies (np.ndarray): 存储非零距离值对应列索引的数组。
        indptr (np.ndarray): 存储每行在 data/indicies 中起始位置的数组。
        matrix_a (np.ndarray): 第一个向量集。
        matrix_b (np.ndarray): 第二个向量集。
        mask (np.ndarray): 布尔掩码矩阵,指示哪些距离需要计算。
    """
    write_pos = 0
    N, M = matrix_a.shape[0], matrix_b.shape[0]

    # 遍历所有可能的向量对
    for i in range(N):
        for j in range(M):
            # 只有当掩码为 True 时才计算距离
            if mask[i, j]:
                # 记录距离值
                data[write_pos] = euclidean_distance(matrix_a[i], matrix_b[j])
                # 记录该距离值对应的列索引
                indicies[write_pos] = j
                write_pos += 1
        # 记录当前行结束后,data/indicies 中元素的总数,作为下一行的起始位置
        indptr[i + 1] = write_pos

    # 确保所有预分配的空间都被使用
    assert write_pos == data.shape[0]
    assert write_pos == indicies.shape[0]
    # data, indicies, indptr 会在函数外部被修改并用于构建 CSR 矩阵

3. 稀疏距离矩阵构建函数

最后,我们定义 masked_distance 函数,它负责设置算法的参数、预分配内存,并调用 masked_distance_inner 来执行计算,最终返回一个 scipy.sparse.csr_matrix 对象。

def masked_distance(matrix_a, matrix_b, mask):
    """
    计算两组向量之间掩码指定的稀疏成对距离。
    参数:
        matrix_a (np.ndarray): 第一个向量集。
        matrix_b (np.ndarray): 第二个向量集。
        mask (np.ndarray): 布尔掩码矩阵。
    返回:
        scipy.sparse.csr_matrix: 包含指定成对距离的稀疏矩阵。
    """
    N, M = matrix_a.shape[0], matrix_b.shape[0]
    assert mask.shape == (N, M), "掩码矩阵的形状必须与向量集兼容。"

    # 确保掩码是布尔类型
    mask = mask != 0

    # 计算稀疏矩阵中非零元素的总数
    sparse_length = mask.sum()

    # 为 CSR 矩阵预分配内存。这些数组不需要初始化为零,直接分配内存更高效。
    data = np.empty(sparse_length, dtype='float64')    # 存储非零数据值
    indicies = np.empty(sparse_length, dtype='int64')  # 存储列索引
    indptr = np.zeros(N + 1, dtype='int64')            # 存储行指针

    # 调用 Numba 加速的核心函数进行计算和填充
    masked_distance_inner(data, indicies, indptr, matrix_a, matrix_b, mask)

    # 使用填充好的数据构建 CSR 稀疏矩阵
    return scipy.sparse.csr_matrix((data, indicies, indptr), shape=(N, M))

示例用法与性能分析

为了演示和评估其性能,我们使用更大的随机生成数据集进行测试。

# 准备大型测试数据
A_big = np.random.rand(2000, 10)
B_big = np.random.rand(4000, 10)
# 创建一个高度稀疏的掩码(0.1% 的元素为 True)
M_big = np.random.rand(A_big.shape[0], B_big.shape[0]) < 0.001

# 使用优化的方法计算稀疏距离
sparse_distances = masked_distance(A_big, B_big, M_big)

print(f"稀疏距离矩阵的形状: {sparse_distances.shape}")
print(f"稀疏距离矩阵的非零元素数量: {sparse_distances.nnz}")
print(f"稀疏距离矩阵的密度: {sparse_distances.nnz / (sparse_distances.shape[0] * sparse_distances.shape[1]):.6f}")

# 性能基准测试 (在Jupyter/IPython环境中运行)
# %timeit masked_distance(A_big, B_big, M_big)
# 
# 原始方法的性能基准测试 (仅供参考,不推荐在生产环境运行大型矩阵)
# %timeit np.linalg.norm(A_big[:,None] - B_big[None,:], ord=2, axis=2) * M_big

在上述 A_big (2000x10) 和 B_big (4000x10) 的测试场景中,当掩码 M_big 只有约 0.1% 的元素为 True 时,此优化方案相比原始的 NumPy 全量计算方法,可以实现显著的性能提升(例如,40倍甚至更高)。具体的加速效果会随着矩阵大小和掩码稀疏度的增加而更加明显。

注意事项与优化建议

  1. Numba 编译开销:euclidean_distance 和 masked_distance_inner 函数在首次调用时会有编译开销。在后续调用中,性能将大幅提升。
  2. 数据类型选择
    • data 数组默认使用 float64。如果对精度要求不高,可以考虑使用 float32 来减少内存占用并可能提高计算速度。
    • indicies 和 indptr 数组默认使用 int64。如果矩阵的维度和非零元素数量都小于 231,可以安全地使用 int32,进一步节省内存。
  3. 稀疏度影响:此方法的性能优势主要体现在掩码矩阵高度稀疏的场景。如果掩码矩阵非常稠密(例如,超过 50% 的元素为 True),那么全量计算后筛选的方法可能反而更简单或性能差异不显著。
  4. 正确性验证:在实际应用中,务必通过 np.allclose() 等方法验证优化后的结果与原始方法的结果是否一致。
  5. 内存管理:在 masked_distance 函数中,data 和 indicies 数组是使用 np.empty 创建的,它们不进行零初始化,这比 np.zeros 更快,因为我们会在 masked_distance_inner 中完全覆盖这些内存。

总结

通过结合 Numba 的即时编译能力和 SciPy 的 CSR 稀疏矩阵格式,我们能够高效地计算两组向量之间指定的一小部分成对距离。这种方法通过避免不必要的计算和优化内存使用,为处理大规模稀疏距离计算问题提供了一个强大且高性能的解决方案。在面临大量数据且仅需少量成对距离的场景时,采用此教程介绍的方案将显著提升应用程序的性能和资源利用率。

相关专题

更多
python开发工具
python开发工具

php中文网为大家提供各种python开发工具,好的开发工具,可帮助开发者攻克编程学习中的基础障碍,理解每一行源代码在程序执行时在计算机中的过程。php中文网还为大家带来python相关课程以及相关文章等内容,供大家免费下载使用。

769

2023.06.15

python打包成可执行文件
python打包成可执行文件

本专题为大家带来python打包成可执行文件相关的文章,大家可以免费的下载体验。

661

2023.07.20

python能做什么
python能做什么

python能做的有:可用于开发基于控制台的应用程序、多媒体部分开发、用于开发基于Web的应用程序、使用python处理数据、系统编程等等。本专题为大家提供python相关的各种文章、以及下载和课程。

764

2023.07.25

format在python中的用法
format在python中的用法

Python中的format是一种字符串格式化方法,用于将变量或值插入到字符串中的占位符位置。通过format方法,我们可以动态地构建字符串,使其包含不同值。php中文网给大家带来了相关的教程以及文章,欢迎大家前来阅读学习。

639

2023.07.31

python教程
python教程

Python已成为一门网红语言,即使是在非编程开发者当中,也掀起了一股学习的热潮。本专题为大家带来python教程的相关文章,大家可以免费体验学习。

1325

2023.08.03

python环境变量的配置
python环境变量的配置

Python是一种流行的编程语言,被广泛用于软件开发、数据分析和科学计算等领域。在安装Python之后,我们需要配置环境变量,以便在任何位置都能够访问Python的可执行文件。php中文网给大家带来了相关的教程以及文章,欢迎大家前来学习阅读。

549

2023.08.04

python eval
python eval

eval函数是Python中一个非常强大的函数,它可以将字符串作为Python代码进行执行,实现动态编程的效果。然而,由于其潜在的安全风险和性能问题,需要谨慎使用。php中文网给大家带来了相关的教程以及文章,欢迎大家前来学习阅读。

579

2023.08.04

scratch和python区别
scratch和python区别

scratch和python的区别:1、scratch是一种专为初学者设计的图形化编程语言,python是一种文本编程语言;2、scratch使用的是基于积木的编程语法,python采用更加传统的文本编程语法等等。本专题为大家提供scratch和python相关的文章、下载、课程内容,供大家免费下载体验。

709

2023.08.11

excel表格操作技巧大全 表格制作excel教程
excel表格操作技巧大全 表格制作excel教程

Excel表格操作的核心技巧在于 熟练使用快捷键、数据处理函数及视图工具,如Ctrl+C/V(复制粘贴)、Alt+=(自动求和)、条件格式、数据验证及数据透视表。掌握这些可大幅提升数据分析与办公效率,实现快速录入、查找、筛选和汇总。

0

2026.01.21

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 9.6万人学习

Django 教程
Django 教程

共28课时 | 3.3万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号