Skip to content

Commit cd6dee1

Browse files
committed
feat: validation in parsers and concurrent components
1 parent bebd269 commit cd6dee1

6 files changed

Lines changed: 42 additions & 10 deletions

File tree

src/bench/parallel.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ pub fn bench_worksteal(data: &[u8]) -> Result<(u64, f64, f64)> {
5151
let chunks = split_into_chunks(data, chunk_size);
5252

5353
for (start, end) in &chunks {
54-
parser.submit_arc(Arc::clone(&data_arc), *start, *end);
54+
parser
55+
.submit_arc(Arc::clone(&data_arc), *start, *end)
56+
.unwrap();
5557
}
5658

5759
let mut total_messages = 0;

src/concurrent.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ impl ConcurrentParser for SpscParser {
309309
if let Ok(msgs) = self.output_receiver.try_recv() {
310310
return Some(msgs);
311311
}
312-
if self.shutdown.load(Ordering::Relaxed) {
312+
if self.shutdown.load(Ordering::Acquire) {
313313
return None;
314314
}
315315
let (lock, cvar) = &*self.has_data;
@@ -1062,7 +1062,8 @@ impl AdaptiveBatchProcessor {
10621062
}
10631063
let n = self.throughput_count as f64;
10641064
let mean = self.throughput_sum / n;
1065-
(self.throughput_sum_sq / n) - (mean * mean)
1065+
let variance = (self.throughput_sum_sq / n) - (mean * mean);
1066+
variance.max(0.0)
10661067
}
10671068

10681069
#[inline]
@@ -1714,11 +1715,17 @@ impl WorkStealingParser {
17141715
self.injector.push(WorkUnit::Owned(data));
17151716
}
17161717

1717-
pub fn submit_arc(&self, data: Arc<[u8]>, start: usize, end: usize) {
1718+
pub fn submit_arc(&self, data: Arc<[u8]>, start: usize, end: usize) -> Result<()> {
17181719
if start >= end || end > data.len() {
1719-
return;
1720+
return Err(ParseError::InvalidArgument(format!(
1721+
"invalid range {}..{} for data of length {}",
1722+
start,
1723+
end,
1724+
data.len()
1725+
)));
17201726
}
17211727
self.injector.push(WorkUnit::ArcSlice(data, start, end));
1728+
Ok(())
17221729
}
17231730

17241731
pub fn submit_chunks(&self, data: Arc<[u8]>, chunk_size: usize) -> usize {

src/config.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pub struct Config {
33
pub max_buffer_size: usize,
44
pub max_message_size: usize,
55
pub initial_capacity: usize,
6+
pub strict_validation: bool,
67
}
78

89
impl Default for Config {
@@ -17,6 +18,7 @@ impl Config {
1718
max_buffer_size: 1024 * 1024 * 1024, // 1 GB
1819
max_message_size: 64 * 1024, // 64 KB
1920
initial_capacity: 2 * 1024 * 1024, // 2 MB
21+
strict_validation: false,
2022
}
2123
}
2224

@@ -40,6 +42,11 @@ impl Config {
4042
self
4143
}
4244

45+
pub const fn with_strict_validation(mut self, strict: bool) -> Self {
46+
self.strict_validation = strict;
47+
self
48+
}
49+
4350
pub fn from_size_mb(size_mb: usize) -> Self {
4451
Self::new().with_max_buffer_size_mb(size_mb)
4552
}

src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ pub enum ParseError {
3232
#[error("Invalid UTF-8 in field '{field}'")]
3333
InvalidUtf8 { field: &'static str },
3434

35+
#[error("Invalid timestamp: {value} nanoseconds (must be <= 86,400,000,000,000)")]
36+
InvalidTimestamp { value: u64 },
37+
3538
#[error("Parser state error: {reason}")]
3639
StateError { reason: &'static str },
3740

@@ -59,6 +62,7 @@ impl ParseError {
5962
| ParseError::InvalidHeader { .. }
6063
| ParseError::LengthMismatch { .. }
6164
| ParseError::InvalidUtf8 { .. }
65+
| ParseError::InvalidTimestamp { .. }
6266
)
6367
}
6468

src/parser.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,18 @@ impl Parser {
212212
}
213213

214214
let message_type = remaining[2];
215-
let expected = EXPECTED_LENGTHS[message_type as usize] as usize;
216-
if expected != 0 && expected != length {
215+
let expected = EXPECTED_LENGTHS[message_type as usize];
216+
if expected == NO_VALIDATION {
217+
#[cfg(debug_assertions)]
218+
eprintln!(
219+
"Warning: no expected length for message type 0x{:02X}",
220+
message_type
221+
);
222+
} else if expected as usize != length {
217223
return Err(ParseError::LengthMismatch {
218224
msg_type: message_type,
219225
declared: length,
220-
expected,
226+
expected: expected as usize,
221227
});
222228
}
223229

@@ -486,6 +492,10 @@ impl Parser {
486492
data[*pos + 5],
487493
]);
488494
*pos += 6;
495+
const MAX_VALID_TIMESTAMP: u64 = 86_400_000_000_000;
496+
if self.config.strict_validation && value > MAX_VALID_TIMESTAMP {
497+
return Err(ParseError::InvalidTimestamp { value });
498+
}
489499
Ok(value)
490500
}
491501

@@ -917,8 +927,10 @@ impl Parser {
917927
}
918928
}
919929

930+
const NO_VALIDATION: u16 = u16::MAX;
931+
920932
const EXPECTED_LENGTHS: [u16; 256] = {
921-
let mut arr = [0u16; 256];
933+
let mut arr = [NO_VALIDATION; 256];
922934
arr[b'S' as usize] = 12;
923935
arr[b'R' as usize] = 39;
924936
arr[b'H' as usize] = 25;

src/simd.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ unsafe fn memcpy_nontemporal_avx2(dst: *mut u8, src: *const u8, len: usize) {
934934
///
935935
/// Caller must ensure:
936936
/// - `src` is valid for reads of `len` bytes
937-
/// - `dst` is valid for writes of `len` bytes and is 16-byte aligned
937+
/// - `dst` is valid for writes of `len` bytes
938938
/// - the regions do not overlap
939939
pub unsafe fn memcpy_nontemporal(dst: *mut u8, src: *const u8, len: usize) {
940940
unsafe {

0 commit comments

Comments
 (0)