diff --git a/request.go b/request.go index 9c2cb13..7290f77 100644 --- a/request.go +++ b/request.go @@ -13,12 +13,13 @@ import ( ) type Request struct { - alias string - chainCallback Callback - hostURL *url.URL - metrics Metrics - restyRequest *resty.Request - startTime time.Time + alias string + chainCallback Callback + hostURL *url.URL + metrics Metrics + additionalAttrs map[string]string + restyRequest *resty.Request + startTime time.Time } // NewRequest creates a request for the specified HTTP method. @@ -36,6 +37,11 @@ func (r *Request) HostURL() *url.URL { return r.hostURL } +func (r *Request) SetMetricsAttrs(attrs map[string]string) *Request { + r.additionalAttrs = attrs + return r +} + // SetHostURL sets the host url for the request. func (r *Request) SetHostURL(url *url.URL) *Request { r.hostURL = url @@ -131,6 +137,11 @@ func (r *Request) Execute(method string, url string) (*Response, error) { metricsAlias = strings.Replace(metricsAlias, ".", "-", -1) + attrs := r.additionalAttrs + if len(r.additionalAttrs) == 0 { + attrs = map[string]string{} + } + return registerMetrics(metricsAlias, r.metrics, func() (*Response, error) { execute := func() (*Response, error) { r.startTime = time.Now() @@ -142,24 +153,21 @@ func (r *Request) Execute(method string, url string) (*Response, error) { } return r.chainCallback(execute) - }) + }, attrs) } -func registerMetrics(key string, metrics Metrics, f func() (*Response, error)) (*Response, error) { +func registerMetrics(key string, metrics Metrics, f func() (*Response, error), additionalAttrs map[string]string) (*Response, error) { resp, err := f() if metrics != nil { go func(resp *Response, err error) { - var attrs map[string]string if resp != nil { - attrs = map[string]string{ - "host": resp.Request().HostURL().Host, - "path": resp.Request().HostURL().Path, - } + additionalAttrs["host"] = resp.Request().HostURL().Host + additionalAttrs["path"] = resp.Request().HostURL().Path metrics.PushToSeries(fmt.Sprintf("%s.%s", key, "response_time"), resp.ResponseTime().Seconds()) if resp.statusCode != 0 { metrics.IncrCounter(fmt.Sprintf("%s.status.%d", key, resp.StatusCode())) - attrs["status"] = fmt.Sprintf("%d", resp.StatusCode()) + additionalAttrs["status"] = fmt.Sprintf("%d", resp.StatusCode()) } } if err != nil { @@ -169,7 +177,7 @@ func registerMetrics(key string, metrics Metrics, f func() (*Response, error)) ( metrics.IncrCounter(fmt.Sprintf("%s.%s", key, "errors")) } } - metrics.IncrCounterWithAttrs(fmt.Sprintf("%s.%s", key, "total"), attrs) + metrics.IncrCounterWithAttrs(fmt.Sprintf("%s.%s", key, "total"), additionalAttrs) }(resp, err) } diff --git a/request_test.go b/request_test.go index c2db8b1..e341f8e 100644 --- a/request_test.go +++ b/request_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "net/url" "testing" + "time" "github.com/globocom/httpclient" @@ -160,3 +161,31 @@ func testSetHostURL(target *httpclient.Request) func(*testing.T) { assert.Nil(t, target.HostURL()) } } + +func TestSetMetricsAttrs_PropagatesToMetrics(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(handleFunc)) + defer server.Close() + + metrics := &mockMetrics{} + client := httpclient.NewHTTPClient( + &httpclient.LoggerAdapter{Writer: io.Discard}, + httpclient.WithHostURL(server.URL), + httpclient.WithMetrics(metrics), + ) + request := client.NewRequest() + + attrs := map[string]string{"foo": "bar", "baz": "qux"} + request.SetMetricsAttrs(attrs) + _, _ = request.Get("/") + + time.Sleep(100 * time.Millisecond) + + found := false + for _, calledAttrs := range metrics.incrCounterWithAttrsCalls { + if calledAttrs.attrs["foo"] == "bar" && calledAttrs.attrs["baz"] == "qux" { + found = true + break + } + } + assert.True(t, found, "Attributes passed to SetMetricsAttrs must be propagated to IncrCounterWithAttrs") +}