欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

spring aop基于redis的令牌桶和漏桶限流

程序员文章站 2022-06-25 11:28:06
1.编写限流单位枚举类package com.around.common.utils.limitrate;import java.util.Calendar;/** * @description: 限流单位 * @author: moodincode * @create 2021/1/8 */public enum LimitUnitEnum { /**每相当于非自然时间段,即从当前时间起的时间间隔**/ PER_SECOND(0,"PER_SECOND","每秒")...

说明: 基于spring AOP的切面限流,用redis lua脚本来保持原子性,支持全局,单个应用,用户IP,甚至是el表达式解析入参中的某个字段作为限流的对象,

限流单位:支持每秒,每分,每时,或自然分,自然日等等时间单位的限流规则

限流类型:支持令牌桶和漏桶算法

支持异常时,返还扣减的次数

支持多个限流组合,例如每日访问多少日+每周或每年访问的次数等等的组合使用,可以设置优先级和是否支持异常时回退次数等

1.编写限流单位枚举类

package com.around.common.utils.limitrate;

import java.util.Calendar;

/**
 * @description: 限流单位
 * @author: moodincode
 * @create 2021/1/8
 */
public enum LimitUnitEnum {
    /**每相当于非自然时间段,即从当前时间起的时间间隔**/
    PER_SECOND(0,"PER_SECOND","每秒"),
    PER_MINUTES(0,"PER_MINUTES","每分"),
    PER_HOUR(0,"PER_HOUR","每时"),
    PER_DAY(0,"PER_DAY","每日"),
    PER_WEEK(0,"PER_WEEK","每周"),
    PER_MONTH(0,"PER_MONTH","每月"),
    PER_YEAR(0,"PER_YEAR","每年"),
    /**自然时间,即该时间的开始时间,例如自然月,即今日是1月31 则2月1日重置***/
    MINUTES(1,"MINUTES","自然分"),
    HOUR(1,"HOUR","自然时"),
    DAY(1,"DAY","自然日"),
    WEEK(1,"WEEK","自然周"),
    MONTH(1,"MONTH","自然月"),
    YEAR(1,"YEAR","自然年"),

    ;
    private Integer type;
    private String code;
    private String name;

    public Integer getType() {
        return type;
    }

    public void setType(Integer type) {
        this.type = type;
    }

    public String getCode() {
        return code;
    }

    public void setCode(String code) {
        this.code = code;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    LimitUnitEnum(Integer type,String code, String name) {
        this.type = type;
        this.code = code;
        this.name = name;
    }

    public Long getExpireTs(){
        return getExpireTs(System.currentTimeMillis());
    }
    public Long getExpireTs(Long compareTime){
        switch (this){
            case PER_SECOND:
                return compareTime+1000L;
            case PER_MINUTES:
                return  compareTime+60000L;
            case PER_HOUR:
                return compareTime+60*60000L;
            case PER_DAY:
                return  compareTime+24*60*60000L;
            case PER_WEEK:
                return compareTime+7*24*60*60000L;
            case PER_MONTH:
                return compareTime+30*24*60*60000L;
            case PER_YEAR:
                return compareTime+365*24*60*60000L;
            default:
                return getNaturalTime(compareTime);
        }

    }

    /**
     * 获取自然时间
     * @param compareTime
     * @return
     */
    private Long getNaturalTime(Long compareTime) {
        Calendar c=Calendar.getInstance();
        c.setTimeInMillis(compareTime);
        //不需要break
        switch (this){
            case YEAR:
                c.set(Calendar.MONTH,11);
            case MONTH:
                c.set(Calendar.DAY_OF_MONTH,c.getActualMaximum(Calendar.DAY_OF_MONTH));
            case WEEK:
                if(WEEK.equals(this)){
                    int week = c.get(Calendar.DAY_OF_WEEK);
                    if(week>1){
                        c.add(Calendar.DATE,8-week);
                    }
                }
            case DAY:
                c.set(Calendar.HOUR,23);
            case HOUR:
                c.set(Calendar.MINUTE,59);
            case MINUTES:
                c.set(Calendar.SECOND,59);
        }
        c.set(Calendar.MILLISECOND,999);
        return c.getTimeInMillis();
    }

}

2.编写限流类型枚举类

package com.around.common.utils.limitrate;

/**
 * @program: com-around
 * @description:
 * @author: moodincode
 * @create: 2021/1/08
 **/
public enum LimitTypeEnum {


    /**
     * 单IP-令牌桶算法-指请求用户的IP
     */
    IP_TOKEN(0,"IP_TOKEN","单IP-令牌桶算法"),
    /**
     * 单应用-令牌桶算法 -指部署的应用
     */
    APP_TOKEN(0,"APP_TOKEN","单应用-令牌桶算法"),
    /**
     * 全局-令牌桶算法  -全部
     */
    GLOBAL_TOKEN(0,"GLOBAL_TOKEN","全局-令牌桶算法"),
    /**
     * 单IP-漏桶算法-指请求用户的IP
     */
    IP_LEAKY(1,"IP_LEAKY","单IP-漏桶算法"),
    /**
     * 单应用-漏桶算法 -指部署的应用
     */
    APP_LEAKY(1,"APP_LEAKY","单应用-漏桶算法"),
    /**
     * 全局-漏桶算法  -全部
     */
    GLOBAL_LEAKY(1,"GLOBAL_LEAKY","全局-漏桶算法"),
    ;
    private int code;
    private String name;
    private String decr;

    LimitTypeEnum(int code, String name, String decr) {
        this.code = code;
        this.name = name;
        this.decr = decr;
    }

    public int getCode() {
        return code;
    }

    public String getName() {
        return name;
    }

    public String getDecr() {
        return decr;
    }



}

3.编写限流通用接口

package com.around.common.utils.limitrate;

/**
 * @description: 限流提供类
 * @author: moodincode
 * @create 2021/1/8
 */
public interface LimitRateProvider {
    /**
     * 消费指定数量的key
     * @param key 对应key
     * @param num 消费的数量
     * @param ts 最后一次消费时间 距离1970毫秒数
     * @param expire 令牌桶重置时间,多久没操作会重置,时间毫秒数
     * @param rate 速度,令牌token生成的速度或漏桶流速
     * @param interval 间隔,生成令牌的时间间隔,为0则直接增加单位速度量
     * @param capacity 令牌桶或漏桶的最大容量,-1不限制,
     * @param type 限流类型 0-令牌桶,1-漏桶
     * @param compareTime 时间间隔比较时间,为0时默认当前时间 毫秒数
     * @return
     */
   Long consumeCount(String key, Long num, Long ts, Long expire, Long rate, Long interval, Long capacity, int type, Long compareTime);

        /**
         * 查询当前容量,不涉及加减数量
         * @param key 对应key
         * @return 如果不存在返回,或数量为0均返回0,其他 大于0
         */
    Long getCurrentCount(String key);

    /**
     * 增加容量
     * @param key 对应key
     * @param num 增加的数量
     * @param capacity 容量限制  0不限制,大于0 限制
     * @param type 限流类型 0-令牌桶,1-漏桶
     * @return
     */
    Long addCount(String key,long num,long capacity,int type);
}

4.编写redis序列化工具类

package com.around.common.utils.limitrate;

import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.SerializationException;
import java.nio.charset.StandardCharsets;

/**
 * @description:
 * @author: moodincode
 * @create 2021/1/8
 */
public class LongStringSerialize implements RedisSerializer<Long> {
    @Override
    public byte[] serialize(Long number) throws SerializationException {

        return String.valueOf(number).getBytes(StandardCharsets.UTF_8);
    }

    @Override
    public Long deserialize(byte[] bytes) throws SerializationException {
        return Long.valueOf(new String(bytes,StandardCharsets.UTF_8));
    }
}

5.编写redis限流实现类和lua脚本

package com.around.common.utils.limitrate;

import com.alibaba.fastjson.support.spring.GenericFastJsonRedisSerializer;
import com.google.common.collect.Lists;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import java.util.List;

/**
 * @description:
 * @author: moodincode
 * @create 2021/1/8
 */
@Component
public class RedisLimitRateProvider implements LimitRateProvider {
    @Resource
    private RedisTemplate<String,Object> redisTemplate;
      private static final String CONSUME_TOKEN_SCRIP="local r_key = tostring(KEYS[1]);\n" +
              "local num = tonumber(ARGV[1]);\n" +
              "local interval = tonumber(ARGV[2]);\n" +
              "local rate = tonumber(ARGV[3]);\n" +
              "local capacity = tonumber(ARGV[4]);\n" +
              "local compareTime = tonumber(ARGV[5]);\n" +
              "local ts = tonumber(ARGV[6]);\n" +
              "local expire = tonumber(ARGV[7]);\n" +
              "local haskey=redis.call('exists', r_key);" +
              "if tonumber(haskey)==1 then\n " +
              "    local gen_ts = redis.call('hget', r_key, 'gen_ts');\n" +
              "    local count = redis.call('hget', r_key, 'count');\n" +
              "    local factor = math.ceil((compareTime - gen_ts) / interval - 1);\n" +
              "    if -factor > 0 then\n" +
              "        local factor = 0 \n" +
              "    end;\n" +
              "    local add_count = factor * rate;\n" +
              "    if add_count > 0 then\n" +
              "        count = count + add_count;\n" +
              "        if capacity > 0 and count > capacity then\n" +
              "            count = capacity;\n" +
              "        end;\n" +
              "        redis.call('hset', r_key, 'count', count);\n" +
              "        redis.call('hset', r_key, 'gen_ts', ts);\n" +
              "        redis.call('pexpireat', r_key, expire);\n" +
              "    end; \n" +
              "    if (count - num) > -1 then\n" +
              "        redis.call('hset', r_key, 'count', count - num);\n" +
              "        redis.call('hset', r_key, 'get_ts', ts);\n" +
              "        return (count - num);\n" +
              "    else \n" +
              "        return -1;\n" +
              "    end;\n" +
              "else\n " +
              "    redis.call('hset', r_key, 'gen_ts', ts);\n" +
              "    redis.call('hset', r_key, 'count', capacity - num);\n" +
              "    redis.call('hset', r_key, 'get_ts', ts);\n" +
              "    redis.call('pexpireat', r_key, expire); \n" +
              "    return (capacity - num);\n" +
              " end;" ;

    /**增加令牌数量*/
    private static final String ADD_TOKEN_SCRIP="local num = tonumber(ARGV[1]);\n" +
            "local capacity=tonumber(ARGV[2]);\n"
            +"if redis.call('exists',KEYS[1])=='1' then" +
            "local count = redis.call('hget',KEYS[1], 'count');" +
            "count=count+num;" +
            "if  capacity >0 and count>capacity then count=capacity; end;" +
            "local res=redis.call('hset',KEYS[1], 'count',count);" +
            " return res;" +
            "else  return -1; end;";
    /**漏桶redis脚本ARGV[1]:num ARGV[2]:interval,ARGV[3]:rate,ARGV[4]:capacity ARGV[5]:compareTime ARGV[6]:ts ARGV[7]:expire**/
    private static final String CONSUME_LEAKY_SCRIP="local r_key = tostring(KEYS[1]);\n" +
            "local num = tonumber(ARGV[1]);\n" +
            "local interval = tonumber(ARGV[2]);\n" +
            "local rate = tonumber(ARGV[3]);\n" +
            "local capacity = tonumber(ARGV[4]);\n" +
            "local compareTime = tonumber(ARGV[5]);\n" +
            "local ts = tonumber(ARGV[6]);\n" +
            "local expire = tonumber(ARGV[7]);\n" +
            "local haskey=redis.call('exists', r_key);\n" +
            "if tonumber(haskey)==1 then \n" +
            "   local gen_ts = redis.call('hget', r_key, 'gen_ts');\n" +
            "   local count = redis.call('hget', r_key, 'count');\n" +
            "   local factor = math.ceil((compareTime - gen_ts) / interval - 1);\n" +
            "   if -factor > 0 then\n" +
            "       local factor = 0 \n" +
            "   end;\n" +
            "   local de_count = factor * rate;\n" +
            "   if de_count > 0 then\n" +
            "       count = count - de_count;\n" +
            "       if -count>0 then\n" +
            "           count=0;\n" +
            "       end;\n" +
            "       redis.call('hset', r_key, 'count', count);\n" +
            "       redis.call('hset', r_key, 'gen_ts', ts);\n" +
            "       redis.call('pexpireat', r_key, expire);\n" +
            "   end;\n" +
            "    local total=(count+num);\n" +
            "   if capacity-total >-1  then\n" +
            "       redis.call('hset', r_key, 'count', total);\n" +
            "       redis.call('hset', r_key, 'get_ts', ts);\n" +
            "       return (capacity - total);\n" +
            "   else \n" +
            "       return -1;\n" +
            "   end;\n" +
            "else \n" +
            "   redis.call('hset', r_key, 'gen_ts', ts);\n" +
            "   redis.call('hset', r_key, 'count', num);\n" +
            "   redis.call('hset', r_key, 'get_ts', ts);\n" +
            "   redis.call('pexpireat', r_key, expire); \n" +
            "   return (capacity - num);\n" +
            "end;";
    /**减少漏桶数量*/
    private static final String DEDUCE_LEAKY_SCRIP="local num = tonumber(ARGV[1]);\n" +
            "local capacity=tonumber(ARGV[2]);\n" +
            "if redis.call('exists',KEYS[1])=='1' then" +
            "local count = redis.call('hget',KEYS[1], 'count');" +
            "count=count-num;" +
            "if  count <0 then count=0; end;" +
            "local res=redis.call('hset',KEYS[1], 'count',count);" +
            " return capacity-res;" +
            "else  return -1; end;";


    /**
     * 消费指定数量的key
     *
     * @param key         对应key
     * @param num         消费的数量
     * @param ts          最后一次消费时间
     * @param expire      令牌桶重置时间,多久没操作会重置,毫秒数
     * @param rate        速度,令牌token生成的速度或漏桶流速
     * @param interval    间隔,生成令牌的时间,为0则直接增加单位速度量
     * @param capacity    令牌桶或漏桶的最大容量,-1不限制,
     * @param type        限流类型 0-令牌桶,1-漏桶
     * @param compareTime 时间间隔比较时间,为0时默认当前时间
     * @return
     */
    @Override
    public Long consumeCount(String key, Long num, Long ts, Long expire, Long rate, Long interval, Long capacity, int type, Long compareTime) {
        long millis = System.currentTimeMillis();
        if (compareTime < 1) {
            compareTime = millis;
        }
        if(ts<1){
            ts=millis;
        }
        DefaultRedisScript<Object> redisScript=new DefaultRedisScript<>();
        if(type==0){
            redisScript.setScriptText(CONSUME_TOKEN_SCRIP);
        }else{
            redisScript.setScriptText(CONSUME_LEAKY_SCRIP);
        }
        redisScript.setResultType(Object.class);

        LongStringSerialize java = new LongStringSerialize();
        GenericFastJsonRedisSerializer serializer=new GenericFastJsonRedisSerializer();
        List<Object> list =(List<Object>) redisTemplate.execute(redisScript,java,serializer, Lists.newArrayList(key),num,interval,rate,capacity,compareTime,ts,expire);
        return Long.valueOf(list.get(0).toString());
    }
    /**
     * 查询当前容量,不涉及加减数量
     *
     * @param key 对应key
     * @return 如果不存在返回, 或数量为0均返回0, 其他 大于0
     */
    @Override
    public Long getCurrentCount(String key) {
        Object count = redisTemplate.opsForHash().get(key, "count");
        if(count!=null){
            return Long.valueOf(count.toString());
        }
        return 0L;
    }

    /**
     * 增加容量
     *
     * @param key      对应key
     * @param num      增加的令牌数量或减少漏桶中的数量
     * @param capacity 容量限制  对于令牌桶 小于1 不限制,大于0 限制,对于漏桶 为剩余最小值,建议为0
     * @param type     限流类型 0-令牌桶,1-漏桶
     * @return 如果key 过期则返回-1
     */
    @Override
    public Long addCount(String key, long num, long capacity, int type) {
        DefaultRedisScript<Object> redisScript=new DefaultRedisScript<>();
        if(type==0){
            redisScript.setScriptText(ADD_TOKEN_SCRIP);
        }else{
            redisScript.setScriptText(DEDUCE_LEAKY_SCRIP);
        }

        redisScript.setResultType(Object.class);
        LongStringSerialize java = new LongStringSerialize();
        GenericFastJsonRedisSerializer serializer=new GenericFastJsonRedisSerializer();
        List<Object> list =(List<Object>) redisTemplate.execute(redisScript,java,serializer, Lists.newArrayList(key),num,capacity);
        return Long.valueOf(list.get(0).toString());
    }

}

6.编写http获取request工具类

package com.around.common.utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.util.*;

/**
 * @program:
 * @description: Http工具类
 * @author: moodincode
 * @create: 2020/9/28
 **/
public class WebRequestUtil {
    private final static Logger log = LoggerFactory.getLogger(WebRequestUtil.class);
    private static String localIp;

    public static HttpServletRequest getRequest() {
        if (RequestContextHolder.getRequestAttributes() != null) {
            return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        } else {
            return null;
        }
    }

    public static HttpServletResponse getResponse() {
        if (RequestContextHolder.getRequestAttributes() != null) {
            return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getResponse();
        } else {
            return null;
        }
    }

    private static final String UNKNOWN = "unknown";

    private static final String SYMBOL = ",";

    /**
     * 获取http request body中的参数
     *
     * @param request
     * @return
     */
    public static String readRequestBodyParams(HttpServletRequest request) {
        BufferedReader br = null;
        StringBuilder sb = new StringBuilder("");
        try {
            br = request.getReader();
            String str;
            while ((str = br.readLine()) != null) {
                sb.append(str);
            }
            br.close();
        } catch (IOException e) {
            log.error("获取body参数失败", e);
        } finally {
            if (null != br) {
                try {
                    br.close();
                } catch (IOException e) {
                    log.error("获取body参数失败", e);
                }
            }
        }
        return sb.toString().replaceAll("\r|\n|\t", "");
    }

    public static String getIpAddress() {
        return getIpAddress(getRequest());
    }

    /**
     * 获取IP地址
     *
     * @param request
     * @return
     */
    public static String getIpAddress(HttpServletRequest request) {
        //没有请求则是本机IP
        if (request == null) {
            return getLocalIpAddress();
        }
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        // 如果是多级代理,那么取第一个ip为客户端ip
        if (ip != null && ip.indexOf(SYMBOL) != -1) {
            ip = ip.substring(0, ip.indexOf(SYMBOL)).trim();
        }

        return ip;
    }

    /**
     * 获取cookies
     *
     * @param request
     * @param key
     * @return
     */
    public static String getCookie(HttpServletRequest request, String key) {
        String token = null;
        Cookie[] cookies = request.getCookies();
        if (cookies != null) {
            for (int i = 0; i < cookies.length; i++) {
                if (key.equals(cookies[i].getName())) {
                    token = cookies[i].getValue();
                    break;
                }
            }
        }
        return token;
    }

    /**
     * 获取请求头信息
     *
     * @param request
     * @return
     */
    public static List<String> getHeaders(HttpServletRequest request) {
        List<String> headList = new ArrayList<>();
        Enumeration<String> headers = request.getHeaderNames();
        while (headers.hasMoreElements()) {
            String headName = headers.nextElement();
            headList.add(String.format("%s:%s", headName, request.getHeader(headName)));
        }
        return headList;
    }

    /**
     * 获取头部参数
     *
     * @param request
     * @return
     */
    public static Map<String, String> getHeaderMap(HttpServletRequest request) {
        Map<String, String> headList = new HashMap<>();
        Enumeration<String> headers = request.getHeaderNames();
        while (headers.hasMoreElements()) {
            String headName = headers.nextElement();
            headList.put(headName, request.getHeader(headName));
        }
        return headList;
    }

    /**
     * 获取本机IP地址
     *
     * @return
     * @throws Exception
     */
    public static String getLocalIpAddress() {
        if (localIp != null) {
            return localIp;
        }
        try {
            InetAddress candidateAddress = null;
            // 遍历所有的网络接口
            for (Enumeration ifaces = NetworkInterface.getNetworkInterfaces(); ifaces.hasMoreElements(); ) {
                NetworkInterface iface = (NetworkInterface) ifaces.nextElement();
                // 在所有的接口下再遍历IP
                for (Enumeration inetAddrs = iface.getInetAddresses(); inetAddrs.hasMoreElements(); ) {
                    InetAddress inetAddr = (InetAddress) inetAddrs.nextElement();
                    // 排除loopback类型地址
                    if (!inetAddr.isLoopbackAddress()) {
                        if (inetAddr.isSiteLocalAddress()) {
                            // 如果是site-local地址,就是它了
                            localIp= inetAddr.getHostAddress();
                            return localIp;
                        } else if (candidateAddress == null) {
                            // site-local类型的地址未被发现,先记录候选地址
                            candidateAddress = inetAddr;
                        }
                    }
                }
            }
            if (candidateAddress != null) {
                localIp= candidateAddress.getHostAddress();
                return localIp;
            }
            // 如果没有发现 non-loopback地址.只能用最次选的方案
            InetAddress jdkSuppliedAddress = InetAddress.getLocalHost();
            localIp= jdkSuppliedAddress.getHostAddress();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return localIp;
    }

    /**
     * 获取body内容,注意,request对象需要重写,因为request.getReader()只能用一次
     * @param request
     * @return
     */
    public static String getBodyString(HttpServletRequest request){
        //字符串读取
        try {
            BufferedReader br = request.getReader();
            String str, wholeStr = "";
            while((str = br.readLine()) != null){
                wholeStr += str;
            }
            return wholeStr;
        } catch (IOException e) {
           log.error("解析body参数失败,原因:",e);
        }
        return null;
    }
}

7.编写限流单位计算工具

package com.around.common.utils.limitrate;

/**
 * @description: 限流单位计算
 * @author: moodincode
 * @create 2021/1/08
 */
public class LimitTime {
    private Long millis;
    private Long expire;
    private Long intermission;
    private Long compareTime;

    public LimitTime(Long millis, Long expire, Long intermission, Long compareTime) {
        this.millis = millis;
        this.expire = expire;
        this.intermission = intermission;
        this.compareTime = compareTime;
    }

    public Long getMillis() {
        return millis;
    }

    public void setMillis(Long millis) {
        this.millis = millis;
    }

    public Long getExpire() {
        return expire;
    }

    public void setExpire(Long expire) {
        this.expire = expire;
    }

    public Long getIntermission() {
        return intermission;
    }

    public void setIntermission(Long intermission) {
        this.intermission = intermission;
    }

    public Long getCompareTime() {
        return compareTime;
    }

    public void setCompareTime(Long compareTime) {
        this.compareTime = compareTime;
    }
    /**
     * 计算时间范围
     * @param millis
     * @param interval
     * @param unit
     * @return  时间单位对象
     */
    public static LimitTime calculateTime(long millis, Long interval,LimitUnitEnum unit) {
        LimitTime time=new LimitTime(millis,-1L,interval,millis);
        Long expireTs = unit.getExpireTs(millis);
        time.setExpire(expireTs);
        if(interval<1000L){
            time.setIntermission(expireTs-millis);
        }

        return time;
    }
}

8.编写限流工具类

package com.around.common.utils.limitrate;

import com.around.common.utils.WebRequestUtil;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;

/**
 * @description: 限流工具类
 * @author: moodincode
 * @create 2021/1/8
 */
@Component
public class LimitRateUtil {
    /**机器ID限流**/
    public static String MACHINE_ID="";
    /**全局限流*/
    public static String GLOBAL="limit:00:";
    static {
        String localIpAddress = WebRequestUtil.getLocalIpAddress();
        if(!StringUtils.isEmpty(localIpAddress)){
            MACHINE_ID+=localIpAddress.replace(".","_").replace(":","_");
        }
            MACHINE_ID+="_"+System.currentTimeMillis();
        MACHINE_ID="limit:m"+Math.abs(MACHINE_ID.hashCode())+":";
    }
    @Resource
    private LimitRateProvider limitRateProvider;

    /**
     *
     * @param request
     * @param key
     * @param rate
     * @param interval 小于1000毫秒则采用单位时间
     * @param unit
     * @param type
     * @param capacity
     * @return
     */
    public Long consumeCount(HttpServletRequest request,String key, Long rate, Long interval, LimitUnitEnum unit, LimitTypeEnum type, Long capacity){
        //获取前缀
        String keySuffix=getKeySuffix(type,request);
        long millis = System.currentTimeMillis();
        //计算时间单位
        LimitTime time=LimitTime.calculateTime(millis,interval,unit);
        return consumeCount(keySuffix + key,1L, rate, type, capacity, time);
    }

    /**
     * 消耗次数
     * @param key
     * @param num
     * @param rate
     * @param type
     * @param capacity
     * @param time
     * @return
     */
    public Long consumeCount(String key,Long num, Long rate, LimitTypeEnum type, Long capacity, LimitTime time) {
        if(capacity<1){
            capacity=rate;
        }
        Long count =limitRateProvider.consumeCount(key,num, time.getMillis(), time.getExpire(), rate, time.getIntermission(), capacity, type.getCode(), time.getCompareTime());
        return count;
    }


    /**
     * 获取前缀
     * @param type
     * @param request
     * @return
     */
    public String getKeySuffix(LimitTypeEnum type, HttpServletRequest request) {
        if(LimitTypeEnum.GLOBAL_LEAKY.equals(type)||LimitTypeEnum.GLOBAL_TOKEN.equals(type)){
            return GLOBAL;
        }else if(LimitTypeEnum.APP_LEAKY.equals(type)||LimitTypeEnum.APP_TOKEN.equals(type)){
            return MACHINE_ID;
        }else {
            String ipAddress = WebRequestUtil.getIpAddress(request);
            return "limit:i"+ipAddress.replace(".","_").replace(":","_")+":";
        }
    }

    /**
     * 添加count 用于回滚
     * @param key
     * @param num
     * @param capacity
     * @param type
     * @return
     */
    public Long addCount(String key, Long num, Long capacity, LimitTypeEnum type){
        return limitRateProvider.addCount(key,num,capacity, type.getCode());
    }
    public static void main(String[] args) {
        System.out.println(MACHINE_ID);
    }

}

9.编写注解类

package com.around.common.utils.limitrate;

import java.lang.annotation.*;

/**
 * @description: 限流工具类
 * @author: moodincode
 * @create 2021/1/8
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Repeatable(LimitRates.class)
public @interface LimitRate {

    /**
     * redis key名称前缀,默认为类名+方法的hash key名作为缓存名称前缀
     * @return
     */
    String name() default "";

    /**
     *  key 支持EL表达式,可操作入参的对象
     * @return
     */
    String key() default "";

    /**
     * 被限制的提示信息
     * @return
     */
    String msg() default "接口被限制访问,请稍后再试";

    /**
     * 令牌从产生的速度,指定单位生成的令牌数量或漏桶的水流速度,默认熟读为5
     * @return
     */
    long rate() default 200;

    /**
     * 令牌或漏桶产生的单位间隔
     * @return
     */
    long interval() default 1;

    /**
     *
     * @return
     */
    LimitUnitEnum unit() default LimitUnitEnum.PER_SECOND;


    /**
     * 限流类型
     * @return
     */
    LimitTypeEnum type() default LimitTypeEnum.IP_TOKEN;

    /**
     * 桶的最大容量,-1不限制,0-默认速度的容量,其他为指定的容量
     * @return
     */
    long capacity() default 0;

    /**
     * 多个注解 按顺序从大到小排序
     * 排序
     * @return
     */
    int order() default 0;

    /**
     * 是否异常回滚,回退次数
     * @return
     */
    boolean rbEx() default false;
}
package com.around.common.utils.limitrate;

import java.lang.annotation.*;

/**
 * @description: 允许多个
 * @author: moodincode
 * @create 2021/1/12
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface LimitRates {
    LimitRate[] value();
}

10.编写AOP拦截器

package com.around.common.utils.limitrate;
import com.around.common.utils.WebRequestUtil;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.Order;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.text.MessageFormat;
import java.util.*;

/**
 * @description: 限流aop
 * @author: moodincode
 * @create: 2021/01/12 @ param @Order 设置优先级为较低,避免执行顺序问题
 **/
@Aspect
@Component
@Order(999)
public class LimitRateCheckAop {
    private static final Logger log= LoggerFactory.getLogger(LimitRateCheckAop.class);
    @Resource
    private LimitRateUtil limitRateUtil;
    /**
     * @param point
     */

    @Around("@annotation(LimitRate)")
    public Object around(ProceedingJoinPoint point) throws Throwable {
        log.trace("enter LimitRateCheckAop");
        MethodSignature methodSignature = (MethodSignature)  point.getSignature();
        // 获取注解中的内容
        LimitRate[] limitRates = methodSignature.getMethod().getAnnotationsByType(LimitRate.class);
        Arrays.sort(limitRates, Comparator.comparing(LimitRate::order).reversed());
        HttpServletRequest request = WebRequestUtil.getRequest();
        Map<String,LimitRate>  rollbackKey=new HashMap<>();

        for (LimitRate limitRate : limitRates) {
            String key = buildRedisKey(point, methodSignature, limitRate);
            if(limitRate.rbEx()){
                String keySuffix = limitRateUtil.getKeySuffix(limitRate.type(), request);
                key=keySuffix+key;
                long millis = System.currentTimeMillis();
                //计算时间单位
                LimitTime time=LimitTime.calculateTime(millis,limitRate.interval(),limitRate.unit());
                Long count = limitRateUtil.consumeCount(key, 1L, limitRate.rate(), limitRate.type(), limitRate.capacity(), time);
                if(count<0){
                    //抛出日常
                    throw new RuntimeException(limitRate.msg());
                }else{
                    //记录要回滚的key
                    rollbackKey.put(key,limitRate);
                }

            }else{
                Long count = limitRateUtil.consumeCount(request, key, limitRate.rate(), limitRate.interval(), limitRate.unit(), limitRate.type(), limitRate.capacity());
                if(count<0){
                    //抛出日常
                    throw new RuntimeException(limitRate.msg());
                }
            }
        }
        try {
            return point.proceed();
        }catch (Exception e){
            //检查是否需要回滚的key
            if(CollectionUtils.isEmpty(rollbackKey)){
                doRollback(rollbackKey);
            }
            throw e;
        }

    }

    /**
     * 回滚key
     * @param rollbackKey
     */
    private void doRollback(Map<String, LimitRate> rollbackKey) {
        for (Map.Entry<String, LimitRate> entry : rollbackKey.entrySet()) {
            //增加会次数
           limitRateUtil.addCount(entry.getKey(),1L,entry.getValue().capacity(), entry.getValue().type());
        }
    }

    /**
     * 构建key
     * @param point
     * @param methodSignature
     * @param limitRate
     * @return
     */
    private String buildRedisKey(ProceedingJoinPoint point, MethodSignature methodSignature,LimitRate limitRate) {
        // 方法名
        String methodName = methodSignature.getName();
        // 类名
        String className = methodSignature.getDeclaringTypeName();
        // 目标类、方法
        log.debug("类名{}方法名{}", className, methodName);
        String keyPrefix;
        //如果没有指定key前缀,则使用类名+方法名的hash值
        if(StringUtils.isEmpty(limitRate.name())){
            // 防止内容过长,只取其hash值作为key的部分
            keyPrefix = String.valueOf(Math.abs(MessageFormat.format("{0}:{1}", className, methodName).hashCode()));
        }else{
            keyPrefix=limitRate.name();
        }

        String paramKey = "";
        Object[] args = point.getArgs();
      if(!StringUtils.isEmpty(limitRate.key())){
            //构建EL表达式的key
            paramKey = buildExpressKey(methodSignature, limitRate.key(), args);
        }
        return keyPrefix+paramKey;
    }

    /**
     * 使用el表达式解析key
     * @param methodSignature
     * @param el
     * @param args
     * @return
     */
    private String buildExpressKey(MethodSignature methodSignature, String el, Object[] args) {
        String paramKey="";

        ExpressionParser expressionParser = new SpelExpressionParser();
        Expression expression = expressionParser.parseExpression(el);
        EvaluationContext context = new StandardEvaluationContext();
        String[] parameterNames = methodSignature.getParameterNames();
        for (int i = 0; i <parameterNames.length ; i++) {
            context.setVariable(parameterNames[i],args[i]);
        }
        try {
            //解析el表达式,将#id等替换为参数值
            paramKey= expression.getValue(context).toString();
            log.debug("限流工具构建key,el表达式解析key成功,el={},解析后值为:{}",el,paramKey);
        }catch (Exception e){
            log.error("参数key={}解析失败,{}{}",el,e,e.getMessage());
        }
        return paramKey;
    }
}

11.在需要拦截的controller类上使用@LimitRate注解进行拦截,具体设置项看注解

项目git地址 https://gitee.com/moodincode/com_around_project/tree/master/common-utils

本文地址:https://blog.csdn.net/qq_24577585/article/details/112539375