Skip to content

Commit 17d2967

Browse files
【Comm】Add FlagCX Comm Context (#71924)
1 parent 6d4cb7e commit 17d2967

19 files changed

+683
-0
lines changed

.gitmodules

+4
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,7 @@
129129
path = third_party/openvino
130130
url = https://github.com/openvinotoolkit/openvino.git
131131
ignore = dirty
132+
[submodule "third_party/flagcx"]
133+
path = third_party/flagcx
134+
url = https://github.com/FlagOpen/FlagCX.git
135+
ignore = dirty

CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ option(
313313
OFF)
314314
option(WITH_CINN "Compile PaddlePaddle with CINN" OFF)
315315
option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON)
316+
option(WITH_FLAGCX "Compile PaddlePaddle with FLAGCX support" OFF)
316317
option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON)
317318
option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF)
318319
option(WITH_CRYPTO "Compile PaddlePaddle with crypto support" ON)
@@ -538,6 +539,10 @@ else()
538539
endif()
539540
endif()
540541

542+
if(WITH_FLAGCX)
543+
add_definitions("-DPADDLE_WITH_FLAGCX")
544+
endif()
545+
541546
if(WITH_HETERPS AND WITH_PSLIB)
542547
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
543548
endif()

cmake/external/flagcx.cmake

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
set(CMAKE_FIND_DEBUG_MODE ON)
2+
# flagcx.cmake
3+
if(NOT WITH_FLAGCX)
4+
return()
5+
endif()
6+
7+
set(FLAGCX_SOURCE_DIR "${PADDLE_SOURCE_DIR}/third_party/flagcx")
8+
set(FLAGCX_BINARY_DIR "${PADDLE_SOURCE_DIR}/build/third_party/flagcx")
9+
set(THIRD_PARTY_DIR "${PADDLE_SOURCE_DIR}/build/third_party")
10+
set(FLAGCX_ROOT "/usr/local/flagcx")
11+
set(FLAGCX_LIB_DIR "${FLAGCX_BINARY_DIR}/build/lib")
12+
set(USR_LOCAL_DIR "/usr/local")
13+
14+
file(REMOVE_RECURSE ${FLAGCX_BINARY_DIR})
15+
message(STATUS "removed old flagcx dir")
16+
message(STATUS "Copying third-party source to build directory")
17+
execute_process(COMMAND cp -r ${FLAGCX_SOURCE_DIR} ${THIRD_PARTY_DIR}
18+
RESULT_VARIABLE COPY_RESULT)
19+
20+
if(NOT COPY_RESULT EQUAL 0)
21+
message(FATAL_ERROR "Failed to copy third-party source to build directory")
22+
endif()
23+
24+
# Create a custom target to build the third-party library
25+
message(STATUS "Building third-party library with its Makefile")
26+
execute_process(
27+
COMMAND make
28+
WORKING_DIRECTORY ${FLAGCX_BINARY_DIR}
29+
RESULT_VARIABLE BUILD_RESULT)
30+
31+
find_path(
32+
FLAGCX_INCLUDE_DIR flagcx.h
33+
PATHS ${FLAGCX_SOURCE_DIR}/flagcx/include
34+
NO_DEFAULT_PATH)
35+
36+
message(STATUS "FLAGCX_INCLUDE_DIR is ${FLAGCX_INCLUDE_DIR}")
37+
include_directories(SYSTEM ${FLAGCX_INCLUDE_DIR})
38+
39+
add_library(flagcx INTERFACE)
40+
find_library(
41+
FLAGCX_LIB
42+
NAMES flagcx libflagcx
43+
PATHS ${FLAGCX_LIB_DIR}
44+
DOC "My custom library")
45+
46+
add_dependencies(flagcx FLAGCX_LIB)
47+
message(STATUS "FLAGCX_LIB is ${FLAGCX_LIB}")

cmake/third_party.cmake

+5
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,11 @@ if(WITH_TESTING OR WITH_DISTRIBUTE)
470470
list(APPEND third_party_deps extern_gtest)
471471
endif()
472472

473+
if(WITH_FLAGCX)
474+
include(external/flagcx)
475+
list(APPEND third_party_deps flagcx)
476+
endif()
477+
473478
if(WITH_ONNXRUNTIME)
474479
include(external/onnxruntime
475480
)# download, build, install onnxruntime、paddle2onnx

paddle/phi/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ if(WITH_GLOO)
8787
list(APPEND PHI_DEPS gloo)
8888
endif()
8989

90+
if(WITH_FLAGCX)
91+
list(APPEND PHI_DEPS flagcx)
92+
endif()
93+
9094
if(WITH_CUDNN_FRONTEND)
9195
list(APPEND PHI_DEPS cudnn-frontend)
9296
endif()

paddle/phi/backends/dynload/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ if(WITH_XPU)
8080
collect_srcs(backends_srcs SRCS xpti.cc)
8181
endif()
8282

83+
if(WITH_FLAGCX)
84+
collect_srcs(backends_srcs SRCS flagcx.cc)
85+
endif()
86+
8387
if(WITH_FLASHATTN)
8488
list(APPEND DYNLOAD_COMMON_SRCS flashattn.cc)
8589
endif()

paddle/phi/backends/dynload/dynamic_loader.cc

+19
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ PHI_DEFINE_string(rccl_dir,
6868
"dlopen will search rccl from LD_LIBRARY_PATH");
6969
#endif
7070

71+
#ifdef PADDLE_WITH_FLAGCX
72+
COMMON_DECLARE_string(flagcx_dir);
73+
#endif
74+
75+
PHI_DEFINE_EXPORTED_string(
76+
flagcx_dir, // NOLINT
77+
"",
78+
"Specify path for loading libflagcx.so. For instance, "
79+
"For instance, /usr/local/flagcx/lib. If default, "
80+
"dlopen will search flagcx from LD_LIBRARY_PATH");
81+
7182
#ifdef PADDLE_WITH_XPU
7283
PD_DEFINE_string(xpti_dir, "", "Specify path for loading libxpti.so.");
7384
#endif
@@ -777,6 +788,14 @@ void* GetNCCLDsoHandle() {
777788
#endif
778789
}
779790

791+
void* GetFLAGCXDsoHandle() {
792+
#ifdef PADDLE_WITH_FLAGCX
793+
return GetDsoHandleFromSearchPath(FLAGS_flagcx_dir, "libflagcx.so");
794+
#else
795+
return nullptr;
796+
#endif
797+
}
798+
780799
void* GetTensorRtDsoHandle() {
781800
#if defined(__APPLE__) || defined(__OSX__)
782801
return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib");

paddle/phi/backends/dynload/dynamic_loader.h

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ void* GetWarpRNNTDsoHandle();
3939
void* GetFlashAttnDsoHandle();
4040
void* GetFlashAttnV3DsoHandle();
4141
void* GetNCCLDsoHandle();
42+
void* GetFLAGCXDsoHandle();
4243
void* GetTensorRtDsoHandle();
4344
void* GetMKLMLDsoHandle();
4445
void* GetLAPACKDsoHandle();

paddle/phi/backends/dynload/flagcx.cc

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/phi/backends/dynload/flagcx.h"
13+
14+
namespace phi {
15+
namespace dynload {
16+
17+
std::once_flag flagcx_dso_flag;
18+
void* flagcx_dso_handle = nullptr;
19+
20+
#define DEFINE_WRAP(__name) DynLoad__##__name __name
21+
22+
FLAGCX_RAND_ROUTINE_EACH(DEFINE_WRAP);
23+
24+
} // namespace dynload
25+
} // namespace phi

paddle/phi/backends/dynload/flagcx.h

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
#pragma once
12+
13+
#include <flagcx.h>
14+
15+
#include <mutex> // NOLINT
16+
17+
#include "paddle/phi/backends/dynload/dynamic_loader.h"
18+
#include "paddle/phi/common/port.h"
19+
20+
namespace phi {
21+
namespace dynload {
22+
23+
extern std::once_flag flagcx_dso_flag;
24+
extern void* flagcx_dso_handle;
25+
26+
#define DECLARE_DYNAMIC_LOAD_FLAGCX_WRAP(__name) \
27+
struct DynLoad__##__name { \
28+
template <typename... Args> \
29+
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
30+
using flagcx_func = decltype(&::__name); \
31+
std::call_once(flagcx_dso_flag, []() { \
32+
flagcx_dso_handle = phi::dynload::GetFLAGCXDsoHandle(); \
33+
}); \
34+
static void* p_##__name = dlsym(flagcx_dso_handle, #__name); \
35+
return reinterpret_cast<flagcx_func>(p_##__name)(args...); \
36+
} \
37+
}; \
38+
extern struct DynLoad__##__name __name
39+
40+
#define FLAGCX_RAND_ROUTINE_EACH(__macro) \
41+
__macro(flagcxGetUniqueId); \
42+
__macro(flagcxCommInitRank); \
43+
__macro(flagcxGetVersion); \
44+
__macro(flagcxCommAbort); \
45+
__macro(flagcxCommDestroy); \
46+
__macro(flagcxCommCount); \
47+
__macro(flagcxCommUserRank); \
48+
__macro(flagcxAllReduce); \
49+
__macro(flagcxBroadcast); \
50+
__macro(flagcxAllGather); \
51+
__macro(flagcxGroupStart); \
52+
__macro(flagcxGroupEnd); \
53+
__macro(flagcxReduce); \
54+
__macro(flagcxReduceScatter); \
55+
__macro(flagcxCommGetAsyncError); \
56+
__macro(flagcxSend); \
57+
__macro(flagcxRecv); \
58+
__macro(flagcxHandleInit); \
59+
__macro(flagcxHandleFree); \
60+
__macro(flagcxGetErrorString);
61+
62+
FLAGCX_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLAGCX_WRAP)
63+
64+
#undef DECLARE_DYNAMIC_LOAD_FLAGCX_WRAP
65+
66+
} // namespace dynload
67+
} // namespace phi

paddle/phi/core/distributed/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,8 @@ if(WITH_XPU_BKCL)
2323
list(APPEND DISTRIBUTED_COMMON_SRCS bkcl_comm_context.cc)
2424
endif()
2525

26+
if(WITH_FLAGCX)
27+
list(APPEND DISTRIBUTED_COMMON_SRCS flagcx_comm_context.cc flagcx_tools.cc)
28+
endif()
29+
2630
collect_srcs(core_srcs SRCS ${DISTRIBUTED_COMMON_SRCS})

paddle/phi/core/distributed/comm_context_manager.cc

+50
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
#include "paddle/phi/core/distributed/xccl_comm_context.h"
4545
#endif
4646

47+
#if defined(PADDLE_WITH_FLAGCX)
48+
#include "paddle/phi/core/distributed/flagcx_comm_context.h"
49+
#include "paddle/phi/core/distributed/flagcx_tools.h"
50+
#endif
51+
4752
namespace phi::distributed {
4853

4954
int CommContextManager::device_id = -1;
@@ -261,6 +266,51 @@ void CommContextManager::CreateBKCLCommContext(
261266
comm_context_manager.Emplace(unique_comm_key, std::move(bkcl_comm_context));
262267
}
263268
#endif
269+
270+
#if defined(PADDLE_WITH_FLAGCX)
271+
void CommContextManager::CreateFlagcxCommContext(
272+
const std::shared_ptr<Store>& store,
273+
const std::string& unique_comm_key,
274+
int rank,
275+
int size,
276+
const std::string& hash_key) {
277+
auto& comm_context_manager = CommContextManager::GetInstance();
278+
if (comm_context_manager.Has(unique_comm_key)) {
279+
return;
280+
}
281+
flagcxHandlerGroup_t flagcx_handler;
282+
phi::dynload::flagcxHandleInit(&flagcx_handler);
283+
if (rank == 0) {
284+
phi::dynload::flagcxGetUniqueId(&flagcx_handler->uniqueId);
285+
}
286+
287+
std::string unique_key = "FlagcxCommContext/" + unique_comm_key + hash_key;
288+
if (rank == 0) {
289+
std::vector<uint8_t> flagcx_id_wrapper(
290+
reinterpret_cast<uint8_t*>(flagcx_handler->uniqueId),
291+
reinterpret_cast<uint8_t*>(flagcx_handler->uniqueId) +
292+
sizeof(flagcxUniqueId));
293+
store->set(unique_key, flagcx_id_wrapper);
294+
} else {
295+
const auto& flagcx_id_wrapper = store->get(unique_key);
296+
std::memcpy(reinterpret_cast<uint8_t*>(flagcx_handler->uniqueId),
297+
flagcx_id_wrapper.data(),
298+
flagcx_id_wrapper.size());
299+
}
300+
301+
VLOG(3) << "init FlagcxCommContext rank: " << rank << ", size: " << size
302+
<< ", unique_comm_key: " << unique_comm_key
303+
<< ", unique_key: " << unique_key << ", flagcx_id: "
304+
<< SerializeFlagcxUniqueId(*flagcx_handler->uniqueId);
305+
auto flagcx_comm_context =
306+
std::make_unique<FlagcxCommContext>(rank, size, flagcx_handler);
307+
// TODO(changtao): find a way to manage different device context,
308+
// now we use cuda device context as default
309+
comm_context_manager.SetStore(store);
310+
comm_context_manager.Emplace(unique_comm_key, std::move(flagcx_comm_context));
311+
}
312+
#endif
313+
264314
CommContext* CommContextManager::Emplace(
265315
const std::string& unique_comm_key,
266316
std::unique_ptr<CommContext> comm_context) {

paddle/phi/core/distributed/comm_context_manager.h

+12
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
#include "paddle/phi/backends/gpu/forwards.h"
2929
#endif
3030

31+
#if defined(PADDLE_WITH_FLAGCX)
32+
#include <flagcx.h>
33+
#endif
34+
3135
namespace phi {
3236
namespace distributed {
3337

@@ -105,6 +109,14 @@ class CommContextManager {
105109
const std::string& hash_key = "");
106110
#endif
107111

112+
#if defined(PADDLE_WITH_FLAGCX)
113+
static void CreateFlagcxCommContext(const std::shared_ptr<Store>& store,
114+
const std::string& unique_comm_key,
115+
int rank,
116+
int size,
117+
const std::string& hash_key = "");
118+
#endif
119+
108120
private:
109121
DISABLE_COPY_AND_ASSIGN(CommContextManager);
110122

0 commit comments

Comments
 (0)