Skip to content

Commit 8fdaf91

Browse files
authored
Merge pull request #17300 from JuliaLang/dot-fusion
fusion of nested f.(args) calls into a single broadcast call
2 parents f47d9fe + fb8f1e1 commit 8fdaf91

File tree

4 files changed

+174
-6
lines changed

4 files changed

+174
-6
lines changed

NEWS.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ New language features
1010
* Generators and comprehensions support filtering using `if` ([#550]) and nested
1111
iteration using multiple `for` keywords ([#4867]).
1212

13-
* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]).
13+
* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
14+
and nested `f.(g.(args...))` calls are fused into a single `broadcast` loop ([#17300]).
1415

1516
* Macro expander functions are now generic, so macros can have multiple definitions
1617
(e.g. for different numbers of arguments, or optional arguments) ([#8846], [#9627]).
@@ -319,4 +320,5 @@ Deprecated or removed
319320
[#17037]: https://github.com/JuliaLang/julia/issues/17037
320321
[#17075]: https://github.com/JuliaLang/julia/issues/17075
321322
[#17266]: https://github.com/JuliaLang/julia/issues/17266
323+
[#17300]: https://github.com/JuliaLang/julia/issues/17300
322324
[#17374]: https://github.com/JuliaLang/julia/issues/17374

doc/manual/functions.rst

+16
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,22 @@ then ``f.(pi,A)`` will return a new array consisting of ``f(pi,a)`` for each
640640
consisting of ``f(vector1[i],vector2[i])`` for each index ``i``
641641
(throwing an exception if the vectors have different length).
642642

643+
Moreover, *nested* ``f.(args...)`` calls are *fused* into a single ``broadcast``
644+
loop. For example, ``sin.(cos.(X))`` is equivalent to ``broadcast(x -> sin(cos(x)), X)``,
645+
similar to ``[sin(cos(x)) for x in X]``: there is only a single loop over ``X``,
646+
and a single array is allocated for the result. [In contrast, ``sin(cos(X))``
647+
in a typical "vectorized" language would first allocate one temporary array for ``tmp=cos(X)``,
648+
and then compute ``sin(tmp)`` in a separate loop, allocating a second array.]
649+
This loop fusion is not a compiler optimization that may or may not occur, it
650+
is a *syntactic guarantee* whenever nested ``f.(args...)`` calls are encountered. Technically,
651+
the fusion stops as soon as a "non-dot" function is encountered; for example,
652+
in ``sin.(sort(cos.(X)))`` the ``sin`` and ``cos`` loops cannot be merged
653+
because of the intervening ``sort`` function.
654+
655+
(In future versions of Julia, operators like ``.*`` will also be handled with
656+
the same mechanism: they will be equivalent to ``broadcast`` calls and
657+
will be fused with other nested "dot" calls.)
658+
643659
Further Reading
644660
---------------
645661

src/julia-syntax.scm

+120-5
Original file line numberDiff line numberDiff line change
@@ -1546,6 +1546,121 @@
15461546
(cadr expr) ;; eta reduce `x->f(x)` => `f`
15471547
`(-> ,argname (block ,@splat ,expr)))))
15481548

1549+
(define (getfield-field? x) ; whether x from (|.| f x) is a getfield call
1550+
(or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)))
1551+
1552+
;; fuse nested calls to f.(args...) into a single broadcast call
1553+
(define (expand-fuse-broadcast f args)
1554+
(define (fuse? e) (and (pair? e) (eq? (car e) 'fuse)))
1555+
(define (anyfuse? exprs)
1556+
(if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs)))))
1557+
(define (to-lambda f args kwargs) ; convert f to anonymous function with hygienic tuple args
1558+
(define (genarg arg) (if (vararg? arg) (list '... (gensy)) (gensy)))
1559+
; (To do: optimize the case where f is already an anonymous function, in which
1560+
; case we only need to hygienicize the arguments? But it is quite tricky
1561+
; to fully handle splatted args, typed args, keywords, etcetera. And probably
1562+
; the extra function call is harmless because it will get inlined anyway.)
1563+
(let ((genargs (map genarg args))) ; hygienic formal parameters
1564+
(if (null? kwargs)
1565+
`(-> ,(cons 'tuple genargs) (call ,f ,@genargs)) ; no keyword args
1566+
`(-> ,(cons 'tuple genargs) (call ,f (parameters ,@kwargs) ,@genargs)))))
1567+
(define (from-lambda f) ; convert (-> (tuple args...) (call func args...)) back to func
1568+
(if (and (pair? f) (eq? (car f) '->) (pair? (cadr f)) (eq? (caadr f) 'tuple)
1569+
(pair? (caddr f)) (eq? (caaddr f) 'call) (equal? (cdadr f) (cdr (cdaddr f))))
1570+
(car (cdaddr f))
1571+
f))
1572+
(define (fuse-args oldargs) ; replace (fuse f args) with args in oldargs list
1573+
(define (fargs newargs oldargs)
1574+
(if (null? oldargs)
1575+
newargs
1576+
(fargs (if (fuse? (car oldargs))
1577+
(append (reverse (caddar oldargs)) newargs)
1578+
(cons (car oldargs) newargs))
1579+
(cdr oldargs))))
1580+
(reverse (fargs '() oldargs)))
1581+
(define (fuse-funcs f args) ; for (fuse g a) in args, merge/inline g into f
1582+
; any argument A of f that is (fuse g a) gets replaced by let A=(body of g):
1583+
(define (fuse-lets fargs args lets)
1584+
(if (null? args)
1585+
lets
1586+
(if (fuse? (car args))
1587+
(fuse-lets (cdr fargs) (cdr args) (cons (list '= (car fargs) (caddr (cadar args))) lets))
1588+
(fuse-lets (cdr fargs) (cdr args) lets))))
1589+
(let ((fargs (cdadr f))
1590+
(fbody (caddr f)))
1591+
`(->
1592+
(tuple ,@(fuse-args (map (lambda (oldarg arg) (if (fuse? arg)
1593+
`(fuse _ ,(cdadr (cadr arg)))
1594+
oldarg))
1595+
fargs args)))
1596+
(let ,fbody ,@(reverse (fuse-lets fargs args '()))))))
1597+
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
1598+
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
1599+
(define (sk args kwargs pargs)
1600+
(if (null? args)
1601+
(cons kwargs pargs)
1602+
(if (kwarg? (car args))
1603+
(sk (cdr args) (cons (car args) kwargs) pargs)
1604+
(sk (cdr args) kwargs (cons (car args) pargs)))))
1605+
(if (has-parameters? args)
1606+
(sk (reverse (cdr args)) (cdar args) '())
1607+
(sk (reverse args) '() '())))
1608+
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
1609+
(if (and (pair? e) (eq? (car e) '|.|) (not (getfield-field? (caddr e))))
1610+
(make-fuse (cadr e) (cdaddr e))
1611+
e))
1612+
(let* ((kws.args (split-kwargs args))
1613+
(kws (car kws.args))
1614+
(args (cdr kws.args)) ; fusing occurs on positional args only
1615+
(args_ (map dot-to-fuse args)))
1616+
(if (anyfuse? args_)
1617+
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
1618+
`(fuse ,(to-lambda f args kws) ,args_))))
1619+
; given e == (fuse lambda args), compress the argument list by removing (pure)
1620+
; duplicates in args, inlining literals, and moving any varargs to the end:
1621+
(define (compress-fuse e)
1622+
(define (findfarg arg args fargs) ; for arg in args, return corresponding farg
1623+
(if (eq? arg (car args))
1624+
(car fargs)
1625+
(findfarg arg (cdr args) (cdr fargs))))
1626+
(let ((f (cadr e))
1627+
(args (caddr e)))
1628+
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
1629+
(if (null? old-args)
1630+
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
1631+
(nargs (if (null? vararg) new-args (cons vararg new-args))))
1632+
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
1633+
,(reverse nargs)))
1634+
(let ((farg (car old-fargs)) (arg (car old-args)))
1635+
(cond
1636+
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
1637+
(if (null? varfarg)
1638+
(cf (cdr old-fargs) (cdr old-args)
1639+
new-fargs new-args renames farg arg)
1640+
(if (eq? (cadr vararg) (cadr arg))
1641+
(cf (cdr old-fargs) (cdr old-args)
1642+
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
1643+
varfarg vararg)
1644+
(error "multiple splatted args cannot be fused into a single broadcast"))))
1645+
((number? arg) ; inline numeric literals
1646+
(cf (cdr old-fargs) (cdr old-args)
1647+
new-fargs new-args
1648+
(cons (cons farg arg) renames)
1649+
varfarg vararg))
1650+
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
1651+
; (note: calling memq for every arg is O(length(args)^2) ...
1652+
; ... would be better to replace with a hash table if args is long)
1653+
(cf (cdr old-fargs) (cdr old-args)
1654+
new-fargs new-args
1655+
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
1656+
varfarg vararg))
1657+
(else
1658+
(cf (cdr old-fargs) (cdr old-args)
1659+
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
1660+
(cf (cdadr f) args '() '() '() '() '())))
1661+
(let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args)
1662+
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))))
1663+
15491664
;; table mapping expression head to a function expanding that form
15501665
(define expand-table
15511666
(table
@@ -1584,11 +1699,11 @@
15841699
(lambda (e) ; e = (|.| f x)
15851700
(let ((f (cadr e))
15861701
(x (caddr e)))
1587-
(if (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$))
1588-
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
1589-
; otherwise, came from f.(args...) --> broadcast(f, args...),
1590-
; where x = (tuple args...) at this point:
1591-
(expand-forms `(call broadcast ,f ,@(cdr x))))))
1702+
(if (getfield-field? x)
1703+
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
1704+
; otherwise, came from f.(args...) --> broadcast(f, args...),
1705+
; where we want to fuse with any nested broadcast calls.
1706+
(expand-fuse-broadcast f (cdr x)))))
15921707

15931708
'|<:| syntactic-op-to-call
15941709
'|>:| syntactic-op-to-call

test/broadcast.jl

+35
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,41 @@ let a = sin.([1, 2])
212212
@test a [0.8414709848078965, 0.9092974268256817]
213213
end
214214

215+
# PR #17300: loop fusion
216+
@test (x->x+1).((x->x+2).((x->x+3).(1:10))) == collect(7:16)
217+
let A = [sqrt(i)+j for i = 1:3, j=1:4]
218+
@test atan2.(log.(A), sum(A,1)) == broadcast(atan2, broadcast(log, A), sum(A, 1))
219+
end
220+
let x = sin.(1:10)
221+
@test atan2.((x->x+1).(x), (x->x+2).(x)) == atan2(x+1, x+2) == atan2(x.+1, x.+2)
222+
@test sin.(atan2.([x+1,x+2]...)) == sin.(atan2.(x+1,x+2))
223+
@test sin.(atan2.(x, 3.7)) == broadcast(x -> sin(atan2(x,3.7)), x)
224+
@test atan2.(x, 3.7) == broadcast(x -> atan2(x,3.7), x) == broadcast(atan2, x, 3.7)
225+
end
226+
# Use side effects to check for loop fusion. Note that, due to #17314,
227+
# a broadcasted function is currently called an extra time with an argument 1.
228+
let g = Int[]
229+
f17300(x) = begin; push!(g, x); x+1; end
230+
f17300.(f17300.(f17300.(1:3)))
231+
@test g == [1,2,3, 1,2,3, 2,3,4, 3,4,5]
232+
end
233+
# fusion with splatted args:
234+
let x = sin.(1:10), a = [x]
235+
@test cos.(x) == cos.(a...)
236+
@test atan2.(x,x) == atan2.(a..., a...) == atan2.([x, x]...)
237+
@test atan2.(x, cos.(x)) == atan2.(a..., cos.(x)) == atan2(x, cos.(a...)) == atan2(a..., cos.(a...))
238+
@test ((args...)->cos(args[1])).(x) == cos.(x) == ((y,args...)->cos(y)).(x)
239+
end
240+
@test atan2.(3,4) == atan2(3,4) == (() -> atan2(3,4)).()
241+
# fusion with keyword args:
242+
let x = [1:4;]
243+
f17300kw(x; y=0) = x + y
244+
@test f17300kw.(x) == x
245+
@test f17300kw.(x, y=1) == f17300kw.(x; y=1) == f17300kw.(x; [(:y,1)]...) == x .+ 1
246+
@test f17300kw.(sin.(x), y=1) == f17300kw.(sin.(x); y=1) == sin.(x) .+ 1
247+
@test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1)
248+
end
249+
215250
# PR 16988
216251
@test Base.promote_op(+, Bool) === Int
217252
@test isa(broadcast(+, [true]), Array{Int,1})

0 commit comments

Comments
 (0)