Skip to content

Commit 559b0a5

Browse files
committed
feat: propagate trace context through HTTP channel
- inject span context into outbound HTTP headers and extract it back onto responses - add Context-aware HTTP response wrapper and corresponding tests - require otel/trace to satisfy new imports
1 parent 66b6a2c commit 559b0a5

4 files changed

Lines changed: 146 additions & 4 deletions

File tree

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ require (
1616
github.com/pkg/errors v0.9.1
1717
github.com/stretchr/testify v1.9.0
1818
github.com/ucan-wg/go-ucan v0.0.0-20240916120445-37f52863156c
19+
go.opentelemetry.io/otel v1.30.0
20+
go.opentelemetry.io/otel/trace v1.30.0
1921
)
2022

2123
require (
@@ -71,9 +73,7 @@ require (
7173
github.com/polydawn/refmt v0.89.1-0.20231129105047-37766d95467a // indirect
7274
github.com/spaolacci/murmur3 v1.1.0 // indirect
7375
github.com/whyrusleeping/cbor-gen v0.1.2 // indirect
74-
go.opentelemetry.io/otel v1.30.0 // indirect
7576
go.opentelemetry.io/otel/metric v1.30.0 // indirect
76-
go.opentelemetry.io/otel/trace v1.30.0 // indirect
7777
go.uber.org/atomic v1.11.0 // indirect
7878
go.uber.org/multierr v1.11.0 // indirect
7979
go.uber.org/zap v1.27.0 // indirect

transport/http/channel.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import (
88

99
"slices"
1010

11+
"go.opentelemetry.io/otel"
12+
"go.opentelemetry.io/otel/propagation"
13+
1114
"github.com/storacha/go-ucanto/transport"
1215
)
1316

@@ -67,6 +70,8 @@ func (c *Channel) Request(ctx context.Context, req transport.HTTPRequest) (trans
6770
}
6871

6972
addAllHeaders(hr.Header, req.Headers(), c.headers)
73+
injectTraceContext(ctx, hr)
74+
7075
res, err := c.client.Do(hr)
7176
if err != nil {
7277
return nil, fmt.Errorf("doing HTTP request: %w", err)
@@ -75,7 +80,8 @@ func (c *Channel) Request(ctx context.Context, req transport.HTTPRequest) (trans
7580
return nil, NewHTTPError(fmt.Sprintf("HTTP Request failed. %s %s → %d", hr.Method, c.url.String(), res.StatusCode), res.StatusCode, res.Header)
7681
}
7782

78-
return NewResponse(res.StatusCode, res.Body, res.Header), nil
83+
ctx = extractTraceContext(ctx, res.Header)
84+
return NewResponseWithContext(ctx, res.StatusCode, res.Body, res.Header), nil
7985
}
8086

8187
func addAllHeaders(dst http.Header, srcs ...http.Header) {
@@ -88,6 +94,20 @@ func addAllHeaders(dst http.Header, srcs ...http.Header) {
8894
}
8995
}
9096

97+
func injectTraceContext(ctx context.Context, req *http.Request) {
98+
if ctx == nil || req == nil {
99+
return
100+
}
101+
otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
102+
}
103+
104+
func extractTraceContext(ctx context.Context, headers http.Header) context.Context {
105+
if ctx == nil || headers == nil {
106+
return ctx
107+
}
108+
return otel.GetTextMapPropagator().Extract(ctx, propagation.HeaderCarrier(headers))
109+
}
110+
91111
var _ transport.Channel = (*Channel)(nil)
92112

93113
func NewChannel(url *url.URL, options ...Option) *Channel {

transport/http/channel_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package http
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/url"
8+
"testing"
9+
10+
"go.opentelemetry.io/otel"
11+
"go.opentelemetry.io/otel/propagation"
12+
"go.opentelemetry.io/otel/trace"
13+
)
14+
15+
func TestChannelPropagatesTraceContext(t *testing.T) {
16+
const (
17+
requestTraceIDHex = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
18+
requestSpanIDHex = "bbbbbbbbbbbbbbbb"
19+
responseTraceIDHex = "cccccccccccccccccccccccccccccccc"
20+
responseSpanIDHex = "dddddddddddddddd"
21+
responseTrace = "00-" + responseTraceIDHex + "-" + responseSpanIDHex + "-01"
22+
expectedRequest = "00-" + requestTraceIDHex + "-" + requestSpanIDHex + "-01"
23+
)
24+
25+
var seenRequestTrace string
26+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27+
seenRequestTrace = r.Header.Get("traceparent")
28+
w.Header().Set("traceparent", responseTrace)
29+
w.WriteHeader(http.StatusOK)
30+
}))
31+
t.Cleanup(server.Close)
32+
33+
endpoint, err := url.Parse(server.URL)
34+
if err != nil {
35+
t.Fatalf("parsing server URL: %v", err)
36+
}
37+
38+
channel := NewChannel(endpoint, WithClient(server.Client()))
39+
40+
restoreProp := setTraceContextPropagator()
41+
t.Cleanup(restoreProp)
42+
43+
ctx := context.Background()
44+
ctx = trace.ContextWithSpanContext(ctx, newSpanContext(t, requestTraceIDHex, requestSpanIDHex))
45+
46+
res, err := channel.Request(ctx, NewRequest(http.NoBody, nil))
47+
if err != nil {
48+
t.Fatalf("request failed: %v", err)
49+
}
50+
t.Cleanup(func() { res.Body().Close() })
51+
52+
if seenRequestTrace != expectedRequest {
53+
t.Fatalf("expected traceparent %q, got %q", expectedRequest, seenRequestTrace)
54+
}
55+
56+
responseCtx, ok := res.(*Response)
57+
if !ok {
58+
t.Fatalf("expected *Response, got %T", res)
59+
}
60+
sc := trace.SpanContextFromContext(responseCtx.Context())
61+
expectedTraceID := mustTraceIDFromHex(t, responseTraceIDHex)
62+
if sc.TraceID() != expectedTraceID {
63+
t.Fatalf("expected response trace ID %s, got %s", expectedTraceID, sc.TraceID())
64+
}
65+
expectedSpanID := mustSpanIDFromHex(t, responseSpanIDHex)
66+
if sc.SpanID() != expectedSpanID {
67+
t.Fatalf("expected response span ID %s, got %s", expectedSpanID, sc.SpanID())
68+
}
69+
}
70+
71+
func newSpanContext(t *testing.T, traceIDHex, spanIDHex string) trace.SpanContext {
72+
t.Helper()
73+
traceID := mustTraceIDFromHex(t, traceIDHex)
74+
spanID := mustSpanIDFromHex(t, spanIDHex)
75+
return trace.NewSpanContext(trace.SpanContextConfig{
76+
TraceID: traceID,
77+
SpanID: spanID,
78+
TraceFlags: trace.FlagsSampled,
79+
})
80+
}
81+
82+
func mustTraceIDFromHex(t *testing.T, hex string) trace.TraceID {
83+
t.Helper()
84+
traceID, err := trace.TraceIDFromHex(hex)
85+
if err != nil {
86+
t.Fatalf("parsing trace ID: %v", err)
87+
}
88+
return traceID
89+
}
90+
91+
func mustSpanIDFromHex(t *testing.T, hex string) trace.SpanID {
92+
t.Helper()
93+
spanID, err := trace.SpanIDFromHex(hex)
94+
if err != nil {
95+
t.Fatalf("parsing span ID: %v", err)
96+
}
97+
return spanID
98+
}
99+
100+
func setTraceContextPropagator() func() {
101+
prev := otel.GetTextMapPropagator()
102+
otel.SetTextMapPropagator(propagation.TraceContext{})
103+
return func() {
104+
otel.SetTextMapPropagator(prev)
105+
}
106+
}

transport/http/response.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package http
22

33
import (
4+
"context"
45
"io"
56
"net/http"
67
"net/url"
@@ -30,6 +31,7 @@ var _ transport.HTTPRequest = (*Request)(nil)
3031
var _ transport.InboundHTTPRequest = (*Request)(nil)
3132

3233
type Response struct {
34+
ctx context.Context
3335
status int
3436
hdrs http.Header
3537
body io.ReadCloser
@@ -47,10 +49,24 @@ func (res *Response) Body() io.ReadCloser {
4749
return res.body
4850
}
4951

52+
func (res *Response) Context() context.Context {
53+
if res.ctx == nil {
54+
return context.Background()
55+
}
56+
return res.ctx
57+
}
58+
5059
var _ transport.HTTPResponse = (*Response)(nil)
5160

5261
func NewResponse(status int, body io.ReadCloser, headers http.Header) *Response {
53-
return &Response{status, headers, body}
62+
return NewResponseWithContext(context.Background(), status, body, headers)
63+
}
64+
65+
func NewResponseWithContext(ctx context.Context, status int, body io.ReadCloser, headers http.Header) *Response {
66+
if ctx == nil {
67+
ctx = context.Background()
68+
}
69+
return &Response{ctx: ctx, status: status, hdrs: headers, body: body}
5470
}
5571

5672
// NewRequest creates a [transport.HTTPRequest]

0 commit comments

Comments
 (0)