package middleware import ( "bytes" "crypto/hmac" "crypto/sha256" "encoding/hex" "io" "net/http" "strconv" "time" ) func HMACAuth(secret string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := r.Header.Get("X-Timestamp") sig := r.Header.Get("X-Signature") if ts == "" || sig == "" { http.Error(w, "Missing HMAC headers", http.StatusUnauthorized) return } unix, err := strconv.ParseInt(ts, 10, 64) if err != nil { http.Error(w, "Invalid timestamp", http.StatusUnauthorized) return } now := time.Now().Unix() const skew = int64(300) if unix < now-skew || unix > now+skew { http.Error(w, "Timestamp out of allowed range", http.StatusUnauthorized) return } var bodyBytes []byte if r.Body != nil { bodyBytes, _ = io.ReadAll(r.Body) } r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) msg := r.Method + "\n" + r.URL.Path + "\n" + ts + "\n" + string(bodyBytes) mac := hmac.New(sha256.New, []byte(secret)) mac.Write([]byte(msg)) expected := mac.Sum(nil) expectedHex := hex.EncodeToString(expected) provided, err := hex.DecodeString(sig) if err != nil { http.Error(w, "Invalid signature encoding", http.StatusUnauthorized) return } if !hmac.Equal(expected, provided) && sig != expectedHex { http.Error(w, "Invalid signature", http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } }