diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index 054c335c..2524e552 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -296,9 +296,6 @@ pub fn metadata_for_model(model: &str) -> Option { None } - - - #[must_use] pub fn strip_provider_prefix(canonical_model: &str) -> String { if let Some(pos) = canonical_model.find('/') { @@ -308,8 +305,6 @@ pub fn strip_provider_prefix(canonical_model: &str) -> String { } } - - #[must_use] pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics { let resolved_model = resolve_model_alias(model); diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index c378b585..09fe09c2 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -16,8 +16,7 @@ use crate::types::{ ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; -use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix}; - +use super::{preflight_message_request, resolve_model_alias, Provider, ProviderFuture}; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; @@ -213,80 +212,22 @@ impl OpenAiCompatClient { } pub async fn send_message( - &self, - request: &MessageRequest, -) -> Result { - // 1. Keep track of what Claw originally asked for - let original_model = request.model.clone(); - let canonical = resolve_model_alias(&request.model); - - // 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash") - let downstream_model = strip_provider_prefix(&canonical); + &self, + request: &MessageRequest, + ) -> Result { + let original_model = request.model.clone(); + let canonical = resolve_model_alias(&request.model); - let mut request = MessageRequest { - stream: false, - ..request.clone() - }; - request.model = downstream_model; // Use the clean name for the API payload - - preflight_message_request(&request)?; - let response = self.send_with_retry(&request).await?; - let request_id = request_id_from_headers(response.headers()); - let body = response.text().await.map_err(ApiError::from)?; + let mut request = MessageRequest { + stream: false, + ..request.clone() + }; + request.model = canonical; - // Some backends return {"error":{"message":"...","type":"...","code":...}} - // instead of a valid completion object. Check for this before attempting - // full deserialization so the user sees the actual error, not a cryptic. - if let Ok(raw) = serde_json::from_str::(&body) { - if let Some(err_obj) = raw.get("error") { - let msg = err_obj - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("provider returned an error") - .to_string(); - let code = err_obj - .get("code") - .and_then(serde_json::Value::as_u64) - .map(|c| c as u16); - return Err(ApiError::Api { - status: reqwest::StatusCode::from_u16(code.unwrap_or(400)) - .unwrap_or(reqwest::StatusCode::BAD_REQUEST), - error_type: err_obj - .get("type") - .and_then(|t| t.as_str()) - .map(str::to_owned), - message: Some(msg), - request_id, - body, - retryable: false, - suggested_action: suggested_action_for_status( - reqwest::StatusCode::from_u16(code.unwrap_or(400)) - .unwrap_or(reqwest::StatusCode::BAD_REQUEST), - ), - retry_after: None, - }); - } - } - - // Pass original_model to the deserializer error context so debugging logs are accurate - let payload = serde_json::from_str::(&body).map_err(|error| { - ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error) - })?; - - let mut normalized = normalize_response(&request.model, payload)?; - if normalized.request_id.is_none() { - normalized.request_id = request_id; - } - - // 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy - normalized.model = original_model; - - Ok(normalized) -} - // Some backends return {"error":{"message":"...","type":"...","code":...}} - // instead of a valid completion object. Check for this before attempting - // full deserialization so the user sees the actual error, not a cryptic - // "missing field 'id'" parse failure. + preflight_message_request(&request)?; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let body = response.text().await.map_err(ApiError::from)?; if let Ok(raw) = serde_json::from_str::(&body) { if let Some(err_obj) = raw.get("error") { let msg = err_obj @@ -318,41 +259,41 @@ impl OpenAiCompatClient { } } let payload = serde_json::from_str::(&body).map_err(|error| { - ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error) + ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error) })?; let mut normalized = normalize_response(&request.model, payload)?; if normalized.request_id.is_none() { normalized.request_id = request_id; } + normalized.model = original_model; Ok(normalized) } -pub async fn stream_message( - &self, - request: &MessageRequest, -) -> Result { - // 1. Keep track of the original model name - let original_model = request.model.clone(); - let canonical = resolve_model_alias(&request.model); - - // 2. Clean it up for DeepSeek - let downstream_model = strip_provider_prefix(&canonical); + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let original_model = request.model.clone(); + let canonical = resolve_model_alias(&request.model); - let mut streaming_request = request.clone().with_streaming(); - streaming_request.model = downstream_model; + let mut streaming_request = request.clone().with_streaming(); + streaming_request.model = canonical; - preflight_message_request(&streaming_request)?; - let response = self.send_with_retry(&streaming_request).await?; + preflight_message_request(&streaming_request)?; + let response = self.send_with_retry(&streaming_request).await?; - Ok(MessageStream { - request_id: request_id_from_headers(response.headers()), - response, - parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()), - pending: VecDeque::new(), - done: false, - state: StreamState::new(original_model), // 3. Use the original name here - }) -} + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: OpenAiSseParser::with_context( + self.config.provider_name, + original_model.clone(), + ), + pending: VecDeque::new(), + done: false, + state: StreamState::new(original_model), + }) + } async fn send_with_retry( &self, diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index 4521ebed..e6edb791 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -548,12 +548,13 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() { .with_base_url("http://origin.invalid/v1"); let response = client .send_message(&MessageRequest { - model: "gpt-4o".to_string(), + model: "openai/gpt-4.1-mini".to_string(), ..sample_request(false) }) .await .expect("proxy should return the OpenAI-compatible response"); + assert_eq!(response.model, "openai/gpt-4.1-mini"); assert_eq!(response.total_tokens(), 7); let captured = state.lock().await; let request = captured.first().expect("proxy should capture request"); @@ -562,6 +563,8 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() { request.headers.get("authorization").map(String::as_str), Some("Bearer openai-test-key") ); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["model"], json!("openai/gpt-4.1-mini")); } #[allow(clippy::await_holding_lock)] diff --git a/rust/crates/runtime/src/session_control.rs b/rust/crates/runtime/src/session_control.rs index 90120eff..4a789a89 100644 --- a/rust/crates/runtime/src/session_control.rs +++ b/rust/crates/runtime/src/session_control.rs @@ -832,6 +832,28 @@ mod tests { static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0); + struct EnvVarGuard { + key: &'static str, + previous: Option, + } + + impl EnvVarGuard { + fn set(key: &'static str, value: &Path) -> Self { + let previous = std::env::var_os(key); + std::env::set_var(key, value); + Self { key, previous } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.previous { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } + } + fn temp_dir() -> PathBuf { let nanos = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -1290,8 +1312,11 @@ mod tests { #[test] fn latest_session_returns_all_empty_error_when_sessions_exist_but_have_no_messages() { // given — create sessions with 0 messages (empty) + let _env_guard = crate::test_env_lock(); let base = temp_dir(); fs::create_dir_all(&base).expect("base dir should exist"); + let isolated_config_home = base.join("config-home"); + let _claw_config_home = EnvVarGuard::set("CLAW_CONFIG_HOME", &isolated_config_home); let store = SessionStore::from_cwd(&base).expect("store should build"); let empty_handle = store.create_handle("empty-session"); diff --git a/rust/crates/runtime/src/worker_boot.rs b/rust/crates/runtime/src/worker_boot.rs index c1193db6..6b4fbfbf 100644 --- a/rust/crates/runtime/src/worker_boot.rs +++ b/rust/crates/runtime/src/worker_boot.rs @@ -1644,16 +1644,13 @@ mod tests { let tmp = tempfile::tempdir().expect("tempdir"); let worktree = tmp.path().join("worktree"); - let git_dir = tmp.path().join("external-gitdir"); fs::create_dir_all(&worktree).expect("worktree dir"); - fs::create_dir_all(git_dir.join("objects")).expect("objects dir"); - fs::create_dir_all(git_dir.join("refs/heads")).expect("refs dir"); - fs::write(git_dir.join("HEAD"), "ref: refs/heads/main\n").expect("HEAD"); - fs::write( - worktree.join(".git"), - format!("gitdir: {}\n", git_dir.display()), - ) - .expect(".git file"); + Command::new("git") + .arg("init") + .current_dir(&worktree) + .output() + .expect("git init should run"); + let git_dir = worktree.join(".git"); let original_permissions = fs::metadata(&git_dir) .expect("gitdir metadata") diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 38974eb5..543ab88a 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -13737,8 +13737,15 @@ fn push_output_block( }; *pending_tool = Some((id, name, initial_input)); } - OutputContentBlock::Thinking { thinking, .. } => { + OutputContentBlock::Thinking { + thinking, + signature, + } => { render_thinking_block_summary(out, Some(thinking.chars().count()), false)?; + events.push(AssistantEvent::Thinking { + thinking, + signature, + }); *block_has_thinking_summary = true; } OutputContentBlock::RedactedThinking { .. } => { @@ -19073,6 +19080,13 @@ UU conflicted.rs", assert!(matches!( &events[0], + AssistantEvent::Thinking { + thinking, + signature + } if thinking == "step 1" && signature.as_deref() == Some("sig_123") + )); + assert!(matches!( + &events[1], AssistantEvent::TextDelta(text) if text == "Final answer" )); let rendered = String::from_utf8(out).expect("utf8"); @@ -19649,6 +19663,41 @@ mod dump_manifests_tests { #[cfg(test)] mod alias_resolution_tests { + fn ollama_env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: std::sync::OnceLock> = std::sync::OnceLock::new(); + LOCK.get_or_init(|| std::sync::Mutex::new(())) + .lock() + .expect("ollama env lock poisoned") + } + + struct EnvVarGuard { + key: &'static str, + previous: Option, + } + + impl EnvVarGuard { + fn unset(key: &'static str) -> Self { + let previous = std::env::var(key).ok(); + std::env::remove_var(key); + Self { key, previous } + } + + fn set(key: &'static str, value: &str) -> Self { + let previous = std::env::var(key).ok(); + std::env::set_var(key, value); + Self { key, previous } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.previous { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } + } + use super::{resolve_model_alias_with_config, validate_model_syntax}; #[test] @@ -19670,6 +19719,8 @@ mod alias_resolution_tests { #[test] fn test_alias_resolution_syntax_validation() { + let _guard = ollama_env_lock(); + let _env = EnvVarGuard::unset("OLLAMA_HOST"); // Resolved aliases should pass syntax validation let resolved = resolve_model_alias_with_config("opus"); assert!(validate_model_syntax(&resolved).is_ok()); @@ -19680,6 +19731,8 @@ mod alias_resolution_tests { #[test] fn test_unknown_alias_fails_validation() { + let _guard = ollama_env_lock(); + let _env = EnvVarGuard::unset("OLLAMA_HOST"); // Unknown aliases resolve to themselves let resolved = resolve_model_alias_with_config("unknown-alias"); assert_eq!(resolved, "unknown-alias"); @@ -19699,14 +19752,13 @@ mod alias_resolution_tests { } #[test] fn test_ollama_host_bypasses_model_validation() { - // Safety: test sets and clears env var within the test. - std::env::set_var("OLLAMA_HOST", "http://127.0.0.1:11434"); + let _guard = ollama_env_lock(); + let _env = EnvVarGuard::set("OLLAMA_HOST", "http://127.0.0.1:11434"); // Ollama model names with colons pass assert!(validate_model_syntax("qwen3:8b").is_ok()); assert!(validate_model_syntax("gemma4:e2b").is_ok()); assert!(validate_model_syntax("qwen3.6:27b-nvfp4").is_ok()); // Empty model still rejected assert!(validate_model_syntax("").is_err()); - std::env::remove_var("OLLAMA_HOST"); } }