Merge pull request #3234 from ultraworkers/fix/openai-compatible-reasoning-history

fix(providers): preserve OpenAI-compatible reasoning history
This commit is contained in:
YeonGyu-Kim
2026-06-08 09:27:33 +09:00
committed by GitHub
6 changed files with 130 additions and 117 deletions
-5
View File
@@ -296,9 +296,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
None None
} }
#[must_use] #[must_use]
pub fn strip_provider_prefix(canonical_model: &str) -> String { pub fn strip_provider_prefix(canonical_model: &str) -> String {
if let Some(pos) = canonical_model.find('/') { if let Some(pos) = canonical_model.find('/') {
@@ -308,8 +305,6 @@ pub fn strip_provider_prefix(canonical_model: &str) -> String {
} }
} }
#[must_use] #[must_use]
pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics { pub fn provider_diagnostics_for_model(model: &str) -> ProviderDiagnostics {
let resolved_model = resolve_model_alias(model); let resolved_model = resolve_model_alias(model);
+39 -98
View File
@@ -16,8 +16,7 @@ use crate::types::{
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
}; };
use super::{preflight_message_request, Provider, ProviderFuture, resolve_model_alias, strip_provider_prefix}; use super::{preflight_message_request, resolve_model_alias, Provider, ProviderFuture};
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
@@ -213,80 +212,22 @@ impl OpenAiCompatClient {
} }
pub async fn send_message( pub async fn send_message(
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageResponse, ApiError> { ) -> Result<MessageResponse, ApiError> {
// 1. Keep track of what Claw originally asked for let original_model = request.model.clone();
let original_model = request.model.clone(); let canonical = resolve_model_alias(&request.model);
let canonical = resolve_model_alias(&request.model);
// 2. Clean the model string (e.g., "openai/deepseek-v4-flash" -> "deepseek-v4-flash")
let downstream_model = strip_provider_prefix(&canonical);
let mut request = MessageRequest { let mut request = MessageRequest {
stream: false, stream: false,
..request.clone() ..request.clone()
}; };
request.model = downstream_model; // Use the clean name for the API payload request.model = canonical;
preflight_message_request(&request)?;
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let body = response.text().await.map_err(ApiError::from)?;
// Some backends return {"error":{"message":"...","type":"...","code":...}} preflight_message_request(&request)?;
// instead of a valid completion object. Check for this before attempting let response = self.send_with_retry(&request).await?;
// full deserialization so the user sees the actual error, not a cryptic. let request_id = request_id_from_headers(response.headers());
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) { let body = response.text().await.map_err(ApiError::from)?;
if let Some(err_obj) = raw.get("error") {
let msg = err_obj
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("provider returned an error")
.to_string();
let code = err_obj
.get("code")
.and_then(serde_json::Value::as_u64)
.map(|c| c as u16);
return Err(ApiError::Api {
status: reqwest::StatusCode::from_u16(code.unwrap_or(400))
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
error_type: err_obj
.get("type")
.and_then(|t| t.as_str())
.map(str::to_owned),
message: Some(msg),
request_id,
body,
retryable: false,
suggested_action: suggested_action_for_status(
reqwest::StatusCode::from_u16(code.unwrap_or(400))
.unwrap_or(reqwest::StatusCode::BAD_REQUEST),
),
retry_after: None,
});
}
}
// Pass original_model to the deserializer error context so debugging logs are accurate
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
})?;
let mut normalized = normalize_response(&request.model, payload)?;
if normalized.request_id.is_none() {
normalized.request_id = request_id;
}
// 3. CRITICAL: Put the original model string back so Claw's internal routing stays happy
normalized.model = original_model;
Ok(normalized)
}
// Some backends return {"error":{"message":"...","type":"...","code":...}}
// instead of a valid completion object. Check for this before attempting
// full deserialization so the user sees the actual error, not a cryptic
// "missing field 'id'" parse failure.
if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) { if let Ok(raw) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(err_obj) = raw.get("error") { if let Some(err_obj) = raw.get("error") {
let msg = err_obj let msg = err_obj
@@ -318,41 +259,41 @@ impl OpenAiCompatClient {
} }
} }
let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| { let payload = serde_json::from_str::<ChatCompletionResponse>(&body).map_err(|error| {
ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error) ApiError::json_deserialize(self.config.provider_name, &original_model, &body, error)
})?; })?;
let mut normalized = normalize_response(&request.model, payload)?; let mut normalized = normalize_response(&request.model, payload)?;
if normalized.request_id.is_none() { if normalized.request_id.is_none() {
normalized.request_id = request_id; normalized.request_id = request_id;
} }
normalized.model = original_model;
Ok(normalized) Ok(normalized)
} }
pub async fn stream_message( pub async fn stream_message(
&self, &self,
request: &MessageRequest, request: &MessageRequest,
) -> Result<MessageStream, ApiError> { ) -> Result<MessageStream, ApiError> {
// 1. Keep track of the original model name let original_model = request.model.clone();
let original_model = request.model.clone(); let canonical = resolve_model_alias(&request.model);
let canonical = resolve_model_alias(&request.model);
// 2. Clean it up for DeepSeek
let downstream_model = strip_provider_prefix(&canonical);
let mut streaming_request = request.clone().with_streaming(); let mut streaming_request = request.clone().with_streaming();
streaming_request.model = downstream_model; streaming_request.model = canonical;
preflight_message_request(&streaming_request)?; preflight_message_request(&streaming_request)?;
let response = self.send_with_retry(&streaming_request).await?; let response = self.send_with_retry(&streaming_request).await?;
Ok(MessageStream { Ok(MessageStream {
request_id: request_id_from_headers(response.headers()), request_id: request_id_from_headers(response.headers()),
response, response,
parser: OpenAiSseParser::with_context(self.config.provider_name, original_model.clone()), parser: OpenAiSseParser::with_context(
pending: VecDeque::new(), self.config.provider_name,
done: false, original_model.clone(),
state: StreamState::new(original_model), // 3. Use the original name here ),
}) pending: VecDeque::new(),
} done: false,
state: StreamState::new(original_model),
})
}
async fn send_with_retry( async fn send_with_retry(
&self, &self,
@@ -548,12 +548,13 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() {
.with_base_url("http://origin.invalid/v1"); .with_base_url("http://origin.invalid/v1");
let response = client let response = client
.send_message(&MessageRequest { .send_message(&MessageRequest {
model: "gpt-4o".to_string(), model: "openai/gpt-4.1-mini".to_string(),
..sample_request(false) ..sample_request(false)
}) })
.await .await
.expect("proxy should return the OpenAI-compatible response"); .expect("proxy should return the OpenAI-compatible response");
assert_eq!(response.model, "openai/gpt-4.1-mini");
assert_eq!(response.total_tokens(), 7); assert_eq!(response.total_tokens(), 7);
let captured = state.lock().await; let captured = state.lock().await;
let request = captured.first().expect("proxy should capture request"); let request = captured.first().expect("proxy should capture request");
@@ -562,6 +563,8 @@ async fn openai_compatible_client_honors_http_proxy_for_requests() {
request.headers.get("authorization").map(String::as_str), request.headers.get("authorization").map(String::as_str),
Some("Bearer openai-test-key") Some("Bearer openai-test-key")
); );
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
assert_eq!(body["model"], json!("openai/gpt-4.1-mini"));
} }
#[allow(clippy::await_holding_lock)] #[allow(clippy::await_holding_lock)]
@@ -832,6 +832,28 @@ mod tests {
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0); static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
struct EnvVarGuard {
key: &'static str,
previous: Option<std::ffi::OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: &Path) -> Self {
let previous = std::env::var_os(key);
std::env::set_var(key, value);
Self { key, previous }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}
fn temp_dir() -> PathBuf { fn temp_dir() -> PathBuf {
let nanos = SystemTime::now() let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
@@ -1290,8 +1312,11 @@ mod tests {
#[test] #[test]
fn latest_session_returns_all_empty_error_when_sessions_exist_but_have_no_messages() { fn latest_session_returns_all_empty_error_when_sessions_exist_but_have_no_messages() {
// given — create sessions with 0 messages (empty) // given — create sessions with 0 messages (empty)
let _env_guard = crate::test_env_lock();
let base = temp_dir(); let base = temp_dir();
fs::create_dir_all(&base).expect("base dir should exist"); fs::create_dir_all(&base).expect("base dir should exist");
let isolated_config_home = base.join("config-home");
let _claw_config_home = EnvVarGuard::set("CLAW_CONFIG_HOME", &isolated_config_home);
let store = SessionStore::from_cwd(&base).expect("store should build"); let store = SessionStore::from_cwd(&base).expect("store should build");
let empty_handle = store.create_handle("empty-session"); let empty_handle = store.create_handle("empty-session");
+6 -9
View File
@@ -1644,16 +1644,13 @@ mod tests {
let tmp = tempfile::tempdir().expect("tempdir"); let tmp = tempfile::tempdir().expect("tempdir");
let worktree = tmp.path().join("worktree"); let worktree = tmp.path().join("worktree");
let git_dir = tmp.path().join("external-gitdir");
fs::create_dir_all(&worktree).expect("worktree dir"); fs::create_dir_all(&worktree).expect("worktree dir");
fs::create_dir_all(git_dir.join("objects")).expect("objects dir"); Command::new("git")
fs::create_dir_all(git_dir.join("refs/heads")).expect("refs dir"); .arg("init")
fs::write(git_dir.join("HEAD"), "ref: refs/heads/main\n").expect("HEAD"); .current_dir(&worktree)
fs::write( .output()
worktree.join(".git"), .expect("git init should run");
format!("gitdir: {}\n", git_dir.display()), let git_dir = worktree.join(".git");
)
.expect(".git file");
let original_permissions = fs::metadata(&git_dir) let original_permissions = fs::metadata(&git_dir)
.expect("gitdir metadata") .expect("gitdir metadata")
+56 -4
View File
@@ -13737,8 +13737,15 @@ fn push_output_block(
}; };
*pending_tool = Some((id, name, initial_input)); *pending_tool = Some((id, name, initial_input));
} }
OutputContentBlock::Thinking { thinking, .. } => { OutputContentBlock::Thinking {
thinking,
signature,
} => {
render_thinking_block_summary(out, Some(thinking.chars().count()), false)?; render_thinking_block_summary(out, Some(thinking.chars().count()), false)?;
events.push(AssistantEvent::Thinking {
thinking,
signature,
});
*block_has_thinking_summary = true; *block_has_thinking_summary = true;
} }
OutputContentBlock::RedactedThinking { .. } => { OutputContentBlock::RedactedThinking { .. } => {
@@ -19073,6 +19080,13 @@ UU conflicted.rs",
assert!(matches!( assert!(matches!(
&events[0], &events[0],
AssistantEvent::Thinking {
thinking,
signature
} if thinking == "step 1" && signature.as_deref() == Some("sig_123")
));
assert!(matches!(
&events[1],
AssistantEvent::TextDelta(text) if text == "Final answer" AssistantEvent::TextDelta(text) if text == "Final answer"
)); ));
let rendered = String::from_utf8(out).expect("utf8"); let rendered = String::from_utf8(out).expect("utf8");
@@ -19649,6 +19663,41 @@ mod dump_manifests_tests {
#[cfg(test)] #[cfg(test)]
mod alias_resolution_tests { mod alias_resolution_tests {
fn ollama_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
LOCK.get_or_init(|| std::sync::Mutex::new(()))
.lock()
.expect("ollama env lock poisoned")
}
struct EnvVarGuard {
key: &'static str,
previous: Option<String>,
}
impl EnvVarGuard {
fn unset(key: &'static str) -> Self {
let previous = std::env::var(key).ok();
std::env::remove_var(key);
Self { key, previous }
}
fn set(key: &'static str, value: &str) -> Self {
let previous = std::env::var(key).ok();
std::env::set_var(key, value);
Self { key, previous }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.previous {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}
use super::{resolve_model_alias_with_config, validate_model_syntax}; use super::{resolve_model_alias_with_config, validate_model_syntax};
#[test] #[test]
@@ -19670,6 +19719,8 @@ mod alias_resolution_tests {
#[test] #[test]
fn test_alias_resolution_syntax_validation() { fn test_alias_resolution_syntax_validation() {
let _guard = ollama_env_lock();
let _env = EnvVarGuard::unset("OLLAMA_HOST");
// Resolved aliases should pass syntax validation // Resolved aliases should pass syntax validation
let resolved = resolve_model_alias_with_config("opus"); let resolved = resolve_model_alias_with_config("opus");
assert!(validate_model_syntax(&resolved).is_ok()); assert!(validate_model_syntax(&resolved).is_ok());
@@ -19680,6 +19731,8 @@ mod alias_resolution_tests {
#[test] #[test]
fn test_unknown_alias_fails_validation() { fn test_unknown_alias_fails_validation() {
let _guard = ollama_env_lock();
let _env = EnvVarGuard::unset("OLLAMA_HOST");
// Unknown aliases resolve to themselves // Unknown aliases resolve to themselves
let resolved = resolve_model_alias_with_config("unknown-alias"); let resolved = resolve_model_alias_with_config("unknown-alias");
assert_eq!(resolved, "unknown-alias"); assert_eq!(resolved, "unknown-alias");
@@ -19699,14 +19752,13 @@ mod alias_resolution_tests {
} }
#[test] #[test]
fn test_ollama_host_bypasses_model_validation() { fn test_ollama_host_bypasses_model_validation() {
// Safety: test sets and clears env var within the test. let _guard = ollama_env_lock();
std::env::set_var("OLLAMA_HOST", "http://127.0.0.1:11434"); let _env = EnvVarGuard::set("OLLAMA_HOST", "http://127.0.0.1:11434");
// Ollama model names with colons pass // Ollama model names with colons pass
assert!(validate_model_syntax("qwen3:8b").is_ok()); assert!(validate_model_syntax("qwen3:8b").is_ok());
assert!(validate_model_syntax("gemma4:e2b").is_ok()); assert!(validate_model_syntax("gemma4:e2b").is_ok());
assert!(validate_model_syntax("qwen3.6:27b-nvfp4").is_ok()); assert!(validate_model_syntax("qwen3.6:27b-nvfp4").is_ok());
// Empty model still rejected // Empty model still rejected
assert!(validate_model_syntax("").is_err()); assert!(validate_model_syntax("").is_err());
std::env::remove_var("OLLAMA_HOST");
} }
} }