欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

基于Redis的分布式服务限流方案

程序员文章站 2022-03-07 10:48:48
...

由于API接口无法控制调用方的行为,因此当遇到瞬时请求量激增时,会导致接口占用过多服务器资源,使得其他请求响应速度降低或是超时,更有甚者可能导致服务器宕机。 

 

限流指对应用服务接口的请求调用次数进行限制,对超过限制次数的请求则进行快速失败或丢弃。

 

限流可以应对:

1、热点业务带来的高并发请求;

2、客户端异常重试导致的并发请求;

3、恶意攻击请求;

 

限流算法多种多样,比如常见的:固定窗口计数器、滑动窗口计数器、漏桶、令牌桶等。本章通过Redis 的Lua来实现滑动窗口的计数器算法。

 

1、Redis lua脚本如下:

local ratelimit_info = redis.pcall('HMGET',KEYS[1],'last_time','current_token') 
local last_time = ratelimit_info[1] 
local current_token = tonumber(ratelimit_info[2]) 
local max_token = tonumber(ARGV[1]) 
local token_rate = tonumber(ARGV[2]) 
local current_time = tonumber(ARGV[3]) 
local reverse_time = token_rate*1000/max_token 
if current_token == nil then 
  current_token = max_token 
  last_time = current_time 
else 
  local past_time = current_time-last_time 
  local reverse_token = math.floor(past_time/reverse_time)
  current_token = current_token+reverse_token 
  last_time = reverse_time*reverse_token+last_time 
  if current_token>max_token then 
    current_token = max_token 
  end 
end 

local result = '0' 
if(current_token>0) then 
  result = '1' 
  current_token = current_token-1 
end 

redis.call('HMSET',KEYS[1],'last_time',last_time,'current_token',current_token) 
redis.call('pexpire',KEYS[1],math.ceil(reverse_time*(max_token-current_token)+(current_time-last_time))) 

return result

 

2、项目中引入spring-data-redis和commons-codec,相关配置请自行google。

 

3、RedisRateLimitScript类

package com.huatech.support.limit;

import org.apache.commons.codec.digest.DigestUtils;
import org.springframework.data.redis.core.script.RedisScript;

public class RedisRateLimitScript implements RedisScript<String> {

   private static final String SCRIPT = 
      "local ratelimit_info = redis.pcall('HMGET',KEYS[1],'last_time','current_token') local last_time = ratelimit_info[1] local current_token = tonumber(ratelimit_info[2]) local max_token = tonumber(ARGV[1]) local token_rate = tonumber(ARGV[2]) local current_time = tonumber(ARGV[3]) local reverse_time = token_rate*1000/max_token if current_token == nil then current_token = max_token last_time = current_time else local past_time = current_time-last_time local reverse_token = math.floor(past_time/reverse_time) current_token = current_token+reverse_token last_time = reverse_time*reverse_token+last_time if current_token>max_token then current_token = max_token end end local result = '0' if(current_token>0) then result = '1' current_token = current_token-1 end redis.call('HMSET',KEYS[1],'last_time',last_time,'current_token',current_token) redis.call('pexpire',KEYS[1],math.ceil(reverse_time*(max_token-current_token)+(current_time-last_time))) return result"; 

  @Override   
  public String getSha1() { 
    return DigestUtils.sha1Hex(SCRIPT); 
  } 

  @Override   
  public Class<String> getResultType() {     
	  return String.class; 
  } 

  @Override   
  public String getScriptAsString() {     
	  return SCRIPT; 
  } 
} 

 

4、添加RateLimit注解

package com.huatech.support.limit;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimit {
	
	/**
	 * 接口标识
	 * @return
	 */
	String value() default "";
	
	/**
	 * 周期:多久为一个周期,单位s
	 * @return
	 */
	int period() default 1;
	
	/**
	 * 周期速率
	 * @return
	 */
	int rate() default 100;
	
	/**
	 * 限制类型,默认按接口限制
	 * @return
	 */
	LimitType limitType() default LimitType.GLOBAL;
	
	/**
	 * 超限后处理方式,默认拒绝访问
	 * @return
	 */
	LimitedMethod method() default LimitedMethod.ACCESS_DENIED;

}

 

基于Redis的分布式服务限流有两种落地方案:

一种是基于aop的切面实现,另一种是基于interceptor的拦截器实现,下面分别做介绍。

 

方案一:基于aspject的aop实现方案

1、添加LimitAspect类

package com.huatech.common.aop;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.util.WebUtils;

import com.alibaba.fastjson.JSONObject;
import com.huatech.common.constant.Constants;
import com.huatech.common.util.IpUtil;
import com.huatech.support.limit.RateLimit;
import com.huatech.support.limit.RedisRateLimitScript;


@Aspect
@Component
public class LimitAspect {
	
	private static final Logger LOGGER = LoggerFactory.getLogger(LimitAspect.class);
	@Autowired
	private StringRedisTemplate redisTemplate;
	
	@Around("execution(* com.huatech.core.controller..*(..) ) && @annotation(com.huatech.support.limit.RateLimit)")
	public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable{
		
		MethodSignature signature = (MethodSignature) joinPoint.getSignature();
		Method method = signature.getMethod();
		RateLimit rateLimit = method.getAnnotation(RateLimit.class);
		if(rateLimit !=	null) {
			ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
			HttpServletRequest request = requestAttributes.getRequest();
			HttpServletResponse response = requestAttributes.getResponse();
			
			Class<?> targetClass = method.getDeclaringClass();
			List<String> keyList = new ArrayList<>(1);
		    String key = rateLimit.value();
		    if(StringUtils.isBlank(key)){
		    	key = targetClass.getName() + "-" + method.getName();
		    }
		    switch (rateLimit.limitType()) {
			case IP:
				String ip = IpUtil.getRemortIP(request);
				key = ip + "-" + key;
				break;
			case USER:
				String userId = WebUtils.getSessionAttribute(request, Constants.SESSION_USER_ID).toString();
				key = userId + "-" + key;
			default:
				break;
			}
		    keyList.add(key);
		    
		    long timer = System.currentTimeMillis();
		    boolean pass = "1".equals(redisTemplate.execute(new RedisRateLimitScript(), keyList, 
		    		Integer.toString(rateLimit.rate()), Integer.toString(rateLimit.period()), 
		    		Long.toString(timer)));
		    if(pass){
		    	return joinPoint.proceed();
		    }else{				
		    	LOGGER.warn("接口key:{}, 周期:{}, 频率:{}", key, rateLimit.period(), rateLimit.rate());
		    	Map<String, Object> result = new HashMap<>();
				result.put("code", "400");
				result.put("msg", "访问超过次数限制!");
				response.setContentType("application/json");
				response.setCharacterEncoding("utf-8");
				response.getWriter().print(JSONObject.toJSON(result));
		    	return null;
		    }
		    
		}else{
			return joinPoint.proceed();
		}
	}
	
}

 

2、在spring-mvc配置文件中开启自定义注解

<aop:aspectj-autoproxy/>

 

3、开启LimitAspect类的自动扫描操作,或者在spring配置文件中配置bean

<context:component-scan base-package="com.huatech.common.aop,com.huatech.core.controller"/>  

 

 

方式二:基于interceptor的拦截器实现方案

1、添加RateLimitInterceptor类

public class RateLimitInterceptor extends HandlerInterceptorAdapter {
	
	private static final Logger LOGGER = LoggerFactory.getLogger(RateLimitInterceptor.class);
	@Autowired StringRedisTemplate redisTemplate;

	@Override
	public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {

		if (handler instanceof HandlerMethod) {
			HandlerMethod method = (HandlerMethod) handler;
			final RateLimit rateLimit = method.getMethodAnnotation(RateLimit.class);
			if (rateLimit != null) {
				// 令牌名称
				List<String> keyList = new ArrayList<>(1);
			    String key = rateLimit.value();
			    if(StringUtils.isBlank(key)){
			    	key = method.getClass().getName() + "-" + method.getMethod().getName();
			    }
			    switch (rateLimit.limitType()) {
					case IP:
						String ip = IpUtil.getRemortIP(request);
						key = ip + "-" + key;
						break;
					case USER:
						String userId = WebUtils.getSessionAttribute(request, Constants.SESSION_USER_ID).toString();
						key = "uid:" + userId + "-" + key;
					default:
						break;
				}
			    keyList.add(key);
			    
			    long timer = System.currentTimeMillis();
			    boolean pass = "1".equals(redisTemplate.execute(new RedisRateLimitScript(), keyList, 
			    		Integer.toString(rateLimit.rate()), Integer.toString(rateLimit.period()), 
			    		Long.toString(timer)));
			    if(pass){
			    	return true;
			    }else{				
			    	LOGGER.warn("接口key:{}, 周期:{}, 频率:{}", key, rateLimit.period(), rateLimit.rate());
			    	Map<String, Object> result = new HashMap<>();
					result.put("code", "400");
					result.put("msg", "访问超过次数限制!");
					response.setContentType("application/json");
					response.setCharacterEncoding("utf-8");
					response.getWriter().print(JSONObject.toJSON(result));
			    	return false;
			    }
				
			}
		}

		return true;
	}
}

 

2、在spring-mvc配置文件中配置拦截器

<!-- 拦截器配置 -->
 	<mvc:interceptors>
 		<!-- 其他拦截器配置 -->
		****
		<!-- 限速拦截器配置 -->
		<mvc:interceptor>
			<mvc:mapping path="/**"/>
			<bean class="com.huatech.common.interceptor.RateLimitInterceptor"/>
		</mvc:interceptor>
	</mvc:interceptors> 

 

使用@RateLimit

  在controller类的方法头上添加RateLimit注解

 /**
     * 服务端ping地址
     * @param request
     * @param response
     * @throws Exception
     */
    @RequestMapping(value = "/api/app/open/ping.htm")
    @RateLimit(value="ping", period=5, rate=5)
    public void ping(HttpServletRequest request, HttpServletResponse response) throws Exception {
    	Map<String, Object> data = new HashMap<String, Object>();
    	data.put("time", System.currentTimeMillis());
    	ServletUtils.successData(response,data);
    }

 

 

 另外两个枚举类

package com.huatech.support.limit;
/**
 * 超限处理方式
 * @author lh@erongdu.com
 * @since 2019年8月28日
 * @version 1.0
 *
 */
public enum LimitedMethod {
	
	/**
	 * 拒绝访问(直接拒绝访问,不预警)
	 */
	ACCESS_DENIED,
	/**
	 * 预警短信(发送预警短信,但不拒绝访问)
	 */
	WARN_SMS,
	/**
	 * 拒绝访问并预警
	 */
	DENIED_AND_SMS
	;

}

 

package com.huatech.support.limit;
/**
 * 接口限制类型
 * @author lh@erongdu.com
 * @since 2019年8月29日
 * @version 1.0
 *
 */
public enum LimitType {
	
	/**
	 * 整个接口限制
	 */
	GLOBAL("接口"), 
	/**
	 * ip层面限制
	 */
	IP("ip"), 
	/**
	 * 用户层面限制
	 */
	USER("用户");
	
	public String value;
	private LimitType(String value) {
		this.value = value;
	}
	
	
}

 

 IpUtil工具类

package com.huatech.common.util;

import java.net.InetAddress;
import java.net.UnknownHostException;

import javax.servlet.http.HttpServletRequest;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 
 * @author lh@erongdu.com
 * @since 2019年8月29日
 * @version 1.0
 *
 */
public class IpUtil {
	
	public static final Logger logger = LoggerFactory.getLogger(IpUtil.class);
    
	/**
	 * 获取请求IP
	 * @param request
	 * @return
	 */
	public static String getRemortIP(HttpServletRequest request) {
		String ip = request.getHeader("x-forwarded-for");
		if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
			ip = request.getHeader("X-Real-IP");
		}
		if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
			ip = request.getHeader("WL-Proxy-Client-IP");
		}
		if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
			ip = request.getRemoteAddr();
		}
		
		 //这里主要是获取本机的ip,可有可无  
	    if ("127.0.0.1".equals(ip) || ip.endsWith("0:0:0:0:0:0:1")) {  
	        // 根据网卡取本机配置的IP  
	        InetAddress inet = null;
	        try {  
	            inet = InetAddress.getLocalHost();  
	        } catch (UnknownHostException e) {  
	            logger.error(e.getMessage(), e);
	        }
	        if(inet != null){
	        	ip = inet.getHostAddress();
	        }
	        return ip;
	    } 
		if(ip.length() > 0){
			String[] ipArray = ip.split(",");
			if (ipArray != null && ipArray.length > 1) {
				return ipArray[0];
			}
			return ip;
		}
		
		return "";
	}
}

 

相关标签: redis