python包-cython编译脚本

使用 cython 编译 python 脚本。

安装 cython

1
pip install cython

然后查看是否安装上

1
2
$ cython -V
Cython version 0.29.33

修改 pyx 脚本

以下面的程序为例,这个程序是计算一个对称矩阵和一个输入向量的乘积,这个矩阵的上三角元素按照三元组的形式存储在 (row_list, col_list, data_list) 中,x 为输入向量,函数返回乘积 b

首先需要用 cimport cython 导入模块。

row_listcol_list 的数据类型定义为 np.intc ,与 C 语言中的 int 类型对应;将 data_list的数据类型定义为 np.double ,与 C 语言中的 double 类型对应。这里就是要保证数据类型一致,具体的 numpy 与 C 语言所有数据类型的对应关系可见官网 数组类型之间的转换

通过网上发现 numpy.double 等同于 numpy.float64 , numpy.intc 等同于 numpy.int32, 证明如下:

1
2
3
4
5
6
7
In [1]: import numpy as np

In [2]: np.double is np.float64
Out[2]: True

In [3]: np.intc is np.int32
Out[3]: True

@cython.wraparound(False)@cython.boundscheck(False) 这两句是说不再检查循环变量 (i ) 的范围。

@cython.cdivision(True) 是说使用 C 语言中的除法,而不是 python 的除法。注意如果是两个整数相除,C 语言的除法会取整,想要得到正确的结果,需要将一个数改为浮点数,比如将 x 改为 <double> (x)

然后我们看函数内部 ,首先就是将数据类型改为与 C 语言一致,x = x.astype(np.double) 。然后数组对象需要放到缓存中,这样就能将 python 的数组对象转为 C 语言直接处理的对象,比如 cdef int[:] row_list_view = row_list(对 row_list_view 的处理,row_list 也会出现相应的改变)。标量可以直接定义,注意这里涉及数据维度的地方也需要提前定义,因为数组维度对象也是 python 对象,比如 cdef int element_num = row_list.shape[0] 。对于循环变量,定义为 cdef Py_ssize_t

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
cimport cython

row_list = np.array(row_list, dtype = np.intc)
col_list = np.array(col_list, dtype = np.intc)
data_list = np.array(data_list, dtype = np.double)

@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True) # 使用 c 的除法
def Ax(x):
# lhs_dict2 为 A 的字典, 0-index
# x 为需要乘的向量,输出 Ax
# x = x.astype(np.double) # 转换类型
b = np.zeros_like(x, dtype = np.double)
cdef int[:] row_list_view = row_list # 存入缓存中
cdef int[:] col_list_view = col_list # 存入缓存中
cdef double[:] data_list_view = data_list # 存入缓存中
cdef double[:] x_view = x # 存入缓存中
cdef double[:] b_view = b # 存入缓存中

cdef int element_num = row_list.shape[0]
cdef Py_ssize_t i # 循环变量
cdef int row_index,col_index
cdef double data
for i in range(element_num):
row_index = row_list_view[i] - 1 # 0-index
col_index = col_list_view[i] - 1 # 0-index
data = data_list_view[i]
b_view[row_index] += data * x_view[col_index]
if (row_index != col_index): # 非对角线元素
b_view[col_index] += data * x_view[row_index]
return b

查看修改效果

Cython 程序的扩展名是 .pyx ,使用下面的命令会生成一个同名的 html ,这个文件中黄色部分就是和 python 进行交互的地方,也就是拖累性能的地方。

1
cython -a *.pyx

如果遇到提示 FutureWarning: Cython directive 'language_level' not set, using 2 for now (Py2). 。只需要在 pyx 文件开头增加一条注释,指明使用 python3 即可。

1
# cython: language_level=3

这里只要循环部分不是黄的,就可以了,比如上面的函数的 html 文件如下:

使用 cython 编译 py 文件

创建一个临时的 py 文件,比如 setup.py,其中内容如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize

# 多个文件
# 第一个参数 recode 是输出模块的前缀,第二个列表是所有需要编译的源文件
extensions=[
Extension("recode", ["recode_so.py"]), # 配置需要cython编译的源文件
Extension("blup", ["blup_so.py"]),
]

setup(ext_modules=cythonize(extensions, language_level=3))
# 里面的 language_level=3 表示只需要兼容 python3 即可, 而默认是 2 和 3 都兼容

然后运行脚本进行编译

1
python setup.py build

在我们执行命令之后,运行成功结尾就是 .so。并且在当前目录会多出一个 build 目录,里面的 .so 文件就是编译后的文件(windows 系统中是 .pyd

实际使用发现,cython 不能和 numba 联合使用。

参考文献

  1. https://www.cnblogs.com/traditional/p/13213173.html

  2. https://www.jianshu.com/p/9053dacee822

  3. https://www.bilibili.com/video/BV1NF411i74S?p=1&vd_source=584d44ba7fed271fb82848e7894a53c8

  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2019-2024 Vincere Zhou
  • 访问人数: | 浏览次数:

请我喝杯茶吧~

支付宝
微信