欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

测试开发进阶(三十六)

程序员文章站 2022-04-10 20:58:58
...

项目模块

list优化

def list(self, request, *args, **kwargs):	
    queryset = self.filter_queryset(self.get_queryset())	
    page = self.paginate_queryset(queryset)	
    if page is not None:	
        serializer = self.get_serializer(page, many=True)	
        datas = serializer.data	
        datas = get_count_by_project(datas)	
        return self.get_paginated_response(datas)	
    serializer = self.get_serializer(queryset, many=True)	
    datas = serializer.data	
    datas = get_count_by_project(datas)	
    return Response(datas)

这个 list其实就是拷贝了父类中的 list方法

使用

super().list(request, *args, **kwargs)

调用父类的 list方法

查看返回的 Response对象

测试开发进阶(三十六)

所以优化为:

def list(self, request, *args, **kwargs):	
    response = super().list(request, *args, **kwargs)	
    response.data['results'] = get_count_by_project(response.data['results'])	
    return response

重写getserializerclass

names中的 serializer使用 serializers.ProjectNameSerializer

为了让它可以直接使用 self.get_serializer方法,重写 get_serializer_class

源码

def get_serializer_class(self):	
    """	
    Return the class to use for the serializer.	
    Defaults to using `self.serializer_class`.	
    You may want to override this if you need to provide different	
    serializations depending on the incoming request.	
    (Eg. admins get full serialization, others get basic serialization)	
    """	
    assert self.serializer_class is not None, (	
        "'%s' should either include a `serializer_class` attribute, "	
        "or override the `get_serializer_class()` method."	
        % self.__class__.__name__	
    )	
    return self.serializer_class

重写

def get_serializer_class(self):	
    if self.action == 'names':	
        return serializers.ProjectNameSerializer	
    else:	
        return self.serializer_class

报告模块

序列化器

from datetime import datetime	
from rest_framework import serializers	
from .models import Reports	
class ReportsSerializer(serializers.ModelSerializer):	
    """	
    报告序列化器	
    """	
    class Meta:	
        model = Reports	
        exclude = ('update_time', 'is_delete')	
        extra_kwargs = {	
            'html': {	
                'write_only': True	
            },	
            'create_time': {	
                'read_only': True	
            }	
        }	
    def create(self, validated_data):	
        report_name = validated_data['name']	
        validated_data['name'] = f"{report_name}_{datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')}"	
        report = Reports.objects.create(**validated_data)	
        return report

从数据库中可以看出其中 html是一串字符串,它需要转换成html格式才可以正常展示,所以在接口返回的内容中不应该包含它,设置它为只写模式 write_only

create函数进行重定义

下面是数据库中显示的内容

测试开发进阶(三十六)

name = models.CharField('报告名称', max_length=200, unique=True, help_text='报告名称')

查看 models文件可以看到 name字段是唯一的「 unique=True」所以我们在添加的时候需要携带上当前的时间信息

视图

定义一个类 ReportsViewSet还是继承 ModelViewSet

其他和之前的类似

其中要注意的是一个 download接口

import re	
import os	
from datetime import datetime	
from django.conf import settings	
from django.http import StreamingHttpResponse	
from rest_framework.viewsets import ModelViewSet	
from rest_framework import permissions	
from rest_framework.decorators import action	
from reports.utils import format_output, get_file_contents	
from .models import Reports	
from .serializers import ReportsSerializer	
class ReportsViewSet(ModelViewSet):	
    """	
    list:	
    返回测试报告(多个)列表数据	
    create:	
    创建测试报告	
    retrieve:	
    返回测试报告(单个)详情数据	
    update:	
    更新(全)测试报告	
    partial_update:	
    更新(部分)测试报告	
    destroy:	
    删除测试报告	
    """	
    queryset = Reports.objects.filter(is_delete=False)	
    serializer_class = ReportsSerializer	
    permission_classes = (permissions.IsAuthenticated,)	
    ordering_fields = ('id', 'name')	
    def perform_destroy(self, instance):	
        instance.is_delete = True	
        instance.save()	
    def list(self, request, *args, **kwargs):	
        response = super().list(request, *args, **kwargs)	
        response.data['results'] = format_output(response.data['results'])	
        return response	
    @action(detail=True)	
    def download(self, request, pk=None):	
        instance = self.get_object()	
        html = instance.html	
        name = instance.name	
        mtch = re.match(r'(.*_)\d+', name)	
        if mtch:	
            mtch = mtch.group(1) + datetime.strftime(datetime.now(), '%Y%m%d%H%M%S') + '.html'	
        report_dir = os.path.join(settings.BASE_DIR, 'reports')	
        report_path = os.path.join(report_dir, mtch)	
        with open(report_path, 'w') as f:	
            f.write(html)	
        response = StreamingHttpResponse(get_file_contents(report_path))	
        response['Content-Type'] = "application/octet-stream"	
        response['Content-Disposition'] = "attachment; filename*=UTF-8''{}".format(name)	
        return response

每次下载之后我们都会在本地存放一次,然后我们需要以数据流的方式返回html报告

response = StreamingHttpResponse(get_file_contents(report_path))
def get_file_contents(filename, chunk_size=512):	
    with open(filename,encoding='utf8') as f:	
        while True:	
            c = f.read(chunk_size)	
            if c:	
                yield c	
            else:	
                break

这里用到了分段的方式,每512字节返回一次,直到全部返回完毕