Skip to content

Commit 335f34a

Browse files
Merge pull request #110 from Shreyas582/v1.3.0-provider-agnostic-config
feat: provider-agnostic ModelConfig (#49, #50, #51)
2 parents 28c9ca7 + 95e2a21 commit 335f34a

7 files changed

Lines changed: 127 additions & 25 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ The format is inspired by Keep a Changelog and this project follows Semantic Ver
1212

1313
### Changed
1414

15-
- (none yet)
15+
- **Provider-agnostic `ModelConfig`** (#49): replaced `vitis_config: Option<VitisEpConfig>` with generic `backend_override: Option<String>` and `backend_config: HashMap<String, String>`. `VitisEpConfig` is retained as a CLI-level helper with `into_backend_config()` / `from_backend_config()` conversion methods.
16+
- **Vitis EP reads via `backend_config` map** (#50): `onnx_vitis` functions (`discover_ort_dylib_path`, `build_base_session_builder_with_provider`, `build_session_with_vitis_cascade`) now read config values from the generic `backend_config` map instead of the Vitis-specific struct.
17+
- **CPU EP unblocked by config refactor** (#51): `CpuBackend` and all non-Vitis callers now use `backend_override: None, backend_config: Default::default()`, removing any coupling to Vitis types.
1618

1719
### Fixed
1820

api_server/src/routes.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,8 @@ async fn run_investigation(task: &str, max_steps: usize) -> anyhow::Result<core_
353353
max_new_tokens: 256,
354354
temperature: 0.2,
355355
dry_run: true,
356-
vitis_config: None,
356+
backend_override: None,
357+
backend_config: Default::default(),
357358
};
358359
let engine = OnnxVitisEngine::new(model_config);
359360
let tools = ToolRegistry::with_default_tools();

cli/src/main.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4405,14 +4405,20 @@ fn run_model_pack_doctor_checks(runtime: &RuntimeConfig, report: &mut DoctorRepo
44054405
return;
44064406
}
44074407

4408+
let vitis_cfg = build_vitis_config(runtime);
4409+
let (bo, bc) = match vitis_cfg {
4410+
Some(cfg) => (Some("vitis".to_string()), cfg.into_backend_config()),
4411+
None => (None, Default::default()),
4412+
};
44084413
let compatibility = inspect_runtime_compatibility(
44094414
&ModelConfig {
44104415
model_path: runtime.model.clone(),
44114416
tokenizer_path: runtime.tokenizer.clone(),
44124417
max_new_tokens: 1,
44134418
temperature: runtime.temperature,
44144419
dry_run: false,
4415-
vitis_config: build_vitis_config(runtime),
4420+
backend_override: bo,
4421+
backend_config: bc,
44164422
},
44174423
true,
44184424
);
@@ -5734,13 +5740,18 @@ async fn run_agent_once(runtime: &RuntimeConfig, dry_run: bool) -> Result<RunRep
57345740
}
57355741

57365742
let vitis_config = build_vitis_config(runtime);
5743+
let (backend_override, backend_config) = match vitis_config {
5744+
Some(cfg) => (Some("vitis".to_string()), cfg.into_backend_config()),
5745+
None => (None, Default::default()),
5746+
};
57375747
let model_config = ModelConfig {
57385748
model_path: runtime.model.clone(),
57395749
tokenizer_path: runtime.tokenizer.clone(),
57405750
max_new_tokens: runtime.max_new_tokens,
57415751
temperature: runtime.temperature,
57425752
dry_run,
5743-
vitis_config,
5753+
backend_override,
5754+
backend_config,
57445755
};
57455756

57465757
// Determine capability tier: override > probe > default.

docs/upgrades.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,58 @@
11
# Upgrade Notes
22

3+
## Unreleased
4+
5+
### Breaking/visible changes
6+
7+
- **`ModelConfig` struct changed** (#49): the `vitis_config: Option<VitisEpConfig>` field is replaced by two new fields:
8+
- `backend_override: Option<String>` — optional backend name hint (e.g. `"vitis"`)
9+
- `backend_config: HashMap<String, String>` — generic key-value config map
10+
11+
Both fields default to empty via `#[serde(default)]`, so TOML/JSON deserialization is backward-compatible if you don't set them.
12+
13+
- **`VitisEpConfig` is still available** as a helper. Use `into_backend_config()` to convert to the new map and `from_backend_config()` to reconstruct from one.
14+
15+
### Migration
16+
17+
Before:
18+
```rust
19+
let config = ModelConfig {
20+
// ...
21+
vitis_config: Some(VitisEpConfig {
22+
config_file: Some("/path/to/vitis.json".into()),
23+
cache_dir: None,
24+
cache_key: None,
25+
}),
26+
};
27+
```
28+
29+
After:
30+
```rust
31+
use std::collections::HashMap;
32+
33+
let config = ModelConfig {
34+
// ...
35+
backend_override: Some("vitis".to_string()),
36+
backend_config: HashMap::from([
37+
("config_file".to_string(), "/path/to/vitis.json".to_string()),
38+
]),
39+
};
40+
```
41+
42+
Or using the helper:
43+
```rust
44+
let vitis = VitisEpConfig {
45+
config_file: Some("/path/to/vitis.json".into()),
46+
cache_dir: None,
47+
cache_key: None,
48+
};
49+
let config = ModelConfig {
50+
// ...
51+
backend_override: Some("vitis".to_string()),
52+
backend_config: vitis.into_backend_config(),
53+
};
54+
```
55+
356
## v1.3.0
457

558
### Breaking/visible changes

inference_bridge/src/backend.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,8 @@ mod tests {
482482
max_new_tokens: 1,
483483
temperature: 0.0,
484484
dry_run: true,
485-
vitis_config: None,
485+
backend_override: None,
486+
backend_config: Default::default(),
486487
};
487488
let cpu = CpuBackend;
488489
let session = cpu.build_session(&config, &BackendOptions::new());
@@ -529,7 +530,8 @@ mod tests {
529530
max_new_tokens: 1,
530531
temperature: 0.0,
531532
dry_run: true,
532-
vitis_config: None,
533+
backend_override: None,
534+
backend_config: Default::default(),
533535
};
534536
let result =
535537
registry.build_session_with_fallback(&config, &BackendOptions::new(), None);
@@ -547,7 +549,8 @@ mod tests {
547549
max_new_tokens: 1,
548550
temperature: 0.0,
549551
dry_run: true,
550-
vitis_config: None,
552+
backend_override: None,
553+
backend_config: Default::default(),
551554
};
552555
let result =
553556
registry.build_session_with_fallback(&config, &BackendOptions::new(), Some("CPU"));

inference_bridge/src/lib.rs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ fn estimate_params_from_file_size(model_path: &PathBuf) -> f32 {
6666

6767
/// Detect which execution provider would be used for this config.
6868
fn detect_execution_provider(config: &ModelConfig) -> String {
69-
if config.vitis_config.is_some() {
69+
if config.backend_override.as_deref() == Some("vitis")
70+
|| config.backend_config.contains_key("config_file")
71+
{
7072
"VitisAIExecutionProvider".to_string()
7173
} else if cfg!(feature = "onnx") {
7274
// Without Vitis config, ONNX Runtime defaults to CPU.
@@ -129,14 +131,45 @@ pub struct VitisEpConfig {
129131
pub cache_key: Option<String>,
130132
}
131133

134+
impl VitisEpConfig {
135+
/// Convert to a generic backend config map.
136+
pub fn into_backend_config(self) -> std::collections::HashMap<String, String> {
137+
let mut map = std::collections::HashMap::new();
138+
if let Some(v) = self.config_file {
139+
map.insert("config_file".to_string(), v);
140+
}
141+
if let Some(v) = self.cache_dir {
142+
map.insert("cache_dir".to_string(), v);
143+
}
144+
if let Some(v) = self.cache_key {
145+
map.insert("cache_key".to_string(), v);
146+
}
147+
map
148+
}
149+
150+
/// Reconstruct from a generic backend config map.
151+
pub fn from_backend_config(map: &std::collections::HashMap<String, String>) -> Self {
152+
Self {
153+
config_file: map.get("config_file").cloned(),
154+
cache_dir: map.get("cache_dir").cloned(),
155+
cache_key: map.get("cache_key").cloned(),
156+
}
157+
}
158+
}
159+
132160
#[derive(Debug, Clone, Serialize, Deserialize)]
133161
pub struct ModelConfig {
134162
pub model_path: PathBuf,
135163
pub tokenizer_path: Option<PathBuf>,
136164
pub max_new_tokens: usize,
137165
pub temperature: f32,
138166
pub dry_run: bool,
139-
pub vitis_config: Option<VitisEpConfig>,
167+
/// Explicit backend override (e.g., "cpu", "vitis", "cuda").
168+
#[serde(default)]
169+
pub backend_override: Option<String>,
170+
/// Provider-specific key-value configuration.
171+
#[serde(default)]
172+
pub backend_config: std::collections::HashMap<String, String>,
140173
}
141174

142175
#[async_trait]
@@ -356,7 +389,8 @@ mod tests {
356389
max_new_tokens: 16,
357390
temperature: 0.2,
358391
dry_run: true,
359-
vitis_config: None,
392+
backend_override: None,
393+
backend_config: Default::default(),
360394
})
361395
}
362396

inference_bridge/src/onnx_vitis.rs

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,8 @@ fn discover_ort_dylib_path(config: &ModelConfig) -> Option<PathBuf> {
450450
let mut candidates = Vec::new();
451451

452452
if let Some(vitis_config_path) = config
453-
.vitis_config
454-
.as_ref()
455-
.and_then(|cfg| cfg.config_file.as_deref())
453+
.backend_config
454+
.get("config_file")
456455
{
457456
let vitis_config_path = PathBuf::from(vitis_config_path);
458457
if let Some(parent) = vitis_config_path.parent() {
@@ -1203,16 +1202,14 @@ fn build_base_session_builder_with_provider(
12031202
let mut vitis = ep::Vitis::default();
12041203

12051204
if use_vitis_provider {
1206-
if let Some(vitis_cfg) = &config.vitis_config {
1207-
if let Some(config_file) = &vitis_cfg.config_file {
1208-
vitis = vitis.with_config_file(config_file);
1209-
}
1210-
if let Some(cache_dir) = &vitis_cfg.cache_dir {
1211-
vitis = vitis.with_cache_dir(cache_dir);
1212-
}
1213-
if let Some(cache_key) = &vitis_cfg.cache_key {
1214-
vitis = vitis.with_cache_key(cache_key);
1215-
}
1205+
if let Some(config_file) = config.backend_config.get("config_file") {
1206+
vitis = vitis.with_config_file(config_file);
1207+
}
1208+
if let Some(cache_dir) = config.backend_config.get("cache_dir") {
1209+
vitis = vitis.with_cache_dir(cache_dir);
1210+
}
1211+
if let Some(cache_key) = config.backend_config.get("cache_key") {
1212+
vitis = vitis.with_cache_key(cache_key);
12161213
}
12171214
}
12181215

@@ -1254,7 +1251,7 @@ fn build_session_with_vitis_cascade(config: &ModelConfig) -> Result<Session> {
12541251
let force_cpu_provider = env_var_truthy("WRAITHRUN_FORCE_CPU_EP");
12551252
debug!(
12561253
model = %config.model_path.display(),
1257-
has_vitis_config = config.vitis_config.is_some(),
1254+
has_vitis_config = config.backend_config.contains_key("config_file"),
12581255
force_cpu_provider,
12591256
"building Vitis ONNX Runtime session"
12601257
);
@@ -2792,7 +2789,8 @@ mod tests {
27922789
max_new_tokens: 1,
27932790
temperature: 0.0,
27942791
dry_run: false,
2795-
vitis_config: None,
2792+
backend_override: None,
2793+
backend_config: Default::default(),
27962794
};
27972795

27982796
let report = inspect_runtime_compatibility(&config, true);

0 commit comments

Comments
 (0)