From 3f796223d14801361a416417c536b8b711aeca73 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 7 Jan 2016 23:16:41 +0000 Subject: [PATCH] add systemd socket activation support --- server.go | 173 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 134 insertions(+), 39 deletions(-) diff --git a/server.go b/server.go index 2974643..e36a115 100644 --- a/server.go +++ b/server.go @@ -12,26 +12,57 @@ import ( "path" "regexp" "sort" + "strconv" "strings" + "sync" + "sync/atomic" "syscall" + "time" ) type Server struct { - DataPath string + DataPath string + LastConnection time.Time + SocketActivation bool + stopListening int32 + activeWorkers sync.WaitGroup } -func (s Server) Run(listener *net.TCPListener) { - for { - conn, e := listener.AcceptTCP() - if e != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", e) +func New(dataPath string) *Server { + return &Server{dataPath, time.Now(), false, 0, sync.WaitGroup{}} +} + +func (s *Server) Run(listener *net.TCPListener) { + atomic.StoreInt32(&s.stopListening, 0) + s.activeWorkers.Add(1) + defer s.activeWorkers.Done() + defer listener.Close() + for atomic.LoadInt32(&s.stopListening) != 1 { + if e := listener.SetDeadline(time.Now().Add(time.Second)); e != nil { + fmt.Fprintf(os.Stderr, "Error setting deadline: %v\n", e) continue } + conn, err := listener.AcceptTCP() + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + continue + } else { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + } + s.activeWorkers.Add(1) + s.LastConnection = time.Now() go s.handleConn(conn) } } +func (s *Server) Shutdown() { + atomic.StoreInt32(&s.stopListening, 1) + s.activeWorkers.Wait() +} + func readCidrs(path string) ([]net.IPNet, error) { files, err := ioutil.ReadDir(path) if err != nil { @@ -118,8 +149,12 @@ func parseQuery(conn *net.TCPConn) map[int]interface{} { return queryArgs } -func (s Server) handleConn(conn *net.TCPConn) { - defer conn.Close() +func (s *Server) handleConn(conn *net.TCPConn) { + defer func() { + conn.Close() + s.activeWorkers.Done() + }() + queryArgs := parseQuery(conn) if queryArgs == nil { return @@ -145,7 +180,7 @@ func (s Server) handleConn(conn *net.TCPConn) { } } -func (s Server) printNet(conn *net.TCPConn, name string, ip net.IP) bool { +func (s *Server) printNet(conn *net.TCPConn, name string, ip net.IP) bool { routePath := path.Join(s.DataPath, name) cidrs, err := readCidrs(routePath) if err != nil { @@ -163,7 +198,7 @@ func (s Server) printNet(conn *net.TCPConn, name string, ip net.IP) bool { return found } -func (s Server) printObject(conn *net.TCPConn, objType string, obj string) { +func (s *Server) printObject(conn *net.TCPConn, objType string, obj string) { f, err := os.Open(path.Join(s.DataPath, objType, obj)) defer f.Close() if err != nil && !os.IsNotExist(err) { @@ -173,11 +208,10 @@ func (s Server) printObject(conn *net.TCPConn, objType string, obj string) { } type options struct { - Port uint - Address string - Registry string - User string - Group string + Port uint + Address string + Registry string + SocketTimeout float64 } func parseFlags() options { @@ -185,6 +219,8 @@ func parseFlags() options { flag.UintVar(&o.Port, "port", 43, "port to listen") flag.StringVar(&o.Address, "address", "*", "address to listen") flag.StringVar(&o.Registry, "registry", ".", "path to dn42 registry") + msg := "timeout in seconds before suspending the service when using socket activation" + flag.Float64Var(&o.SocketTimeout, "timeout", 10, msg) flag.Parse() if o.Address == "*" { o.Address = "" @@ -192,39 +228,98 @@ func parseFlags() options { return o } -func main() { - opts := parseFlags() - registryPath := path.Join(opts.Registry, "data") +func Listeners() []*net.TCPListener { + defer os.Unsetenv("LISTEN_PID") + defer os.Unsetenv("LISTEN_FDS") - if _, err := os.Stat(registryPath); err != nil { - fmt.Fprintf(os.Stderr, - "Cannot access '%s', should be in the registry repository: %s\n", - registryPath, - err) - os.Exit(1) + pid, err := strconv.Atoi(os.Getenv("LISTEN_PID")) + if err != nil || pid != os.Getpid() { + return nil } - address := opts.Address + ":" + fmt.Sprint(opts.Port) - listener, err := net.Listen("tcp", address) + nfds, err := strconv.Atoi(os.Getenv("LISTEN_FDS")) + if err != nil || nfds == 0 { + return nil + } + + listeners := make([]*net.TCPListener, 0) + for fd := 3; fd < 3+nfds; fd++ { + syscall.CloseOnExec(fd) + file := os.NewFile(uintptr(fd), "LISTEN_FD_"+strconv.Itoa(fd)) + if listener, err := net.FileListener(file); err == nil { + if l, ok := listener.(*net.TCPListener); ok { + listeners = append(listeners, l) + } + } + } + + return listeners +} + +func checkDataPath(registry string) (string, error) { + dataPath := path.Join(registry, "data") + + if _, err := os.Stat(dataPath); err != nil { + return "", fmt.Errorf("Cannot access '%s', should be in the registry repository: %s\n", + dataPath, + err) + } + return dataPath, nil +} + +func createServer(opts options) (*Server, error) { + dataPath, err := checkDataPath(opts.Registry) + if err != nil { + return nil, err + } + server := New(dataPath) + + if listeners := Listeners(); len(listeners) > 0 { + fmt.Printf("socket action detected\n") + server.SocketActivation = true + for _, listener := range listeners { + go server.Run(listener) + } + } else { + address := opts.Address + ":" + strconv.Itoa(int(opts.Port)) + listener, err := net.Listen("tcp", address) + if err != nil { + return nil, err + } + go server.Run(listener.(*net.TCPListener)) + } + return server, nil +} + +func main() { + opts := parseFlags() + server, err := createServer(opts) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } - server := Server{registryPath} - go server.Run(listener.(*net.TCPListener)) + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + signal.Notify(signals, syscall.SIGTERM) + signal.Notify(signals, syscall.SIGINT) - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt) - signal.Notify(c, syscall.SIGTERM) - signal.Notify(c, syscall.SIGINT) - - for { - select { - case <-c: - fmt.Printf("Shutting socket down\n") - listener.Close() - os.Exit(0) + if server.SocketActivation { + Out: + for { + select { + case <-signals: + break Out + case <-time.After(time.Second * 3): + if time.Since(server.LastConnection).Seconds() >= opts.SocketTimeout { + break Out + } + } } + } else { + <-signals } + + fmt.Printf("Shutting socket(s) down (takes up to 1s)\n") + server.Shutdown() }