From 338139ad0f650ce201eb1f9f51bb770f5ecc23e4 Mon Sep 17 00:00:00 2001 From: Kioubit Date: Wed, 22 Dec 2021 06:04:00 -0500 Subject: [PATCH] Release all ressources on shutdown --- config.go | 4 +++- main.go | 48 +++++++++++++++++++++++------------------ process.go | 60 ++++++++++++++++++++++++++++++++++++++-------------- rawsocket.go | 7 ++++-- responder.go | 9 +++++++- 5 files changed, 87 insertions(+), 41 deletions(-) diff --git a/config.go b/config.go index 846b747..5fe2b4a 100644 --- a/config.go +++ b/config.go @@ -39,6 +39,7 @@ func readConfig(dest string) { } if strings.HasPrefix(line, "responder") { obj := configResponder{} + filter := "" for { scanner.Scan() line = scanner.Text() @@ -46,9 +47,10 @@ func readConfig(dest string) { obj.Iface = strings.TrimSpace(strings.TrimPrefix(line, "iface")) } if strings.HasPrefix(line, "filter") { - obj.Filter = strings.TrimSpace(strings.TrimPrefix(line, "filter")) + filter += strings.TrimSpace(strings.TrimPrefix(line, "filter")) + ";" } if strings.HasPrefix(line, "}") { + obj.Filter = filter break } } diff --git a/main.go b/main.go index de4f9b4..08a2a8c 100644 --- a/main.go +++ b/main.go @@ -6,28 +6,34 @@ import ( ) func main() { - fmt.Println("PNDPD Version 0.3 by Kioubit") + fmt.Println("PNDPD Version 0.4 by Kioubit") + + if len(os.Args) <= 2 { + printUsage() + return + } + + switch os.Args[1] { + case "respond": + if len(os.Args) == 4 { + go simpleRespond(os.Args[2], parseFilter(os.Args[3])) + } else { + go simpleRespond(os.Args[2], nil) + } + case "proxy": + go proxy(os.Args[2], os.Args[3]) + case "readconfig": + readConfig(os.Args[2]) + default: + printUsage() + return + } + waitForSignal() +} + +func printUsage() { + fmt.Println("Specify command") fmt.Println("Usage: pndpd readconfig ") fmt.Println("Usage: pndpd respond ") fmt.Println("Usage: pndpd proxy ") - - if len(os.Args) <= 1 { - fmt.Println("Specify command") - os.Exit(1) - } - if os.Args[1] == "respond" { - if len(os.Args) == 4 { - simpleRespond(os.Args[2], parseFilter(os.Args[3])) - } else { - simpleRespond(os.Args[2], nil) - } - } - if os.Args[1] == "proxy" { - proxy(os.Args[2], os.Args[3]) - } - - if os.Args[1] == "readConfig" { - readConfig(os.Args[2]) - } - } diff --git a/process.go b/process.go index 0ab9a87..3cae951 100644 --- a/process.go +++ b/process.go @@ -5,28 +5,63 @@ import ( "net" "os" "os/signal" + "runtime/pprof" "strings" + "sync" "syscall" + "time" ) var GlobalDebug = false +// Items needed for graceful shutdown +var stop = make(chan struct{}) +var stopWg sync.WaitGroup +var sigCh = make(chan os.Signal) + +func waitForSignal() { + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + <-sigCh + fmt.Println("Shutting down...") + close(stop) + if wgWaitTimout(&stopWg, 10*time.Second) { + fmt.Println("Done") + } else { + fmt.Println("Aborting shutdown, since it is taking too long") + pprof.Lookup("goroutine").WriteTo(os.Stdout, 1) + } + + os.Exit(0) +} + +func wgWaitTimout(wg *sync.WaitGroup, timeout time.Duration) bool { + t := make(chan struct{}) + go func() { + defer close(t) + wg.Wait() + }() + select { + case <-t: + return true + case <-time.After(timeout): + return false + } +} + func simpleRespond(iface string, filter []*net.IPNet) { + defer stopWg.Done() + stopWg.Add(3) // This function, 2x goroutines requests := make(chan *NDRequest, 100) defer close(requests) go respond(iface, requests, NDP_ADV, filter) go listen(iface, requests, NDP_SOL) - - sigCh := make(chan os.Signal) - signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) - select { - case <-sigCh: - fmt.Println("Exit") - os.Exit(0) - } + <-stop } func proxy(iface1, iface2 string) { + defer stopWg.Done() + stopWg.Add(9) // This function, 8x goroutines + req_iface1_sol_iface2 := make(chan *NDRequest, 100) defer close(req_iface1_sol_iface2) go listen(iface1, req_iface1_sol_iface2, NDP_SOL) @@ -46,14 +81,7 @@ func proxy(iface1, iface2 string) { defer close(req_iface2_adv_iface1) go listen(iface2, req_iface2_adv_iface1, NDP_ADV) go respond(iface1, req_iface2_adv_iface1, NDP_ADV, nil) - - sigCh := make(chan os.Signal) - signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) - select { - case <-sigCh: - fmt.Println("Exit") - os.Exit(0) - } + <-stop } func parseFilter(f string) []*net.IPNet { diff --git a/rawsocket.go b/rawsocket.go index dd8e824..d6dfe5d 100644 --- a/rawsocket.go +++ b/rawsocket.go @@ -41,7 +41,6 @@ func htons(v uint16) int { func htons16(v uint16) uint16 { return v<<8 | v>>8 } func listen(iface string, responder chan *NDRequest, requestType NDPType) { - niface, err := net.InterfaceByName(iface) if err != nil { panic(err.Error()) @@ -55,7 +54,11 @@ func listen(iface string, responder chan *NDRequest, requestType NDPType) { if err != nil { fmt.Println(err.Error()) } - defer syscall.Close(fd) + go func() { + <-stop + syscall.Close(fd) + stopWg.Done() // syscall.read does not release when the file descriptor is closed + }() fmt.Println("Obtained fd ", fd) if len([]byte(iface)) > syscall.IFNAMSIZ { diff --git a/responder.go b/responder.go index b4b844b..47819b0 100644 --- a/responder.go +++ b/responder.go @@ -10,6 +10,7 @@ import ( var globalFd int func respond(iface string, requests chan *NDRequest, respondType NDPType, filter []*net.IPNet) { + defer stopWg.Done() fd, err := syscall.Socket(syscall.AF_INET6, syscall.SOCK_RAW, syscall.IPPROTO_RAW) if err != nil { panic(err) @@ -46,7 +47,13 @@ func respond(iface string, requests chan *NDRequest, respondType NDPType, filter } for { - n := <-requests + var n *NDRequest + select { + case <-stop: + return + case n = <-requests: + } + if filter != nil { ok := false for _, i := range filter {