From 812562e09637dcceff69acf7a661840a5e021c6a Mon Sep 17 00:00:00 2001 From: LandaMm Date: Sat, 19 Apr 2025 13:31:08 +0200 Subject: [PATCH] feat: support streamer routes --- hsp/server/router.go | 100 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 11 deletions(-) diff --git a/hsp/server/router.go b/hsp/server/router.go index de3f887..07c0c6a 100644 --- a/hsp/server/router.go +++ b/hsp/server/router.go @@ -1,7 +1,7 @@ package server import ( - "errors" + "fmt" "log" "net" @@ -9,22 +9,42 @@ import ( ) type RouteHandler func(req *hsp.Request) *hsp.Response +type StreamHandler func(req *hsp.Request, stream chan []byte) type Router struct { - Routes map[string]RouteHandler + routes map[string]RouteHandler + streamers map[string]StreamHandler + streamMaxSize uint64 + streamBufferSize uint16 } func NewRouter() *Router { return &Router{ - Routes: make(map[string]RouteHandler), + routes: make(map[string]RouteHandler), + streamers: make(map[string]StreamHandler), } } func (r *Router) AddRoute(pathname string, handler RouteHandler) { - if _, ok := r.Routes[pathname]; ok { + if _, ok := r.routes[pathname]; ok { log.Printf("WARN: Rewriting existing route '%s'\n", pathname) } - r.Routes[pathname] = handler + r.routes[pathname] = handler +} + +func (r *Router) AddStreamer(pathname string, handler StreamHandler) { + if _, ok := r.streamers[pathname]; ok { + log.Printf("WARN: Rewriting existing streamer '%s'\n", pathname) + } + r.streamers[pathname] = handler +} + +func (r *Router) SetStreamMaxSize(size uint64) { + r.streamMaxSize = size +} + +func (r *Router) SetStreamBufferSize(size uint16) { + r.streamBufferSize = size } func (r *Router) Handle(conn net.Conn) error { @@ -43,12 +63,70 @@ func (r *Router) Handle(conn net.Conn) error { if route, ok := packet.Headers["route"]; ok { log.Printf("[ROUTER] New connection to '%s'", route) - if handler, ok := r.Routes[route]; ok { - req := hsp.NewRequest(conn, packet) - res := handler(req) - _, err := dupl.WritePacket(res.ToPacket()) - return err + req := hsp.NewRequest(conn, packet) + + switch req.GetRequestKind() { + case "single-hit": + if handler, ok := r.routes[route]; ok { + res := handler(req) + _, err := dupl.WritePacket(res.ToPacket()) + return err + } + case "stream": + if handler, ok := r.streamers[route]; ok { + info, err := req.GetStreamInfo() + if err != nil { + _, err = dupl.WritePacket(hsp.NewErrorResponse(err).ToPacket()) + return err + } + + streamSize := uint64(min(info.TotalBytes, r.streamMaxSize)) + bufferSize := uint16(min(info.BufferSize, r.streamBufferSize)) + + res := hsp.NewStatusResponse(hsp.STATUS_SUCCESS) + res.AddHeader(hsp.H_XSTREAM, fmt.Sprintf("%d:%d", streamSize, bufferSize)) + res.AddHeader(hsp.H_XSTREAM_KEY, "0") // TODO: generate id + + _, err = dupl.WritePacket(res.ToPacket()) + if err != nil { + return err + } + + req := hsp.NewRequest(conn, res.ToPacket()) + bc := make(chan []byte) + + go func() { + handler(req, bc) + }() + + buf := make([]byte, bufferSize) + var totalReceived uint64 + totalReceived = 0 + for totalReceived < streamSize { + n, err := conn.Read(buf) + if err != nil || n <= 0 { + break + } + if n > 0 { + totalReceived += uint64(n) + } + } + + res = hsp.NewStatusResponse(hsp.STATUS_SUCCESS) + res.AddHeader(hsp.H_XSTREAM, fmt.Sprintf("%d:0", streamSize - totalReceived)) + res.AddHeader(hsp.H_XSTREAM_KEY, "0") // TODO: generate id + _, err = dupl.WritePacket(res.ToPacket()) + + conn.Close() + close(bc) + + return err + } + default: + return fmt.Errorf("Unsupported request kind: %s", req.GetRequestKind()) } } - return errors.New("Not Found") + + _, err = dupl.WritePacket(hsp.NewStatusResponse(hsp.STATUS_NOTFOUND).ToPacket()) + return err }