From 5aa9db4e51bfaa7781a03020b8c500bdce4c49bc Mon Sep 17 00:00:00 2001 From: Milarin Date: Thu, 15 Feb 2024 17:49:22 +0100 Subject: [PATCH] manage clients via callback function --- multiwriter.go | 7 ++----- server.go | 28 ++++++++++++++++++---------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/multiwriter.go b/multiwriter.go index 31c2325..98fc8e2 100644 --- a/multiwriter.go +++ b/multiwriter.go @@ -10,15 +10,12 @@ var _ io.Writer = &multiWriter{} func (t *multiWriter) Write(p []byte) (n int, err error) { for _, w := range t.writers { - n, err = w.Write(p) - if err != nil || n != len(p) { - continue - } + w.Write(p) } return len(p), nil } -func MultiWriter(writers ...io.Writer) io.Writer { +func newMultiWriter(writers ...io.Writer) io.Writer { allWriters := make([]io.Writer, 0, len(writers)) for _, w := range writers { if mw, ok := w.(*multiWriter); ok { diff --git a/server.go b/server.go index f87b6a3..209f67b 100644 --- a/server.go +++ b/server.go @@ -14,9 +14,11 @@ type Server struct { socketPath string server net.Listener clients *cmap.Map[net.Conn, struct{}] + + onNewClient func(client net.Conn) } -func Listen(socketPath string) (*Server, error) { +func Listen(socketPath string, onNewClient func(client net.Conn)) (*Server, error) { absPath, err := filepath.Abs(socketPath) if err != nil { return nil, err @@ -28,10 +30,19 @@ func Listen(socketPath string) (*Server, error) { return nil, err } + if onNewClient == nil { + onNewClient = func(client net.Conn) { + data := make([]byte, 1024) + for _, err := client.Read(data); err == nil; _, err = client.Read(data) { + } + } + } + s := &Server{ - socketPath: absPath, - server: server, - clients: cmap.New[net.Conn, struct{}](), + socketPath: absPath, + server: server, + clients: cmap.New[net.Conn, struct{}](), + onNewClient: onNewClient, } go s.handleClients() @@ -41,7 +52,7 @@ func Listen(socketPath string) (*Server, error) { func (s *Server) Broadcast(r io.Reader) error { clients := slices.Map(s.clients.Keys(), func(c net.Conn) io.Writer { return c }) - w := MultiWriter(clients...) + w := newMultiWriter(clients...) if _, err := io.Copy(w, r); err != nil { return err @@ -59,14 +70,11 @@ func (s *Server) handleClients() { func (s *Server) handleClient(client net.Conn) { s.clients.Put(client, struct{}{}) defer s.clients.Delete(client) - - data := make([]byte, 1024) - for _, err := client.Read(data); err == nil; _, err = client.Read(data) { - - } + s.onNewClient(client) } func (s *Server) Close() error { + s.clients.Iter(func(client net.Conn, _ struct{}) { client.Close() }) if err := s.server.Close(); err != nil { return err