欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

redis+lua+拦截器实现限流

程序员文章站 2022-07-09 20:34:32
...

前言

当非法用户死命调用你的接口(机器攻击)时,怎么办?
正常情况下,用户是不会那么频繁的通过前端调用你的接口的。一般出现某个用户极其频繁的调用你的接口时,那就一定要小心了,可能是想搞你!!!所以,一定要在API调用前端加个限流策略,也就是将用户的一段时间的访问次数记下来,超过某个值的时候,拒绝其访问。这种限流,可以加在nginx里面,也可以加在项目的过滤器中。
但是这种高频数据放在哪呢?数据库?那你的数据库可能直接就炸了!!!ok,还是放在redis里吧!这时候就要考虑操作的原子性了。
今天用lua来实现这个功能!

实现

lua实现的限流脚本

-- ip限流脚本
-- 限定每个ip在expire_time时间段内只能访问limit次
-- 脚本返回0,说明已到达上限,返回1,说明没有到达上限

local ip = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]

local exist = redis.call('exists', ip)
if exist == 1 then
    if redis.call('incr', ip) > limit then
        return 0
    else
        return 1
    end
else
    redis.call('set', ip, 1);
    redis.call('expire', ip, expire_time)
    return 1
end

实现一个限流处理器

package com.zyu.boot.demo.utils.limit;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;

import java.util.Arrays;

/**
 * redis实现IP限流器
 */
@Component
public class IPLimiter {

    private RedisTemplate redisTemplate;
    /**
     * 每个流控周期允许访问的次数
     */
    private final int limit = 5;
    /**
     * 流控的时间段:秒
     */
    private final int expire = 2;
    /**
     * 限流的脚本
     */
    private final String LIMITER_SCRIPT_PATH = "lua/ipLimiter.lua";

    private DefaultRedisScript<Boolean> limiterScript;

    public IPLimiter(@Autowired RedisTemplate redisTemplate) {
        this.redisTemplate = redisTemplate;
        limiterScript = new DefaultRedisScript();
        limiterScript.setLocation(new ClassPathResource(LIMITER_SCRIPT_PATH));
        limiterScript.setResultType(Boolean.class);
    }

    /**
     * 返回true表示未触发流控,false表示触发流控
     *
     * @param ip
     * @return
     */
    public boolean limiterValidate(String ip) {
        return (Boolean) redisTemplate.execute(limiterScript, Arrays.asList(new String[]{ip}), limit, expire);
    }
}

限流器测试代码

package com.zyu.boot.demo;

import com.zyu.boot.demo.utils.limit.IPLimiter;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringRunner;

@SpringBootTest(classes = {DemoApplication.class})
@RunWith(SpringRunner.class)
public class LimiterTest {
    @Autowired
    IPLimiter ipLimiter;
    @Test
    public void test(){
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        Thread.sleep(2000);
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
        System.out.println(ipLimiter.limiterValidate("127.0.0.1"));
    }
}

redis+lua+拦截器实现限流

写个拦截器

这里使用拦截器实现,还可以使用过滤器、或者spring的AOP配合注解实现更细粒度的方法层限流

package com.zyu.boot.demo.interceptor;

import com.alibaba.fastjson.JSONObject;
import com.zyu.boot.demo.utils.item.RespEntity;
import com.zyu.boot.demo.utils.limit.IPLimiter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;

/**
 * 实现流控的过滤器
 */
@Component
public class LimiterInterceptor implements HandlerInterceptor {
    @Autowired
    private IPLimiter ipLimiter;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if (request.getRequestURI().contains(".html") ||
                request.getRequestURI().contains(".css") ||
                request.getRequestURI().contains(".js")) {
            //静态资源请求放行
            return true;
        }
        String ip = getIPAddress(request);
        if (ipLimiter.limiterValidate(ip)) {
            return true;
        } else {
            PrintWriter writer = response.getWriter();
            writer.write(JSONObject.toJSONString(new RespEntity(-999, "触发流控", null)));
            writer.close();
        }
        return false;
    }

    /**
     * 根据request获取客户端IP地址
     *
     * @param request
     * @return
     */
    private String getIPAddress(HttpServletRequest request) {
        String ip = null;

        //X-Forwarded-For:Squid 服务代理
        String ipAddresses = request.getHeader("X-Forwarded-For");
        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            //Proxy-Client-IP:apache 服务代理
            ipAddresses = request.getHeader("Proxy-Client-IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            //WL-Proxy-Client-IP:weblogic 服务代理
            ipAddresses = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            //HTTP_CLIENT_IP:有些代理服务器
            ipAddresses = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            //X-Real-IP:nginx服务代理
            ipAddresses = request.getHeader("X-Real-IP");
        }

        //有些网络通过多层代理,那么获取到的ip就会有多个,一般都是通过逗号(,)分割开来,并且第一个ip为客户端的真实IP
        if (ipAddresses != null && ipAddresses.length() != 0) {
            ip = ipAddresses.split(",")[0];
        }

        //还是不能获取到,最后再通过request.getRemoteAddr();获取
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
            ip = request.getRemoteAddr();
        }
        return ip.equals("0:0:0:0:0:0:0:1") ? "127.0.0.1" : ip;
    }
}

配置拦截器

package com.zyu.boot.demo.config;

import com.zyu.boot.demo.interceptor.LimiterInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.Order;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

import java.util.ArrayList;
import java.util.List;

@Configuration
@Order(1)
public class WebConfig implements WebMvcConfigurer {

    @Autowired
    private LimiterInterceptor limiterInterceptor;// 流控拦截器
    /**
     * 配置拦截器
     */
    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        //配置拦截器,执行顺序和声明顺序一致
        registry.addInterceptor(limiterInterceptor).addPathPatterns("/**");
        WebMvcConfigurer.super.addInterceptors(registry);
    }
}

结束语

基本功能还是实现了,开心!