欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Java Post请求的request只可以读取一次的问题解决

程序员文章站 2024-02-03 11:21:52
...

最近在做文件上传下载,需要从request中获取文件流,然后我发现,Post请求的request只可以读取一次,之后就读不到了,思索了一下午想到了解决方法,精髓就是将request的文件流每次进行保存,就可以反复进行读取了,代码如下:

package com.openailab.oascloud.gateway.util;

import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.parser.Feature;
import com.google.common.collect.Maps;
import com.netflix.zuul.context.RequestContext;
import com.netflix.zuul.http.HttpServletRequestWrapper;
import com.netflix.zuul.http.ServletInputStreamWrapper;
import com.openailab.oascloud.common.util.IPUtil;
import com.openailab.oascloud.gateway.filter.CharacterEncodeFilter;
import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileUploadException;
import org.apache.commons.fileupload.disk.DiskFileItemFactory;
import org.apache.commons.fileupload.servlet.ServletFileUpload;
import org.apache.commons.fileupload.servlet.ServletRequestContext;
import org.apache.tomcat.util.http.fileupload.FileItemIterator;
import org.apache.tomcat.util.http.fileupload.FileItemStream;
import org.apache.tomcat.util.http.fileupload.util.Streams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StreamUtils;

import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.nio.charset.Charset;
import java.util.*;

/**
 * @Classname: com.openailab.oascloud.gateway.util.ParamUtil
 * @Description: 描述
 * @Author: zxzhang
 * @Date: 2019/6/27
 */
public class ParamUtil {

    private static final Logger LOGGER = LoggerFactory.getLogger(CharacterEncodeFilter.class);

    /**
     * 获取请求参数(Get||Post)
     *
     * @param ctx
     * @return java.util.Map<java.lang.String, java.lang.Object>
     * @author zxzhang
     * @date 2019/6/27
     */
    public static Map<String, Object> getRequestParams(RequestContext ctx) {
        String method = ctx.getRequest().getMethod();
        String uri = ctx.getRequest().getRequestURI();
        //判断是POST请求还是GET请求(不同请求获取参数方式不同)
        LinkedHashMap param = null;
        try {
            if (uri.startsWith("/zuul")) {
                byte[] requestByte = saveaIns(ctx.getRequest().getInputStream());
                rewriteRequest(ctx, requestByte);
                param = Maps.newLinkedHashMap();
                DiskFileItemFactory factory = new DiskFileItemFactory();
                ServletFileUpload upload = new ServletFileUpload(factory);
                upload.setHeaderEncoding("UTF-8");
                List<FileItem> list = upload.parseRequest(new ServletRequestContext(ctx.getRequest()));
                for (FileItem item : list) {
                    String name = item.getFieldName();
                    if (item.isFormField()) {
                        String value = item.getString("UTF-8");
                        param.put(name, value);
                    } else {
                        String filename = item.getName();
                        param.put(name, filename);
                    }
                }
                rewriteRequest(ctx, requestByte);
            } else {
                if ("GET".equals(method.toUpperCase())) {
                    Map<String, List<String>> map = ctx.getRequestQueryParams();
                    if (!(Objects.isNull(map) || map.isEmpty())) {
                        param = Maps.newLinkedHashMap();
                        for (Map.Entry<String, List<String>> entry : map.entrySet()) {
                            param.put(entry.getKey(), entry.getValue().get(0));
                        }
                    }
                } else if ("POST".equals(method.toUpperCase())) {
                    try (InputStream inputStream = ctx.getRequest().getInputStream()) {
                        String body = StreamUtils.copyToString(inputStream, Charset.forName("UTF-8"));
                        LOGGER.info("***************原始参数:{}***************", body);
                        param = JSONObject.parseObject(body, LinkedHashMap.class, Feature.OrderedField);
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            }
        } catch (Exception e) {
            LOGGER.error("***************ParamUtil->getRequestParams throw Exception:{}***************", e);
        }
        return param;
    }

    /**
     * 设置请求参数(Get||Post)
     *
     * @param ctx
     * @return void
     * @author zxzhang
     * @date 2019/6/28
     */
    public static void setRequestParams(RequestContext ctx) {
        String uri = ctx.getRequest().getRequestURI();
        String source = uri.indexOf("/platform") > 0 ? "platform" : "application";
        String client = uri.indexOf("/web") > 0 ? "web" : "app";
        String method = ctx.getRequest().getMethod();
        HttpServletRequest request = ctx.getRequest();
        if (uri.startsWith("/zuul")) {
            return;
        } else {
            //判断是POST请求还是GET请求(不同请求获取参数方式不同)
            if ("GET".equals(method.toUpperCase())) {
                Map<String, List<String>> param = ctx.getRequestQueryParams();
                if (Objects.isNull(param) || param.isEmpty()) {
                    param = Maps.newHashMap();
                }
                param.put("source", Arrays.asList(source));
                param.put("client", Arrays.asList(client));
                param.put("ip", Arrays.asList(IPUtil.getClientIp(request)));
                ctx.setRequestQueryParams(param);
            } else if ("POST".equals(method.toUpperCase())) {
                try (InputStream inputStream = ctx.getRequest().getInputStream()) {
                    String body = StreamUtils.copyToString(inputStream, Charset.forName("UTF-8"));
                    Map<String, Object> param = JSONObject.parseObject(body);
                    if (Objects.isNull(param) || param.isEmpty()) {
                        param = Maps.newHashMap();
                    }
                    param.put("source", source);
                    param.put("client", client);
                    param.put("ip", IPUtil.getClientIp(request));
                    // 重写上下文的HttpServletRequestWrapper
                    final byte[] paramBytes = param.toString().getBytes();
                    rewriteRequest(ctx, paramBytes);
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    /**
     * 重新写入生成request
     *
     * @param ctx
     * @param paramBytes
     * @return void
     * @author zxzhang
     * @date 2019/10/21
     */
    private static void rewriteRequest(RequestContext ctx, byte[] paramBytes) {
        ctx.setRequest(new HttpServletRequestWrapper(ctx.getRequest()) {
            @Override
            public ServletInputStream getInputStream() throws IOException {
                return new ServletInputStreamWrapper(paramBytes);
            }

            @Override
            public int getContentLength() {
                return paramBytes.length;
            }

            @Override
            public long getContentLengthLong() {
                return paramBytes.length;
            }
        });
    }


    /**
     * 保存流对象(输入流在第二次使用的时候会失效)
     * 在需要用到InputStream的地方再封装成InputStream
     * ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(buf);
     * Workbook wb = new HSSFWorkbook(byteArrayInputStream);//byteArrayInputStream 继承了InputStream,故这样用并没有问题
     * 如果只需要用到一次inputstream流,就不用这样啦,直接用就OK
     *
     * @param ins
     */
    public static byte[] saveaIns(InputStream ins) {
        byte[] buf = null;
        try {
            if (ins != null) {
                buf = org.apache.commons.io.IOUtils.toByteArray(ins);//ins为InputStream流
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return buf;
    }
}

这段代码是在zuul的过滤器中重复获取Post请求文件流并重复读取利用的方法。