Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed Dec 7, 2023
1 parent 4bd280f commit 06a518b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
8 changes: 2 additions & 6 deletions src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function UserMessage(content::T,
end
Base.@kwdef struct UserMessageWithImages{T <: AbstractString} <: AbstractChatMessage
content::T
image_url::Vector{<:AbstractString} # no default! fail when not provided
image_url::Vector{String} # no default! fail when not provided
variables::Vector{Symbol} = _extract_handlebar_variables(content)
_type::Symbol = :usermessagewithimages
UserMessageWithImages{T}(c, i, v, t) where {T <: AbstractString} = new(c, i, v, t)
Expand All @@ -59,7 +59,7 @@ function UserMessageWithImages(content::T, image_url::Vector{<:AbstractString},
type::Symbol) where {T <: AbstractString}
not_allowed_kwargs = intersect(variables, RESERVED_KWARGS)
@assert length(not_allowed_kwargs)==0 "Error: Some placeholders are invalid, as they are reserved for `ai*` functions. Change: $(join(not_allowed_kwargs,","))"
return UserMessageWithImages{T}(content, image_url, variables, type)
return UserMessageWithImages{T}(content, string.(image_url), variables, type)
end
Base.@kwdef struct AIMessage{T <: Union{AbstractString, Nothing}} <: AbstractChatMessage
content::T = nothing
Expand Down Expand Up @@ -89,10 +89,6 @@ Base.var"=="(m1::AbstractMessage, m2::AbstractMessage) = false
function Base.var"=="(m1::T, m2::T) where {T <: AbstractMessage}
all([getproperty(m1, f) == getproperty(m2, f) for f in fieldnames(T)])
end
Base.length(t::AbstractMessage) = nfields(t)
function Base.iterate(t::AbstractMessage, iter = 1)
iter > nfields(t) ? nothing : (getfield(t, iter), iter + 1)
end

## Vision Models -- Constructor and Conversion
"Construct `UserMessageWithImages` with 1 or more images. Images can be either URLs or local paths."
Expand Down
7 changes: 5 additions & 2 deletions test/serialization.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using PromptingTools: AIMessage, SystemMessage, UserMessage
using PromptingTools: AIMessage,
SystemMessage, UserMessage, UserMessageWithImages, AbstractMessage, DataMessage
using PromptingTools: save_conversation, load_conversation
using PromptingTools: save_template, load_template

Expand All @@ -7,7 +8,9 @@ using PromptingTools: save_template, load_template
messages = AbstractMessage[SystemMessage("System message 1"),
UserMessage("User message"),
AIMessage("AI message"),
DataMessage(; content = "Data message")]
UserMessageWithImages(; content = "a", image_url = String["b", "c"]),
DataMessage(;
content = "Data message")]
tmp, _ = mktemp()
save_conversation(tmp, messages)
# Test load_conversation
Expand Down

0 comments on commit 06a518b

Please sign in to comment.