@@ -31,6 +31,114 @@ namespace LinearAlgebra
31
31
32
32
namespace TpetraWrappers
33
33
{
34
+ namespace internal
35
+ {
36
+ template <typename Number, typename MemorySpace>
37
+ void
38
+ apply (const SparseMatrix<Number, MemorySpace> &M,
39
+ const Vector<Number, MemorySpace> &src,
40
+ Vector<Number, MemorySpace> &dst,
41
+ Teuchos::ETransp mode = Teuchos::NO_TRANS,
42
+ Number alpha = Teuchos::ScalarTraits<Number>::one(),
43
+ Number beta = Teuchos::ScalarTraits<Number>::zero())
44
+ {
45
+ Assert (&src != &dst,
46
+ SparseMatrix<double >::ExcSourceEqualsDestination ());
47
+ Assert (M.trilinos_matrix ().isFillComplete (),
48
+ SparseMatrix<double >::ExcMatrixNotCompressed ());
49
+
50
+ if (mode == Teuchos::NO_TRANS)
51
+ {
52
+ Assert (src.trilinos_vector ().getMap ()->isSameAs (
53
+ *M.trilinos_matrix ().getDomainMap ()),
54
+ SparseMatrix<double >::ExcColMapMissmatch ());
55
+ Assert (dst.trilinos_vector ().getMap ()->isSameAs (
56
+ *M.trilinos_matrix ().getRangeMap ()),
57
+ SparseMatrix<double >::ExcDomainMapMissmatch ());
58
+ }
59
+ else
60
+ {
61
+ Assert (dst.trilinos_vector ().getMap ()->isSameAs (
62
+ *M.trilinos_matrix ().getDomainMap ()),
63
+ SparseMatrix<double >::ExcColMapMissmatch ());
64
+ Assert (src.trilinos_vector ().getMap ()->isSameAs (
65
+ *M.trilinos_matrix ().getRangeMap ()),
66
+ SparseMatrix<double >::ExcDomainMapMissmatch ());
67
+ }
68
+
69
+ M.trilinos_matrix ().apply (
70
+ src.trilinos_vector (), dst.trilinos_vector (), mode, alpha, beta);
71
+ }
72
+
73
+
74
+
75
+ template <typename Number, typename MemorySpace>
76
+ void
77
+ apply (const SparseMatrix<Number, MemorySpace> &M,
78
+ const dealii::Vector<Number> &src,
79
+ dealii::Vector<Number> &dst,
80
+ Teuchos::ETransp mode = Teuchos::NO_TRANS,
81
+ Number alpha = Teuchos::ScalarTraits<Number>::one(),
82
+ Number beta = Teuchos::ScalarTraits<Number>::zero())
83
+ {
84
+ Assert (&src != &dst,
85
+ SparseMatrix<double >::ExcSourceEqualsDestination ());
86
+ Assert (M.trilinos_matrix ().isFillComplete (),
87
+ SparseMatrix<double >::ExcMatrixNotCompressed ());
88
+
89
+ // get the size of the input vectors:
90
+ const size_type dst_local_size = dst.end () - dst.begin ();
91
+ const size_type src_local_size = src.end () - src.begin ();
92
+
93
+ // For the dst vector:
94
+ Kokkos::View<Number **, Kokkos::LayoutLeft, Kokkos::HostSpace>
95
+ kokkos_view_dst (dst.begin (), dst_local_size, 1 );
96
+
97
+ // get a Kokkos::DualView
98
+ auto mirror_view_dst = Kokkos::create_mirror_view_and_copy (
99
+ typename MemorySpace::kokkos_space{}, kokkos_view_dst);
100
+ typename SparseMatrix<Number, MemorySpace>::VectorType::dual_view_type
101
+ kokkos_dual_view_dst (mirror_view_dst, kokkos_view_dst);
102
+
103
+ // create the Tpetra::Vector
104
+ typename SparseMatrix<Number, MemorySpace>::VectorType tpetra_dst (
105
+ M.trilinos_matrix ().getRangeMap (), kokkos_dual_view_dst);
106
+
107
+ // For the src vector:
108
+ // create a Kokkos::View from the src vector
109
+ Kokkos::View<Number **, Kokkos::LayoutLeft, Kokkos::HostSpace>
110
+ kokkos_view_src (const_cast <Number *>(src.begin ()), src_local_size, 1 );
111
+
112
+ // get a Kokkos::DualView
113
+ auto mirror_view_src = Kokkos::create_mirror_view_and_copy (
114
+ typename MemorySpace::kokkos_space{}, kokkos_view_src);
115
+ typename SparseMatrix<Number, MemorySpace>::VectorType::dual_view_type
116
+ kokkos_dual_view_src (mirror_view_src, kokkos_view_src);
117
+
118
+ // create the Tpetra::Vector
119
+ typename SparseMatrix<Number, MemorySpace>::VectorType tpetra_src (
120
+ M.trilinos_matrix ().getDomainMap (), kokkos_dual_view_src);
121
+
122
+ M.trilinos_matrix ().apply (tpetra_src, tpetra_dst, mode, alpha, beta);
123
+ }
124
+
125
+
126
+
127
+ template <typename Number, typename MemorySpace, typename VectorType>
128
+ void
129
+ apply (SparseMatrix<Number, MemorySpace> &,
130
+ const VectorType &,
131
+ VectorType &,
132
+ Teuchos::ETransp,
133
+ Number,
134
+ Number)
135
+ {
136
+ DEAL_II_NOT_IMPLEMENTED ();
137
+ }
138
+ } // namespace internal
139
+
140
+
141
+
34
142
// reinit_matrix():
35
143
namespace
36
144
{
@@ -1170,81 +1278,58 @@ namespace LinearAlgebra
1170
1278
1171
1279
1172
1280
// Multiplications
1173
-
1174
1281
template <typename Number, typename MemorySpace>
1282
+ template <typename InputVectorType>
1175
1283
void
1176
- SparseMatrix<Number, MemorySpace>::vmult(
1177
- Vector<Number, MemorySpace> &dst,
1178
- const Vector<Number, MemorySpace> &src) const
1284
+ SparseMatrix<Number, MemorySpace>::vmult(InputVectorType &dst,
1285
+ const InputVectorType &src) const
1179
1286
{
1180
- Assert (&src != &dst, ExcSourceEqualsDestination ());
1181
- Assert (matrix->isFillComplete (), ExcMatrixNotCompressed ());
1182
- Assert (src.trilinos_vector ().getMap ()->isSameAs (*matrix->getDomainMap ()),
1183
- ExcColMapMissmatch ());
1184
- Assert (dst.trilinos_vector ().getMap ()->isSameAs (*matrix->getRangeMap ()),
1185
- ExcDomainMapMissmatch ());
1186
- matrix->apply (src.trilinos_vector (), dst.trilinos_vector ());
1287
+ internal::apply (*this , src, dst);
1187
1288
}
1188
1289
1189
1290
1190
1291
1191
1292
template <typename Number, typename MemorySpace>
1293
+ template <typename InputVectorType>
1192
1294
void
1193
- SparseMatrix<Number, MemorySpace>::Tvmult(
1194
- Vector<Number, MemorySpace> &dst,
1195
- const Vector<Number, MemorySpace> &src) const
1295
+ SparseMatrix<Number, MemorySpace>::Tvmult(InputVectorType &dst,
1296
+ const InputVectorType &src) const
1196
1297
{
1197
- Assert (&src != &dst, ExcSourceEqualsDestination ());
1198
- Assert (matrix->isFillComplete (), ExcMatrixNotCompressed ());
1199
- Assert (dst.trilinos_vector ().getMap ()->isSameAs (*matrix->getDomainMap ()),
1200
- ExcColMapMissmatch ());
1201
- Assert (src.trilinos_vector ().getMap ()->isSameAs (*matrix->getRangeMap ()),
1202
- ExcDomainMapMissmatch ());
1203
- matrix->apply (src.trilinos_vector (),
1204
- dst.trilinos_vector (),
1205
- Teuchos::TRANS);
1298
+ internal::apply (*this , src, dst, Teuchos::TRANS);
1206
1299
}
1207
1300
1208
1301
1209
1302
1210
1303
template <typename Number, typename MemorySpace>
1304
+ template <typename InputVectorType>
1211
1305
void
1212
1306
SparseMatrix<Number, MemorySpace>::vmult_add(
1213
- Vector<Number, MemorySpace> &dst,
1214
- const Vector<Number, MemorySpace> &src) const
1307
+ InputVectorType &dst,
1308
+ const InputVectorType &src) const
1215
1309
{
1216
- Assert (&src != &dst, ExcSourceEqualsDestination ());
1217
- Assert (matrix->isFillComplete (), ExcMatrixNotCompressed ());
1218
- Assert (src.trilinos_vector ().getMap ()->isSameAs (*matrix->getDomainMap ()),
1219
- ExcColMapMissmatch ());
1220
- Assert (dst.trilinos_vector ().getMap ()->isSameAs (*matrix->getRangeMap ()),
1221
- ExcDomainMapMissmatch ());
1222
- matrix->apply (src.trilinos_vector (),
1223
- dst.trilinos_vector (),
1224
- Teuchos::NO_TRANS,
1225
- Teuchos::ScalarTraits<Number>::one (),
1226
- Teuchos::ScalarTraits<Number>::one ());
1310
+ internal::apply (*this ,
1311
+ src,
1312
+ dst,
1313
+ Teuchos::NO_TRANS,
1314
+ Teuchos::ScalarTraits<Number>::one (),
1315
+ Teuchos::ScalarTraits<Number>::one ());
1227
1316
}
1228
1317
1229
1318
1230
1319
1231
1320
template <typename Number, typename MemorySpace>
1321
+ template <typename InputVectorType>
1232
1322
void
1233
1323
SparseMatrix<Number, MemorySpace>::Tvmult_add(
1234
- Vector<Number, MemorySpace> &dst,
1235
- const Vector<Number, MemorySpace> &src) const
1324
+ InputVectorType &dst,
1325
+ const InputVectorType &src) const
1236
1326
{
1237
- Assert (&src != &dst, ExcSourceEqualsDestination ());
1238
- Assert (matrix->isFillComplete (), ExcMatrixNotCompressed ());
1239
- Assert (dst.trilinos_vector ().getMap ()->isSameAs (*matrix->getDomainMap ()),
1240
- ExcColMapMissmatch ());
1241
- Assert (src.trilinos_vector ().getMap ()->isSameAs (*matrix->getRangeMap ()),
1242
- ExcDomainMapMissmatch ());
1243
- matrix->apply (src.trilinos_vector (),
1244
- dst.trilinos_vector (),
1245
- Teuchos::TRANS,
1246
- Teuchos::ScalarTraits<Number>::one (),
1247
- Teuchos::ScalarTraits<Number>::one ());
1327
+ internal::apply (*this ,
1328
+ src,
1329
+ dst,
1330
+ Teuchos::TRANS,
1331
+ Teuchos::ScalarTraits<Number>::one (),
1332
+ Teuchos::ScalarTraits<Number>::one ());
1248
1333
}
1249
1334
1250
1335
0 commit comments