纯C++超分辨率重建LapSRN --改编--(三)转置卷积vl_nnconvt函数
程序员文章站
2022-03-09 13:06:55
...
接前文,
转置卷积vl_nnconvt:
void vl_nnconvt(卷积层 const *indata,卷积层 *out,层数据 *filters_biases,//残差放大
int upsampleX = 1 , int upsampleY = 1 ,
int cropLeft = 0 , int cropRight = 0 , int cropTop = 0 , int cropBottom = 0 )
{
vl_nnconvt_forward(
out,
indata,
filters_biases,
//filters,
//biases,
upsampleY, upsampleX,
cropTop, cropBottom, cropLeft, cropRight) ;
}
vl_nnconvt_forward 分两处分别处理核 和偏置:
void vl_nnconvt_forward(
卷积层 * output,
卷积层 const *data,
层数据 *filters_biases,
int upsampleY, int upsampleX,
int cropTop, int cropBottom,
int cropLeft, int cropRight)
{
vl_impl_nnconv_backward_blas_CPU_float //核
(
output,
data,
filters_biases,
upsampleY, upsampleX,
cropTop, cropBottom,
cropLeft, cropRight) ;
if (filters_biases->偏移长度 > 0) {//偏置
vl_nnbias_forward(
output, 1,
0,
filters_biases, 1) ;
}
}
这个vl_impl_nnconv_backward_blas_CPU_float就是前面的哪个反向:
void vl_impl_nnconv_backward_blas_CPU_float(
卷积层 * derData,
卷积层 const *derOutput,
层数据 *filters_biases,
int strideY, int strideX,
int padTop, int padBottom,
int padLeft, int padRight)
{
int numGroups = 0 ;
int numFiltersPerGroup = 0 ;
int filtersVolume = 0 ;
int tempVolume = 0 ;
float* tempMemory = NULL ;
//int numOutputPixels = derOutput.getHeight() * derOutput.getWidth() ;
int numOutputPixels = derOutput->height * derOutput->width ;
int filters_getDepth=0, filters_getHeight=0, filters_getWidth=0;
if(filters_biases->权重长度==36864)
{
filters_getHeight=3;
filters_getWidth=3;
filters_getDepth=64;
}else if(filters_biases->权重长度==16)
{
filters_getHeight=4;
filters_getWidth=4;
filters_getDepth=1;
}
if (derData) {
//numGroups = derData.getDepth() / filters.getDepth() ;
numGroups = derData->depth / filters_getDepth ;
//filtersVolume = filters.getHeight() * filters.getWidth() * filters.getDepth() ;
filtersVolume = filters_getHeight * filters_getWidth * filters_getDepth ;
}
//numFiltersPerGroup = derOutput.getDepth() / numGroups ;
numFiltersPerGroup = derOutput->depth / numGroups ;
// 获得临时空间
tempVolume = numOutputPixels * filtersVolume * numGroups ;
if (tempVolume) {
//tempMemory = (float*) context.getWorkspace(CPU, tempVolume * sizeof(float)) ;
tempMemory = new float[tempVolume * sizeof(float)];
if (tempMemory == NULL) {
printf("分配内存错误!\n");
goto done ;
}
}
/* compute derData dz/dx */
if (derData)
{
//int filterGrpOffset = filtersVolume * numFiltersPerGroup * g ;
//int tempGrpOffset = numOutputPixels * filtersVolume * g ;
//int derOutputGrpOffset = numOutputPixels * numFiltersPerGroup * g ;
float alpha = 1 ;
float beta = 0 ;
//printf("gemm<CPU,float>\n");
gemm(
'n', 't',
numOutputPixels, filtersVolume, numFiltersPerGroup,
alpha,
(float*)derOutput->data /*.getMemory() + derOutputOffset + derOutputGrpOffset*/, numOutputPixels,
(float*)filters_biases->权重_数据/*filters.getMemory() + filterGrpOffset*/, filtersVolume,
beta,
tempMemory /*+ tempGrpOffset*/, numOutputPixels) ;
{
//在vl_impl_row2im中对数据进行转置,这里要还原
vl_impl_row2im(
(float*)derData->data/*.getMemory() + derDataOffset*/, //;tmpderData
tempMemory,
//derData->height, derData->width, derData->depth,//derData.getHeight(), derData.getWidth(), derData.getDepth(),
derData->width, derData->height, derData->depth,////宽、高倒置
filters_getHeight,filters_getWidth,//filters.getHeight(), filters.getWidth(),
strideY, strideX,
padTop, padBottom, padLeft, padRight) ;
}
}
done:
if (tempMemory != NULL) {
delete []tempMemory; tempMemory=NULL;
}
}
vl_impl_row2im:
void vl_impl_row2im(
float* data,
float const* stacked,
size_t height, size_t width, size_t depth,
size_t windowHeight, size_t windowWidth,
size_t strideY, size_t strideX,
size_t padTop, size_t padBottom, size_t padLeft, size_t padRight)
{
//由于row2im_cpu是为matlab(数据按列排)设计的,所以先把data转置一下
//data中没有数据,转置省略
row2im_cpu(data,stacked ,//
height, width, depth,
windowHeight, windowWidth,
strideY, strideX,
padTop, padBottom, padLeft, padRight) ;
}
row2im_cpu:
static inline void
row2im_cpu(float* data,
float const* stacked,
size_t width,
size_t height,
size_t depth,
size_t windowWidth,
size_t windowHeight,
size_t strideX,
size_t strideY,
size_t padLeft,
size_t padRight,
size_t padTop,
size_t padBottom)//转置
// size_t height,
//size_t width,
//size_t depth,
// size_t windowHeight,
//size_t windowWidth,
// size_t strideY,
//size_t strideX,
// size_t padTop,
//size_t padBottom,
//size_t padLeft,
//size_t padRight) //无转置
{//printf("row2im_cpu\n");
int numPatchesX = (width + (padLeft + padRight) - windowWidth)/strideX + 1 ;
int numPatchesY = (height + (padTop + padBottom) - windowHeight)/strideY + 1 ;
int numRows = windowWidth * windowHeight * depth ;
memset(data, 0, sizeof(float) * width * height * depth) ;
/*
与im2col相反,仍扫描堆叠图像的行。
有关算法的说明,请参见im2col的注释。
*/
for (int row = 0; row < numRows ; ++row) {
int u = row ;
int v = u / windowWidth ;
int z = v / windowHeight ;
u %= windowWidth ;
v %= windowHeight ;
int x0 = static_min(numPatchesX, ceil_divide(padLeft - u, strideX)) ;
int y0 = static_min(numPatchesY, ceil_divide(padTop - v, strideY)) ;
int x1 = static_min(numPatchesX, floor_divide(width-1 + padLeft - u, strideX) + 1) ;
int y1 = static_min(numPatchesY, floor_divide(height-1 + padTop - v, strideY) + 1) ;
int x ;
int y ;
y = static_max(0, y0) ;
stacked += numPatchesX * static_max(y, 0) ;
for ( ; y < y1 ; ++y) {
x = static_max(0, x0) ;
int y_data = y * strideY + v - padTop ;
int x_data = x * strideX + u - padLeft ;
float * b = data + (z * height + y_data) * width + x_data ;
stacked += x ;
for ( ; x < x1 ; ++x) {
*b += *stacked++ ;
b += strideX ;
}
stacked += numPatchesX - x ;
}
stacked += numPatchesX * (numPatchesY - y) ;
}
}
vl_nnconvt函数已完成。