Skip to content

Commit

Permalink
Rename .turns() -> .get_turns(); .last_turn() -> `.get_last_tur…
Browse files Browse the repository at this point in the history
…n()` (#19)

* Rename .turns() -> .get_turns(); .last_turn() -> .get_last_turn()

* Missed some renaming

* Eliminate warnings from tests
  • Loading branch information
cpsievert authored Dec 10, 2024
1 parent 4e8084a commit 6f3569e
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 53 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,16 @@ Easily get a full markdown or HTML export of a conversation:
chat.export("index.html", title="Python Q&A")
```

If the export doesn't have all the information you need, you can also access the full conversation history via the `.turns()` method:
If the export doesn't have all the information you need, you can also access the full conversation history via the `.get_turns()` method:

```python
chat.turns()
chat.get_turns()
```

And, if the conversation is too long, you can specify which turns to include:

```python
chat.export("index.html", turns=chat.turns()[-5:])
chat.export("index.html", turns=chat.get_turns()[-5:])
```

### Async
Expand Down Expand Up @@ -242,7 +242,7 @@ chat.chat("What is the capital of France?", echo="all")

This shows important information like tool call results, finish reasons, and more.

If the problem isn't self-evident, you can also reach into the `.last_turn()`, which contains the full response object, with full details about the completion.
If the problem isn't self-evident, you can also reach into the `.get_last_turn()`, which contains the full response object, with full details about the completion.


<div style="display:flex;justify-content:center;">
Expand Down
30 changes: 18 additions & 12 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
"css_styles": {},
}

def turns(
def get_turns(
self,
*,
include_system_prompt: bool = False,
Expand All @@ -115,7 +115,7 @@ def turns(
return self._turns[1:]
return self._turns

def last_turn(
def get_last_turn(
self,
*,
role: Literal["assistant", "user", "system"] = "assistant",
Expand Down Expand Up @@ -158,7 +158,12 @@ def set_turns(self, turns: Sequence[Turn]):
@property
def system_prompt(self) -> str | None:
"""
Get the system prompt for the chat.
A property to get (or set) the system prompt for the chat.
Returns
-------
str | None
The system prompt (if any).
"""
if self._turns and self._turns[0].role == "system":
return self._turns[0].text
Expand Down Expand Up @@ -228,7 +233,8 @@ def server(input): # noqa: A002
chat = ui.Chat(
"chat",
messages=[
{"role": turn.role, "content": turn.text} for turn in self.turns()
{"role": turn.role, "content": turn.text}
for turn in self.get_turns()
],
)

Expand Down Expand Up @@ -533,7 +539,7 @@ def extract_data(
for _ in response:
pass

turn = self.last_turn()
turn = self.get_last_turn()
assert turn is not None

res: list[ContentJson] = []
Expand Down Expand Up @@ -593,7 +599,7 @@ async def extract_data_async(
async for _ in response:
pass

turn = self.last_turn()
turn = self.get_last_turn()
assert turn is not None

res: list[ContentJson] = []
Expand Down Expand Up @@ -711,7 +717,7 @@ def export(
The filename to export the chat to. Currently this must
be a `.md` or `.html` file.
turns
The `.turns()` to export. If not provided, the chat's current turns
The `.get_turns()` to export. If not provided, the chat's current turns
will be used.
title
A title to place at the top of the exported file.
Expand All @@ -729,7 +735,7 @@ def export(
The path to the exported file.
"""
if not turns:
turns = self.turns(include_system_prompt=False)
turns = self.get_turns(include_system_prompt=False)
if not turns:
raise ValueError("No turns to export.")

Expand Down Expand Up @@ -986,7 +992,7 @@ def emit(text: str | Content):
self._turns.extend([user_turn, turn])

def _invoke_tools(self) -> Turn | None:
turn = self.last_turn()
turn = self.get_last_turn()
if turn is None:
return None

Expand All @@ -1003,7 +1009,7 @@ def _invoke_tools(self) -> Turn | None:
return Turn("user", results)

async def _invoke_tools_async(self) -> Turn | None:
turn = self.last_turn()
turn = self.get_last_turn()
if turn is None:
return None

Expand Down Expand Up @@ -1112,15 +1118,15 @@ def set_echo_options(
}

def __str__(self):
turns = self.turns(include_system_prompt=False)
turns = self.get_turns(include_system_prompt=False)
res = ""
for turn in turns:
icon = "👤" if turn.role == "user" else "🤖"
res += f"## {icon} {turn.role.capitalize()} turn:\n\n{str(turn)}\n\n"
return res

def __repr__(self):
turns = self.turns(include_system_prompt=True)
turns = self.get_turns(include_system_prompt=True)
tokens = sum(sum(turn.tokens) for turn in turns if turn.tokens)
res = f"<Chat turns={len(turns)} tokens={tokens}>"
for turn in turns:
Expand Down
4 changes: 2 additions & 2 deletions chatlas/_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ class Turn(Generic[CompletionT]):
chat = ChatOpenAI()
str(chat.chat("What is the capital of France?"))
turns = chat.turns()
turns = chat.get_turns()
assert len(turns) == 2
assert isinstance(turns[0], Turn)
assert turns[0].role == "user"
assert turns[1].role == "assistant"
# Load context into a new chat instance
chat2 = ChatAnthropic(turns=turns)
turns2 = chat2.turns()
turns2 = chat2.get_turns()
assert turns == turns2
```
Expand Down
6 changes: 3 additions & 3 deletions docs/reference/Chat.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ You should generally not create this object yourself, but instead call
| [console](#chatlas.Chat.console) | Enter a chat console to interact with the LLM. |
| [extract_data](#chatlas.Chat.extract_data) | Extract structured data from the given input. |
| [extract_data_async](#chatlas.Chat.extract_data_async) | Extract structured data from the given input asynchronously. |
| [last_turn](#chatlas.Chat.last_turn) | Get the last turn in the chat with a specific role. |
| [get_last_turn](#chatlas.Chat.get_last_turn) | Get the last turn in the chat with a specific role. |
| [register_tool](#chatlas.Chat.register_tool) | Register a tool (function) with the chat. |
| [set_turns](#chatlas.Chat.set_turns) | Set the turns of the chat. |
| [tokens](#chatlas.Chat.tokens) | Get the tokens for each turn in the chat. |
Expand Down Expand Up @@ -158,7 +158,7 @@ Extract structured data from the given input asynchronously.
|--------|-----------------------------------------------------|---------------------|
| | [dict](`dict`)\[[str](`str`), [Any](`typing.Any`)\] | The extracted data. |

### last_turn { #chatlas.Chat.last_turn }
### get_last_turn { #chatlas.Chat.get_last_turn }

```python
Chat.get_last_turn(role='assistant')
Expand Down Expand Up @@ -284,7 +284,7 @@ Get the tokens for each turn in the chat.
### turns { #chatlas.Chat.turns }

```python
Chat.turns(include_system_prompt=False)
Chat.get_turns(include_system_prompt=False)
```

Get all the turns (i.e., message contents) in the chat.
Expand Down
4 changes: 2 additions & 2 deletions docs/reference/Turn.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ from chatlas import Turn, ChatOpenAI, ChatAnthropic

chat = ChatOpenAI()
str(chat.chat("What is the capital of France?"))
turns = chat.turns()
turns = chat.get_turns()
assert len(turns) == 2
assert isinstance(turns[0], Turn)
assert turns[0].role == "user"
assert turns[1].role == "assistant"

# Load context into a new chat instance
chat2 = ChatAnthropic(turns=turns)
turns2 = chat2.turns()
turns2 = chat2.get_turns()
assert turns == turns2
```

Expand Down
2 changes: 1 addition & 1 deletion docs/web-apps.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ if prompt := st.chat_input():
with st.chat_message("assistant"):
st.write_stream(response)

st.session_state["turns"] = chat.turns()
st.session_state["turns"] = chat.get_turns()
```
16 changes: 8 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def assert_turns_system(chat_fun: ChatFun):
chat = chat_fun(system_prompt=system_prompt)
response = chat.chat("What is the name of Winnie the Pooh's human friend?")
response_text = str(response)
assert len(chat.turns()) == 2
assert len(chat.get_turns()) == 2
assert "CHRISTOPHER ROBIN" in response_text

chat = chat_fun(turns=[Turn("system", system_prompt)])
response = chat.chat("What is the name of Winnie the Pooh's human friend?")
assert "CHRISTOPHER ROBIN" in str(response)
assert len(chat.turns()) == 2
assert len(chat.get_turns()) == 2


def assert_turns_existing(chat_fun: ChatFun):
Expand All @@ -70,11 +70,11 @@ def assert_turns_existing(chat_fun: ChatFun):
),
]
)
assert len(chat.turns()) == 2
assert len(chat.get_turns()) == 2

response = chat.chat("Who is the remaining one? Just give the name")
assert "Prancer" in str(response)
assert len(chat.turns()) == 4
assert len(chat.get_turns()) == 4


def assert_tools_simple(chat_fun: ChatFun, stream: bool = True):
Expand Down Expand Up @@ -133,7 +133,7 @@ def favorite_color(person: str):

assert "Joe: sage green" in str(response)
assert "Hadley: red" in str(response)
assert len(chat.turns()) == 4
assert len(chat.get_turns()) == 4


def assert_tools_sequential(chat_fun: ChatFun, total_calls: int, stream: bool = True):
Expand All @@ -156,7 +156,7 @@ def equipment(weather: str):
stream=stream,
)
assert "umbrella" in str(response).lower()
assert len(chat.turns()) == total_calls
assert len(chat.get_turns()) == total_calls


def assert_data_extraction(chat_fun: ChatFun):
Expand All @@ -178,7 +178,7 @@ def assert_images_inline(chat_fun: ChatFun, stream: bool = True):
chat = chat_fun()
response = chat.chat(
"What's in this image?",
content_image_file(str(img_path)),
content_image_file(str(img_path), resize="low"),
stream=stream,
)
assert "red" in str(response).lower()
Expand All @@ -202,4 +202,4 @@ def assert_images_remote_error(chat_fun: ChatFun):
with pytest.raises(Exception, match="Remote images aren't supported"):
chat.chat("What's in this image?", image_remote)

assert len(chat.turns()) == 0
assert len(chat.get_turns()) == 0
16 changes: 8 additions & 8 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_simple_streaming_chat():
result = "".join(chunks)
rainbow_re = "^red *\norange *\nyellow *\ngreen *\nblue *\nindigo *\nviolet *\n?$"
assert re.match(rainbow_re, result.lower())
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert re.match(rainbow_re, turn.text.lower())

Expand All @@ -50,7 +50,7 @@ async def test_simple_streaming_chat_async():
result = "".join(chunks)
rainbow_re = "^red *\norange *\nyellow *\ngreen *\nblue *\nindigo *\nviolet *\n?$"
assert re.match(rainbow_re, result.lower())
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert re.match(rainbow_re, turn.text.lower())

Expand Down Expand Up @@ -119,24 +119,24 @@ class Person(BaseModel):

def test_last_turn_retrieval():
chat = ChatOpenAI()
assert chat.last_turn(role="user") is None
assert chat.last_turn(role="assistant") is None
assert chat.get_last_turn(role="user") is None
assert chat.get_last_turn(role="assistant") is None

chat.chat("Hi")
user_turn = chat.last_turn(role="user")
user_turn = chat.get_last_turn(role="user")
assert user_turn is not None and user_turn.role == "user"
turn = chat.last_turn(role="assistant")
turn = chat.get_last_turn(role="assistant")
assert turn is not None and turn.role == "assistant"


def test_system_prompt_retrieval():
chat1 = ChatOpenAI()
assert chat1.system_prompt is None
assert chat1.last_turn(role="system") is None
assert chat1.get_last_turn(role="system") is None

chat2 = ChatOpenAI(system_prompt="You are from New Zealand")
assert chat2.system_prompt == "You are from New Zealand"
turn = chat2.last_turn(role="system")
turn = chat2.get_last_turn(role="system")
assert turn is not None and turn.text == "You are from New Zealand"


Expand Down
5 changes: 3 additions & 2 deletions tests/test_content_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_can_create_image_from_path(tmp_path):
path = tmp_path / "test.png"
img.save(path)

obj = content_image_file(str(path))
obj = content_image_file(str(path), resize="low")
assert isinstance(obj, ContentImageInline)


Expand Down Expand Up @@ -65,7 +65,8 @@ def test_image_resizing(tmp_path):
content_image_file(str(tmp_path / "test.txt"))

# Test valid resize options
assert content_image_file(str(img_path)) is not None
with pytest.warns(RuntimeWarning):
assert content_image_file(str(img_path)) is not None
assert content_image_file(str(img_path), resize="low") is not None
assert content_image_file(str(img_path), resize="high") is not None
assert content_image_file(str(img_path), resize="none") is not None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_anthropic_simple_request():
system_prompt="Be as terse as possible; no punctuation",
)
chat.chat("What is 1 + 1?")
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (26, 5)
assert turn.finish_reason == "end_turn"
Expand All @@ -37,7 +37,7 @@ async def test_anthropic_simple_streaming_request():
async for x in foo:
res.append(x)
assert "2" in "".join(res)
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert turn.finish_reason == "end_turn"

Expand Down
4 changes: 2 additions & 2 deletions tests/test_provider_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_azure_simple_request():

response = chat.chat("What is 1 + 1?")
assert "2" == response.get_content()
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (27, 1)

Expand All @@ -34,6 +34,6 @@ async def test_azure_simple_request_async():

response = await chat.chat_async("What is 1 + 1?")
assert "2" == await response.get_content()
turn = chat.last_turn()
turn = chat.get_last_turn()
assert turn is not None
assert turn.tokens == (27, 1)
2 changes: 1 addition & 1 deletion tests/test_provider_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# system_prompt="Be as terse as possible; no punctuation",
# )
# _ = str(chat.chat("What is 1 + 1?"))
# turn = chat.last_turn()
# turn = chat.get_last_turn()
# assert turn is not None
# assert turn.tokens == (26, 5)

Expand Down
Loading

0 comments on commit 6f3569e

Please sign in to comment.