diff --git a/paddle/fluid/distributed/collective/bkcl_tools.cc b/paddle/fluid/distributed/collective/bkcl_tools.cc index 7e95eb8b748eb..902328d869c4a 100644 --- a/paddle/fluid/distributed/collective/bkcl_tools.cc +++ b/paddle/fluid/distributed/collective/bkcl_tools.cc @@ -41,5 +41,29 @@ std::string SerializeBKCLUniqueId(const BKCLUniqueId& bkclID) { return oss.str(); } +std::string BKCLDTypeToString(BKCLDataType dtype) { +#define PD_BKCL_DTYPE_TO_STR(__bkcl_dtype, __str_dtype) \ + if (dtype == __bkcl_dtype) return __str_dtype; + PD_BKCL_DTYPE_TO_STR(BKCL_FLOAT, "float32"); + PD_BKCL_DTYPE_TO_STR(BKCL_FLOAT16, "float16"); + PD_BKCL_DTYPE_TO_STR(BKCL_BFLOAT16, "bfloat16"); + PD_BKCL_DTYPE_TO_STR(BKCL_FLOAT64, "float64"); + PD_BKCL_DTYPE_TO_STR(BKCL_UINT8, "uint8"); + PD_BKCL_DTYPE_TO_STR(BKCL_INT32, "int32"); + PD_BKCL_DTYPE_TO_STR(BKCL_INT64, "int64"); + +#undef PD_BKCL_DTYPE_TO_STR + PADDLE_THROW(phi::errors::InvalidArgument( + "This datatype %d in nccl is not supported.", static_cast(dtype))); +} + +std::string BKCLRedTypeToString(BKCLOp op) { + if (op == BKCL_ADD) return "SUM"; + if (op == BKCL_PRODUCT) return "PROD"; + if (op == BKCL_MIN) return "MIN"; + if (op == BKCL_MAX) return "MAX"; + return "UDF_" + std::to_string(op); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/bkcl_tools.h b/paddle/fluid/distributed/collective/bkcl_tools.h index 19d321080d47a..c5462d00fdb20 100644 --- a/paddle/fluid/distributed/collective/bkcl_tools.h +++ b/paddle/fluid/distributed/collective/bkcl_tools.h @@ -102,6 +102,8 @@ class XPUEventManager { BKCLOp ToBKCLRedType(ReduceOp reduction); std::string SerializeBKCLUniqueId(const BKCLUniqueId& bkclId); +std::string BKCLDTypeToString(BKCLDataType dtype); +std::string BKCLRedTypeToString(BKCLOp op); } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index dc3c38d283594..28cdc77325250 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -23,6 +23,7 @@ #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" @@ -84,12 +85,10 @@ ProcessGroupBKCL::ProcessGroupBKCL( : ProcessGroupWithStream(rank, size, gid), store_(store) {} void ProcessGroupBKCL::GroupStart() { - VLOG(3) << "bkcl_group_start"; PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); } void ProcessGroupBKCL::GroupEnd() { - VLOG(3) << "bkcl_group_end"; PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end()); } @@ -107,31 +106,24 @@ std::shared_ptr ProcessGroupBKCL::Recv( tensor = &partial_tensor; } - return Collective( - tensor, - // have to pass a tensor here - // TODO(zhangxiaoci) catch up with nccl's api - *tensor, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_recv" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", src_rank: " << src_rank << ", numel: " << output->numel() - << ", dtype: " << output->type() << ", sync_op: " << sync_op + return Point2Point( + [&](phi::distributed::BKCLCommContext* comm_context, + XPUStream stream, + int rank_in_group) { + VLOG(3) << "[bkcl_recv] " + << "recvbuff: " << tensor->data() + << ", count: " << tensor->numel() << ", datatype: " + << BKCLDTypeToString(phi::ToBKCLDataType(tensor->dtype())) + << ", src_in_group: " << src_rank + << ", bkcl_comm: " << comm_context->GetBKCLComm() + << ", stream: " << stream + << ", rank_in_group: " << rank_in_group << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op << ", use_calc_stream: " << use_calc_stream; - int r = bkcl_recv(comm, - output->data(), - output->numel(), - src_rank, - platform::ToBKCLDataType( - framework::TransToProtoVarType(output->type())), - stream); - return r; + comm_context->Recv(tensor, tensor->numel(), rank_in_group, stream); }, + src_rank, + *tensor, CommType::RECV, sync_op, use_calc_stream); @@ -150,30 +142,28 @@ std::shared_ptr ProcessGroupBKCL::Send( const phi::DenseTensor& tensor_maybe_partial = numel > 0 ? GetPartialTensor(tensor_tmp, offset, numel) : tensor_tmp; - return Collective( - nullptr, - tensor_maybe_partial, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_send" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", dst_rank: " << dst_rank - << ", input numel: " << input.numel() - << ", dtype: " << input.type() << ", sync_op: " << sync_op + return Point2Point( + [&](phi::distributed::BKCLCommContext* comm_context, + XPUStream stream, + int rank_in_group) { + VLOG(3) << "[bkcl_send] " + << "sendbuff: " << tensor_maybe_partial.data() + << ", count: " << tensor_maybe_partial.numel() << ", datatype: " + << BKCLDTypeToString( + phi::ToBKCLDataType(tensor_maybe_partial.dtype())) + << ", dst_in_group: " << dst_rank + << ", bkcl_comm: " << comm_context->GetBKCLComm() + << ", stream: " << stream + << ", rank_in_group: " << rank_in_group << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op << ", use_calc_stream: " << use_calc_stream; - int r = bkcl_send(comm, - input.data(), - input.numel(), - dst_rank, - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - stream); - return r; + comm_context->Send(tensor_maybe_partial, + tensor_maybe_partial.numel(), + rank_in_group, + stream); }, + dst_rank, + tensor_maybe_partial, CommType::SEND, sync_op, use_calc_stream); @@ -205,29 +195,25 @@ void ProcessGroupBKCL::BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id) { void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place, const std::string& place_key) { platform::XPUDeviceGuard guard(place.GetDeviceId()); - BKCLUniqueId bkcl_id; - if (rank_ == 0) { - PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id)); - } - BroadcastUniqueBKCLID(&bkcl_id); VLOG(3) << "init bkcl rank: " << rank_ << ", nranks: " << size_ - << ", place: " << place_key - << ", bkcl uniqueid: " << SerializeBKCLUniqueId(bkcl_id); + << ", place: " << place_key; + + phi::distributed::CommContextManager::CreateBKCLCommContext( + store_, std::to_string(gid_), rank_, size_); calc_event_ = std::make_shared(); auto* calc_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place)); // must use XPUDeviceContext here to make sure XPUContext::Init() is called auto comm_ctx = std::make_unique(place); + auto bkcl_comm_ctx = this->GetCommContext(); + comm_ctx->SetBkclContext(bkcl_comm_ctx->GetBKCLComm()); + // set allocator comm_ctx->SetAllocator(memory::allocation::AllocatorFacade::Instance() .GetAllocator(place) .get()); - - BKCLContext_t bkcl_comm; - BKCLCHECK(bkcl_init_rank(&bkcl_comm, GetRank(), GetSize(), &bkcl_id)); - comm_ctx->SetBkclContext(bkcl_comm); // comm context creates a separate XPU stream for communication comm_ctx->CreateStream(); @@ -243,19 +229,17 @@ void ProcessGroupBKCL::SyncCalcStream(const Place& place) { calc_event_->Block(*comm_ctx); } -template std::shared_ptr ProcessGroupBKCL::Collective( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - Fn fn, + std::function fn, + const phi::DenseTensor& tensor, CommType op_type, bool sync_op, bool use_calc_stream) { - auto tensor_tmp = - paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor); - const auto& place = tensor_tmp.place(); + const auto& place = tensor.place(); const auto& key = GetKeyFromPlace(place); + platform::XPUDeviceGuard xpu_guard(place); + if (!calc_event_ || (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end())) { CreateBKCLEnvCache(place, key); @@ -267,11 +251,59 @@ std::shared_ptr ProcessGroupBKCL::Collective( auto task = CreateTask(place, rank_, op_type, sync_op, use_calc_stream); - const auto* calc_ctx = place_to_calc_ctx_[key]; - const auto& comm_ctx = place_to_comm_ctx_[key]; + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto& comm_ctx = place_to_comm_ctx_.at(key); + auto bkcl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); + + auto bkcl_comm_ctx = this->GetCommContext(); + + fn(bkcl_comm_ctx, bkcl_stream); + + if (!use_calc_stream) { + PADDLE_ENFORCE_NOT_NULL( + comm_ctx.get(), platform::errors::Fatal("comm context is nullptr.")); + task->comm_event_->Record(*comm_ctx.get()); + } + + if (sync_op) { + task->Wait(); + } + + return task; +} + +std::shared_ptr ProcessGroupBKCL::Point2Point( + std::function fn, + int peer, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream) { + auto tensor_tmp = + paddle::experimental::CheckAndTrans2NewContiguousTensor(tensor); + const auto& place = tensor_tmp.place(); + + int p2p_target_rank = peer; + std::string key = GetKeyFromPlace(place); + + platform::XPUDeviceGuard xpu_guard(place); + + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateBKCLEnvCache(place, key); + } + + if (!use_calc_stream) { + SyncCalcStream(place); + } + + auto task = CreateTask(place, rank_, comm_type, sync_op, use_calc_stream); + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto& comm_ctx = place_to_comm_ctx_.at(key); auto bkcl_stream = use_calc_stream ? calc_ctx->stream() : comm_ctx->stream(); - PADDLE_ENFORCE_XPU_SUCCESS( - fn(out_tensor, tensor_tmp, comm_ctx->bkcl_context(), bkcl_stream)); + + auto bkcl_comm_ctx = this->GetCommContext(); + + fn(bkcl_comm_ctx, bkcl_stream, p2p_target_rank); if (!use_calc_stream) { PADDLE_ENFORCE_NOT_NULL( @@ -279,6 +311,10 @@ std::shared_ptr ProcessGroupBKCL::Collective( task->comm_event_->Record(*comm_ctx.get()); } + if (sync_op) { + task->Wait(); + } + return task; } @@ -291,31 +327,22 @@ std::shared_ptr ProcessGroupBKCL::AllReduce( auto tensor_tmp = paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor); return Collective( - out_tensor, - tensor_tmp, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_all_reduce" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", numel: " << input.numel() << ", dtype: " << input.type() - << ", reduce_type: " << ToBKCLRedType(opts.reduce_op) - << ", sync_op: " << sync_op + [&](phi::distributed::BKCLCommContext* comm_context, XPUStream stream) { + VLOG(3) << "bkcl_all_reduce" + << "sendbuff: " << tensor_tmp.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << tensor_tmp.numel() << ", datatype: " + << BKCLDTypeToString(phi::ToBKCLDataType(tensor_tmp.dtype())) + << ", redop: " << ToBKCLRedType(opts.reduce_op) + << ", bkcl_comm: " << comm_context->GetBKCLComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op << ", use_calc_stream: " << use_calc_stream; - int r = - bkcl_all_reduce(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - stream); - return r; + + comm_context->AllReduce( + out_tensor, tensor_tmp, ToBKCLRedType(opts.reduce_op), stream); }, + tensor_tmp, CommType::ALLREDUCE, sync_op, use_calc_stream); @@ -330,45 +357,22 @@ std::shared_ptr ProcessGroupBKCL::Broadcast( auto tensor_tmp = paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor); return Collective( - out_tensor, - tensor_tmp, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { + [&](phi::distributed::BKCLCommContext* comm_context, XPUStream stream) { int root = opts.source_rank + opts.source_root; - VLOG(3) << "calling bkcl_broadcast" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", root: " << root << ", numel: " << input.numel() - << ", dtype: " << input.type() << ", sync_op: " << sync_op + + VLOG(3) << "[bkcl_broadcast] " + << "sendbuff: " << tensor_tmp.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << tensor_tmp.numel() << ", datatype: " + << BKCLDTypeToString(phi::ToBKCLDataType(tensor_tmp.dtype())) + << ", root: " << root + << ", bkcl_comm: " << comm_context->GetBKCLComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op << ", use_calc_stream: " << use_calc_stream; - if (framework::TransToProtoVarType(input.dtype()) == - framework::proto::VarType::INT64) { - // special for int64_t, send as int32_t with DOUBLE NUMEL - int r = bkcl_broadcast( - comm, - input.data(), - output->data(), - input.numel() * 2, - platform::ToBKCLDataType(framework::proto::VarType::INT32), - root, - stream); - return r; - } else { - int r = - bkcl_broadcast(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - root, - stream); - return r; - } + comm_context->Broadcast(out_tensor, tensor_tmp, root, stream); }, + tensor_tmp, CommType::BROADCAST, sync_op, use_calc_stream); @@ -392,29 +396,22 @@ std::shared_ptr ProcessGroupBKCL::AllGather( size_, phi::AllocationType::XPU); return Collective( - out_tensor, - in_tensor_maybe_partial, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_all_gather" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", numel: " << in_tensor_maybe_partial.numel() - << ", dtype: " << input.type() << ", sync_op: " << sync_op + [&](phi::distributed::BKCLCommContext* comm_context, XPUStream stream) { + VLOG(3) << "bkcl_all_gather" + << "sendbuff: " << in_tensor_maybe_partial.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor_maybe_partial.numel() + << ", datatype: " + << BKCLDTypeToString(phi::ToBKCLDataType(in_tensor.dtype())) + << ", bkcl_comm: " << comm_context->GetBKCLComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", offset: " << offset + << ", sync_op: " << sync_op << ", use_calc_stream: " << use_calc_stream; - int r = - bkcl_all_gather(comm, - in_tensor_maybe_partial.data(), - in_tensor_maybe_partial.numel(), - output->data(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - stream); - return r; + + comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream); }, + in_tensor_maybe_partial, CommType::ALLGATHER, sync_op, use_calc_stream); @@ -429,32 +426,26 @@ std::shared_ptr ProcessGroupBKCL::Reduce( auto tensor_tmp = paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor); return Collective( - out_tensor, - tensor_tmp, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_reduce" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", root: " << opts.root_rank << ", numel: " << input.numel() - << ", dtype: " << input.type() - << ", reduce_type: " << ToBKCLRedType(opts.reduce_op) - << ", sync_op: " << sync_op + [&](phi::distributed::BKCLCommContext* comm_context, XPUStream stream) { + VLOG(3) << "[bkcl_reduce] " + << "sendbuff: " << tensor_tmp.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << tensor_tmp.numel() << ", datatype: " + << BKCLDTypeToString(phi::ToBKCLDataType(tensor_tmp.dtype())) + << ", redop: " + << BKCLRedTypeToString(ToBKCLRedType(opts.reduce_op)) + << ", root: " << opts.root_rank + << ", bkcl_comm: " << comm_context->GetBKCLComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op << ", use_calc_stream: " << use_calc_stream; - int r = bkcl_reduce(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - opts.root_rank, - stream); - return r; + comm_context->Reduce(out_tensor, + tensor_tmp, + ToBKCLRedType(opts.reduce_op), + opts.root_rank, + stream); }, + tensor_tmp, CommType::REDUCE, sync_op, use_calc_stream); @@ -469,31 +460,22 @@ std::shared_ptr ProcessGroupBKCL::ReduceScatter( auto tensor_tmp = paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensor); return Collective( - out_tensor, - tensor_tmp, - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_reduce_scatter" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", numel: " << output->numel() << ", dtype: " << input.type() - << ", reduce_type: " << ToBKCLRedType(opts.reduce_op) - << ", sync_op: " << sync_op + [&](phi::distributed::BKCLCommContext* comm_context, XPUStream stream) { + VLOG(3) << "[bkcl_reduce_scatter] " + << "sendbuff: " << tensor_tmp.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << tensor_tmp.numel() << ", datatype: " + << BKCLDTypeToString(phi::ToBKCLDataType(tensor_tmp.dtype())) + << ", redop: " + << BKCLRedTypeToString(ToBKCLRedType(opts.reduce_op)) + << ", bkcl_comm: " << comm_context->GetBKCLComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op << ", use_calc_stream: " << use_calc_stream; - int r = bkcl_reduce_scatter( - comm, - input.data(), - output->data(), - output->numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - stream); - return r; + comm_context->ReduceScatter( + out_tensor, tensor_tmp, ToBKCLRedType(opts.reduce_op), stream); }, + tensor_tmp, CommType::REDUCE_SCATTER, sync_op, use_calc_stream); @@ -542,314 +524,6 @@ phi::DeviceContext* ProcessGroupBKCL::GetDeviceContext( } } -// below are old apis -std::shared_ptr ProcessGroupBKCL::AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& opts) { - auto tensor_tmp = - paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors); - PADDLE_ENFORCE_EQ( - tensor_tmp.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - out_tensors.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInXPUPlace(tensor_tmp), - true, - platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); - return Collective( - &out_tensors[0], - tensor_tmp[0], - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_all_reduce" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", numel: " << input.numel() << ", dtype: " << input.type() - << ", reduce_type: " << ToBKCLRedType(opts.reduce_op) - << ", sync_op: " << true << ", use_calc_stream: " << false; - int r = - bkcl_all_reduce(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - stream); - return r; - }, - CommType::ALLREDUCE, - /*sync_op*/ true, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupBKCL::AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& opts, - bool sync_op) { - auto tensor_tmp = - paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors); - PADDLE_ENFORCE_EQ( - tensor_tmp.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - out_tensors.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInXPUPlace(tensor_tmp), - true, - platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); - return Collective( - &out_tensors[0], - tensor_tmp[0], - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_all_reduce" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", numel: " << input.numel() << ", dtype: " << input.type() - << ", reduce_type: " << ToBKCLRedType(opts.reduce_op) - << ", sync_op: " << sync_op << ", use_calc_stream: " << false; - int r = - bkcl_all_reduce(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - ToBKCLRedType(opts.reduce_op), - stream); - return r; - }, - CommType::ALLREDUCE, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupBKCL::Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& opts) { - auto tensor_tmp = - paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors); - PADDLE_ENFORCE_EQ( - tensor_tmp.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - out_tensors.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInXPUPlace(tensor_tmp), - true, - platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); - - return Collective( - &out_tensors[0], - tensor_tmp[0], - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - const auto root = - opts.source_rank * tensor_tmp.size() + opts.source_root; - VLOG(3) << "calling bkcl_broadcast" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", root: " << root << ", numel: " << input.numel() - << ", dtype: " << input.type() << ", sync_op: " << true - << ", use_calc_stream: " << false; - int r = - bkcl_broadcast(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - root, - stream); - return r; - }, - CommType::BROADCAST, - /*sync_op*/ true, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupBKCL::Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& opts, - bool sync_op) { - auto tensor_tmp = - paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors); - PADDLE_ENFORCE_EQ( - tensor_tmp.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - out_tensors.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInXPUPlace(tensor_tmp), - true, - platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); - - return Collective( - &out_tensors[0], - tensor_tmp[0], - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - const auto root = - opts.source_rank * tensor_tmp.size() + opts.source_root; - VLOG(3) << "calling bkcl_broadcast" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", root: " << root << ", numel: " << input.numel() - << ", dtype: " << input.type() << ", sync_op: " << sync_op - << ", use_calc_stream: " << false; - int r = - bkcl_broadcast(comm, - input.data(), - output->data(), - input.numel(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - root, - stream); - return r; - }, - CommType::BROADCAST, - sync_op, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupBKCL::AllGather( - std::vector& in_tensors, - std::vector& out_tensors) { - auto tensor_tmp = - paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors); - PADDLE_ENFORCE_EQ( - tensor_tmp.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - out_tensors.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInXPUPlace(tensor_tmp), - true, - platform::errors::InvalidArgument("All inputs should be in XPUPlace.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInXPUPlace(out_tensors), - true, - platform::errors::InvalidArgument("All outputs should be in XPUPlace.")); - return Collective( - &out_tensors[0], - tensor_tmp[0], - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_all_gather" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", numel: " << input.numel() << ", dtype: " << input.type() - << ", sync_op: " << true << ", use_calc_stream: " << false; - int r = - bkcl_all_gather(comm, - input.data(), - input.numel(), - output->data(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - stream); - return r; - }, - CommType::ALLGATHER, - /*sync_op*/ true, - /*use_calc_stream*/ false); -} - -std::shared_ptr ProcessGroupBKCL::AllGather( - std::vector& in_tensors, - std::vector& out_tensors, - bool sync_op) { - auto tensor_tmp = - paddle::experimental::CheckAndTrans2NewContiguousTensor(in_tensors); - PADDLE_ENFORCE_EQ( - tensor_tmp.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - out_tensors.size(), - 1, - platform::errors::InvalidArgument( - "BKCL only support single tensor collective communication.")); - PADDLE_ENFORCE_EQ( - CheckTensorsInXPUPlace(out_tensors), - true, - platform::errors::InvalidArgument("All outputs should be in XPUPlace.")); - return Collective( - &out_tensors[0], - tensor_tmp[0], - [&](phi::DenseTensor* output, - const phi::DenseTensor& input, - BKCLContext_t comm, - const XPUStream& stream) { - VLOG(3) << "calling bkcl_all_gather" - << ", rank_id: " << platform::GetBKCLRankID(comm) - << ", dev_id: " << platform::GetBKCLDevID(comm) - << ", nranks: " << platform::GetBKCLNRanks(comm) - << ", numel: " << input.numel() << ", dtype: " << input.type() - << ", sync_op: " << sync_op << ", use_calc_stream: " << false; - int r = - bkcl_all_gather(comm, - input.data(), - input.numel(), - output->data(), - platform::ToBKCLDataType( - framework::TransToProtoVarType(input.type())), - stream); - return r; - }, - CommType::ALLGATHER, - sync_op, - /*use_calc_stream*/ false); -} - std::shared_ptr ProcessGroupBKCL::CreateProcessGroupBKCL( const std::shared_ptr& store, int rank, @@ -861,5 +535,16 @@ std::shared_ptr ProcessGroupBKCL::CreateProcessGroupBKCL( return process_group; } +phi::distributed::BKCLCommContext* ProcessGroupBKCL::GetCommContext() { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + auto comm_context = static_cast( + comm_context_manager.Get(std::to_string(this->gid_))); + PADDLE_ENFORCE_NE(comm_context, + nullptr, + phi::errors::Unavailable("BKCLCommContext is nullptr")); + return comm_context; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.h b/paddle/fluid/distributed/collective/process_group_bkcl.h index ea89dbdd6d87d..bd83257cc8d63 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.h +++ b/paddle/fluid/distributed/collective/process_group_bkcl.h @@ -25,6 +25,7 @@ #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/bkcl_comm_context.h" #include "paddle/phi/core/distributed/store/store.h" #if defined(PADDLE_WITH_XPU) @@ -145,38 +146,6 @@ class ProcessGroupBKCL : public ProcessGroupWithStream { BKCLContext_t BKCLComm(const Place& place) const; - // below are old apis - std::shared_ptr AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& = AllreduceOptions()) override; - - std::shared_ptr AllReduce( - std::vector& in_tensors, - std::vector& out_tensors, - const AllreduceOptions& options, - bool sync_op) override; - - std::shared_ptr Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions& = BroadcastOptions()) override; - - std::shared_ptr Broadcast( - std::vector& in_tensors, - std::vector& out_tensors, - const BroadcastOptions&, - bool sync_op) override; - - std::shared_ptr AllGather( - std::vector& in_tensors, - std::vector& out_tensors) override; - - std::shared_ptr AllGather( - std::vector& in_tensors, - std::vector& out_tensors, - bool sync_op) override; - private: std::shared_ptr CreateTask(const Place& place, int rank, @@ -188,16 +157,24 @@ class ProcessGroupBKCL : public ProcessGroupWithStream { void CreateBKCLEnvCache(const Place& place, const std::string& place_key); - template std::shared_ptr Collective( - phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor, - Fn fn, + std::function fn, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Point2Point( + std::function + fn, + int peer, + const phi::DenseTensor& tensor, CommType comm_type, bool sync_op, bool use_calc_stream); void SyncCalcStream(const Place& place); + phi::distributed::BKCLCommContext* GetCommContext(); private: std::shared_ptr store_; diff --git a/paddle/phi/backends/xpu/enforce_xpu.h b/paddle/phi/backends/xpu/enforce_xpu.h index 0a2a21e236d04..e4fc15f4cb747 100644 --- a/paddle/phi/backends/xpu/enforce_xpu.h +++ b/paddle/phi/backends/xpu/enforce_xpu.h @@ -16,9 +16,6 @@ limitations under the License. */ #include "paddle/phi/backends/xpu/xpu_header.h" #include "paddle/phi/core/enforce.h" -#ifdef PADDLE_WITH_XPU_BKCL -#include "xpu/bkcl.h" -#endif namespace phi { namespace backends { diff --git a/paddle/phi/backends/xpu/xpu_header.h b/paddle/phi/backends/xpu/xpu_header.h index 36caaf00f5ea4..17cb0c4615efa 100644 --- a/paddle/phi/backends/xpu/xpu_header.h +++ b/paddle/phi/backends/xpu/xpu_header.h @@ -21,6 +21,9 @@ limitations under the License. */ #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" +#ifdef PADDLE_WITH_XPU_BKCL +#include "xpu/bkcl.h" +#endif #include "xpu/runtime.h" #include "xpu/runtime_ex.h" #include "xpu/xdnn.h" diff --git a/paddle/phi/backends/xpu/xpu_info.h b/paddle/phi/backends/xpu/xpu_info.h index ad5a0b9745832..e47a45e9b2451 100644 --- a/paddle/phi/backends/xpu/xpu_info.h +++ b/paddle/phi/backends/xpu/xpu_info.h @@ -71,13 +71,10 @@ void MemcpySyncD2D(void *dst, class XPUDeviceGuard { public: - explicit inline XPUDeviceGuard(int dev_id) { - int prev_id = GetXPUCurrentDeviceId(); - if (prev_id != dev_id) { - prev_id_ = prev_id; - SetXPUDeviceId(dev_id); - } - } + explicit XPUDeviceGuard(int dev_id) { SetDeviceIndex(dev_id); } + + explicit XPUDeviceGuard(const XPUPlace &place) + : XPUDeviceGuard(place.device) {} inline ~XPUDeviceGuard() { if (prev_id_ != -1) { @@ -85,6 +82,14 @@ class XPUDeviceGuard { } } + inline void SetDeviceIndex(const int dev_id) { + int prev_id = GetXPUCurrentDeviceId(); + if (prev_id != dev_id) { + prev_id_ = prev_id; + SetXPUDeviceId(dev_id); + } + } + XPUDeviceGuard(const XPUDeviceGuard &o) = delete; XPUDeviceGuard &operator=(const XPUDeviceGuard &o) = delete; diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt index 8e58ab4bf840e..00000c3fff9e0 100644 --- a/paddle/phi/core/distributed/CMakeLists.txt +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -18,4 +18,8 @@ if(WITH_CUSTOM_DEVICE) list(APPEND DISTRIBUTED_COMMON_SRCS xccl_comm_context.cc) endif() +if(WITH_XPU_BKCL) + list(APPEND DISTRIBUTED_COMMON_SRCS bkcl_comm_context.cc) +endif() + collect_srcs(core_srcs SRCS ${DISTRIBUTED_COMMON_SRCS}) diff --git a/paddle/phi/core/distributed/bkcl_comm_context.cc b/paddle/phi/core/distributed/bkcl_comm_context.cc new file mode 100644 index 0000000000000..2f5fe0eb3ccbe --- /dev/null +++ b/paddle/phi/core/distributed/bkcl_comm_context.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/bkcl_comm_context.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { +namespace distributed { + +BKCLCommContext::BKCLCommContext(int rank, int size, BKCLUniqueId bkcl_id) + : CommContext(rank, size) { + PADDLE_ENFORCE_XPU_SUCCESS( + bkcl_init_rank(&bkcl_comm_, rank_, size_, &bkcl_id)); +} + +BKCLContext_t BKCLCommContext::GetBKCLComm() { return bkcl_comm_; } + +XPUStream BKCLCommContext::GetStream() { return dev_ctx_->stream(); } + +phi::XPUContext* BKCLCommContext::GetDevContext() { return dev_ctx_.get(); } + +void BKCLCommContext::SetDevContext( + std::unique_ptr&& dev_ctx) { + dev_ctx_ = std::move(dev_ctx); +} + +void BKCLCommContext::Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root, + XPUStream stream) { + CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::XPU); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_broadcast(bkcl_comm_, + in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToBKCLDataType(in_tensor.type()), + root, + stream)); +} + +void BKCLCommContext::AllGather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + XPUStream stream) { + phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::XPU); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_all_gather(bkcl_comm_, + in_tensor.data(), + in_tensor.numel(), + out_tensor->data(), + ToBKCLDataType(in_tensor.type()), + stream)); +} + +void BKCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + BKCLOp reduce_type, + XPUStream stream) { + phi::distributed::CommStaticCheck::ScatterLikeShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::XPU); + PADDLE_ENFORCE_XPU_SUCCESS( + bkcl_reduce_scatter(bkcl_comm_, + in_tensor.data(), + out_tensor->data(), + out_tensor->numel(), + ToBKCLDataType(in_tensor.type()), + reduce_type, + stream)); +} + +void BKCLCommContext::Send(const phi::DenseTensor& in_tensor, + const int64_t& count, + const int& peer, + XPUStream stream) { + phi::distributed::CommStaticCheck::CheckShape( + in_tensor, rank_, size_, phi::AllocationType::XPU); + + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_send(bkcl_comm_, + in_tensor.data(), + count, + peer, + ToBKCLDataType(in_tensor.dtype()), + stream)); + VLOG(3) << "rank " << GetRank() << " send " << phi::product(in_tensor.dims()) + << " to " << peer; +} + +void BKCLCommContext::Recv(phi::DenseTensor* out_tensor, + const int64_t& count, + const int& peer, + XPUStream stream) { + phi::distributed::CommStaticCheck::CheckShape( + *out_tensor, rank_, size_, phi::AllocationType::XPU); + + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_recv(bkcl_comm_, + out_tensor->data(), + count, + peer, + ToBKCLDataType(out_tensor->dtype()), + stream)); + VLOG(3) << "rank " << GetRank() << " recv " + << phi::product(out_tensor->dims()) << " from " << peer; +} + +void BKCLCommContext::AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + BKCLOp reduce_type, + XPUStream stream) { + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::XPU); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_all_reduce(bkcl_comm_, + in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToBKCLDataType(in_tensor.type()), + reduce_type, + stream)); +} + +void BKCLCommContext::Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + BKCLOp reduce_type, + int root, + XPUStream stream) { + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ root, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::XPU); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_reduce(bkcl_comm_, + in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToBKCLDataType(in_tensor.type()), + reduce_type, + root, + stream)); +} + +void BKCLCommContext::GroupStart() { + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); +} +void BKCLCommContext::GroupEnd() { + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end()); +} +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/bkcl_comm_context.h b/paddle/phi/core/distributed/bkcl_comm_context.h new file mode 100644 index 0000000000000..5ba0594aba7d4 --- /dev/null +++ b/paddle/phi/core/distributed/bkcl_comm_context.h @@ -0,0 +1,85 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/distributed/comm_context.h" + +namespace phi { +class DenseTensor; +namespace distributed { + +class BKCLCommContext final : public CommContext { + public: + BKCLCommContext(int rank, int size, BKCLUniqueId BKCL_id); + ~BKCLCommContext() override = default; + + BKCLContext_t GetBKCLComm(); + + XPUStream GetStream(); + + phi::XPUContext* GetDevContext(); + + void SetDevContext(std::unique_ptr&& dev_ctx); + + void Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root, + XPUStream stream); + + void Send(const phi::DenseTensor& in_tensor, + const int64_t& count, + const int& peer, + XPUStream stream); + + void Recv(phi::DenseTensor* out_tensor, + const int64_t& count, + const int& peer, + XPUStream stream); + + void ReduceScatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + BKCLOp reduce_type, + XPUStream stream); + + void AllGather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + XPUStream stream); + + void AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + BKCLOp reduce_type, + XPUStream stream); + + void Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + BKCLOp reduce_type, + int root, + XPUStream stream); + + void GroupStart(); + + void GroupEnd(); + + private: + DISABLE_COPY_AND_ASSIGN(BKCLCommContext); + + BKCLContext_t bkcl_comm_; + + std::unique_ptr dev_ctx_; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 2a5b336f34e25..d7f9d81214265 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -39,6 +39,11 @@ #include "paddle/phi/core/distributed/xccl_comm_context.h" #endif +#ifdef PADDLE_WITH_XPU_BKCL +#include "paddle/phi/backends/xpu/xpu_info.h" +#include "paddle/phi/core/distributed/bkcl_comm_context.h" +#endif + namespace phi { namespace distributed { @@ -169,6 +174,40 @@ void CommContextManager::CreateXCCLCommContext( } #endif +#if defined(PADDLE_WITH_XPU_BKCL) +void CommContextManager::CreateBKCLCommContext( + const std::shared_ptr& store, + const std::string& unique_comm_key, + int rank, + int size, + const std::string& hash_key) { + auto& comm_context_manager = CommContextManager::GetInstance(); + if (comm_context_manager.Has(unique_comm_key)) { + return; + } + BKCLUniqueId bkcl_id; + if (rank == 0) { + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id)); + } + + std::string unique_key = "BKCLCommContext/" + unique_comm_key + hash_key; + if (rank == 0) { + std::vector bkcl_id_wrapper( + reinterpret_cast(&bkcl_id), + reinterpret_cast(&bkcl_id) + BKCL_UNIQUE_ID_BYTES); + store->set(unique_key, bkcl_id_wrapper); + } else { + const auto& bkcl_id_wrapper = store->get(unique_key); + std::memcpy(&bkcl_id, bkcl_id_wrapper.data(), bkcl_id_wrapper.size()); + } + + auto bkcl_comm_context = + std::make_unique(rank, size, bkcl_id); + + comm_context_manager.SetStore(store); + comm_context_manager.Emplace(unique_comm_key, std::move(bkcl_comm_context)); +} +#endif CommContext* CommContextManager::Emplace( const std::string& unique_comm_key, std::unique_ptr comm_context) { diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 2229786db3855..f9de38d8b5005 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -89,6 +89,14 @@ class CommContextManager { const std::string& hash_key = ""); #endif +#if defined(PADDLE_WITH_XPU_BKCL) + static void CreateBKCLCommContext(const std::shared_ptr& store, + const std::string& unique_comm_key, + int rank, + int size, + const std::string& hash_key = ""); +#endif + private: DISABLE_COPY_AND_ASSIGN(CommContextManager); diff --git a/paddle/phi/core/enforce.h b/paddle/phi/core/enforce.h index 48ade86e6ae33..aa68dd802c0b4 100644 --- a/paddle/phi/core/enforce.h +++ b/paddle/phi/core/enforce.h @@ -93,6 +93,9 @@ limitations under the License. */ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/phi/backends/gpu/gpu_types.h" #endif +#if defined(PADDLE_WITH_XPU_BKCL) +#include "xpu/bkcl.h" +#endif #include "paddle/utils/variant.h" diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index d24320900dbc6..ebef7410c31bf 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1361,7 +1361,8 @@ PHI_DEFINE_EXPORTED_int32( PHI_DEFINE_EXPORTED_bool(print_ir, false, "Whether print ir debug str."); -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) /** * Communication library related FLAG * Name: FLAGS_dynamic_static_unified_comm diff --git a/paddle/phi/core/utils/data_type.h b/paddle/phi/core/utils/data_type.h index 018672e45b597..449d7cbe8966d 100644 --- a/paddle/phi/core/utils/data_type.h +++ b/paddle/phi/core/utils/data_type.h @@ -239,5 +239,29 @@ inline ncclDataType_t ToNCCLDataType(DataType type) { } } #endif +#if defined(PADDLE_WITH_XPU_BKCL) +inline BKCLDataType ToBKCLDataType(DataType type) { + if (type == DataType::FLOAT32) { + return BKCL_FLOAT; + } else if (type == DataType::FLOAT64) { + return BKCL_FLOAT64; + } else if (type == DataType::INT32) { + return BKCL_INT32; + } else if (type == DataType::INT64) { + return BKCL_INT64; + } else if (type == DataType::FLOAT16) { + return BKCL_FLOAT16; + } else if (type == DataType::UINT8) { + return BKCL_UINT8; + } else if (type == DataType::BOOL) { + return BKCL_UINT8; + } else if (type == DataType::BFLOAT16) { + return BKCL_BFLOAT16; + } else { + PADDLE_THROW( + errors::Unimplemented("This datatype in bkcl is not supported.")); + } +} +#endif } // namespace phi