manage clients via callback function

This commit is contained in:
Milarin 2024-02-15 17:49:22 +01:00
parent f280a09feb
commit 5aa9db4e51
2 changed files with 20 additions and 15 deletions

View File

@ -10,15 +10,12 @@ var _ io.Writer = &multiWriter{}
func (t *multiWriter) Write(p []byte) (n int, err error) { func (t *multiWriter) Write(p []byte) (n int, err error) {
for _, w := range t.writers { for _, w := range t.writers {
n, err = w.Write(p) w.Write(p)
if err != nil || n != len(p) {
continue
}
} }
return len(p), nil return len(p), nil
} }
func MultiWriter(writers ...io.Writer) io.Writer { func newMultiWriter(writers ...io.Writer) io.Writer {
allWriters := make([]io.Writer, 0, len(writers)) allWriters := make([]io.Writer, 0, len(writers))
for _, w := range writers { for _, w := range writers {
if mw, ok := w.(*multiWriter); ok { if mw, ok := w.(*multiWriter); ok {

View File

@ -14,9 +14,11 @@ type Server struct {
socketPath string socketPath string
server net.Listener server net.Listener
clients *cmap.Map[net.Conn, struct{}] clients *cmap.Map[net.Conn, struct{}]
onNewClient func(client net.Conn)
} }
func Listen(socketPath string) (*Server, error) { func Listen(socketPath string, onNewClient func(client net.Conn)) (*Server, error) {
absPath, err := filepath.Abs(socketPath) absPath, err := filepath.Abs(socketPath)
if err != nil { if err != nil {
return nil, err return nil, err
@ -28,10 +30,19 @@ func Listen(socketPath string) (*Server, error) {
return nil, err return nil, err
} }
if onNewClient == nil {
onNewClient = func(client net.Conn) {
data := make([]byte, 1024)
for _, err := client.Read(data); err == nil; _, err = client.Read(data) {
}
}
}
s := &Server{ s := &Server{
socketPath: absPath, socketPath: absPath,
server: server, server: server,
clients: cmap.New[net.Conn, struct{}](), clients: cmap.New[net.Conn, struct{}](),
onNewClient: onNewClient,
} }
go s.handleClients() go s.handleClients()
@ -41,7 +52,7 @@ func Listen(socketPath string) (*Server, error) {
func (s *Server) Broadcast(r io.Reader) error { func (s *Server) Broadcast(r io.Reader) error {
clients := slices.Map(s.clients.Keys(), func(c net.Conn) io.Writer { return c }) clients := slices.Map(s.clients.Keys(), func(c net.Conn) io.Writer { return c })
w := MultiWriter(clients...) w := newMultiWriter(clients...)
if _, err := io.Copy(w, r); err != nil { if _, err := io.Copy(w, r); err != nil {
return err return err
@ -59,14 +70,11 @@ func (s *Server) handleClients() {
func (s *Server) handleClient(client net.Conn) { func (s *Server) handleClient(client net.Conn) {
s.clients.Put(client, struct{}{}) s.clients.Put(client, struct{}{})
defer s.clients.Delete(client) defer s.clients.Delete(client)
s.onNewClient(client)
data := make([]byte, 1024)
for _, err := client.Read(data); err == nil; _, err = client.Read(data) {
}
} }
func (s *Server) Close() error { func (s *Server) Close() error {
s.clients.Iter(func(client net.Conn, _ struct{}) { client.Close() })
if err := s.server.Close(); err != nil { if err := s.server.Close(); err != nil {
return err return err