Skip to content

Commit c5fbb48

Browse files
committed
Support RTP extensions when writing samples
1 parent d0bc062 commit c5fbb48

2 files changed

Lines changed: 117 additions & 1 deletion

File tree

track_local_static.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,15 @@ func (s *TrackLocalStaticRTP) Write(b []byte) (n int, err error) {
230230
return len(b), s.writeRTP(packet)
231231
}
232232

233+
// SampleRTPExtension is a RTP extension that can be added to packets generated
234+
// when writing Samples to a TrackLocalStaticSample.
235+
type SampleRTPExtension struct {
236+
// ID is the negotiated extension id.
237+
ID uint8
238+
// Payload is the payload of the extension to add to the packet.
239+
Payload []byte
240+
}
241+
233242
// TrackLocalStaticSample is a TrackLocal that has a pre-set codec and accepts Samples.
234243
// If you wish to send a RTP Packet use TrackLocalStaticRTP.
235244
type TrackLocalStaticSample struct {
@@ -340,7 +349,7 @@ func (s *TrackLocalStaticSample) Unbind(t TrackLocalContext) error {
340349
// If one PeerConnection fails the packets will still be sent to
341350
// all PeerConnections. The error message will contain the ID of the failed
342351
// PeerConnections so you can remove them.
343-
func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
352+
func (s *TrackLocalStaticSample) WriteSample(sample media.Sample, extensions ...SampleRTPExtension) error {
344353
s.rtpTrack.mu.RLock()
345354
packetizer := s.packetizer
346355
clockRate := s.clockRate
@@ -377,6 +386,11 @@ func (s *TrackLocalStaticSample) WriteSample(sample media.Sample) error {
377386

378387
writeErrs := []error{}
379388
for _, p := range packets {
389+
for _, e := range extensions {
390+
if err := p.SetExtension(e.ID, e.Payload); err != nil {
391+
writeErrs = append(writeErrs, err)
392+
}
393+
}
380394
if err := s.rtpTrack.WriteRTP(p); err != nil {
381395
writeErrs = append(writeErrs, err)
382396
}

track_local_static_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,3 +1027,105 @@ func (p *countingPacketizer) EnableAbsSendTime(value int) {}
10271027
func (p *countingPacketizer) SkipSamples(skippedSamples uint32) {
10281028
p.totalSamples += uint64(skippedSamples)
10291029
}
1030+
1031+
type sampleExtensionChecker struct {
1032+
dummyWriter
1033+
1034+
t *testing.T
1035+
errorTest bool
1036+
called atomic.Uint32
1037+
}
1038+
1039+
func (c *sampleExtensionChecker) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
1040+
c.called.Add(1)
1041+
assert.True(c.t, header.Extension)
1042+
assert.EqualValues(c.t, "hello", header.GetExtension(1))
1043+
if !c.errorTest {
1044+
assert.EqualValues(c.t, "world", header.GetExtension(2))
1045+
} else {
1046+
assert.Nil(c.t, header.GetExtension(2))
1047+
}
1048+
1049+
return 0, nil
1050+
}
1051+
1052+
func TestTrackLocalStaticSample_WriteSample_Extensions(t *testing.T) {
1053+
track, err := NewTrackLocalStaticSample(
1054+
RTPCodecCapability{MimeType: MimeTypeVP8},
1055+
"video",
1056+
"pion",
1057+
)
1058+
require.NoError(t, err)
1059+
1060+
checker := &sampleExtensionChecker{
1061+
t: t,
1062+
}
1063+
1064+
track.rtpTrack.mu.Lock()
1065+
track.rtpTrack.bindings = []trackBinding{{
1066+
id: "b1",
1067+
ssrc: 0x1234,
1068+
payloadType: 96,
1069+
writeStream: checker,
1070+
}}
1071+
fp := &fakePacketizer{}
1072+
track.packetizer = fp
1073+
track.rtpTrack.mu.Unlock()
1074+
1075+
sample := media.Sample{
1076+
Data: []byte("hi"),
1077+
Duration: 20 * time.Millisecond,
1078+
}
1079+
extension1 := SampleRTPExtension{
1080+
ID: 1,
1081+
Payload: []byte("hello"),
1082+
}
1083+
extension2 := SampleRTPExtension{
1084+
ID: 2,
1085+
Payload: []byte("world"),
1086+
}
1087+
err = track.WriteSample(sample, extension1, extension2)
1088+
require.NoError(t, err)
1089+
require.EqualValues(t, 2, checker.called.Load())
1090+
}
1091+
1092+
func TestTrackLocalStaticSample_WriteSample_ExtensionsError(t *testing.T) {
1093+
track, err := NewTrackLocalStaticSample(
1094+
RTPCodecCapability{MimeType: MimeTypeVP8},
1095+
"video",
1096+
"pion",
1097+
)
1098+
require.NoError(t, err)
1099+
1100+
checker := &sampleExtensionChecker{
1101+
t: t,
1102+
errorTest: true,
1103+
}
1104+
1105+
track.rtpTrack.mu.Lock()
1106+
track.rtpTrack.bindings = []trackBinding{{
1107+
id: "b1",
1108+
ssrc: 0x1234,
1109+
payloadType: 96,
1110+
writeStream: checker,
1111+
}}
1112+
fp := &fakePacketizer{}
1113+
track.packetizer = fp
1114+
track.rtpTrack.mu.Unlock()
1115+
1116+
sample := media.Sample{
1117+
Data: []byte("hi"),
1118+
Duration: 20 * time.Millisecond,
1119+
}
1120+
extension1 := SampleRTPExtension{
1121+
ID: 1,
1122+
Payload: []byte("hello"),
1123+
}
1124+
extension2 := SampleRTPExtension{
1125+
ID: 2,
1126+
Payload: []byte("this is a long extension payload that will trigger an error"),
1127+
}
1128+
err = track.WriteSample(sample, extension1, extension2)
1129+
assert.ErrorContains(t, err, "one byte extension")
1130+
require.EqualValues(t, 2, checker.called.Load())
1131+
}

0 commit comments

Comments
 (0)