package proxy import ( "strconv" "strings" "github.com/rs/zerolog/log" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/fujin/anthropic-proxy/internal/config" ) type Sanitizer struct { toolsForward map[string]string toolsReverse map[string]string systemRules []config.ReplaceRule bodyRules []config.ReplaceRule } func NewSanitizer(cfg config.SanitizeConfig) *Sanitizer { s := &Sanitizer{ toolsForward: make(map[string]string), toolsReverse: make(map[string]string), systemRules: cfg.System, bodyRules: cfg.Body, } for _, r := range cfg.Tools { s.toolsForward[r.From] = r.To s.toolsReverse[r.To] = r.From } return s } func (s *Sanitizer) SanitizeRequest(body []byte) []byte { body = s.renameTools(body) body = s.replaceSystem(body) body = s.replaceBody(body) return body } func (s *Sanitizer) DesanitizeResponse(body []byte) []byte { content := gjson.GetBytes(body, "content") if !content.Exists() || !content.IsArray() { return body } for i, block := range content.Array() { if block.Get("type").String() != "tool_use" { continue } name := block.Get("name").String() if orig, ok := s.toolsReverse[name]; ok { if b, err := sjson.SetBytes(body, "content."+strconv.Itoa(i)+".name", orig); err != nil { log.Warn().Err(err).Str("tool", name).Msg("desanitize response: set name failed") } else { body = b } } } return body } func (s *Sanitizer) DesanitizeStreamEvent(line string) string { if !strings.Contains(line, "tool_use") || !strings.HasPrefix(line, "data: ") { return line } data := []byte(line[6:]) changed := false for _, path := range []string{"content_block.name", "delta.name"} { name := gjson.GetBytes(data, path).String() if orig, ok := s.toolsReverse[name]; ok { if b, err := sjson.SetBytes(data, path, orig); err != nil { log.Warn().Err(err).Str("tool", name).Msg("desanitize stream event: set name failed") } else { data = b changed = true } } } if changed { return "data: " + string(data) } return line } func (s *Sanitizer) renameTools(body []byte) []byte { if len(s.toolsForward) == 0 { return body } tools := gjson.GetBytes(body, "tools") if !tools.Exists() || !tools.IsArray() { return body } for i, tool := range tools.Array() { name := tool.Get("name").String() if newName, ok := s.toolsForward[name]; ok { if b, err := sjson.SetBytes(body, "tools."+strconv.Itoa(i)+".name", newName); err != nil { log.Warn().Err(err).Str("tool", name).Msg("rename tool failed") } else { body = b } } } return body } func (s *Sanitizer) replaceSystem(body []byte) []byte { if len(s.systemRules) == 0 { return body } system := gjson.GetBytes(body, "system") if !system.Exists() || !system.IsArray() { return body } for i, block := range system.Array() { text := block.Get("text").String() for _, rule := range s.systemRules { text = strings.ReplaceAll(text, rule.Match, rule.Replace) } if b, err := sjson.SetBytes(body, "system."+strconv.Itoa(i)+".text", text); err != nil { log.Warn().Err(err).Int("block", i).Msg("replace system text failed") } else { body = b } } return body } func (s *Sanitizer) replaceBody(body []byte) []byte { if len(s.bodyRules) == 0 { return body } str := string(body) for _, rule := range s.bodyRules { str = strings.ReplaceAll(str, rule.Match, rule.Replace) } return []byte(str) }