Skip to content

Commit 1d56f85

Browse files
authored
Merge pull request dealii#16614 from kinnewig/tpetra_wrappers_vmult
Template TpetraWrappers::SparseMatrix::vmult on the vector type.
2 parents 0b3410f + 9db41cd commit 1d56f85

File tree

3 files changed

+182
-58
lines changed

3 files changed

+182
-58
lines changed

Diff for: include/deal.II/lac/trilinos_tpetra_sparse_matrix.h

+17-9
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@
2727
# include <deal.II/lac/sparsity_pattern.h>
2828
# include <deal.II/lac/trilinos_tpetra_sparsity_pattern.h>
2929
# include <deal.II/lac/trilinos_tpetra_vector.h>
30+
# include <deal.II/lac/vector.h>
3031

3132
// Tpetra includes
3233
# include <Tpetra_Core.hpp>
3334
# include <Tpetra_CrsMatrix.hpp>
3435

36+
# include <type_traits>
37+
3538

3639
DEAL_II_NAMESPACE_OPEN
3740

@@ -162,6 +165,12 @@ namespace LinearAlgebra
162165
using GraphType =
163166
Tpetra::CrsGraph<int, dealii::types::signed_global_dof_index, NodeType>;
164167

168+
/**
169+
* Typedef for Tpetra::Vector
170+
*/
171+
using VectorType = Tpetra::
172+
Vector<Number, int, dealii::types::signed_global_dof_index, NodeType>;
173+
165174
/**
166175
* @name Constructors and initialization.
167176
*/
@@ -799,9 +808,9 @@ namespace LinearAlgebra
799808
* initialized with the same IndexSet that was used for the column indices
800809
* of the matrix.
801810
*/
811+
template <typename InputVectorType>
802812
void
803-
vmult(Vector<Number, MemorySpace> &dst,
804-
const Vector<Number, MemorySpace> &src) const;
813+
vmult(InputVectorType &dst, const InputVectorType &src) const;
805814

806815
/*
807816
* Matrix-vector multiplication: let <i>dst = M<sup>T</sup>*src</i> with
@@ -810,20 +819,19 @@ namespace LinearAlgebra
810819
*
811820
* Source and destination must not be the same vector.
812821
*/
822+
template <typename InputVectorType>
813823
void
814-
Tvmult(Vector<Number, MemorySpace> &dst,
815-
const Vector<Number, MemorySpace> &src) const;
824+
Tvmult(InputVectorType &dst, const InputVectorType &src) const;
816825

817826
/**
818827
* Adding matrix-vector multiplication. Add <i>M*src</i> on <i>dst</i>
819828
* with <i>M</i> being this matrix.
820829
*
821830
* Source and destination must not be the same vector.
822831
*/
832+
template <typename InputVectorType>
823833
void
824-
vmult_add(Vector<Number, MemorySpace> &dst,
825-
const Vector<Number, MemorySpace> &src) const;
826-
834+
vmult_add(InputVectorType &dst, const InputVectorType &src) const;
827835

828836
/**
829837
* Adding matrix-vector multiplication. Add <i>M<sup>T</sup>*src</i> to
@@ -832,9 +840,9 @@ namespace LinearAlgebra
832840
*
833841
* Source and destination must not be the same vector.
834842
*/
843+
template <typename InputVectorType>
835844
void
836-
Tvmult_add(Vector<Number, MemorySpace> &dst,
837-
const Vector<Number, MemorySpace> &src) const;
845+
Tvmult_add(InputVectorType &dst, const InputVectorType &src) const;
838846

839847
/**
840848
* Return the square of the norm of the vector $v$ with respect to the

Diff for: include/deal.II/lac/trilinos_tpetra_sparse_matrix.templates.h

+134-49
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,114 @@ namespace LinearAlgebra
3131

3232
namespace TpetraWrappers
3333
{
34+
namespace internal
35+
{
36+
template <typename Number, typename MemorySpace>
37+
void
38+
apply(const SparseMatrix<Number, MemorySpace> &M,
39+
const Vector<Number, MemorySpace> &src,
40+
Vector<Number, MemorySpace> &dst,
41+
Teuchos::ETransp mode = Teuchos::NO_TRANS,
42+
Number alpha = Teuchos::ScalarTraits<Number>::one(),
43+
Number beta = Teuchos::ScalarTraits<Number>::zero())
44+
{
45+
Assert(&src != &dst,
46+
SparseMatrix<double>::ExcSourceEqualsDestination());
47+
Assert(M.trilinos_matrix().isFillComplete(),
48+
SparseMatrix<double>::ExcMatrixNotCompressed());
49+
50+
if (mode == Teuchos::NO_TRANS)
51+
{
52+
Assert(src.trilinos_vector().getMap()->isSameAs(
53+
*M.trilinos_matrix().getDomainMap()),
54+
SparseMatrix<double>::ExcColMapMissmatch());
55+
Assert(dst.trilinos_vector().getMap()->isSameAs(
56+
*M.trilinos_matrix().getRangeMap()),
57+
SparseMatrix<double>::ExcDomainMapMissmatch());
58+
}
59+
else
60+
{
61+
Assert(dst.trilinos_vector().getMap()->isSameAs(
62+
*M.trilinos_matrix().getDomainMap()),
63+
SparseMatrix<double>::ExcColMapMissmatch());
64+
Assert(src.trilinos_vector().getMap()->isSameAs(
65+
*M.trilinos_matrix().getRangeMap()),
66+
SparseMatrix<double>::ExcDomainMapMissmatch());
67+
}
68+
69+
M.trilinos_matrix().apply(
70+
src.trilinos_vector(), dst.trilinos_vector(), mode, alpha, beta);
71+
}
72+
73+
74+
75+
template <typename Number, typename MemorySpace>
76+
void
77+
apply(const SparseMatrix<Number, MemorySpace> &M,
78+
const dealii::Vector<Number> &src,
79+
dealii::Vector<Number> &dst,
80+
Teuchos::ETransp mode = Teuchos::NO_TRANS,
81+
Number alpha = Teuchos::ScalarTraits<Number>::one(),
82+
Number beta = Teuchos::ScalarTraits<Number>::zero())
83+
{
84+
Assert(&src != &dst,
85+
SparseMatrix<double>::ExcSourceEqualsDestination());
86+
Assert(M.trilinos_matrix().isFillComplete(),
87+
SparseMatrix<double>::ExcMatrixNotCompressed());
88+
89+
// get the size of the input vectors:
90+
const size_type dst_local_size = dst.end() - dst.begin();
91+
const size_type src_local_size = src.end() - src.begin();
92+
93+
// For the dst vector:
94+
Kokkos::View<Number **, Kokkos::LayoutLeft, Kokkos::HostSpace>
95+
kokkos_view_dst(dst.begin(), dst_local_size, 1);
96+
97+
// get a Kokkos::DualView
98+
auto mirror_view_dst = Kokkos::create_mirror_view_and_copy(
99+
typename MemorySpace::kokkos_space{}, kokkos_view_dst);
100+
typename SparseMatrix<Number, MemorySpace>::VectorType::dual_view_type
101+
kokkos_dual_view_dst(mirror_view_dst, kokkos_view_dst);
102+
103+
// create the Tpetra::Vector
104+
typename SparseMatrix<Number, MemorySpace>::VectorType tpetra_dst(
105+
M.trilinos_matrix().getRangeMap(), kokkos_dual_view_dst);
106+
107+
// For the src vector:
108+
// create a Kokkos::View from the src vector
109+
Kokkos::View<Number **, Kokkos::LayoutLeft, Kokkos::HostSpace>
110+
kokkos_view_src(const_cast<Number *>(src.begin()), src_local_size, 1);
111+
112+
// get a Kokkos::DualView
113+
auto mirror_view_src = Kokkos::create_mirror_view_and_copy(
114+
typename MemorySpace::kokkos_space{}, kokkos_view_src);
115+
typename SparseMatrix<Number, MemorySpace>::VectorType::dual_view_type
116+
kokkos_dual_view_src(mirror_view_src, kokkos_view_src);
117+
118+
// create the Tpetra::Vector
119+
typename SparseMatrix<Number, MemorySpace>::VectorType tpetra_src(
120+
M.trilinos_matrix().getDomainMap(), kokkos_dual_view_src);
121+
122+
M.trilinos_matrix().apply(tpetra_src, tpetra_dst, mode, alpha, beta);
123+
}
124+
125+
126+
127+
template <typename Number, typename MemorySpace, typename VectorType>
128+
void
129+
apply(SparseMatrix<Number, MemorySpace> &,
130+
const VectorType &,
131+
VectorType &,
132+
Teuchos::ETransp,
133+
Number,
134+
Number)
135+
{
136+
DEAL_II_NOT_IMPLEMENTED();
137+
}
138+
} // namespace internal
139+
140+
141+
34142
// reinit_matrix():
35143
namespace
36144
{
@@ -1170,81 +1278,58 @@ namespace LinearAlgebra
11701278

11711279

11721280
// Multiplications
1173-
11741281
template <typename Number, typename MemorySpace>
1282+
template <typename InputVectorType>
11751283
void
1176-
SparseMatrix<Number, MemorySpace>::vmult(
1177-
Vector<Number, MemorySpace> &dst,
1178-
const Vector<Number, MemorySpace> &src) const
1284+
SparseMatrix<Number, MemorySpace>::vmult(InputVectorType &dst,
1285+
const InputVectorType &src) const
11791286
{
1180-
Assert(&src != &dst, ExcSourceEqualsDestination());
1181-
Assert(matrix->isFillComplete(), ExcMatrixNotCompressed());
1182-
Assert(src.trilinos_vector().getMap()->isSameAs(*matrix->getDomainMap()),
1183-
ExcColMapMissmatch());
1184-
Assert(dst.trilinos_vector().getMap()->isSameAs(*matrix->getRangeMap()),
1185-
ExcDomainMapMissmatch());
1186-
matrix->apply(src.trilinos_vector(), dst.trilinos_vector());
1287+
internal::apply(*this, src, dst);
11871288
}
11881289

11891290

11901291

11911292
template <typename Number, typename MemorySpace>
1293+
template <typename InputVectorType>
11921294
void
1193-
SparseMatrix<Number, MemorySpace>::Tvmult(
1194-
Vector<Number, MemorySpace> &dst,
1195-
const Vector<Number, MemorySpace> &src) const
1295+
SparseMatrix<Number, MemorySpace>::Tvmult(InputVectorType &dst,
1296+
const InputVectorType &src) const
11961297
{
1197-
Assert(&src != &dst, ExcSourceEqualsDestination());
1198-
Assert(matrix->isFillComplete(), ExcMatrixNotCompressed());
1199-
Assert(dst.trilinos_vector().getMap()->isSameAs(*matrix->getDomainMap()),
1200-
ExcColMapMissmatch());
1201-
Assert(src.trilinos_vector().getMap()->isSameAs(*matrix->getRangeMap()),
1202-
ExcDomainMapMissmatch());
1203-
matrix->apply(src.trilinos_vector(),
1204-
dst.trilinos_vector(),
1205-
Teuchos::TRANS);
1298+
internal::apply(*this, src, dst, Teuchos::TRANS);
12061299
}
12071300

12081301

12091302

12101303
template <typename Number, typename MemorySpace>
1304+
template <typename InputVectorType>
12111305
void
12121306
SparseMatrix<Number, MemorySpace>::vmult_add(
1213-
Vector<Number, MemorySpace> &dst,
1214-
const Vector<Number, MemorySpace> &src) const
1307+
InputVectorType &dst,
1308+
const InputVectorType &src) const
12151309
{
1216-
Assert(&src != &dst, ExcSourceEqualsDestination());
1217-
Assert(matrix->isFillComplete(), ExcMatrixNotCompressed());
1218-
Assert(src.trilinos_vector().getMap()->isSameAs(*matrix->getDomainMap()),
1219-
ExcColMapMissmatch());
1220-
Assert(dst.trilinos_vector().getMap()->isSameAs(*matrix->getRangeMap()),
1221-
ExcDomainMapMissmatch());
1222-
matrix->apply(src.trilinos_vector(),
1223-
dst.trilinos_vector(),
1224-
Teuchos::NO_TRANS,
1225-
Teuchos::ScalarTraits<Number>::one(),
1226-
Teuchos::ScalarTraits<Number>::one());
1310+
internal::apply(*this,
1311+
src,
1312+
dst,
1313+
Teuchos::NO_TRANS,
1314+
Teuchos::ScalarTraits<Number>::one(),
1315+
Teuchos::ScalarTraits<Number>::one());
12271316
}
12281317

12291318

12301319

12311320
template <typename Number, typename MemorySpace>
1321+
template <typename InputVectorType>
12321322
void
12331323
SparseMatrix<Number, MemorySpace>::Tvmult_add(
1234-
Vector<Number, MemorySpace> &dst,
1235-
const Vector<Number, MemorySpace> &src) const
1324+
InputVectorType &dst,
1325+
const InputVectorType &src) const
12361326
{
1237-
Assert(&src != &dst, ExcSourceEqualsDestination());
1238-
Assert(matrix->isFillComplete(), ExcMatrixNotCompressed());
1239-
Assert(dst.trilinos_vector().getMap()->isSameAs(*matrix->getDomainMap()),
1240-
ExcColMapMissmatch());
1241-
Assert(src.trilinos_vector().getMap()->isSameAs(*matrix->getRangeMap()),
1242-
ExcDomainMapMissmatch());
1243-
matrix->apply(src.trilinos_vector(),
1244-
dst.trilinos_vector(),
1245-
Teuchos::TRANS,
1246-
Teuchos::ScalarTraits<Number>::one(),
1247-
Teuchos::ScalarTraits<Number>::one());
1327+
internal::apply(*this,
1328+
src,
1329+
dst,
1330+
Teuchos::TRANS,
1331+
Teuchos::ScalarTraits<Number>::one(),
1332+
Teuchos::ScalarTraits<Number>::one());
12481333
}
12491334

12501335

Diff for: source/lac/trilinos_tpetra_sparse_matrix.cc

+31
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,37 @@ namespace LinearAlgebra
4646
template void
4747
SparseMatrix<double>::reinit(const dealii::DynamicSparsityPattern &);
4848

49+
template void
50+
SparseMatrix<double>::vmult(Vector<double> &dst,
51+
const Vector<double> &src) const;
52+
53+
template void
54+
SparseMatrix<double>::Tvmult(Vector<double> &dst,
55+
const Vector<double> &src) const;
56+
57+
template void
58+
SparseMatrix<double>::vmult_add(Vector<double> &dst,
59+
const Vector<double> &src) const;
60+
61+
template void
62+
SparseMatrix<double>::Tvmult_add(Vector<double> &dst,
63+
const Vector<double> &src) const;
64+
65+
template void
66+
SparseMatrix<double>::vmult(::dealii::Vector<double> &dst,
67+
const ::dealii::Vector<double> &src) const;
68+
69+
template void
70+
SparseMatrix<double>::Tvmult(::dealii::Vector<double> &dst,
71+
const ::dealii::Vector<double> &src) const;
72+
73+
template void
74+
SparseMatrix<double>::vmult_add(::dealii::Vector<double> &dst,
75+
const ::dealii::Vector<double> &src) const;
76+
77+
template void
78+
SparseMatrix<double>::Tvmult_add(::dealii::Vector<double> &dst,
79+
const ::dealii::Vector<double> &src) const;
4980
} // namespace TpetraWrappers
5081
} // namespace LinearAlgebra
5182
# endif // DOXYGEN

0 commit comments

Comments
 (0)