|
1546 | 1546 | (cadr expr) ;; eta reduce `x->f(x)` => `f`
|
1547 | 1547 | `(-> ,argname (block ,@splat ,expr)))))
|
1548 | 1548 |
|
| 1549 | +(define (getfield-field? x) ; whether x from (|.| f x) is a getfield call |
| 1550 | + (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$))) |
| 1551 | + |
| 1552 | +;; fuse nested calls to f.(args...) into a single broadcast call |
| 1553 | +(define (expand-fuse-broadcast f args) |
| 1554 | + (define (fuse? e) (and (pair? e) (eq? (car e) 'fuse))) |
| 1555 | + (define (anyfuse? exprs) |
| 1556 | + (if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs))))) |
| 1557 | + (define (to-lambda f args kwargs) ; convert f to anonymous function with hygienic tuple args |
| 1558 | + (define (genarg arg) (if (vararg? arg) (list '... (gensy)) (gensy))) |
| 1559 | + ; (To do: optimize the case where f is already an anonymous function, in which |
| 1560 | + ; case we only need to hygienicize the arguments? But it is quite tricky |
| 1561 | + ; to fully handle splatted args, typed args, keywords, etcetera. And probably |
| 1562 | + ; the extra function call is harmless because it will get inlined anyway.) |
| 1563 | + (let ((genargs (map genarg args))) ; hygienic formal parameters |
| 1564 | + (if (null? kwargs) |
| 1565 | + `(-> ,(cons 'tuple genargs) (call ,f ,@genargs)) ; no keyword args |
| 1566 | + `(-> ,(cons 'tuple genargs) (call ,f (parameters ,@kwargs) ,@genargs))))) |
| 1567 | + (define (from-lambda f) ; convert (-> (tuple args...) (call func args...)) back to func |
| 1568 | + (if (and (pair? f) (eq? (car f) '->) (pair? (cadr f)) (eq? (caadr f) 'tuple) |
| 1569 | + (pair? (caddr f)) (eq? (caaddr f) 'call) (equal? (cdadr f) (cdr (cdaddr f)))) |
| 1570 | + (car (cdaddr f)) |
| 1571 | + f)) |
| 1572 | + (define (fuse-args oldargs) ; replace (fuse f args) with args in oldargs list |
| 1573 | + (define (fargs newargs oldargs) |
| 1574 | + (if (null? oldargs) |
| 1575 | + newargs |
| 1576 | + (fargs (if (fuse? (car oldargs)) |
| 1577 | + (append (reverse (caddar oldargs)) newargs) |
| 1578 | + (cons (car oldargs) newargs)) |
| 1579 | + (cdr oldargs)))) |
| 1580 | + (reverse (fargs '() oldargs))) |
| 1581 | + (define (fuse-funcs f args) ; for (fuse g a) in args, merge/inline g into f |
| 1582 | + ; any argument A of f that is (fuse g a) gets replaced by let A=(body of g): |
| 1583 | + (define (fuse-lets fargs args lets) |
| 1584 | + (if (null? args) |
| 1585 | + lets |
| 1586 | + (if (fuse? (car args)) |
| 1587 | + (fuse-lets (cdr fargs) (cdr args) (cons (list '= (car fargs) (caddr (cadar args))) lets)) |
| 1588 | + (fuse-lets (cdr fargs) (cdr args) lets)))) |
| 1589 | + (let ((fargs (cdadr f)) |
| 1590 | + (fbody (caddr f))) |
| 1591 | + `(-> |
| 1592 | + (tuple ,@(fuse-args (map (lambda (oldarg arg) (if (fuse? arg) |
| 1593 | + `(fuse _ ,(cdadr (cadr arg))) |
| 1594 | + oldarg)) |
| 1595 | + fargs args))) |
| 1596 | + (let ,fbody ,@(reverse (fuse-lets fargs args '())))))) |
| 1597 | + (define (make-fuse f args) ; check for nested (fuse f args) exprs and combine |
| 1598 | + (define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args |
| 1599 | + (define (sk args kwargs pargs) |
| 1600 | + (if (null? args) |
| 1601 | + (cons kwargs pargs) |
| 1602 | + (if (kwarg? (car args)) |
| 1603 | + (sk (cdr args) (cons (car args) kwargs) pargs) |
| 1604 | + (sk (cdr args) kwargs (cons (car args) pargs))))) |
| 1605 | + (if (has-parameters? args) |
| 1606 | + (sk (reverse (cdr args)) (cdar args) '()) |
| 1607 | + (sk (reverse args) '() '()))) |
| 1608 | + (define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args) |
| 1609 | + (if (and (pair? e) (eq? (car e) '|.|) (not (getfield-field? (caddr e)))) |
| 1610 | + (make-fuse (cadr e) (cdaddr e)) |
| 1611 | + e)) |
| 1612 | + (let* ((kws.args (split-kwargs args)) |
| 1613 | + (kws (car kws.args)) |
| 1614 | + (args (cdr kws.args)) ; fusing occurs on positional args only |
| 1615 | + (args_ (map dot-to-fuse args))) |
| 1616 | + (if (anyfuse? args_) |
| 1617 | + `(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_)) |
| 1618 | + `(fuse ,(to-lambda f args kws) ,args_)))) |
| 1619 | + ; given e == (fuse lambda args), compress the argument list by removing (pure) |
| 1620 | + ; duplicates in args, inlining literals, and moving any varargs to the end: |
| 1621 | + (define (compress-fuse e) |
| 1622 | + (define (findfarg arg args fargs) ; for arg in args, return corresponding farg |
| 1623 | + (if (eq? arg (car args)) |
| 1624 | + (car fargs) |
| 1625 | + (findfarg arg (cdr args) (cdr fargs)))) |
| 1626 | + (let ((f (cadr e)) |
| 1627 | + (args (caddr e))) |
| 1628 | + (define (cf old-fargs old-args new-fargs new-args renames varfarg vararg) |
| 1629 | + (if (null? old-args) |
| 1630 | + (let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs))) |
| 1631 | + (nargs (if (null? vararg) new-args (cons vararg new-args)))) |
| 1632 | + `(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames)) |
| 1633 | + ,(reverse nargs))) |
| 1634 | + (let ((farg (car old-fargs)) (arg (car old-args))) |
| 1635 | + (cond |
| 1636 | + ((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument |
| 1637 | + (if (null? varfarg) |
| 1638 | + (cf (cdr old-fargs) (cdr old-args) |
| 1639 | + new-fargs new-args renames farg arg) |
| 1640 | + (if (eq? (cadr vararg) (cadr arg)) |
| 1641 | + (cf (cdr old-fargs) (cdr old-args) |
| 1642 | + new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames) |
| 1643 | + varfarg vararg) |
| 1644 | + (error "multiple splatted args cannot be fused into a single broadcast")))) |
| 1645 | + ((number? arg) ; inline numeric literals |
| 1646 | + (cf (cdr old-fargs) (cdr old-args) |
| 1647 | + new-fargs new-args |
| 1648 | + (cons (cons farg arg) renames) |
| 1649 | + varfarg vararg)) |
| 1650 | + ((and (symbol? arg) (memq arg new-args)) ; combine duplicate args |
| 1651 | + ; (note: calling memq for every arg is O(length(args)^2) ... |
| 1652 | + ; ... would be better to replace with a hash table if args is long) |
| 1653 | + (cf (cdr old-fargs) (cdr old-args) |
| 1654 | + new-fargs new-args |
| 1655 | + (cons (cons farg (findfarg arg new-args new-fargs)) renames) |
| 1656 | + varfarg vararg)) |
| 1657 | + (else |
| 1658 | + (cf (cdr old-fargs) (cdr old-args) |
| 1659 | + (cons farg new-fargs) (cons arg new-args) renames varfarg vararg)))))) |
| 1660 | + (cf (cdadr f) args '() '() '() '() '()))) |
| 1661 | + (let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args) |
| 1662 | + (expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e))))) |
| 1663 | + |
1549 | 1664 | ;; table mapping expression head to a function expanding that form
|
1550 | 1665 | (define expand-table
|
1551 | 1666 | (table
|
|
1584 | 1699 | (lambda (e) ; e = (|.| f x)
|
1585 | 1700 | (let ((f (cadr e))
|
1586 | 1701 | (x (caddr e)))
|
1587 |
| - (if (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)) |
1588 |
| - `(call (core getfield) ,(expand-forms f) ,(expand-forms x)) |
1589 |
| - ; otherwise, came from f.(args...) --> broadcast(f, args...), |
1590 |
| - ; where x = (tuple args...) at this point: |
1591 |
| - (expand-forms `(call broadcast ,f ,@(cdr x)))))) |
| 1702 | + (if (getfield-field? x) |
| 1703 | + `(call (core getfield) ,(expand-forms f) ,(expand-forms x)) |
| 1704 | + ; otherwise, came from f.(args...) --> broadcast(f, args...), |
| 1705 | + ; where we want to fuse with any nested broadcast calls. |
| 1706 | + (expand-fuse-broadcast f (cdr x))))) |
1592 | 1707 |
|
1593 | 1708 | '|<:| syntactic-op-to-call
|
1594 | 1709 | '|>:| syntactic-op-to-call
|
|
0 commit comments