Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip thinking section in Claude tool call response #226

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
93 changes: 80 additions & 13 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
from enum import Enum
from functools import singledispatch
from itertools import chain, groupby
from itertools import chain, dropwhile, groupby
from typing import Any, AsyncIterable, Generic, Sequence, TypeVar, cast, overload

from pydantic import ValidationError
Expand Down Expand Up @@ -39,8 +39,11 @@
AsyncStreamedStr,
StreamedStr,
achain,
adropwhile,
agroupby,
apeek,
async_iter,
peek,
)
from magentic.typing import is_any_origin_subclass, is_origin_subclass

Expand Down Expand Up @@ -312,6 +315,70 @@ async def agenerator(
return usage_ref, agenerator(response)


def _extract_thinking(
response: Iterator[ToolsBetaMessageStreamEvent],
) -> tuple[str | None, Iterator[ToolsBetaMessageStreamEvent]]:
"""Extract the <thinking>...</thinking> block from the response."""
first_chunk = next(response)
if not (
first_chunk.type == "content_block_start"
and first_chunk.content_block.type == "text"
):
return None, chain([first_chunk], response)

second_chunk = next(response)
assert second_chunk.type == "content_block_delta" # noqa: S101
assert second_chunk.delta.type == "text_delta" # noqa: S101
if not second_chunk.delta.text.startswith("<thinking>"):
return None, chain([first_chunk, second_chunk], response)

thinking = second_chunk.delta.text.removeprefix("<thinking>").lstrip()
for chunk in response:
assert chunk.type == "content_block_delta" # noqa: S101
assert chunk.delta.type == "text_delta" # noqa: S101
thinking += chunk.delta.text
if "</thinking>" in thinking:
break
thinking = thinking.rstrip().removesuffix("</thinking>").rstrip()
first_chunk = next(response)
# content_block_stop encountered if switching to tool calls
if first_chunk.type == "content_block_stop":
first_chunk = next(response)
return thinking, chain([first_chunk], response)


async def _aextract_thinking(
response: AsyncIterator[ToolsBetaMessageStreamEvent],
) -> tuple[str | None, AsyncIterator[ToolsBetaMessageStreamEvent]]:
"""Async version of `_extract_thinking`."""
first_chunk = await anext(response)
if not (
first_chunk.type == "content_block_start"
and first_chunk.content_block.type == "text"
):
return None, achain(async_iter([first_chunk]), response)

second_chunk = await anext(response)
assert second_chunk.type == "content_block_delta" # noqa: S101
assert second_chunk.delta.type == "text_delta" # noqa: S101
if not second_chunk.delta.text.startswith("<thinking>"):
return None, achain(async_iter([first_chunk, second_chunk]), response)

thinking = second_chunk.delta.text.removeprefix("<thinking>").lstrip()
async for chunk in response:
assert chunk.type == "content_block_delta" # noqa: S101
assert chunk.delta.type == "text_delta" # noqa: S101
thinking += chunk.delta.text
if "</thinking>" in thinking:
break
thinking = thinking.rstrip().removesuffix("</thinking>").rstrip()
first_chunk = await anext(response)
# content_block_stop encountered if switching to tool calls
if first_chunk.type == "content_block_stop":
first_chunk = await anext(response)
return thinking, achain(async_iter([first_chunk]), response)


R = TypeVar("R")


Expand Down Expand Up @@ -450,16 +517,16 @@ def _response_generator() -> Iterator[ToolsBetaMessageStreamEvent]:

response = _response_generator()
usage_ref, response = _create_usage_ref(response)

first_chunk = next(response)
if first_chunk.type == "message_start":
first_chunk = next(response)
assert first_chunk.type == "content_block_start" # noqa: S101
response = chain([first_chunk], response)
response = dropwhile(lambda x: x.type != "content_block_start", response)
_, response = _extract_thinking(response)
first_chunk, response = peek(response)

if (
first_chunk.type == "content_block_start"
and first_chunk.content_block.type == "text"
) or (
first_chunk.type == "content_block_delta"
and first_chunk.delta.type == "text_delta"
):
streamed_str = StreamedStr(
chunk.delta.text
Expand Down Expand Up @@ -569,16 +636,16 @@ async def _response_generator() -> AsyncIterator[ToolsBetaMessageStreamEvent]:

response = _response_generator()
usage_ref, response = _create_usage_ref_async(response)

first_chunk = await anext(response)
if first_chunk.type == "message_start":
first_chunk = await anext(response)
assert first_chunk.type == "content_block_start" # noqa: S101
response = achain(async_iter([first_chunk]), response)
response = adropwhile(lambda x: x.type != "content_block_start", response)
_, response = await _aextract_thinking(response)
first_chunk, response = await apeek(response)

if (
first_chunk.type == "content_block_start"
and first_chunk.content_block.type == "text"
) or (
first_chunk.type == "content_block_delta"
and first_chunk.delta.type == "text_delta"
):
async_streamed_str = AsyncStreamedStr(
chunk.delta.text
Expand Down
41 changes: 41 additions & 0 deletions tests/chat_model/test_anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@ def plus(a: int, b: int) -> int:
assert isinstance(message.content, FunctionCall)


@pytest.mark.parametrize(
"model_name",
[
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
],
)
@pytest.mark.anthropic
def test_anthropic_chat_model_complete_function_call_with_thinking(model_name):
def plus(a: int, b: int) -> int:
"""Sum two numbers."""
return a + b

chat_model = AnthropicChatModel(model_name)
message = chat_model.complete(
messages=[UserMessage("Use the tool to sum 1 and 2")],
functions=[plus],
# Union with str so tool call is not forced => <thinking> section is generated
output_types=[FunctionCall[int], str], # type: ignore[misc]
)
assert isinstance(message.content, FunctionCall)


@pytest.mark.anthropic
def test_anthropic_chat_model_complete_parallel_function_call():
def plus(a: int, b: int) -> int:
Expand Down Expand Up @@ -165,6 +189,23 @@ def plus(a: int, b: int) -> int:
assert isinstance(message.content, FunctionCall)


@pytest.mark.asyncio
@pytest.mark.anthropic
async def test_anthropic_chat_model_acomplete_function_call_with_thinking():
def plus(a: int, b: int) -> int:
"""Sum two numbers."""
return a + b

chat_model = AnthropicChatModel("claude-3-haiku-20240307")
message = await chat_model.acomplete(
messages=[UserMessage("Use the tool to sum 1 and 2")],
functions=[plus],
# Union with str so tool call is not forced => <thinking> section is generated
output_types=[FunctionCall[int], str], # type: ignore[misc]
)
assert isinstance(message.content, FunctionCall)


@pytest.mark.asyncio
@pytest.mark.anthropic
async def test_anthropic_chat_model_acomplete_async_parallel_function_call():
Expand Down