Skip to content

Commit b18e0a8

Browse files
committed
feat: message type validation in ZeroCopyParser
1 parent 1ad6ad4 commit b18e0a8

5 files changed

Lines changed: 128 additions & 39 deletions

File tree

src/zerocopy.rs

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -361,54 +361,73 @@ impl<'a> ZeroCopyParser<'a> {
361361

362362
#[inline(always)]
363363
pub fn parse_next(&mut self) -> Option<ZeroCopyMessage<'a>> {
364-
let data_len = self.data.len();
365-
let pos = self.position;
364+
loop {
365+
let data_len = self.data.len();
366+
let pos = self.position;
366367

367-
if pos + 3 > data_len {
368-
return None;
369-
}
368+
if pos + 3 > data_len {
369+
return None;
370+
}
370371

371-
let length = unsafe {
372-
let ptr = self.data.as_ptr().add(pos);
373-
u16::from_be_bytes(std::ptr::read_unaligned(ptr as *const [u8; 2]))
374-
} as usize;
372+
let length = unsafe {
373+
let ptr = self.data.as_ptr().add(pos);
374+
u16::from_be_bytes(std::ptr::read_unaligned(ptr as *const [u8; 2]))
375+
} as usize;
375376

376-
let total_size = length + 2;
377-
let msg_end = pos + total_size;
377+
let total_size = length + 2;
378+
let msg_end = pos + total_size;
378379

379-
if msg_end > data_len {
380-
return None;
381-
}
380+
if msg_end <= pos {
381+
self.position = pos + 1;
382+
continue;
383+
}
382384

383-
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
384-
{
385-
if msg_end + 64 <= data_len {
386-
unsafe {
387-
use std::arch::x86_64::*;
388-
_mm_prefetch(self.data.as_ptr().add(msg_end) as *const i8, _MM_HINT_T0);
385+
if msg_end > data_len {
386+
return None;
387+
}
388+
389+
#[cfg(all(feature = "simd", target_arch = "x86_64"))]
390+
{
391+
if msg_end + 64 <= data_len {
392+
unsafe {
393+
use std::arch::x86_64::*;
394+
_mm_prefetch(self.data.as_ptr().add(msg_end) as *const i8, _MM_HINT_T0);
395+
}
389396
}
390397
}
391-
}
392398

393-
let msg_type = self.data[pos + 2];
394-
let header_start = pos + 3;
395-
let header_end = header_start + 10;
399+
let msg_type = self.data[pos + 2];
396400

397-
if header_end > data_len {
398-
return None;
399-
}
401+
if !crate::simd::is_valid_message_type(msg_type) {
402+
self.position = msg_end;
403+
continue;
404+
}
400405

401-
let payload_start = header_end;
406+
let header_start = pos + 3;
407+
let header_end = header_start + 10;
408+
409+
if header_end > msg_end {
410+
self.position = msg_end;
411+
continue;
412+
}
413+
414+
if header_end > data_len {
415+
return None;
416+
}
417+
418+
let payload_start = header_end;
419+
420+
if let Ok((hdr_ref, _)) =
421+
Ref::<&[u8], MessageHeaderRaw>::from_prefix(&self.data[header_start..header_end])
422+
{
423+
let payload = &self.data[payload_start..msg_end];
424+
self.position = msg_end;
425+
return Some(ZeroCopyMessage::new(msg_type, hdr_ref, payload));
426+
}
402427

403-
if let Ok((hdr_ref, _)) =
404-
Ref::<&[u8], MessageHeaderRaw>::from_prefix(&self.data[header_start..header_end])
405-
{
406-
let payload = &self.data[payload_start..msg_end];
407428
self.position = msg_end;
408-
return Some(ZeroCopyMessage::new(msg_type, hdr_ref, payload));
429+
continue;
409430
}
410-
411-
None
412431
}
413432

414433
#[inline]
@@ -741,14 +760,14 @@ mod tests {
741760
let mut data = Vec::new();
742761
for _ in 0..2 {
743762
data.extend(&[0u8, 11u8]);
744-
data.push(1u8);
763+
data.push(b'S');
745764
data.extend(&[0u8; 10]);
746765
}
747766

748767
let mut parser = ZeroCopyParser::new(&data);
749768
let arcs = parser.parse_all_arc();
750769
assert_eq!(arcs.len(), 2);
751-
assert_eq!(arcs[0].msg_type, 1u8);
770+
assert_eq!(arcs[0].msg_type, b'S');
752771
assert!(Arc::strong_count(&arcs[0].payload) >= 1);
753772
}
754773
}

tests/fixtures.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub fn create_uniform_buffer(count: usize, msg_type: u8, payload_len: usize) ->
1616
create_test_buffer(&messages)
1717
}
1818

19+
#[allow(dead_code)]
1920
pub fn standard_fixture() -> Vec<u8> {
2021
let mut messages = Vec::new();
2122
for i in 0..1000u16 {

tests/integration.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,35 @@ fn test_large_buffer_no_data_loss() {
197197
"All messages should be parsed without data loss"
198198
);
199199
}
200+
201+
#[test]
202+
fn test_skip_invalid_message_type() {
203+
let messages = vec![
204+
(0xFF, &[][..]), // invalid
205+
(b'S', &[0; 10][..]), // valid
206+
];
207+
let buf = fixtures::create_test_buffer(&messages);
208+
let mut parser = ZeroCopyParser::new(&buf);
209+
let msg = parser.parse_next().unwrap();
210+
assert_eq!(msg.msg_type(), b'S');
211+
assert_eq!(parser.position(), buf.len());
212+
assert!(parser.parse_next().is_none());
213+
}
214+
215+
#[test]
216+
fn test_skip_invalid_with_while_loop() {
217+
let messages = vec![(0xFF, &[][..]), (b'S', &[0; 10][..]), (b'R', &[1; 5][..])];
218+
let buf = fixtures::create_test_buffer(&messages);
219+
let mut parser = ZeroCopyParser::new(&buf);
220+
let mut count = 0;
221+
while let Some(msg) = parser.parse_next() {
222+
count += 1;
223+
if count == 1 {
224+
assert_eq!(msg.msg_type(), b'S');
225+
} else {
226+
panic!("Unexpected extra message with type {}", msg.msg_type());
227+
}
228+
}
229+
assert_eq!(count, 1);
230+
assert_eq!(parser.position(), buf.len());
231+
}

tests/mmap_shared.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ fn test_mmap_shared_and_owned_messages() {
99
let mut buf = Vec::new();
1010
for _ in 0..2 {
1111
buf.extend(&[0, 11]);
12-
buf.push(1u8);
12+
buf.push(b'S');
1313
buf.extend(&[0u8; 10]);
1414
}
1515

@@ -23,5 +23,5 @@ fn test_mmap_shared_and_owned_messages() {
2323
assert_eq!(owned.len(), 2);
2424

2525
drop(shared);
26-
assert_eq!(owned[0].msg_type, 1u8);
26+
assert_eq!(owned[0].msg_type, b'S');
2727
}

tests/zerocopy_validation.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
mod fixtures;
2+
3+
use lunary::ZeroCopyParser;
4+
5+
use fixtures::create_test_buffer;
6+
7+
#[test]
8+
fn test_skip_invalid_message_type() {
9+
let messages = vec![
10+
(0xFF, &[][..]), // invalid
11+
(b'S', &[0; 10][..]), // valid
12+
];
13+
let buf = create_test_buffer(&messages);
14+
let mut parser = ZeroCopyParser::new(&buf);
15+
let msg = parser.parse_next().unwrap();
16+
assert_eq!(msg.msg_type(), b'S');
17+
assert_eq!(parser.position(), buf.len());
18+
assert!(parser.parse_next().is_none());
19+
}
20+
21+
#[test]
22+
fn test_skip_invalid_with_while_loop() {
23+
let messages = vec![(0xFF, &[][..]), (b'S', &[0; 10][..]), (b'R', &[1; 5][..])];
24+
let buf = create_test_buffer(&messages);
25+
let mut parser = ZeroCopyParser::new(&buf);
26+
let mut count = 0;
27+
while let Some(msg) = parser.parse_next() {
28+
count += 1;
29+
if count == 1 {
30+
assert_eq!(msg.msg_type(), b'S');
31+
} else {
32+
panic!("Unexpected extra message with type {}", msg.msg_type());
33+
}
34+
}
35+
assert_eq!(count, 1);
36+
assert_eq!(parser.position(), buf.len());
37+
}

0 commit comments

Comments
 (0)