解决 NumPy 浮点数显示为 np.float64 的问题
问题描述
当在 GitHub Actions 的 Python 3.9+ 环境中运行包含 NumPy 数组的 doctest 测试时,会出现以下错误:
Expected:
[13.0, 12.0, 7.0]
Got:
[np.float64(13.0), np.float64(12.0), np.float64(7.0)]
测试数据计算结果正确,但 NumPy 浮点数被显示为 np.float64(...)
形式,导致测试失败。本地 Python 3.8 环境通常正常工作,问题主要发生在 CI/CD 环境中 Python 3.9 及以上版本。
问题根源
此问题由 NumPy 2.0 的标量显示格式变更引起:
- NumPy 1.x 显示:
1.0
- NumPy 2.x 显示:
np.float64(1.0)
具体行为改变:
python
# NumPy 1.x
>>> repr(np.array([1.0])[0])
'1.0'
# NumPy 2.x
>>> repr(np.array([1.0])[0])
'np.float64(1.0)'
关键提示
此打印格式变化不会影响数值计算本身,仅影响变量在 REPL 环境中的默认字符串表示形式(即调用 repr()
的结果)。
解决方案
推荐方案:配置 NumPy 打印选项 (最佳实践)
在测试入口或配置文件(如 conftest.py
)中加入:
python
np.set_printoptions(legacy='1.25')
注意事项
- 此设置仅影响打印格式,不会改变实际数据类型
legacy='1.25'
表示强制使用 NumPy 1.25.x 的打印样式- 对项目中所有 NumPy 数组均有效
python
# 设置前
print([np.float64(1.0), np.float64(2.0)])
# 输出:[np.float64(1.0), np.float64(2.0)]
# 设置后
np.set_printoptions(legacy='1.25')
print([np.float64(1.0), np.float64(2.0)])
# 输出:[1.0, 2.0]
备选方案:降级 NumPy 版本
在 requirements.txt
或 pyproject.toml
中固定版本:
text
numpy ~> 1.26 # 等效于 numpy>=1.26, ==1.*
替代方法:显式类型转换
python
# 将数组元素转为 Python 原生 float
print(list(map(float, np.array([13.0, 12.0, 7.0]))))
# 输出:[13.0, 12.0, 7.0]
实用技巧
对于一维数组打印,亦可使用:
python
np.savetxt(sys.stdout, np.array([1.0, 2.0, 3.0]))
方案比较
方法 | 优势 | 劣势 |
---|---|---|
设置打印选项 | 单行解决,全局生效 | 需提前配置 |
降级 NumPy | 无需修改代码 | 无法使用新版本特性 |
显式类型转换 | 精确控制输出格式 | 需逐处修改,增加维护成本 |
实施建议
优先采用打印选项配置:
- 在测试初始化代码中添加
np.set_printoptions(legacy='1.25')
- 确保所有测试运行前执行该设置
- 在测试初始化代码中添加
长期策略:
pythonimport numpy as np if hasattr(np, 'float64') and callable(np.float64): np.set_printoptions(legacy='1.25')
总结
NumPy 2.0 改变了标量默认显示格式是导致该问题的根本原因。推荐使用 np.set_printoptions(legacy='1.25')
全局配置解决方案。此方法既能维持数值计算的准确性,又能确保不同环境中数字显示的一致性,无需大规模修改测试用例或依赖特定 NumPy 版本。