diff --git a/src/code_generation.jl b/src/code_generation.jl index a3357910c..11840addd 100644 --- a/src/code_generation.jl +++ b/src/code_generation.jl @@ -198,6 +198,51 @@ function remove_julia_prompt(s::T) where {T <: AbstractString} String(take!(io)) |> strip end +# escape dollar sign only if not preceeded by backslash already, ie, unescaped -- use negative lookbehind +# Useful in cases where we have double nested interpolation, eg, string code -> has string literal -> function with interpolation inside it +escape_interpolation(s::AbstractString) = replace(s, r"(? String(['\\', '$'])) + +""" + find_subsequence_positions(subseq, seq) -> Vector{Int} + +Find all positions of a subsequence `subseq` within a larger sequence `seq`. Used to lookup positions of code blocks in markdown. + +This function scans the sequence `seq` and identifies all starting positions where the subsequence `subseq` is found. Both `subseq` and `seq` should be vectors of integers, typically obtained using `codeunits` on strings. + +# Arguments +- `subseq`: A vector of integers representing the subsequence to search for. +- `seq`: A vector of integers representing the larger sequence in which to search. + +# Returns +- `Vector{Int}`: A vector of starting positions (1-based indices) where the subsequence is found in the sequence. + +# Examples +```julia +find_subsequence_positions(codeunits("ab"), codeunits("cababcab")) # Returns [2, 5] +``` +""" +function find_subsequence_positions(subseq, seq) + positions = Int[] + len_subseq = length(subseq) + len_seq = length(seq) + lim = len_seq - len_subseq + 1 + cur = 1 + while cur <= lim + match = true + @inbounds for i in 1:len_subseq + if seq[cur + i - 1] != subseq[i] + match = false + break + end + end + if match + push!(positions, cur) + end + cur += 1 + end + return positions +end + """ extract_code_blocks(markdown_content::String) -> Vector{String} @@ -243,17 +288,46 @@ extract_code_blocks(markdown_multiple) # Output: ["x = 5", "y = x + 2"] ``` """ -function extract_code_blocks(markdown_content::AbstractString) - # Define the pattern for Julia code blocks - pattern = r"```julia\n(.*?)\n```"s - - # Find all matches and extract the code - matches = eachmatch(pattern, markdown_content) +function extract_code_blocks(markdown_content::T) where {T <: AbstractString} + # Convert content and delimiters to codeunits + content_units = codeunits(markdown_content) + start_delim_units = codeunits("```julia") + end_delim_units = codeunits("```") + + # Find all starting and ending positions of code blocks + start_positions = find_subsequence_positions(start_delim_units, content_units) + end_positions = setdiff(find_subsequence_positions(end_delim_units, content_units), + start_positions) + unused_end_positions = trues(length(end_positions)) + + # Generate code block position pairs + block_positions = Tuple{Int, Int}[] + for start_pos in reverse(start_positions) + for (i, end_pos) in enumerate(end_positions) + if end_pos > start_pos && unused_end_positions[i] + push!(block_positions, (start_pos, end_pos)) + unused_end_positions[i] = false + break + end + end + end - # Extract and clean the code blocks (remove the julia prompt) - code_blocks = String[remove_julia_prompt(m.captures[1]) for m in matches] + # Filter out nested blocks (only if they have full overlap) + filtered_positions = filter(inner -> !any(outer -> (outer[1] < inner[1]) && + (inner[2] < outer[2]), + block_positions), + block_positions) + + # Extract code blocks + eltype_ = typeof(@view(markdown_content[begin:end])) + code_blocks = Vector{eltype_}() + for (start_pos, end_pos) in filtered_positions + code_block = markdown_content[(start_pos + length(start_delim_units)):(end_pos - 1)] + # Also remove the julia prompt + push!(code_blocks, remove_julia_prompt(strip(code_block))) + end - return code_blocks + return reverse(code_blocks) # Reverse to maintain original order end """ diff --git a/test/code_generation.jl b/test/code_generation.jl index 1f18e7d34..892f253e1 100644 --- a/test/code_generation.jl +++ b/test/code_generation.jl @@ -1,6 +1,7 @@ using PromptingTools: extract_julia_imports using PromptingTools: detect_pkg_operation, detect_missing_packages, extract_function_name using PromptingTools: has_julia_prompt, remove_julia_prompt, extract_code_blocks, eval! +using PromptingTools: escape_interpolation, find_subsequence_positions @testset "extract_imports tests" begin @test extract_julia_imports("using Test, LinearAlgebra") == @@ -71,6 +72,29 @@ a=\"\"\" \"\"\"""" end +@testset "escape_interpolation" begin + @test escape_interpolation("aaa") == "aaa" + @test escape_interpolation("\$") == String(['\\', '$']) +end + +@testset "find_subsequence_positions" begin + # Test 1: Basic functionality + @test find_subsequence_positions(codeunits("ab"), codeunits("cababcab")) == [2, 4, 7] + + # Test 2: Subsequence not in sequence + @test find_subsequence_positions(codeunits("xyz"), codeunits("hello")) == [] + + # Test 3: Empty subsequence -- should return all positions+1 + @test find_subsequence_positions(codeunits(""), codeunits("hello")) == 1:6 + + # Test 4: Subsequence longer than sequence + @test find_subsequence_positions(codeunits("longsubsequence"), codeunits("short")) == [] + + # Test 5: Repeated characters + @test find_subsequence_positions(codeunits("ana"), codeunits("banana")) == [2, 4] + @test find_subsequence_positions(codeunits("a"), codeunits("a"^6)) == 1:6 +end + @testset "extract_code_blocks" begin # Single Julia Code Block markdown_content = """ @@ -79,7 +103,8 @@ end println("Hello, World!") ``` """ - @test extract_code_blocks(markdown_content) == ["println(\"Hello, World!\")"] + @test extract_code_blocks(markdown_content) == + SubString{String}["println(\"Hello, World!\")"] # Multiple Julia Code Blocks markdown_content = """ @@ -92,7 +117,7 @@ end ``` """ @test extract_code_blocks(markdown_content) == - ["println(\"First Block\")", "println(\"Second Block\")"] + SubString{String}["println(\"First Block\")", "println(\"Second Block\")"] # No Julia Code Blocks markdown_content = """ @@ -109,9 +134,10 @@ end println("This is Julia") ``` """ - @test extract_code_blocks(markdown_content) == ["println(\"This is Julia\")"] + @test extract_code_blocks(markdown_content) == + SubString{String}["println(\"This is Julia\")"] - # Nested Code Blocks" + # Nested Blocks (plain block outer) markdown_content = """ ``` ```julia @@ -119,7 +145,28 @@ end ``` ``` """ - @test extract_code_blocks(markdown_content) == ["println(\"Nested Block\")"] + @test extract_code_blocks(markdown_content) == + SubString{String}["println(\"Nested Block\")"] + + # Nested Julia code blocks + markdown_example = """ + ```julia + # Outer Julia code block + + # An example of a nested Julia code block in markdown + \"\"\" + ```julia + x = 5 + println(x) + ``` + \"\"\" + + y = 10 + println(y) + ``` + """ + @test extract_code_blocks(markdown_example) == + SubString{String}["# Outer Julia code block\n\n# An example of a nested Julia code block in markdown\n\"\"\"\n```julia\nx = 5\nprintln(x)\n```\n\"\"\"\n\ny = 10\nprintln(y)"] end @testset "extract_function_name" begin