Skip to content

Commit 4073878

Browse files
committed
Validate dimensions in VectorIndex
1 parent 63393c8 commit 4073878

4 files changed

Lines changed: 39 additions & 26 deletions

File tree

benches/index_performance.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::collections::HashMap;
55

66
fn bench_vector_index(c: &mut Criterion) {
77
// Setup test vectors
8-
let dimensions = 128;
8+
let dimensions = 60;
99
let vector_count = 10000;
1010
let test_vectors: Vec<Vector> = (0..vector_count)
1111
.map(|_| Vector::random_normal(dimensions, 0.0, 1.0))
@@ -36,7 +36,7 @@ fn bench_vector_index(c: &mut Criterion) {
3636
b.iter_with_setup(
3737
|| {
3838
// Create a new index for each iteration
39-
let index = VectorIndex::new("bench_index", dimensions, metric, None);
39+
let index = VectorIndex::new("bench_index", dimensions, metric, None).unwrap();
4040
(index, test_vectors.clone())
4141
},
4242
|(index, vectors)| {
@@ -66,7 +66,7 @@ fn bench_vector_index(c: &mut Criterion) {
6666
|b, _| {
6767
// Setup the index with test vectors
6868
let rt = tokio::runtime::Runtime::new().unwrap();
69-
let index = VectorIndex::new("bench_index", dimensions, metric, None);
69+
let index = VectorIndex::new("bench_index", dimensions, metric, None).unwrap();
7070

7171
rt.block_on(async {
7272
for vector in &test_vectors {
@@ -97,7 +97,7 @@ fn bench_vector_index(c: &mut Criterion) {
9797
|b, _| {
9898
// Setup the index with test vectors
9999
let rt = tokio::runtime::Runtime::new().unwrap();
100-
let index = VectorIndex::new("bench_index", dimensions, DistanceMetric::Cosine, None);
100+
let index = VectorIndex::new("bench_index", dimensions, DistanceMetric::Cosine, None).unwrap();
101101

102102
rt.block_on(async {
103103
for vector in &test_vectors {

src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async fn main() -> anyhow::Result<()> {
7575
let shard_id = shard_manager.create_shard("demo_shard").await?;
7676

7777
// Create a vector index
78-
let dimensions = 128;
78+
let dimensions = 60;
7979
let index = shard_manager
8080
.create_vector_index(shard_id, "demo_index", dimensions, DistanceMetric::Cosine)
8181
.await?;

src/sharding/manager.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,13 @@ impl ShardManager {
118118

119119
// Create the index
120120
let index = VectorIndex::new(
121-
name,
122-
dimensions,
121+
name,
122+
dimensions,
123123
distance_metric,
124124
Some(self.metrics.clone()),
125-
);
126-
125+
)
126+
.map_err(|e| anyhow!("Failed to create vector index: {}", e))?;
127+
127128
let index = Arc::new(index);
128129

129130
// Store the index

src/sharding/vector_index.rs

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,30 +97,42 @@ pub struct VectorIndex {
9797
impl VectorIndex {
9898
/// Create a new vector index
9999
pub fn new(
100-
name: &str,
101-
dimensions: usize,
100+
name: &str,
101+
dimensions: usize,
102102
distance_metric: DistanceMetric,
103103
metrics: Option<Arc<MetricsCollector>>,
104-
) -> Self {
105-
// Determine bits per dimension based on dimensions
106-
// We want to keep the total bits under 64 (for u64 hilbert index)
107-
let max_total_bits = 60; // Leave some room for safety
108-
let bits_per_dimension = std::cmp::min(
109-
10, // Maximum reasonable value
110-
max_total_bits / dimensions
111-
);
112-
104+
) -> Result<Self, String> {
105+
let max_total_bits = 60;
106+
107+
if dimensions == 0 {
108+
return Err("dimensions must be greater than zero".to_string());
109+
}
110+
if dimensions > max_total_bits {
111+
return Err(format!(
112+
"dimensions ({}) exceed maximum supported {}",
113+
dimensions, max_total_bits
114+
));
115+
}
116+
117+
let bits_per_dimension = std::cmp::min(10, max_total_bits / dimensions);
118+
if bits_per_dimension == 0 {
119+
return Err(format!(
120+
"calculated bits_per_dimension is zero for {} dimensions",
121+
dimensions
122+
));
123+
}
124+
113125
let hilbert_curve = HilbertCurve::new(dimensions, bits_per_dimension);
114-
115-
Self {
126+
127+
Ok(Self {
116128
name: name.to_string(),
117129
vectors: RwLock::new(HashMap::new()),
118130
hilbert_curve,
119131
hilbert_map: RwLock::new(HashMap::new()),
120132
dimensions,
121133
distance_metric,
122134
metrics,
123-
}
135+
})
124136
}
125137

126138
/// Convert a vector to a Hilbert index
@@ -432,7 +444,7 @@ mod tests {
432444
use rand::Rng;
433445

434446
async fn create_test_index(vector_count: usize, dimensions: usize) -> VectorIndex {
435-
let index = VectorIndex::new("test_index", dimensions, DistanceMetric::Euclidean, None);
447+
let index = VectorIndex::new("test_index", dimensions, DistanceMetric::Euclidean, None).unwrap();
436448

437449
// Add random vectors
438450
for _ in 0..vector_count {
@@ -454,7 +466,7 @@ mod tests {
454466

455467
#[tokio::test]
456468
async fn test_remove() {
457-
let index = VectorIndex::new("test_remove", 3, DistanceMetric::Euclidean, None);
469+
let index = VectorIndex::new("test_remove", 3, DistanceMetric::Euclidean, None).unwrap();
458470

459471
// Add a vector
460472
let vector = Vector::random(3);
@@ -512,7 +524,7 @@ mod tests {
512524
// Test each metric
513525
for metric in [DistanceMetric::Euclidean, DistanceMetric::Cosine,
514526
DistanceMetric::Manhattan, DistanceMetric::Hamming].iter() {
515-
let index = VectorIndex::new("test_metric", dimensions, *metric, None);
527+
let index = VectorIndex::new("test_metric", dimensions, *metric, None).unwrap();
516528

517529
// Add all vectors
518530
for v in &vectors {

0 commit comments

Comments
 (0)