spring aop基于redis的令牌桶和漏桶限流
程序员文章站
2022-03-12 21:27:09
1.编写限流单位枚举类package com.around.common.utils.limitrate;import java.util.Calendar;/** * @description: 限流单位 * @author: moodincode * @create 2021/1/8 */public enum LimitUnitEnum { /**每相当于非自然时间段,即从当前时间起的时间间隔**/ PER_SECOND(0,"PER_SECOND","每秒")...
说明: 基于spring AOP的切面限流,用redis lua脚本来保持原子性,支持全局,单个应用,用户IP,甚至是el表达式解析入参中的某个字段作为限流的对象,
限流单位:支持每秒,每分,每时,或自然分,自然日等等时间单位的限流规则
限流类型:支持令牌桶和漏桶算法
支持异常时,返还扣减的次数
支持多个限流组合,例如每日访问多少日+每周或每年访问的次数等等的组合使用,可以设置优先级和是否支持异常时回退次数等
1.编写限流单位枚举类
package com.around.common.utils.limitrate;
import java.util.Calendar;
/**
* @description: 限流单位
* @author: moodincode
* @create 2021/1/8
*/
public enum LimitUnitEnum {
/**每相当于非自然时间段,即从当前时间起的时间间隔**/
PER_SECOND(0,"PER_SECOND","每秒"),
PER_MINUTES(0,"PER_MINUTES","每分"),
PER_HOUR(0,"PER_HOUR","每时"),
PER_DAY(0,"PER_DAY","每日"),
PER_WEEK(0,"PER_WEEK","每周"),
PER_MONTH(0,"PER_MONTH","每月"),
PER_YEAR(0,"PER_YEAR","每年"),
/**自然时间,即该时间的开始时间,例如自然月,即今日是1月31 则2月1日重置***/
MINUTES(1,"MINUTES","自然分"),
HOUR(1,"HOUR","自然时"),
DAY(1,"DAY","自然日"),
WEEK(1,"WEEK","自然周"),
MONTH(1,"MONTH","自然月"),
YEAR(1,"YEAR","自然年"),
;
private Integer type;
private String code;
private String name;
public Integer getType() {
return type;
}
public void setType(Integer type) {
this.type = type;
}
public String getCode() {
return code;
}
public void setCode(String code) {
this.code = code;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
LimitUnitEnum(Integer type,String code, String name) {
this.type = type;
this.code = code;
this.name = name;
}
public Long getExpireTs(){
return getExpireTs(System.currentTimeMillis());
}
public Long getExpireTs(Long compareTime){
switch (this){
case PER_SECOND:
return compareTime+1000L;
case PER_MINUTES:
return compareTime+60000L;
case PER_HOUR:
return compareTime+60*60000L;
case PER_DAY:
return compareTime+24*60*60000L;
case PER_WEEK:
return compareTime+7*24*60*60000L;
case PER_MONTH:
return compareTime+30*24*60*60000L;
case PER_YEAR:
return compareTime+365*24*60*60000L;
default:
return getNaturalTime(compareTime);
}
}
/**
* 获取自然时间
* @param compareTime
* @return
*/
private Long getNaturalTime(Long compareTime) {
Calendar c=Calendar.getInstance();
c.setTimeInMillis(compareTime);
//不需要break
switch (this){
case YEAR:
c.set(Calendar.MONTH,11);
case MONTH:
c.set(Calendar.DAY_OF_MONTH,c.getActualMaximum(Calendar.DAY_OF_MONTH));
case WEEK:
if(WEEK.equals(this)){
int week = c.get(Calendar.DAY_OF_WEEK);
if(week>1){
c.add(Calendar.DATE,8-week);
}
}
case DAY:
c.set(Calendar.HOUR,23);
case HOUR:
c.set(Calendar.MINUTE,59);
case MINUTES:
c.set(Calendar.SECOND,59);
}
c.set(Calendar.MILLISECOND,999);
return c.getTimeInMillis();
}
}
2.编写限流类型枚举类
package com.around.common.utils.limitrate;
/**
* @program: com-around
* @description:
* @author: moodincode
* @create: 2021/1/08
**/
public enum LimitTypeEnum {
/**
* 单IP-令牌桶算法-指请求用户的IP
*/
IP_TOKEN(0,"IP_TOKEN","单IP-令牌桶算法"),
/**
* 单应用-令牌桶算法 -指部署的应用
*/
APP_TOKEN(0,"APP_TOKEN","单应用-令牌桶算法"),
/**
* 全局-令牌桶算法 -全部
*/
GLOBAL_TOKEN(0,"GLOBAL_TOKEN","全局-令牌桶算法"),
/**
* 单IP-漏桶算法-指请求用户的IP
*/
IP_LEAKY(1,"IP_LEAKY","单IP-漏桶算法"),
/**
* 单应用-漏桶算法 -指部署的应用
*/
APP_LEAKY(1,"APP_LEAKY","单应用-漏桶算法"),
/**
* 全局-漏桶算法 -全部
*/
GLOBAL_LEAKY(1,"GLOBAL_LEAKY","全局-漏桶算法"),
;
private int code;
private String name;
private String decr;
LimitTypeEnum(int code, String name, String decr) {
this.code = code;
this.name = name;
this.decr = decr;
}
public int getCode() {
return code;
}
public String getName() {
return name;
}
public String getDecr() {
return decr;
}
}
3.编写限流通用接口
package com.around.common.utils.limitrate;
/**
* @description: 限流提供类
* @author: moodincode
* @create 2021/1/8
*/
public interface LimitRateProvider {
/**
* 消费指定数量的key
* @param key 对应key
* @param num 消费的数量
* @param ts 最后一次消费时间 距离1970毫秒数
* @param expire 令牌桶重置时间,多久没操作会重置,时间毫秒数
* @param rate 速度,令牌token生成的速度或漏桶流速
* @param interval 间隔,生成令牌的时间间隔,为0则直接增加单位速度量
* @param capacity 令牌桶或漏桶的最大容量,-1不限制,
* @param type 限流类型 0-令牌桶,1-漏桶
* @param compareTime 时间间隔比较时间,为0时默认当前时间 毫秒数
* @return
*/
Long consumeCount(String key, Long num, Long ts, Long expire, Long rate, Long interval, Long capacity, int type, Long compareTime);
/**
* 查询当前容量,不涉及加减数量
* @param key 对应key
* @return 如果不存在返回,或数量为0均返回0,其他 大于0
*/
Long getCurrentCount(String key);
/**
* 增加容量
* @param key 对应key
* @param num 增加的数量
* @param capacity 容量限制 0不限制,大于0 限制
* @param type 限流类型 0-令牌桶,1-漏桶
* @return
*/
Long addCount(String key,long num,long capacity,int type);
}
4.编写redis序列化工具类
package com.around.common.utils.limitrate;
import org.springframework.data.redis.serializer.RedisSerializer;
import org.springframework.data.redis.serializer.SerializationException;
import java.nio.charset.StandardCharsets;
/**
* @description:
* @author: moodincode
* @create 2021/1/8
*/
public class LongStringSerialize implements RedisSerializer<Long> {
@Override
public byte[] serialize(Long number) throws SerializationException {
return String.valueOf(number).getBytes(StandardCharsets.UTF_8);
}
@Override
public Long deserialize(byte[] bytes) throws SerializationException {
return Long.valueOf(new String(bytes,StandardCharsets.UTF_8));
}
}
5.编写redis限流实现类和lua脚本
package com.around.common.utils.limitrate;
import com.alibaba.fastjson.support.spring.GenericFastJsonRedisSerializer;
import com.google.common.collect.Lists;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.List;
/**
* @description:
* @author: moodincode
* @create 2021/1/8
*/
@Component
public class RedisLimitRateProvider implements LimitRateProvider {
@Resource
private RedisTemplate<String,Object> redisTemplate;
private static final String CONSUME_TOKEN_SCRIP="local r_key = tostring(KEYS[1]);\n" +
"local num = tonumber(ARGV[1]);\n" +
"local interval = tonumber(ARGV[2]);\n" +
"local rate = tonumber(ARGV[3]);\n" +
"local capacity = tonumber(ARGV[4]);\n" +
"local compareTime = tonumber(ARGV[5]);\n" +
"local ts = tonumber(ARGV[6]);\n" +
"local expire = tonumber(ARGV[7]);\n" +
"local haskey=redis.call('exists', r_key);" +
"if tonumber(haskey)==1 then\n " +
" local gen_ts = redis.call('hget', r_key, 'gen_ts');\n" +
" local count = redis.call('hget', r_key, 'count');\n" +
" local factor = math.ceil((compareTime - gen_ts) / interval - 1);\n" +
" if -factor > 0 then\n" +
" local factor = 0 \n" +
" end;\n" +
" local add_count = factor * rate;\n" +
" if add_count > 0 then\n" +
" count = count + add_count;\n" +
" if capacity > 0 and count > capacity then\n" +
" count = capacity;\n" +
" end;\n" +
" redis.call('hset', r_key, 'count', count);\n" +
" redis.call('hset', r_key, 'gen_ts', ts);\n" +
" redis.call('pexpireat', r_key, expire);\n" +
" end; \n" +
" if (count - num) > -1 then\n" +
" redis.call('hset', r_key, 'count', count - num);\n" +
" redis.call('hset', r_key, 'get_ts', ts);\n" +
" return (count - num);\n" +
" else \n" +
" return -1;\n" +
" end;\n" +
"else\n " +
" redis.call('hset', r_key, 'gen_ts', ts);\n" +
" redis.call('hset', r_key, 'count', capacity - num);\n" +
" redis.call('hset', r_key, 'get_ts', ts);\n" +
" redis.call('pexpireat', r_key, expire); \n" +
" return (capacity - num);\n" +
" end;" ;
/**增加令牌数量*/
private static final String ADD_TOKEN_SCRIP="local num = tonumber(ARGV[1]);\n" +
"local capacity=tonumber(ARGV[2]);\n"
+"if redis.call('exists',KEYS[1])=='1' then" +
"local count = redis.call('hget',KEYS[1], 'count');" +
"count=count+num;" +
"if capacity >0 and count>capacity then count=capacity; end;" +
"local res=redis.call('hset',KEYS[1], 'count',count);" +
" return res;" +
"else return -1; end;";
/**漏桶redis脚本ARGV[1]:num ARGV[2]:interval,ARGV[3]:rate,ARGV[4]:capacity ARGV[5]:compareTime ARGV[6]:ts ARGV[7]:expire**/
private static final String CONSUME_LEAKY_SCRIP="local r_key = tostring(KEYS[1]);\n" +
"local num = tonumber(ARGV[1]);\n" +
"local interval = tonumber(ARGV[2]);\n" +
"local rate = tonumber(ARGV[3]);\n" +
"local capacity = tonumber(ARGV[4]);\n" +
"local compareTime = tonumber(ARGV[5]);\n" +
"local ts = tonumber(ARGV[6]);\n" +
"local expire = tonumber(ARGV[7]);\n" +
"local haskey=redis.call('exists', r_key);\n" +
"if tonumber(haskey)==1 then \n" +
" local gen_ts = redis.call('hget', r_key, 'gen_ts');\n" +
" local count = redis.call('hget', r_key, 'count');\n" +
" local factor = math.ceil((compareTime - gen_ts) / interval - 1);\n" +
" if -factor > 0 then\n" +
" local factor = 0 \n" +
" end;\n" +
" local de_count = factor * rate;\n" +
" if de_count > 0 then\n" +
" count = count - de_count;\n" +
" if -count>0 then\n" +
" count=0;\n" +
" end;\n" +
" redis.call('hset', r_key, 'count', count);\n" +
" redis.call('hset', r_key, 'gen_ts', ts);\n" +
" redis.call('pexpireat', r_key, expire);\n" +
" end;\n" +
" local total=(count+num);\n" +
" if capacity-total >-1 then\n" +
" redis.call('hset', r_key, 'count', total);\n" +
" redis.call('hset', r_key, 'get_ts', ts);\n" +
" return (capacity - total);\n" +
" else \n" +
" return -1;\n" +
" end;\n" +
"else \n" +
" redis.call('hset', r_key, 'gen_ts', ts);\n" +
" redis.call('hset', r_key, 'count', num);\n" +
" redis.call('hset', r_key, 'get_ts', ts);\n" +
" redis.call('pexpireat', r_key, expire); \n" +
" return (capacity - num);\n" +
"end;";
/**减少漏桶数量*/
private static final String DEDUCE_LEAKY_SCRIP="local num = tonumber(ARGV[1]);\n" +
"local capacity=tonumber(ARGV[2]);\n" +
"if redis.call('exists',KEYS[1])=='1' then" +
"local count = redis.call('hget',KEYS[1], 'count');" +
"count=count-num;" +
"if count <0 then count=0; end;" +
"local res=redis.call('hset',KEYS[1], 'count',count);" +
" return capacity-res;" +
"else return -1; end;";
/**
* 消费指定数量的key
*
* @param key 对应key
* @param num 消费的数量
* @param ts 最后一次消费时间
* @param expire 令牌桶重置时间,多久没操作会重置,毫秒数
* @param rate 速度,令牌token生成的速度或漏桶流速
* @param interval 间隔,生成令牌的时间,为0则直接增加单位速度量
* @param capacity 令牌桶或漏桶的最大容量,-1不限制,
* @param type 限流类型 0-令牌桶,1-漏桶
* @param compareTime 时间间隔比较时间,为0时默认当前时间
* @return
*/
@Override
public Long consumeCount(String key, Long num, Long ts, Long expire, Long rate, Long interval, Long capacity, int type, Long compareTime) {
long millis = System.currentTimeMillis();
if (compareTime < 1) {
compareTime = millis;
}
if(ts<1){
ts=millis;
}
DefaultRedisScript<Object> redisScript=new DefaultRedisScript<>();
if(type==0){
redisScript.setScriptText(CONSUME_TOKEN_SCRIP);
}else{
redisScript.setScriptText(CONSUME_LEAKY_SCRIP);
}
redisScript.setResultType(Object.class);
LongStringSerialize java = new LongStringSerialize();
GenericFastJsonRedisSerializer serializer=new GenericFastJsonRedisSerializer();
List<Object> list =(List<Object>) redisTemplate.execute(redisScript,java,serializer, Lists.newArrayList(key),num,interval,rate,capacity,compareTime,ts,expire);
return Long.valueOf(list.get(0).toString());
}
/**
* 查询当前容量,不涉及加减数量
*
* @param key 对应key
* @return 如果不存在返回, 或数量为0均返回0, 其他 大于0
*/
@Override
public Long getCurrentCount(String key) {
Object count = redisTemplate.opsForHash().get(key, "count");
if(count!=null){
return Long.valueOf(count.toString());
}
return 0L;
}
/**
* 增加容量
*
* @param key 对应key
* @param num 增加的令牌数量或减少漏桶中的数量
* @param capacity 容量限制 对于令牌桶 小于1 不限制,大于0 限制,对于漏桶 为剩余最小值,建议为0
* @param type 限流类型 0-令牌桶,1-漏桶
* @return 如果key 过期则返回-1
*/
@Override
public Long addCount(String key, long num, long capacity, int type) {
DefaultRedisScript<Object> redisScript=new DefaultRedisScript<>();
if(type==0){
redisScript.setScriptText(ADD_TOKEN_SCRIP);
}else{
redisScript.setScriptText(DEDUCE_LEAKY_SCRIP);
}
redisScript.setResultType(Object.class);
LongStringSerialize java = new LongStringSerialize();
GenericFastJsonRedisSerializer serializer=new GenericFastJsonRedisSerializer();
List<Object> list =(List<Object>) redisTemplate.execute(redisScript,java,serializer, Lists.newArrayList(key),num,capacity);
return Long.valueOf(list.get(0).toString());
}
}
6.编写http获取request工具类
package com.around.common.utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.util.*;
/**
* @program:
* @description: Http工具类
* @author: moodincode
* @create: 2020/9/28
**/
public class WebRequestUtil {
private final static Logger log = LoggerFactory.getLogger(WebRequestUtil.class);
private static String localIp;
public static HttpServletRequest getRequest() {
if (RequestContextHolder.getRequestAttributes() != null) {
return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
} else {
return null;
}
}
public static HttpServletResponse getResponse() {
if (RequestContextHolder.getRequestAttributes() != null) {
return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getResponse();
} else {
return null;
}
}
private static final String UNKNOWN = "unknown";
private static final String SYMBOL = ",";
/**
* 获取http request body中的参数
*
* @param request
* @return
*/
public static String readRequestBodyParams(HttpServletRequest request) {
BufferedReader br = null;
StringBuilder sb = new StringBuilder("");
try {
br = request.getReader();
String str;
while ((str = br.readLine()) != null) {
sb.append(str);
}
br.close();
} catch (IOException e) {
log.error("获取body参数失败", e);
} finally {
if (null != br) {
try {
br.close();
} catch (IOException e) {
log.error("获取body参数失败", e);
}
}
}
return sb.toString().replaceAll("\r|\n|\t", "");
}
public static String getIpAddress() {
return getIpAddress(getRequest());
}
/**
* 获取IP地址
*
* @param request
* @return
*/
public static String getIpAddress(HttpServletRequest request) {
//没有请求则是本机IP
if (request == null) {
return getLocalIpAddress();
}
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
// 如果是多级代理,那么取第一个ip为客户端ip
if (ip != null && ip.indexOf(SYMBOL) != -1) {
ip = ip.substring(0, ip.indexOf(SYMBOL)).trim();
}
return ip;
}
/**
* 获取cookies
*
* @param request
* @param key
* @return
*/
public static String getCookie(HttpServletRequest request, String key) {
String token = null;
Cookie[] cookies = request.getCookies();
if (cookies != null) {
for (int i = 0; i < cookies.length; i++) {
if (key.equals(cookies[i].getName())) {
token = cookies[i].getValue();
break;
}
}
}
return token;
}
/**
* 获取请求头信息
*
* @param request
* @return
*/
public static List<String> getHeaders(HttpServletRequest request) {
List<String> headList = new ArrayList<>();
Enumeration<String> headers = request.getHeaderNames();
while (headers.hasMoreElements()) {
String headName = headers.nextElement();
headList.add(String.format("%s:%s", headName, request.getHeader(headName)));
}
return headList;
}
/**
* 获取头部参数
*
* @param request
* @return
*/
public static Map<String, String> getHeaderMap(HttpServletRequest request) {
Map<String, String> headList = new HashMap<>();
Enumeration<String> headers = request.getHeaderNames();
while (headers.hasMoreElements()) {
String headName = headers.nextElement();
headList.put(headName, request.getHeader(headName));
}
return headList;
}
/**
* 获取本机IP地址
*
* @return
* @throws Exception
*/
public static String getLocalIpAddress() {
if (localIp != null) {
return localIp;
}
try {
InetAddress candidateAddress = null;
// 遍历所有的网络接口
for (Enumeration ifaces = NetworkInterface.getNetworkInterfaces(); ifaces.hasMoreElements(); ) {
NetworkInterface iface = (NetworkInterface) ifaces.nextElement();
// 在所有的接口下再遍历IP
for (Enumeration inetAddrs = iface.getInetAddresses(); inetAddrs.hasMoreElements(); ) {
InetAddress inetAddr = (InetAddress) inetAddrs.nextElement();
// 排除loopback类型地址
if (!inetAddr.isLoopbackAddress()) {
if (inetAddr.isSiteLocalAddress()) {
// 如果是site-local地址,就是它了
localIp= inetAddr.getHostAddress();
return localIp;
} else if (candidateAddress == null) {
// site-local类型的地址未被发现,先记录候选地址
candidateAddress = inetAddr;
}
}
}
}
if (candidateAddress != null) {
localIp= candidateAddress.getHostAddress();
return localIp;
}
// 如果没有发现 non-loopback地址.只能用最次选的方案
InetAddress jdkSuppliedAddress = InetAddress.getLocalHost();
localIp= jdkSuppliedAddress.getHostAddress();
} catch (Exception e) {
e.printStackTrace();
}
return localIp;
}
/**
* 获取body内容,注意,request对象需要重写,因为request.getReader()只能用一次
* @param request
* @return
*/
public static String getBodyString(HttpServletRequest request){
//字符串读取
try {
BufferedReader br = request.getReader();
String str, wholeStr = "";
while((str = br.readLine()) != null){
wholeStr += str;
}
return wholeStr;
} catch (IOException e) {
log.error("解析body参数失败,原因:",e);
}
return null;
}
}
7.编写限流单位计算工具
package com.around.common.utils.limitrate;
/**
* @description: 限流单位计算
* @author: moodincode
* @create 2021/1/08
*/
public class LimitTime {
private Long millis;
private Long expire;
private Long intermission;
private Long compareTime;
public LimitTime(Long millis, Long expire, Long intermission, Long compareTime) {
this.millis = millis;
this.expire = expire;
this.intermission = intermission;
this.compareTime = compareTime;
}
public Long getMillis() {
return millis;
}
public void setMillis(Long millis) {
this.millis = millis;
}
public Long getExpire() {
return expire;
}
public void setExpire(Long expire) {
this.expire = expire;
}
public Long getIntermission() {
return intermission;
}
public void setIntermission(Long intermission) {
this.intermission = intermission;
}
public Long getCompareTime() {
return compareTime;
}
public void setCompareTime(Long compareTime) {
this.compareTime = compareTime;
}
/**
* 计算时间范围
* @param millis
* @param interval
* @param unit
* @return 时间单位对象
*/
public static LimitTime calculateTime(long millis, Long interval,LimitUnitEnum unit) {
LimitTime time=new LimitTime(millis,-1L,interval,millis);
Long expireTs = unit.getExpireTs(millis);
time.setExpire(expireTs);
if(interval<1000L){
time.setIntermission(expireTs-millis);
}
return time;
}
}
8.编写限流工具类
package com.around.common.utils.limitrate;
import com.around.common.utils.WebRequestUtil;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
/**
* @description: 限流工具类
* @author: moodincode
* @create 2021/1/8
*/
@Component
public class LimitRateUtil {
/**机器ID限流**/
public static String MACHINE_ID="";
/**全局限流*/
public static String GLOBAL="limit:00:";
static {
String localIpAddress = WebRequestUtil.getLocalIpAddress();
if(!StringUtils.isEmpty(localIpAddress)){
MACHINE_ID+=localIpAddress.replace(".","_").replace(":","_");
}
MACHINE_ID+="_"+System.currentTimeMillis();
MACHINE_ID="limit:m"+Math.abs(MACHINE_ID.hashCode())+":";
}
@Resource
private LimitRateProvider limitRateProvider;
/**
*
* @param request
* @param key
* @param rate
* @param interval 小于1000毫秒则采用单位时间
* @param unit
* @param type
* @param capacity
* @return
*/
public Long consumeCount(HttpServletRequest request,String key, Long rate, Long interval, LimitUnitEnum unit, LimitTypeEnum type, Long capacity){
//获取前缀
String keySuffix=getKeySuffix(type,request);
long millis = System.currentTimeMillis();
//计算时间单位
LimitTime time=LimitTime.calculateTime(millis,interval,unit);
return consumeCount(keySuffix + key,1L, rate, type, capacity, time);
}
/**
* 消耗次数
* @param key
* @param num
* @param rate
* @param type
* @param capacity
* @param time
* @return
*/
public Long consumeCount(String key,Long num, Long rate, LimitTypeEnum type, Long capacity, LimitTime time) {
if(capacity<1){
capacity=rate;
}
Long count =limitRateProvider.consumeCount(key,num, time.getMillis(), time.getExpire(), rate, time.getIntermission(), capacity, type.getCode(), time.getCompareTime());
return count;
}
/**
* 获取前缀
* @param type
* @param request
* @return
*/
public String getKeySuffix(LimitTypeEnum type, HttpServletRequest request) {
if(LimitTypeEnum.GLOBAL_LEAKY.equals(type)||LimitTypeEnum.GLOBAL_TOKEN.equals(type)){
return GLOBAL;
}else if(LimitTypeEnum.APP_LEAKY.equals(type)||LimitTypeEnum.APP_TOKEN.equals(type)){
return MACHINE_ID;
}else {
String ipAddress = WebRequestUtil.getIpAddress(request);
return "limit:i"+ipAddress.replace(".","_").replace(":","_")+":";
}
}
/**
* 添加count 用于回滚
* @param key
* @param num
* @param capacity
* @param type
* @return
*/
public Long addCount(String key, Long num, Long capacity, LimitTypeEnum type){
return limitRateProvider.addCount(key,num,capacity, type.getCode());
}
public static void main(String[] args) {
System.out.println(MACHINE_ID);
}
}
9.编写注解类
package com.around.common.utils.limitrate;
import java.lang.annotation.*;
/**
* @description: 限流工具类
* @author: moodincode
* @create 2021/1/8
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Repeatable(LimitRates.class)
public @interface LimitRate {
/**
* redis key名称前缀,默认为类名+方法的hash key名作为缓存名称前缀
* @return
*/
String name() default "";
/**
* key 支持EL表达式,可操作入参的对象
* @return
*/
String key() default "";
/**
* 被限制的提示信息
* @return
*/
String msg() default "接口被限制访问,请稍后再试";
/**
* 令牌从产生的速度,指定单位生成的令牌数量或漏桶的水流速度,默认熟读为5
* @return
*/
long rate() default 200;
/**
* 令牌或漏桶产生的单位间隔
* @return
*/
long interval() default 1;
/**
*
* @return
*/
LimitUnitEnum unit() default LimitUnitEnum.PER_SECOND;
/**
* 限流类型
* @return
*/
LimitTypeEnum type() default LimitTypeEnum.IP_TOKEN;
/**
* 桶的最大容量,-1不限制,0-默认速度的容量,其他为指定的容量
* @return
*/
long capacity() default 0;
/**
* 多个注解 按顺序从大到小排序
* 排序
* @return
*/
int order() default 0;
/**
* 是否异常回滚,回退次数
* @return
*/
boolean rbEx() default false;
}
package com.around.common.utils.limitrate;
import java.lang.annotation.*;
/**
* @description: 允许多个
* @author: moodincode
* @create 2021/1/12
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface LimitRates {
LimitRate[] value();
}
10.编写AOP拦截器
package com.around.common.utils.limitrate;
import com.around.common.utils.WebRequestUtil;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.Order;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.text.MessageFormat;
import java.util.*;
/**
* @description: 限流aop
* @author: moodincode
* @create: 2021/01/12 @ param @Order 设置优先级为较低,避免执行顺序问题
**/
@Aspect
@Component
@Order(999)
public class LimitRateCheckAop {
private static final Logger log= LoggerFactory.getLogger(LimitRateCheckAop.class);
@Resource
private LimitRateUtil limitRateUtil;
/**
* @param point
*/
@Around("@annotation(LimitRate)")
public Object around(ProceedingJoinPoint point) throws Throwable {
log.trace("enter LimitRateCheckAop");
MethodSignature methodSignature = (MethodSignature) point.getSignature();
// 获取注解中的内容
LimitRate[] limitRates = methodSignature.getMethod().getAnnotationsByType(LimitRate.class);
Arrays.sort(limitRates, Comparator.comparing(LimitRate::order).reversed());
HttpServletRequest request = WebRequestUtil.getRequest();
Map<String,LimitRate> rollbackKey=new HashMap<>();
for (LimitRate limitRate : limitRates) {
String key = buildRedisKey(point, methodSignature, limitRate);
if(limitRate.rbEx()){
String keySuffix = limitRateUtil.getKeySuffix(limitRate.type(), request);
key=keySuffix+key;
long millis = System.currentTimeMillis();
//计算时间单位
LimitTime time=LimitTime.calculateTime(millis,limitRate.interval(),limitRate.unit());
Long count = limitRateUtil.consumeCount(key, 1L, limitRate.rate(), limitRate.type(), limitRate.capacity(), time);
if(count<0){
//抛出日常
throw new RuntimeException(limitRate.msg());
}else{
//记录要回滚的key
rollbackKey.put(key,limitRate);
}
}else{
Long count = limitRateUtil.consumeCount(request, key, limitRate.rate(), limitRate.interval(), limitRate.unit(), limitRate.type(), limitRate.capacity());
if(count<0){
//抛出日常
throw new RuntimeException(limitRate.msg());
}
}
}
try {
return point.proceed();
}catch (Exception e){
//检查是否需要回滚的key
if(CollectionUtils.isEmpty(rollbackKey)){
doRollback(rollbackKey);
}
throw e;
}
}
/**
* 回滚key
* @param rollbackKey
*/
private void doRollback(Map<String, LimitRate> rollbackKey) {
for (Map.Entry<String, LimitRate> entry : rollbackKey.entrySet()) {
//增加会次数
limitRateUtil.addCount(entry.getKey(),1L,entry.getValue().capacity(), entry.getValue().type());
}
}
/**
* 构建key
* @param point
* @param methodSignature
* @param limitRate
* @return
*/
private String buildRedisKey(ProceedingJoinPoint point, MethodSignature methodSignature,LimitRate limitRate) {
// 方法名
String methodName = methodSignature.getName();
// 类名
String className = methodSignature.getDeclaringTypeName();
// 目标类、方法
log.debug("类名{}方法名{}", className, methodName);
String keyPrefix;
//如果没有指定key前缀,则使用类名+方法名的hash值
if(StringUtils.isEmpty(limitRate.name())){
// 防止内容过长,只取其hash值作为key的部分
keyPrefix = String.valueOf(Math.abs(MessageFormat.format("{0}:{1}", className, methodName).hashCode()));
}else{
keyPrefix=limitRate.name();
}
String paramKey = "";
Object[] args = point.getArgs();
if(!StringUtils.isEmpty(limitRate.key())){
//构建EL表达式的key
paramKey = buildExpressKey(methodSignature, limitRate.key(), args);
}
return keyPrefix+paramKey;
}
/**
* 使用el表达式解析key
* @param methodSignature
* @param el
* @param args
* @return
*/
private String buildExpressKey(MethodSignature methodSignature, String el, Object[] args) {
String paramKey="";
ExpressionParser expressionParser = new SpelExpressionParser();
Expression expression = expressionParser.parseExpression(el);
EvaluationContext context = new StandardEvaluationContext();
String[] parameterNames = methodSignature.getParameterNames();
for (int i = 0; i <parameterNames.length ; i++) {
context.setVariable(parameterNames[i],args[i]);
}
try {
//解析el表达式,将#id等替换为参数值
paramKey= expression.getValue(context).toString();
log.debug("限流工具构建key,el表达式解析key成功,el={},解析后值为:{}",el,paramKey);
}catch (Exception e){
log.error("参数key={}解析失败,{}{}",el,e,e.getMessage());
}
return paramKey;
}
}
11.在需要拦截的controller类上使用@LimitRate注解进行拦截,具体设置项看注解
项目git地址 https://gitee.com/moodincode/com_around_project/tree/master/common-utils
本文地址:https://blog.csdn.net/qq_24577585/article/details/112539375
下一篇: 如何使用Axcure的元件?