用Go語(yǔ)言做了一個(gè)分布式限流器, 看看實(shí)現(xiàn)的方法與步驟
項(xiàng)目的要求主要有以下幾點(diǎn):
- 支持本地/分布式限流,接口統(tǒng)一
- 支持多種限流算法的切換
- 方便配置,配置方式不確定
go 語(yǔ)言不是很支持 OOP,我在實(shí)現(xiàn)的時(shí)候是按 Java 的思路走的,所以看起來(lái)有點(diǎn)不倫不類(lèi),希望能拋磚引玉。
1. 接口定義
package ratelimit
import "time"
// 限流器接口
type Limiter interface {
Acquire() error
TryAcquire() bool
}
// 限流定義接口
type Limit interface {
Name() string
Key() string
Period() time.Duration
Count() int32
LimitType() LimitType
}
// 支持 burst
type BurstLimit interface {
Limit
BurstCount() int32
}
// 分布式定義的 burst
type DistLimit interface {
Limit
ClusterNum() int32
}
type LimitType int32
const (
CUSTOM LimitType = iota
IP
)
Limiter 接口參考了 Google 的 guava 包里的 Limiter 實(shí)現(xiàn)。Acquire 接口是阻塞接口,其實(shí)還需要加上 context 來(lái)保證調(diào)用鏈安全,因?yàn)閷?shí)際項(xiàng)目中并沒(méi)有用到 Acquire 接口,所以沒(méi)有實(shí)現(xiàn)完善;同理,超時(shí)時(shí)間的支持也可以通過(guò)添加新接口繼承自 Limiter 接口來(lái)實(shí)現(xiàn)。TryAcquire 會(huì)立即返回。
Limit 抽象了一個(gè)限流定義,Key() 方法返回這個(gè) Limit 的唯一標(biāo)識(shí),Name() 僅作輔助,Period() 表示周期,單位是秒,Count() 表示周期內(nèi)的最大次數(shù),LimitType()表示根據(jù)什么來(lái)做區(qū)分,如 IP,默認(rèn)是 CUSTOM.
BurstLimit 提供突發(fā)的能力,一般是配合令牌桶算法。DistLimit 新增 ClusterNum() 方法,因?yàn)?nbsp;mentor 要求分布式遇到錯(cuò)誤的時(shí)候,需要退化為單機(jī)版本,退化的策略即是:2 節(jié)點(diǎn)總共 100QPS,如果出現(xiàn)分區(qū),每個(gè)節(jié)點(diǎn)需要調(diào)整為各 50QPS
2. LocalCounterLimiter
package ratelimit
import (
"errors"
"fmt"
"math"
"sync"
"sync/atomic"
"time"
)
// todo timer 需要 stop
type localCounterLimiter struct {
limit Limit
limitCount int32 // 內(nèi)部使用,對(duì) limit.count 做了 <0 時(shí)的轉(zhuǎn)換
ticker *time.Ticker
quit chan bool
lock sync.Mutex
newTerm *sync.Cond
count int32
}
func (lim *localCounterLimiter) init() {
lim.newTerm = sync.NewCond(&lim.lock)
lim.limitCount = lim.limit.Count()
if lim.limitCount < 0 {
lim.limitCount = math.MaxInt32 // count 永遠(yuǎn)不會(huì)大于 limitCount,后面的寫(xiě)法保證溢出也沒(méi)問(wèn)題
} else if lim.limitCount == 0 {
// 禁止訪問(wèn), 會(huì)無(wú)限阻塞
} else {
lim.ticker = time.NewTicker(lim.limit.Period())
lim.quit = make(chan bool, 1)
go func() {
for {
select {
case <- lim.ticker.C:
fmt.Println("ticker .")
atomic.StoreInt32(&lim.count, 0)
lim.newTerm.Broadcast()
//lim.newTerm.L.Unlock()
case <- lim.quit:
fmt.Println("work well .")
lim.ticker.Stop()
return
}
}
}()
}
}
// todo 需要機(jī)制來(lái)防止無(wú)限阻塞, 不超時(shí)也應(yīng)該有個(gè)極限時(shí)間
func (lim *localCounterLimiter) Acquire() error {
if lim.limitCount == 0 {
return errors.New("rate limit is 0, infinity wait")
}
lim.newTerm.L.Lock()
for lim.count >= lim.limitCount {
// block instead of spinning
lim.newTerm.Wait()
//fmt.Println(count, lim.limitCount)
}
lim.count++
lim.newTerm.L.Unlock()
return nil
}
func (lim *localCounterLimiter) TryAcquire() bool {
count := atomic.AddInt32(&lim.count, 1)
if count > lim.limitCount {
return false
} else {
return true
}
}
代碼很簡(jiǎn)單,就不多說(shuō)了
3. LocalTokenBucketLimitergolang 的官方庫(kù)里提供了一個(gè) ratelimiter,就是采用令牌桶的算法。所以這里并沒(méi)有重復(fù)造輪子,直接代理了 ratelimiter。
package ratelimit
import (
"context"
"golang.org/x/time/rate"
"math"
)
type localTokenBucketLimiter struct {
limit Limit
limiter *rate.Limiter // 直接復(fù)用令牌桶的
}
func (lim *localTokenBucketLimiter) init() {
burstCount := lim.limit.Count()
if burstLimit, ok := lim.limit.(BurstLimit); ok {
burstCount = burstLimit.BurstCount()
}
count := lim.limit.Count()
if count < 0 {
count = math.MaxInt32
}
f := float64(count) / lim.limit.Period().Seconds()
if f < 0 {
f = float64(rate.Inf) // 無(wú)限
} else if f == 0 {
panic("為 0 的時(shí)候,底層實(shí)現(xiàn)有問(wèn)題")
}
lim.limiter = rate.NewLimiter(rate.Limit(f), int(burstCount))
}
func (lim *localTokenBucketLimiter) Acquire() error {
err := lim.limiter.Wait(context.TODO())
return err
}
func (lim *localTokenBucketLimiter) TryAcquire() bool {
return lim.limiter.Allow()
}
4. RedisCounterLimiter
package ratelimit
import (
"math"
"sync"
"xg-go/log"
"xg-go/xg/common"
)
type redisCounterLimiter struct {
limit DistLimit
limitCount int32 // 內(nèi)部使用,對(duì) limit.count 做了 <0 時(shí)的轉(zhuǎn)換
redisClient *common.RedisClient
once sync.Once // 退化為本地計(jì)數(shù)器的時(shí)候使用
localLim Limiter
//script string
}
func (lim *redisCounterLimiter) init() {
lim.limitCount = lim.limit.Count()
if lim.limitCount < 0 {
lim.limitCount = math.MaxInt32
}
//lim.script = buildScript()
}
//func buildScript() string {
// sb := strings.Builder{}
//
// sb.WriteString("local c")
// sb.WriteString("\nc = redis.call('get',KEYS[1])")
// // 調(diào)用不超過(guò)最大值,則直接返回
// sb.WriteString("\nif c and tonumber(c) > tonumber(ARGV[1]) then")
// sb.WriteString("\nreturn c;")
// sb.WriteString("\nend")
// // 執(zhí)行計(jì)算器自加
// sb.WriteString("\nc = redis.call('incr',KEYS[1])")
// sb.WriteString("\nif tonumber(c) == 1 then")
// sb.WriteString("\nredis.call('expire',KEYS[1],ARGV[2])")
// sb.WriteString("\nend")
// sb.WriteString("\nif tonumber(c) == 1 then")
// sb.WriteString("\nreturn c;")
//
// return sb.String()
//}
func (lim *redisCounterLimiter) Acquire() error {
panic("implement me")
}
func (lim *redisCounterLimiter) TryAcquire() (success bool) {
defer func() {
// 一般是 redis 連接斷了,會(huì)觸發(fā)空指針
if err := recover(); err != nil {
//log.Errorw("TryAcquire err", common.ERR, err)
//success = lim.degradeTryAcquire()
//return
success = true
}
// 沒(méi)有錯(cuò)誤,判斷是否開(kāi)啟了 local 如果開(kāi)啟了,把它停掉
//if lim.localLim != nil {
// // stop 線(xiàn)程安全
// lim.localLim.Stop()
//}
}()
count, err := lim.redisClient.IncrBy(lim.limit.Key(), 1)
//panic("模擬 redis 出錯(cuò)")
if err != nil {
log.Errorw("TryAcquire err", common.ERR, err)
panic(err)
}
// *2 是為了保留久一點(diǎn),便于觀察
err = lim.redisClient.Expire(lim.limit.Key(), int(2 * lim.limit.Period().Seconds()))
if err != nil {
log.Errorw("TryAcquire error", common.ERR, err)
panic(err)
}
// 業(yè)務(wù)正確的情況下 確認(rèn)超限
if int32(count) > lim.limitCount {
return false
}
return true
//keys := []string{lim.limit.Key()}
//
//log.Errorw("TryAcquire ", keys, lim.limit.Count(), lim.limit.Period().Seconds())
//count, err := lim.redisClient.Eval(lim.script, keys, lim.limit.Count(), lim.limit.Period().Seconds())
//if err != nil {
// log.Errorw("TryAcquire error", common.ERR, err)
// return false
//}
//
//
//typeName := reflect.TypeOf(count).Name()
//log.Errorw(typeName)
//
//if count != nil && count.(int32) <= lim.limitCount {
//
// return true
//}
//return false
}
func (lim *redisCounterLimiter) Stop() {
// 判斷是否開(kāi)啟了 local 如果開(kāi)啟了,把它停掉
if lim.localLim != nil {
// stop 線(xiàn)程安全
lim.localLim.Stop()
}
}
func (lim *redisCounterLimiter) degradeTryAcquire() bool {
lim.once.Do(func() {
count := lim.limit.Count() / lim.limit.ClusterNum()
limit := LocalLimit {
name: lim.limit.Name(),
key: lim.limit.Key(),
count: count,
period: lim.limit.Period(),
limitType: lim.limit.LimitType(),
}
lim.localLim = NewLimiter(&limit)
})
return lim.localLim.TryAcquire()
}
代碼里回退的部分注釋了,因?yàn)榫€(xiàn)上為了穩(wěn)定,實(shí)習(xí)生的代碼畢竟,所以先不跑。
本來(lái)原有的思路是直接用 lua 腳本在 redis上保證原子操作,但是底層封裝的庫(kù)對(duì)于直接調(diào) eval 跑的時(shí)候,會(huì)拋錯(cuò),而且 source 是 go-redis 里面,趕 ddl 沒(méi)有時(shí)間去 debug,所以只能用 incrBy + expire 分開(kāi)來(lái)。
5. RedisTokenBucketLimiter令牌桶的狀態(tài)變量得放在一個(gè) 線(xiàn)程安全/一致 的地方,redis 是不二人選。但是令牌桶的算法核心是個(gè)延遲計(jì)算得到令牌數(shù)量,這個(gè)是一個(gè)很長(zhǎng)的臨界區(qū),所以要么用分布式鎖,要么直接利用 redis 的單線(xiàn)程以原子方式跑。一般業(yè)界是后者,即 lua 腳本維護(hù)令牌桶的狀態(tài)變量、計(jì)算令牌。代碼類(lèi)似這種
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local intval = tonumber(ARGV[5])
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2) * intval
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
if allowed then
new_tokens = filled_tokens - requested
end
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
return { allowed, new_tokens }