Skip to content

Commit

Permalink
Improve code parsing for small LLMs
Browse files Browse the repository at this point in the history
Improve code parsing for small LLMs
  • Loading branch information
svilupp authored Dec 12, 2023
2 parents f991745 + 20b3f65 commit 9f369f1
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]

### Added
- Improved AICode parsing and error handling (eg, capture more REPL prompts, detect parsing errors earlier), including the option to remove unsafe code (eg, `Pkg.add("SomePkg")`) with `AICode(msg; skip_unsafe=true, vebose=true)`

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.3.0"
version = "0.4.0-DEV"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand Down
57 changes: 54 additions & 3 deletions src/code_generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ function Base.var"=="(c1::T, c2::T) where {T <: AICode}
end
function Base.show(io::IO, cb::AICode)
success_str = cb.success === nothing ? "N/A" : titlecase(string(cb.success))
expression_str = cb.expression === nothing ? "N/A" : "True"
expression_str = cb.expression === nothing ? "N/A" : titlecase(string(isparsed(cb)))
stdout_str = cb.stdout === nothing ? "N/A" : "True"
output_str = cb.output === nothing ? "N/A" : "True"
error_str = cb.error === nothing ? "N/A" : "True"
Expand All @@ -129,9 +129,33 @@ function Base.show(io::IO, cb::AICode)
"AICode(Success: $success_str, Parsed: $expression_str, Evaluated: $output_str, Error Caught: $error_str, StdOut: $stdout_str, Code: $count_lines Lines)")
end

## Parsing error detection
function isparsed(ex::Expr)
parse_error = Meta.isexpr(ex, :toplevel) && !isempty(ex.args) &&
Meta.isexpr(ex.args[end], (:error, :incomplete))
return !parse_error
end
function isparsed(ex::Nothing)
return false
end
function isparseerror(err::Exception)
return err isa Base.Meta.ParseError ||
(err isa ErrorException && startswith(err.msg, "syntax:"))
end
function isparseerror(err::Nothing)
return false
end
function isparsed(cb::AICode)
return isparsed(cb.expression) && !isparseerror(cb.error)
end

## Overload for AIMessage - simply extracts the code blocks and concatenates them
function AICode(msg::AIMessage; kwargs...)
function AICode(msg::AIMessage;
verbose::Bool = false,
skip_unsafe::Bool = false,
kwargs...)
code = extract_code_blocks(msg.content) |> Base.Fix2(join, "\n")
skip_unsafe && (code = remove_unsafe_lines(code; verbose))
return AICode(code; kwargs...)
end

Expand Down Expand Up @@ -161,6 +185,9 @@ end

# Utility to pinpoint unavailable dependencies
function detect_missing_packages(imports_required::AbstractVector{<:Symbol})
# shortcut if no packages are required
isempty(imports_required) && return false, Symbol[]
#
available_packages = Base.loaded_modules |> values .|> Symbol
missing_packages = filter(pkg -> !in(pkg, available_packages), imports_required)
if length(missing_packages) > 0
Expand All @@ -170,8 +197,22 @@ function detect_missing_packages(imports_required::AbstractVector{<:Symbol})
end
end

"Iterates over the lines of a string and removes those that contain a package operation or a missing import."
function remove_unsafe_lines(code::AbstractString; verbose::Bool = false)
io = IOBuffer()
for line in readlines(IOBuffer(code))
if !detect_pkg_operation(line) &&
!detect_missing_packages(extract_julia_imports(line))[1]
println(io, line)
else
verbose && @info "Unsafe line removed: $line"
end
end
return String(take!(io))
end

"Checks if a given string has a Julia prompt (`julia> `) at the beginning of a line."
has_julia_prompt(s::T) where {T <: AbstractString} = occursin(r"^julia> "m, s)
has_julia_prompt(s::T) where {T <: AbstractString} = occursin(r"(:?^julia> |^> )"m, s)

"""
remove_julia_prompt(s::T) where {T<:AbstractString}
Expand All @@ -191,6 +232,10 @@ function remove_julia_prompt(s::T) where {T <: AbstractString}
code_line = true
# remove the prompt
println(io, replace(line, "julia> " => ""))
elseif startswith(line, r"^> ")
code_line = true
# remove the prompt
println(io, replace(line, "> " => ""))
elseif code_line && startswith(line, r"^ ")
# continuation of the code line
println(io, line)
Expand Down Expand Up @@ -430,6 +475,12 @@ function eval!(cb::AbstractCodeBlock;
return cb
end
end
## Catch bad code extraction
if isempty(code)
cb.error = ErrorException("Parse Error: No code found!")
cb.success = false
return cb
end
## Parse into an expression
try
ex = Meta.parseall(code_extra)
Expand Down
74 changes: 71 additions & 3 deletions test/code_generation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using PromptingTools: extract_julia_imports
using PromptingTools: detect_pkg_operation, detect_missing_packages, extract_function_name
using PromptingTools: detect_pkg_operation,
detect_missing_packages, extract_function_name, remove_unsafe_lines
using PromptingTools: has_julia_prompt, remove_julia_prompt, extract_code_blocks, eval!
using PromptingTools: escape_interpolation, find_subsequence_positions
using PromptingTools: AICode, isparsed, isparseerror

@testset "extract_imports tests" begin
@test extract_julia_imports("using Test, LinearAlgebra") ==
Expand Down Expand Up @@ -30,12 +32,33 @@ end
@test detect_pkg_operation("import Pkg;") == false
end

@testset "remove_unsafe_lines" begin
@test remove_unsafe_lines("Pkg.activate(\".\")") == ""
@test remove_unsafe_lines("Pkg.add(\"SomePkg\")") == ""
s = """
a=1
Pkg.add("a")
b=2
Pkg.add("b")
using 12315456NotExisting
"""
@test remove_unsafe_lines(s) == "a=1\nb=2\n"
@test remove_unsafe_lines("Nothing"; verbose = true) == "Nothing\n"
end

@testset "has_julia_prompt" begin
@test has_julia_prompt("julia> a=1")
@test has_julia_prompt("> a=1")
@test has_julia_prompt("""
# something else first
julia> a=1
""")
@test has_julia_prompt("""
> a=\"\"\"
hey
there
\"\"\"
""")
@test !has_julia_prompt("""
# something
# new
Expand All @@ -45,6 +68,7 @@ end

@testset "remove_julia_prompt" begin
@test remove_julia_prompt("julia> a=1") == "a=1"
@test remove_julia_prompt("> a=1") == "a=1"
@test remove_julia_prompt("""
# something else first
julia> a=1
Expand Down Expand Up @@ -281,8 +305,13 @@ end
@test cb.output == 123
@test a123 == 123

# Check that empty code is invalid
cb = AICode("")
@test !isvalid(cb)
@test cb.error isa Exception

# Test prefix and suffix
cb = AICode(; code = "")
cb = AICode(; code = "x=1")
eval!(cb; prefix = "a=1", suffix = "b=2")
@test cb.output.a == 1
@test cb.output.b == 2
Expand Down Expand Up @@ -331,6 +360,22 @@ b=2
@test cb.stdout == "hello\nworld\n"
@test cb.output.b == 2
end
# skip_unsafe=true
s = """
"""
let msg = AIMessage("""
```julia
a=1
Pkg.add("a")
b=2
Pkg.add("b")
using 12315456NotExisting
```
""")
cb = AICode(msg; skip_unsafe = true)
@test cb.code == "a=1\nb=2\n"
end

# Methods - copy
let msg = AIMessage("""
Expand Down Expand Up @@ -369,7 +414,7 @@ end
"AICode(Success: True, Parsed: True, Evaluated: True, Error Caught: N/A, StdOut: True, Code: 1 Lines)"

# Test with error
code_block = AICode("error(\"Test Error\"))\nprint(\"\")")
code_block = AICode("error(\"Test Error\")\nprint(\"\")")
buffer = IOBuffer()
show(buffer, code_block)
output = String(take!(buffer))
Expand All @@ -391,4 +436,27 @@ end
code1 = AICode("print(\"Hello\")"; safe_eval = true)
code2 = AICode("print(\"Hello\")"; safe_eval = false)
@test code1 != code2
end
@testset "isparsed, isparseerror" begin
## isparsed
@test isparsed(:(x = 1)) == true
# parse an incomplete call
@test isparsed(Meta.parseall("(")) == false
# parse an error call
@test isparsed(Meta.parseall("+-+-+--+")) == false
# nothing
@test isparsed(nothing) == false
# Validate that we don't have false positives with error
@test isparsed(Meta.parseall("error(\"s\")")) == true

## isparseerror
@test isparseerror(nothing) == false
@test isparseerror(ErrorException("syntax: unexpected \"(\" in argument list")) == true
@test isparseerror(Base.Meta.ParseError("xyz")) == true

# AICode
cb = AICode("(")
@test isparsed(cb) == false
cb = AICode("a+1")
@test isparsed(cb) == true
end

0 comments on commit 9f369f1

Please sign in to comment.