diff --git a/base/abstractarray.jl b/base/abstractarray.jl index a880b4eecfd3f..4326b999a6930 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1793,6 +1793,10 @@ typed_hcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(2)) # 2d horizontal and vertical concatenation +# these are produced in lowering if splatting occurs inside hvcat +hvcat_rows(rows::Tuple...) = hvcat(map(length, rows), (rows...)...) +typed_hvcat_rows(T::Type, rows::Tuple...) = typed_hvcat(T, map(length, rows), (rows...)...) + function hvcat(nbc::Integer, as...) # nbc = # of block columns n = length(as) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index ac798109ed1d5..6329293e3a795 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1950,6 +1950,34 @@ ,@(map expand-forms (cddr e)))) (cons (car e) (map expand-forms (cdr e)))))) +(define (expand-vcat e + (vcat '((top vcat))) + (hvcat '((top hvcat))) + (hvcat_rows '((top hvcat_rows)))) + (let ((a (cdr e))) + (if (any assignment? a) + (error (string "misplaced assignment statement in \"" (deparse e) "\""))) + (if (has-parameters? a) + (error "unexpected semicolon in array expression") + (expand-forms + (if (any (lambda (x) + (and (pair? x) (eq? (car x) 'row))) + a) + ;; convert nested hcat inside vcat to hvcat + (let ((rows (map (lambda (x) + (if (and (pair? x) (eq? (car x) 'row)) + (cdr x) + (list x))) + a))) + ;; in case there is splatting inside `hvcat`, collect each row as a + ;; separate tuple and pass those to `hvcat_rows` instead (ref #38844) + (if (any (lambda (row) (any vararg? row)) rows) + `(call ,@hvcat_rows ,@(map (lambda (x) `(tuple ,@x)) rows)) + `(call ,@hvcat + (tuple ,@(map length rows)) + ,@(apply append rows)))) + `(call ,@vcat ,@a)))))) + (define (expand-tuple-destruct lhss x) (define (sides-match? l r) ;; l and r either have equal lengths, or r has a trailing ... @@ -2449,27 +2477,7 @@ (error (string "misplaced assignment statement in \"" (deparse e) "\""))) (expand-forms `(call (top hcat) ,@(cdr e)))) - 'vcat - (lambda (e) - (let ((a (cdr e))) - (if (any assignment? a) - (error (string "misplaced assignment statement in \"" (deparse e) "\""))) - (if (has-parameters? a) - (error "unexpected semicolon in array expression") - (expand-forms - (if (any (lambda (x) - (and (pair? x) (eq? (car x) 'row))) - a) - ;; convert nested hcat inside vcat to hvcat - (let ((rows (map (lambda (x) - (if (and (pair? x) (eq? (car x) 'row)) - (cdr x) - (list x))) - a))) - `(call (top hvcat) - (tuple ,.(map length rows)) - ,.(apply append rows))) - `(call (top vcat) ,@a)))))) + 'vcat expand-vcat 'typed_hcat (lambda (e) @@ -2480,23 +2488,8 @@ 'typed_vcat (lambda (e) (let ((t (cadr e)) - (a (cddr e))) - (if (any assignment? (cddr e)) - (error (string "misplaced assignment statement in \"" (deparse e) "\""))) - (expand-forms - (if (any (lambda (x) - (and (pair? x) (eq? (car x) 'row))) - a) - ;; convert nested hcat inside vcat to hvcat - (let ((rows (map (lambda (x) - (if (and (pair? x) (eq? (car x) 'row)) - (cdr x) - (list x))) - a))) - `(call (top typed_hvcat) ,t - (tuple ,.(map length rows)) - ,.(apply append rows))) - `(call (top typed_vcat) ,t ,@a))))) + (e (cdr e))) + (expand-vcat e `((top typed_vcat) ,t) `((top typed_hvcat) ,t) `((top typed_hvcat_rows) ,t)))) '|'| (lambda (e) (expand-forms `(call |'| ,(cadr e)))) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 52af916acbdac..dd24dc28364c7 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1269,3 +1269,14 @@ Base.pushfirst!(tpa::TestPushArray{T}, a::T) where T = pushfirst!(tpa.data, a) pushfirst!(tpa, 6, 5, 4, 3, 2) @test tpa.data == reverse(collect(1:6)) end + +@testset "splatting into hvcat" begin + t = (1, 2) + @test [t...; 3 4] == [1 2; 3 4] + @test [0 t...; t... 0] == [0 1 2; 1 2 0] + @test_throws ArgumentError [t...; 3 4 5] + + @test Int[t...; 3 4] == [1 2; 3 4] + @test Int[0 t...; t... 0] == [0 1 2; 1 2 0] + @test_throws ArgumentError Int[t...; 3 4 5] +end