diff --git a/example/intrinsics/CMakeLists.txt b/example/intrinsics/CMakeLists.txt index 1645ba8a1..162744b66 100644 --- a/example/intrinsics/CMakeLists.txt +++ b/example/intrinsics/CMakeLists.txt @@ -1,2 +1,3 @@ ADD_EXAMPLE(sum) -ADD_EXAMPLE(dot_product) \ No newline at end of file +ADD_EXAMPLE(dot_product) +ADD_EXAMPLE(matmul) diff --git a/example/intrinsics/example_matmul.f90 b/example/intrinsics/example_matmul.f90 new file mode 100644 index 000000000..18ab1a0ec --- /dev/null +++ b/example/intrinsics/example_matmul.f90 @@ -0,0 +1,19 @@ +program example_matmul + use stdlib_intrinsics, only: stdlib_matmul + complex :: x(2, 2), y(2, 2) + real :: r1(50, 100), r2(100, 40), r3(40, 50) + real, allocatable :: res(:, :) + x = reshape([(0, 0), (1, 0), (1, 0), (0, 0)], [2, 2]) + y = reshape([(0, 0), (0, 1), (0, -1), (0, 0)], [2, 2]) ! pauli y-matrix + + print *, stdlib_matmul(y, y, y) ! should be y + print *, stdlib_matmul(x, x, y, x) ! should be -i x sigma_z + + call random_seed() + call random_number(r1) + call random_number(r2) + call random_number(r3) + + res = stdlib_matmul(r1, r2, r3) ! 50x50 matrix + print *, shape(res) +end program example_matmul diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 933da34de..acfe315e8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,6 +19,7 @@ set(fppFiles stdlib_hash_64bit_spookyv2.fypp stdlib_intrinsics_dot_product.fypp stdlib_intrinsics_sum.fypp + stdlib_intrinsics_matmul.fypp stdlib_intrinsics.fypp stdlib_io.fypp stdlib_io_npy.fypp @@ -32,14 +33,14 @@ set(fppFiles stdlib_linalg_kronecker.fypp stdlib_linalg_cross_product.fypp stdlib_linalg_eigenvalues.fypp - stdlib_linalg_solve.fypp + stdlib_linalg_solve.fypp stdlib_linalg_determinant.fypp stdlib_linalg_qr.fypp stdlib_linalg_inverse.fypp stdlib_linalg_pinv.fypp stdlib_linalg_norms.fypp stdlib_linalg_state.fypp - stdlib_linalg_svd.fypp + stdlib_linalg_svd.fypp stdlib_linalg_cholesky.fypp stdlib_linalg_schur.fypp stdlib_optval.fypp diff --git a/src/stdlib_intrinsics.fypp b/src/stdlib_intrinsics.fypp index b2c16a5a6..bb3103140 100644 --- a/src/stdlib_intrinsics.fypp +++ b/src/stdlib_intrinsics.fypp @@ -146,6 +146,30 @@ module stdlib_intrinsics #:endfor end interface public :: kahan_kernel + + interface stdlib_matmul + !! version: experimental + !! + !!### Summary + !! compute the matrix multiplication of more than two matrices with a single function call. + !! ([Specification](../page/specs/stdlib_intrinsics.html#stdlib_matmul)) + !! + !!### Description + !! + !! matrix multiply more than two matrices with a single function call + !! the multiplication with the optimal parenthesization for efficiency of computation is done automatically + !! Supported data types are `real` and `complex`. + !! + !! Note: The matrices must be of compatible shapes to be multiplied + #:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES + pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + ${t}$, allocatable :: r(:,:) + end function stdlib_matmul_${s}$ + #:endfor + end interface stdlib_matmul + public :: stdlib_matmul contains diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp new file mode 100644 index 000000000..2d5a320cd --- /dev/null +++ b/src/stdlib_intrinsics_matmul.fypp @@ -0,0 +1,226 @@ +#:include "common.fypp" +#:set I_KINDS_TYPES = list(zip(INT_KINDS, INT_TYPES, INT_KINDS)) +#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) +#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) + +submodule (stdlib_intrinsics) stdlib_intrinsics_matmul + use stdlib_linalg_blas, only: gemm + use stdlib_constants + implicit none + +contains + + ! Algorithm for the optimal parenthesization of matrices + ! Reference: Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2 + ! Internal use only! + pure function matmul_chain_order(p) result(s) + integer, intent(in) :: p(:) + integer :: s(1:size(p) - 2, 2:size(p) - 1), m(1:size(p) - 1, 1:size(p) - 1) + integer :: n, l, i, j, k, q + n = size(p) - 1 + m(:,:) = 0 + s(:,:) = 0 + + do l = 2, n + do i = 1, n - l + 1 + j = i + l - 1 + m(i,j) = huge(1) + + do k = i, j - 1 + q = m(i,k) + m(k+1,j) + p(i)*p(k+1)*p(j+1) + + if (q < m(i, j)) then + m(i,j) = q + s(i,j) = k + end if + end do + end do + end do + end function matmul_chain_order + +#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES + + pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s, p) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:) + integer, intent(in) :: start, s(:,2:), p(:) + ${t}$, allocatable :: r(:,:), temp(:,:) + integer :: ord, m, n, k + ord = s(start, start + 2) + allocate(r(p(start), p(start + 3))) + + if (ord == start) then + ! m1*(m2*m3) + m = p(start + 1) + n = p(start + 3) + k = p(start + 2) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m2, m, m3, k, zero_${s}$, temp, m) + m = p(start) + n = p(start + 3) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + else if (ord == start + 1) then + ! (m1*m2)*m3 + m = p(start) + n = p(start + 2) + k = p(start + 1) + allocate(temp(m, n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + m = p(start) + n = p(start + 3) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m) + else + error stop "stdlib_matmul: error: unexpected s(i,j)" + end if + + end function matmul_chain_mult_${s}$_3 + + pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s, p) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:) + integer, intent(in) :: start, s(:,2:), p(:) + ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:) + integer :: ord, m, n, k + ord = s(start, start + 3) + allocate(r(p(start), p(start + 4))) + + if (ord == start) then + ! m1*(m2*m3*m4) + temp = matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s, p) + m = p(start) + n = p(start + 4) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + else if (ord == start + 1) then + ! (m1*m2)*(m3*m4) + m = p(start) + n = p(start + 2) + k = p(start + 1) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + + m = p(start + 2) + n = p(start + 4) + k = p(start + 3) + allocate(temp1(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m3, m, m4, k, zero_${s}$, temp1, m) + + m = p(start) + n = p(start + 4) + k = p(start + 2) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) + else if (ord == start + 2) then + ! (m1*m2*m3)*m4 + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, start, s, p) + m = p(start) + n = p(start + 4) + k = p(start + 3) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m) + else + error stop "stdlib_matmul: error: unexpected s(i,j)" + end if + + end function matmul_chain_mult_${s}$_4 + + pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:) + integer :: p(6), num_present, m, n, k + integer, allocatable :: s(:,:) + + p(1) = size(m1, 1) + p(2) = size(m2, 1) + p(3) = size(m2, 2) + + num_present = 2 + if (present(m3)) then + p(3) = size(m3, 1) + p(4) = size(m3, 2) + num_present = num_present + 1 + end if + if (present(m4)) then + p(4) = size(m4, 1) + p(5) = size(m4, 2) + num_present = num_present + 1 + end if + if (present(m5)) then + p(5) = size(m5, 1) + p(6) = size(m5, 2) + num_present = num_present + 1 + end if + + allocate(r(p(1), p(num_present + 1))) + + if (num_present == 2) then + m = p(1) + n = p(3) + k = p(2) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m) + return + end if + + ! Now num_present >= 3 + allocate(s(1:num_present - 1, 2:num_present)) + + s = matmul_chain_order(p(1: num_present + 1)) + + if (num_present == 3) then + r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4)) + return + else if (num_present == 4) then + r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5)) + return + end if + + ! Now num_present is 5 + + select case (s(1, 5)) + case (1) + ! m1*(m2*m3*m4*m5) + temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p) + m = p(1) + n = p(6) + k = p(2) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + case (2) + ! (m1*m2)*(m3*m4*m5) + m = p(1) + n = p(3) + k = p(2) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + + temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p) + + k = n + n = p(6) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) + case (3) + ! (m1*m2*m3)*(m4*m5) + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p) + + m = p(4) + n = p(6) + k = p(5) + allocate(temp1(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m) + + k = m + m = p(1) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) + case (4) + ! (m1*m2*m3*m4)*m5 + temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p) + m = p(1) + n = p(6) + k = p(5) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m) + case default + error stop "stdlib_matmul: error: unexpected s(i,j)" + end select + + end function stdlib_matmul_${s}$ + +#:endfor +end submodule stdlib_intrinsics_matmul