From 0b2fbef1f23d16be61c6f6c51043b5cc6a9e0324 Mon Sep 17 00:00:00 2001 From: Luca Bigliardi Date: Sat, 27 Mar 2021 00:49:16 +0100 Subject: [PATCH] new channel management logic this should handle bans and kicks a bit better Signed-off-by: Luca Bigliardi --- irc.go | 36 ++++++-- irc_test.go | 22 ++--- irc_testserver.go | 7 ++ reconciler.go | 224 +++++++++++++++++++++++++++++++++++++++------ reconciler_test.go | 25 +++-- 5 files changed, 255 insertions(+), 59 deletions(-) diff --git a/irc.go b/irc.go index efd2409..773a7d2 100644 --- a/irc.go +++ b/irc.go @@ -177,14 +177,36 @@ func (n *IRCNotifier) MaybeIdentifyNick() { time.Sleep(n.NickservDelayWait) } -func (n *IRCNotifier) MaybeSendAlertMsg(alertMsg *AlertMsg) { +func (n *IRCNotifier) ChannelJoined(channel string) bool { + + isJoined, waitJoined := n.channelReconciler.JoinChannel(channel) + if isJoined { + return true + } + + select { + case <-waitJoined: + return true + case <-time.After(ircJoinWaitSecs * time.Second): + log.Printf("Channel %s not joined after %d seconds, giving bad news to caller", channel, ircJoinWaitSecs) + return false + case <-n.stopCtx.Done(): + log.Printf("Context canceled while waiting for join on channel %s", channel) + return false + } +} + +func (n *IRCNotifier) SendAlertMsg(alertMsg *AlertMsg) { if !n.sessionUp { - log.Printf("Cannot send alert to %s : IRC not connected", - alertMsg.Channel) + log.Printf("Cannot send alert to %s : IRC not connected", alertMsg.Channel) ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_connected").Inc() return } - n.channelReconciler.JoinChannel(&IRCChannel{Name: alertMsg.Channel}) + if !n.ChannelJoined(alertMsg.Channel) { + log.Printf("Cannot send alert to %s : cannot join channel", alertMsg.Channel) + ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_joined").Inc() + return + } if n.UsePrivmsg { n.Client.Privmsg(alertMsg.Channel, alertMsg.Alert) @@ -213,10 +235,10 @@ func (n *IRCNotifier) ShutdownPhase() { func (n *IRCNotifier) ConnectedPhase() { select { case alertMsg := <-n.AlertMsgs: - n.MaybeSendAlertMsg(&alertMsg) + n.SendAlertMsg(&alertMsg) case <-n.sessionDownSignal: n.sessionUp = false - n.channelReconciler.CleanupChannels() + n.channelReconciler.Stop() n.Client.Quit("see ya") ircConnectedGauge.Set(0) case <-n.stopCtx.Done(): @@ -240,7 +262,7 @@ func (n *IRCNotifier) SetupPhase() { case <-n.sessionUpSignal: n.sessionUp = true n.MaybeIdentifyNick() - n.channelReconciler.JoinChannels() + n.channelReconciler.Start(n.stopCtx) ircConnectedGauge.Set(1) case <-n.sessionDownSignal: log.Printf("Receiving a session down before the session is up, this is odd") diff --git a/irc_test.go b/irc_test.go index c30fabe..9f22df8 100644 --- a/irc_test.go +++ b/irc_test.go @@ -108,7 +108,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) { if line.Args[0] == testChannel { testStep.Done() } - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinedHandler) @@ -117,7 +117,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) { testStep.Wait() - server.SetHandler("JOIN", nil) + server.SetHandler("JOIN", hJOIN) noticeHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { testStep.Done() @@ -163,7 +163,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) { if line.Args[0] == testChannel { testStep.Done() } - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinedHandler) @@ -172,7 +172,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) { testStep.Wait() - server.SetHandler("JOIN", nil) + server.SetHandler("JOIN", hJOIN) privmsgHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { testStep.Done() @@ -215,7 +215,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) { // ordering. joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { testStep.Done() - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinHandler) @@ -224,7 +224,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) { testStep.Wait() - server.SetHandler("JOIN", nil) + server.SetHandler("JOIN", hJOIN) noticeHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { testStep.Done() @@ -295,7 +295,7 @@ func TestSendAlertDisconnected(t *testing.T) { testStep.Add(1) joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { testStep.Done() - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinHandler) @@ -339,7 +339,7 @@ func TestReconnect(t *testing.T) { joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { testStep.Done() - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinHandler) @@ -409,7 +409,7 @@ func TestConnectErrorRetry(t *testing.T) { joinStep.Add(1) joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { joinStep.Done() - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinHandler) server.SetCloseEarly(nil) @@ -446,7 +446,7 @@ func TestIdentify(t *testing.T) { // after identification). joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { testStep.Done() - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinHandler) @@ -498,7 +498,7 @@ func TestGhostAndIdentify(t *testing.T) { // after identification). joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { testStep.Done() - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinHandler) diff --git a/irc_testserver.go b/irc_testserver.go index 5aca144..80791af 100644 --- a/irc_testserver.go +++ b/irc_testserver.go @@ -29,6 +29,12 @@ import ( type LineHandlerFunc func(*bufio.ReadWriter, *irc.Line) error +func hJOIN(conn *bufio.ReadWriter, line *irc.Line) error { + r := fmt.Sprintf(":foo!foo@example.com JOIN :%s\n", line.Args[0]) + _, err := conn.WriteString(r) + return err +} + func hUSER(conn *bufio.ReadWriter, line *irc.Line) error { r := fmt.Sprintf(":example.com 001 %s :Welcome\n", line.Args[0]) _, err := conn.WriteString(r) @@ -61,6 +67,7 @@ func (s *testServer) setDefaultHandlers() { if s.lineHandlers == nil { s.lineHandlers = make(map[string]LineHandlerFunc) } + s.lineHandlers["JOIN"] = hJOIN s.lineHandlers["USER"] = hUSER s.lineHandlers["QUIT"] = hQUIT } diff --git a/reconciler.go b/reconciler.go index 4caf136..d05eafc 100644 --- a/reconciler.go +++ b/reconciler.go @@ -23,9 +23,123 @@ import ( irc "github.com/fluffle/goirc/client" ) +const ( + ircJoinWaitSecs = 10 + ircJoinMaxBackoffSecs = 300 + ircJoinBackoffResetSecs = 1800 +) + type channelState struct { - Channel IRCChannel - BackoffCounter Delayer + channel IRCChannel + client *irc.Conn + delayer Delayer + + joinDone chan struct{} // joined when channel is closed + joined bool + + joinUnsetSignal chan bool + + mu sync.Mutex +} + +func newChannelState(channel *IRCChannel, client *irc.Conn, delayerMaker DelayerMaker) *channelState { + delayer := delayerMaker.NewDelayer(ircJoinMaxBackoffSecs, ircJoinBackoffResetSecs, time.Second) + + return &channelState{ + channel: *channel, + client: client, + delayer: delayer, + joinDone: make(chan struct{}), + joined: false, + joinUnsetSignal: make(chan bool), + } +} + +func (c *channelState) JoinDone() <-chan struct{} { + c.mu.Lock() + defer c.mu.Unlock() + + return c.joinDone +} + +func (c *channelState) SetJoined() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.joined == true { + log.Printf("Not setting JOIN state on channel %s: already set", c.channel.Name) + return + } + + log.Printf("Setting JOIN state on channel %s", c.channel.Name) + c.joined = true + close(c.joinDone) +} + +func (c *channelState) UnsetJoined() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.joined == false { + log.Printf("Not removing JOIN state on channel %s: already not set", c.channel.Name) + return + } + + log.Printf("Removing JOIN state on channel %s", c.channel.Name) + c.joined = false + c.joinDone = make(chan struct{}) + + // eventually poke monitor routine + select { + case c.joinUnsetSignal <- true: + default: + } +} + +func (c *channelState) join(ctx context.Context) { + log.Printf("Channel %s monitor: waiting to join", c.channel.Name) + if ok := c.delayer.DelayContext(ctx); !ok { + return + } + + c.client.Join(c.channel.Name, c.channel.Password) + log.Printf("Channel %s monitor: join request sent", c.channel.Name) + + select { + case <-c.JoinDone(): + log.Printf("Channel %s monitor: join succeeded", c.channel.Name) + case <-time.After(ircJoinWaitSecs * time.Second): + log.Printf("Channel %s monitor: could not join after %d seconds, will retry", c.channel.Name, ircJoinWaitSecs) + case <-ctx.Done(): + log.Printf("Channel %s monitor: context canceled while waiting for join", c.channel.Name) + } +} + +func (c *channelState) monitorJoinUnset(ctx context.Context) { + select { + case <-c.joinUnsetSignal: + log.Printf("Channel %s monitor: channel no longer joined", c.channel.Name) + case <-ctx.Done(): + log.Printf("Channel %s monitor: context canceled while monitoring", c.channel.Name) + } +} + +func (c *channelState) Monitor(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + + joined := func() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.joined + } + + for ctx.Err() != context.Canceled { + if !joined() { + c.join(ctx) + } else { + c.monitorJoinUnset(ctx) + } + } } type ChannelReconciler struct { @@ -57,53 +171,107 @@ func NewChannelReconciler(config *Config, client *irc.Conn, delayerMaker Delayer } func (r *ChannelReconciler) registerHandlers() { + r.client.HandleFunc(irc.JOIN, + func(_ *irc.Conn, line *irc.Line) { + r.HandleJoin(line.Nick, line.Args[0]) + }) + r.client.HandleFunc(irc.KICK, func(_ *irc.Conn, line *irc.Line) { r.HandleKick(line.Args[1], line.Args[0]) }) } +func (r *ChannelReconciler) HandleJoin(nick string, channel string) { + r.mu.Lock() + defer r.mu.Unlock() + + if nick != r.client.Me().Nick { + // received join info for somebody else + return + } + log.Printf("Received JOIN confirmation for channel %s", channel) + + c, ok := r.channels[channel] + if !ok { + log.Printf("Not processing JOIN for channel %s: unknown channel", channel) + return + } + c.SetJoined() +} + func (r *ChannelReconciler) HandleKick(nick string, channel string) { + r.mu.Lock() + defer r.mu.Unlock() + if nick != r.client.Me().Nick { // received kick info for somebody else return } - state, ok := r.channels[channel] + log.Printf("Received KICK for channel %s", channel) + + c, ok := r.channels[channel] if !ok { - log.Printf("Being kicked out of non-joined channel (%s), ignoring", channel) + log.Printf("Not processing KICK for channel %s: unknown channel", channel) return } - log.Printf("Being kicked out of %s, re-joining", channel) - go func() { - if ok := state.BackoffCounter.DelayContext(r.stopCtx); !ok { - return - } - r.client.Join(state.Channel.Name, state.Channel.Password) - }() + c.UnsetJoined() } -func (r *ChannelReconciler) CleanupChannels() { - log.Printf("Deregistering all channels.") +func (r *ChannelReconciler) unsafeAddChannel(channel *IRCChannel) *channelState { + c := newChannelState(channel, r.client, r.delayerMaker) + + r.stopWg.Add(1) + go c.Monitor(r.stopCtx, &r.stopWg) + + r.channels[channel.Name] = c + return c +} + +func (r *ChannelReconciler) JoinChannel(channel string) (bool, <-chan struct{}) { + r.mu.Lock() + defer r.mu.Unlock() + + c, ok := r.channels[channel] + if !ok { + log.Printf("Request to JOIN new channel %s", channel) + c = r.unsafeAddChannel(&IRCChannel{Name: channel}) + } + + select { + case <-c.JoinDone(): + return true, nil + default: + return false, c.JoinDone() + } +} + +func (r *ChannelReconciler) unsafeStop() { + if r.stopCtxCancel == nil { + // calling stop before first start, ignoring + return + } + r.stopCtxCancel() + r.stopWg.Wait() r.channels = make(map[string]*channelState) } -func (r *ChannelReconciler) JoinChannel(channel *IRCChannel) { - if _, joined := r.channels[channel.Name]; joined { - return - } - log.Printf("Joining %s", channel.Name) - r.client.Join(channel.Name, channel.Password) - state := &channelState{ - Channel: *channel, - BackoffCounter: r.delayerMaker.NewDelayer( - ircConnectMaxBackoffSecs, ircConnectBackoffResetSecs, - time.Second), - } - r.channels[channel.Name] = state +func (r *ChannelReconciler) Stop() { + r.mu.Lock() + defer r.mu.Unlock() + + r.unsafeStop() } -func (r *ChannelReconciler) JoinChannels() { +func (r *ChannelReconciler) Start(ctx context.Context) { + r.mu.Lock() + defer r.mu.Unlock() + + r.unsafeStop() + + r.stopCtx, r.stopCtxCancel = context.WithCancel(ctx) + for _, channel := range r.preJoinChannels { - r.JoinChannel(&channel) + r.unsafeAddChannel(&channel) } } diff --git a/reconciler_test.go b/reconciler_test.go index 413797b..d9cb09e 100644 --- a/reconciler_test.go +++ b/reconciler_test.go @@ -16,7 +16,9 @@ package main import ( "bufio" + "context" "reflect" + "sort" "sync" "testing" @@ -63,12 +65,14 @@ func TestPreJoinChannels(t *testing.T) { var testStep sync.WaitGroup + joinedChannels := []string{} + joinHandler := func(conn *bufio.ReadWriter, line *irc.Line) error { - // #baz is configured as the last channel to pre-join - if line.Args[0] == "#baz" { + joinedChannels = append(joinedChannels, line.Args[0]) + if len(joinedChannels) == 3 { testStep.Done() } - return nil + return hJOIN(conn, line) } server.SetHandler("JOIN", joinHandler) @@ -77,25 +81,20 @@ func TestPreJoinChannels(t *testing.T) { reconciler.client.Connect() <-sessionUp - reconciler.JoinChannels() + reconciler.Start(context.Background()) testStep.Wait() reconciler.client.Quit("see ya") <-sessionDown + reconciler.Stop() server.Stop() - expectedCommands := []string{ - "NICK foo", - "USER foo 12 * :", - "JOIN #foo", - "JOIN #bar", - "JOIN #baz", - "QUIT :see ya", - } + expectedJoinedChannels := []string{"#bar", "#baz", "#foo"} + sort.Strings(joinedChannels) - if !reflect.DeepEqual(expectedCommands, server.Log) { + if !reflect.DeepEqual(expectedJoinedChannels, joinedChannels) { t.Error("Did not pre-join channels") } }