欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Spring boot 添加 XssFilter过滤器

程序员文章站 2022-04-28 16:41:46
...

第一步部分代码: XSS_ERROR(90006, “入参含有非法字符”)

@Component
@Slf4j
@WebFilter(filterName = "xssFilter", urlPatterns = "/*")
@Order(5)
public class XssFilter implements Filter {

    private static final String SCRIPT_LOW_REGEX = ".*((((\\%3C)|<)[^\\n]+((\\%3E)|>))|(((\\%22)|"|(\\%27)|')[(\\%20) ]*((\\%2B)|\\+|(\\%3B)|;))|(((\\%3D)|=)[(\\%20) ]*((\\%22)|"|(\\%27)|'))).*";
    private static final String SCRIPT_UPPER_REGEX = ".*((((\\%3C)|<)[^\\n]+((\\%3E)|>))|(((\\%22)|"|(\\%27)|')[(\\%20) ]*((\\%2B)|\\+|(\\%3B)|;))|(((\\%3D)|=)[(\\%20) ]*((\\%22)|"|(\\%27)|'))).*";
    private static final String SQL_REGEX = ".*((\\%27)|(\\')).*((\\-\\-)|(((\\%6F)|o|(\\%4F))((\\%72)|r|(\\%52)))|((\\%3B)|(;))).*";
    private static final String NOT_PROTECT = "/undefined$";

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        System.out.println("*************执行过滤xssFilter******");
        HttpServletResponse res = (HttpServletResponse) servletResponse;
        res.setCharacterEncoding("UTF-8");
        res.setContentType("application/json; charset=utf-8");
        HttpServletRequest request ;
        if (servletRequest instanceof HttpServletRequest) {
            request = new XssHttpServletRequestWrapper((HttpServletRequest) servletRequest);
        } else {
            filterChain.doFilter(servletRequest, servletResponse);
            return;
        }

        Map<String, String[]> parameterMap = getParams(request);
        if (Objects.isNull(parameterMap) || parameterMap.size() == 0) {
            parameterMap = request.getParameterMap();
        }

        String url = request.getRequestURI();
        final Pattern scriptLowRegex = Pattern.compile(SCRIPT_LOW_REGEX);
        final Pattern sqlRegex = Pattern.compile(SQL_REGEX);
        final Pattern notProject = Pattern.compile(NOT_PROTECT);
        final Pattern scriptUpperRegex = Pattern.compile(SCRIPT_UPPER_REGEX);
        if (!notProject.matcher(url).find()) {
            if (Objects.nonNull(parameterMap) && parameterMap.size() > 0) {
                Iterator<Map.Entry<String, String[]>> iterator = parameterMap.entrySet().iterator();
                while (iterator.hasNext()) {
                    Map.Entry<String, String[]> next = iterator.next();
                    String[] value = next.getValue();
                    for (int i = 0; i < value.length; i++) {
                        String paraValue = value[i].toUpperCase();
                        paraValue = replaceStr(paraValue);
                        if (scriptLowRegex.matcher(paraValue).matches()|| sqlRegex.matcher(paraValue).matches() || scriptUpperRegex.matcher(paraValue).matches()) {
                            PrintWriter writer = null;
                            try {
                                writer = res.getWriter();
                                String resultVal = JSON.toJSONString(ApiRes.getInstance(ResultEnum.XSS_ERROR));
                                writer.write(resultVal);
                                writer.flush();
                                writer.close();
                            } catch (Exception e) {
                                log.error("AuthPathFilter Error" + e.getMessage(), e);
                            } finally {
                                if (null != writer) {
                                    writer.close();
                                }
                            }
                            return;
                        }
                    }
                }
            }
        }
        filterChain.doFilter(request, servletResponse);
        return;
    }

    @Override
    public void destroy() {

    }

    private Map<String, String[]> getParams(HttpServletRequest request) {
        Map<String, String[]> paras = null;
        BufferedReader streamReader = null;
        try {
            streamReader = new BufferedReader(new InputStreamReader(request.getInputStream(), "UTF-8"));
            StringBuilder responseStrBuilder = new StringBuilder();
            String inputStr;
            while ((inputStr = streamReader.readLine()) != null) {
                responseStrBuilder.append(inputStr);
            }
            String jString = responseStrBuilder.toString();
            if (StringUtils.isNotBlank(jString)) {
                boolean valid = isJSONValid(jString);
                if (valid) {
                    paras = new HashMap<>();
                    JSONObject jsonObject = JSONObject.parseObject(jString);
                    Iterator<Map.Entry<String, Object>> iterator = jsonObject.entrySet().iterator();
                    while (iterator.hasNext()) {
                        Map.Entry<String, Object> next = iterator.next();
                        String key = next.getKey();
                        if (Objects.isNull(next.getValue())) {
                            paras.put(key, new String[]{""});
                        } else {
                            String value = next.getValue().toString();
                            paras.put(key, new String[]{value});
                        }
                    }
                }
            }
        } catch (Exception e) {
            log.error("参数解析错误", e);
        }
        return paras;
    }

    private boolean isJSONValid(String jsonInString) {
        try {
            final ObjectMapper mapper = new ObjectMapper();
            mapper.readTree(jsonInString);
            return true;
        } catch (IOException e) {
            return false;
        }
    }

    private String replaceStr(String value) {
        value = value.replace("+", "%2B")
                //.replace("/", "%2F")
                .replace("?", "%3F")
                .replace("%", "%25")
                .replace("#", "%23")
                .replace("&", "%26")
                .replace("=", "%3D")
                .replace("@", "%40")
                .replace(":", "%3A")
                .replace(";", "%3B")
                .replace("<", "%3C")
                .replace(">", "%3E")
                .replace("\\", "%5C")
                .replace("|", "%7C")
                .replace("$", "%24")
                .replace("^", "%5E")
                .replace(",", "%2C")
                .replace("'", "%27")
                .replace("=", "%3D")
                .replace("[", "%5B")
                .replace("]", "%5D")
                .replace("{", "%7B")
                .replace("}", "%7D")
                .replace("\"", "%22");
        return value;
    }

}

第二部分代码:

public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {

    private final byte[] bytes;

    /**
     * Constructs a request object wrapping the given request.
     *
     * @param request The request to wrap
     * @throws IllegalArgumentException if the request is null
     */
    public XssHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        // 读取输入流里的请求参数,并保存到bytes里
        bytes = IOUtils.toByteArray(request.getInputStream());

    }
    @Override
    public ServletInputStream getInputStream() throws IOException {
        return new BufferedServletInputStream(this.bytes);
    }

    class BufferedServletInputStream extends ServletInputStream {
        private ByteArrayInputStream inputStream;
        public BufferedServletInputStream(byte[] buffer) {
            //此处即赋能,可以详细查看ByteArrayInputStream的该构造函数;
            this.inputStream = new ByteArrayInputStream( buffer );
        }
        @Override
        public int available() throws IOException {
            return inputStream.available();
        }
        @Override
        public int read() throws IOException {
            return inputStream.read();
        }
        @Override
        public int read(byte[] b, int off, int len) throws IOException {
            return inputStream.read( b, off, len );
        }

        @Override
        public boolean isFinished() {
            return false;
        }

        @Override
        public boolean isReady() {
            return false;
        }

        @Override
        public void setReadListener(ReadListener listener) {

        }
    }
}
相关标签: Spring Boot xss