最完整清晰的redis+ lua脚本 + 令牌桶算法 实现限流控制
在网上看了好多博客,感觉不是很清楚,于是决定自己手撸一个。
一、自定义一个注解,用来给限流的方法标注
@Target({ElementType
.TYPE
, ElementType
.METHOD
})
@Retention(RetentionPolicy
.RUNTIME
)
public @
interface RateLimit {
String
key() default "";
int time() default 1;
int count();
boolean ipLimit() default false;
}
二、编写lua脚本
重要的地方注释得非常详细了,这里就不多解释;
主要功能是:
根据key(参数) 查询 对应的 value(令牌数)
如果为null 说明该key 是第一次进入
{
初始化 令牌桶(参数)数量;记录初始化时间 ->返回 剩余令牌数
}
如果不为null
{
判断 value 是否大于1
{
大于1 ->value - 1 -> 返回 剩余令牌数
小于1 -> 判断 补充令牌时间间隔是否足够
{
足够 -> 补充令牌;更新补充令牌时间-> 返回 剩余令牌数
不足够 -> 返回 -1 (说明超过限流访问次数)
}
}
}
redis
.replicate_commands();
local key
= KEYS
[1]
local update_len
= tonumber(ARGV
[1])
local key_time
= 'ratetokenprefix'..key
local curr_time_arr
= redis
.call('TIME')
local nowTime
= tonumber(curr_time_arr
[1])
local curr_key_time
= tonumber(redis
.call('get',KEYS
[1]) or 0)
local token_count
= tonumber(redis
.call('get',KEYS
[1]) or -1)
local token_size
= tonumber(ARGV
[2])
if token_count
< 0 then
redis
.call('set',key_time
,nowTime
)
redis
.call('set',key
,token_size
-1)
return token_size
-1
else
if token_count
> 0 then
redis
.call('set',key
,token_count
- 1)
return token_count
-1
else
if curr_key_time
+ update_len
< nowTime
then
redis
.call('set',key
,token_size
-1)
return token_size
- 1
else
return -1
end
end
end
三、读取lua脚本
@Component
public class CommonConfig {
@Bean
public DefaultRedisScript
<Number> redisluaScript() {
DefaultRedisScript
<Number> redisScript
= new DefaultRedisScript<>();
redisScript
.setScriptSource(new ResourceScriptSource(new ClassPathResource("myLua.lua")));
redisScript
.setResultType(Number
.class);
return redisScript
;
}
@Bean
public RedisTemplate
<String, Serializable> limitRedisTemplate(LettuceConnectionFactory redisConnectionFactory
) {
RedisTemplate
<String, Serializable> template
= new RedisTemplate<String, Serializable>();
template
.setKeySerializer(new StringRedisSerializer());
template
.setValueSerializer(new GenericJackson2JsonRedisSerializer());
template
.setConnectionFactory(redisConnectionFactory
);
return template
;
}
}
四、创建拦截器拦截带有该注解的方法
@Component
public class RateLimitInterceptor implements HandlerInterceptor {
private final Logger LOG
= LoggerFactory
.getLogger(this.getClass());
@Autowired
private RedisTemplate
<String, Serializable> limitRedisTemplate
;
@Autowired
private DefaultRedisScript
<Number> redisLuaScript
;
@Override
public boolean preHandle(HttpServletRequest request
, HttpServletResponse response
, Object handler
) throws Exception
{
assert handler
instanceof HandlerMethod;
HandlerMethod method
= (HandlerMethod
) handler
;
RateLimit rateLimit
= method
.getMethodAnnotation(RateLimit
.class);
if (rateLimit
!= null
) {
int count
= rateLimit
.count();
String key
= rateLimit
.key();
int time
= rateLimit
.time();
boolean ipLimit
= rateLimit
.ipLimit();
StringBuilder sb
= new StringBuilder();
sb
.append(Constants
.RATE_LIMIT_KEY
).append(key
).append(":");
if(ipLimit
){
sb
.append(getIpAddress(request
)).append(":");
}
List
<String> keys
= Collections
.singletonList(sb
.toString());
Number execute
= limitRedisTemplate
.execute(redisLuaScript
, keys
, time
, count
);
assert execute
!= null
;
if (-1 == execute
.intValue()) {
ResultModel resultModel
= ResultModel
.error_900("接口调用超过限流次数");
response
.setStatus(901);
response
.setCharacterEncoding("utf-8");
response
.setContentType("application/json");
response
.getWriter().write(JSONObject
.toJSONString(resultModel
));
response
.getWriter().flush();
response
.getWriter().close();
LOG
.info("当前接口调用超过时间段内限流,key:{}", sb
.toString());
return false;
} else {
LOG
.info("当前访问时间段内剩余{}次访问次数", execute
.toString());
}
}
return true;
}
@Override
public void postHandle(HttpServletRequest request
, HttpServletResponse response
, Object handler
, ModelAndView modelAndView
) throws Exception
{
}
@Override
public void afterCompletion(HttpServletRequest request
, HttpServletResponse response
, Object handler
, Exception ex
) throws Exception
{
}
public static String
getIpAddr(HttpServletRequest request
) {
String ipAddress
= null
;
try {
ipAddress
= request
.getHeader("x-forwarded-for");
if (ipAddress
== null
|| ipAddress
.length() == 0 || "unknown".equalsIgnoreCase(ipAddress
)) {
ipAddress
= request
.getHeader("Proxy-Client-IP");
}
if (ipAddress
== null
|| ipAddress
.length() == 0 || "unknown".equalsIgnoreCase(ipAddress
)) {
ipAddress
= request
.getHeader("WL-Proxy-Client-IP");
}
if (ipAddress
== null
|| ipAddress
.length() == 0 || "unknown".equalsIgnoreCase(ipAddress
)) {
ipAddress
= request
.getRemoteAddr();
}
if (ipAddress
!= null
&& ipAddress
.length() > 15) {
if (ipAddress
.indexOf(",") > 0) {
ipAddress
= ipAddress
.substring(0, ipAddress
.indexOf(","));
}
}
} catch (Exception e
) {
ipAddress
= "";
}
return ipAddress
;
}
}
一个自定义的常量
用作redis前缀
public class Constants {
public static final String RATE_LIMIT_KEY
= "rateLimit:";
}
五、在WebConfig中注册这个这个拦截器
@Configuration
@EnableWebMvc
public class WebConfig extends WebMvcConfigurerAdapter {
@Autowired
private RateLimitInterceptor rateLimitInterceptor
;
@Override
public void addInterceptors(InterceptorRegistry registry
) {
registry
.addInterceptor(rateLimitInterceptor
);
super.addInterceptors(registry
);
}
}
六、注解使用
@RestController
@RequestMapping(value
= "/test")
public class TestController {
@RateLimit(key
= "testGet",time
= 1,count
= 5,ipLimit
= true)
@RequestMapping(value
= "/get")
public ResultModel
testGet(){
return ResultModel
.ok_200();
}
}
如果觉得有问题,欢迎各位大佬指正 觉得可以的话点个赞再走吧!!!!!!