From 0ec08d5ea16b5a26e33fa4cc8541a280166a4d13 Mon Sep 17 00:00:00 2001
From: Luca Bigliardi <shammash@google.com>
Date: Sat, 27 Mar 2021 12:35:38 +0100
Subject: [PATCH] stop storing context received from outside

Signed-off-by: Luca Bigliardi <shammash@google.com>
---
 irc.go      | 41 +++++++++++++++------------------
 irc_test.go | 65 +++++++++++++++++++++++++++++++++--------------------
 main.go     |  4 ++--
 3 files changed, 61 insertions(+), 49 deletions(-)

diff --git a/irc.go b/irc.go
index 773a7d2..ccdad83 100644
--- a/irc.go
+++ b/irc.go
@@ -84,9 +84,6 @@ type IRCNotifier struct {
 	Client       *irc.Conn
 	AlertMsgs    chan AlertMsg
 
-	stopCtx context.Context
-	stopWg  *sync.WaitGroup
-
 	// irc.Conn has a Connected() method that can tell us wether the TCP
 	// connection is up, and thus if we should trigger connect/disconnect.
 	// We need to track the session establishment also at a higher level to
@@ -104,7 +101,7 @@ type IRCNotifier struct {
 	BackoffCounter    Delayer
 }
 
-func NewIRCNotifier(stopCtx context.Context, stopWg *sync.WaitGroup, config *Config, alertMsgs chan AlertMsg, delayerMaker DelayerMaker) (*IRCNotifier, error) {
+func NewIRCNotifier(config *Config, alertMsgs chan AlertMsg, delayerMaker DelayerMaker) (*IRCNotifier, error) {
 
 	ircConfig := makeGOIRCConfig(config)
 
@@ -121,8 +118,6 @@ func NewIRCNotifier(stopCtx context.Context, stopWg *sync.WaitGroup, config *Con
 		NickPassword:      config.IRCNickPass,
 		Client:            client,
 		AlertMsgs:         alertMsgs,
-		stopCtx:           stopCtx,
-		stopWg:            stopWg,
 		sessionUpSignal:   make(chan bool),
 		sessionDownSignal: make(chan bool),
 		channelReconciler: channelReconciler,
@@ -177,7 +172,7 @@ func (n *IRCNotifier) MaybeIdentifyNick() {
 	time.Sleep(n.NickservDelayWait)
 }
 
-func (n *IRCNotifier) ChannelJoined(channel string) bool {
+func (n *IRCNotifier) ChannelJoined(ctx context.Context, channel string) bool {
 
 	isJoined, waitJoined := n.channelReconciler.JoinChannel(channel)
 	if isJoined {
@@ -190,19 +185,19 @@ func (n *IRCNotifier) ChannelJoined(channel string) bool {
 	case <-time.After(ircJoinWaitSecs * time.Second):
 		log.Printf("Channel %s not joined after %d seconds, giving bad news to caller", channel, ircJoinWaitSecs)
 		return false
-	case <-n.stopCtx.Done():
+	case <-ctx.Done():
 		log.Printf("Context canceled while waiting for join on channel %s", channel)
 		return false
 	}
 }
 
-func (n *IRCNotifier) SendAlertMsg(alertMsg *AlertMsg) {
+func (n *IRCNotifier) SendAlertMsg(ctx context.Context, alertMsg *AlertMsg) {
 	if !n.sessionUp {
 		log.Printf("Cannot send alert to %s : IRC not connected", alertMsg.Channel)
 		ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_connected").Inc()
 		return
 	}
-	if !n.ChannelJoined(alertMsg.Channel) {
+	if !n.ChannelJoined(ctx, alertMsg.Channel) {
 		log.Printf("Cannot send alert to %s : cannot join channel", alertMsg.Channel)
 		ircSendMsgErrors.WithLabelValues(alertMsg.Channel, "not_joined").Inc()
 		return
@@ -232,27 +227,27 @@ func (n *IRCNotifier) ShutdownPhase() {
 	}
 }
 
-func (n *IRCNotifier) ConnectedPhase() {
+func (n *IRCNotifier) ConnectedPhase(ctx context.Context) {
 	select {
 	case alertMsg := <-n.AlertMsgs:
-		n.SendAlertMsg(&alertMsg)
+		n.SendAlertMsg(ctx, &alertMsg)
 	case <-n.sessionDownSignal:
 		n.sessionUp = false
 		n.channelReconciler.Stop()
 		n.Client.Quit("see ya")
 		ircConnectedGauge.Set(0)
-	case <-n.stopCtx.Done():
+	case <-ctx.Done():
 		log.Printf("IRC routine asked to terminate")
 	}
 }
 
-func (n *IRCNotifier) SetupPhase() {
+func (n *IRCNotifier) SetupPhase(ctx context.Context) {
 	if !n.Client.Connected() {
 		log.Printf("Connecting to IRC %s", n.Client.Config().Server)
-		if ok := n.BackoffCounter.DelayContext(n.stopCtx); !ok {
+		if ok := n.BackoffCounter.DelayContext(ctx); !ok {
 			return
 		}
-		if err := n.Client.ConnectContext(n.stopCtx); err != nil {
+		if err := n.Client.ConnectContext(ctx); err != nil {
 			log.Printf("Could not connect to IRC: %s", err)
 			return
 		}
@@ -262,23 +257,23 @@ func (n *IRCNotifier) SetupPhase() {
 	case <-n.sessionUpSignal:
 		n.sessionUp = true
 		n.MaybeIdentifyNick()
-		n.channelReconciler.Start(n.stopCtx)
+		n.channelReconciler.Start(ctx)
 		ircConnectedGauge.Set(1)
 	case <-n.sessionDownSignal:
 		log.Printf("Receiving a session down before the session is up, this is odd")
-	case <-n.stopCtx.Done():
+	case <-ctx.Done():
 		log.Printf("IRC routine asked to terminate")
 	}
 }
 
-func (n *IRCNotifier) Run() {
-	defer n.stopWg.Done()
+func (n *IRCNotifier) Run(ctx context.Context, stopWg *sync.WaitGroup) {
+	defer stopWg.Done()
 
-	for n.stopCtx.Err() != context.Canceled {
+	for ctx.Err() != context.Canceled {
 		if !n.sessionUp {
-			n.SetupPhase()
+			n.SetupPhase(ctx)
 		} else {
-			n.ConnectedPhase()
+			n.ConnectedPhase(ctx)
 		}
 	}
 	n.ShutdownPhase()
diff --git a/irc_test.go b/irc_test.go
index 9f22df8..97b36b4 100644
--- a/irc_test.go
+++ b/irc_test.go
@@ -42,26 +42,26 @@ func makeTestIRCConfig(IRCPort int) *Config {
 	}
 }
 
-func makeTestNotifier(t *testing.T, config *Config) (*IRCNotifier, chan AlertMsg, context.CancelFunc, *sync.WaitGroup) {
+func makeTestNotifier(t *testing.T, config *Config) (*IRCNotifier, chan AlertMsg, context.Context, context.CancelFunc, *sync.WaitGroup) {
 	fakeDelayerMaker := &FakeDelayerMaker{}
 	alertMsgs := make(chan AlertMsg)
 	ctx, cancel := context.WithCancel(context.Background())
 	stopWg := sync.WaitGroup{}
 	stopWg.Add(1)
-	notifier, err := NewIRCNotifier(ctx, &stopWg, config, alertMsgs, fakeDelayerMaker)
+	notifier, err := NewIRCNotifier(config, alertMsgs, fakeDelayerMaker)
 	if err != nil {
 		t.Fatal(fmt.Sprintf("Could not create IRC notifier: %s", err))
 	}
 	notifier.Client.Config().Flood = true
 
-	return notifier, alertMsgs, cancel, &stopWg
+	return notifier, alertMsgs, ctx, cancel, &stopWg
 }
 
 func TestServerPassword(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
 	config.IRCHostPass = "hostsecret"
-	notifier, _, cancel, _ := makeTestNotifier(t, config)
+	notifier, _, ctx, cancel, stopWg := makeTestNotifier(t, config)
 
 	var testStep sync.WaitGroup
 
@@ -72,11 +72,13 @@ func TestServerPassword(t *testing.T) {
 	server.SetHandler("JOIN", joinHandler)
 
 	testStep.Add(1)
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	testStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -95,7 +97,7 @@ func TestServerPassword(t *testing.T) {
 func TestSendAlertOnPreJoinedChannel(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
-	notifier, alertMsgs, cancel, _ := makeTestNotifier(t, config)
+	notifier, alertMsgs, ctx, cancel, stopWg := makeTestNotifier(t, config)
 
 	var testStep sync.WaitGroup
 
@@ -113,7 +115,7 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) {
 	server.SetHandler("JOIN", joinedHandler)
 
 	testStep.Add(1)
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	testStep.Wait()
 
@@ -131,6 +133,8 @@ func TestSendAlertOnPreJoinedChannel(t *testing.T) {
 	testStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -150,7 +154,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
 	config.UsePrivmsg = true
-	notifier, alertMsgs, cancel, _ := makeTestNotifier(t, config)
+	notifier, alertMsgs, ctx, cancel, stopWg := makeTestNotifier(t, config)
 
 	var testStep sync.WaitGroup
 
@@ -168,7 +172,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
 	server.SetHandler("JOIN", joinedHandler)
 
 	testStep.Add(1)
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	testStep.Wait()
 
@@ -186,6 +190,8 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
 	testStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -204,7 +210,7 @@ func TestUsePrivmsgToSendAlertOnPreJoinedChannel(t *testing.T) {
 func TestSendAlertAndJoinChannel(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
-	notifier, alertMsgs, cancel, _ := makeTestNotifier(t, config)
+	notifier, alertMsgs, ctx, cancel, stopWg := makeTestNotifier(t, config)
 
 	var testStep sync.WaitGroup
 
@@ -220,7 +226,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
 	server.SetHandler("JOIN", joinHandler)
 
 	testStep.Add(1)
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	testStep.Wait()
 
@@ -238,6 +244,8 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
 	testStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -258,7 +266,7 @@ func TestSendAlertAndJoinChannel(t *testing.T) {
 func TestSendAlertDisconnected(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
-	notifier, alertMsgs, cancel, _ := makeTestNotifier(t, config)
+	notifier, alertMsgs, ctx, cancel, stopWg := makeTestNotifier(t, config)
 
 	var testStep, holdUserStep sync.WaitGroup
 
@@ -278,7 +286,7 @@ func TestSendAlertDisconnected(t *testing.T) {
 	}
 	server.SetHandler("USER", holdUser)
 
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	// Alert channels is not consumed while disconnected
 	select {
@@ -314,6 +322,8 @@ func TestSendAlertDisconnected(t *testing.T) {
 	testStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -333,7 +343,7 @@ func TestSendAlertDisconnected(t *testing.T) {
 func TestReconnect(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
-	notifier, _, cancel, _ := makeTestNotifier(t, config)
+	notifier, _, ctx, cancel, stopWg := makeTestNotifier(t, config)
 
 	var testStep sync.WaitGroup
 
@@ -344,7 +354,7 @@ func TestReconnect(t *testing.T) {
 	server.SetHandler("JOIN", joinHandler)
 
 	testStep.Add(1)
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	// Wait until the pre-joined channel is seen.
 	testStep.Wait()
@@ -357,6 +367,8 @@ func TestReconnect(t *testing.T) {
 	testStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -382,7 +394,7 @@ func TestConnectErrorRetry(t *testing.T) {
 	// Attempt SSL handshake. The server does not support it, resulting in
 	// a connection error.
 	config.IRCUseSSL = true
-	notifier, _, cancel, _ := makeTestNotifier(t, config)
+	notifier, _, ctx, cancel, stopWg := makeTestNotifier(t, config)
 	// Pilot reconnect attempts via backoff delay to prevent race
 	// conditions in the test while we change the components behavior on
 	// the fly.
@@ -398,7 +410,7 @@ func TestConnectErrorRetry(t *testing.T) {
 
 	server.SetCloseEarly(earlyHandler)
 
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	delayer.StopDelay <- true
 
@@ -419,6 +431,8 @@ func TestConnectErrorRetry(t *testing.T) {
 	joinStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -437,7 +451,7 @@ func TestIdentify(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
 	config.IRCNickPass = "nickpassword"
-	notifier, _, cancel, _ := makeTestNotifier(t, config)
+	notifier, _, ctx, cancel, stopWg := makeTestNotifier(t, config)
 	notifier.NickservDelayWait = 0 * time.Second
 
 	var testStep sync.WaitGroup
@@ -451,11 +465,13 @@ func TestIdentify(t *testing.T) {
 	server.SetHandler("JOIN", joinHandler)
 
 	testStep.Add(1)
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	testStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -475,7 +491,7 @@ func TestGhostAndIdentify(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
 	config.IRCNickPass = "nickpassword"
-	notifier, _, cancel, _ := makeTestNotifier(t, config)
+	notifier, _, ctx, cancel, stopWg := makeTestNotifier(t, config)
 	notifier.NickservDelayWait = 0 * time.Second
 
 	var testStep, usedNick, unregisteredNickHandler sync.WaitGroup
@@ -503,7 +519,7 @@ func TestGhostAndIdentify(t *testing.T) {
 	server.SetHandler("JOIN", joinHandler)
 
 	testStep.Add(1)
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	usedNick.Wait()
 	server.SetHandler("NICK", nil)
@@ -512,6 +528,8 @@ func TestGhostAndIdentify(t *testing.T) {
 	testStep.Wait()
 
 	cancel()
+	stopWg.Wait()
+
 	server.Stop()
 
 	expectedCommands := []string{
@@ -533,7 +551,7 @@ func TestGhostAndIdentify(t *testing.T) {
 func TestStopRunningWhenHalfConnected(t *testing.T) {
 	server, port := makeTestServer(t)
 	config := makeTestIRCConfig(port)
-	notifier, _, cancel, stopWg := makeTestNotifier(t, config)
+	notifier, _, ctx, cancel, stopWg := makeTestNotifier(t, config)
 
 	var testStep, holdQuitWait sync.WaitGroup
 
@@ -556,12 +574,11 @@ func TestStopRunningWhenHalfConnected(t *testing.T) {
 	}
 	server.SetHandler("QUIT", holdQuit)
 
-	go notifier.Run()
+	go notifier.Run(ctx, stopWg)
 
 	testStep.Wait()
 
 	cancel()
-
 	stopWg.Wait()
 
 	holdQuitWait.Wait()
diff --git a/main.go b/main.go
index 8a318dc..1a1350d 100644
--- a/main.go
+++ b/main.go
@@ -59,12 +59,12 @@ func main() {
 	alertMsgs := make(chan AlertMsg, config.AlertBufferSize)
 
 	stopWg.Add(1)
-	ircNotifier, err := NewIRCNotifier(ctx, &stopWg, config, alertMsgs, &BackoffMaker{})
+	ircNotifier, err := NewIRCNotifier(config, alertMsgs, &BackoffMaker{})
 	if err != nil {
 		log.Printf("Could not create IRC notifier: %s", err)
 		return
 	}
-	go ircNotifier.Run()
+	go ircNotifier.Run(ctx, &stopWg)
 
 	httpServer, err := NewHTTPServer(config, alertMsgs)
 	if err != nil {