Skip to content

Commit

Permalink
allow splatting in hvcat syntax (#39249)
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub authored Feb 5, 2021
1 parent 0ea19e4 commit 7d34b0d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 38 deletions.
4 changes: 4 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1793,6 +1793,10 @@ typed_hcat(T::Type, A::AbstractArray...) = cat_t(T, A...; dims=Val(2))

# 2d horizontal and vertical concatenation

# these are produced in lowering if splatting occurs inside hvcat
hvcat_rows(rows::Tuple...) = hvcat(map(length, rows), (rows...)...)
typed_hvcat_rows(T::Type, rows::Tuple...) = typed_hvcat(T, map(length, rows), (rows...)...)

function hvcat(nbc::Integer, as...)
# nbc = # of block columns
n = length(as)
Expand Down
69 changes: 31 additions & 38 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,34 @@
,@(map expand-forms (cddr e))))
(cons (car e) (map expand-forms (cdr e))))))

(define (expand-vcat e
(vcat '((top vcat)))
(hvcat '((top hvcat)))
(hvcat_rows '((top hvcat_rows))))
(let ((a (cdr e)))
(if (any assignment? a)
(error (string "misplaced assignment statement in \"" (deparse e) "\"")))
(if (has-parameters? a)
(error "unexpected semicolon in array expression")
(expand-forms
(if (any (lambda (x)
(and (pair? x) (eq? (car x) 'row)))
a)
;; convert nested hcat inside vcat to hvcat
(let ((rows (map (lambda (x)
(if (and (pair? x) (eq? (car x) 'row))
(cdr x)
(list x)))
a)))
;; in case there is splatting inside `hvcat`, collect each row as a
;; separate tuple and pass those to `hvcat_rows` instead (ref #38844)
(if (any (lambda (row) (any vararg? row)) rows)
`(call ,@hvcat_rows ,@(map (lambda (x) `(tuple ,@x)) rows))
`(call ,@hvcat
(tuple ,@(map length rows))
,@(apply append rows))))
`(call ,@vcat ,@a))))))

(define (expand-tuple-destruct lhss x)
(define (sides-match? l r)
;; l and r either have equal lengths, or r has a trailing ...
Expand Down Expand Up @@ -2449,27 +2477,7 @@
(error (string "misplaced assignment statement in \"" (deparse e) "\"")))
(expand-forms `(call (top hcat) ,@(cdr e))))

'vcat
(lambda (e)
(let ((a (cdr e)))
(if (any assignment? a)
(error (string "misplaced assignment statement in \"" (deparse e) "\"")))
(if (has-parameters? a)
(error "unexpected semicolon in array expression")
(expand-forms
(if (any (lambda (x)
(and (pair? x) (eq? (car x) 'row)))
a)
;; convert nested hcat inside vcat to hvcat
(let ((rows (map (lambda (x)
(if (and (pair? x) (eq? (car x) 'row))
(cdr x)
(list x)))
a)))
`(call (top hvcat)
(tuple ,.(map length rows))
,.(apply append rows)))
`(call (top vcat) ,@a))))))
'vcat expand-vcat

'typed_hcat
(lambda (e)
Expand All @@ -2480,23 +2488,8 @@
'typed_vcat
(lambda (e)
(let ((t (cadr e))
(a (cddr e)))
(if (any assignment? (cddr e))
(error (string "misplaced assignment statement in \"" (deparse e) "\"")))
(expand-forms
(if (any (lambda (x)
(and (pair? x) (eq? (car x) 'row)))
a)
;; convert nested hcat inside vcat to hvcat
(let ((rows (map (lambda (x)
(if (and (pair? x) (eq? (car x) 'row))
(cdr x)
(list x)))
a)))
`(call (top typed_hvcat) ,t
(tuple ,.(map length rows))
,.(apply append rows)))
`(call (top typed_vcat) ,t ,@a)))))
(e (cdr e)))
(expand-vcat e `((top typed_vcat) ,t) `((top typed_hvcat) ,t) `((top typed_hvcat_rows) ,t))))

'|'| (lambda (e) (expand-forms `(call |'| ,(cadr e))))

Expand Down
11 changes: 11 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1269,3 +1269,14 @@ Base.pushfirst!(tpa::TestPushArray{T}, a::T) where T = pushfirst!(tpa.data, a)
pushfirst!(tpa, 6, 5, 4, 3, 2)
@test tpa.data == reverse(collect(1:6))
end

@testset "splatting into hvcat" begin
t = (1, 2)
@test [t...; 3 4] == [1 2; 3 4]
@test [0 t...; t... 0] == [0 1 2; 1 2 0]
@test_throws ArgumentError [t...; 3 4 5]

@test Int[t...; 3 4] == [1 2; 3 4]
@test Int[0 t...; t... 0] == [0 1 2; 1 2 0]
@test_throws ArgumentError Int[t...; 3 4 5]
end

0 comments on commit 7d34b0d

Please sign in to comment.