欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Cython入门:将python代码转为cython

程序员文章站 2022-03-20 12:16:50
...

前言

本节不具体讲解cython的原理和细节,提供一个最简单的例子,将一个python代码转化为一个cython代码,同时由于本人对cython刚入门,只会一个简单的操作,即在cython中声明变量的类型。实验证明,就这样简单添加变量类型,代码运行速度提升了将近4倍
cython对于代码中许多循环的情况很有帮助!

python代码

这里给的是CVPPP官方提供的evaluate代码(evaluate.py)
为了节省空间,这里删除了注释和一些无关紧要的判断语句

import numpy as np
def DiffFGLabels(inLabel,gtLabel):
    maxInLabel = np.int(np.max(inLabel))
    minInLabel = np.int(np.min(inLabel))
    maxGtLabel = np.int(np.max(gtLabel))
    minGtLabel = np.int(np.min(gtLabel))
    return  (maxInLabel-minInLabel) - (maxGtLabel-minGtLabel) 

def BestDice(inLabel,gtLabel):
    score = 0
    maxInLabel = np.max(inLabel) 
    minInLabel = np.min(inLabel) 
    maxGtLabel = np.max(gtLabel) 
    minGtLabel = np.min(gtLabel) 
    if(maxInLabel==minInLabel):
        return score
    for i in range(minInLabel+1,maxInLabel+1):
        sMax = 0; 
        for j in range(minGtLabel+1,maxGtLabel+1): 
            s = Dice(inLabel, gtLabel, i, j) 
            if(sMax < s):
                sMax = s
        score = score + sMax; 
    score = score/(maxInLabel-minInLabel)
    return score

def FGBGDice(inLabel,gtLabel):
    minInLabel = np.min(inLabel) 
    minGtLabel = np.min(gtLabel) 
    one = np.ones(inLabel.shape)    
    inFgLabel = (inLabel != minInLabel*one)*one
    gtFgLabel = (gtLabel != minGtLabel*one)*one
    return Dice(inFgLabel,gtFgLabel,1,1)

def Dice(inLabel, gtLabel, i, j):
    one = np.ones(inLabel.shape)
    inMask = (inLabel==i*one) 
    gtMask = (gtLabel==j*one) 
    inSize = np.sum(inMask*one) 
    gtSize = np.sum(gtMask*one) 
    overlap= np.sum(inMask*gtMask*one) 
    if ((inSize + gtSize)>1e-8):
        out = 2*overlap/(inSize + gtSize) 
    else:
        out = 0
    return out

def AbsDiffFGLabels(inLabel,gtLabel):
    return np.abs( DiffFGLabels(inLabel,gtLabel) )

def SymmetricBestDice(inLabel,gtLabel):
    bd1 = BestDice(inLabel,gtLabel)
    bd2 = BestDice(gtLabel,inLabel)
    if bd1 < bd2:
        return bd1
    else:
        return bd2

Cython代码

创建一个evaluate.pyx文件(注意:后缀得是pyx!!!

from __future__ import division
from libcpp cimport bool as bool_t
import numpy as np
cimport numpy as np
cimport cython

ctypedef bint TYPE_BOOL
ctypedef unsigned long long TYPE_U_INT64
ctypedef unsigned int TYPE_U_INT32
ctypedef unsigned short TYPE_U_INT16
ctypedef unsigned char TYPE_U_INT8
ctypedef long long TYPE_INT64
ctypedef int TYPE_INT32
ctypedef short TYPE_INT16
ctypedef signed char TYPE_INT8
ctypedef float TYPE_FLOAT
ctypedef double TYPE_DOUBLE

@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False)  # turn off negative index wrapping for entire function
def DiffFGLabels(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef int maxInLabel = np.int(np.max(inLabel)) 
    cdef int minInLabel = np.int(np.min(inLabel)) 
    cdef int maxGtLabel = np.int(np.max(gtLabel)) 
    cdef int minGtLabel = np.int(np.min(gtLabel)) 
    cdef double out = (maxInLabel-minInLabel) - (maxGtLabel-minGtLabel)
    return out

@cython.boundscheck(False)
@cython.wraparound(False)
def BestDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef int i, j
    cdef double sMax = 0.0
    cdef double s = 0.0
    cdef double score = 0.0 
    cdef int maxInLabel = np.max(inLabel) 
    cdef int minInLabel = np.min(inLabel) 
    cdef int maxGtLabel = np.max(gtLabel) 
    cdef int minGtLabel = np.min(gtLabel) 
    if(maxInLabel == minInLabel): 
        return score
    for i in range(minInLabel+1, maxInLabel+1):
        sMax = 0;
        for j in range(minGtLabel+1, maxGtLabel+1):
            s = Dice(inLabel, gtLabel, i, j) 
            if(sMax < s):
                sMax = s
        score = score + sMax;
    score = score / (maxInLabel-minInLabel)
    return score

@cython.boundscheck(False)
@cython.wraparound(False)
def FGBGDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef int minInLabel = np.min(inLabel) 
    cdef int minGtLabel = np.min(gtLabel) 
    cdef np.ndarray[TYPE_U_INT16, ndim=2] one = np.ones_like(inLabel)
    cdef np.ndarray[TYPE_U_INT16, ndim=2] inFgLabel = (inLabel != minInLabel*one)*one
    cdef np.ndarray[TYPE_U_INT16, ndim=2] gtFgLabel = (gtLabel != minGtLabel*one)*one
    cdef double out = Dice(inFgLabel,gtFgLabel,1,1) 
    return out

@cython.boundscheck(False)
@cython.wraparound(False)
def Dice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel, int i, int j):
    cdef double out = 0.0
    cdef np.ndarray[TYPE_U_INT16, ndim=2] one = np.ones_like(inLabel)
    cdef int inSize = np.sum((inLabel==i*one)*one) 
    cdef int gtSize = np.sum((gtLabel==j*one)*one) 
    cdef int overlap= np.sum((inLabel==i*one)*(gtLabel==j*one)*one) 
    if ((inSize + gtSize)>1e-8):
        out = 2*overlap/(inSize + gtSize) 
    else:
        out = 0
    return out

@cython.boundscheck(False)
@cython.wraparound(False)
def AbsDiffFGLabels(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef double out = np.abs(DiffFGLabels(inLabel, gtLabel))
    return out

@cython.boundscheck(False)
@cython.wraparound(False)
def SymmetricBestDice(np.ndarray[TYPE_U_INT16, ndim=2] inLabel, np.ndarray[TYPE_U_INT16, ndim=2] gtLabel):
    cdef double bd1 = BestDice(inLabel,gtLabel)
    cdef double bd2 = BestDice(gtLabel,inLabel)
    if bd1 < bd2:
        return bd1
    else:
        return bd2

编译

写好pyx文件后,需要再写一个setup.py文件,里面的内容也很简单(注意修改相应的pyx文件名!!!):

import distutils.core
import Cython.Build
import numpy as np
distutils.core.setup(
    ext_modules = Cython.Build.cythonize("evaluate.pyx"),
    include_dirs = [np.get_include()])

编译:

python setup.py build_ext --inplace

编译成功后,就可以正常的 import 里面的函数了

解释

通过对比两个代码,我们可以看出一些规律也可以总结出一些规律

  1. 在导入包的时候,有一句最重要的是:cimport numpy as np,表明使用的是cython接口的numpy。(当然也有import numpy as np,编译器会根据情况使用numpy还是c-numpy);还有一句是:from libcpp cimport bool as bool_t,这是为了使用C语言中的bool类型(这个例子里面没有用到bool类型,可以不用管)
  2. 为数据类型起一个新名字:ctypedef。这个不是必须,但这里为了可读性,我列举了一些numpy中常用的数据类型对应的C语言中的数据类型
numpy C
np.uint8 unsigned char
np.uint16 unsigned short
np.uint32 unsigned int
np.uint64 unsigned long long
np.int8 signed char
np.int16 short
np.int32 int
np.int64 long long
np.float32 double
  1. 每个函数前面都有:@cython.boundscheck(False) 和 @cython.wraparound(False),这是为了加速而关闭边界检查,这样做就需要提前保证代码的准确性,建议在python下验证代码的准确性
  2. 每个函数的输入变量都定义了数据类型,比如这里全是:np.ndarray[TYPE_U_INT16, ndim=2],这表明输入的一个二维的uint16的numpy数组,如果输入类型不是这样,那就会报错
  3. cdef int/double:定义整型/双精度浮点型变量。在使用每个变量须先对它进行定义,如果没有编译器就会花时间来判断,就会耗时