|
44 | 44 | #include "paddle/phi/core/distributed/xccl_comm_context.h"
|
45 | 45 | #endif
|
46 | 46 |
|
| 47 | +#if defined(PADDLE_WITH_FLAGCX) |
| 48 | +#include "paddle/phi/core/distributed/flagcx_comm_context.h" |
| 49 | +#include "paddle/phi/core/distributed/flagcx_tools.h" |
| 50 | +#endif |
| 51 | + |
47 | 52 | namespace phi::distributed {
|
48 | 53 |
|
49 | 54 | int CommContextManager::device_id = -1;
|
@@ -261,6 +266,51 @@ void CommContextManager::CreateBKCLCommContext(
|
261 | 266 | comm_context_manager.Emplace(unique_comm_key, std::move(bkcl_comm_context));
|
262 | 267 | }
|
263 | 268 | #endif
|
| 269 | + |
| 270 | +#if defined(PADDLE_WITH_FLAGCX) |
| 271 | +void CommContextManager::CreateFlagcxCommContext( |
| 272 | + const std::shared_ptr<Store>& store, |
| 273 | + const std::string& unique_comm_key, |
| 274 | + int rank, |
| 275 | + int size, |
| 276 | + const std::string& hash_key) { |
| 277 | + auto& comm_context_manager = CommContextManager::GetInstance(); |
| 278 | + if (comm_context_manager.Has(unique_comm_key)) { |
| 279 | + return; |
| 280 | + } |
| 281 | + flagcxHandlerGroup_t flagcx_handler; |
| 282 | + phi::dynload::flagcxHandleInit(&flagcx_handler); |
| 283 | + if (rank == 0) { |
| 284 | + phi::dynload::flagcxGetUniqueId(&flagcx_handler->uniqueId); |
| 285 | + } |
| 286 | + |
| 287 | + std::string unique_key = "FlagcxCommContext/" + unique_comm_key + hash_key; |
| 288 | + if (rank == 0) { |
| 289 | + std::vector<uint8_t> flagcx_id_wrapper( |
| 290 | + reinterpret_cast<uint8_t*>(flagcx_handler->uniqueId), |
| 291 | + reinterpret_cast<uint8_t*>(flagcx_handler->uniqueId) + |
| 292 | + sizeof(flagcxUniqueId)); |
| 293 | + store->set(unique_key, flagcx_id_wrapper); |
| 294 | + } else { |
| 295 | + const auto& flagcx_id_wrapper = store->get(unique_key); |
| 296 | + std::memcpy(reinterpret_cast<uint8_t*>(flagcx_handler->uniqueId), |
| 297 | + flagcx_id_wrapper.data(), |
| 298 | + flagcx_id_wrapper.size()); |
| 299 | + } |
| 300 | + |
| 301 | + VLOG(3) << "init FlagcxCommContext rank: " << rank << ", size: " << size |
| 302 | + << ", unique_comm_key: " << unique_comm_key |
| 303 | + << ", unique_key: " << unique_key << ", flagcx_id: " |
| 304 | + << SerializeFlagcxUniqueId(*flagcx_handler->uniqueId); |
| 305 | + auto flagcx_comm_context = |
| 306 | + std::make_unique<FlagcxCommContext>(rank, size, flagcx_handler); |
| 307 | + // TODO(changtao): find a way to manage different device context, |
| 308 | + // now we use cuda device context as default |
| 309 | + comm_context_manager.SetStore(store); |
| 310 | + comm_context_manager.Emplace(unique_comm_key, std::move(flagcx_comm_context)); |
| 311 | +} |
| 312 | +#endif |
| 313 | + |
264 | 314 | CommContext* CommContextManager::Emplace(
|
265 | 315 | const std::string& unique_comm_key,
|
266 | 316 | std::unique_ptr<CommContext> comm_context) {
|
|
0 commit comments