From f06f5561138bc68884520367305a785fe70d95d0 Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Thu, 13 Mar 2025 21:52:49 +0530 Subject: [PATCH 01/11] add interface and procedures --- src/stdlib_intrinsics.fypp | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/stdlib_intrinsics.fypp b/src/stdlib_intrinsics.fypp index b2c16a5a6..ecd273173 100644 --- a/src/stdlib_intrinsics.fypp +++ b/src/stdlib_intrinsics.fypp @@ -146,6 +146,38 @@ module stdlib_intrinsics #:endfor end interface public :: kahan_kernel + + interface stdlib_matmul + !! version: experimental + !! + !!### Summary + !! compute the matrix multiplication of more than two matrices with a single function call. + !! ([Specification](../page/specs/stdlib_intrinsics.html#stdlib_matmul)) + !! + !!### Description + !! + !! matrix multiply more than two matrices with a single function call + !! the multiplication with the optimal bracketization is done automatically + !! Supported data types are `real`, `integer` and `complex`. + !! + #:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES + pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d) + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:) + ${t}$, allocatable :: d(:,:) + end function stdlib_matmul_${s}$_3 + + pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e) + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:) + ${t}$, allocatable :: e(:,:) + end function stdlib_matmul_${s}$_4 + + pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f) + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:) + ${t}$, allocatable :: f(:,:) + end function stdlib_matmul_${s}$_5 + #:endfor + end interface stdlib_matmul + public :: stdlib_matmul contains From fed4d73965d7a2cab7db89e2b88b2e7debc8f2b0 Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Thu, 13 Mar 2025 21:53:31 +0530 Subject: [PATCH 02/11] add implementation for 3,4,5 matrices --- src/CMakeLists.txt | 5 +- src/stdlib_intrinsics_matmul.fypp | 119 ++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 src/stdlib_intrinsics_matmul.fypp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 933da34de..acfe315e8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,6 +19,7 @@ set(fppFiles stdlib_hash_64bit_spookyv2.fypp stdlib_intrinsics_dot_product.fypp stdlib_intrinsics_sum.fypp + stdlib_intrinsics_matmul.fypp stdlib_intrinsics.fypp stdlib_io.fypp stdlib_io_npy.fypp @@ -32,14 +33,14 @@ set(fppFiles stdlib_linalg_kronecker.fypp stdlib_linalg_cross_product.fypp stdlib_linalg_eigenvalues.fypp - stdlib_linalg_solve.fypp + stdlib_linalg_solve.fypp stdlib_linalg_determinant.fypp stdlib_linalg_qr.fypp stdlib_linalg_inverse.fypp stdlib_linalg_pinv.fypp stdlib_linalg_norms.fypp stdlib_linalg_state.fypp - stdlib_linalg_svd.fypp + stdlib_linalg_svd.fypp stdlib_linalg_cholesky.fypp stdlib_linalg_schur.fypp stdlib_optval.fypp diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp new file mode 100644 index 000000000..240faecf8 --- /dev/null +++ b/src/stdlib_intrinsics_matmul.fypp @@ -0,0 +1,119 @@ +#:include "common.fypp" +#:set I_KINDS_TYPES = list(zip(INT_KINDS, INT_TYPES, INT_KINDS)) +#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) +#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) + +submodule (stdlib_intrinsics) stdlib_intrinsics_matmul + implicit none + +contains + + ! Algorithm for the optimal bracketization of matrices + ! Reference: Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2 + ! Internal use only! + pure function matmul_chain_order(n, p) result(s) + integer, intent(in) :: n, p(:) + integer :: s(1:n-1, 2:n), m(1:n, 1:n), l, i, j, k, q + m(:,:) = 0 + s(:,:) = 0 + + do l = 2, n + do i = 1, n - l + 1 + j = i + l - 1 + m(i,j) = huge(1) + + do k = i, j - 1 + q = m(i,k) + m(k+1,j) + p(i)*p(k+1)*p(j+1) + + if (q < m(i, j)) then + m(i,j) = q + s(i,j) = k + end if + end do + end do + end do + end function matmul_chain_order + +#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES + + pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d) + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:) + ${t}$, allocatable :: d(:,:) + integer :: sa(2), sb(2), sc(2), cost1, cost2 + sa = shape(a) + sb = shape(b) + sc = shape(c) + + if ((sa(2) /= sb(1)) .or. (sb(2) /= sc(1))) then + error stop "stdlib_matmul: Incompatible array shapes" + end if + + ! computes the cost (number of scalar multiplications required) + ! cost(A, B) = shape(A)(1) * shape(A)(2) * shape(B)(2) + cost1 = sa(1) * sa(2) * sb(2) + sa(1) * sb(2) * sc(2) ! ((AB)C) + cost2 = sb(1) * sb(2) * sc(2) + sa(1) * sa(2) * sc(2) ! (A(BC)) + + if (cost1 < cost2) then + d = matmul(matmul(a, b), c) + else + d = matmul(a, matmul(b, c)) + end if + end function stdlib_matmul_${s}$_3 + + pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e) + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:) + ${t}$, allocatable :: e(:,:) + integer :: p(5), i + integer :: s(3,2:4) + + p(1) = size(a, 1) + p(2) = size(b, 1) + p(3) = size(c, 1) + p(4) = size(d, 1) + p(5) = size(d, 2) + + s = matmul_chain_order(4, p) + + select case (s(1,4)) + case (1) + e = matmul(a, stdlib_matmul(b, c, d)) + case (2) + e = matmul(matmul(a, b), matmul(c, d)) + case (3) + e = matmul(stdlib_matmul(a, b ,c), d) + case default + error stop "stdlib_matmul: unexpected error unexpected s(i,j)" + end select + end function stdlib_matmul_${s}$_4 + + pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f) + ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:) + ${t}$, allocatable :: f(:,:) + integer :: p(6), i + integer :: s(4,2:5) + + p(1) = size(a, 1) + p(2) = size(b, 1) + p(3) = size(c, 1) + p(4) = size(d, 1) + p(5) = size(e, 1) + p(6) = size(e, 2) + + s = matmul_chain_order(5, p) + + select case (s(1,5)) + case (1) + f = matmul(a, stdlib_matmul(b, c, d, e)) + case (2) + f = matmul(matmul(a, b), stdlib_matmul(c, d, e)) + case (3) + f = matmul(stdlib_matmul(a, b ,c), matmul(d, e)) + case (4) + f = matmul(stdlib_matmul(a, b, c, d), e) + case default + error stop "stdlib_matmul: unexpected error unexpected s(i,j)" + end select + end function stdlib_matmul_${s}$_5 + +#:endfor +end submodule stdlib_intrinsics_matmul From 27911ae88ad7d18a337fb0d6b1ef807b7980ce20 Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Thu, 13 Mar 2025 21:54:02 +0530 Subject: [PATCH 03/11] add very basic example --- example/intrinsics/CMakeLists.txt | 3 ++- example/intrinsics/example_matmul.f90 | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 example/intrinsics/example_matmul.f90 diff --git a/example/intrinsics/CMakeLists.txt b/example/intrinsics/CMakeLists.txt index 1645ba8a1..162744b66 100644 --- a/example/intrinsics/CMakeLists.txt +++ b/example/intrinsics/CMakeLists.txt @@ -1,2 +1,3 @@ ADD_EXAMPLE(sum) -ADD_EXAMPLE(dot_product) \ No newline at end of file +ADD_EXAMPLE(dot_product) +ADD_EXAMPLE(matmul) diff --git a/example/intrinsics/example_matmul.f90 b/example/intrinsics/example_matmul.f90 new file mode 100644 index 000000000..31906a65d --- /dev/null +++ b/example/intrinsics/example_matmul.f90 @@ -0,0 +1,7 @@ +program example_matmul + use stdlib_intrinsics, only: stdlib_matmul + complex :: a(2,2) + a = reshape([(0, 0), (0, -1), (0, 1), (0, 0)], [2, 2]) ! pauli y-matrix + + print *, stdlib_matmul(a, a, a, a, a) ! should be sigma_y +end program example_matmul From a7f645c17caed49524cbf21683e18e96a0cca6cf Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Thu, 13 Mar 2025 22:02:46 +0530 Subject: [PATCH 04/11] fix typo --- src/stdlib_intrinsics.fypp | 3 ++- src/stdlib_intrinsics_matmul.fypp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/stdlib_intrinsics.fypp b/src/stdlib_intrinsics.fypp index ecd273173..fc3470fc3 100644 --- a/src/stdlib_intrinsics.fypp +++ b/src/stdlib_intrinsics.fypp @@ -157,9 +157,10 @@ module stdlib_intrinsics !!### Description !! !! matrix multiply more than two matrices with a single function call - !! the multiplication with the optimal bracketization is done automatically + !! the multiplication with the optimal parenthesization for efficiency of computation is done automatically !! Supported data types are `real`, `integer` and `complex`. !! + !! Note: The matrices must be of compatible shapes to be multiplied #:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d) ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:) diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp index 240faecf8..58938f9ac 100644 --- a/src/stdlib_intrinsics_matmul.fypp +++ b/src/stdlib_intrinsics_matmul.fypp @@ -8,7 +8,7 @@ submodule (stdlib_intrinsics) stdlib_intrinsics_matmul contains - ! Algorithm for the optimal bracketization of matrices + ! Algorithm for the optimal parenthesization of matrices ! Reference: Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2 ! Internal use only! pure function matmul_chain_order(n, p) result(s) From cc77dee0106a05c5509eabfa2ad1cb5ff2418904 Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Fri, 14 Mar 2025 06:04:14 +0530 Subject: [PATCH 05/11] a bit efficient --- src/stdlib_intrinsics_matmul.fypp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp index 58938f9ac..6c909d0ee 100644 --- a/src/stdlib_intrinsics_matmul.fypp +++ b/src/stdlib_intrinsics_matmul.fypp @@ -76,11 +76,25 @@ contains select case (s(1,4)) case (1) - e = matmul(a, stdlib_matmul(b, c, d)) + select case (s(2, 4)) + case (2) + e = matmul(a, matmul(b, matmul(c, d))) + case (3) + e = matmul(a, matmul(matmul(b, c), d)) + case default + error stop "stdlib_matmul: unexpected error unexpected s(i,j)" + end select case (2) e = matmul(matmul(a, b), matmul(c, d)) case (3) - e = matmul(stdlib_matmul(a, b ,c), d) + select case (s(1, 3)) + case (1) + e = matmul(matmul(a, matmul(b, c)), d) + case (2) + e = matmul(matmul(matmul(a, b), c), d) + case default + error stop "stdlib_matmul: unexpected error unexpected s(i,j)" + end select case default error stop "stdlib_matmul: unexpected error unexpected s(i,j)" end select From 3958018460d688a72bf678015fccd76a7e27403e Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Sat, 15 Mar 2025 01:19:48 +0530 Subject: [PATCH 06/11] refactor algorithm --- src/stdlib_intrinsics_matmul.fypp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp index 6c909d0ee..e6bab57ff 100644 --- a/src/stdlib_intrinsics_matmul.fypp +++ b/src/stdlib_intrinsics_matmul.fypp @@ -11,9 +11,11 @@ contains ! Algorithm for the optimal parenthesization of matrices ! Reference: Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2 ! Internal use only! - pure function matmul_chain_order(n, p) result(s) - integer, intent(in) :: n, p(:) - integer :: s(1:n-1, 2:n), m(1:n, 1:n), l, i, j, k, q + pure function matmul_chain_order(p) result(s) + integer, intent(in) :: p(:) + integer :: s(1:size(p) - 2, 2: size(p) - 1), m(1: size(p) - 1, 1: size(p) - 1) + integer :: n, l, i, j, k, q + n = size(p) - 1 m(:,:) = 0 s(:,:) = 0 @@ -72,7 +74,7 @@ contains p(4) = size(d, 1) p(5) = size(d, 2) - s = matmul_chain_order(4, p) + s = matmul_chain_order(p) select case (s(1,4)) case (1) @@ -113,7 +115,7 @@ contains p(5) = size(e, 1) p(6) = size(e, 2) - s = matmul_chain_order(5, p) + s = matmul_chain_order(p) select case (s(1,5)) case (1) From 35a5a282d00008842da80e4a684f46be6348d227 Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Sat, 15 Mar 2025 03:03:20 +0530 Subject: [PATCH 07/11] add new interface --- src/stdlib_intrinsics.fypp | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/stdlib_intrinsics.fypp b/src/stdlib_intrinsics.fypp index fc3470fc3..e14c161ed 100644 --- a/src/stdlib_intrinsics.fypp +++ b/src/stdlib_intrinsics.fypp @@ -162,20 +162,11 @@ module stdlib_intrinsics !! !! Note: The matrices must be of compatible shapes to be multiplied #:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES - pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d) - ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:) - ${t}$, allocatable :: d(:,:) - end function stdlib_matmul_${s}$_3 - - pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e) - ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:) - ${t}$, allocatable :: e(:,:) - end function stdlib_matmul_${s}$_4 - - pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f) - ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:) - ${t}$, allocatable :: f(:,:) - end function stdlib_matmul_${s}$_5 + pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + ${t}$, allocatable :: r(:,:) + end function stdlib_matmul_${s}$ #:endfor end interface stdlib_matmul public :: stdlib_matmul From ebf92d79f91190c143e16dd6c786fa486e9c08f3 Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Sat, 15 Mar 2025 03:03:44 +0530 Subject: [PATCH 08/11] add helper functions --- src/stdlib_intrinsics_matmul.fypp | 97 ++++++------------------------- 1 file changed, 18 insertions(+), 79 deletions(-) diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp index e6bab57ff..2dc1b7fb5 100644 --- a/src/stdlib_intrinsics_matmul.fypp +++ b/src/stdlib_intrinsics_matmul.fypp @@ -37,99 +37,38 @@ contains end function matmul_chain_order #:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES + + pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:) + integer, intent(in) :: start, s(:,:) + ${t}$, allocatable :: r(:,:) - pure module function stdlib_matmul_${s}$_3 (a, b, c) result(d) - ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:) - ${t}$, allocatable :: d(:,:) - integer :: sa(2), sb(2), sc(2), cost1, cost2 - sa = shape(a) - sb = shape(b) - sc = shape(c) - - if ((sa(2) /= sb(1)) .or. (sb(2) /= sc(1))) then - error stop "stdlib_matmul: Incompatible array shapes" - end if - - ! computes the cost (number of scalar multiplications required) - ! cost(A, B) = shape(A)(1) * shape(A)(2) * shape(B)(2) - cost1 = sa(1) * sa(2) * sb(2) + sa(1) * sb(2) * sc(2) ! ((AB)C) - cost2 = sb(1) * sb(2) * sc(2) + sa(1) * sa(2) * sc(2) ! (A(BC)) - - if (cost1 < cost2) then - d = matmul(matmul(a, b), c) - else - d = matmul(a, matmul(b, c)) - end if - end function stdlib_matmul_${s}$_3 - - pure module function stdlib_matmul_${s}$_4 (a, b, c, d) result(e) - ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:) - ${t}$, allocatable :: e(:,:) - integer :: p(5), i - integer :: s(3,2:4) - - p(1) = size(a, 1) - p(2) = size(b, 1) - p(3) = size(c, 1) - p(4) = size(d, 1) - p(5) = size(d, 2) - - s = matmul_chain_order(p) - - select case (s(1,4)) + select case (s(start, start + 2)) case (1) - select case (s(2, 4)) - case (2) - e = matmul(a, matmul(b, matmul(c, d))) - case (3) - e = matmul(a, matmul(matmul(b, c), d)) - case default - error stop "stdlib_matmul: unexpected error unexpected s(i,j)" - end select + r = matmul(m1, matmul(m2, m3)) case (2) - e = matmul(matmul(a, b), matmul(c, d)) - case (3) - select case (s(1, 3)) - case (1) - e = matmul(matmul(a, matmul(b, c)), d) - case (2) - e = matmul(matmul(matmul(a, b), c), d) - case default - error stop "stdlib_matmul: unexpected error unexpected s(i,j)" - end select + r = matmul(matmul(m1, m2), m3) case default error stop "stdlib_matmul: unexpected error unexpected s(i,j)" end select - end function stdlib_matmul_${s}$_4 - - pure module function stdlib_matmul_${s}$_5 (a, b, c, d, e) result(f) - ${t}$, intent(in) :: a(:,:), b(:,:), c(:,:), d(:,:), e(:,:) - ${t}$, allocatable :: f(:,:) - integer :: p(6), i - integer :: s(4,2:5) - - p(1) = size(a, 1) - p(2) = size(b, 1) - p(3) = size(c, 1) - p(4) = size(d, 1) - p(5) = size(e, 1) - p(6) = size(e, 2) + end function matmul_chain_mult_${s}$_3 - s = matmul_chain_order(p) + pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:) + integer, intent(in) :: start, s(:,:) + ${t}$, allocatable :: r(:,:) - select case (s(1,5)) + select case (s(start, start + 3)) case (1) - f = matmul(a, stdlib_matmul(b, c, d, e)) + r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s)) case (2) - f = matmul(matmul(a, b), stdlib_matmul(c, d, e)) + r = matmul(matmul(m1, m2), matmul(m3, m4)) case (3) - f = matmul(stdlib_matmul(a, b ,c), matmul(d, e)) - case (4) - f = matmul(stdlib_matmul(a, b, c, d), e) + r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4) case default error stop "stdlib_matmul: unexpected error unexpected s(i,j)" end select - end function stdlib_matmul_${s}$_5 + end function matmul_chain_mult_${s}$_4 #:endfor end submodule stdlib_intrinsics_matmul From 5f5c5a9849d1e742bbd570b03eec6ca08f1282d8 Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Sat, 15 Mar 2025 19:57:52 +0530 Subject: [PATCH 09/11] add implementation, refactor select to if clauses --- src/stdlib_intrinsics_matmul.fypp | 101 +++++++++++++++++++++++++----- 1 file changed, 84 insertions(+), 17 deletions(-) diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp index 2dc1b7fb5..5028c5cc2 100644 --- a/src/stdlib_intrinsics_matmul.fypp +++ b/src/stdlib_intrinsics_matmul.fypp @@ -13,7 +13,7 @@ contains ! Internal use only! pure function matmul_chain_order(p) result(s) integer, intent(in) :: p(:) - integer :: s(1:size(p) - 2, 2: size(p) - 1), m(1: size(p) - 1, 1: size(p) - 1) + integer :: s(1:size(p) - 2, 2:size(p) - 1), m(1:size(p) - 1, 1:size(p) - 1) integer :: n, l, i, j, k, q n = size(p) - 1 m(:,:) = 0 @@ -40,35 +40,102 @@ contains pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r) ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:) - integer, intent(in) :: start, s(:,:) + integer, intent(in) :: start, s(:,2:) ${t}$, allocatable :: r(:,:) + integer :: tmp + tmp = s(start, start + 2) + + if (tmp == start) then + r = matmul(m1, matmul(m2, m3)) + else if (tmp == start + 1) then + r = matmul(matmul(m1, m2), m3) + else + error stop "stdlib_matmul: error: unexpected s(i,j)" + end if - select case (s(start, start + 2)) - case (1) - r = matmul(m1, matmul(m2, m3)) - case (2) - r = matmul(matmul(m1, m2), m3) - case default - error stop "stdlib_matmul: unexpected error unexpected s(i,j)" - end select end function matmul_chain_mult_${s}$_3 pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r) ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:) - integer, intent(in) :: start, s(:,:) + integer, intent(in) :: start, s(:,2:) + ${t}$, allocatable :: r(:,:) + integer :: tmp + tmp = s(start, start + 3) + + if (tmp == start) then + r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s)) + else if (tmp == start + 1) then + r = matmul(matmul(m1, m2), matmul(m3, m4)) + else if (tmp == start + 2) then + r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4) + else + error stop "stdlib_matmul: error: unexpected s(i,j)" + end if + + end function matmul_chain_mult_${s}$_4 + + pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) ${t}$, allocatable :: r(:,:) + integer :: p(6), num_present + integer, allocatable :: s(:,:) - select case (s(start, start + 3)) + p(1) = size(m1, 1) + p(2) = size(m2, 1) + p(3) = size(m2, 2) + + num_present = 2 + if (present(m3)) then + p(3) = size(m3, 1) + p(4) = size(m3, 2) + num_present = num_present + 1 + end if + if (present(m4)) then + p(4) = size(m4, 1) + p(5) = size(m4, 2) + num_present = num_present + 1 + end if + if (present(m5)) then + p(5) = size(m5, 1) + p(6) = size(m5, 2) + num_present = num_present + 1 + end if + + if (num_present == 2) then + r = matmul(m1, m2) + return + end if + + ! Now num_present >= 3 + allocate(s(1:num_present - 1, 2:num_present)) + + s = matmul_chain_order(p(1: num_present + 1)) + + if (num_present == 3) then + r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s) + return + else if (num_present == 4) then + r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s) + return + end if + + ! Now num_present is 5 + + select case (s(1, 5)) case (1) - r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s)) + r = matmul(m1, matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s)) case (2) - r = matmul(matmul(m1, m2), matmul(m3, m4)) + r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s)) case (3) - r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4) + r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5)) + case (4) + r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5) case default - error stop "stdlib_matmul: unexpected error unexpected s(i,j)" + error stop "stdlib_matmul: error: unexpected s(i,j)" end select - end function matmul_chain_mult_${s}$_4 + + end function stdlib_matmul_${s}$ #:endfor end submodule stdlib_intrinsics_matmul From 06ce7351996597b1b6ab8cf22e2d2454529a28a8 Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Sat, 15 Mar 2025 19:58:15 +0530 Subject: [PATCH 10/11] slightly better examples --- example/intrinsics/example_matmul.f90 | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/example/intrinsics/example_matmul.f90 b/example/intrinsics/example_matmul.f90 index 31906a65d..99ba770dd 100644 --- a/example/intrinsics/example_matmul.f90 +++ b/example/intrinsics/example_matmul.f90 @@ -1,7 +1,19 @@ program example_matmul use stdlib_intrinsics, only: stdlib_matmul - complex :: a(2,2) - a = reshape([(0, 0), (0, -1), (0, 1), (0, 0)], [2, 2]) ! pauli y-matrix + complex :: x(2, 2), y(2, 2) + real :: r1(50, 100), r2(100, 40), r3(40, 50) + real, allocatable :: res(:, :) + x = reshape([(0, 0), (1, 0), (1, 0), (0, 0)], [2, 2]) + y = reshape([(0, 0), (0, -1), (0, 1), (0, 0)], [2, 2]) ! pauli y-matrix - print *, stdlib_matmul(a, a, a, a, a) ! should be sigma_y + print *, stdlib_matmul(y, y, y, y, y) ! should be y + print *, stdlib_matmul(x, x, y, x) ! should be -i x sigma_z + + call random_seed() + call random_number(r1) + call random_number(r2) + call random_number(r3) + + res = stdlib_matmul(r1, r2, r3) ! 50x50 matrix + print *, shape(res) end program example_matmul From e709f838aeb1b57276bdfe36840ef9a720e1086c Mon Sep 17 00:00:00 2001 From: "supritsj@Arch" Date: Fri, 21 Mar 2025 01:58:17 +0530 Subject: [PATCH 11/11] replace all matmul's by gemm --- example/intrinsics/example_matmul.f90 | 4 +- src/stdlib_intrinsics.fypp | 4 +- src/stdlib_intrinsics_matmul.fypp | 149 ++++++++++++++++++++------ 3 files changed, 121 insertions(+), 36 deletions(-) diff --git a/example/intrinsics/example_matmul.f90 b/example/intrinsics/example_matmul.f90 index 99ba770dd..18ab1a0ec 100644 --- a/example/intrinsics/example_matmul.f90 +++ b/example/intrinsics/example_matmul.f90 @@ -4,9 +4,9 @@ program example_matmul real :: r1(50, 100), r2(100, 40), r3(40, 50) real, allocatable :: res(:, :) x = reshape([(0, 0), (1, 0), (1, 0), (0, 0)], [2, 2]) - y = reshape([(0, 0), (0, -1), (0, 1), (0, 0)], [2, 2]) ! pauli y-matrix + y = reshape([(0, 0), (0, 1), (0, -1), (0, 0)], [2, 2]) ! pauli y-matrix - print *, stdlib_matmul(y, y, y, y, y) ! should be y + print *, stdlib_matmul(y, y, y) ! should be y print *, stdlib_matmul(x, x, y, x) ! should be -i x sigma_z call random_seed() diff --git a/src/stdlib_intrinsics.fypp b/src/stdlib_intrinsics.fypp index e14c161ed..bb3103140 100644 --- a/src/stdlib_intrinsics.fypp +++ b/src/stdlib_intrinsics.fypp @@ -158,10 +158,10 @@ module stdlib_intrinsics !! !! matrix multiply more than two matrices with a single function call !! the multiplication with the optimal parenthesization for efficiency of computation is done automatically - !! Supported data types are `real`, `integer` and `complex`. + !! Supported data types are `real` and `complex`. !! !! Note: The matrices must be of compatible shapes to be multiplied - #:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES + #:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r) ${t}$, intent(in) :: m1(:,:), m2(:,:) ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp index 5028c5cc2..2d5a320cd 100644 --- a/src/stdlib_intrinsics_matmul.fypp +++ b/src/stdlib_intrinsics_matmul.fypp @@ -4,6 +4,8 @@ #:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) submodule (stdlib_intrinsics) stdlib_intrinsics_matmul + use stdlib_linalg_blas, only: gemm + use stdlib_constants implicit none contains @@ -36,38 +38,84 @@ contains end do end function matmul_chain_order -#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES +#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES - pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r) + pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s, p) result(r) ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:) - integer, intent(in) :: start, s(:,2:) - ${t}$, allocatable :: r(:,:) - integer :: tmp - tmp = s(start, start + 2) - - if (tmp == start) then - r = matmul(m1, matmul(m2, m3)) - else if (tmp == start + 1) then - r = matmul(matmul(m1, m2), m3) + integer, intent(in) :: start, s(:,2:), p(:) + ${t}$, allocatable :: r(:,:), temp(:,:) + integer :: ord, m, n, k + ord = s(start, start + 2) + allocate(r(p(start), p(start + 3))) + + if (ord == start) then + ! m1*(m2*m3) + m = p(start + 1) + n = p(start + 3) + k = p(start + 2) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m2, m, m3, k, zero_${s}$, temp, m) + m = p(start) + n = p(start + 3) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + else if (ord == start + 1) then + ! (m1*m2)*m3 + m = p(start) + n = p(start + 2) + k = p(start + 1) + allocate(temp(m, n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + m = p(start) + n = p(start + 3) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m) else error stop "stdlib_matmul: error: unexpected s(i,j)" end if end function matmul_chain_mult_${s}$_3 - pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r) + pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s, p) result(r) ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:) - integer, intent(in) :: start, s(:,2:) - ${t}$, allocatable :: r(:,:) - integer :: tmp - tmp = s(start, start + 3) - - if (tmp == start) then - r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s)) - else if (tmp == start + 1) then - r = matmul(matmul(m1, m2), matmul(m3, m4)) - else if (tmp == start + 2) then - r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4) + integer, intent(in) :: start, s(:,2:), p(:) + ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:) + integer :: ord, m, n, k + ord = s(start, start + 3) + allocate(r(p(start), p(start + 4))) + + if (ord == start) then + ! m1*(m2*m3*m4) + temp = matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s, p) + m = p(start) + n = p(start + 4) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + else if (ord == start + 1) then + ! (m1*m2)*(m3*m4) + m = p(start) + n = p(start + 2) + k = p(start + 1) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + + m = p(start + 2) + n = p(start + 4) + k = p(start + 3) + allocate(temp1(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m3, m, m4, k, zero_${s}$, temp1, m) + + m = p(start) + n = p(start + 4) + k = p(start + 2) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) + else if (ord == start + 2) then + ! (m1*m2*m3)*m4 + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, start, s, p) + m = p(start) + n = p(start + 4) + k = p(start + 3) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m) else error stop "stdlib_matmul: error: unexpected s(i,j)" end if @@ -77,8 +125,8 @@ contains pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r) ${t}$, intent(in) :: m1(:,:), m2(:,:) ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) - ${t}$, allocatable :: r(:,:) - integer :: p(6), num_present + ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:) + integer :: p(6), num_present, m, n, k integer, allocatable :: s(:,:) p(1) = size(m1, 1) @@ -102,8 +150,13 @@ contains num_present = num_present + 1 end if + allocate(r(p(1), p(num_present + 1))) + if (num_present == 2) then - r = matmul(m1, m2) + m = p(1) + n = p(3) + k = p(2) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m) return end if @@ -113,10 +166,10 @@ contains s = matmul_chain_order(p(1: num_present + 1)) if (num_present == 3) then - r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s) + r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4)) return else if (num_present == 4) then - r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s) + r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5)) return end if @@ -124,13 +177,45 @@ contains select case (s(1, 5)) case (1) - r = matmul(m1, matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s)) + ! m1*(m2*m3*m4*m5) + temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p) + m = p(1) + n = p(6) + k = p(2) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) case (2) - r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s)) + ! (m1*m2)*(m3*m4*m5) + m = p(1) + n = p(3) + k = p(2) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + + temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p) + + k = n + n = p(6) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) case (3) - r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5)) + ! (m1*m2*m3)*(m4*m5) + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p) + + m = p(4) + n = p(6) + k = p(5) + allocate(temp1(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m) + + k = m + m = p(1) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) case (4) - r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5) + ! (m1*m2*m3*m4)*m5 + temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p) + m = p(1) + n = p(6) + k = p(5) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m) case default error stop "stdlib_matmul: error: unexpected s(i,j)" end select