Cython入门:将python代码转为cython
程序员文章站
2022-03-20 12:16:50
...
前言
本节不具体讲解cython的原理和细节,提供一个最简单的例子,将一个python代码转化为一个cython代码,同时由于本人对cython刚入门,只会一个简单的操作,即在cython中声明变量的类型。实验证明,就这样简单添加变量类型,代码运行速度提升了将近4倍
cython对于代码中许多循环的情况很有帮助!
python代码
这里给的是CVPPP官方提供的evaluate代码(evaluate.py)
为了节省空间,这里删除了注释和一些无关紧要的判断语句
import numpy as np
def DiffFGLabels(inLabel,gtLabel):
maxInLabel = np.int(np.max(inLabel))
minInLabel = np.int(np.min(inLabel))
maxGtLabel = np.int(np.max(gtLabel))
minGtLabel = np.int(np.min(gtLabel))
return (maxInLabel-minInLabel) - (maxGtLabel-minGtLabel)
def BestDice(inLabel,gtLabel):
score = 0
maxInLabel = np.max(inLabel)
minInLabel = np.min(inLabel)
maxGtLabel = np.max(gtLabel)
minGtLabel = np.min(gtLabel)
if(maxInLabel==minInLabel):
return score
for i in range(minInLabel+1,maxInLabel+1):
sMax = 0;
for j in range(minGtLabel+1,maxGtLabel+1):
s = Dice(inLabel, gtLabel, i, j)
if(sMax < s):
sMax = s
score = score + sMax;
score = score/(maxInLabel-minInLabel)
return score
def FGBGDice(inLabel,gtLabel):
minInLabel = np.min(inLabel)
minGtLabel = np.min(gtLabel)
one = np.ones(inLabel.shape)
inFgLabel = (inLabel != minInLabel*one)*one
gtFgLabel = (gtLabel != minGtLabel*one)*one
return Dice(inFgLabel,gtFgLabel,1,1)
def Dice(inLabel, gtLabel, i, j):
one = np.ones(inLabel.shape)
inMask = (inLabel==i*one)
gtMask = (gtLabel==j*one)
inSize = np.sum(inMask*one)
gtSize = np.sum(gtMask*one)
overlap= np.sum(inMask*gtMask*one)
if ((inSize + gtSize)>1e-8):
out = 2*overlap/(inSize + gtSize)
else:
out = 0
return out
def AbsDiffFGLabels(inLabel,gtLabel):
return np.abs( DiffFGLabels(inLabel,gtLabel) )
def SymmetricBestDice(inLabel,gtLabel):
bd1 = BestDice(inLabel,gtLabel)
bd2 = BestDice(gtLabel,inLabel)
if bd1 < bd2:
return bd1
else:
return bd2
Cython代码
创建一个evaluate.pyx文件(注意:后缀得是pyx!!!)
from __future__ import division
from libcpp cimport bool as bool_t
import numpy as np
cimport numpy as np
cimport cython
ctypedef bint TYPE_BOOL
ctypedef unsigned long long TYPE_U_INT64
ctypedef unsigned int TYPE_U_INT32
ctypedef unsigned short TYPE_U_INT16
ctypedef unsigned char TYPE_U_INT8
ctypedef long long TYPE_INT64
ctypedef int TYPE_INT32
ctypedef short TYPE_INT16
ctypedef signed char TYPE_INT8
ctypedef float TYPE_FLOAT
ctypedef double TYPE_DOUBLE
@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
def DiffFGLabels(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
cdef int maxInLabel = np.int(np.max(inLabel))
cdef int minInLabel = np.int(np.min(inLabel))
cdef int maxGtLabel = np.int(np.max(gtLabel))
cdef int minGtLabel = np.int(np.min(gtLabel))
cdef double out = (maxInLabel-minInLabel) - (maxGtLabel-minGtLabel)
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def BestDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
cdef int i, j
cdef double sMax = 0.0
cdef double s = 0.0
cdef double score = 0.0
cdef int maxInLabel = np.max(inLabel)
cdef int minInLabel = np.min(inLabel)
cdef int maxGtLabel = np.max(gtLabel)
cdef int minGtLabel = np.min(gtLabel)
if(maxInLabel == minInLabel):
return score
for i in range(minInLabel+1, maxInLabel+1):
sMax = 0;
for j in range(minGtLabel+1, maxGtLabel+1):
s = Dice(inLabel, gtLabel, i, j)
if(sMax < s):
sMax = s
score = score + sMax;
score = score / (maxInLabel-minInLabel)
return score
@cython.boundscheck(False)
@cython.wraparound(False)
def FGBGDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
cdef int minInLabel = np.min(inLabel)
cdef int minGtLabel = np.min(gtLabel)
cdef np.ndarray[TYPE_U_INT16, ndim=2] one = np.ones_like(inLabel)
cdef np.ndarray[TYPE_U_INT16, ndim=2] inFgLabel = (inLabel != minInLabel*one)*one
cdef np.ndarray[TYPE_U_INT16, ndim=2] gtFgLabel = (gtLabel != minGtLabel*one)*one
cdef double out = Dice(inFgLabel,gtFgLabel,1,1)
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def Dice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel, int i, int j):
cdef double out = 0.0
cdef np.ndarray[TYPE_U_INT16, ndim=2] one = np.ones_like(inLabel)
cdef int inSize = np.sum((inLabel==i*one)*one)
cdef int gtSize = np.sum((gtLabel==j*one)*one)
cdef int overlap= np.sum((inLabel==i*one)*(gtLabel==j*one)*one)
if ((inSize + gtSize)>1e-8):
out = 2*overlap/(inSize + gtSize)
else:
out = 0
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def AbsDiffFGLabels(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
cdef double out = np.abs(DiffFGLabels(inLabel, gtLabel))
return out
@cython.boundscheck(False)
@cython.wraparound(False)
def SymmetricBestDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
cdef double bd1 = BestDice(inLabel,gtLabel)
cdef double bd2 = BestDice(gtLabel,inLabel)
if bd1 < bd2:
return bd1
else:
return bd2
编译
写好pyx文件后,需要再写一个setup.py文件,里面的内容也很简单(注意修改相应的pyx文件名!!!):
import distutils.core
import Cython.Build
import numpy as np
distutils.core.setup(
ext_modules = Cython.Build.cythonize("evaluate.pyx"),
include_dirs = [np.get_include()])
编译:
python setup.py build_ext --inplace
编译成功后,就可以正常的 import 里面的函数了
解释
通过对比两个代码,我们可以看出一些规律也可以总结出一些规律
- 在导入包的时候,有一句最重要的是:cimport numpy as np,表明使用的是cython接口的numpy。(当然也有import numpy as np,编译器会根据情况使用numpy还是c-numpy);还有一句是:from libcpp cimport bool as bool_t,这是为了使用C语言中的bool类型(这个例子里面没有用到bool类型,可以不用管)
- 为数据类型起一个新名字:ctypedef。这个不是必须,但这里为了可读性,我列举了一些numpy中常用的数据类型对应的C语言中的数据类型
numpy | C |
---|---|
np.uint8 | unsigned char |
np.uint16 | unsigned short |
np.uint32 | unsigned int |
np.uint64 | unsigned long long |
np.int8 | signed char |
np.int16 | short |
np.int32 | int |
np.int64 | long long |
np.float32 | double |
- 每个函数前面都有:@cython.boundscheck(False) 和 @cython.wraparound(False),这是为了加速而关闭边界检查,这样做就需要提前保证代码的准确性,建议在python下验证代码的准确性
- 每个函数的输入变量都定义了数据类型,比如这里全是:np.ndarray[TYPE_U_INT16, ndim=2],这表明输入的一个二维的uint16的numpy数组,如果输入类型不是这样,那就会报错
- cdef int/double:定义整型/双精度浮点型变量。在使用每个变量须先对它进行定义,如果没有编译器就会花时间来判断,就会耗时
上一篇: 使用 php4 加速 web 传输
下一篇: php生成WAP页面