diff --git a/go.sum b/go.sum index 7e80efe..534c189 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= diff --git a/main.go b/main.go index 024c4d4..6f36c2f 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ import ( "time" "github.com/fatih/color" + "github.com/gorilla/websocket" "github.com/spf13/cobra" ) @@ -26,25 +27,39 @@ var ( Version = "dev" APIURL = "https://usewebhook.com/api/webhooks/" BaseURL = "https://usewebhook.com" + WSURL = "wss://usewebhook.com/ws/webhook/" SettingsFilename = ".usewebhook" ) // WebhookRequest represents a single webhook request type WebhookRequest struct { - RequestID string `json:"request_id"` - Timestamp string `json:"timestamp"` - IP string `json:"ip"` - Method string `json:"method"` - Query string `json:"query"` - Headers map[string]string `json:"headers"` - Body string `json:"body"` + RequestID string `json:"request_id"` + Timestamp string `json:"timestamp"` + IP string `json:"ip"` + CountryCode string `json:"country_code"` + UserAgent string `json:"user_agent"` + Method string `json:"method"` + Scheme string `json:"scheme"` + Hostname string `json:"hostname"` + Path string `json:"path"` + Query string `json:"query"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` } -// WebhookResponse represents the response from the webhook API +// WebhookResponse represents the response from the webhook HTTP API type WebhookResponse struct { Requests []WebhookRequest `json:"requests"` } +// WSMessage is the envelope for WebSocket messages from the server. +// type "webhook.init" contains historical Requests; type "webhook.new" contains a single Request. +type WSMessage struct { + Type string `json:"type"` + Requests []WebhookRequest `json:"requests"` + Request *WebhookRequest `json:"request"` +} + // Config represents the user's configuration type Config struct { WebhookHistory []string `json:"webhook_history"` @@ -57,7 +72,6 @@ type AppConfig struct { ForwardTo string WebhookID string RequestID string - PollSleep time.Duration InitialSleep time.Duration } @@ -196,43 +210,95 @@ func decodeBase64Body(encodedBody string) (string, string, error) { return string(decoded), originalContentType, nil } -// pollWebhook continuously polls the webhook API for new requests -func pollWebhook(config AppConfig) { - lastPollTime := time.Now().UTC() +// fetchSingleRequest fetches a specific request by ID from the HTTP API and exits +func fetchSingleRequest(config AppConfig) { + params := url.Values{} + params.Set("request_id", config.RequestID) - for { - params := url.Values{} - if config.RequestID != "" { - params.Set("request_id", config.RequestID) - } else { - params.Set("since", lastPollTime.Format(time.RFC3339)) + webhookData, err := fetchWebhookData(config.WebhookID, params) + if err != nil { + color.Red("Error fetching webhook data: %v", err) + os.Exit(1) + } + + if len(webhookData.Requests) == 0 { + color.Red("No requests found for request ID: %s", config.RequestID) + os.Exit(1) + } + + for _, request := range webhookData.Requests { + logRequest(request, config.FullLog) + if config.ForwardTo != "" { + forwardRequest(request, config.ForwardTo) } + } + os.Exit(0) +} - webhookData, err := fetchWebhookData(config.WebhookID, params) +// connectAndListen opens a WebSocket connection and dispatches incoming requests until an error occurs. +// seen tracks request IDs already processed; isFirstConnect suppresses the history batch on the initial connection. +func connectAndListen(config AppConfig, seen map[string]bool, isFirstConnect *bool) error { + conn, _, err := websocket.DefaultDialer.Dial(WSURL+config.WebhookID, http.Header{ + "Origin": []string{BaseURL}, + }) + if err != nil { + return err + } + defer conn.Close() + + for { + _, message, err := conn.ReadMessage() if err != nil { - color.Red("Error fetching webhook data: %v", err) - time.Sleep(config.InitialSleep) - continue + return err } - for _, request := range webhookData.Requests { - logRequest(request, config.FullLog) - if config.ForwardTo != "" { - forwardRequest(request, config.ForwardTo) - } + var msg WSMessage + if err := json.Unmarshal(message, &msg); err != nil { + color.Yellow("Warning: failed to parse message: %v", err) + continue } - // if single request mode, exit after the first request - if config.RequestID != "" { - if len(webhookData.Requests) <= 0 { - color.Red("No requests found for request ID: %s", config.RequestID) - os.Exit(1) + switch msg.Type { + case "webhook.init": + for _, req := range msg.Requests { + if *isFirstConnect { + // Mark historical requests as seen without displaying them + seen[req.RequestID] = true + } else if !seen[req.RequestID] { + // Requests that arrived while we were disconnected + seen[req.RequestID] = true + logRequest(req, config.FullLog) + if config.ForwardTo != "" { + forwardRequest(req, config.ForwardTo) + } + } + } + *isFirstConnect = false + + case "webhook.new": + if msg.Request != nil && !seen[msg.Request.RequestID] { + seen[msg.Request.RequestID] = true + logRequest(*msg.Request, config.FullLog) + if config.ForwardTo != "" { + forwardRequest(*msg.Request, config.ForwardTo) + } } - os.Exit(0) } + } +} - lastPollTime = time.Now().UTC() - time.Sleep(config.PollSleep) +// listenWebSocket connects via WebSocket and reconnects automatically on disconnect. +// The seen map and isFirstConnect flag persist across reconnects to avoid replaying requests. +func listenWebSocket(config AppConfig) { + seen := make(map[string]bool) + isFirstConnect := true + + for { + err := connectAndListen(config, seen, &isFirstConnect) + if err != nil { + color.Red("WebSocket error: %v. Reconnecting...", err) + time.Sleep(config.InitialSleep) + } } } @@ -311,7 +377,6 @@ func saveConfig(config *Config) error { // createRootCommand creates and returns the root command for the CLI func createRootCommand() *cobra.Command { appConfig := AppConfig{ - PollSleep: 3 * time.Second, InitialSleep: 1 * time.Second, } @@ -378,6 +443,7 @@ func runRootCommand(cmd *cobra.Command, args []string, appConfig *AppConfig) { if appConfig.RequestID != "" { color.Green("Single request mode. Retrieving webhook=%s request=%s\n\n", appConfig.WebhookID, appConfig.RequestID) + fetchSingleRequest(*appConfig) } else { color.Green("Dashboard: %s/?id=%s", BaseURL, appConfig.WebhookID) color.Green("Webhook URL: %s/%s", BaseURL, appConfig.WebhookID) @@ -385,8 +451,8 @@ func runRootCommand(cmd *cobra.Command, args []string, appConfig *AppConfig) { color.Green("Forwarding to: %s", appConfig.ForwardTo) } color.HiBlack("\nPress Ctrl+C to stop\n\n") + listenWebSocket(*appConfig) } - pollWebhook(*appConfig) } // contains checks if a slice contains a specific item diff --git a/main_test.go b/main_test.go index a289ff3..c79858c 100644 --- a/main_test.go +++ b/main_test.go @@ -1,9 +1,124 @@ package main import ( + "encoding/json" + "fmt" + "net/http" + "strings" "testing" + "time" + + "github.com/gorilla/websocket" ) +// sharedWebhookID is a fixed test webhook — reused across all live tests +const sharedWebhookID = "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4" + +func dialTestWS(t *testing.T) *websocket.Conn { + t.Helper() + conn, _, err := websocket.DefaultDialer.Dial(WSURL+sharedWebhookID, http.Header{ + "Origin": []string{BaseURL}, + }) + if err != nil { + t.Fatalf("failed to connect to WebSocket: %v", err) + } + return conn +} + +func readWSMessage(t *testing.T, conn *websocket.Conn, timeout time.Duration) WSMessage { + t.Helper() + conn.SetReadDeadline(time.Now().Add(timeout)) + _, raw, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read WebSocket message: %v", err) + } + var msg WSMessage + if err := json.Unmarshal(raw, &msg); err != nil { + t.Fatalf("failed to unmarshal message: %v\nraw: %s", err, raw) + } + return msg +} + +func sendWebhookRequest(t *testing.T, method, body string) { + t.Helper() + url := fmt.Sprintf("%s/%s", BaseURL, sharedWebhookID) + req, err := http.NewRequest(method, url, strings.NewReader(body)) + if err != nil { + t.Fatalf("failed to create HTTP request: %v", err) + } + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("failed to send webhook request: %v", err) + } + defer resp.Body.Close() +} + +// TestWSConnectReceivesInit verifies that connecting to the WebSocket returns a webhook.init message +func TestWSConnectReceivesInit(t *testing.T) { + conn := dialTestWS(t) + defer conn.Close() + + msg := readWSMessage(t, conn, 5*time.Second) + + if msg.Type != "webhook.init" { + t.Errorf("expected type webhook.init, got %q", msg.Type) + } +} + +// TestWSReceivesNewRequestOnHTTPPost verifies that sending an HTTP request triggers a webhook.new message +func TestWSReceivesNewRequestOnHTTPPost(t *testing.T) { + conn := dialTestWS(t) + defer conn.Close() + + // Consume the init message + readWSMessage(t, conn, 5*time.Second) + + // Send a POST request to the webhook URL + payload := `{"test": "live-integration"}` + sendWebhookRequest(t, http.MethodPost, payload) + + msg := readWSMessage(t, conn, 10*time.Second) + + if msg.Type != "webhook.new" { + t.Errorf("expected type webhook.new, got %q", msg.Type) + } + if msg.Request == nil { + t.Fatal("expected request to be non-nil") + } + if msg.Request.Method != http.MethodPost { + t.Errorf("expected method POST, got %q", msg.Request.Method) + } + if !strings.Contains(msg.Request.Body, "live-integration") { + t.Errorf("expected body to contain 'live-integration', got %q", msg.Request.Body) + } +} + +// TestWSReceivesNewRequestOnHTTPGet verifies a GET request also triggers webhook.new +func TestWSReceivesNewRequestOnHTTPGet(t *testing.T) { + conn := dialTestWS(t) + defer conn.Close() + + // Consume the init message + readWSMessage(t, conn, 5*time.Second) + + sendWebhookRequest(t, http.MethodGet, "") + + msg := readWSMessage(t, conn, 10*time.Second) + + if msg.Type != "webhook.new" { + t.Errorf("expected type webhook.new, got %q", msg.Type) + } + if msg.Request == nil { + t.Fatal("expected request to be non-nil") + } + if msg.Request.Method != http.MethodGet { + t.Errorf("expected method GET, got %q", msg.Request.Method) + } +} + // TestExtractIdsFromURLOrArgs covers the various input formats accepted by the CLI func TestExtractIdsFromURLOrArgs(t *testing.T) { id := "409bdb1f81abfa826c2022d18ddff2e5"