Fixes, readme
This commit is contained in:
+72
-15
@@ -1,10 +1,12 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -14,21 +16,35 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
engine *gin.Engine
|
||||
port int
|
||||
httpServer *http.Server
|
||||
engine *gin.Engine
|
||||
configPath string
|
||||
sanitizer atomic.Pointer[proxy.Sanitizer]
|
||||
apiKeys atomic.Pointer[map[string]struct{}]
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Server {
|
||||
s := &Server{configPath: "config.yaml"}
|
||||
|
||||
san := proxy.NewSanitizer(cfg.Sanitize)
|
||||
s.sanitizer.Store(san)
|
||||
|
||||
keys := makeKeySet(cfg.APIKeys)
|
||||
s.apiKeys.Store(&keys)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
engine := gin.New()
|
||||
engine.Use(gin.Recovery())
|
||||
engine.Use(corsMiddleware())
|
||||
engine.Use(authMiddleware(cfg.APIKeys))
|
||||
engine.Use(s.authMiddleware())
|
||||
|
||||
handler := proxy.HandleMessages(pool, profile, cfg.Sanitize)
|
||||
handler := proxy.HandleMessages(pool, profile, func() *proxy.Sanitizer {
|
||||
return s.sanitizer.Load()
|
||||
})
|
||||
engine.POST("/v1/messages", handler)
|
||||
engine.POST("/messages", handler)
|
||||
|
||||
engine.POST("/reload", s.handleReload())
|
||||
engine.GET("/healthz", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
@@ -37,12 +53,56 @@ func New(cfg *config.Config, pool *auth.Pool, profile *proxy.SniffedProfile) *Se
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||
})
|
||||
|
||||
return &Server{engine: engine, port: cfg.Port}
|
||||
s.engine = engine
|
||||
s.httpServer = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Port),
|
||||
Handler: engine,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
addr := fmt.Sprintf(":%d", s.port)
|
||||
return s.engine.Run(addr)
|
||||
return s.httpServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
return s.httpServer.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *Server) handleReload() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
cfg, err := config.Load(s.configPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("load config: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
san := proxy.NewSanitizer(cfg.Sanitize)
|
||||
s.sanitizer.Store(san)
|
||||
|
||||
keys := makeKeySet(cfg.APIKeys)
|
||||
s.apiKeys.Store(&keys)
|
||||
|
||||
log.Printf("config reloaded: %d tool renames, %d system rules, %d body rules, %d api keys",
|
||||
len(cfg.Sanitize.Tools), len(cfg.Sanitize.System), len(cfg.Sanitize.Body), len(cfg.APIKeys))
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "reloaded",
|
||||
"tool_renames": len(cfg.Sanitize.Tools),
|
||||
"system_rules": len(cfg.Sanitize.System),
|
||||
"body_rules": len(cfg.Sanitize.Body),
|
||||
"api_keys": len(cfg.APIKeys),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makeKeySet(apiKeys []string) map[string]struct{} {
|
||||
keySet := make(map[string]struct{}, len(apiKeys))
|
||||
for _, k := range apiKeys {
|
||||
keySet[k] = struct{}{}
|
||||
}
|
||||
return keySet
|
||||
}
|
||||
|
||||
func corsMiddleware() gin.HandlerFunc {
|
||||
@@ -60,14 +120,10 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func authMiddleware(apiKeys []string) gin.HandlerFunc {
|
||||
keySet := make(map[string]struct{}, len(apiKeys))
|
||||
for _, k := range apiKeys {
|
||||
keySet[k] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *Server) authMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.URL.Path == "/healthz" {
|
||||
path := c.Request.URL.Path
|
||||
if path == "/healthz" || path == "/reload" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -85,7 +141,8 @@ func authMiddleware(apiKeys []string) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := keySet[token]; !ok {
|
||||
keys := s.apiKeys.Load()
|
||||
if _, ok := (*keys)[token]; !ok {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid api key"})
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user