基于Redis的分布式服务限流方案
程序员文章站
2022-03-07 10:48:48
...
由于API接口无法控制调用方的行为,因此当遇到瞬时请求量激增时,会导致接口占用过多服务器资源,使得其他请求响应速度降低或是超时,更有甚者可能导致服务器宕机。
限流指对应用服务接口的请求调用次数进行限制,对超过限制次数的请求则进行快速失败或丢弃。
限流可以应对:
1、热点业务带来的高并发请求;
2、客户端异常重试导致的并发请求;
3、恶意攻击请求;
限流算法多种多样,比如常见的:固定窗口计数器、滑动窗口计数器、漏桶、令牌桶等。本章通过Redis 的Lua来实现滑动窗口的计数器算法。
1、Redis lua脚本如下:
local ratelimit_info = redis.pcall('HMGET',KEYS[1],'last_time','current_token') local last_time = ratelimit_info[1] local current_token = tonumber(ratelimit_info[2]) local max_token = tonumber(ARGV[1]) local token_rate = tonumber(ARGV[2]) local current_time = tonumber(ARGV[3]) local reverse_time = token_rate*1000/max_token if current_token == nil then current_token = max_token last_time = current_time else local past_time = current_time-last_time local reverse_token = math.floor(past_time/reverse_time) current_token = current_token+reverse_token last_time = reverse_time*reverse_token+last_time if current_token>max_token then current_token = max_token end end local result = '0' if(current_token>0) then result = '1' current_token = current_token-1 end redis.call('HMSET',KEYS[1],'last_time',last_time,'current_token',current_token) redis.call('pexpire',KEYS[1],math.ceil(reverse_time*(max_token-current_token)+(current_time-last_time))) return result
2、项目中引入spring-data-redis和commons-codec,相关配置请自行google。
3、RedisRateLimitScript类
package com.huatech.support.limit; import org.apache.commons.codec.digest.DigestUtils; import org.springframework.data.redis.core.script.RedisScript; public class RedisRateLimitScript implements RedisScript<String> { private static final String SCRIPT = "local ratelimit_info = redis.pcall('HMGET',KEYS[1],'last_time','current_token') local last_time = ratelimit_info[1] local current_token = tonumber(ratelimit_info[2]) local max_token = tonumber(ARGV[1]) local token_rate = tonumber(ARGV[2]) local current_time = tonumber(ARGV[3]) local reverse_time = token_rate*1000/max_token if current_token == nil then current_token = max_token last_time = current_time else local past_time = current_time-last_time local reverse_token = math.floor(past_time/reverse_time) current_token = current_token+reverse_token last_time = reverse_time*reverse_token+last_time if current_token>max_token then current_token = max_token end end local result = '0' if(current_token>0) then result = '1' current_token = current_token-1 end redis.call('HMSET',KEYS[1],'last_time',last_time,'current_token',current_token) redis.call('pexpire',KEYS[1],math.ceil(reverse_time*(max_token-current_token)+(current_time-last_time))) return result"; @Override public String getSha1() { return DigestUtils.sha1Hex(SCRIPT); } @Override public Class<String> getResultType() { return String.class; } @Override public String getScriptAsString() { return SCRIPT; } }
4、添加RateLimit注解
package com.huatech.support.limit; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; @Target({ ElementType.TYPE, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface RateLimit { /** * 接口标识 * @return */ String value() default ""; /** * 周期:多久为一个周期,单位s * @return */ int period() default 1; /** * 周期速率 * @return */ int rate() default 100; /** * 限制类型,默认按接口限制 * @return */ LimitType limitType() default LimitType.GLOBAL; /** * 超限后处理方式,默认拒绝访问 * @return */ LimitedMethod method() default LimitedMethod.ACCESS_DENIED; }
基于Redis的分布式服务限流有两种落地方案:
一种是基于aop的切面实现,另一种是基于interceptor的拦截器实现,下面分别做介绍。
方案一:基于aspject的aop实现方案
1、添加LimitAspect类
package com.huatech.common.aop; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.commons.lang3.StringUtils; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.reflect.MethodSignature; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.stereotype.Component; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.util.WebUtils; import com.alibaba.fastjson.JSONObject; import com.huatech.common.constant.Constants; import com.huatech.common.util.IpUtil; import com.huatech.support.limit.RateLimit; import com.huatech.support.limit.RedisRateLimitScript; @Aspect @Component public class LimitAspect { private static final Logger LOGGER = LoggerFactory.getLogger(LimitAspect.class); @Autowired private StringRedisTemplate redisTemplate; @Around("execution(* com.huatech.core.controller..*(..) ) && @annotation(com.huatech.support.limit.RateLimit)") public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable{ MethodSignature signature = (MethodSignature) joinPoint.getSignature(); Method method = signature.getMethod(); RateLimit rateLimit = method.getAnnotation(RateLimit.class); if(rateLimit != null) { ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); HttpServletRequest request = requestAttributes.getRequest(); HttpServletResponse response = requestAttributes.getResponse(); Class<?> targetClass = method.getDeclaringClass(); List<String> keyList = new ArrayList<>(1); String key = rateLimit.value(); if(StringUtils.isBlank(key)){ key = targetClass.getName() + "-" + method.getName(); } switch (rateLimit.limitType()) { case IP: String ip = IpUtil.getRemortIP(request); key = ip + "-" + key; break; case USER: String userId = WebUtils.getSessionAttribute(request, Constants.SESSION_USER_ID).toString(); key = userId + "-" + key; default: break; } keyList.add(key); long timer = System.currentTimeMillis(); boolean pass = "1".equals(redisTemplate.execute(new RedisRateLimitScript(), keyList, Integer.toString(rateLimit.rate()), Integer.toString(rateLimit.period()), Long.toString(timer))); if(pass){ return joinPoint.proceed(); }else{ LOGGER.warn("接口key:{}, 周期:{}, 频率:{}", key, rateLimit.period(), rateLimit.rate()); Map<String, Object> result = new HashMap<>(); result.put("code", "400"); result.put("msg", "访问超过次数限制!"); response.setContentType("application/json"); response.setCharacterEncoding("utf-8"); response.getWriter().print(JSONObject.toJSON(result)); return null; } }else{ return joinPoint.proceed(); } } }
2、在spring-mvc配置文件中开启自定义注解
<aop:aspectj-autoproxy/>
3、开启LimitAspect类的自动扫描操作,或者在spring配置文件中配置bean
<context:component-scan base-package="com.huatech.common.aop,com.huatech.core.controller"/>
方式二:基于interceptor的拦截器实现方案
1、添加RateLimitInterceptor类
public class RateLimitInterceptor extends HandlerInterceptorAdapter { private static final Logger LOGGER = LoggerFactory.getLogger(RateLimitInterceptor.class); @Autowired StringRedisTemplate redisTemplate; @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { if (handler instanceof HandlerMethod) { HandlerMethod method = (HandlerMethod) handler; final RateLimit rateLimit = method.getMethodAnnotation(RateLimit.class); if (rateLimit != null) { // 令牌名称 List<String> keyList = new ArrayList<>(1); String key = rateLimit.value(); if(StringUtils.isBlank(key)){ key = method.getClass().getName() + "-" + method.getMethod().getName(); } switch (rateLimit.limitType()) { case IP: String ip = IpUtil.getRemortIP(request); key = ip + "-" + key; break; case USER: String userId = WebUtils.getSessionAttribute(request, Constants.SESSION_USER_ID).toString(); key = "uid:" + userId + "-" + key; default: break; } keyList.add(key); long timer = System.currentTimeMillis(); boolean pass = "1".equals(redisTemplate.execute(new RedisRateLimitScript(), keyList, Integer.toString(rateLimit.rate()), Integer.toString(rateLimit.period()), Long.toString(timer))); if(pass){ return true; }else{ LOGGER.warn("接口key:{}, 周期:{}, 频率:{}", key, rateLimit.period(), rateLimit.rate()); Map<String, Object> result = new HashMap<>(); result.put("code", "400"); result.put("msg", "访问超过次数限制!"); response.setContentType("application/json"); response.setCharacterEncoding("utf-8"); response.getWriter().print(JSONObject.toJSON(result)); return false; } } } return true; } }
2、在spring-mvc配置文件中配置拦截器
<!-- 拦截器配置 --> <mvc:interceptors> <!-- 其他拦截器配置 --> **** <!-- 限速拦截器配置 --> <mvc:interceptor> <mvc:mapping path="/**"/> <bean class="com.huatech.common.interceptor.RateLimitInterceptor"/> </mvc:interceptor> </mvc:interceptors>
使用@RateLimit
在controller类的方法头上添加RateLimit注解
/** * 服务端ping地址 * @param request * @param response * @throws Exception */ @RequestMapping(value = "/api/app/open/ping.htm") @RateLimit(value="ping", period=5, rate=5) public void ping(HttpServletRequest request, HttpServletResponse response) throws Exception { Map<String, Object> data = new HashMap<String, Object>(); data.put("time", System.currentTimeMillis()); ServletUtils.successData(response,data); }
另外两个枚举类
package com.huatech.support.limit; /** * 超限处理方式 * @author lh@erongdu.com * @since 2019年8月28日 * @version 1.0 * */ public enum LimitedMethod { /** * 拒绝访问(直接拒绝访问,不预警) */ ACCESS_DENIED, /** * 预警短信(发送预警短信,但不拒绝访问) */ WARN_SMS, /** * 拒绝访问并预警 */ DENIED_AND_SMS ; }
package com.huatech.support.limit; /** * 接口限制类型 * @author lh@erongdu.com * @since 2019年8月29日 * @version 1.0 * */ public enum LimitType { /** * 整个接口限制 */ GLOBAL("接口"), /** * ip层面限制 */ IP("ip"), /** * 用户层面限制 */ USER("用户"); public String value; private LimitType(String value) { this.value = value; } }
IpUtil工具类
package com.huatech.common.util; import java.net.InetAddress; import java.net.UnknownHostException; import javax.servlet.http.HttpServletRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * * @author lh@erongdu.com * @since 2019年8月29日 * @version 1.0 * */ public class IpUtil { public static final Logger logger = LoggerFactory.getLogger(IpUtil.class); /** * 获取请求IP * @param request * @return */ public static String getRemortIP(HttpServletRequest request) { String ip = request.getHeader("x-forwarded-for"); if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("X-Real-IP"); } if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getHeader("WL-Proxy-Client-IP"); } if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) { ip = request.getRemoteAddr(); } //这里主要是获取本机的ip,可有可无 if ("127.0.0.1".equals(ip) || ip.endsWith("0:0:0:0:0:0:1")) { // 根据网卡取本机配置的IP InetAddress inet = null; try { inet = InetAddress.getLocalHost(); } catch (UnknownHostException e) { logger.error(e.getMessage(), e); } if(inet != null){ ip = inet.getHostAddress(); } return ip; } if(ip.length() > 0){ String[] ipArray = ip.split(","); if (ipArray != null && ipArray.length > 1) { return ipArray[0]; } return ip; } return ""; } }