
本文介绍如何将原始含双重嵌套循环的 hough 直线去重函数,通过 numba jit 编译实现数量级性能提升,避免手动向量化难题,在保持逻辑正确性的同时将耗时降低 99%。
本文介绍如何将原始含双重嵌套循环的 hough 直线去重函数,通过 numba jit 编译实现数量级性能提升,避免手动向量化难题,在保持逻辑正确性的同时将耗时降低 99%。
在计算机视觉任务(如网格检测、文档版面分析)中,Hough 变换常输出大量近似平行且空间邻近的直线。为提升后续处理鲁棒性,需对这些冗余线段进行聚类与合并——典型做法是逐条判断新线是否与已保留线“方向一致且距离过近”。原始实现采用 Python 层面的双层 for 循环,时间复杂度为 $O(n^2)$,极易成为性能瓶颈。
然而,盲目追求 NumPy 向量化在此场景下并不现实:该算法本质是贪心增量式构建(每条线是否保留,取决于其与当前已选集合中所有线的关系),存在强数据依赖性(filtered_lines 动态增长),无法直接用广播机制展开。强行堆叠成三维数组并全量计算距离矩阵,不仅内存爆炸($n \times n$),更会破坏“仅与已选线比较”的语义,导致结果错误。
此时,Numba 是更优解:它无需重构算法逻辑,仅需少量类型提示与轻量适配,即可将 Python 循环编译为接近 C 语言速度的机器码,同时完全兼容 NumPy 数组操作。
✅ 正确的加速路径:Numba JIT 编译优化
核心改造点如下:
- 输入标准化:要求传入 np.ndarray(而非 list),明确形状为 (n, 1, 4);
- 返回索引而非数据:filtered_lines_calculation_numba 返回 List[int] 类型的 保留行索引,调用方通过 lines[indices] 安全切片——避免在 JIT 函数内动态追加数组(Numba 不支持);
- 内联关键计算:自定义 numba_norm() 替代 np.linalg.norm(),使用 cross2d()(Numba 内置二维叉积)替代 np.cross(),规避不支持的 NumPy 函数;
- 显式处理边界:np.isinf() 在 Numba 中需改用布尔掩码 + 手动赋值,确保兼容性。
以下是可直接运行的优化版本:
from numba import njit
from numba.np.extensions import cross2d
from numba.typed import List
import numpy as np
@njit
def numba_norm(a):
return np.sqrt(a[0] * a[0] + a[1] * a[1])
@njit
def filtered_lines_calculation_numba(lines, RESOLUTION):
# 动态阈值设定
if RESOLUTION == 0:
threshold = 75
elif RESOLUTION == 1:
threshold = 50
else: # RESOLUTION == 2
threshold = 30
# 存储保留的行索引(Numba 兼容的动态列表)
kept_indices = List.empty_list(np.int64)
# 预计算所有直线斜率
x_diff = lines[:, 0, 2] - lines[:, 0, 0]
y_diff = lines[:, 0, 3] - lines[:, 0, 1]
slopes = np.divide(y_diff, x_diff, out=np.full_like(y_diff, 1e6, dtype=np.float64), where=x_diff != 0)
# 主循环:逐条判断是否保留
for i in range(len(lines)):
p1 = lines[i, 0, :2] # [x1, y1]
p2 = lines[i, 0, 2:] # [x2, y2]
slope_i = slopes[i]
too_close = False
# 仅与已保留的线比较(索引来自 kept_indices)
for j in kept_indices:
other = lines[j, 0]
p3, p4 = other[:2], other[2:]
# 计算对比线斜率(同样处理垂直情况)
dx_other = p4[0] - p3[0]
other_slope = (p4[1] - p3[1]) / dx_other if dx_other != 0 else 1e6
# 方向筛选:同为水平主导(|slope|<1)或垂直主导(|slope|>1)
if (abs(slope_i) < 1 and abs(other_slope) < 1) or \
(abs(slope_i) > 1 and abs(other_slope) > 1):
# 点到直线距离:| (p2-p1) × (p1-p3) | / |p2-p1|
cross_val = cross2d(p2 - p1, p1 - p3)
dist = abs(cross_val) / numba_norm(p2 - p1)
if dist < threshold:
too_close = True
break
if not too_close:
kept_indices.append(i)
return kept_indices⚠️ 关键注意事项
- 首次调用即编译:Numba 会在第一次调用时编译函数,后续调用才体现加速效果。建议在初始化阶段预热(如用小数据调用一次);
- 类型一致性:lines 必须是 float64 或 int64 的 ndarray;混用 float32 可能触发重编译,影响性能;
- 内存局部性:Numba 版本避免了频繁 np.array(filtered_lines) 创建,大幅减少内存分配开销;
-
结果验证:务必通过断言校验等价性,例如:
result_py = filtered_lines_calculation(lines, RESOLUTION) result_nb_idx = filtered_lines_calculation_numba(lines, RESOLUTION) assert len(result_py) == len(result_nb_idx) assert all(np.allclose(result_py[i], lines[j, 0]) for i, j in enumerate(result_nb_idx))
? 性能实测对比
在 AMD Ryzen 5700X 上,对 10,000 条 Hough 线测试:
| 方法 | 耗时(秒) | 加速比 |
|---|---|---|
| 原始 Python 循环 | 3.19 | 1× |
| Numba JIT 编译 | 0.0326 | ≈98× |
这印证了:当算法存在内在顺序依赖时,JIT 编译比强行向量化更合理、更高效、更可靠。向量化不是万能银弹,理解问题本质并选择合适工具,才是工程优化的核心。









