From 0dc08e391f0913a0aa29f9caa3c49e82efc192fa Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Mon, 11 Dec 2023 21:41:03 +0000 Subject: [PATCH 1/3] catch parsing errors --- src/code_generation.jl | 32 +++++++++++++++++++++++++++++++- test/code_generation.jl | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/code_generation.jl b/src/code_generation.jl index ff10030f2..c11b7d4ba 100644 --- a/src/code_generation.jl +++ b/src/code_generation.jl @@ -129,6 +129,26 @@ function Base.show(io::IO, cb::AICode) "AICode(Success: $success_str, Parsed: $expression_str, Evaluated: $output_str, Error Caught: $error_str, StdOut: $stdout_str, Code: $count_lines Lines)") end +## Parsing error detection +function isparsed(ex::Expr) + parse_error = Meta.isexpr(ex, :toplevel) && !isempty(ex.args) && + Meta.isexpr(ex.args[end], (:error, :incomplete)) + return !parse_error +end +function isparsed(ex::Nothing) + return false +end +function isparseerror(err::Exception) + return err isa Base.Meta.ParseError || + (err isa ErrorException && startswith(err.msg, "syntax:")) +end +function isparseerror(err::Nothing) + return false +end +function isparsed(cb::AICode) + return isparsed(cb.expression) && !isparseerror(cb.error) +end + ## Overload for AIMessage - simply extracts the code blocks and concatenates them function AICode(msg::AIMessage; kwargs...) code = extract_code_blocks(msg.content) |> Base.Fix2(join, "\n") @@ -171,7 +191,7 @@ function detect_missing_packages(imports_required::AbstractVector{<:Symbol}) end "Checks if a given string has a Julia prompt (`julia> `) at the beginning of a line." -has_julia_prompt(s::T) where {T <: AbstractString} = occursin(r"^julia> "m, s) +has_julia_prompt(s::T) where {T <: AbstractString} = occursin(r"(:?^julia> |^> )"m, s) """ remove_julia_prompt(s::T) where {T<:AbstractString} @@ -191,6 +211,10 @@ function remove_julia_prompt(s::T) where {T <: AbstractString} code_line = true # remove the prompt println(io, replace(line, "julia> " => "")) + elseif startswith(line, r"^> ") + code_line = true + # remove the prompt + println(io, replace(line, "> " => "")) elseif code_line && startswith(line, r"^ ") # continuation of the code line println(io, line) @@ -430,6 +454,12 @@ function eval!(cb::AbstractCodeBlock; return cb end end + ## Catch bad code extraction + if isempty(code) + cb.error = ErrorException("Parse Error: No code found!") + cb.success = false + return cb + end ## Parse into an expression try ex = Meta.parseall(code_extra) diff --git a/test/code_generation.jl b/test/code_generation.jl index 38bfb87fc..56691b71d 100644 --- a/test/code_generation.jl +++ b/test/code_generation.jl @@ -2,6 +2,7 @@ using PromptingTools: extract_julia_imports using PromptingTools: detect_pkg_operation, detect_missing_packages, extract_function_name using PromptingTools: has_julia_prompt, remove_julia_prompt, extract_code_blocks, eval! using PromptingTools: escape_interpolation, find_subsequence_positions +using PromptingTools: AICode, isparsed, isparseerror @testset "extract_imports tests" begin @test extract_julia_imports("using Test, LinearAlgebra") == @@ -32,10 +33,17 @@ end @testset "has_julia_prompt" begin @test has_julia_prompt("julia> a=1") + @test has_julia_prompt("> a=1") @test has_julia_prompt(""" # something else first julia> a=1 """) + @test has_julia_prompt(""" + > a=\"\"\" + hey + there + \"\"\" + """) @test !has_julia_prompt(""" # something # new @@ -45,6 +53,7 @@ end @testset "remove_julia_prompt" begin @test remove_julia_prompt("julia> a=1") == "a=1" + @test remove_julia_prompt("> a=1") == "a=1" @test remove_julia_prompt(""" # something else first julia> a=1 @@ -281,8 +290,13 @@ end @test cb.output == 123 @test a123 == 123 + # Check that empty code is invalid + cb = AICode("") + @test !isvalid(cb) + @test cb.error isa Exception + # Test prefix and suffix - cb = AICode(; code = "") + cb = AICode(; code = "x=1") eval!(cb; prefix = "a=1", suffix = "b=2") @test cb.output.a == 1 @test cb.output.b == 2 @@ -391,4 +405,27 @@ end code1 = AICode("print(\"Hello\")"; safe_eval = true) code2 = AICode("print(\"Hello\")"; safe_eval = false) @test code1 != code2 +end +@testset "isparsed, isparseerror" begin + ## isparsed + @test isparsed(:(x = 1)) == true + # parse an incomplete call + @test isparsed(Meta.parseall("(")) == false + # parse an error call + @test isparsed(Meta.parseall("+-+-+--+")) == false + # nothing + @test isparsed(nothing) == false + # Validate that we don't have false positives with error + @test isparsed(Meta.parseall("error(\"s\")")) == true + + ## isparseerror + @test isparseerror(nothing) == false + @test isparseerror(ErrorException("syntax: unexpected \"(\" in argument list")) == true + @test isparseerror(Base.Meta.ParseError("xyz")) == true + + # AICode + cb = AICode("(") + @test isparsed(cb) == false + cb = AICode("a+1") + @test isparsed(cb) == true end \ No newline at end of file From cfd3c1f6c57858f878c306b939f5d1f1013e43d7 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Tue, 12 Dec 2023 20:57:25 +0000 Subject: [PATCH 2/3] improve code parsing --- CHANGELOG.md | 1 + Project.toml | 2 +- src/code_generation.jl | 23 ++++++++++++++++++++++- test/code_generation.jl | 33 ++++++++++++++++++++++++++++++++- 4 files changed, 56 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c60b5cb3..d4821241b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Improved AICode parsing and error handling (eg, capture more REPL prompts, detect parsing errors earlier), including the option to remove unsafe code (eg, `Pkg.add("SomePkg")`) with `AICode(msg; skip_unsafe=true, vebose=true)` ### Fixed diff --git a/Project.toml b/Project.toml index bb1236c5b..5957c08c7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.3.0" +version = "0.4.0-DEV" [deps] Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" diff --git a/src/code_generation.jl b/src/code_generation.jl index c11b7d4ba..578e99685 100644 --- a/src/code_generation.jl +++ b/src/code_generation.jl @@ -150,8 +150,12 @@ function isparsed(cb::AICode) end ## Overload for AIMessage - simply extracts the code blocks and concatenates them -function AICode(msg::AIMessage; kwargs...) +function AICode(msg::AIMessage; + verbose::Bool = false, + skip_unsafe::Bool = false, + kwargs...) code = extract_code_blocks(msg.content) |> Base.Fix2(join, "\n") + skip_unsafe && (code = remove_unsafe_lines(code; verbose)) return AICode(code; kwargs...) end @@ -181,6 +185,9 @@ end # Utility to pinpoint unavailable dependencies function detect_missing_packages(imports_required::AbstractVector{<:Symbol}) + # shortcut if no packages are required + isempty(imports_required) && return false, Symbol[] + # available_packages = Base.loaded_modules |> values .|> Symbol missing_packages = filter(pkg -> !in(pkg, available_packages), imports_required) if length(missing_packages) > 0 @@ -190,6 +197,20 @@ function detect_missing_packages(imports_required::AbstractVector{<:Symbol}) end end +"Iterates over the lines of a string and removes those that contain a package operation or a missing import." +function remove_unsafe_lines(code::AbstractString; verbose::Bool = false) + io = IOBuffer() + for line in readlines(IOBuffer(code)) + if !detect_pkg_operation(line) && + !detect_missing_packages(extract_julia_imports(line))[1] + println(io, line) + else + verbose && @info "Unsafe line removed: $line" + end + end + return String(take!(io)) +end + "Checks if a given string has a Julia prompt (`julia> `) at the beginning of a line." has_julia_prompt(s::T) where {T <: AbstractString} = occursin(r"(:?^julia> |^> )"m, s) diff --git a/test/code_generation.jl b/test/code_generation.jl index 56691b71d..0e69ef6f5 100644 --- a/test/code_generation.jl +++ b/test/code_generation.jl @@ -1,5 +1,6 @@ using PromptingTools: extract_julia_imports -using PromptingTools: detect_pkg_operation, detect_missing_packages, extract_function_name +using PromptingTools: detect_pkg_operation, + detect_missing_packages, extract_function_name, remove_unsafe_lines using PromptingTools: has_julia_prompt, remove_julia_prompt, extract_code_blocks, eval! using PromptingTools: escape_interpolation, find_subsequence_positions using PromptingTools: AICode, isparsed, isparseerror @@ -31,6 +32,20 @@ end @test detect_pkg_operation("import Pkg;") == false end +@testset "remove_unsafe_lines" begin + @test remove_unsafe_lines("Pkg.activate(\".\")") == "" + @test remove_unsafe_lines("Pkg.add(\"SomePkg\")") == "" + s = """ + a=1 + Pkg.add("a") + b=2 + Pkg.add("b") + using 12315456NotExisting + """ + @test remove_unsafe_lines(s) == "a=1\nb=2\n" + @test remove_unsafe_lines("Nothing"; verbose = true) == "Nothing\n" +end + @testset "has_julia_prompt" begin @test has_julia_prompt("julia> a=1") @test has_julia_prompt("> a=1") @@ -345,6 +360,22 @@ b=2 @test cb.stdout == "hello\nworld\n" @test cb.output.b == 2 end + # skip_unsafe=true + s = """ + + """ + let msg = AIMessage(""" + ```julia + a=1 + Pkg.add("a") + b=2 + Pkg.add("b") + using 12315456NotExisting + ``` + """) + cb = AICode(msg; skip_unsafe = true) + @test cb.code == "a=1\nb=2\n" + end # Methods - copy let msg = AIMessage(""" From 47f25ea14f4ba1b3d1a52275db6df423292d8b0a Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Tue, 12 Dec 2023 21:14:14 +0000 Subject: [PATCH 3/3] update tests --- src/code_generation.jl | 2 +- test/code_generation.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/code_generation.jl b/src/code_generation.jl index 578e99685..4a0e96b39 100644 --- a/src/code_generation.jl +++ b/src/code_generation.jl @@ -119,7 +119,7 @@ function Base.var"=="(c1::T, c2::T) where {T <: AICode} end function Base.show(io::IO, cb::AICode) success_str = cb.success === nothing ? "N/A" : titlecase(string(cb.success)) - expression_str = cb.expression === nothing ? "N/A" : "True" + expression_str = cb.expression === nothing ? "N/A" : titlecase(string(isparsed(cb))) stdout_str = cb.stdout === nothing ? "N/A" : "True" output_str = cb.output === nothing ? "N/A" : "True" error_str = cb.error === nothing ? "N/A" : "True" diff --git a/test/code_generation.jl b/test/code_generation.jl index 0e69ef6f5..6ec54cc44 100644 --- a/test/code_generation.jl +++ b/test/code_generation.jl @@ -414,7 +414,7 @@ end "AICode(Success: True, Parsed: True, Evaluated: True, Error Caught: N/A, StdOut: True, Code: 1 Lines)" # Test with error - code_block = AICode("error(\"Test Error\"))\nprint(\"\")") + code_block = AICode("error(\"Test Error\")\nprint(\"\")") buffer = IOBuffer() show(buffer, code_block) output = String(take!(buffer))