Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve code parsing for small LLMs #28

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- Improved AICode parsing and error handling (eg, capture more REPL prompts, detect parsing errors earlier), including the option to remove unsafe code (eg, `Pkg.add("SomePkg")`) with `AICode(msg; skip_unsafe=true, vebose=true)`

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.3.0"
version = "0.4.0-DEV"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand Down
57 changes: 54 additions & 3 deletions src/code_generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ function Base.var"=="(c1::T, c2::T) where {T <: AICode}
end
function Base.show(io::IO, cb::AICode)
success_str = cb.success === nothing ? "N/A" : titlecase(string(cb.success))
expression_str = cb.expression === nothing ? "N/A" : "True"
expression_str = cb.expression === nothing ? "N/A" : titlecase(string(isparsed(cb)))
stdout_str = cb.stdout === nothing ? "N/A" : "True"
output_str = cb.output === nothing ? "N/A" : "True"
error_str = cb.error === nothing ? "N/A" : "True"
Expand All @@ -129,9 +129,33 @@ function Base.show(io::IO, cb::AICode)
"AICode(Success: $success_str, Parsed: $expression_str, Evaluated: $output_str, Error Caught: $error_str, StdOut: $stdout_str, Code: $count_lines Lines)")
end

## Parsing error detection
function isparsed(ex::Expr)
parse_error = Meta.isexpr(ex, :toplevel) && !isempty(ex.args) &&
Meta.isexpr(ex.args[end], (:error, :incomplete))
return !parse_error
end
function isparsed(ex::Nothing)
return false
end
function isparseerror(err::Exception)
return err isa Base.Meta.ParseError ||
(err isa ErrorException && startswith(err.msg, "syntax:"))
end
function isparseerror(err::Nothing)
return false
end
function isparsed(cb::AICode)
return isparsed(cb.expression) && !isparseerror(cb.error)
end

## Overload for AIMessage - simply extracts the code blocks and concatenates them
function AICode(msg::AIMessage; kwargs...)
function AICode(msg::AIMessage;
verbose::Bool = false,
skip_unsafe::Bool = false,
kwargs...)
code = extract_code_blocks(msg.content) |> Base.Fix2(join, "\n")
skip_unsafe && (code = remove_unsafe_lines(code; verbose))
return AICode(code; kwargs...)
end

Expand Down Expand Up @@ -161,6 +185,9 @@ end

# Utility to pinpoint unavailable dependencies
function detect_missing_packages(imports_required::AbstractVector{<:Symbol})
# shortcut if no packages are required
isempty(imports_required) && return false, Symbol[]
#
available_packages = Base.loaded_modules |> values .|> Symbol
missing_packages = filter(pkg -> !in(pkg, available_packages), imports_required)
if length(missing_packages) > 0
Expand All @@ -170,8 +197,22 @@ function detect_missing_packages(imports_required::AbstractVector{<:Symbol})
end
end

"Iterates over the lines of a string and removes those that contain a package operation or a missing import."
function remove_unsafe_lines(code::AbstractString; verbose::Bool = false)
io = IOBuffer()
for line in readlines(IOBuffer(code))
if !detect_pkg_operation(line) &&
!detect_missing_packages(extract_julia_imports(line))[1]
println(io, line)
else
verbose && @info "Unsafe line removed: $line"
end
end
return String(take!(io))
end

"Checks if a given string has a Julia prompt (`julia> `) at the beginning of a line."
has_julia_prompt(s::T) where {T <: AbstractString} = occursin(r"^julia> "m, s)
has_julia_prompt(s::T) where {T <: AbstractString} = occursin(r"(:?^julia> |^> )"m, s)

"""
remove_julia_prompt(s::T) where {T<:AbstractString}
Expand All @@ -191,6 +232,10 @@ function remove_julia_prompt(s::T) where {T <: AbstractString}
code_line = true
# remove the prompt
println(io, replace(line, "julia> " => ""))
elseif startswith(line, r"^> ")
code_line = true
# remove the prompt
println(io, replace(line, "> " => ""))
elseif code_line && startswith(line, r"^ ")
# continuation of the code line
println(io, line)
Expand Down Expand Up @@ -430,6 +475,12 @@ function eval!(cb::AbstractCodeBlock;
return cb
end
end
## Catch bad code extraction
if isempty(code)
cb.error = ErrorException("Parse Error: No code found!")
cb.success = false
return cb
end
## Parse into an expression
try
ex = Meta.parseall(code_extra)
Expand Down
74 changes: 71 additions & 3 deletions test/code_generation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using PromptingTools: extract_julia_imports
using PromptingTools: detect_pkg_operation, detect_missing_packages, extract_function_name
using PromptingTools: detect_pkg_operation,
detect_missing_packages, extract_function_name, remove_unsafe_lines
using PromptingTools: has_julia_prompt, remove_julia_prompt, extract_code_blocks, eval!
using PromptingTools: escape_interpolation, find_subsequence_positions
using PromptingTools: AICode, isparsed, isparseerror

@testset "extract_imports tests" begin
@test extract_julia_imports("using Test, LinearAlgebra") ==
Expand Down Expand Up @@ -30,12 +32,33 @@ end
@test detect_pkg_operation("import Pkg;") == false
end

@testset "remove_unsafe_lines" begin
@test remove_unsafe_lines("Pkg.activate(\".\")") == ""
@test remove_unsafe_lines("Pkg.add(\"SomePkg\")") == ""
s = """
a=1
Pkg.add("a")
b=2
Pkg.add("b")
using 12315456NotExisting
"""
@test remove_unsafe_lines(s) == "a=1\nb=2\n"
@test remove_unsafe_lines("Nothing"; verbose = true) == "Nothing\n"
end

@testset "has_julia_prompt" begin
@test has_julia_prompt("julia> a=1")
@test has_julia_prompt("> a=1")
@test has_julia_prompt("""
# something else first
julia> a=1
""")
@test has_julia_prompt("""
> a=\"\"\"
hey
there
\"\"\"
""")
@test !has_julia_prompt("""
# something
# new
Expand All @@ -45,6 +68,7 @@ end

@testset "remove_julia_prompt" begin
@test remove_julia_prompt("julia> a=1") == "a=1"
@test remove_julia_prompt("> a=1") == "a=1"
@test remove_julia_prompt("""
# something else first
julia> a=1
Expand Down Expand Up @@ -281,8 +305,13 @@ end
@test cb.output == 123
@test a123 == 123

# Check that empty code is invalid
cb = AICode("")
@test !isvalid(cb)
@test cb.error isa Exception

# Test prefix and suffix
cb = AICode(; code = "")
cb = AICode(; code = "x=1")
eval!(cb; prefix = "a=1", suffix = "b=2")
@test cb.output.a == 1
@test cb.output.b == 2
Expand Down Expand Up @@ -331,6 +360,22 @@ b=2
@test cb.stdout == "hello\nworld\n"
@test cb.output.b == 2
end
# skip_unsafe=true
s = """

"""
let msg = AIMessage("""
```julia
a=1
Pkg.add("a")
b=2
Pkg.add("b")
using 12315456NotExisting
```
""")
cb = AICode(msg; skip_unsafe = true)
@test cb.code == "a=1\nb=2\n"
end

# Methods - copy
let msg = AIMessage("""
Expand Down Expand Up @@ -369,7 +414,7 @@ end
"AICode(Success: True, Parsed: True, Evaluated: True, Error Caught: N/A, StdOut: True, Code: 1 Lines)"

# Test with error
code_block = AICode("error(\"Test Error\"))\nprint(\"\")")
code_block = AICode("error(\"Test Error\")\nprint(\"\")")
buffer = IOBuffer()
show(buffer, code_block)
output = String(take!(buffer))
Expand All @@ -391,4 +436,27 @@ end
code1 = AICode("print(\"Hello\")"; safe_eval = true)
code2 = AICode("print(\"Hello\")"; safe_eval = false)
@test code1 != code2
end
@testset "isparsed, isparseerror" begin
## isparsed
@test isparsed(:(x = 1)) == true
# parse an incomplete call
@test isparsed(Meta.parseall("(")) == false
# parse an error call
@test isparsed(Meta.parseall("+-+-+--+")) == false
# nothing
@test isparsed(nothing) == false
# Validate that we don't have false positives with error
@test isparsed(Meta.parseall("error(\"s\")")) == true

## isparseerror
@test isparseerror(nothing) == false
@test isparseerror(ErrorException("syntax: unexpected \"(\" in argument list")) == true
@test isparseerror(Base.Meta.ParseError("xyz")) == true

# AICode
cb = AICode("(")
@test isparsed(cb) == false
cb = AICode("a+1")
@test isparsed(cb) == true
end
Loading