diff --git a/backoff.go b/backoff.go index 5873365..90492d2 100644 --- a/backoff.go +++ b/backoff.go @@ -15,6 +15,7 @@ package main import ( + "context" "log" "math" "math/rand" @@ -23,10 +24,9 @@ import ( type JitterFunc func(int) int -type TimeFunc func() time.Time - type Delayer interface { Delay() + DelayContext(context.Context) bool } type Backoff struct { @@ -36,7 +36,7 @@ type Backoff struct { lastAttempt time.Time durationUnit time.Duration jitterer JitterFunc - timeGetter TimeFunc + timeTeller TimeTeller } func jitterFunc(input int) int { @@ -46,27 +46,44 @@ func jitterFunc(input int) int { return rand.Intn(input) } +// TimeTeller interface allows injection of fake time during testing +type TimeTeller interface { + Now() time.Time + After(time.Duration) <-chan time.Time +} + +type RealTime struct{} + +func (r *RealTime) Now() time.Time { + return time.Now() +} + +func (r *RealTime) After(d time.Duration) <-chan time.Time { + return time.After(d) +} + func NewBackoff(maxBackoff float64, resetDelta float64, durationUnit time.Duration) *Backoff { + timeTeller := &RealTime{} return NewBackoffForTesting( - maxBackoff, resetDelta, durationUnit, jitterFunc, time.Now) + maxBackoff, resetDelta, durationUnit, jitterFunc, timeTeller) } func NewBackoffForTesting(maxBackoff float64, resetDelta float64, - durationUnit time.Duration, jitterer JitterFunc, timeGetter TimeFunc) *Backoff { + durationUnit time.Duration, jitterer JitterFunc, timeTeller TimeTeller) *Backoff { return &Backoff{ step: 0, maxBackoff: maxBackoff, resetDelta: resetDelta, - lastAttempt: timeGetter(), + lastAttempt: timeTeller.Now(), durationUnit: durationUnit, jitterer: jitterer, - timeGetter: timeGetter, + timeTeller: timeTeller, } } func (b *Backoff) maybeReset() { - now := b.timeGetter() + now := b.timeTeller.Now() lastAttemptDelta := float64(now.Sub(b.lastAttempt) / b.durationUnit) b.lastAttempt = now @@ -96,7 +113,18 @@ func (b *Backoff) GetDelay() time.Duration { } func (b *Backoff) Delay() { - delay := b.GetDelay() - log.Printf("Backoff for %s", delay) - time.Sleep(delay) + b.DelayContext(context.Background()) +} + +func (b *Backoff) DelayContext(ctx context.Context) bool { + delay := b.GetDelay() + log.Printf("Backoff for %s starts", delay) + select { + case <-b.timeTeller.After(delay): + log.Printf("Backoff for %s ends", delay) + case <-ctx.Done(): + log.Printf("Backoff for %s canceled by context", delay) + return false + } + return true } diff --git a/backoff_test.go b/backoff_test.go index f6448e7..f262ea3 100644 --- a/backoff_test.go +++ b/backoff_test.go @@ -15,6 +15,7 @@ package main import ( + "context" "testing" "time" ) @@ -23,29 +24,39 @@ type FakeTime struct { timeseries []int lastIndex int durationUnit time.Duration + afterChan chan time.Time } -func (f *FakeTime) GetTime() time.Time { +func (f *FakeTime) Now() time.Time { timeDelta := time.Duration(f.timeseries[f.lastIndex]) * f.durationUnit fakeTime := time.Unix(0, 0).Add(timeDelta) f.lastIndex++ return fakeTime } +func (f *FakeTime) After(d time.Duration) <-chan time.Time { + return f.afterChan +} + func FakeJitter(input int) int { return input } -func RunBackoffTest(t *testing.T, - maxBackoff float64, resetDelta float64, - elapsedTime []int, expectedDelays []int) { +func MakeTestingBackoff(maxBackoff float64, resetDelta float64, elapsedTime []int) (*Backoff, *FakeTime) { fakeTime := &FakeTime{ timeseries: elapsedTime, lastIndex: 0, durationUnit: time.Millisecond, + afterChan: make(chan time.Time, 1), } backoff := NewBackoffForTesting(maxBackoff, resetDelta, time.Millisecond, - FakeJitter, fakeTime.GetTime) + FakeJitter, fakeTime) + return backoff, fakeTime +} + +func RunBackoffTest(t *testing.T, maxBackoff float64, resetDelta float64, elapsedTime []int, expectedDelays []int) { + + backoff, _ := MakeTestingBackoff(maxBackoff, resetDelta, elapsedTime) for i, value := range expectedDelays { expected_delay := time.Duration(value) * time.Millisecond @@ -78,3 +89,19 @@ func TestBackoffReset(t *testing.T) { []int{0, 2, 4, 0, 2, 0, 2, 4}, ) } + +func TestBackoffDelayContext(t *testing.T) { + backoff, fakeTime := MakeTestingBackoff(8, 32, []int{0, 0, 0}) + + ctx, cancel := context.WithCancel(context.Background()) + + fakeTime.afterChan <- time.Now() + if ok := backoff.DelayContext(ctx); !ok { + t.Errorf("Expired time does not return true") + } + + cancel() + if ok := backoff.DelayContext(ctx); ok { + t.Errorf("Canceled context does not return false") + } +} diff --git a/irc.go b/irc.go index f30f2e6..a13da13 100644 --- a/irc.go +++ b/irc.go @@ -164,7 +164,9 @@ func (notifier *IRCNotifier) HandleKick(nick string, channel string) { } log.Printf("Being kicked out of %s, re-joining", channel) go func() { - state.BackoffCounter.Delay() + if ok := state.BackoffCounter.DelayContext(notifier.ctx); !ok { + return + } notifier.Client.Join(state.Channel.Name, state.Channel.Password) }() @@ -242,7 +244,9 @@ func (notifier *IRCNotifier) Run() { for notifier.ctx.Err() != context.Canceled { if !notifier.Client.Connected() { log.Printf("Connecting to IRC %s", notifier.Client.Config().Server) - notifier.BackoffCounter.Delay() + if ok := notifier.BackoffCounter.DelayContext(notifier.ctx); !ok { + continue + } if err := notifier.Client.Connect(); err != nil { log.Printf("Could not connect to IRC: %s", err) continue diff --git a/irc_test.go b/irc_test.go index 15d608d..d6cc9c2 100644 --- a/irc_test.go +++ b/irc_test.go @@ -202,6 +202,11 @@ func (f *FakeDelayer) Delay() { log.Printf("Faking Backoff") } +func (f *FakeDelayer) DelayContext(ctx context.Context) bool { + log.Printf("Faking Backoff") + return true +} + func makeTestIRCConfig(IRCPort int) *Config { return &Config{ IRCNick: "foo",