From 5f126e50d0995cbe3f8aff9ef574254c0cc1bfc5 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 14 Jun 2022 21:19:45 +0800 Subject: [PATCH 01/31] add adam/sharedadam optimzier for gpups;edit optimizer struct;test=develop --- .../distributed/ps/table/ctr_dymf_accessor.cc | 21 +- .../distributed/ps/table/ctr_dymf_accessor.h | 5 +- .../distributed/ps/wrapper/CMakeLists.txt | 7 + paddle/fluid/distributed/ps/wrapper/fleet.cc | 25 +- .../framework/distributed_strategy.proto | 4 +- paddle/fluid/framework/fleet/heter_context.h | 4 +- .../framework/fleet/heter_ps/feature_value.h | 381 +++++++++++- .../framework/fleet/heter_ps/hashtable.h | 8 +- .../fleet/heter_ps/hashtable_kernel.cu | 149 +++-- .../framework/fleet/heter_ps/heter_comm.h | 16 +- .../framework/fleet/heter_ps/heter_comm_inl.h | 130 ++-- .../fleet/heter_ps/heter_comm_kernel.cu | 126 ++-- .../fleet/heter_ps/heter_comm_kernel.h | 57 +- .../framework/fleet/heter_ps/heter_ps.cc | 27 +- .../framework/fleet/heter_ps/heter_ps.cu | 48 +- .../fluid/framework/fleet/heter_ps/heter_ps.h | 21 +- .../framework/fleet/heter_ps/heter_ps_base.h | 16 +- .../fluid/framework/fleet/heter_ps/mem_pool.h | 26 +- .../framework/fleet/heter_ps/optimizer.cuh.h | 579 +++++++++++++++--- .../framework/fleet/heter_ps/optimizer_conf.h | 28 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 425 +++++++------ .../fluid/framework/fleet/ps_gpu_wrapper.cu | 83 ++- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 295 ++++++++- .../fluid/framework/fleet/ps_gpu_wrapper.kps | 12 +- paddle/fluid/framework/multi_trainer.cc | 88 +++ paddle/fluid/framework/trainer.h | 9 + paddle/fluid/pybind/ps_gpu_wrapper_py.cc | 2 + .../distributed/fleet/base/fleet_base.py | 28 +- python/paddle/distributed/ps/the_one_ps.py | 8 +- python/paddle/fluid/device_worker.py | 1 + python/paddle/fluid/trainer_factory.py | 6 + 31 files changed, 2066 insertions(+), 569 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc index c65eac99acc03d..82d761c37c5929 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc @@ -24,9 +24,12 @@ namespace distributed { int CtrDymfAccessor::Initialize() { auto name = _config.embed_sgd_param().name(); + _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1); + VLOG(0) << "CtrDymfAccessor::Initialize embed_sgd_param name:" << name + << " embedx_sgd_param name: " << _config.embedx_sgd_param().name(); name = _config.embedx_sgd_param().name(); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(), @@ -42,20 +45,32 @@ int CtrDymfAccessor::Initialize() { if (_config.ctr_accessor_param().show_scale()) { _show_scale = true; } - VLOG(0) << " INTO CtrDymfAccessor::Initialize()"; + VLOG(0) << " INTO CtrDymfAccessor::Initialize(); embed_sgd_dim:" << common_feature_value.embed_sgd_dim + << " embedx_dim:" << common_feature_value.embedx_dim + << " embedx_sgd_dim:" << common_feature_value.embedx_sgd_dim; InitAccessorInfo(); return 0; } +// int CtrDymfAccessor::InitializeDim(int embed_sgd_dim, int embedx_dim, int embedx_sgd_dim) { +// common_feature_value.embed_sgd_dim = embed_sgd_dim; +// common_feature_value.embedx_dim = embedx_dim; +// common_feature_value.embedx_sgd_dim = embedx_sgd_dim; +// VLOG(0) << " INTO CtrDymfAccessor::InitializeDim(); embed_sgd_dim:" << embed_sgd_dim +// << " embedx_dim:" << embedx_dim<< " embedx_sgd_dim:" << embedx_sgd_dim; +// InitAccessorInfo(); +// return 0; +// } + void CtrDymfAccessor::InitAccessorInfo() { _accessor_info.dim = common_feature_value.Dim(); _accessor_info.size = common_feature_value.Size(); auto embedx_dim = _config.embedx_dim(); VLOG(0) << "InitAccessorInfo embedx_dim:" << embedx_dim; - _accessor_info.select_dim = 3 + embedx_dim; + _accessor_info.select_dim = 4 + embedx_dim; _accessor_info.select_size = _accessor_info.select_dim * sizeof(float); - _accessor_info.update_dim = 4 + embedx_dim; + _accessor_info.update_dim = 5 + embedx_dim; _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); _accessor_info.mf_size = (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h index 38b3e6ecae68d9..9444922ac833e7 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h @@ -54,10 +54,10 @@ class CtrDymfAccessor : public ValueAccessor { int ClickIndex() { return ShowIndex() + 1; } int EmbedWIndex() { return ClickIndex() + 1; } int EmbedG2SumIndex() { return EmbedWIndex() + 1; } - int SlotIndex() { return EmbedG2SumIndex() + 1; } + int SlotIndex() { return EmbedG2SumIndex() + embed_sgd_dim; } int MfDimIndex() { return SlotIndex() + 1; } int EmbedxG2SumIndex() { return MfDimIndex() + 1; } - int EmbedxWIndex() { return EmbedxG2SumIndex() + 1; } + int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; } float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } @@ -151,6 +151,7 @@ class CtrDymfAccessor : public ValueAccessor { CtrDymfAccessor() {} virtual ~CtrDymfAccessor() {} virtual int Initialize(); + // virtual int InitializeDim(int embed_sgd_dim, int embedx_dim, int embedx_sgd_dim); // 初始化AccessorInfo virtual void InitAccessorInfo(); // 判断该value是否进行shrink diff --git a/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt b/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt index 8b5457ef9eea52..352e3aa19eb09f 100644 --- a/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt @@ -1,5 +1,11 @@ get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) +set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses") +if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(DISTRIBUTE_COMPILE_FLAGS + "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") +endif() + set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library( @@ -13,6 +19,7 @@ cc_library( op_registry fs shell + ps_gpu_wrapper ${RPC_DEPS}) target_link_libraries(fleet z) diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index 57ff1d3bcd4fa5..f474810bc87dd3 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -18,6 +18,10 @@ limitations under the License. */ #include "paddle/fluid/distributed/ps/service/communicator/communicator.h" #include "paddle/fluid/distributed/ps/table/table.h" +#include "paddle/fluid/distributed/ps/wrapper/fleet.h" +#if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE +#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" +#endif namespace paddle { namespace distributed { @@ -129,6 +133,13 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, worker_ptr_ = std::shared_ptr( paddle::distributed::PSClientFactory::Create(ps_param)); worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index); +#if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE + VLOG(0) << "FleetWrapper::InitWorker InitializeGPUServer"; + auto* accessor = worker_ptr_->GetTableAccessor(0); + auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance(); + ps_gpu_wrapper->InitializeGPUServer(ps_param); + ps_gpu_wrapper->SetTableAccessor(accessor); +#endif } } else { VLOG(3) << "Client can be initialized only once"; @@ -525,11 +536,11 @@ void FleetWrapper::PushSparseFromTensorAsync( int batch_size = -1; bool batch_size_consist = true; for (auto* input : *inputs) { - int cur_batch_size = + size_t cur_batch_size = input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0]; if (batch_size == -1) { - batch_size = cur_batch_size; - } else if (batch_size != cur_batch_size) { + batch_size = int(cur_batch_size); + } else if (batch_size != int(cur_batch_size)) { // CHECK(batch_size == cur_batch_size); // NOLINT batch_size_consist = false; break; @@ -537,12 +548,12 @@ void FleetWrapper::PushSparseFromTensorAsync( } CHECK(batch_size > 0); // NOLINT - int show_size = + size_t show_size = shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0]; - CHECK(show_size == batch_size || show_size == 1); - int clk_size = + CHECK(show_size == size_t(batch_size) || show_size == 1); + size_t clk_size = clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0]; - CHECK(clk_size == batch_size || clk_size == 1); + CHECK(clk_size == size_t(batch_size) || clk_size == 1); CHECK(outputs->size() == inputs->size()); std::vector push_keys; diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index b3a01ae169e4e2..7504e6f93a1e65 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -197,14 +197,14 @@ message TableParameter { message TableAccessorParameter { optional string accessor_class = 1; - optional SGDParameter embed_sgd_param = 2; - optional SGDParameter embedx_sgd_param = 3; optional uint32 fea_dim = 4 [ default = 11 ]; // field size of one value optional uint32 embedx_dim = 5 [ default = 8 ]; // embedx feature size optional uint32 embedx_threshold = 6 [ default = 10 ]; // embedx feature create threshold optional CtrAccessorParameter ctr_accessor_param = 7; repeated TableAccessorSaveParameter table_accessor_save_param = 8; + optional SGDParameter embed_sgd_param = 10; + optional SGDParameter embedx_sgd_param = 11; } message SGDParameter { diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 3955502c8b8081..3407608d90cdbc 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -81,7 +81,7 @@ class HeterContext { std::vector> device_values_; std::vector> device_keys_; std::vector>> device_dim_keys_; - std::vector>> device_dim_values_; + // std::vector>> device_dim_values_; std::vector mutex_; std::vector> dim_mutex_; int multi_mf_dim_ = 0; @@ -114,7 +114,7 @@ class HeterContext { value_dim_ptr_[i].resize(dim_num); } device_values_.resize(device_num); - device_dim_values_.resize(device_num); + // device_dim_values_.resize(device_num); device_keys_.resize(device_num); device_dim_keys_.resize(device_num); diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index cb7f3a40d6720b..43e70d8edabf22 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -17,6 +17,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_HETERPS #include +#include +#include + namespace paddle { namespace framework { @@ -24,23 +27,305 @@ namespace framework { typedef uint64_t FeatureKey; +struct GpuAccessorInfo { + // value维度 + size_t dim; + // value各个维度的size + size_t size; + // push value维度 + size_t update_dim; + // push value各个维度的size + size_t update_size; + // value中mf动态长度部分总size大小, sparse下生效 + size_t mf_size; +}; + +class FeatureValueAccessor { + public: + __host__ __device__ FeatureValueAccessor() {} + __host__ __device__ ~FeatureValueAccessor() {} + + __host__ __device__ virtual int Configure(std::unordered_map config) { + _config = config; + Initialize(); + return 0; + } + __host__ __device__ virtual int Initialize() = 0; + + __host__ __device__ virtual GpuAccessorInfo GetAccessorInfo() { return _accessor_info; } + + protected: + // TableAccessorParameter _config; + std::unordered_map _config; + GpuAccessorInfo _accessor_info; +}; + +// adagrad: embed_sgd_dim=1, embedx_sgd_dim=1,embedx_dim=n +// adam std: embed_sgd_dim=4, embedx_sgd_dim=n*2+2,embedx_dim=n +// adam shared: embed_sgd_dim=4, embedx_sgd_dim=4,embedx_dim=n +class CommonFeatureValueAccessor : public FeatureValueAccessor { + public: + struct CommonFeatureValue { + /* + uint64_t cpu_ptr; + float delta_score; + float show; + float click; + float embed_w; + std::vector embed_g2sum; + float slot; + float mf_dim + float mf_size + std::vector embedx_g2sum; + std::vector embedx_w; + */ + + __host__ __device__ int Dim() { return 8 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr + __host__ __device__ int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } + __host__ __device__ int Size() { return (Dim()-1) * sizeof(float) + sizeof(uint64_t); } + __host__ __device__ int EmbedDim() { return embed_sgd_dim;} + __host__ __device__ int EmbedXDim() { return embedx_sgd_dim;} + __host__ __device__ int EmbedWDim() { return embedx_dim;} + __host__ __device__ int CpuPtrIndex() {return 0; } // cpuprt uint64 + __host__ __device__ int DeltaScoreIndex() { return CpuPtrIndex() + 2; } + __host__ __device__ int ShowIndex() { return DeltaScoreIndex() + 1; } + __host__ __device__ int ClickIndex() { return ShowIndex() + 1; } + __host__ __device__ int EmbedWIndex() { return ClickIndex() + 1; } + __host__ __device__ int EmbedG2SumIndex() { return EmbedWIndex() + 1; } + __host__ __device__ int SlotIndex() { return EmbedG2SumIndex() + embed_sgd_dim; } + __host__ __device__ int MfDimIndex() { return SlotIndex() + 1; } + __host__ __device__ int MfSizeIndex() { return MfDimIndex() + 1; } // actual mf size (ex. 0) + __host__ __device__ int EmbedxG2SumIndex() { return MfSizeIndex() + 1; } + __host__ __device__ int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } + + __host__ __device__ uint64_t CpuPtr(float* val) {return *(reinterpret_cast(val)); } + __host__ __device__ float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } + __host__ __device__ float& Show(float* val) { return val[ShowIndex()]; } + __host__ __device__ float& Click(float* val) { return val[ClickIndex()]; } + __host__ __device__ float& Slot(float* val) { return val[SlotIndex()]; } + __host__ __device__ float& MfDim(float* val) { return val[MfDimIndex()]; } + __host__ __device__ float& MfSize(float* val) { return val[MfSizeIndex()]; } + __host__ __device__ float& EmbedW(float* val) { return val[EmbedWIndex()]; } + __host__ __device__ float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; } + __host__ __device__ float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; } + __host__ __device__ float& EmbedxW(float* val) { return val[EmbedxWIndex()]; } + + int embed_sgd_dim; + int embedx_dim; + int embedx_sgd_dim; + }; + + struct CommonPushValue { + /* + float slot; + float show; + float click; + float mf_dim; + float embed_g; + std::vector embedx_g; + */ + + __host__ __device__ int Dim(int embedx_dim) { return 5 + embedx_dim; } + + __host__ __device__ int DimSize(int dim, int embedx_dim) { return sizeof(float); } + __host__ __device__ int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } + __host__ __device__ int SlotIndex() { return 0; } + __host__ __device__ int ShowIndex() { return CommonPushValue::SlotIndex() + 1; } + __host__ __device__ int ClickIndex() { return CommonPushValue::ShowIndex() + 1; } + __host__ __device__ int MfDimIndex() { return CommonPushValue::ClickIndex() + 1; } + __host__ __device__ int EmbedGIndex() { return CommonPushValue::MfDimIndex() + 1; } + __host__ __device__ int EmbedxGIndex() { return CommonPushValue::EmbedGIndex() + 1; } + __host__ __device__ float& Slot(float* val) { + return val[CommonPushValue::SlotIndex()]; + } + __host__ __device__ float& Show(float* val) { + return val[CommonPushValue::ShowIndex()]; + } + __host__ __device__ float& Click(float* val) { + return val[CommonPushValue::ClickIndex()]; + } + __host__ __device__ float& MfDim(float* val) { + return val[CommonPushValue::MfDimIndex()]; + } + __host__ __device__ float& EmbedG(float* val) { + return val[CommonPushValue::EmbedGIndex()]; + } + __host__ __device__ float* EmbedxG(float* val) { + return val + CommonPushValue::EmbedxGIndex(); + } + }; + + + __host__ __device__ CommonFeatureValueAccessor() {} + __host__ __device__ ~CommonFeatureValueAccessor() {} + // __host__ __device__ virtual int Initialize() { + // std::string name = (_config.find("embed_sparse_optimizer") == _config.end()) + // ? "adagrad" + // : _config["embed_sparse_optimizer"]; + // int sparse_embedx_dim = (_config.find("sparse_embedx_dim") == _config.end()) + // ? 8 + // : std::stoi(_config["sparse_embedx_dim"]); + // if (name.compare("adam") == 0) { + // common_feature_value.embed_sgd_dim = 4; + // common_feature_value.embedx_sgd_dim = sparse_embedx_dim * 2 + 2; + // } else if (name.compare("sharedadam") == 0) { + // common_feature_value.embed_sgd_dim = 4; + // common_feature_value.embedx_sgd_dim = 4; + // } else { + // common_feature_value.embed_sgd_dim = 1; + // common_feature_value.embedx_sgd_dim = 1; + // } + + // common_feature_value.embedx_dim = sparse_embedx_dim; + + // // VLOG(0) << " INTO FeatureValueAccessor::Initialize()"; + // InitAccessorInfo(); + // return 0; + // } + + __host__ __device__ virtual int Initialize() { + int optimizer_type = (_config.find("optimizer_type") == _config.end()) + ? 1 + : int(_config["optimizer_type"]); + int sparse_embedx_dim = (_config.find("embedx_dim") == _config.end()) + ? 8 + : int(_config["embedx_dim"]); + if (optimizer_type == 3) { //adam + common_feature_value.embed_sgd_dim = 4; + common_feature_value.embedx_sgd_dim = sparse_embedx_dim * 2 + 2; + } else if (optimizer_type == 4) { //sharedadam + common_feature_value.embed_sgd_dim = 4; + common_feature_value.embedx_sgd_dim = 4; + } else { + common_feature_value.embed_sgd_dim = 1; + common_feature_value.embedx_sgd_dim = 1; + } + + common_feature_value.embedx_dim = sparse_embedx_dim; + + // VLOG(0) << " INTO FeatureValueAccessor::Initialize()"; + InitAccessorInfo(); + return 0; + } + + // 初始化AccessorInfo + __host__ __device__ virtual void InitAccessorInfo() { + _accessor_info.dim = common_feature_value.Dim(); + _accessor_info.size = common_feature_value.Size(); + + int embedx_dim = (_config.find("embedx_dim") == _config.end()) + ? 8 + : int(_config["embedx_dim"]); + // VLOG(0) << "feature value InitAccessorInfo embedx_dim:" << embedx_dim; + _accessor_info.update_dim = 5 + embedx_dim; + _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); + _accessor_info.mf_size = + (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); + } + + // friend std::ostream& operator<<(std::ostream& out, CommonFeatureValueAccessor& v) { + // /* + // uint64_t cpu_ptr; + // float delta_score; + // float show; + // float click; + // float embed_w; + // std::vector embed_g2sum; + // float slot; + // float mf_dim + // float mf_size + // std::vector embedx_g2sum; + // std::vector embedx_w; + // */ + // out << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4]; + // // << v[5] << " " << v[6]; + // for (int i = common_feature_value.EmbedG2SumIndex(); + // i < common_feature_value.EmbedxWIndex(); i++) { + // out << " " << v[i]; + // } + // out << " " << common_feature_value.Slot(v) << " " + // << common_feature_value.MfDim(v) + // << common_feature_value.MfSize(v); + + // for (int x = 0; x < common_feature_value.EmbedXDim(); x++) { + // out << " " << v[common_feature_value.EmbedxG2SumIndex() + x]; + // } + // for (int x = 0; x < common_feature_value.MfDim(v); x++) { + // out << " " << v[common_feature_value.EmbedxWIndex() + x]; + // } + // return out; + // } + + + __host__ __device__ std::string ParseToString(const float* v, int param_size) { + /* + uint64_t cpu_ptr; // 2float + float delta_score; + float show; + float click; + float embed_w; + std::vector embed_g2sum; + float slot; + float mf_dim + float mf_size + std::vector embedx_g2sum; + std::vector embedx_w; + */ + std::stringstream os; + os << "cpuptr: " << common_feature_value.CpuPtr(const_cast(v)) << " delta_score: " << v[2] + << " show: " << v[3] << " click: " << v[4] + << " embed_w:" << v[5] << " embed_g2sum:"; + for (int i = common_feature_value.EmbedG2SumIndex(); + i < common_feature_value.SlotIndex(); i++) { + os << " " << v[i]; + } + os << " slot: " << common_feature_value.Slot(const_cast(v)) + << " mf_dim: " << common_feature_value.MfDim(const_cast(v)) + << " mf_size: " << common_feature_value.MfSize(const_cast(v)) + << " mf: "; + if (param_size > common_feature_value.EmbedxG2SumIndex()) { + for (auto i = common_feature_value.EmbedxG2SumIndex(); + i < common_feature_value.Dim(); ++i) { + os << " " << v[i]; + } + } + return os.str(); + } + + public: + CommonFeatureValue common_feature_value; + CommonPushValue common_push_value; + // SparseValueSGDRule* _embed_sgd_rule; + // SparseValueSGDRule* _embedx_sgd_rule; +}; + + struct FeatureValue { float delta_score; float show; float clk; int slot; float lr; - float lr_g2sum; int mf_size; int mf_dim; uint64_t cpu_ptr; + int lr_sgd_dim; + int mf_sgd_dim; + float lr_g2sum[1]; + float mf_g2sum[1]; float mf[0]; friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) { out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot << " lr: " << val.lr << " mf_dim: " << val.mf_dim << "cpuptr: " << val.cpu_ptr << " mf_size: " << val.mf_size << " mf:"; - for (int i = 0; i < val.mf_dim + 1; ++i) { + for (int i = 0; i < val.lr_sgd_dim; ++i) { + out << " " << val.lr_g2sum[i]; + } + for (int i = 0; i < val.mf_sgd_dim; ++i) { + out << " " << val.mf_g2sum[i]; + } + for (int i = 0; i < val.mf_dim; ++i) { out << " " << val.mf[i]; } return out; @@ -51,16 +336,104 @@ struct FeatureValue { clk = in.clk; slot = in.slot; lr = in.lr; - lr_g2sum = in.lr_g2sum; mf_size = in.mf_size; mf_dim = in.mf_dim; cpu_ptr = in.cpu_ptr; - for (int i = 0; i < mf_dim + 1; i++) { + lr_sgd_dim = in.lr_sgd_dim; + mf_sgd_dim = in.mf_sgd_dim; + + for (int i = 0; i < lr_sgd_dim; ++i) { + lr_g2sum[i] = in.lr_g2sum[i]; + } + for (int i = 0; i < mf_sgd_dim; ++i) { + mf_g2sum[i] = in.mf_g2sum[i]; + } + for (int i = 0; i < mf_dim; i++) { mf[i] = in.mf[i]; } } }; +// struct AdamFeatureValue { +// float delta_score; +// float show; +// float clk; +// int slot; +// float lr; +// int mf_size; +// int mf_dim; +// int lr_sgd_dim; +// int mf_sgd_dim; +// uint64_t cpu_ptr; +// float lr_g2sum[0]; +// float mf_g2sum[0]; +// float mf[0]; + + +// __device__ __forceinline__ void operator=(const FeatureValue& in) { +// delta_score = in.delta_score; +// show = in.show; +// clk = in.clk; +// slot = in.slot; +// lr = in.lr; +// mf_size = in.mf_size; +// mf_dim = in.mf_dim; +// cpu_ptr = in.cpu_ptr; +// lr_sgd_dim = in.lr_sgd_dim; +// mf_sgd_dim = in.mf_sgd_dim; + +// for (int i = 0; i < lr_sgd_dim; ++i) { +// lr_g2sum[i] = in.lr_g2sum[i]; +// } +// for (int i = 0; i < mf_sgd_dim; ++i) { +// mf_g2sum[i] = in.mf_g2sum[i]; +// } +// for (int i = 0; i < mf_dim; i++) { +// mf[i] = in.mf[i]; +// } +// } +// }; + +// struct AdamSharedFeatureValue { +// float delta_score; +// float show; +// float clk; +// int slot; +// float lr; +// int mf_size; +// int mf_dim; +// int lr_sgd_dim; +// int mf_sgd_dim; +// uint64_t cpu_ptr; +// float lr_g2sum[4]; +// float mf_g2sum[4]; +// float mf[0]; + +// __device__ __forceinline__ void operator=(const FeatureValue& in) { +// delta_score = in.delta_score; +// show = in.show; +// clk = in.clk; +// slot = in.slot; +// lr = in.lr; +// mf_size = in.mf_size; +// mf_dim = in.mf_dim; +// cpu_ptr = in.cpu_ptr; +// lr_sgd_dim = in.lr_sgd_dim; +// mf_sgd_dim = in.mf_sgd_dim; + +// for (int i = 0; i < lr_sgd_dim; ++i) { +// lr_g2sum[i] = in.lr_g2sum[i]; +// } +// for (int i = 0; i < mf_sgd_dim; ++i) { +// mf_g2sum[i] = in.mf_g2sum[i]; +// } +// for (int i = 0; i < mf_dim; i++) { +// mf[i] = in.mf[i]; +// } +// } +// }; + + struct FeaturePushValue { float show; float clk; diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index dbd6130c1461dc..8b0d4dd3a53202 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -150,7 +150,7 @@ class HashTable { #if defined(PADDLE_WITH_CUDA) - template + template void update(const KeyType* d_keys, const GradType* d_grads, size_t len, @@ -189,7 +189,12 @@ class HashTable { << " push value size: " << push_grad_value_size_; } + void set_accessor(CommonFeatureValueAccessor& accessor) { + feature_value_accessor_ = accessor; + } + std::unique_ptr rwlock_{nullptr}; + CommonFeatureValueAccessor feature_value_accessor_; private: #if defined(PADDLE_WITH_CUDA) @@ -206,6 +211,7 @@ class HashTable { size_t max_mf_dim_ = 8; size_t pull_feature_value_size_; size_t push_grad_value_size_; + }; } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 04842caef6b7f9..467564f97b0f8c 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -88,7 +88,8 @@ __global__ void dy_mf_search_kernel(Table* table, const typename Table::key_type* const keys, char* vals, size_t len, - size_t pull_feature_value_size) { + size_t pull_feature_value_size, + CommonFeatureValueAccessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; // return; if (i < len) { @@ -96,19 +97,44 @@ __global__ void dy_mf_search_kernel(Table* table, if (it != table->end()) { uint64_t offset = i * pull_feature_value_size; - FeatureValue* cur = (FeatureValue*)(vals + offset); - FeatureValue& input = *(FeatureValue*)(it->second); - cur->slot = input.slot; - cur->show = input.show; - cur->clk = input.clk; - cur->mf_dim = input.mf_dim; - cur->lr = input.lr; - cur->mf_size = input.mf_size; - cur->cpu_ptr = input.cpu_ptr; - cur->delta_score = input.delta_score; - cur->lr_g2sum = input.lr_g2sum; - for (int j = 0; j < cur->mf_dim + 1; ++j) { - cur->mf[j] = input.mf[j]; + float* cur = (float*)(vals + offset); + float* input = it->second; + + cur[feature_value_accessor.common_feature_value.SlotIndex()] = + input[feature_value_accessor.common_feature_value.SlotIndex()]; + cur[feature_value_accessor.common_feature_value.ShowIndex()] = + input[feature_value_accessor.common_feature_value.ShowIndex()]; + cur[feature_value_accessor.common_feature_value.ClickIndex()] = + input[feature_value_accessor.common_feature_value.ClickIndex()]; + cur[feature_value_accessor.common_feature_value.MfDimIndex()] = + input[feature_value_accessor.common_feature_value.MfDimIndex()]; + cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = + input[feature_value_accessor.common_feature_value.EmbedWIndex()]; + cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = + input[feature_value_accessor.common_feature_value.MfSizeIndex()]; + cur[feature_value_accessor.common_feature_value.CpuPtrIndex()] = + input[feature_value_accessor.common_feature_value.CpuPtrIndex()]; + cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = + input[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; + cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = + input[feature_value_accessor.common_feature_value.EmbedWIndex()]; + printf("dy_mf_search_kernel table slot: %f; show: %f; click: %f; lr: %f", + cur[feature_value_accessor.common_feature_value.SlotIndex()], + cur[feature_value_accessor.common_feature_value.ShowIndex()], + cur[feature_value_accessor.common_feature_value.ClickIndex()], + cur[feature_value_accessor.common_feature_value.EmbedWIndex()]); + for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { + cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = + input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; + } + + for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedXDim(); x++) { + cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x] = + input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x]; + } + for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedWDim(); x++) { + cur[feature_value_accessor.common_feature_value.EmbedxWIndex() + x] = + input[feature_value_accessor.common_feature_value.EmbedxWIndex() + x]; } } else { if (keys[i] != 0) { @@ -146,8 +172,10 @@ __global__ void dy_mf_update_kernel(Table* table, if (i < len) { auto it = table->find(keys[i]); if (it != table->end()) { - FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size); - sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur); + float* cur = (float*)(grads + i * grad_value_size); + sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, cur); + // printf("dy_mf_update_kernel: %d, %s", keys[i], + // sgd.feature_value_accessor_.ParseToString(cur, sgd.feature_value_accessor_.GetAccessorInfo().dim)); } else { if (keys[i] != 0) { printf("warning::push miss key: %llu", keys[i]); @@ -221,8 +249,11 @@ void HashTable::get(const KeyType* d_keys, return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; + VLOG(0) << "GET:" << feature_value_accessor_.common_feature_value.EmbedDim() + << " " << feature_value_accessor_.common_feature_value.EmbedXDim() + << " " << feature_value_accessor_.common_feature_value.EmbedWDim(); dy_mf_search_kernel<<>>( - container_, d_keys, d_vals, len, pull_feature_value_size_); + container_, d_keys, d_vals, len, pull_feature_value_size_, feature_value_accessor_); } template @@ -299,27 +330,27 @@ void HashTable::dump_to_cpu(int devid, StreamType stream) { } } #endif -#ifdef PADDLE_WITH_PSCORE - auto* downpour_value = - (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr); - int downpour_value_size = downpour_value->size(); - if (gpu_val.mf_size > 0 && downpour_value_size == 7) { - downpour_value->resize(gpu_val.mf_size + downpour_value_size); - } - float* cpu_val = downpour_value->data(); - // cpu_val[0] = 0; - cpu_val[2] = gpu_val.delta_score; - cpu_val[3] = gpu_val.show; - cpu_val[4] = gpu_val.clk; - cpu_val[5] = gpu_val.lr; - cpu_val[6] = gpu_val.lr_g2sum; - cpu_val[0] = gpu_val.slot; - if (gpu_val.mf_size > 0) { - for (int x = 0; x < gpu_val.mf_size; x++) { - cpu_val[x + 7] = gpu_val.mf[x]; - } - } -#endif +// #ifdef PADDLE_WITH_PSCORE +// auto* downpour_value = +// (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr); +// int downpour_value_size = downpour_value->size(); +// if (gpu_val.mf_size > 0 && downpour_value_size == 7) { +// downpour_value->resize(gpu_val.mf_size + downpour_value_size); +// } +// float* cpu_val = downpour_value->data(); +// // cpu_val[0] = 0; +// cpu_val[2] = gpu_val.delta_score; +// cpu_val[3] = gpu_val.show; +// cpu_val[4] = gpu_val.clk; +// cpu_val[5] = gpu_val.lr; +// cpu_val[6] = gpu_val.lr_g2sum; +// cpu_val[0] = gpu_val.slot; +// if (gpu_val.mf_size > 0) { +// for (int x = 0; x < gpu_val.mf_size; x++) { +// cpu_val[x + 7] = gpu_val.mf[x]; +// } +// } +// #endif } }; @@ -336,9 +367,9 @@ void HashTable::dump_to_cpu(int devid, StreamType stream) { } template -template +template void HashTable::update(const KeyType* d_keys, - const GradType* d_grads, + const float* d_grads, size_t len, Sgd sgd, StreamType stream) { @@ -371,8 +402,8 @@ void HashTable::update(const KeyType* d_keys, push_grad_value_size_); } -template class HashTable; -template class HashTable; +template class HashTable; +template class HashTable; template class HashTable; template class HashTable; template class HashTable; @@ -382,14 +413,14 @@ template class HashTable; template class HashTable; template class HashTable; -template void HashTable::get< +template void HashTable::get< cudaStream_t>(const unsigned long* d_keys, - paddle::framework::FeatureValue* d_vals, + float* d_vals, size_t len, cudaStream_t stream); template void -HashTable::get( +HashTable::get( const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream); template void HashTable::get(const long* d_keys, @@ -414,13 +445,13 @@ template void HashTable::get( // const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t // stream); -template void HashTable::insert< +template void HashTable::insert< cudaStream_t>(const unsigned long* d_keys, - const paddle::framework::FeatureValue* d_vals, + const float* d_vals, size_t len, cudaStream_t stream); -template void HashTable:: +template void HashTable:: insert(const unsigned long* d_keys, size_t len, char* pool, @@ -460,18 +491,20 @@ template void HashTable::insert( size_t len, cudaStream_t stream); -template void HashTable:: +template void HashTable:: dump_to_cpu(int devid, cudaStream_t stream); -template void HashTable::update< - paddle::framework::FeaturePushValue, - Optimizer, - cudaStream_t>(const unsigned long* d_keys, - const paddle::framework::FeaturePushValue* d_grads, - size_t len, - Optimizer sgd, +template void +HashTable::update(const unsigned long* d_keys, const char* d_grads, size_t len, + SparseAdagradOptimizer sgd, + cudaStream_t stream); +template void +HashTable::update(const unsigned long* d_keys, const char* d_grads, size_t len, + SparseAdamOptimizer sgd, + cudaStream_t stream); +template void +HashTable::update(const unsigned long* d_keys, const char* d_grads, size_t len, + SparseAdamSharedOptimizer sgd, cudaStream_t stream); template void HashTable:: diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 92b97625ba5d77..00a1583ee36797 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -50,6 +50,8 @@ template class HeterComm { public: HeterComm(size_t capacity, std::shared_ptr resource); + HeterComm(size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor& accessor); virtual ~HeterComm(); HeterComm(const HeterComm&) = delete; HeterComm& operator=(const HeterComm&) = delete; @@ -67,10 +69,10 @@ class HeterComm { int& uniq_len); // NOLINT void dynamic_merge_grad(int gpu_num, KeyType* d_keys, - GradType* d_grads, + float* d_grads, size_t len, int& uniq_len); - void pull_sparse(int num, KeyType* d_keys, ValType* d_vals, size_t len); + void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len); void build_ps(int num, KeyType* h_keys, ValType* h_vals, @@ -92,7 +94,7 @@ class HeterComm { template void push_sparse(int num, KeyType* d_keys, - GradType* d_grads, + float* d_grads, size_t len, Sgd& sgd); // NOLINT #elif defined(PADDLE_WITH_XPU_KP) @@ -149,6 +151,10 @@ class HeterComm { multi_mf_dim_ = multi_mf_dim; max_mf_dim_ = max_mf_dim; } + + void set_accessor(CommonFeatureValueAccessor& accessor) { + feature_value_accessor_ = accessor; + } #endif bool need_transfer(int send_id, int receive_id) { @@ -282,9 +288,11 @@ class HeterComm { char* src_val, size_t val_size); + + CommonFeatureValueAccessor feature_value_accessor_; protected: using Table = HashTable; - using PtrTable = HashTable; + using PtrTable = HashTable; std::vector tables_; std::vector ptr_tables_; std::shared_ptr resource_; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index ace533cb0c7458..0016ff318fd5f1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -36,17 +36,48 @@ HeterComm::HeterComm( platform::CUDADeviceGuard guard(resource_->dev_id(i)); allocators_.push_back(std::make_shared( 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT +#endif + if (!multi_mf_dim_) { + auto table = new Table(capacity / load_factor_); + tables_.push_back(table); + } else { + VLOG(0) << "Error:use HeterComm Construct with accessor"; + return; + } + if (multi_node_) { + storage_[i].init(feanum_, resource_->dev_id(i)); + } + } + heter_comm_kernel_ = std::make_unique(block_size_); + init_path(); +} + +template +HeterComm::HeterComm( + size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor& feature_value_accessor) { + VLOG(1) << "Construct new HeterComm"; + resource_ = resource; + storage_.resize(resource_->total_device()); + multi_mf_dim_ = resource->multi_mf(); + for (int i = 0; i < resource_->total_device(); ++i) { +#if defined(PADDLE_WITH_CUDA) + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + allocators_.push_back(std::make_shared( + 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT #endif if (!multi_mf_dim_) { auto table = new Table(capacity / load_factor_); tables_.push_back(table); } else { max_mf_dim_ = resource_->max_mf_dim(); - size_t val_type_size = TYPEALIGN( - 8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1)); - size_t grad_type_size = TYPEALIGN( - 8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + feature_value_accessor_ = feature_value_accessor; + VLOG(0) << " HeterComm INIT:" << feature_value_accessor_.GetAccessorInfo().size + << " " << feature_value_accessor_.GetAccessorInfo().update_size; + size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); auto ptr_table = new PtrTable(capacity / load_factor_); + ptr_table->set_accessor(feature_value_accessor_); ptr_table->set_feature_value_size(val_type_size, grad_type_size); ptr_tables_.push_back(ptr_table); } @@ -54,10 +85,11 @@ HeterComm::HeterComm( storage_[i].init(feanum_, resource_->dev_id(i)); } } - heter_comm_kernel_ = std::make_unique(block_size_); + heter_comm_kernel_ = std::make_unique(block_size_, feature_value_accessor_); init_path(); } + template void HeterComm::init_path() { int total_device = resource_->total_device(); @@ -648,7 +680,7 @@ template void HeterComm::dynamic_merge_grad( int gpu_num, KeyType* d_keys, - GradType* d_grads, + float* d_grads, size_t len, int& uniq_len) { int dev_id = resource_->dev_id(gpu_num); @@ -659,15 +691,14 @@ void HeterComm::dynamic_merge_grad( size_t temp_storage_bytes; // VLOG(1) << "hetercomm merge_grad: max_mf_dim: " << max_mf_dim_; - size_t grad_value_size = - TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); auto d_merge_grads = memory::Alloc(place, len * grad_value_size); - GradType* d_merge_grads_ptr = - reinterpret_cast(d_merge_grads->ptr()); + float* d_merge_grads_ptr = + reinterpret_cast(d_merge_grads->ptr()); auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); uint32_t* d_fea_num_info_ptr = @@ -836,7 +867,7 @@ void HeterComm::split_input_to_shard( template void HeterComm::pull_sparse(int num, KeyType* d_keys, - ValType* d_vals, + float* d_vals, //new edit from ValType size_t len) { if (len == 0) { return; @@ -883,12 +914,12 @@ void HeterComm::pull_sparse(int num, auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - size_t val_type_size = - TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1)); + size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + VLOG(0) << "PULLSPARSE len:" << len << " val_type_size: " << val_type_size; auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = memory::Alloc(place, len * val_type_size); - ValType* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); + float* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); split_input_to_shard(d_keys, d_idx_ptr, len, d_left_ptr, d_right_ptr, num); @@ -958,6 +989,15 @@ void HeterComm::pull_sparse(int num, d_shard_vals_ptr, d_vals, d_idx_ptr, len, val_type_size, stream); sync_stream(stream); + + // char* tmp_mem2 = (char*)malloc(len * val_type_size); + // cudaMemcpy(tmp_mem2, reinterpret_cast(d_shard_vals_ptr), len * val_type_size, + // cudaMemcpyDeviceToHost); + // for (int i =0 ; i < 20; i++){ + // float* val = (float*)(void*)&tmp_mem2[(i)*val_type_size]; + // VLOG(0) << "pullsparse walk_to_src fill_dvals cpu: "<< i << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); + // } + for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; @@ -971,7 +1011,7 @@ template template void HeterComm::push_sparse(int dev_num, KeyType* d_keys, - GradType* d_grads, + float* d_grads, size_t len, Sgd& sgd) { // NOLINT if (len == 0) { @@ -982,7 +1022,7 @@ void HeterComm::push_sparse(int dev_num, int dev_id = resource_->dev_id(dev_num); size_t grad_value_size = - TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(dev_num, 0); @@ -1027,8 +1067,8 @@ void HeterComm::push_sparse(int dev_num, KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_grads = memory::Alloc(place, len * grad_value_size); - GradType* d_shard_grads_ptr = - reinterpret_cast(d_shard_grads->ptr()); + float* d_shard_grads_ptr = + reinterpret_cast(d_shard_grads->ptr()); int uniq_len = len; dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len); @@ -1038,16 +1078,7 @@ void HeterComm::push_sparse(int dev_num, split_input_to_shard( d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, dev_num); - if (!multi_mf_dim_) { - heter_comm_kernel_->fill_shard_grads(d_shard_keys_ptr, - d_keys, - d_shard_grads_ptr, - d_grads, - d_idx_ptr, - uniq_len, - stream); - } else { - heter_comm_kernel_->dy_mf_fill_shard_grads(d_shard_keys_ptr, + heter_comm_kernel_->dy_mf_fill_shard_grads(d_shard_keys_ptr, d_keys, d_shard_grads_ptr, d_grads, @@ -1090,21 +1121,13 @@ void HeterComm::push_sparse(int dev_num, } } - if (!multi_mf_dim_) { - walk_to_dest(dev_num, - total_device, - h_left, - h_right, - d_shard_keys_ptr, - d_shard_grads_ptr); - } else { - walk_to_dest(dev_num, - total_device, - h_left, - h_right, - d_shard_keys_ptr, - reinterpret_cast(d_shard_grads_ptr), - grad_value_size); + walk_to_dest(dev_num, + total_device, + h_left, + h_right, + d_shard_keys_ptr, + reinterpret_cast(d_shard_grads_ptr), + grad_value_size); } for (int i = 0; i < total_device; ++i) { @@ -1115,21 +1138,12 @@ void HeterComm::push_sparse(int dev_num, sync_stream(node.in_stream); AnyDeviceGuard guard(resource_->dev_id(i)); - if (!multi_mf_dim_) { - tables_[i]->rwlock_->WRLock(); - tables_[i]->update(reinterpret_cast(node.key_storage), - reinterpret_cast(node.val_storage), - h_right[i] - h_left[i] + 1, - sgd, - resource_->remote_stream(i, dev_num)); - } else { - ptr_tables_[i]->rwlock_->WRLock(); - ptr_tables_[i]->update(reinterpret_cast(node.key_storage), - node.val_storage, - h_right[i] - h_left[i] + 1, - sgd, - resource_->remote_stream(i, dev_num)); - } + ptr_tables_[i]->rwlock_->WRLock(); + ptr_tables_[i]->update(reinterpret_cast(node.key_storage), + node.val_storage, + h_right[i] - h_left[i] + 1, + sgd, + resource_->remote_stream(i, dev_num)); } for (int i = 0; i < total_device; ++i) { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index fd0dd1a72cca1a..98295bb36560a9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -128,19 +128,38 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, } } -template +template __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys, KeyType* d_keys, - GradType* d_shard_grads, - GradType* d_grads, + float* d_shard_grads, + float* d_grads, T* idx, size_t len, - size_t grad_value_size) { + size_t grad_value_size, + CommonFeatureValueAccessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { d_shard_keys[i] = d_keys[idx[i]]; - *(GradType*)((char*)d_shard_grads + i * grad_value_size) = - *(GradType*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); + // *(float*)((char*)d_shard_grads + i * grad_value_size) = + // *(float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); + float* cur = (float*)((char*)d_shard_grads + i * grad_value_size); + float* shard_val = (float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); + + cur[feature_value_accessor.common_push_value.SlotIndex()] = + shard_val[feature_value_accessor.common_push_value.SlotIndex()]; + cur[feature_value_accessor.common_push_value.ShowIndex()] = + shard_val[feature_value_accessor.common_push_value.ShowIndex()]; + cur[feature_value_accessor.common_push_value.ClickIndex()] = + shard_val[feature_value_accessor.common_push_value.ClickIndex()]; + cur[feature_value_accessor.common_push_value.MfDimIndex()] = + shard_val[feature_value_accessor.common_push_value.MfDimIndex()]; + cur[feature_value_accessor.common_push_value.EmbedGIndex()] = + shard_val[feature_value_accessor.common_push_value.EmbedGIndex()]; + + for (int x = 0; x < int(shard_val[feature_value_accessor.common_push_value.MfDimIndex()]); x++) { + cur[feature_value_accessor.common_push_value.EmbedxGIndex() + x] = + shard_val[feature_value_accessor.common_push_value.EmbedxGIndex() + x]; + } } } @@ -151,36 +170,71 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, char* output, int n, size_t grad_value_size, - DynamicGradMerger& merger_) { + DynamicGradMerger& merger_, + CommonFeatureValueAccessor& feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { uint32_t start = offset[i]; uint32_t num = fea_num[i]; int ori_index = index[start]; - FeaturePushValue& out = *(FeaturePushValue*)(output + i * grad_value_size); - FeaturePushValue& in = - *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); - merger_.update_one(out, in); + float* out = (float*)(output + i * grad_value_size); + float* in = + (float*)(input + size_t(ori_index) * grad_value_size); + merger_.update_one(out, in, feature_value_accessor); for (int j = 1; j < num; ++j) { ori_index = index[start + j]; - FeaturePushValue& rhs = - *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size); - merger_.merge_one(out, rhs); + float& rhs = + *(float*)(input + size_t(ori_index) * grad_value_size); + merger_.merge_one(out, rhs, feature_value_accessor); } } } -template -__global__ void dy_mf_fill_dvals_kernel(ValType* d_shard_vals, - ValType* d_vals, +template +__global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, + float* d_vals, T* idx, size_t len, - size_t val_size) { + size_t val_size, + CommonFeatureValueAccessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { uint64_t new_offset = uint64_t(idx[i]) * val_size; - *(ValType*)((char*)d_vals + new_offset) = - *(ValType*)((char*)d_shard_vals + i * val_size); + float* cur = (float*)((char*)d_vals + new_offset); + float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); + // *(float*)((char*)d_vals + new_offset) = + // (float*)((char*)d_shard_vals + i * val_size); + cur[feature_value_accessor.common_feature_value.SlotIndex()] = + shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; + cur[feature_value_accessor.common_feature_value.ShowIndex()] = + shard_val[feature_value_accessor.common_feature_value.ShowIndex()]; + cur[feature_value_accessor.common_feature_value.ClickIndex()] = + shard_val[feature_value_accessor.common_feature_value.ClickIndex()]; + cur[feature_value_accessor.common_feature_value.MfDimIndex()] = + shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]; + cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = + shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; + cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = + shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; + cur[feature_value_accessor.common_feature_value.CpuPtrIndex()] = + shard_val[feature_value_accessor.common_feature_value.CpuPtrIndex()]; + cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = + shard_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; + cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = + shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; + for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { + cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = + shard_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; + } + + for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedXDim(); x++) { + cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x] = + shard_val[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x]; + } + for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedWDim(); x++) { + cur[feature_value_accessor.common_feature_value.EmbedxWIndex() + x] = + shard_val[feature_value_accessor.common_feature_value.EmbedxWIndex() + x]; + } } } @@ -312,7 +366,7 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage, debug_synchronous)); } -template +template void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, GradType* d_shard_grads, @@ -330,7 +384,8 @@ void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys, d_grads, idx, c_len, - grad_value_size); + grad_value_size, + feature_value_accessor_); } template @@ -345,12 +400,12 @@ void HeterCommKernel::merge_gradient(const uint32_t* offset, const StreamType& stream) { int grid_size = (n - 1) / block_size_ + 1; merge_gradients_kernel<<>>( - offset, fea_num, index, input, output, n, grad_value_size, merger_); + offset, fea_num, index, input, output, n, grad_value_size, merger_, feature_value_accessor_); } -template -void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals, - ValType* d_vals, +template +void HeterCommKernel::dy_mf_fill_dvals(float* d_shard_vals, + float* d_vals, T* idx, long long len, size_t val_size, @@ -358,7 +413,7 @@ void HeterCommKernel::dy_mf_fill_dvals(ValType* d_shard_vals, int grid_size = (len - 1) / block_size_ + 1; size_t c_len = (size_t)len; dy_mf_fill_dvals_kernel<<>>( - d_shard_vals, d_vals, idx, c_len, val_size); + d_shard_vals, d_vals, idx, c_len, val_size, feature_value_accessor_); } template void HeterCommKernel::fill_idx( @@ -404,12 +459,12 @@ template void HeterCommKernel::fill_shard_key( template void HeterCommKernel::fill_shard_grads< unsigned long, - paddle::framework::FeaturePushValue, + float, int, cudaStream_t>(unsigned long* d_shard_keys, unsigned long* d_keys, - paddle::framework::FeaturePushValue* d_shard_grads, - paddle::framework::FeaturePushValue* d_grads, + float* d_shard_grads, + float* d_grads, int* idx, long long len, const cudaStream_t& stream); @@ -469,12 +524,11 @@ template void HeterCommKernel::reduce_by_key< template void HeterCommKernel::dy_mf_fill_shard_grads< unsigned long, - paddle::framework::FeaturePushValue, int, cudaStream_t>(unsigned long* d_shard_keys, unsigned long* d_keys, - paddle::framework::FeaturePushValue* d_shard_grads, - paddle::framework::FeaturePushValue* d_grads, + float* d_shard_grads, + float* d_grads, int* idx, long long len, size_t grad_value_size, @@ -492,9 +546,9 @@ template void HeterCommKernel::merge_gradient( const cudaStream_t& stream); template void HeterCommKernel:: - dy_mf_fill_dvals( - paddle::framework::FeatureValue* d_shard_vals, - paddle::framework::FeatureValue* d_vals, + dy_mf_fill_dvals( + float* d_shard_vals, + float* d_vals, int* idx, long long len, size_t val_size, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index d1555dc2e09196..be433065c581e5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -41,24 +41,35 @@ struct DynamicGradMerger { return out; } - template - __device__ __forceinline__ void update_one(T& output, const T& input) { - output.slot = input.slot; - output.show = input.show; - output.clk = input.clk; - output.mf_dim = input.mf_dim; - output.lr_g = input.lr_g; - for (int i = 0; i < output.mf_dim; ++i) { - output.mf_g[i] = input.mf_g[i]; + __device__ __forceinline__ void update_one(float* output, const float* input, + CommonFeatureValueAccessor& feature_value_accessor) { + output[feature_value_accessor.common_push_value.SlotIndex()] = + input[feature_value_accessor.common_push_value.SlotIndex()]; + output[feature_value_accessor.common_push_value.ShowIndex()] = + input[feature_value_accessor.common_push_value.ShowIndex()]; + output[feature_value_accessor.common_push_value.ClickIndex()] = + input[feature_value_accessor.common_push_value.ClickIndex()]; + output[feature_value_accessor.common_push_value.MfDimIndex()] = + input[feature_value_accessor.common_push_value.MfDimIndex()]; + output[feature_value_accessor.common_push_value.EmbedGIndex()] = + input[feature_value_accessor.common_push_value.EmbedGIndex()]; + for (int j = 0; j < int(output[feature_value_accessor.common_push_value.MfDimIndex()]); j++) { + output[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = + input[feature_value_accessor.common_push_value.EmbedxGIndex() + j]; } } - template - __device__ __forceinline__ void merge_one(T& output, const T& input) { - output.show += input.show; - output.clk += input.clk; - output.lr_g += input.lr_g; - for (int i = 0; i < input.mf_dim; ++i) { - output.mf_g[i] += input.mf_g[i]; + + __device__ __forceinline__ void merge_one(float* output, const float* input, + CommonFeatureValueAccessor& feature_value_accessor) { + output[feature_value_accessor.common_push_value.ShowIndex()] += + input[feature_value_accessor.common_push_value.ShowIndex()]; + output[feature_value_accessor.common_push_value.ClickIndex()] += + input[feature_value_accessor.common_push_value.ClickIndex()]; + output[feature_value_accessor.common_push_value.EmbedGIndex()] += + input[feature_value_accessor.common_push_value.EmbedGIndex()]; + for (int j = 0; j < int(output[feature_value_accessor.common_push_value.MfDimIndex()]); j++) { + output[feature_value_accessor.common_push_value.EmbedxGIndex() + j] += + input[feature_value_accessor.common_push_value.EmbedxGIndex() + j]; } } }; @@ -68,6 +79,8 @@ class HeterCommKernel { HeterCommKernel() {} explicit HeterCommKernel(const int block_size) : block_size_(block_size) {} + explicit HeterCommKernel(const int block_size, CommonFeatureValueAccessor& feature_value_accessor) : block_size_(block_size), feature_value_accessor_(feature_value_accessor) {} + template void fill_idx(T* idx, long long len, const StreamType& stream); @@ -146,13 +159,12 @@ class HeterCommKernel { bool debug_synchronous = false); template void dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, - GradType* d_shard_grads, - GradType* d_grads, + float* d_shard_grads, + float* d_grads, T* idx, long long len, size_t grad_value_size, @@ -169,14 +181,15 @@ class HeterCommKernel { DynamicGradMerger& merger_, const StreamType& stream); - template - void dy_mf_fill_dvals(ValType* d_shard_vals, - ValType* d_vals, + template + void dy_mf_fill_dvals(float* d_shard_vals, + float* d_vals, T* idx, long long len, size_t val_size, const StreamType& stream); + CommonFeatureValueAccessor feature_value_accessor_; private: int block_size_{256}; }; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc index 82f5393c3660ba..618a88fd70e56f 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc @@ -22,33 +22,31 @@ namespace paddle { namespace framework { HeterPsBase* HeterPsBase::get_instance( - size_t capacity, std::shared_ptr resource) { - return new HeterPs(capacity, resource); + size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor feature_value_accessor, + int optimizer_type) { + return new HeterPs(capacity, resource, feature_value_accessor, optimizer_type); } -HeterPs::HeterPs(size_t capacity, std::shared_ptr resource) { +HeterPs::HeterPs(size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor feature_value_accessor, + int optimizer_type) { comm_ = - std::make_shared>( + std::make_shared>( capacity, resource); + feature_value_accessor_ = feature_value_accessor; + optimizer_type_ = optimizer_type; } HeterPs::~HeterPs() {} void HeterPs::pull_sparse(int num, FeatureKey* d_keys, - FeatureValue* d_vals, + float* d_vals, size_t len) { comm_->pull_sparse(num, d_keys, d_vals, len); } -void HeterPs::build_ps(int num, - FeatureKey* h_keys, - FeatureValue* h_vals, - size_t len, - size_t chunk_size, - int stream_num) { - comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num); -} int HeterPs::get_index_by_devid(int devid) { return comm_->get_index_by_devid(devid); @@ -68,12 +66,13 @@ void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } void HeterPs::push_sparse(int num, FeatureKey* d_keys, - FeaturePushValue* d_grads, + float* d_grads, size_t len) { comm_->push_sparse(num, d_keys, d_grads, len); // comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_); } + } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index 005cbd401223d5..e5af93c0ef8bde 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -22,35 +22,32 @@ namespace paddle { namespace framework { HeterPsBase* HeterPsBase::get_instance( - size_t capacity, std::shared_ptr resource) { - return new HeterPs(capacity, resource); + size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor feature_value_accessor, + int optimizer_type) { + return new HeterPs(capacity, resource, feature_value_accessor, optimizer_type); } -HeterPs::HeterPs(size_t capacity, std::shared_ptr resource) { +HeterPs::HeterPs(size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor feature_value_accessor, + int optimizer_type) { comm_ = - std::make_shared>( - capacity, resource); - opt_ = Optimizer(); + std::make_shared>( + capacity, resource, feature_value_accessor); + feature_value_accessor_ = feature_value_accessor; + optimizer_type_ = optimizer_type; + // opt_ = Optimizer(); } HeterPs::~HeterPs() {} void HeterPs::pull_sparse(int num, FeatureKey* d_keys, - FeatureValue* d_vals, + float* d_vals, size_t len) { comm_->pull_sparse(num, d_keys, d_vals, len); } -void HeterPs::build_ps(int num, - FeatureKey* h_keys, - FeatureValue* h_vals, - size_t len, - size_t chunk_size, - int stream_num) { - comm_->build_ps(num, h_keys, h_vals, len, chunk_size, stream_num); -} - void HeterPs::build_ps(int num, FeatureKey* h_keys, char* pool, @@ -80,10 +77,19 @@ void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } void HeterPs::push_sparse(int num, FeatureKey* d_keys, - FeaturePushValue* d_grads, + float* d_grads, size_t len) { - comm_->push_sparse(num, d_keys, d_grads, len, opt_); - // comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_); + if (optimizer_type_ == 3) { //adam + auto optimizer = SparseAdamOptimizer(feature_value_accessor_); + VLOG(0) << "INTO push_sparse SparseAdamOptimizer EmbedDim():" << optimizer.EmbedDim(); + comm_->push_sparse(num, d_keys, d_grads, len, optimizer); + } else if (optimizer_type_ == 4) { //sharedadam + auto optimizer = SparseAdamSharedOptimizer(feature_value_accessor_); + comm_->push_sparse(num, d_keys, d_grads, len, optimizer); + } else { + auto optimizer = SparseAdagradOptimizer(feature_value_accessor_); + comm_->push_sparse(num, d_keys, d_grads, len, optimizer); + } } void HeterPs::set_nccl_comm_and_size(const std::vector& inner_comms, @@ -96,6 +102,10 @@ void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) { comm_->set_multi_mf_dim(multi_mf_dim, max_mf_dim); } +void HeterPs::set_accessor(CommonFeatureValueAccessor& accessor) { + comm_->set_accessor(accessor); +} + } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 7fee2297388308..63f35dee92bc83 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -29,21 +29,17 @@ namespace framework { class HeterPs : public HeterPsBase { public: HeterPs() {} - HeterPs(size_t capacity, std::shared_ptr resource); + HeterPs(size_t capacity, std::shared_ptr resource, + CommonFeatureValueAccessor feature_value_accessor, + int optimizer_type); virtual ~HeterPs(); HeterPs(const HeterPs&) = delete; HeterPs& operator=(const HeterPs&) = delete; void pull_sparse(int num, FeatureKey* d_keys, - FeatureValue* d_vals, + float* d_vals, size_t len) override; - void build_ps(int num, - FeatureKey* h_keys, - FeatureValue* h_vals, - size_t len, - size_t chunk_size, - int stream_num) override; void build_ps(int num, FeatureKey* h_keys, char* pool, @@ -56,6 +52,7 @@ class HeterPs : public HeterPsBase { const std::vector& inter_comms, int comm_size) override; void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) override; + void set_accessor(CommonFeatureValueAccessor& accessor) override; #endif void set_sparse_sgd(const OptimizerConfig& optimizer_config) override; @@ -66,13 +63,15 @@ class HeterPs : public HeterPsBase { void show_one_table(int gpu_num) override; void push_sparse(int num, FeatureKey* d_keys, - FeaturePushValue* d_grads, + float* d_grads, size_t len) override; private: - std::shared_ptr> comm_; + std::shared_ptr> comm_; #if defined(PADDLE_WITH_CUDA) - Optimizer opt_; + // Optimizer opt_; + CommonFeatureValueAccessor feature_value_accessor_; + int optimizer_type_; #endif }; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index acc984f14adaaf..0769b9280ef4d8 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -34,14 +34,8 @@ class HeterPsBase { virtual void pull_sparse(int num, FeatureKey* d_keys, - FeatureValue* d_vals, + float* d_vals, size_t len) = 0; - virtual void build_ps(int num, - FeatureKey* h_keys, - FeatureValue* h_vals, - size_t len, - size_t chunk_size, - int stream_num) = 0; virtual void build_ps(int num, FeatureKey* h_keys, char* pool, @@ -56,19 +50,23 @@ class HeterPsBase { const std::vector& inter_comms, int comm_size) = 0; virtual void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) = 0; + virtual void set_accessor(CommonFeatureValueAccessor& accessor) = 0; + #endif virtual void end_pass() = 0; virtual void show_one_table(int gpu_num) = 0; virtual void push_sparse(int num, FeatureKey* d_keys, - FeaturePushValue* d_grads, + float* d_grads, size_t len) = 0; virtual void set_sparse_sgd(const OptimizerConfig& optimizer_config) = 0; virtual void set_embedx_sgd(const OptimizerConfig& optimizer_config) = 0; static HeterPsBase* get_instance(size_t capacity, - std::shared_ptr resource); + std::shared_ptr resource, + CommonFeatureValueAccessor feature_value_accessor, + int optimizer_type); }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/mem_pool.h b/paddle/fluid/framework/fleet/heter_ps/mem_pool.h index 88c3136dd77d16..0a95576cd2987b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/mem_pool.h +++ b/paddle/fluid/framework/fleet/heter_ps/mem_pool.h @@ -82,19 +82,19 @@ class HBMMemoryPool : public managed { cudaMemset(mem_, 0, block_size_ * capacity); } - friend std::ostream& operator<<(std::ostream& out, HBMMemoryPool& p) { - for (size_t k = 0; k < 5; k++) { - auto x = (FeatureValue*)(p.mem() + k * p.capacity()); - out << "show: " << x->show << " clk: " << x->clk << " slot: " << x->slot - << " lr: " << x->lr << " mf_dim: " << x->mf_size - << " mf_size: " << x->mf_size << " mf:"; - for (int i = 0; i < x->mf_size + 1; ++i) { - out << " " << x->mf[i]; - } - out << "\n"; - } - return out; - } + // friend std::ostream& operator<<(std::ostream& out, HBMMemoryPool& p) { + // for (size_t k = 0; k < 5; k++) { + // auto x = (float*)(p.mem() + k * p.capacity()); + // out << "show: " << x->show << " clk: " << x->clk << " slot: " << x->slot + // << " lr: " << x->lr << " mf_dim: " << x->mf_size + // << " mf_size: " << x->mf_size << " mf:"; + // for (int i = 0; i < x->mf_size + 1; ++i) { + // out << " " << x->mf[i]; + // } + // out << "\n"; + // } + // return out; + // } char* mem() { return mem_; } diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index 74a4f1ca16c2b7..0e1206de23a7f3 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -27,42 +27,91 @@ namespace paddle { namespace framework { #if defined(PADDLE_WITH_CUDA) -template + class Optimizer { public: - Optimizer() {} + __host__ Optimizer(CommonFeatureValueAccessor feature_value_accessor) { + feature_value_accessor_ = feature_value_accessor; + } + __host__ ~Optimizer() {} + // __host__ virtual void initialize(const OptimizerConfig& optimizer_config, + // size_t emb_dim) { + + // _lr_embedding_dim = 1; + // _embedding_dim = emb_dim; + // } - ~Optimizer() {} + __device__ void update_value(const OptimizerConfig& optimizer_config, + float& val, // NOLINT + const float& grad) { + } - void initialize() {} + __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, + float* ptr, const float* grad) { + } + + // __host__ float& MinBound() { return _min_bound; } + // float& MaxBound() { return _max_bound; } - __device__ void update_lr(const OptimizerConfig& optimizer_config, - float& w, // NOLINT - float& g2sum, - float g, // NOLINT - float scale) { - double add_g2sum = 0; - double ratio = optimizer_config.learning_rate * - sqrt(optimizer_config.initial_g2sum / - (optimizer_config.initial_g2sum + g2sum)); - double scaled_grad = g / scale; + CommonFeatureValueAccessor feature_value_accessor_; - w += scaled_grad * ratio; +// protected: + // float _min_bound; + // float _max_bound; + // float _initial_range; + size_t _embedding_dim; + size_t _lr_embedding_dim; - if (w < optimizer_config.min_bound) w = optimizer_config.min_bound; - if (w > optimizer_config.max_bound) w = optimizer_config.max_bound; - add_g2sum += scaled_grad * scaled_grad; +}; - g2sum += add_g2sum; +class SparseAdagradOptimizer : public Optimizer { + public: + + __host__ SparseAdagradOptimizer(CommonFeatureValueAccessor feature_value_accessor): Optimizer(feature_value_accessor) { + _lr_embedding_dim = 1; + _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); } - __device__ void update_mf(const OptimizerConfig& optimizer_config, - int n, + // __host__ virtual void initialize(const OptimizerConfig& optimizer_config, + // size_t emb_dim) { + // _lr_embedding_dim = 1; + // _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); + // // this->_learning_rate = optimizer_config.learning_rate; + // // this->_initial_g2sum = optimizer_config.initial_g2sum; + // // this->_initial_range = optimizer_config.initial_range; + // // this->_min_bound = optimizer_config.min_bound; + // // this->_max_bound = optimizer_config.max_bound; + // } + + + // __device__ void update_lr(const OptimizerConfig& optimizer_config, + // float& w, // NOLINT + // float* sgd, float g, // NOLINT + // float scale) { + // float& g2sum = sgd[G2SumIndex()]; + // double add_g2sum = 0; + // double ratio = optimizer_config.learning_rate * + // sqrt(optimizer_config.initial_g2sum / + // (optimizer_config.initial_g2sum + g2sum)); + // double scaled_grad = g / scale; + + // w += scaled_grad * ratio; + + // if (w < optimizer_config.min_bound) w = optimizer_config.min_bound; + // if (w > optimizer_config.max_bound) w = optimizer_config.max_bound; + + // add_g2sum += scaled_grad * scaled_grad; + + // g2sum += add_g2sum; + // } + + + __device__ void update_value_work(const OptimizerConfig& optimizer_config, int n, float* w, - float& g2sum, // NOLINT - const float* g, - float scale) { + float* sgd, // NOLINT + const float* g, float scale) { + float& g2sum = sgd[G2SumIndex()]; double add_g2sum = 0; double ratio = optimizer_config.mf_learning_rate * sqrt(optimizer_config.mf_initial_g2sum / @@ -82,81 +131,471 @@ class Optimizer { g2sum += add_g2sum / n; } + // __device__ void update_value_work(const OptimizerConfig& optimizer_config, + // float* w, + // float* sgd, // NOLINT + // const float* g, float scale) { + // float& g2sum = sgd[G2SumIndex()]; + // double add_g2sum = 0; + // double ratio = _learning_rate * + // sqrt(_initial_g2sum / + // (_initial_g2sum + g2sum)); + // for (int i = 0; i < _embedding_dim; ++i) { + // double scaled_grad = g[i] / scale; + + // w[i] += scaled_grad * ratio; + + // if (w[i] < this->_min_bound) + // w[i] = this->_min_bound; + // if (w[i] > this->_max_bound) + // w[i] = this->_max_bound; + // add_g2sum += scaled_grad * scaled_grad; + // } + + // g2sum += add_g2sum / _embedding_dim; + // } + __device__ void update_value(const OptimizerConfig& optimizer_config, - ValType& val, // NOLINT - const GradType& grad) { - val.slot = grad.slot; - val.show += grad.show; - val.clk += grad.clk; - val.delta_score += optimizer_config.nonclk_coeff * (grad.show - grad.clk) + - optimizer_config.clk_coeff * grad.clk; + float& val, // NOLINT + const float& grad) { + } + __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, + float* ptr, const float* grad) { + float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()]; + float g_click = grad[feature_value_accessor_.common_push_value.ClickIndex()]; + - update_lr(optimizer_config, val.lr, val.lr_g2sum, grad.lr_g, grad.show); + ptr[feature_value_accessor_.common_feature_value.SlotIndex()] = + grad[feature_value_accessor_.common_push_value.SlotIndex()]; + ptr[feature_value_accessor_.common_feature_value.ShowIndex()] += g_show; + ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click; + ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] += + optimizer_config.nonclk_coeff * (g_show - g_click) + + optimizer_config.clk_coeff * g_click; - if (val.mf_size == 0) { + update_value_work(optimizer_config, 1, + ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedGIndex(), + g_show); + + int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= - optimizer_config.nonclk_coeff * (val.show - val.clk) + - optimizer_config.clk_coeff * val.clk) { - val.mf_size = MF_DIM + 1; - val.mf[0] = 0; + optimizer_config.nonclk_coeff * + (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - + ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + + + // ptr->mf_size = MF_DIM + 1; int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); - for (int i = 0; i < MF_DIM; ++i) { - val.mf[i + 1] = + for (int i = 0; i < mf_dim; ++i) { + ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } } } else { - update_mf(optimizer_config, - MF_DIM, - &val.mf[1], - val.mf[0], - grad.mf_g, - grad.show); + update_value_work(optimizer_config, mf_dim, + ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedxGIndex(), + g_show); } } + + __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim();} + __host__ __device__ size_t EmbedDim() { return _lr_embedding_dim; } + __host__ __device__ size_t EmbedxDim() { return _embedding_dim; } + __host__ __device__ size_t G2SumIndex() { return 0; } + __host__ __device__ size_t EmbedxG2SumIndex() { return 0; } + + +// private: + // float _learning_rate; + // float _initial_g2sum; +}; + +class SparseAdamOptimizer : public Optimizer { + public: + + __host__ SparseAdamOptimizer(CommonFeatureValueAccessor feature_value_accessor): Optimizer(feature_value_accessor) { + _lr_embedding_dim = 1; + _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); + } + + // __host__ virtual void initialize(const OptimizerConfig& optimizer_config, + // size_t emb_dim) { + // _lr_embedding_dim = 1; + // _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); + // // this->_learning_rate = optimizer_config.learning_rate; + // // this->_initial_range = optimizer_config.initial_range; + // // this->_beta1_decay_rate = optimizer_config.beta1_decay_rate; + // // this->_beta2_decay_rate = optimizer_config.beta2_decay_rate; + // // this->_ada_epsilon = optimizer_config.ada_epsilon; // float epsilon = 1e-08; + // // this->_min_bound = optimizer_config.min_bound; + // // this->_max_bound = optimizer_config.max_bound; + // } + + // __device__ void update_lr(const OptimizerConfig& optimizer_config, + // float& w, // NOLINT + // float* sgd, float g, // NOLINT + // float scale) { + + // float* moment1 = sgd + GSumIndex(); + // float* moment2 = sgd + G2SumIndex(); + // float* beta1_pow = sgd + Beta1PowIndex(); + // float* beta2_pow = sgd + Beta2PowIndex(); + + // float beta1_pow_ = *beta1_pow; + // float beta2_pow_ = *beta2_pow; + // float moment1_ = *moment1; + // float moment2_ = *moment2; + + // float epsilon = 1e-08; + // double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + // double scaled_grad = g / scale; + // double new_moment1 = optimizer_config.beta1_decay_rate * moment1_ + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; + // double new_moment2 = optimizer_config.beta2_decay_rate * moment2_ + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; + // w += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); + + // if (w < optimizer_config.min_bound) w = optimizer_config.min_bound; + // if (w > optimizer_config.max_bound) w = optimizer_config.max_bound; + + // (*moment1) = new_moment1; + // (*moment2) = new_moment2; + + // (*beta1_pow) *= optimizer_config.beta1_decay_rate; + // (*beta2_pow) *= optimizer_config.beta2_decay_rate; + + // } + + __device__ void update_value_work(const OptimizerConfig& optimizer_config, int n, + float* w, + float* sgd, + const float* g, float scale) { + float* moment1 = sgd + GSumIndex(); + float* moment2 = sgd + G2SumIndex(); + float* beta1_pow = sgd + Beta1PowIndex(); + float* beta2_pow = sgd + Beta2PowIndex(); + + float beta1_pow_ = *beta1_pow; + float beta2_pow_ = *beta2_pow; + + float epsilon = 1e-08; + double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + for (int i = 0; i < n; ++i) { + double scaled_grad = g[i] / scale; + + double new_moment1 = optimizer_config.beta1_decay_rate * moment1[i] + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; + double new_moment2 = optimizer_config.beta2_decay_rate * moment2[i] + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; + w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); + + + if (w[i] < optimizer_config.mf_min_bound) + w[i] = optimizer_config.mf_min_bound; + if (w[i] > optimizer_config.mf_max_bound) + w[i] = optimizer_config.mf_max_bound; + + moment1[i] = new_moment1; + moment2[i] = new_moment2; + } + (*beta1_pow) *= optimizer_config.beta1_decay_rate; + (*beta2_pow) *= optimizer_config.beta2_decay_rate; + } + + // __device__ void update_value_work(const OptimizerConfig& optimizer_config, + // float* w, + // float* sgd, // NOLINT + // const float* g, float scale) { + // float* moment1 = sgd + GSumIndex(); + // float* moment2 = sgd + G2SumIndex(); + // float* beta1_pow = sgd + Beta1PowIndex(); + // float* beta2_pow = sgd + Beta2PowIndex(); + + // float beta1_pow_ = *beta1_pow; + // float beta2_pow_ = *beta2_pow; + + // double ratio = _learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + + // for (int i = 0; i < this->_embedding_dim; ++i) { + // double scaled_grad = g[i] / scale; + + // double new_moment1 = _beta1_decay_rate * moment1[i] + (1.0 - _beta1_decay_rate) * scaled_grad; + // double new_moment2 = _beta2_decay_rate * moment2[i] + (1.0 - _beta2_decay_rate) * scaled_grad * scaled_grad; + // w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + _ada_epsilon)); + + + // if (w[i] < this->_min_bound) + // w[i] = this->_min_bound; + // if (w[i] > this->_max_bound) + // w[i] = this->_max_bound; + + // moment1[i] = new_moment1; + // moment2[i] = new_moment2; + // } + // (*beta1_pow) *= _beta1_decay_rate; + // (*beta2_pow) *= _beta2_decay_rate; + // } + + __device__ void update_value(const OptimizerConfig& optimizer_config, + float& val, // NOLINT + const float& grad) { + } + __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, + float* ptr, const float* grad) { + + float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()]; + float g_click = grad[feature_value_accessor_.common_push_value.ClickIndex()]; + + + ptr[feature_value_accessor_.common_feature_value.SlotIndex()] = + grad[feature_value_accessor_.common_push_value.SlotIndex()]; + ptr[feature_value_accessor_.common_feature_value.ShowIndex()] += g_show; + ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click; + ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] += + optimizer_config.nonclk_coeff * (g_show - g_click) + + optimizer_config.clk_coeff * g_click; + + update_value_work(optimizer_config, 1, + ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedGIndex(), + g_show); + int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { + if (optimizer_config.mf_create_thresholds <= + optimizer_config.nonclk_coeff * + (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - + ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + + + // ptr->mf_size = MF_DIM + 1; + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + curandState state; + curand_init(clock64(), tid_x, 0, &state); + for (int i = 0; i < mf_dim; ++i) { + ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = + (curand_uniform(&state)) * optimizer_config.mf_initial_range; + } + ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta1PowIndex()] = + optimizer_config.beta1_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta2PowIndex()] = + optimizer_config.beta2_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta1PowIndex()] = + optimizer_config.beta1_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta2PowIndex()] = + optimizer_config.beta2_decay_rate; + } + } else { + update_value_work(optimizer_config, mf_dim, + ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedxGIndex(), + g_show); + } + } + + __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); } + __host__ __device__ size_t EmbedDim() { return _lr_embedding_dim * 2 + 2; } + __host__ __device__ size_t EmbedxDim() { return _embedding_dim * 2 + 2; } + __host__ __device__ size_t GSumIndex() { return 0; } + __host__ __device__ size_t G2SumIndex() { return GSumIndex() + _lr_embedding_dim; } + __host__ __device__ size_t Beta1PowIndex() { return G2SumIndex() + _lr_embedding_dim; } + __host__ __device__ size_t Beta2PowIndex() { return Beta1PowIndex() + 1; } + __host__ __device__ size_t EmbedxGSumIndex() { return 0; } + __host__ __device__ size_t EmbedxG2SumIndex() { return EmbedxGSumIndex() + _embedding_dim; } + __host__ __device__ size_t EmbedxBeta1PowIndex() { return EmbedxG2SumIndex() + _embedding_dim; } + __host__ __device__ size_t EmbedxBeta2PowIndex() { return EmbedxBeta1PowIndex() + 1; } + + +// protected: + // float _learning_rate; + // float _beta1_decay_rate; + // float _beta2_decay_rate; + // float _ada_epsilon; +}; + + +class SparseAdamSharedOptimizer : public Optimizer { + public: + + __host__ SparseAdamSharedOptimizer(CommonFeatureValueAccessor feature_value_accessor): Optimizer(feature_value_accessor) { + _lr_embedding_dim = 1; + _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); + } + + // virtual void initialize() { + // _lr_embedding_dim = 1; + // _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); + // // this->_learning_rate = optimizer_config.learning_rate; + // // this->_initial_range = optimizer_config.initial_range; + // // this->_beta1_decay_rate = optimizer_config.beta1_decay_rate; + // // this->_beta2_decay_rate = optimizer_config.beta2_decay_rate; + // // this->_ada_epsilon = optimizer_config.ada_epsilon; // float epsilon = 1e-08; + // // this->_min_bound = optimizer_config.min_bound; + // // this->_max_bound = optimizer_config.max_bound; + // } + + // __device__ void update_lr(const OptimizerConfig& optimizer_config, + // float& w, // NOLINT + // float* sgd, float g, // NOLINT + // float scale) { + // float* moment1 = sgd + GSumIndex(); + // float* moment2 = sgd + G2SumIndex(); + // float* beta1_pow = sgd + Beta1PowIndex(); + // float* beta2_pow = sgd + Beta2PowIndex(); + + // float beta1_pow_ = *beta1_pow; + // float beta2_pow_ = *beta2_pow; + // float moment1_ = *moment1; + // float moment2_ = *moment2; + + // float epsilon = 1e-08; + // double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + // double scaled_grad = g / scale; + // double new_moment1 = optimizer_config.beta1_decay_rate * moment1_ + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; + // double new_moment2 = optimizer_config.beta2_decay_rate * moment2_ + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; + // w += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); + + // if (w < optimizer_config.min_bound) w = optimizer_config.min_bound; + // if (w > optimizer_config.max_bound) w = optimizer_config.max_bound; + + // (*moment1) = new_moment1; + // (*moment2) = new_moment2; + + // (*beta1_pow) *= optimizer_config.beta1_decay_rate; + // (*beta2_pow) *= optimizer_config.beta2_decay_rate; + // } + + __device__ void update_value_work(const OptimizerConfig& optimizer_config, int n, + float* w, + float* sgd, + const float* g, float scale) { + float* moment1 = sgd + GSumIndex(); + float* moment2 = sgd + G2SumIndex(); + float* beta1_pow = sgd + Beta1PowIndex(); + float* beta2_pow = sgd + Beta2PowIndex(); + + float beta1_pow_ = *beta1_pow; + float beta2_pow_ = *beta2_pow; + float moment1_ = *moment1; + float moment2_ = *moment2; + float epsilon = 1e-08; + double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + + double sum_mom1 = 0.0; + double sum_mom2 = 0.0; + for (int i = 0; i < n; ++i) { + double scaled_grad = g[i] / scale; + + double new_moment1 = optimizer_config.beta1_decay_rate * moment1_ + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; + double new_moment2 = optimizer_config.beta2_decay_rate * moment2_ + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; + w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); + + + if (w[i] < optimizer_config.mf_min_bound) + w[i] = optimizer_config.mf_min_bound; + if (w[i] > optimizer_config.mf_max_bound) + w[i] = optimizer_config.mf_max_bound; + + sum_mom1 += new_moment1; + sum_mom2 += new_moment2; + } + + (*moment1) = sum_mom1 / n; + (*moment2) = sum_mom2 / n; + (*beta1_pow) *= optimizer_config.beta1_decay_rate; + (*beta2_pow) *= optimizer_config.beta2_decay_rate; + } + + __device__ void update_value(const OptimizerConfig& optimizer_config, + float& val, // NOLINT + const float& grad) { + } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, - ValType* ptr, - const GradType& grad) { - ptr->slot = grad.slot; - ptr->show += grad.show; - ptr->clk += grad.clk; - ptr->delta_score += optimizer_config.nonclk_coeff * (grad.show - grad.clk) + - optimizer_config.clk_coeff * grad.clk; - - update_lr(optimizer_config, ptr->lr, ptr->lr_g2sum, grad.lr_g, grad.show); - // use MF_DIM temporarily - // ptr->mf_dim = grad.mf_dim; - - if (ptr->mf_size == 0) { + float* ptr, const float* grad) { + + float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()]; + float g_click = grad[feature_value_accessor_.common_push_value.ClickIndex()]; + + + ptr[feature_value_accessor_.common_feature_value.SlotIndex()] = + grad[feature_value_accessor_.common_push_value.SlotIndex()]; + ptr[feature_value_accessor_.common_feature_value.ShowIndex()] += g_show; + ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click; + ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] += + optimizer_config.nonclk_coeff * (g_show - g_click) + + optimizer_config.clk_coeff * g_click; + + update_value_work(optimizer_config, 1, + ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedGIndex(), + g_show); + int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= - optimizer_config.nonclk_coeff * (ptr->show - ptr->clk) + - optimizer_config.clk_coeff * ptr->clk) { - ptr->mf_size = ptr->mf_dim + 1; + optimizer_config.nonclk_coeff * + (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - + ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + // ptr->mf_size = MF_DIM + 1; - ptr->mf[0] = 0; int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); - for (int i = 0; i < ptr->mf_dim; ++i) { - ptr->mf[i + 1] = + for (int i = 0; i < mf_dim; ++i) { + ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } + ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta1PowIndex()] = + optimizer_config.beta1_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta2PowIndex()] = + optimizer_config.beta2_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta1PowIndex()] = + optimizer_config.beta1_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta2PowIndex()] = + optimizer_config.beta2_decay_rate; } } else { - update_mf(optimizer_config, - ptr->mf_dim, - &(ptr->mf[1]), - ptr->mf[0], - grad.mf_g, - grad.show); // for local test + update_value_work(optimizer_config, mf_dim, + ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedxGIndex(), + g_show); } } + + __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); } + __host__ __device__ size_t EmbedDim() { return 4; } + __host__ __device__ size_t EmbedxDim() { return 4; } + __host__ __device__ size_t GSumIndex() { return 0; } + __host__ __device__ size_t G2SumIndex() { return GSumIndex() + 1; } + __host__ __device__ size_t Beta1PowIndex() { return G2SumIndex() + 1; } + __host__ __device__ size_t Beta2PowIndex() { return Beta1PowIndex() + 1; } + __host__ __device__ size_t EmbedxGSumIndex() { return 0; } + __host__ __device__ size_t EmbedxG2SumIndex() { return EmbedxGSumIndex() + 1; } + __host__ __device__ size_t EmbedxBeta1PowIndex() { return EmbedxG2SumIndex() + 1; } + __host__ __device__ size_t EmbedxBeta2PowIndex() { return EmbedxBeta1PowIndex() + 1; } + + +// protected: +// float _learning_rate; +// float _beta1_decay_rate; +// float _beta2_decay_rate; +// float _ada_epsilon; }; + #endif } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h index 0db72992215a28..8b301b9dbae015 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h @@ -27,13 +27,19 @@ class OptimizerConfig { float learning_rate = 0.05; float initial_g2sum = 3.0; float initial_range = 0; + float beta1_decay_rate = 0.9; //adam + float beta2_decay_rate = 0.999; //adam + float ada_epsilon = 1e-8; float mf_create_thresholds = 10; float mf_learning_rate = 0.05; float mf_initial_g2sum = 3.0; float mf_initial_range = 1e-4; + float mf_beta1_decay_rate = 0.9; //adam + float mf_beta2_decay_rate = 0.999; //adam float mf_min_bound = -10; float mf_max_bound = 10; + float mf_ada_epsilon = 1e-8; void set_sparse_sgd(float nonclk_coeff, float clk_coeff, @@ -41,7 +47,10 @@ class OptimizerConfig { float max_bound, float learning_rate, float initial_g2sum, - float initial_range) { + float initial_range, + float beta1_decay_rate, + float beta2_decay_rate, + float ada_epsilon) { this->nonclk_coeff = nonclk_coeff; this->clk_coeff = clk_coeff; this->min_bound = min_bound; @@ -49,6 +58,9 @@ class OptimizerConfig { this->learning_rate = learning_rate; this->initial_g2sum = initial_g2sum; this->initial_range = initial_range; + this->beta1_decay_rate = beta1_decay_rate; + this->beta2_decay_rate = beta2_decay_rate; + this->ada_epsilon = ada_epsilon; } void set_sparse_sgd(const OptimizerConfig& optimizer_config) { @@ -59,6 +71,9 @@ class OptimizerConfig { this->learning_rate = optimizer_config.learning_rate; this->initial_g2sum = optimizer_config.initial_g2sum; this->initial_range = optimizer_config.initial_range; + this->beta1_decay_rate = optimizer_config.beta1_decay_rate; + this->beta2_decay_rate = optimizer_config.beta2_decay_rate; + this->ada_epsilon = optimizer_config.ada_epsilon; } void set_embedx_sgd(float mf_create_thresholds, @@ -66,13 +81,19 @@ class OptimizerConfig { float mf_initial_g2sum, float mf_initial_range, float mf_min_bound, - float mf_max_bound) { + float mf_max_bound, + float mf_beta1_decay_rate, + float mf_beta2_decay_rate, + float mf_ada_epsilon) { this->mf_create_thresholds = mf_create_thresholds; this->mf_learning_rate = mf_learning_rate; this->mf_initial_g2sum = mf_initial_g2sum; this->mf_initial_range = mf_initial_range; this->mf_min_bound = mf_min_bound; this->mf_max_bound = mf_max_bound; + this->mf_beta1_decay_rate = mf_beta1_decay_rate; + this->mf_beta2_decay_rate = mf_beta2_decay_rate; + this->mf_ada_epsilon = mf_ada_epsilon; } void set_embedx_sgd(const OptimizerConfig& optimizer_config) { @@ -82,6 +103,9 @@ class OptimizerConfig { this->mf_initial_range = optimizer_config.mf_initial_range; this->mf_min_bound = optimizer_config.mf_min_bound; this->mf_max_bound = optimizer_config.mf_max_bound; + this->mf_beta1_decay_rate = optimizer_config.mf_beta1_decay_rate; + this->mf_beta2_decay_rate = optimizer_config.mf_beta2_decay_rate; + this->mf_ada_epsilon = optimizer_config.mf_ada_epsilon; } }; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 94fa386aac4880..fa4f571ff02bef 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -34,8 +34,8 @@ limitations under the License. */ #include #include "paddle/fluid/platform/timer.h" +#include "paddle/fluid/framework/data_set.h" #if defined(PADDLE_WITH_PSCORE) -#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" #include "paddle/fluid/distributed/ps/table/depends/feature_value.h" #endif @@ -540,17 +540,17 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { &device_vals, &device_task_keys, &device_task_ptrs](int dev, int shard_id) { - auto& task_keys = device_task_keys[shard_id]; + // auto& task_keys = device_task_keys[shard_id]; #ifdef PADDLE_WITH_PSLIB auto& task_ptrs = device_task_ptrs[shard_id]; #endif -#ifdef PADDLE_WITH_PSCORE - auto& task_ptrs = device_task_ptrs[shard_id]; -#endif +// #ifdef PADDLE_WITH_PSCORE +// auto& task_ptrs = device_task_ptrs[shard_id]; +// #endif - int len = prefix_sum[dev][shard_id + 1] - prefix_sum[dev][shard_id]; - int cur = prefix_sum[dev][shard_id]; + // int len = prefix_sum[dev][shard_id + 1] - prefix_sum[dev][shard_id]; + // int cur = prefix_sum[dev][shard_id]; #ifdef PADDLE_WITH_PSLIB for (int j = 0; j < len; ++j) { device_keys[dev][cur + j] = task_keys[dev][j]; @@ -579,33 +579,33 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { } } #endif -#ifdef PADDLE_WITH_PSCORE - for (int j = 0; j < len; ++j) { - device_keys[dev][cur + j] = task_keys[dev][j]; - float* ptr_val = task_ptrs[dev][j]->data(); - FeatureValue& val = device_vals[dev][cur + j]; - size_t dim = task_ptrs[dev][j]->size(); - val.delta_score = ptr_val[2]; - val.show = ptr_val[3]; - val.clk = ptr_val[4]; - val.slot = ptr_val[0]; - val.lr = ptr_val[5]; - val.lr_g2sum = ptr_val[6]; - val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); - - if (dim > 7) { - val.mf_size = MF_DIM + 1; - for (int x = 0; x < val.mf_size; x++) { - val.mf[x] = ptr_val[x + 7]; - } - } else { - val.mf_size = 0; - for (int x = 0; x < MF_DIM + 1; x++) { - val.mf[x] = 0; - } - } - } -#endif +// #ifdef PADDLE_WITH_PSCORE + // for (int j = 0; j < len; ++j) { + // device_keys[dev][cur + j] = task_keys[dev][j]; + // float* ptr_val = task_ptrs[dev][j]->data(); + // FeatureValue& val = device_vals[dev][cur + j]; + // size_t dim = task_ptrs[dev][j]->size(); + // val.delta_score = ptr_val[2]; + // val.show = ptr_val[3]; + // val.clk = ptr_val[4]; + // val.slot = ptr_val[0]; + // val.lr = ptr_val[5]; + // val.lr_g2sum = ptr_val[6]; + // val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); + + // if (dim > 7) { + // val.mf_size = MF_DIM + 1; + // for (int x = 0; x < val.mf_size; x++) { + // val.mf[x] = ptr_val[x + 7]; + // } + // } else { + // val.mf_size = 0; + // for (int x = 0; x < MF_DIM + 1; x++) { + // val.mf[x] = 0; + // } + // } + // } +// #endif VLOG(3) << "GpuPs build hbmps done"; }; @@ -665,16 +665,21 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { return; } std::vector threads(device_num); - HeterPs_ = HeterPsBase::get_instance(size_max, resource_); + HeterPs_ = HeterPsBase::get_instance(size_max, resource_, feature_value_accessor_, optimizer_type_); #ifdef PADDLE_WITH_CUDA HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_); #endif auto build_dymf_mem_pool = [this, &gpu_task](int i, int j) { this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); + // this->HeterPs_->set_accessor(feature_value_accessor_); int mf_dim = this->index_dim_vec_[j]; - size_t feature_value_size = - TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float))); + VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim + << " feature_value_DIM:" << feature_value_accessor_.GetAccessorInfo().dim; + size_t feature_value_size = + TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + // size_t feature_value_size = + // TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1 + 1) * sizeof(float))); auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; size_t len = device_dim_keys.size(); @@ -765,23 +770,6 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: mf_dim_index()] = float(mf_dim); val->mf_dim = mf_dim; -#endif -#ifdef PADDLE_WITH_PSCORE - paddle::distributed::CtrDymfAccessor accessor; - val->delta_score = - ptr_val[accessor.common_feature_value.DeltaScoreIndex()]; - val->show = ptr_val[accessor.common_feature_value.ShowIndex()]; - val->clk = ptr_val[accessor.common_feature_value.ClickIndex()]; - val->slot = int(ptr_val[accessor.common_feature_value.SlotIndex()]); - val->lr = ptr_val[accessor.common_feature_value.EmbedWIndex()]; - val->lr_g2sum = ptr_val[accessor.common_feature_value.EmbedG2SumIndex()]; - - val->cpu_ptr = (uint64_t)(device_dim_ptrs[k]); - - // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor - ptr_val[accessor.common_feature_value.MfDimIndex()] = float(mf_dim); - val->mf_dim = mf_dim; -#endif if (dim > 8) { // CpuPS alreay expand as mf_dim val->mf_size = mf_dim + 1; for (int x = 0; x < val->mf_dim + 1; x++) { @@ -794,7 +782,70 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } } } - }; +#endif +#ifdef PADDLE_WITH_PSCORE + VLOG(0) << "cpu build "<< k << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) + << " |: "<< cpu_table_accessor_->ParseToString(ptr_val, dim); + // paddle::distributed::CtrDymfAccessor accessor; + // accessor.InitializeDim(embed_sgd_dim_, mf_dim, embedx_sgd_dim_); + // VLOG(0) << "cpu_table_accessor_ DIM:" << cpu_table_accessor_->GetAccessorInfo().dim; + val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()]; + val[feature_value_accessor_.common_feature_value.ShowIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.ShowIndex()]; + val[feature_value_accessor_.common_feature_value.ClickIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.ClickIndex()]; + val[feature_value_accessor_.common_feature_value.SlotIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.SlotIndex()]; + val[feature_value_accessor_.common_feature_value.EmbedWIndex()] = + ptr_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()]; + // val->lr_sgd_dim = 1; + // val->mf_sgd_dim = 1; + // // sgd: embed_sgd_dim=1; adam: embed_sgd_dim=1*2+2 + // for (int i = 0; i < val->lr_sgd_dim; i++) { + // val->lr_g2sum[i] = ptr_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i]; + // } + // VLOG(0)<< "EmbedDim:" << feature_value_accessor_.common_feature_value.EmbedDim(); + for (int i = 0; i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { + val[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + i] = + ptr_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i]; + } + + // VLOG(0)<< "CpuPtrIndex:" << feature_value_accessor_.common_feature_value.CpuPtrIndex(); + // reinterpret_cast(val[feature_value_accessor_.common_feature_value.CpuPtrIndex()]) = + // (uint64_t)(device_dim_ptrs[k]); + *(reinterpret_cast(val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = (uint64_t)(device_dim_ptrs[k]); + // (uint64_t*)(val + feature_value_accessor_.common_feature_value.CpuPtrIndex()) = (uint64_t)(device_dim_ptrs[k]); + // ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = float(mf_dim); + // val->mf_dim = mf_dim; + val[feature_value_accessor_.common_feature_value.MfDimIndex()] = mf_dim; + // VLOG(0) << " dim:" << dim + // << " DIM:" << feature_value_accessor_.GetAccessorInfo().dim + // << " MF_DIM:" << feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); + if (dim > cpu_table_accessor_->GetAccessorInfo().dim - + cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)) { + val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); + for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { + val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = + ptr_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x]; + } + for (int x = 0; x < mf_dim; x++) { + val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x] = + ptr_val[cpu_table_accessor_->common_feature_value.EmbedxWIndex() + x]; + } + } else { + val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = 0; + for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { + val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = 0; + } + for (int x = 0; x < mf_dim; x++) { + val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x] = 0; + } + } + VLOG(0) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); + } +#endif threads.resize(device_num * multi_mf_dim_); for (int i = 0; i < device_num; i++) { @@ -965,9 +1016,9 @@ void PSGPUWrapper::EndPass() { } // ============ multi-thread process feasign============ int mf_dim = this->index_dim_vec_[j]; - VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim; - size_t feature_value_size = - TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float))); + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim << " key_len :" << len; + size_t feature_value_size = + TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); char* test_build_values = (char*)malloc(feature_value_size * real_len); uint64_t offset = left * feature_value_size; cudaMemcpy(test_build_values, @@ -981,7 +1032,7 @@ void PSGPUWrapper::EndPass() { continue; } size_t local_offset = (i - left) * feature_value_size; - FeatureValue* gpu_val = (FeatureValue*)(test_build_values + local_offset); + float* gpu_val = (float*)(test_build_values + local_offset); #ifdef PADDLE_WITH_PSLIB auto* downpour_value = (paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr); @@ -1002,32 +1053,63 @@ void PSGPUWrapper::EndPass() { embed_g2sum_index()] = gpu_val->lr_g2sum; cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: slot_index()] = gpu_val->slot; + + if (gpu_val->mf_size > 0) { + for (int x = 0; x < gpu_val->mf_dim + 1; x++) { + cpu_val[x + 8] = gpu_val->mf[x]; + } + } + } #endif #ifdef PADDLE_WITH_PSCORE + // paddle::distributed::CtrDymfAccessor accessor; + // accessor.InitializeDim(embed_sgd_dim_, mf_dim, embedx_sgd_dim_); auto* downpour_value = - (paddle::distributed::FixedFeatureValue*)(gpu_val->cpu_ptr); - int downpour_value_size = downpour_value->size(); - if (gpu_val->mf_size > 0 && downpour_value_size == 8) { - downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size); + (paddle::distributed::FixedFeatureValue*)(reinterpret_cast(gpu_val+ feature_value_accessor_.common_feature_value.CpuPtrIndex())); + size_t downpour_value_size = downpour_value->size(); + VLOG(0) << "downpour_value_size:" <GetAccessorInfo().dim + << " MF_FIM:" << cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float); + if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0 && + downpour_value_size == (cpu_table_accessor_->GetAccessorInfo().dim - + cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float))) { // cpu_accessor + downpour_value->resize(cpu_table_accessor_->GetAccessorInfo().dim); } float* cpu_val = downpour_value->data(); - paddle::distributed::CtrDymfAccessor accessor; - cpu_val[accessor.common_feature_value.DeltaScoreIndex()] = - gpu_val->delta_score; - cpu_val[accessor.common_feature_value.ShowIndex()] = gpu_val->show; - cpu_val[accessor.common_feature_value.ClickIndex()] = gpu_val->clk; - cpu_val[accessor.common_feature_value.EmbedWIndex()] = gpu_val->lr; - cpu_val[accessor.common_feature_value.EmbedG2SumIndex()] = - gpu_val->lr_g2sum; - cpu_val[accessor.common_feature_value.SlotIndex()] = gpu_val->slot; -#endif - if (gpu_val->mf_size > 0) { - for (int x = 0; x < gpu_val->mf_dim + 1; x++) { - cpu_val[x + 8] = gpu_val->mf[x]; + cpu_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()]; + cpu_val[cpu_table_accessor_->common_feature_value.ShowIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.ShowIndex()]; + cpu_val[cpu_table_accessor_->common_feature_value.ClickIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.ClickIndex()]; + cpu_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.EmbedWIndex()]; + cpu_val[cpu_table_accessor_->common_feature_value.SlotIndex()] = + gpu_val[feature_value_accessor_.common_feature_value.SlotIndex()]; + + for (int i = 0; i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { + cpu_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i] = + gpu_val[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + i]; + } + + if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0) { + for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { + cpu_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x] = + gpu_val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x]; + } + for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedWDim(); x++) { + cpu_val[cpu_table_accessor_->common_feature_value.EmbedxWIndex() + x] = + gpu_val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x]; } } + VLOG(0) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.GetAccessorInfo().dim) + << " =====CPU:" << cpu_table_accessor_->ParseToString(cpu_val, cpu_table_accessor_->GetAccessorInfo().update_dim); + } + + +#endif free(test_build_values); }; if (multi_mf_dim_) { @@ -1066,80 +1148,73 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, const std::vector& values, const std::vector& slot_lengths, const int hidden_size) { - platform::Timer all_timer; - platform::Timer pull_gpups_timer; - all_timer.Start(); - int64_t total_length = - std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); - VLOG(3) << "Begine Gpu/Xpu Ps PullSparse"; - auto buf = memory::Alloc(place, total_length * sizeof(FeatureValue)); - FeatureValue* total_values_gpu = reinterpret_cast(buf->ptr()); - if (platform::is_cpu_place(place)) { - PADDLE_THROW(platform::errors::Unimplemented( - "Warning:: CPUPlace is not supported in GpuPs now.")); - } else if (platform::is_gpu_place(place)) { -#ifdef PADDLE_WITH_CUDA - VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; - int device_id = place.GetDeviceId(); - int devid_2_index = HeterPs_->get_index_by_devid(device_id); - LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; - uint64_t* total_keys = reinterpret_cast( - total_keys_tensor.mutable_data({total_length, 1}, place)); - - // construct slot_level lod info - auto slot_lengths_lod = slot_lengths; - for (size_t i = 1; i < slot_lengths_lod.size(); i++) { - slot_lengths_lod[i] += slot_lengths_lod[i - 1]; - } - auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*)); - auto buf_length = - memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); - uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); - int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); - cudaMemcpy(gpu_keys, - keys.data(), - keys.size() * sizeof(uint64_t*), - cudaMemcpyHostToDevice); - cudaMemcpy(gpu_len, - slot_lengths_lod.data(), - slot_lengths.size() * sizeof(int64_t), - cudaMemcpyHostToDevice); - - this->CopyKeys(place, - gpu_keys, - total_keys, - gpu_len, - static_cast(slot_lengths.size()), - static_cast(total_length)); - VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index - << " len: " << total_length; - pull_gpups_timer.Start(); - HeterPs_->pull_sparse(devid_2_index, - total_keys, - total_values_gpu, - static_cast(total_length)); - pull_gpups_timer.Pause(); - - VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length - << "]"; - this->CopyForPull(place, - gpu_keys, - values, - total_values_gpu, - gpu_len, - static_cast(slot_lengths.size()), - hidden_size, - total_length); - } else { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "GpuPs: PullSparse Only Support CUDAPlace Now.")); } - all_timer.Pause(); - VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec() - << " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec() - << " s"; - VLOG(3) << "End PullSparse"; -} +// void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, +// const int table_id, +// const std::vector& keys, +// const std::vector& values, +// const std::vector& slot_lengths, +// const int hidden_size) { +// platform::Timer all_timer; +// platform::Timer pull_gpups_timer; +// all_timer.Start(); +// int64_t total_length = +// std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); +// VLOG(3) << "Begine Gpu/Xpu Ps PullSparse"; +// auto buf = memory::Alloc(place, total_length * sizeof(FeatureValue)); +// FeatureValue* total_values_gpu = reinterpret_cast(buf->ptr()); +// if (platform::is_cpu_place(place)) { +// PADDLE_THROW(platform::errors::Unimplemented( +// "Warning:: CPUPlace is not supported in GpuPs now.")); +// } else if (platform::is_gpu_place(place)) { +// #ifdef PADDLE_WITH_CUDA +// VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; +// int device_id = place.GetDeviceId(); +// int devid_2_index = HeterPs_->get_index_by_devid(device_id); +// LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; +// uint64_t* total_keys = reinterpret_cast( +// total_keys_tensor.mutable_data({total_length, 1}, place)); + +// // construct slot_level lod info +// auto slot_lengths_lod = slot_lengths; +// for (size_t i = 1; i < slot_lengths_lod.size(); i++) { +// slot_lengths_lod[i] += slot_lengths_lod[i - 1]; +// } +// auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*)); +// auto buf_length = +// memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); +// uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); +// int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); +// cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), +// cudaMemcpyHostToDevice); +// cudaMemcpy(gpu_len, slot_lengths_lod.data(), +// slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); + +// this->CopyKeys(place, gpu_keys, total_keys, gpu_len, +// static_cast(slot_lengths.size()), +// static_cast(total_length)); +// VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index +// << " len: " << total_length; +// pull_gpups_timer.Start(); +// HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu, +// static_cast(total_length)); +// pull_gpups_timer.Pause(); + +// VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length +// << "]"; +// this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len, +// static_cast(slot_lengths.size()), hidden_size, +// total_length); +// } else { +// PADDLE_THROW(platform::errors::PreconditionNotMet( +// "GpuPs: PullSparse Only Support CUDAPlace Now.")); +// } +// all_timer.Pause(); +// VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec() +// << " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec() +// << " s"; +// VLOG(3) << "End PullSparse"; +// } void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, const int table_id, @@ -1156,13 +1231,15 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); size_t feature_value_size = 0; - feature_value_size = TYPEALIGN( - 8, sizeof(FeatureValue) + sizeof(float) * (index_dim_vec_.back() + 1)); - + // feature_value_size = TYPEALIGN( + // 8, sizeof(FeatureValue) + sizeof(float) * (index_dim_vec_.back() + 1)); + feature_value_size = TYPEALIGN( 8, feature_value_accessor_.GetAccessorInfo().size); + VLOG(0) << "PULLSPASE" << feature_value_accessor_.GetAccessorInfo().size; + #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begine Gpu Ps PullSparse"; auto buf = memory::Alloc(place, total_length * feature_value_size); - FeatureValue* total_values_gpu = reinterpret_cast(buf->ptr()); + float* total_values_gpu = reinterpret_cast(buf->ptr()); #endif #ifdef PADDLE_WITH_XPU_KP VLOG(3) << "Begine Xpu Ps PullSparse"; @@ -1221,6 +1298,14 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, HeterPs_->pull_sparse( devid_2_index, total_keys, total_values_gpu, total_length); + // char* tmp_mem = (char*)malloc(total_length * feature_value_size); + // cudaMemcpy(tmp_mem, total_values_gpu, total_length * feature_value_size, + // cudaMemcpyDeviceToHost); + // for (int i =0 ; i < 20; i++){ + // float* val = (float*)(void*)&tmp_mem[(i)*feature_value_size]; + // VLOG(0) << "pullsparse_cpu "<< i << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); + // } + VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length << "]"; @@ -1317,12 +1402,12 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); // #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begin GPUPS PushSparseGrad"; - size_t grad_value_size = - TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + size_t grad_value_size = + TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); auto buf = memory::Alloc(place, total_length * grad_value_size); VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_; - FeaturePushValue* total_grad_values_gpu = - reinterpret_cast(buf->ptr()); + float* total_grad_values_gpu = + reinterpret_cast(buf->ptr()); if (platform::is_cpu_place(place)) { PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in GPUPS now.")); @@ -1334,23 +1419,13 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, uint64_t* total_keys = reinterpret_cast(cached_total_keys_tensor.data()); VLOG(3) << "Begin copy grad tensor to gpups struct"; - if (!multi_mf_dim_) { - this->CopyForPush(place, - grad_values, - total_grad_values_gpu, - slot_lengths, - hidden_size, - total_length, - batch_size); - } else { - this->CopyForPush(place, - grad_values, - total_grad_values_gpu, - slot_lengths, - total_length, - batch_size, - grad_value_size); - } + this->CopyForPush(place, + grad_values, + total_grad_values_gpu, + slot_lengths, + total_length, + batch_size, + grad_value_size); VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index << " len: " << total_length; @@ -1401,4 +1476,4 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, } // end namespace framework } // end namespace paddle -#endif +// #endif diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 734765fa954238..e14b78522ada35 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -67,13 +67,14 @@ __global__ void PullCopy(float** dest, } __global__ void PullCopy(float** dest, - const FeatureValue* src, + const float* src, const int64_t* len, int slot_num, int total_len, uint64_t** keys, uint64_t max_val_size, - int* gpu_dim) { + int* gpu_dim, + CommonFeatureValueAccessor feature_value_accessor) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; int high = slot_num - 1; @@ -86,25 +87,26 @@ __global__ void PullCopy(float** dest, } int x = low; int y = i - (x ? len[x - 1] : 0); - FeatureValue* feature_value_ptr = - (FeatureValue*)((char*)src + uint64_t(i) * uint64_t(max_val_size)); + float* feature_value_ptr = + (float*)((char*)src + uint64_t(i) * uint64_t(max_val_size)); int mf_dim = gpu_dim[x] - 3; if (*(keys[x] + y) == 0) { *(dest[x] + y * (mf_dim + 3)) = 0; *(dest[x] + y * (mf_dim + 3) + 1) = 0; *(dest[x] + y * (mf_dim + 3) + 2) = 0; } else { - *(dest[x] + y * (mf_dim + 3)) = feature_value_ptr->show; - *(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr->clk; - *(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr->lr; + *(dest[x] + y * (mf_dim + 3)) = feature_value_ptr[feature_value_accessor.common_feature_value.ShowIndex()]; + *(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr[feature_value_accessor.common_feature_value.ClickIndex()]; + *(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr[feature_value_accessor.common_feature_value.EmbedWIndex()]; } - if ((feature_value_ptr)->mf_size == 0 || *(keys[x] + y) == 0) { + + if (feature_value_ptr[feature_value_accessor.common_feature_value.MfSizeIndex()] == 0 || *(keys[x] + y) == 0) { for (int j = 0; j < mf_dim; j++) { *(dest[x] + y * (mf_dim + 3) + 3 + j) = 0; } } else { for (int j = 0; j < mf_dim; j++) { - *(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr->mf[1 + j]; + *(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr[feature_value_accessor.common_feature_value.EmbedxWIndex() + j]; } } } @@ -161,7 +163,7 @@ __global__ void PushCopy(FeaturePushValue* dest, } } -__global__ void PushCopyWithPool(FeaturePushValue* dest, +__global__ void PushCopyWithPool(float* dest, float** src, int64_t* len, int slot_num, @@ -169,7 +171,8 @@ __global__ void PushCopyWithPool(FeaturePushValue* dest, int bs, int* slot_vector, int* mf_dim_vector, - size_t grad_value_size) { + size_t grad_value_size, + CommonFeatureValueAccessor feature_value_accessor) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; int high = slot_num - 1; @@ -182,16 +185,24 @@ __global__ void PushCopyWithPool(FeaturePushValue* dest, } int x = low; int y = i - (x ? len[low - 1] : 0); - FeaturePushValue* cur = - (FeaturePushValue*)((char*)dest + i * grad_value_size); - cur->slot = slot_vector[x]; + float* cur = + (float*)((char*)dest + i * grad_value_size); + + cur[feature_value_accessor.common_push_value.SlotIndex()] = + (float)slot_vector[x]; int mf_dim = mf_dim_vector[x]; - cur->mf_dim = mf_dim; - cur->show = *(src[x] + y * (mf_dim + 3)); - cur->clk = *(src[x] + y * (mf_dim + 3) + 1); - cur->lr_g = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; - for (int j = 0; j < cur->mf_dim; j++) { - cur->mf_g[j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; + cur[feature_value_accessor.common_push_value.MfDimIndex()] = mf_dim; + + cur[feature_value_accessor.common_push_value.ShowIndex()] = + *(src[x] + y * (mf_dim + 3)); + cur[feature_value_accessor.common_push_value.ClickIndex()] = + *(src[x] + y * (mf_dim + 3) + 1); + cur[feature_value_accessor.common_push_value.EmbedGIndex()] = + *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; + // printf("PushCopyWithPool show:%f ,click: %d\n", cur[feature_value_accessor.common_push_value.ShowIndex()], + // cur[feature_value_accessor.common_push_value.ClickIndex()]); + for (int j = 0; j < mf_dim; j++) { + cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; } } } @@ -229,7 +240,7 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, const std::vector& values, - const FeatureValue* total_values_gpu, + const float* total_values_gpu, const int64_t* gpu_len, const int slot_num, const int hidden_size, @@ -252,7 +263,8 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, total_length, gpu_keys, val_type_size_, - gpu_dim); + gpu_dim, + feature_value_accessor_); cudaStreamSynchronize(stream); } @@ -321,7 +333,7 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, const std::vector& grad_values, - FeaturePushValue* total_grad_values_gpu, + float* total_grad_values_gpu, const std::vector& slot_lengths, const uint64_t total_length, const int batch_size, @@ -369,7 +381,8 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, batch_size, d_slot_vector, d_mf_dim_vector, - grad_value_size); + grad_value_size, + feature_value_accessor_); cudaStreamSynchronize(stream); } @@ -379,7 +392,10 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float max_bound, float learning_rate, float initial_g2sum, - float initial_range) { + float initial_range, + float beta1_decay_rate, + float beta2_decay_rate, + float ada_epsilon) { OptimizerConfig optimizer_config; optimizer_config.set_sparse_sgd(nonclk_coeff, clk_coeff, @@ -387,7 +403,10 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, max_bound, learning_rate, initial_g2sum, - initial_range); + initial_range, + beta1_decay_rate, + beta2_decay_rate, + ada_epsilon); HeterPs_->set_sparse_sgd(optimizer_config); } @@ -396,14 +415,20 @@ void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, float mf_initial_g2sum, float mf_initial_range, float mf_min_bound, - float mf_max_bound) { + float mf_max_bound, + float mf_beta1_decay_rate, + float mf_beta2_decay_rate, + float mf_ada_epsilon) { OptimizerConfig optimizer_config; optimizer_config.set_embedx_sgd(mf_create_thresholds, mf_learning_rate, mf_initial_g2sum, mf_initial_range, - mf_min_bound, - mf_max_bound); + mf_min_bound, + mf_max_bound, + mf_beta1_decay_rate, + mf_beta2_decay_rate, + mf_ada_epsilon); HeterPs_->set_embedx_sgd(optimizer_config); } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 0e816beef0d331..34242aa12212b2 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -41,6 +41,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/dynload/nccl.h" +// #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #endif #ifdef PADDLE_WITH_XPU_KP #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" @@ -52,6 +53,9 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/ps/wrapper/fleet.h" +#include "paddle/fluid/distributed/ps/table/accessor.h" +#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" +#include "paddle/fluid/distributed/ps.pb.h" #endif #ifdef PADDLE_WITH_PSLIB #include "afs_api.h" @@ -95,9 +99,17 @@ class AfsWrapper { }; #endif +enum OptimizerType { + OPTIMIZER_NAIVE = 0, + OPTIMIZER_ADAGRAD = 1, + OPTIMIZER_STDADAGRAD = 2, + OPTIMIZER_ADAM = 3, + OPTIMIZER_SHARDADAM = 4, +}; + class PSGPUWrapper { public: - virtual ~PSGPUWrapper(); + ~PSGPUWrapper(); PSGPUWrapper() { HeterPs_ = NULL; @@ -149,7 +161,7 @@ class PSGPUWrapper { void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, const std::vector& values, - const FeatureValue* total_values_gpu, + const float* total_values_gpu, const int64_t* gpu_len, const int slot_num, const int hidden_size, @@ -164,7 +176,7 @@ class PSGPUWrapper { const int batch_size); void CopyForPush(const paddle::platform::Place& place, const std::vector& grad_values, - FeaturePushValue* total_grad_values_gpu, + float* total_grad_values_gpu, const std::vector& slot_lengths, const uint64_t total_length, const int batch_size, @@ -273,13 +285,200 @@ class PSGPUWrapper { float max_bound, float learning_rate, float initial_g2sum, - float initial_range); + float initial_range, + float beta1_decay_rate, + float beta2_decay_rate, + float ada_epsilon); void SetEmbedxSGD(float mf_create_thresholds, float mf_learning_rate, float mf_initial_g2sum, float mf_initial_range, float mf_min_bound, - float mf_max_bound); + float mf_max_bound, + float mf_beta1_decay_rate, + float mf_beta2_decay_rate, + float mf_ada_epsilon); + +#ifdef PADDLE_WITH_PSCORE + void add_sparse_optimizer( + std::unordered_map& config, // NOLINT + const ::paddle::distributed::SparseCommonSGDRuleParameter& sgd_param, + const std::string& prefix = "") { + auto optimizer_name = sgd_param.name(); + if (optimizer_name == "SparseNaiveSGDRule") { + config[prefix + "optimizer_type"] = 0; + config[prefix + "learning_rate"] = sgd_param.naive().learning_rate(); + config[prefix + "initial_range"] = sgd_param.naive().initial_range(); + config[prefix + "min_bound"] = sgd_param.naive().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.naive().weight_bounds()[1]; + } else if (optimizer_name == "SparseAdaGradSGDRule") { + config[prefix + "optimizer_type"] = 1; + config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate(); + config[prefix + "initial_range"] = sgd_param.adagrad().initial_range(); + config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum(); + config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1]; + } else if (optimizer_name == "StdAdaGradSGDRule") { + config[prefix + "optimizer_type"] = 2; + config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate(); + config[prefix + "initial_range"] = sgd_param.adagrad().initial_range(); + config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum(); + config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1]; + } else if (optimizer_name == "SparseAdamSGDRule") { + config[prefix + "optimizer_type"] = 3; + config[prefix + "learning_rate"] = sgd_param.adam().learning_rate(); + config[prefix + "initial_range"] = sgd_param.adam().initial_range(); + config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate(); + config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate(); + config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon(); + config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1]; + } else if (optimizer_name == "SparseAdamSharedSGDRule") { + config[prefix + "optimizer_type"] = 4; + config[prefix + "learning_rate"] = sgd_param.adam().learning_rate(); + config[prefix + "initial_range"] = sgd_param.adam().initial_range(); + config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate(); + config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate(); + config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon(); + config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1]; + } + } + + void InitializeGPUServer(paddle::distributed::PSParameter ps_param) { + auto sparse_table = + ps_param.server_param().downpour_server_param().downpour_table_param(0); + auto sparse_table_accessor = sparse_table.accessor(); + auto sparse_table_accessor_parameter = + sparse_table_accessor.ctr_accessor_param(); + auto accessor_class = sparse_table_accessor.accessor_class(); + + // NOTE(zhangminxu): gpups' sparse table optimizer config, + // now only support single sparse table + // auto sparse_table = param_.sparse_table(0); + std::unordered_map config; + config["embedx_dim"] = sparse_table_accessor.embedx_dim(); + config["nonclk_coeff"] = sparse_table_accessor_parameter.nonclk_coeff(); + config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff(); + + + if (accessor_class == "CtrDymfAccessor") { + // optimizer config for embed_w and embedx + add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param()); + add_sparse_optimizer(config, sparse_table_accessor.embedx_sgd_param(), + "mf_"); + } + + // CommonFeatureValueAccessor feature_value_accessor_; + feature_value_accessor_.Configure(config); + VLOG(0) << "INIT feature_value_accessor_:" << feature_value_accessor_.GetAccessorInfo().dim + << " EMBX:" << feature_value_accessor_.common_feature_value.embedx_sgd_dim; + InitializeGPUServer(config); + } + #endif + +// void InitializeGPUServer(std::unordered_map config) { +// float nonclk_coeff = (config.find("sparse_nonclk_coeff") == config.end()) +// ? 1.0 +// : std::stof(config["sparse_nonclk_coeff"]); +// float clk_coeff = +// (config.find("sparse_click_coeff") == config.end()) ? 1.0 : std::stof(config["sparse_click_coeff"]); +// float min_bound = (config.find("min_bound") == config.end()) +// ? -10.0 +// : std::stof(config["min_bound"]); +// float max_bound = (config.find("max_bound") == config.end()) +// ? 10.0 +// : std::stof(config["max_bound"]); +// float learning_rate = (config.find("sparse_learning_rate") == config.end()) +// ? 0.05 +// : std::stof(config["sparse_learning_rate"]); +// float initial_g2sum = (config.find("sparse_initial_g2sum") == config.end()) +// ? 3.0 +// : std::stof(config["sparse_initial_g2sum"]); +// float initial_range = (config.find("sparse_initial_range") == config.end()) +// ? 1e-4 +// : std::stof(config["sparse_initial_range"]); +// float beta1_decay_rate = (config.find("embed_sparse_beta1_decay_rate") == config.end()) +// ? 0.9 +// : std::stof(config["embed_sparse_beta1_decay_rate"]); +// float beta2_decay_rate = (config.find("embed_sparse_beta2_decay_rate") == config.end()) +// ? 0.999 +// : std::stof(config["embed_sparse_beta2_decay_rate"]); +// float ada_epsilon = (config.find("embed_sparse_ada_epsilon") == config.end()) +// ? 1e-8 +// : std::stof(config["embed_sparse_ada_epsilon"]); +// // mf config settings +// float mf_create_thresholds = +// (config.find("sparse_embedx_threshold") == config.end()) +// ? static_cast(1.0) +// : std::stof(config["sparse_embedx_threshold"]); +// float mf_learning_rate = (config.find("embedx_sparse_learning_rate") == config.end()) +// ? 0.05 +// : std::stof(config["embedx_sparse_learning_rate"]); +// float mf_initial_g2sum = (config.find("sparse_initial_g2sum") == config.end()) +// ? 3.0 +// : std::stof(config["sparse_initial_g2sum"]); +// float mf_initial_range = (config.find("embedx_sparse_initial_range") == config.end()) +// ? 1e-4 +// : std::stof(config["embedx_sparse_initial_range"]); +// float mf_min_bound = (config.find("mf_min_bound") == config.end()) +// ? -10.0 +// : std::stof(config["mf_min_bound"]); +// float mf_max_bound = (config.find("mf_max_bound") == config.end()) +// ? 10.0 +// : std::stof(config["mf_max_bound"]); +// float mf_beta1_decay_rate = (config.find("embedx_sparse_beta1_decay_rate") == config.end()) +// ? 0.9 +// : std::stof(config["embedx_sparse_beta1_decay_rate"]); +// float mf_beta2_decay_rate = (config.find("embedx_sparse_beta2_decay_rate") == config.end()) +// ? 0.999 +// : std::stof(config["embedx_sparse_beta2_decay_rate"]); +// float mf_ada_epsilon = (config.find("embedx_sparse_ada_epsilon") == config.end()) +// ? 1e-8 +// : std::stof(config["embedx_sparse_ada_epsilon"]); +// for (size_t i = 0; i < heter_devices_.size(); i++) { +// #ifdef PADDLE_WITH_CUDA +// PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i])); +// #elif defined(PADDLE_WITH_XPU_KP) +// PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i])); +// #endif +// this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, max_bound, +// learning_rate, initial_g2sum, initial_range, +// beta1_decay_rate, beta2_decay_rate, ada_epsilon); +// this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate, +// mf_initial_g2sum, mf_initial_range, mf_min_bound, +// mf_max_bound, mf_beta1_decay_rate, mf_beta2_decay_rate, +// mf_ada_epsilon); + +// // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam) +// // optimizer_type_ = (config.find("optimizer_type") == config.end()) +// // ? 1 +// // : config["optimizer_type"]; +// optimizer_name_ = (config.find("embed_sparse_optimizer") == config.end()) +// ? "adagrad" +// : config["embed_sparse_optimizer"]; +// CommonFeatureValueAccessor feature_value_accessor_; +// feature_value_accessor_.Configure(config); +// embedx_dim_ = (config.find("sparse_embedx_dim") == config.end()) +// ? 8 +// : std::stoi(config["sparse_embedx_dim"]); +// if(optimizer_name_ == "adagrad") { +// embed_sgd_dim_ = 1; +// embedx_sgd_dim_ = 1; +// } else if (optimizer_name_ == "adam") { +// embed_sgd_dim_ = 4; +// embedx_sgd_dim_ = embedx_dim_ * 2 + 2; +// } else if (optimizer_name_ == "sharedadam") { +// embed_sgd_dim_ = 4; +// embedx_sgd_dim_ = 4; +// } else { +// embed_sgd_dim_ = 1; +// embedx_sgd_dim_ = 1; +// } +// } +// } +>>>>>>> 4e395c7ebd... add adam/sharedadam optimzier for gpups;edit optimizer struct;test=develop void InitializeGPUServer(std::unordered_map config) { float nonclk_coeff = (config.find("nonclk_coeff") == config.end()) ? 1.0 @@ -287,10 +486,10 @@ class PSGPUWrapper { float clk_coeff = (config.find("clk_coeff") == config.end()) ? 1.0 : config["clk_coeff"]; float min_bound = (config.find("min_bound") == config.end()) - ? -10000.0 + ? -10.0 : config["min_bound"]; float max_bound = (config.find("max_bound") == config.end()) - ? 10000.0 + ? 10.0 : config["max_bound"]; float learning_rate = (config.find("learning_rate") == config.end()) ? 1.0 @@ -301,7 +500,15 @@ class PSGPUWrapper { float initial_range = (config.find("initial_range") == config.end()) ? 1.0 : config["initial_range"]; - + float beta1_decay_rate = (config.find("beta1_decay_rate") == config.end()) + ? 0.9 + : config["beta1_decay_rate"]; + float beta2_decay_rate = (config.find("beta2_decay_rate") == config.end()) + ? 0.999 + : config["beta2_decay_rate"]; + float ada_epsilon = (config.find("ada_epsilon") == config.end()) + ? 1e-8 + : config["ada_epsilon"]; // mf config settings float mf_create_thresholds = (config.find("mf_create_thresholds") == config.end()) @@ -322,6 +529,15 @@ class PSGPUWrapper { float mf_max_bound = (config.find("mf_max_bound") == config.end()) ? 1.0 : config["mf_max_bound"]; + float mf_beta1_decay_rate = (config.find("mf_beta1_decay_rate") == config.end()) + ? 0.9 + : config["mf_beta1_decay_rate"]; + float mf_beta2_decay_rate = (config.find("mf_beta2_decay_rate") == config.end()) + ? 0.999 + : config["mf_beta2_decay_rate"]; + float mf_ada_epsilon = (config.find("mf_ada_epsilon") == config.end()) + ? 1e-8 + : config["mf_ada_epsilon"]; for (size_t i = 0; i < heter_devices_.size(); i++) { #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i])); @@ -334,14 +550,42 @@ class PSGPUWrapper { max_bound, learning_rate, initial_g2sum, - initial_range); + initial_range, + beta1_decay_rate, + beta2_decay_rate, + ada_epsilon); this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate, mf_initial_g2sum, mf_initial_range, mf_min_bound, - mf_max_bound); + mf_max_bound, + mf_beta1_decay_rate, + mf_beta2_decay_rate, + mf_ada_epsilon); + } + + // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam) + optimizer_type_ = (config.find("optimizer_type") == config.end()) + ? 1 + : int(config["optimizer_type"]); + embedx_dim_ = (config.find("embedx_dim") == config.end()) + ? 8 + : int(config["embedx_dim"]); + if (optimizer_type_ == 3) { //adam + embed_sgd_dim_ = 4; + embedx_sgd_dim_ = embedx_dim_ * 2 + 2; + } else if (optimizer_type_ == 4) { //sharedadam + embed_sgd_dim_ = 4; + embedx_sgd_dim_ = 4; + } else { + embed_sgd_dim_ = 1; + embedx_sgd_dim_ = 1; } + + VLOG(0) << "InitializeGPUServer embed_sgd_dim_:" << embed_sgd_dim_ << " embedx_sgd_dim_:" + << embedx_sgd_dim_ << " embedx_dim_:" << embedx_dim_ + << " optimizer_type_:" << optimizer_type_; } void SetDate(int year, int month, int day) { @@ -386,7 +630,7 @@ class PSGPUWrapper { if (slot_info_initialized_) { return; } - SlotRecordDataset* dataset = dynamic_cast(dataset_); + SlotRecordDataset* dataset = (SlotRecordDataset*)(dataset_); auto slots_vec = dataset->GetSlots(); slot_offset_vector_.clear(); for (auto& slot : slot_vector_) { @@ -427,10 +671,17 @@ class PSGPUWrapper { for (size_t i = 0; i < slot_index_vec_.size(); i++) { slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]]; } - val_type_size_ = - TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1)); - grad_type_size_ = - TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); + //TODO(FENGDANLEI): max_mf + VLOG(0) << "InitSlotInfo embed_sgd_dim_:" << embed_sgd_dim_ << " embedx_sgd_dim_:" + << embedx_sgd_dim_ << " embedx_dim_:" << embedx_dim_ + << " optimizer_type_:" << optimizer_type_; + VLOG(0) << "InitSlotInfo:" << feature_value_accessor_.GetAccessorInfo().size; + val_type_size_ =TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + // val_type_size_ = + // TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1)); + grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + // grad_type_size_ = + // TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); slot_info_initialized_ = true; } #endif @@ -451,6 +702,13 @@ class PSGPUWrapper { const std::string& conf); #endif +#ifdef PADDLE_WITH_PSCORE + void SetTableAccessor(paddle::distributed::ValueAccessor* accessor) { + cpu_table_accessor_ = dynamic_cast(accessor); + } +#endif + + CommonFeatureValueAccessor feature_value_accessor_; private: static std::shared_ptr s_instance_; Dataset* dataset_; @@ -503,6 +761,13 @@ class PSGPUWrapper { int day_; bool slot_info_initialized_ = false; int use_afs_api_ = 0; + int optimizer_type_ = 1; + int embed_sgd_dim_ = 1; + int embedx_sgd_dim_ = 1; + int embedx_dim_ = 8; +#ifdef PADDLE_WITH_PSCORE + paddle::distributed::CtrDymfAccessor* cpu_table_accessor_; +#endif #ifdef PADDLE_WITH_CUDA std::vector mem_pools_; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps b/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps index ef6c70e624d4cf..369a20874d42e3 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps @@ -277,21 +277,25 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, float min_bound, float max_bound, float learning_rate, float initial_g2sum, - float initial_range) { + float initial_range, float beta1_decay_rate, + float beta2_decay_rate, float ada_epsilon) { OptimizerConfig optimizer_config; optimizer_config.set_sparse_sgd(nonclk_coeff, clk_coeff, min_bound, max_bound, - learning_rate, initial_g2sum, initial_range); + learning_rate, initial_g2sum, initial_range, + beta1_decay_rate, beta2_decay_rate, ada_epsilon); HeterPs_->set_sparse_sgd(optimizer_config); } void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, float mf_learning_rate, float mf_initial_g2sum, float mf_initial_range, float mf_min_bound, - float mf_max_bound) { + float mf_max_bound, float mf_beta1_decay_rate, + float mf_beta2_decay_rate, float mf_ada_epsilon) { OptimizerConfig optimizer_config; optimizer_config.set_embedx_sgd(mf_create_thresholds, mf_learning_rate, mf_initial_g2sum, mf_initial_range, - mf_min_bound, mf_max_bound); + mf_min_bound, mf_max_bound,mf_beta1_decay_rate, + mf_beta2_decay_rate, mf_ada_epsilon); HeterPs_->set_embedx_sgd(optimizer_config); } diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index c2c05f373c2d23..5bca9650d6f29f 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +// #include #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" @@ -22,6 +23,9 @@ limitations under the License. */ #include "paddle/fluid/distributed/ps/service/communicator/communicator.h" #endif +// #if defined PADDLE_WITH_HETERPS +// #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" +// #endif namespace paddle { namespace framework { @@ -46,6 +50,11 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, places_.push_back(place); } #endif + +// #if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE +// InitializeGPUServer(trainer_desc); +// #endif + // get filelist from trainer_desc here const std::vector readers = dataset->GetReaders(); @@ -79,6 +88,83 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, SetDebug(trainer_desc.debug()); } +#if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE +void add_sparse_optimizer( + std::unordered_map& config, // NOLINT + const ::paddle::distributed::SparseCommonSGDRuleParameter& sgd_param, + const std::string& prefix = "") { + auto optimizer_name = sgd_param.name(); + if (optimizer_name == "SparseNaiveSGDRule") { + config[prefix + "optimizer_type"] = 0; + config[prefix + "learning_rate"] = sgd_param.naive().learning_rate(); + config[prefix + "initial_range"] = sgd_param.naive().initial_range(); + config[prefix + "min_bound"] = sgd_param.naive().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.naive().weight_bounds()[1]; + } else if (optimizer_name == "SparseAdaGradSGDRule") { + config[prefix + "optimizer_type"] = 1; + config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate(); + config[prefix + "initial_range"] = sgd_param.adagrad().initial_range(); + config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum(); + config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1]; + } else if (optimizer_name == "StdAdaGradSGDRule") { + config[prefix + "optimizer_type"] = 2; + config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate(); + config[prefix + "initial_range"] = sgd_param.adagrad().initial_range(); + config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum(); + config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1]; + } else if (optimizer_name == "SparseAdamSGDRule") { + config[prefix + "optimizer_type"] = 3; + config[prefix + "learning_rate"] = sgd_param.adam().learning_rate(); + config[prefix + "initial_range"] = sgd_param.adam().initial_range(); + config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate(); + config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate(); + config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon(); + config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1]; + } else if (optimizer_name == "SparseAdamSharedSGDRule") { + config[prefix + "optimizer_type"] = 4; + config[prefix + "learning_rate"] = sgd_param.adam().learning_rate(); + config[prefix + "initial_range"] = sgd_param.adam().initial_range(); + config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate(); + config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate(); + config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon(); + config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0]; + config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1]; + } +} + +// void MultiTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) { +// // optimizer config for hbmps +// auto fleet_desc_str = trainer_desc.fleet_desc(); +// VLOG(0) << "InitializeGPUServer fleet_desc_str" << fleet_desc_str; +// google::protobuf::TextFormat::ParseFromString(fleet_desc_str, &_ps_param); +// auto sparse_table = +// _ps_param.server_param().downpour_server_param().downpour_table_param(0); +// auto sparse_table_accessor = sparse_table.accessor(); +// auto sparse_table_accessor_parameter = +// sparse_table_accessor.ctr_accessor_param(); +// auto accessor_class = sparse_table_accessor.accessor_class(); + +// // NOTE(zhangminxu): gpups' sparse table optimizer config, +// // now only support single sparse table +// // auto sparse_table = param_.sparse_table(0); +// std::unordered_map config; +// config["nonclk_coeff"] = sparse_table_accessor_parameter.nonclk_coeff(); +// config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff(); + +// if (accessor_class == "CtrDymfAccessor") { +// // optimizer config for embed_w and embedx +// add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param()); +// add_sparse_optimizer(config, sparse_table_accessor.embedx_sgd_param(), +// "mf_"); +// } +// auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance(); +// ps_gpu_wrapper->InitializeGPUServer(config); +// } +#endif + std::string MultiTrainer::GetDumpPath(int tid) { if (user_define_dump_filename_ != "") { return string::format_string("%s/part-%s-%05d", @@ -304,5 +390,7 @@ void MultiTrainer::Finalize() { root_scope_->DropKids(); } + + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 1a805ccd76e440..75e18865505c6a 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -122,6 +122,9 @@ class MultiTrainer : public TrainerBase { void MergeDenseParam(); #endif +// #if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE +// void InitializeGPUServer(const TrainerDesc& trainer_desc); +// #endif protected: int thread_num_; @@ -132,7 +135,13 @@ class MultiTrainer : public TrainerBase { std::vector trainable_param_; #ifdef PADDLE_WITH_HETERPS std::vector places_; + // _ps_param for gpups optimizer config #endif + +// #if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE +// ::paddle::distributed::PSParameter _ps_param; +// #endif + int mpi_rank_; int mpi_size_; int dump_file_num_; diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc index e9c993d3ee1282..da833641227b1f 100644 --- a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -41,6 +41,8 @@ void BindPSGPUWrapper(py::module* m) { .def("set_slot_vector", &framework::PSGPUWrapper::SetSlotVector, py::call_guard()) + // .def("init_GPU_server", &framework::PSGPUWrapper::InitializeGPUServer, + // py::call_guard()) #ifdef PADDLE_WITH_CUDA .def("set_slot_dim_vector", &framework::PSGPUWrapper::SetSlotDimVector, diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index f4f2076cd12b79..1b0acd79924086 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -39,7 +39,7 @@ from paddle.fluid.dygraph import to_variable from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar - +from paddle.distributed.fleet.proto import the_one_ps_pb2 __all__ = [] _grad_scalar = None @@ -1618,7 +1618,7 @@ def _minimize_impl(self, context["valid_strategy"] = copy.deepcopy(valid_strategy) # print("valid_strategy:", context["valid_strategy"]) - # print("user_defined_strategy:", context["user_defined_strategy"]) + print("user_defined_strategy:", context["user_defined_strategy"]) applied_meta_list = self.strategy_compiler._get_applied_meta_list() applied_graph_list = self.strategy_compiler._get_applied_graph_list() @@ -1648,17 +1648,17 @@ def _minimize_impl(self, no_grad_set=no_grad_set) if meta_optimizer: - # print("before minimize program id:", id(loss.block.program)) + print("before minimize program id:", id(loss.block.program)) optimize_ops, params_grads = meta_optimizer.minimize( loss, startup_program, parameter_list, no_grad_set=no_grad_set) - # print("after minimize program id:", id(loss.block.program)) + print("after minimize program id:", id(loss.block.program)) default_program = paddle.static.default_main_program() - # print("default program id:", id(default_program)) + print("default program id:", id(default_program)) if id(default_program) != id(loss.block.program): paddle.fluid.framework.switch_main_program(loss.block.program) - # print("default program id after switch:", id(default_program)) + print("default program id after switch:", id(default_program)) else: optimize_ops, params_grads = self.user_defined_optimizer.minimize( @@ -1668,7 +1668,7 @@ def _minimize_impl(self, context["program_params_grads"] = params_grads if graph_optimizer: - # print("before graph minimize program id:", id(loss.block.program)) + print("before graph minimize program id:", id(loss.block.program)) optimize_ops, params_grads = graph_optimizer.minimize( loss, startup_program, parameter_list, no_grad_set=no_grad_set) # since we do not encourage users to use graph operations @@ -1680,6 +1680,17 @@ def _minimize_impl(self, else: apply_ir_passes(loss.block.program, startup_program, self) + # ps_param = the_one_ps_pb2.PSParameter() + + # all_table_proto = context["user_defined_strategy"].sparse_table_configs + # ps_param = all_table_proto.add() + + # opt_info = {} + # opt_info["fleet_desc"] = ps_param + # program = paddle.static.default_main_program() + # program._fleet_opt = opt_info + # print("ps_param:", ps_param) + if not self._role_maker._is_heter_parameter_server_mode: program = paddle.static.default_main_program() opt_info = {} if program._fleet_opt is None else program._fleet_opt @@ -1691,6 +1702,7 @@ def _minimize_impl(self, opt_info[k] = v program._fleet_opt = opt_info + print("_fleet_opt:", program._fleet_opt) if self._runtime_handle is None: self._runtime_handle = RuntimeFactory()._create_runtime(context) @@ -1765,7 +1777,7 @@ def _minimize_losses_impl(self, for k, v in self._user_defined_strategy.trainer_desc_configs.items( ): if v or k not in opt_info: - opt_info[k] = v + opt_info[k] = v program._fleet_opt = opt_info # print("fleet base opt info:", id(program), program._fleet_opt) diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index c6ba48e5e32b57..72757e6c229109 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -591,6 +591,7 @@ def _set(self, table_proto): print('new table_name: {}'.format(self.common.table_name)) all_table_proto = self.context[ "user_defined_strategy"].sparse_table_configs + print("all_table_proto:", all_table_proto) usr_table_proto = all_table_proto.add() for proto in all_table_proto: if proto.table_name == self.common.table_name: @@ -619,8 +620,12 @@ def _set(self, table_proto): warnings.warn( "The accessor of sparse table is not set, use default value.") + # table_proto.accessor = usr_table_proto.accessor table_proto.accessor.ParseFromString( usr_table_proto.accessor.SerializeToString()) + # table_proto.accessor.CopyFrom(usr_table_proto.accessor) + print("usr_table_proto.accessor.SerializeToString():", usr_table_proto.accessor.SerializeToString()) + print("====table_proto:", table_proto) self.accessor._set(table_proto.accessor, self.common.table_name, ctx.program_id(), self.context) @@ -936,6 +941,7 @@ def _init_worker(self, scopes=None): #with open("test_fl_ps_worker_desc", "w") as f: # f.write(worker_desc) if self.context['use_ps_gpu']: + #main_program._fleet_opt["fleet_desc"] = worker_desc() main_program = self.context['loss'].block.program if not main_program._fleet_opt: main_program._fleet_opt = {} @@ -965,7 +971,7 @@ def sync_strategy_envs(): proto_txt = worker_desc debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) - if debug: + if True: print("worker: \n{}".format(proto_txt)) print("communicator send_ctx:") for key in send_ctx: diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index f0c094a84f758b..adfce2a19d1ed8 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -104,6 +104,7 @@ def _gen_worker_desc(self, trainer_desc): print("program of current device worker is not configured") exit(-1) opt_info = self._program._fleet_opt + print("opt_info:", opt_info) # when opt_info is None or empty dict, it should return if not opt_info: return diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index a34fb2dea7dc50..1b4a63d8c30784 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -52,6 +52,12 @@ def _create_trainer(self, opt_info=None): else: trainer_class = opt_info.get("trainer", "MultiTrainer") device_worker_class = opt_info.get("device_worker", "Hogwild") + if trainer_class == '': + trainer_class = "MultiTrainer" + opt_info["trainer"] = "MultiTrainer" + if device_worker_class == '': + device_worker_class = "Hogwild" + opt_info["device_worker"] = "Hogwild" trainer = globals()[trainer_class]() device_worker = globals()[device_worker_class]() From 509768aed83865804a85b768c6adef42b5c0550b Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 15 Jun 2022 10:04:46 +0800 Subject: [PATCH 02/31] remove useless code;test=develop --- .../distributed/ps/table/ctr_dymf_accessor.cc | 16 -- .../distributed/ps/table/ctr_dymf_accessor.h | 1 - .../distributed/ps/table/sparse_sgd_rule.cc | 80 ++++++ .../distributed/ps/table/sparse_sgd_rule.h | 21 ++ paddle/fluid/distributed/ps/table/table.cc | 1 + .../framework/distributed_strategy.proto | 2 +- paddle/fluid/framework/fleet/heter_context.h | 2 - .../framework/fleet/heter_ps/feature_value.h | 6 +- .../fleet/heter_ps/hashtable_kernel.cu | 26 -- .../framework/fleet/heter_ps/heter_comm_inl.h | 8 - .../fleet/heter_ps/heter_comm_kernel.cu | 2 - .../framework/fleet/heter_ps/heter_ps.cc | 2 - .../framework/fleet/heter_ps/heter_ps.cu | 2 +- .../fluid/framework/fleet/heter_ps/heter_ps.h | 1 - .../fluid/framework/fleet/heter_ps/mem_pool.h | 14 - .../framework/fleet/heter_ps/optimizer.cuh.h | 245 +++--------------- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 112 +------- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 120 +-------- paddle/fluid/framework/multi_trainer.cc | 87 ------- paddle/fluid/framework/trainer.h | 8 - paddle/fluid/pybind/ps_gpu_wrapper_py.cc | 2 - .../fleet/base/distributed_strategy.py | 15 ++ .../distributed/fleet/base/fleet_base.py | 27 +- python/paddle/distributed/ps/the_one_ps.py | 7 +- python/paddle/fluid/device_worker.py | 1 - python/paddle/fluid/trainer_factory.py | 6 - 26 files changed, 173 insertions(+), 641 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc index 82d761c37c5929..5eaf49a0ebbf03 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc @@ -52,16 +52,6 @@ int CtrDymfAccessor::Initialize() { return 0; } -// int CtrDymfAccessor::InitializeDim(int embed_sgd_dim, int embedx_dim, int embedx_sgd_dim) { -// common_feature_value.embed_sgd_dim = embed_sgd_dim; -// common_feature_value.embedx_dim = embedx_dim; -// common_feature_value.embedx_sgd_dim = embedx_sgd_dim; -// VLOG(0) << " INTO CtrDymfAccessor::InitializeDim(); embed_sgd_dim:" << embed_sgd_dim -// << " embedx_dim:" << embedx_dim<< " embedx_sgd_dim:" << embedx_sgd_dim; -// InitAccessorInfo(); -// return 0; -// } - void CtrDymfAccessor::InitAccessorInfo() { _accessor_info.dim = common_feature_value.Dim(); _accessor_info.size = common_feature_value.Size(); @@ -315,12 +305,6 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) { auto score = ShowClickScore(show, click); if (score >= _config.embedx_threshold() && param > common_feature_value.EmbedxG2SumIndex()) { - // VLOG(1) << "common_feature_value.EmbedxG2SumIndex():" - // << common_feature_value.EmbedxG2SumIndex(); - // VLOG(1) << "common_feature_value.EmbedxWIndex():" - // << common_feature_value.EmbedxWIndex(); - // VLOG(1) << "common_feature_value.MfDim():" - // << common_feature_value.MfDim(const_cast(v)); for (auto i = common_feature_value.EmbedxG2SumIndex(); i < common_feature_value.EmbedxWIndex() + common_feature_value.MfDim(const_cast(v)); diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h index 9444922ac833e7..d5b9acd8e9b258 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h @@ -151,7 +151,6 @@ class CtrDymfAccessor : public ValueAccessor { CtrDymfAccessor() {} virtual ~CtrDymfAccessor() {} virtual int Initialize(); - // virtual int InitializeDim(int embed_sgd_dim, int embedx_dim, int embedx_sgd_dim); // 初始化AccessorInfo virtual void InitAccessorInfo(); // 判断该value是否进行shrink diff --git a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc index 07562f566d3267..f23e7f71c603d0 100644 --- a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc +++ b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc @@ -252,5 +252,85 @@ void SparseAdamSGDRule::InitValueWork(float* value, *(sgd + Beta1PowIndex()) = _beta1_decay_rate; *(sgd + Beta2PowIndex()) = _beta2_decay_rate; } + +void SparseSharedAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim) { + _embedding_dim = emb_dim; + auto adam_param = param.adam(); + learning_rate_ = adam_param.learning_rate(); + _initial_range = adam_param.initial_range(); + _beta1_decay_rate = adam_param.beta1_decay_rate(); + _beta2_decay_rate = adam_param.beta2_decay_rate(); + _ada_epsilon = adam_param.ada_epsilon(); + if (adam_param.weight_bounds_size() == 0) { + _min_bound = -std::numeric_limits::max(); + _max_bound = std::numeric_limits::max(); + } else { + CHECK(adam_param.weight_bounds_size() >= 2) + << "invalid repeated size for weight_bounds:" + << adam_param.weight_bounds_size(); + _min_bound = adam_param.weight_bounds(0); + _max_bound = adam_param.weight_bounds(1); + } +} + +void SparseSharedAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad, + float scale) { + float* gsum = sgd + GSumIndex(); + float* g2sum = sgd + G2SumIndex(); + float* beta1_pow = sgd + Beta1PowIndex(); + float* beta2_pow = sgd + Beta2PowIndex(); + const float* g = grad; + + float lr = learning_rate_; + float beta1_pow_ = *beta1_pow; + float beta2_pow_ = *beta2_pow; + float gsum_ = *gsum; + float g2sum_ = *g2sum; + + // lr not change in one update + lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_); + double sum_gsum = 0.0; + double sum_g2sum = 0.0; + for (int i = 0; i < _embedding_dim; i++) { + // Calculation + double new_gsum = _beta1_decay_rate * gsum_ + (1 - _beta1_decay_rate) * g[i]; + double new_g2sum = + _beta2_decay_rate * g2sum_ + (1 - _beta2_decay_rate) * g[i] * g[i]; + w[i] = w[i] - lr * (new_gsum / (sqrt(new_g2sum) + _ada_epsilon)); + BoundValue(w[i]); + sum_gsum += new_gsum; + sum_g2sum += new_g2sum; + } + // update beta_pow_decay + (*gsum) = sum_gsum / _embedding_dim; + (*g2sum) = sum_g2sum / _embedding_dim; + (*beta1_pow) *= _beta1_decay_rate; + (*beta2_pow) *= _beta2_decay_rate; +} + +void SparseSharedAdamSGDRule::InitValueWork(float* value, float* sgd, + bool zero_init) { + for (int i = 0; i < _embedding_dim; ++i) { + if (zero_init) { + value[i] = 0.0; + BoundValue(value[i]); + } else { + value[i] = + (local_uniform_real_distribution()(local_random_engine()) * + 2 - + 1) * + _initial_range; + BoundValue(value[i]); + } + } + // init rule gsum and g2sum + for (int i = GSumIndex(); i < Beta1PowIndex(); i++) { + sgd[i] = 0.0; + } + // init beta1_pow and beta2_pow + *(sgd + Beta1PowIndex()) = _beta1_decay_rate; + *(sgd + Beta2PowIndex()) = _beta2_decay_rate; +} } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h index 215a15a7d31eb4..aea7fa2cd85f14 100644 --- a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h +++ b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h @@ -144,5 +144,26 @@ class SparseAdamSGDRule : public SparseValueSGDRule { float _beta2_decay_rate; float _ada_epsilon; }; + +class SparseSharedAdamSGDRule : public SparseValueSGDRule { + public: + virtual void LoadConfig(const SparseCommonSGDRuleParameter& param, + size_t emb_dim); + virtual void UpdateValueWork(float* w, float* sgd, const float* push_value, + float scale); + virtual void InitValueWork(float* value, float* sgd, bool zero_init); + virtual size_t Dim() { return 4; } + size_t GSumIndex() { return 0; } + size_t G2SumIndex() { return GSumIndex() + 1; } + size_t Beta1PowIndex() { return G2SumIndex() + 1; } + size_t Beta2PowIndex() { return Beta1PowIndex() + 1; } + + protected: + float learning_rate_; + float _beta1_decay_rate; + float _beta2_decay_rate; + float _ada_epsilon; +}; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/table.cc b/paddle/fluid/distributed/ps/table/table.cc index cfa286f1c3f7f5..3e6d5a99412065 100644 --- a/paddle/fluid/distributed/ps/table/table.cc +++ b/paddle/fluid/distributed/ps/table/table.cc @@ -49,6 +49,7 @@ REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule); REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule); +REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseSharedAdamSGDRule); int32_t TableManager::Initialize() { static bool initialized = false; diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 7504e6f93a1e65..45758389c54135 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -228,7 +228,7 @@ message repeated float weight_bounds = 4; } -message SparseAdamSGDParameter { // SparseAdamSGDRule +message SparseAdamSGDParameter { // SparseAdamSGDRule | SparseSharedAdamSGDRule optional double learning_rate = 1 [ default = 0.001 ]; optional double initial_range = 2 [ default = 0.0001 ]; optional double beta1_decay_rate = 3 [ default = 0.9 ]; diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 3407608d90cdbc..ef2e73d6dd5b56 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -81,7 +81,6 @@ class HeterContext { std::vector> device_values_; std::vector> device_keys_; std::vector>> device_dim_keys_; - // std::vector>> device_dim_values_; std::vector mutex_; std::vector> dim_mutex_; int multi_mf_dim_ = 0; @@ -114,7 +113,6 @@ class HeterContext { value_dim_ptr_[i].resize(dim_num); } device_values_.resize(device_num); - // device_dim_values_.resize(device_num); device_keys_.resize(device_num); device_dim_keys_.resize(device_num); diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 43e70d8edabf22..7e899cb6f377de 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -168,7 +168,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { // if (name.compare("adam") == 0) { // common_feature_value.embed_sgd_dim = 4; // common_feature_value.embedx_sgd_dim = sparse_embedx_dim * 2 + 2; - // } else if (name.compare("sharedadam") == 0) { + // } else if (name.compare("shared_adam") == 0) { // common_feature_value.embed_sgd_dim = 4; // common_feature_value.embedx_sgd_dim = 4; // } else { @@ -193,7 +193,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { if (optimizer_type == 3) { //adam common_feature_value.embed_sgd_dim = 4; common_feature_value.embedx_sgd_dim = sparse_embedx_dim * 2 + 2; - } else if (optimizer_type == 4) { //sharedadam + } else if (optimizer_type == 4) { //shared_adam common_feature_value.embed_sgd_dim = 4; common_feature_value.embedx_sgd_dim = 4; } else { @@ -285,7 +285,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { << " mf: "; if (param_size > common_feature_value.EmbedxG2SumIndex()) { for (auto i = common_feature_value.EmbedxG2SumIndex(); - i < common_feature_value.Dim(); ++i) { + i < int(common_feature_value.Size() / sizeof(float)); ++i) { os << " " << v[i]; } } diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 467564f97b0f8c..f072b20e7e27aa 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -118,11 +118,6 @@ __global__ void dy_mf_search_kernel(Table* table, input[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = input[feature_value_accessor.common_feature_value.EmbedWIndex()]; - printf("dy_mf_search_kernel table slot: %f; show: %f; click: %f; lr: %f", - cur[feature_value_accessor.common_feature_value.SlotIndex()], - cur[feature_value_accessor.common_feature_value.ShowIndex()], - cur[feature_value_accessor.common_feature_value.ClickIndex()], - cur[feature_value_accessor.common_feature_value.EmbedWIndex()]); for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; @@ -330,27 +325,6 @@ void HashTable::dump_to_cpu(int devid, StreamType stream) { } } #endif -// #ifdef PADDLE_WITH_PSCORE -// auto* downpour_value = -// (paddle::distributed::FixedFeatureValue*)(gpu_val.cpu_ptr); -// int downpour_value_size = downpour_value->size(); -// if (gpu_val.mf_size > 0 && downpour_value_size == 7) { -// downpour_value->resize(gpu_val.mf_size + downpour_value_size); -// } -// float* cpu_val = downpour_value->data(); -// // cpu_val[0] = 0; -// cpu_val[2] = gpu_val.delta_score; -// cpu_val[3] = gpu_val.show; -// cpu_val[4] = gpu_val.clk; -// cpu_val[5] = gpu_val.lr; -// cpu_val[6] = gpu_val.lr_g2sum; -// cpu_val[0] = gpu_val.slot; -// if (gpu_val.mf_size > 0) { -// for (int x = 0; x < gpu_val.mf_size; x++) { -// cpu_val[x + 7] = gpu_val.mf[x]; -// } -// } -// #endif } }; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 0016ff318fd5f1..70edcee5950553 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -990,14 +990,6 @@ void HeterComm::pull_sparse(int num, sync_stream(stream); - // char* tmp_mem2 = (char*)malloc(len * val_type_size); - // cudaMemcpy(tmp_mem2, reinterpret_cast(d_shard_vals_ptr), len * val_type_size, - // cudaMemcpyDeviceToHost); - // for (int i =0 ; i < 20; i++){ - // float* val = (float*)(void*)&tmp_mem2[(i)*val_type_size]; - // VLOG(0) << "pullsparse walk_to_src fill_dvals cpu: "<< i << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); - // } - for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 98295bb36560a9..58001d5c022bc7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -202,8 +202,6 @@ __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, uint64_t new_offset = uint64_t(idx[i]) * val_size; float* cur = (float*)((char*)d_vals + new_offset); float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); - // *(float*)((char*)d_vals + new_offset) = - // (float*)((char*)d_shard_vals + i * val_size); cur[feature_value_accessor.common_feature_value.SlotIndex()] = shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc index 618a88fd70e56f..29b6c525971b12 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc @@ -47,7 +47,6 @@ void HeterPs::pull_sparse(int num, comm_->pull_sparse(num, d_keys, d_vals, len); } - int HeterPs::get_index_by_devid(int devid) { return comm_->get_index_by_devid(devid); } @@ -72,7 +71,6 @@ void HeterPs::push_sparse(int num, // comm_->push_sparse_multi_node(num, d_keys, d_grads, len, opt_); } - } // end namespace framework } // end namespace paddle #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index e5af93c0ef8bde..2d030c75a82ba5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -83,7 +83,7 @@ void HeterPs::push_sparse(int num, auto optimizer = SparseAdamOptimizer(feature_value_accessor_); VLOG(0) << "INTO push_sparse SparseAdamOptimizer EmbedDim():" << optimizer.EmbedDim(); comm_->push_sparse(num, d_keys, d_grads, len, optimizer); - } else if (optimizer_type_ == 4) { //sharedadam + } else if (optimizer_type_ == 4) { //shared_adam auto optimizer = SparseAdamSharedOptimizer(feature_value_accessor_); comm_->push_sparse(num, d_keys, d_grads, len, optimizer); } else { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 63f35dee92bc83..109facb5828cc2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -69,7 +69,6 @@ class HeterPs : public HeterPsBase { private: std::shared_ptr> comm_; #if defined(PADDLE_WITH_CUDA) - // Optimizer opt_; CommonFeatureValueAccessor feature_value_accessor_; int optimizer_type_; #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/mem_pool.h b/paddle/fluid/framework/fleet/heter_ps/mem_pool.h index 0a95576cd2987b..05e252b2afe44e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/mem_pool.h +++ b/paddle/fluid/framework/fleet/heter_ps/mem_pool.h @@ -82,20 +82,6 @@ class HBMMemoryPool : public managed { cudaMemset(mem_, 0, block_size_ * capacity); } - // friend std::ostream& operator<<(std::ostream& out, HBMMemoryPool& p) { - // for (size_t k = 0; k < 5; k++) { - // auto x = (float*)(p.mem() + k * p.capacity()); - // out << "show: " << x->show << " clk: " << x->clk << " slot: " << x->slot - // << " lr: " << x->lr << " mf_dim: " << x->mf_size - // << " mf_size: " << x->mf_size << " mf:"; - // for (int i = 0; i < x->mf_size + 1; ++i) { - // out << " " << x->mf[i]; - // } - // out << "\n"; - // } - // return out; - // } - char* mem() { return mem_; } size_t capacity() { return capacity_; } diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index 0e1206de23a7f3..db27da8712f2dc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -34,12 +34,6 @@ class Optimizer { feature_value_accessor_ = feature_value_accessor; } __host__ ~Optimizer() {} - // __host__ virtual void initialize(const OptimizerConfig& optimizer_config, - // size_t emb_dim) { - - // _lr_embedding_dim = 1; - // _embedding_dim = emb_dim; - // } __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT @@ -50,15 +44,8 @@ class Optimizer { float* ptr, const float* grad) { } - // __host__ float& MinBound() { return _min_bound; } - // float& MaxBound() { return _max_bound; } - CommonFeatureValueAccessor feature_value_accessor_; -// protected: - // float _min_bound; - // float _max_bound; - // float _initial_range; size_t _embedding_dim; size_t _lr_embedding_dim; @@ -72,40 +59,6 @@ class SparseAdagradOptimizer : public Optimizer { _lr_embedding_dim = 1; _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); } - - // __host__ virtual void initialize(const OptimizerConfig& optimizer_config, - // size_t emb_dim) { - // _lr_embedding_dim = 1; - // _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); - // // this->_learning_rate = optimizer_config.learning_rate; - // // this->_initial_g2sum = optimizer_config.initial_g2sum; - // // this->_initial_range = optimizer_config.initial_range; - // // this->_min_bound = optimizer_config.min_bound; - // // this->_max_bound = optimizer_config.max_bound; - // } - - - // __device__ void update_lr(const OptimizerConfig& optimizer_config, - // float& w, // NOLINT - // float* sgd, float g, // NOLINT - // float scale) { - // float& g2sum = sgd[G2SumIndex()]; - // double add_g2sum = 0; - // double ratio = optimizer_config.learning_rate * - // sqrt(optimizer_config.initial_g2sum / - // (optimizer_config.initial_g2sum + g2sum)); - // double scaled_grad = g / scale; - - // w += scaled_grad * ratio; - - // if (w < optimizer_config.min_bound) w = optimizer_config.min_bound; - // if (w > optimizer_config.max_bound) w = optimizer_config.max_bound; - - // add_g2sum += scaled_grad * scaled_grad; - - // g2sum += add_g2sum; - // } - __device__ void update_value_work(const OptimizerConfig& optimizer_config, int n, float* w, @@ -131,30 +84,6 @@ class SparseAdagradOptimizer : public Optimizer { g2sum += add_g2sum / n; } - // __device__ void update_value_work(const OptimizerConfig& optimizer_config, - // float* w, - // float* sgd, // NOLINT - // const float* g, float scale) { - // float& g2sum = sgd[G2SumIndex()]; - // double add_g2sum = 0; - // double ratio = _learning_rate * - // sqrt(_initial_g2sum / - // (_initial_g2sum + g2sum)); - // for (int i = 0; i < _embedding_dim; ++i) { - // double scaled_grad = g[i] / scale; - - // w[i] += scaled_grad * ratio; - - // if (w[i] < this->_min_bound) - // w[i] = this->_min_bound; - // if (w[i] > this->_max_bound) - // w[i] = this->_max_bound; - // add_g2sum += scaled_grad * scaled_grad; - // } - - // g2sum += add_g2sum / _embedding_dim; - // } - __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { @@ -188,8 +117,6 @@ class SparseAdagradOptimizer : public Optimizer { optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; - - // ptr->mf_size = MF_DIM + 1; int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); @@ -213,10 +140,6 @@ class SparseAdagradOptimizer : public Optimizer { __host__ __device__ size_t G2SumIndex() { return 0; } __host__ __device__ size_t EmbedxG2SumIndex() { return 0; } - -// private: - // float _learning_rate; - // float _initial_g2sum; }; class SparseAdamOptimizer : public Optimizer { @@ -227,53 +150,7 @@ class SparseAdamOptimizer : public Optimizer { _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); } - // __host__ virtual void initialize(const OptimizerConfig& optimizer_config, - // size_t emb_dim) { - // _lr_embedding_dim = 1; - // _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); - // // this->_learning_rate = optimizer_config.learning_rate; - // // this->_initial_range = optimizer_config.initial_range; - // // this->_beta1_decay_rate = optimizer_config.beta1_decay_rate; - // // this->_beta2_decay_rate = optimizer_config.beta2_decay_rate; - // // this->_ada_epsilon = optimizer_config.ada_epsilon; // float epsilon = 1e-08; - // // this->_min_bound = optimizer_config.min_bound; - // // this->_max_bound = optimizer_config.max_bound; - // } - - // __device__ void update_lr(const OptimizerConfig& optimizer_config, - // float& w, // NOLINT - // float* sgd, float g, // NOLINT - // float scale) { - - // float* moment1 = sgd + GSumIndex(); - // float* moment2 = sgd + G2SumIndex(); - // float* beta1_pow = sgd + Beta1PowIndex(); - // float* beta2_pow = sgd + Beta2PowIndex(); - - // float beta1_pow_ = *beta1_pow; - // float beta2_pow_ = *beta2_pow; - // float moment1_ = *moment1; - // float moment2_ = *moment2; - - // float epsilon = 1e-08; - // double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); - // double scaled_grad = g / scale; - // double new_moment1 = optimizer_config.beta1_decay_rate * moment1_ + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; - // double new_moment2 = optimizer_config.beta2_decay_rate * moment2_ + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; - // w += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); - - // if (w < optimizer_config.min_bound) w = optimizer_config.min_bound; - // if (w > optimizer_config.max_bound) w = optimizer_config.max_bound; - - // (*moment1) = new_moment1; - // (*moment2) = new_moment2; - - // (*beta1_pow) *= optimizer_config.beta1_decay_rate; - // (*beta2_pow) *= optimizer_config.beta2_decay_rate; - - // } - - __device__ void update_value_work(const OptimizerConfig& optimizer_config, int n, + __device__ void update_lr(const OptimizerConfig& optimizer_config, int n, float* w, float* sgd, const float* g, float scale) { @@ -307,39 +184,39 @@ class SparseAdamOptimizer : public Optimizer { (*beta2_pow) *= optimizer_config.beta2_decay_rate; } - // __device__ void update_value_work(const OptimizerConfig& optimizer_config, - // float* w, - // float* sgd, // NOLINT - // const float* g, float scale) { - // float* moment1 = sgd + GSumIndex(); - // float* moment2 = sgd + G2SumIndex(); - // float* beta1_pow = sgd + Beta1PowIndex(); - // float* beta2_pow = sgd + Beta2PowIndex(); - - // float beta1_pow_ = *beta1_pow; - // float beta2_pow_ = *beta2_pow; + __device__ void update_mf(const OptimizerConfig& optimizer_config, int n, + float* w, + float* sgd, + const float* g, float scale) { + float* moment1 = sgd + EmbedxGSumIndex(); + float* moment2 = sgd + EmbedxG2SumIndex(); + float* beta1_pow = sgd + EmbedxBeta1PowIndex(); + float* beta2_pow = sgd + EmbedxBeta2PowIndex(); - // double ratio = _learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + float beta1_pow_ = *beta1_pow; + float beta2_pow_ = *beta2_pow; - // for (int i = 0; i < this->_embedding_dim; ++i) { - // double scaled_grad = g[i] / scale; + float epsilon = 1e-08; + double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + for (int i = 0; i < n; ++i) { + double scaled_grad = g[i] / scale; - // double new_moment1 = _beta1_decay_rate * moment1[i] + (1.0 - _beta1_decay_rate) * scaled_grad; - // double new_moment2 = _beta2_decay_rate * moment2[i] + (1.0 - _beta2_decay_rate) * scaled_grad * scaled_grad; - // w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + _ada_epsilon)); + double new_moment1 = optimizer_config.beta1_decay_rate * moment1[i] + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; + double new_moment2 = optimizer_config.beta2_decay_rate * moment2[i] + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; + w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); - // if (w[i] < this->_min_bound) - // w[i] = this->_min_bound; - // if (w[i] > this->_max_bound) - // w[i] = this->_max_bound; + if (w[i] < optimizer_config.mf_min_bound) + w[i] = optimizer_config.mf_min_bound; + if (w[i] > optimizer_config.mf_max_bound) + w[i] = optimizer_config.mf_max_bound; - // moment1[i] = new_moment1; - // moment2[i] = new_moment2; - // } - // (*beta1_pow) *= _beta1_decay_rate; - // (*beta2_pow) *= _beta2_decay_rate; - // } + moment1[i] = new_moment1; + moment2[i] = new_moment2; + } + (*beta1_pow) *= optimizer_config.beta1_decay_rate; + (*beta2_pow) *= optimizer_config.beta2_decay_rate; + } __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT @@ -360,12 +237,13 @@ class SparseAdamOptimizer : public Optimizer { optimizer_config.nonclk_coeff * (g_show - g_click) + optimizer_config.clk_coeff * g_click; - update_value_work(optimizer_config, 1, + update_lr(optimizer_config, 1, ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), grad + feature_value_accessor_.common_push_value.EmbedGIndex(), g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + printf("mf_dim: %f, lr_gsum: %f, ", mf_dim, ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex()]); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * @@ -374,8 +252,6 @@ class SparseAdamOptimizer : public Optimizer { optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; - - // ptr->mf_size = MF_DIM + 1; int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); @@ -393,12 +269,14 @@ class SparseAdamOptimizer : public Optimizer { optimizer_config.beta2_decay_rate; } } else { - update_value_work(optimizer_config, mf_dim, + update_mf(optimizer_config, mf_dim, ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(), ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(), grad + feature_value_accessor_.common_push_value.EmbedxGIndex(), g_show); } + printf("EmbedxGIndex: %f, mf_gsum: %f, ", feature_value_accessor_.common_push_value.EmbedxGIndex(), + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex()]); } __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); } @@ -413,12 +291,6 @@ class SparseAdamOptimizer : public Optimizer { __host__ __device__ size_t EmbedxBeta1PowIndex() { return EmbedxG2SumIndex() + _embedding_dim; } __host__ __device__ size_t EmbedxBeta2PowIndex() { return EmbedxBeta1PowIndex() + 1; } - -// protected: - // float _learning_rate; - // float _beta1_decay_rate; - // float _beta2_decay_rate; - // float _ada_epsilon; }; @@ -430,49 +302,6 @@ class SparseAdamSharedOptimizer : public Optimizer { _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); } - // virtual void initialize() { - // _lr_embedding_dim = 1; - // _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); - // // this->_learning_rate = optimizer_config.learning_rate; - // // this->_initial_range = optimizer_config.initial_range; - // // this->_beta1_decay_rate = optimizer_config.beta1_decay_rate; - // // this->_beta2_decay_rate = optimizer_config.beta2_decay_rate; - // // this->_ada_epsilon = optimizer_config.ada_epsilon; // float epsilon = 1e-08; - // // this->_min_bound = optimizer_config.min_bound; - // // this->_max_bound = optimizer_config.max_bound; - // } - - // __device__ void update_lr(const OptimizerConfig& optimizer_config, - // float& w, // NOLINT - // float* sgd, float g, // NOLINT - // float scale) { - // float* moment1 = sgd + GSumIndex(); - // float* moment2 = sgd + G2SumIndex(); - // float* beta1_pow = sgd + Beta1PowIndex(); - // float* beta2_pow = sgd + Beta2PowIndex(); - - // float beta1_pow_ = *beta1_pow; - // float beta2_pow_ = *beta2_pow; - // float moment1_ = *moment1; - // float moment2_ = *moment2; - - // float epsilon = 1e-08; - // double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); - // double scaled_grad = g / scale; - // double new_moment1 = optimizer_config.beta1_decay_rate * moment1_ + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; - // double new_moment2 = optimizer_config.beta2_decay_rate * moment2_ + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; - // w += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); - - // if (w < optimizer_config.min_bound) w = optimizer_config.min_bound; - // if (w > optimizer_config.max_bound) w = optimizer_config.max_bound; - - // (*moment1) = new_moment1; - // (*moment2) = new_moment2; - - // (*beta1_pow) *= optimizer_config.beta1_decay_rate; - // (*beta2_pow) *= optimizer_config.beta2_decay_rate; - // } - __device__ void update_value_work(const OptimizerConfig& optimizer_config, int n, float* w, float* sgd, @@ -548,8 +377,6 @@ class SparseAdamSharedOptimizer : public Optimizer { optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; - - // ptr->mf_size = MF_DIM + 1; int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); @@ -587,12 +414,6 @@ class SparseAdamSharedOptimizer : public Optimizer { __host__ __device__ size_t EmbedxBeta1PowIndex() { return EmbedxG2SumIndex() + 1; } __host__ __device__ size_t EmbedxBeta2PowIndex() { return EmbedxBeta1PowIndex() + 1; } - -// protected: -// float _learning_rate; -// float _beta1_decay_rate; -// float _beta2_decay_rate; -// float _ada_epsilon; }; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index fa4f571ff02bef..5c4c29961f775b 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -272,8 +272,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { auto& local_dim_keys = gpu_task->feature_dim_keys_; auto& local_dim_ptr = gpu_task->value_dim_ptr_; - auto& device_keys = gpu_task->device_keys_; - auto& device_vals = gpu_task->device_values_; auto& device_dim_keys = gpu_task->device_dim_keys_; auto& device_dim_ptr = gpu_task->device_dim_ptr_; auto& device_dim_mutex = gpu_task->dim_mutex_; @@ -619,17 +617,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { for (std::thread& t : threads) { t.join(); } - } else { - for (int i = 0; i < thread_keys_shard_num_; i++) { - for (int j = 0; j < device_num; j++) { - task_futures.emplace_back( - hbm_thread_pool_[i]->enqueue(prepare_dev_value_func, j, i)); - } - } - for (auto& f : task_futures) { - f.wait(); - } - task_futures.clear(); } timeline.Pause(); VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() @@ -678,8 +665,6 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { << " feature_value_DIM:" << feature_value_accessor_.GetAccessorInfo().dim; size_t feature_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - // size_t feature_value_size = - // TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1 + 1) * sizeof(float))); auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; size_t len = device_dim_keys.size(); @@ -786,9 +771,6 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { #ifdef PADDLE_WITH_PSCORE VLOG(0) << "cpu build "<< k << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) << " |: "<< cpu_table_accessor_->ParseToString(ptr_val, dim); - // paddle::distributed::CtrDymfAccessor accessor; - // accessor.InitializeDim(embed_sgd_dim_, mf_dim, embedx_sgd_dim_); - // VLOG(0) << "cpu_table_accessor_ DIM:" << cpu_table_accessor_->GetAccessorInfo().dim; val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] = ptr_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()]; val[feature_value_accessor_.common_feature_value.ShowIndex()] = @@ -799,13 +781,6 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { ptr_val[cpu_table_accessor_->common_feature_value.SlotIndex()]; val[feature_value_accessor_.common_feature_value.EmbedWIndex()] = ptr_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()]; - // val->lr_sgd_dim = 1; - // val->mf_sgd_dim = 1; - // // sgd: embed_sgd_dim=1; adam: embed_sgd_dim=1*2+2 - // for (int i = 0; i < val->lr_sgd_dim; i++) { - // val->lr_g2sum[i] = ptr_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i]; - // } - // VLOG(0)<< "EmbedDim:" << feature_value_accessor_.common_feature_value.EmbedDim(); for (int i = 0; i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { val[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + i] = ptr_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i]; @@ -816,12 +791,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { // (uint64_t)(device_dim_ptrs[k]); *(reinterpret_cast(val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = (uint64_t)(device_dim_ptrs[k]); // (uint64_t*)(val + feature_value_accessor_.common_feature_value.CpuPtrIndex()) = (uint64_t)(device_dim_ptrs[k]); - // ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = float(mf_dim); - // val->mf_dim = mf_dim; + ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = float(mf_dim); val[feature_value_accessor_.common_feature_value.MfDimIndex()] = mf_dim; - // VLOG(0) << " dim:" << dim - // << " DIM:" << feature_value_accessor_.GetAccessorInfo().dim - // << " MF_DIM:" << feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); if (dim > cpu_table_accessor_->GetAccessorInfo().dim - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)) { val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = @@ -1062,10 +1033,8 @@ void PSGPUWrapper::EndPass() { } #endif #ifdef PADDLE_WITH_PSCORE - // paddle::distributed::CtrDymfAccessor accessor; - // accessor.InitializeDim(embed_sgd_dim_, mf_dim, embedx_sgd_dim_); auto* downpour_value = - (paddle::distributed::FixedFeatureValue*)(reinterpret_cast(gpu_val+ feature_value_accessor_.common_feature_value.CpuPtrIndex())); + (paddle::distributed::FixedFeatureValue*)(*(reinterpret_cast(gpu_val+ feature_value_accessor_.common_feature_value.CpuPtrIndex()))); size_t downpour_value_size = downpour_value->size(); VLOG(0) << "downpour_value_size:" <GetAccessorInfo().dim @@ -1104,11 +1073,11 @@ void PSGPUWrapper::EndPass() { } } VLOG(0) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.GetAccessorInfo().dim) - << " =====CPU:" << cpu_table_accessor_->ParseToString(cpu_val, cpu_table_accessor_->GetAccessorInfo().update_dim); + << " =====CPU "<< cpu_table_accessor_->GetAccessorInfo().dim + << "-" << downpour_value->size() + << " :" << cpu_table_accessor_->ParseToString(cpu_val, downpour_value->size()); } - - #endif free(test_build_values); }; @@ -1148,73 +1117,8 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, const std::vector& values, const std::vector& slot_lengths, const int hidden_size) { - } -// void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, -// const int table_id, -// const std::vector& keys, -// const std::vector& values, -// const std::vector& slot_lengths, -// const int hidden_size) { -// platform::Timer all_timer; -// platform::Timer pull_gpups_timer; -// all_timer.Start(); -// int64_t total_length = -// std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); -// VLOG(3) << "Begine Gpu/Xpu Ps PullSparse"; -// auto buf = memory::Alloc(place, total_length * sizeof(FeatureValue)); -// FeatureValue* total_values_gpu = reinterpret_cast(buf->ptr()); -// if (platform::is_cpu_place(place)) { -// PADDLE_THROW(platform::errors::Unimplemented( -// "Warning:: CPUPlace is not supported in GpuPs now.")); -// } else if (platform::is_gpu_place(place)) { -// #ifdef PADDLE_WITH_CUDA -// VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; -// int device_id = place.GetDeviceId(); -// int devid_2_index = HeterPs_->get_index_by_devid(device_id); -// LoDTensor& total_keys_tensor = keys_tensor[devid_2_index]; -// uint64_t* total_keys = reinterpret_cast( -// total_keys_tensor.mutable_data({total_length, 1}, place)); - -// // construct slot_level lod info -// auto slot_lengths_lod = slot_lengths; -// for (size_t i = 1; i < slot_lengths_lod.size(); i++) { -// slot_lengths_lod[i] += slot_lengths_lod[i - 1]; -// } -// auto buf_key = memory::Alloc(place, keys.size() * sizeof(uint64_t*)); -// auto buf_length = -// memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); -// uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); -// int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); -// cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), -// cudaMemcpyHostToDevice); -// cudaMemcpy(gpu_len, slot_lengths_lod.data(), -// slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); - -// this->CopyKeys(place, gpu_keys, total_keys, gpu_len, -// static_cast(slot_lengths.size()), -// static_cast(total_length)); -// VLOG(3) << "Begin call PullSparseGPU in GPUPS, dev: " << devid_2_index -// << " len: " << total_length; -// pull_gpups_timer.Start(); -// HeterPs_->pull_sparse(devid_2_index, total_keys, total_values_gpu, -// static_cast(total_length)); -// pull_gpups_timer.Pause(); - -// VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length -// << "]"; -// this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len, -// static_cast(slot_lengths.size()), hidden_size, -// total_length); -// } else { -// PADDLE_THROW(platform::errors::PreconditionNotMet( -// "GpuPs: PullSparse Only Support CUDAPlace Now.")); -// } -// all_timer.Pause(); -// VLOG(3) << "GpuPs PullSparse total costs: " << all_timer.ElapsedSec() -// << " s, of which GPUPS costs: " << pull_gpups_timer.ElapsedSec() -// << " s"; -// VLOG(3) << "End PullSparse"; -// } + VLOG(0) << "Warning:: recommand use pull_gpups_sparse op instead. This PullSparse is not used."; +} void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, const int table_id, @@ -1231,8 +1135,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); size_t feature_value_size = 0; - // feature_value_size = TYPEALIGN( - // 8, sizeof(FeatureValue) + sizeof(float) * (index_dim_vec_.back() + 1)); feature_value_size = TYPEALIGN( 8, feature_value_accessor_.GetAccessorInfo().size); VLOG(0) << "PULLSPASE" << feature_value_accessor_.GetAccessorInfo().size; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 34242aa12212b2..1a6b0812293c25 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -99,14 +99,6 @@ class AfsWrapper { }; #endif -enum OptimizerType { - OPTIMIZER_NAIVE = 0, - OPTIMIZER_ADAGRAD = 1, - OPTIMIZER_STDADAGRAD = 2, - OPTIMIZER_ADAM = 3, - OPTIMIZER_SHARDADAM = 4, -}; - class PSGPUWrapper { public: ~PSGPUWrapper(); @@ -334,7 +326,7 @@ class PSGPUWrapper { config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon(); config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0]; config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1]; - } else if (optimizer_name == "SparseAdamSharedSGDRule") { + } else if (optimizer_name == "SparseSharedAdamSGDRule") { config[prefix + "optimizer_type"] = 4; config[prefix + "learning_rate"] = sgd_param.adam().learning_rate(); config[prefix + "initial_range"] = sgd_param.adam().initial_range(); @@ -354,9 +346,6 @@ class PSGPUWrapper { sparse_table_accessor.ctr_accessor_param(); auto accessor_class = sparse_table_accessor.accessor_class(); - // NOTE(zhangminxu): gpups' sparse table optimizer config, - // now only support single sparse table - // auto sparse_table = param_.sparse_table(0); std::unordered_map config; config["embedx_dim"] = sparse_table_accessor.embedx_dim(); config["nonclk_coeff"] = sparse_table_accessor_parameter.nonclk_coeff(); @@ -378,107 +367,6 @@ class PSGPUWrapper { } #endif -// void InitializeGPUServer(std::unordered_map config) { -// float nonclk_coeff = (config.find("sparse_nonclk_coeff") == config.end()) -// ? 1.0 -// : std::stof(config["sparse_nonclk_coeff"]); -// float clk_coeff = -// (config.find("sparse_click_coeff") == config.end()) ? 1.0 : std::stof(config["sparse_click_coeff"]); -// float min_bound = (config.find("min_bound") == config.end()) -// ? -10.0 -// : std::stof(config["min_bound"]); -// float max_bound = (config.find("max_bound") == config.end()) -// ? 10.0 -// : std::stof(config["max_bound"]); -// float learning_rate = (config.find("sparse_learning_rate") == config.end()) -// ? 0.05 -// : std::stof(config["sparse_learning_rate"]); -// float initial_g2sum = (config.find("sparse_initial_g2sum") == config.end()) -// ? 3.0 -// : std::stof(config["sparse_initial_g2sum"]); -// float initial_range = (config.find("sparse_initial_range") == config.end()) -// ? 1e-4 -// : std::stof(config["sparse_initial_range"]); -// float beta1_decay_rate = (config.find("embed_sparse_beta1_decay_rate") == config.end()) -// ? 0.9 -// : std::stof(config["embed_sparse_beta1_decay_rate"]); -// float beta2_decay_rate = (config.find("embed_sparse_beta2_decay_rate") == config.end()) -// ? 0.999 -// : std::stof(config["embed_sparse_beta2_decay_rate"]); -// float ada_epsilon = (config.find("embed_sparse_ada_epsilon") == config.end()) -// ? 1e-8 -// : std::stof(config["embed_sparse_ada_epsilon"]); -// // mf config settings -// float mf_create_thresholds = -// (config.find("sparse_embedx_threshold") == config.end()) -// ? static_cast(1.0) -// : std::stof(config["sparse_embedx_threshold"]); -// float mf_learning_rate = (config.find("embedx_sparse_learning_rate") == config.end()) -// ? 0.05 -// : std::stof(config["embedx_sparse_learning_rate"]); -// float mf_initial_g2sum = (config.find("sparse_initial_g2sum") == config.end()) -// ? 3.0 -// : std::stof(config["sparse_initial_g2sum"]); -// float mf_initial_range = (config.find("embedx_sparse_initial_range") == config.end()) -// ? 1e-4 -// : std::stof(config["embedx_sparse_initial_range"]); -// float mf_min_bound = (config.find("mf_min_bound") == config.end()) -// ? -10.0 -// : std::stof(config["mf_min_bound"]); -// float mf_max_bound = (config.find("mf_max_bound") == config.end()) -// ? 10.0 -// : std::stof(config["mf_max_bound"]); -// float mf_beta1_decay_rate = (config.find("embedx_sparse_beta1_decay_rate") == config.end()) -// ? 0.9 -// : std::stof(config["embedx_sparse_beta1_decay_rate"]); -// float mf_beta2_decay_rate = (config.find("embedx_sparse_beta2_decay_rate") == config.end()) -// ? 0.999 -// : std::stof(config["embedx_sparse_beta2_decay_rate"]); -// float mf_ada_epsilon = (config.find("embedx_sparse_ada_epsilon") == config.end()) -// ? 1e-8 -// : std::stof(config["embedx_sparse_ada_epsilon"]); -// for (size_t i = 0; i < heter_devices_.size(); i++) { -// #ifdef PADDLE_WITH_CUDA -// PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i])); -// #elif defined(PADDLE_WITH_XPU_KP) -// PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i])); -// #endif -// this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, max_bound, -// learning_rate, initial_g2sum, initial_range, -// beta1_decay_rate, beta2_decay_rate, ada_epsilon); -// this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate, -// mf_initial_g2sum, mf_initial_range, mf_min_bound, -// mf_max_bound, mf_beta1_decay_rate, mf_beta2_decay_rate, -// mf_ada_epsilon); - -// // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam) -// // optimizer_type_ = (config.find("optimizer_type") == config.end()) -// // ? 1 -// // : config["optimizer_type"]; -// optimizer_name_ = (config.find("embed_sparse_optimizer") == config.end()) -// ? "adagrad" -// : config["embed_sparse_optimizer"]; -// CommonFeatureValueAccessor feature_value_accessor_; -// feature_value_accessor_.Configure(config); -// embedx_dim_ = (config.find("sparse_embedx_dim") == config.end()) -// ? 8 -// : std::stoi(config["sparse_embedx_dim"]); -// if(optimizer_name_ == "adagrad") { -// embed_sgd_dim_ = 1; -// embedx_sgd_dim_ = 1; -// } else if (optimizer_name_ == "adam") { -// embed_sgd_dim_ = 4; -// embedx_sgd_dim_ = embedx_dim_ * 2 + 2; -// } else if (optimizer_name_ == "sharedadam") { -// embed_sgd_dim_ = 4; -// embedx_sgd_dim_ = 4; -// } else { -// embed_sgd_dim_ = 1; -// embedx_sgd_dim_ = 1; -// } -// } -// } ->>>>>>> 4e395c7ebd... add adam/sharedadam optimzier for gpups;edit optimizer struct;test=develop void InitializeGPUServer(std::unordered_map config) { float nonclk_coeff = (config.find("nonclk_coeff") == config.end()) ? 1.0 @@ -575,7 +463,7 @@ class PSGPUWrapper { if (optimizer_type_ == 3) { //adam embed_sgd_dim_ = 4; embedx_sgd_dim_ = embedx_dim_ * 2 + 2; - } else if (optimizer_type_ == 4) { //sharedadam + } else if (optimizer_type_ == 4) { //shared_adam embed_sgd_dim_ = 4; embedx_sgd_dim_ = 4; } else { @@ -677,11 +565,7 @@ class PSGPUWrapper { << " optimizer_type_:" << optimizer_type_; VLOG(0) << "InitSlotInfo:" << feature_value_accessor_.GetAccessorInfo().size; val_type_size_ =TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - // val_type_size_ = - // TYPEALIGN(8, sizeof(FeatureValue) + sizeof(float) * (max_mf_dim_ + 1)); grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); - // grad_type_size_ = - // TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float))); slot_info_initialized_ = true; } #endif diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 5bca9650d6f29f..cabaac270274e3 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -// #include #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" @@ -23,9 +22,6 @@ limitations under the License. */ #include "paddle/fluid/distributed/ps/service/communicator/communicator.h" #endif -// #if defined PADDLE_WITH_HETERPS -// #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" -// #endif namespace paddle { namespace framework { @@ -51,10 +47,6 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, } #endif -// #if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE -// InitializeGPUServer(trainer_desc); -// #endif - // get filelist from trainer_desc here const std::vector readers = dataset->GetReaders(); @@ -88,83 +80,6 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, SetDebug(trainer_desc.debug()); } -#if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE -void add_sparse_optimizer( - std::unordered_map& config, // NOLINT - const ::paddle::distributed::SparseCommonSGDRuleParameter& sgd_param, - const std::string& prefix = "") { - auto optimizer_name = sgd_param.name(); - if (optimizer_name == "SparseNaiveSGDRule") { - config[prefix + "optimizer_type"] = 0; - config[prefix + "learning_rate"] = sgd_param.naive().learning_rate(); - config[prefix + "initial_range"] = sgd_param.naive().initial_range(); - config[prefix + "min_bound"] = sgd_param.naive().weight_bounds()[0]; - config[prefix + "max_bound"] = sgd_param.naive().weight_bounds()[1]; - } else if (optimizer_name == "SparseAdaGradSGDRule") { - config[prefix + "optimizer_type"] = 1; - config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate(); - config[prefix + "initial_range"] = sgd_param.adagrad().initial_range(); - config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum(); - config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0]; - config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1]; - } else if (optimizer_name == "StdAdaGradSGDRule") { - config[prefix + "optimizer_type"] = 2; - config[prefix + "learning_rate"] = sgd_param.adagrad().learning_rate(); - config[prefix + "initial_range"] = sgd_param.adagrad().initial_range(); - config[prefix + "initial_g2sum"] = sgd_param.adagrad().initial_g2sum(); - config[prefix + "min_bound"] = sgd_param.adagrad().weight_bounds()[0]; - config[prefix + "max_bound"] = sgd_param.adagrad().weight_bounds()[1]; - } else if (optimizer_name == "SparseAdamSGDRule") { - config[prefix + "optimizer_type"] = 3; - config[prefix + "learning_rate"] = sgd_param.adam().learning_rate(); - config[prefix + "initial_range"] = sgd_param.adam().initial_range(); - config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate(); - config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate(); - config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon(); - config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0]; - config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1]; - } else if (optimizer_name == "SparseAdamSharedSGDRule") { - config[prefix + "optimizer_type"] = 4; - config[prefix + "learning_rate"] = sgd_param.adam().learning_rate(); - config[prefix + "initial_range"] = sgd_param.adam().initial_range(); - config[prefix + "beta1_decay_rate"] = sgd_param.adam().beta1_decay_rate(); - config[prefix + "beta2_decay_rate"] = sgd_param.adam().beta2_decay_rate(); - config[prefix + "ada_epsilon"] = sgd_param.adam().ada_epsilon(); - config[prefix + "min_bound"] = sgd_param.adam().weight_bounds()[0]; - config[prefix + "max_bound"] = sgd_param.adam().weight_bounds()[1]; - } -} - -// void MultiTrainer::InitializeGPUServer(const TrainerDesc& trainer_desc) { -// // optimizer config for hbmps -// auto fleet_desc_str = trainer_desc.fleet_desc(); -// VLOG(0) << "InitializeGPUServer fleet_desc_str" << fleet_desc_str; -// google::protobuf::TextFormat::ParseFromString(fleet_desc_str, &_ps_param); -// auto sparse_table = -// _ps_param.server_param().downpour_server_param().downpour_table_param(0); -// auto sparse_table_accessor = sparse_table.accessor(); -// auto sparse_table_accessor_parameter = -// sparse_table_accessor.ctr_accessor_param(); -// auto accessor_class = sparse_table_accessor.accessor_class(); - -// // NOTE(zhangminxu): gpups' sparse table optimizer config, -// // now only support single sparse table -// // auto sparse_table = param_.sparse_table(0); -// std::unordered_map config; -// config["nonclk_coeff"] = sparse_table_accessor_parameter.nonclk_coeff(); -// config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff(); - -// if (accessor_class == "CtrDymfAccessor") { -// // optimizer config for embed_w and embedx -// add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param()); -// add_sparse_optimizer(config, sparse_table_accessor.embedx_sgd_param(), -// "mf_"); -// } -// auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance(); -// ps_gpu_wrapper->InitializeGPUServer(config); -// } -#endif - std::string MultiTrainer::GetDumpPath(int tid) { if (user_define_dump_filename_ != "") { return string::format_string("%s/part-%s-%05d", @@ -390,7 +305,5 @@ void MultiTrainer::Finalize() { root_scope_->DropKids(); } - - } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 75e18865505c6a..7dd988a22c908a 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -122,9 +122,6 @@ class MultiTrainer : public TrainerBase { void MergeDenseParam(); #endif -// #if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE -// void InitializeGPUServer(const TrainerDesc& trainer_desc); -// #endif protected: int thread_num_; @@ -135,13 +132,8 @@ class MultiTrainer : public TrainerBase { std::vector trainable_param_; #ifdef PADDLE_WITH_HETERPS std::vector places_; - // _ps_param for gpups optimizer config #endif -// #if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE -// ::paddle::distributed::PSParameter _ps_param; -// #endif - int mpi_rank_; int mpi_size_; int dump_file_num_; diff --git a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc index da833641227b1f..e9c993d3ee1282 100644 --- a/paddle/fluid/pybind/ps_gpu_wrapper_py.cc +++ b/paddle/fluid/pybind/ps_gpu_wrapper_py.cc @@ -41,8 +41,6 @@ void BindPSGPUWrapper(py::module* m) { .def("set_slot_vector", &framework::PSGPUWrapper::SetSlotVector, py::call_guard()) - // .def("init_GPU_server", &framework::PSGPUWrapper::InitializeGPUServer, - // py::call_guard()) #ifdef PADDLE_WITH_CUDA .def("set_slot_dim_vector", &framework::PSGPUWrapper::SetSlotDimVector, diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 902854a7c72796..c58b539b6877d3 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -594,6 +594,21 @@ def sparse_optimizer_config(sgd, strategy, prefix): bounds = strategy.get(prefix + 'sparse_weight_bounds', [-10, 10]) sgd.adam.weight_bounds.extend(bounds) + elif optimizer_name == "shared_adam": + sgd.name = 'SparseSharedAdamSGDRule' + sgd.adam.learning_rate = strategy.get( + prefix + 'sparse_learning_rate', 0.001) + sgd.adam.initial_range = strategy.get( + prefix + 'sparse_initial_range', 1e-4) + sgd.adam.beta1_decay_rate = strategy.get( + prefix + 'sparse_beta1_decay_rate', 0.9) + sgd.adam.beta2_decay_rate = strategy.get( + prefix + 'sparse_beta2_decay_rate', 0.999) + sgd.adam.ada_epsilon = strategy.get( + prefix + 'sparse_ada_epsilon', 1e-8) + bounds = strategy.get(prefix + 'sparse_weight_bounds', + [-10, 10]) + sgd.adam.weight_bounds.extend(bounds) def set_sparse_table_config(table_data, config): for key in config: diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 1b0acd79924086..be2902855f5760 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -39,7 +39,6 @@ from paddle.fluid.dygraph import to_variable from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar -from paddle.distributed.fleet.proto import the_one_ps_pb2 __all__ = [] _grad_scalar = None @@ -1618,7 +1617,7 @@ def _minimize_impl(self, context["valid_strategy"] = copy.deepcopy(valid_strategy) # print("valid_strategy:", context["valid_strategy"]) - print("user_defined_strategy:", context["user_defined_strategy"]) + # print("user_defined_strategy:", context["user_defined_strategy"]) applied_meta_list = self.strategy_compiler._get_applied_meta_list() applied_graph_list = self.strategy_compiler._get_applied_graph_list() @@ -1648,17 +1647,17 @@ def _minimize_impl(self, no_grad_set=no_grad_set) if meta_optimizer: - print("before minimize program id:", id(loss.block.program)) + # print("before minimize program id:", id(loss.block.program)) optimize_ops, params_grads = meta_optimizer.minimize( loss, startup_program, parameter_list, no_grad_set=no_grad_set) - print("after minimize program id:", id(loss.block.program)) + # print("after minimize program id:", id(loss.block.program)) default_program = paddle.static.default_main_program() - print("default program id:", id(default_program)) + # print("default program id:", id(default_program)) if id(default_program) != id(loss.block.program): paddle.fluid.framework.switch_main_program(loss.block.program) - print("default program id after switch:", id(default_program)) + # print("default program id after switch:", id(default_program)) else: optimize_ops, params_grads = self.user_defined_optimizer.minimize( @@ -1668,7 +1667,7 @@ def _minimize_impl(self, context["program_params_grads"] = params_grads if graph_optimizer: - print("before graph minimize program id:", id(loss.block.program)) + # print("before graph minimize program id:", id(loss.block.program)) optimize_ops, params_grads = graph_optimizer.minimize( loss, startup_program, parameter_list, no_grad_set=no_grad_set) # since we do not encourage users to use graph operations @@ -1680,17 +1679,6 @@ def _minimize_impl(self, else: apply_ir_passes(loss.block.program, startup_program, self) - # ps_param = the_one_ps_pb2.PSParameter() - - # all_table_proto = context["user_defined_strategy"].sparse_table_configs - # ps_param = all_table_proto.add() - - # opt_info = {} - # opt_info["fleet_desc"] = ps_param - # program = paddle.static.default_main_program() - # program._fleet_opt = opt_info - # print("ps_param:", ps_param) - if not self._role_maker._is_heter_parameter_server_mode: program = paddle.static.default_main_program() opt_info = {} if program._fleet_opt is None else program._fleet_opt @@ -1702,7 +1690,6 @@ def _minimize_impl(self, opt_info[k] = v program._fleet_opt = opt_info - print("_fleet_opt:", program._fleet_opt) if self._runtime_handle is None: self._runtime_handle = RuntimeFactory()._create_runtime(context) @@ -1777,7 +1764,7 @@ def _minimize_losses_impl(self, for k, v in self._user_defined_strategy.trainer_desc_configs.items( ): if v or k not in opt_info: - opt_info[k] = v + opt_info[k] = v program._fleet_opt = opt_info # print("fleet base opt info:", id(program), program._fleet_opt) diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 72757e6c229109..4e3db17d6584a3 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -195,7 +195,7 @@ def _set(self, accessor_proto, varname, program_id, context): sgd_param.naive.initial_range = 0.0001 if len(sgd_param.naive.weight_bounds) == 0: sgd_param.naive.weight_bounds.extend([-10.0, 10.0]) - if sgd_param.name == "SparseAdamSGDRule": + if sgd_param.name == "SparseAdamSGDRule" or sgd_param.name == "SparseSharedAdamSGDRule": if not sgd_param.adam.HasField("learning_rate"): sgd_param.adam.learning_rate = 0.001 if not sgd_param.adam.HasField("initial_range"): @@ -591,7 +591,6 @@ def _set(self, table_proto): print('new table_name: {}'.format(self.common.table_name)) all_table_proto = self.context[ "user_defined_strategy"].sparse_table_configs - print("all_table_proto:", all_table_proto) usr_table_proto = all_table_proto.add() for proto in all_table_proto: if proto.table_name == self.common.table_name: @@ -623,8 +622,6 @@ def _set(self, table_proto): # table_proto.accessor = usr_table_proto.accessor table_proto.accessor.ParseFromString( usr_table_proto.accessor.SerializeToString()) - # table_proto.accessor.CopyFrom(usr_table_proto.accessor) - print("usr_table_proto.accessor.SerializeToString():", usr_table_proto.accessor.SerializeToString()) print("====table_proto:", table_proto) self.accessor._set(table_proto.accessor, self.common.table_name, ctx.program_id(), self.context) @@ -971,7 +968,7 @@ def sync_strategy_envs(): proto_txt = worker_desc debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) - if True: + if debug: print("worker: \n{}".format(proto_txt)) print("communicator send_ctx:") for key in send_ctx: diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index adfce2a19d1ed8..f0c094a84f758b 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -104,7 +104,6 @@ def _gen_worker_desc(self, trainer_desc): print("program of current device worker is not configured") exit(-1) opt_info = self._program._fleet_opt - print("opt_info:", opt_info) # when opt_info is None or empty dict, it should return if not opt_info: return diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 1b4a63d8c30784..a34fb2dea7dc50 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -52,12 +52,6 @@ def _create_trainer(self, opt_info=None): else: trainer_class = opt_info.get("trainer", "MultiTrainer") device_worker_class = opt_info.get("device_worker", "Hogwild") - if trainer_class == '': - trainer_class = "MultiTrainer" - opt_info["trainer"] = "MultiTrainer" - if device_worker_class == '': - device_worker_class = "Hogwild" - opt_info["device_worker"] = "Hogwild" trainer = globals()[trainer_class]() device_worker = globals()[device_worker_class]() From 724381cf6cbf1ed7d8fb90d8037ad33f69a9bfa5 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 15 Jun 2022 10:18:26 +0800 Subject: [PATCH 03/31] remove useless code;test=develop --- .../distributed/ps/table/ctr_dymf_accessor.cc | 3 - .../framework/fleet/heter_ps/feature_value.h | 164 +----------------- .../fleet/heter_ps/hashtable_kernel.cu | 5 - .../fluid/framework/fleet/ps_gpu_wrapper.cc | 12 +- 4 files changed, 6 insertions(+), 178 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc index 5eaf49a0ebbf03..a677c6c190177e 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc @@ -24,12 +24,9 @@ namespace distributed { int CtrDymfAccessor::Initialize() { auto name = _config.embed_sgd_param().name(); - _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1); - VLOG(0) << "CtrDymfAccessor::Initialize embed_sgd_param name:" << name - << " embedx_sgd_param name: " << _config.embedx_sgd_param().name(); name = _config.embedx_sgd_param().name(); _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(), diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 7e899cb6f377de..4f7768f45f13ad 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -158,30 +158,6 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { __host__ __device__ CommonFeatureValueAccessor() {} __host__ __device__ ~CommonFeatureValueAccessor() {} - // __host__ __device__ virtual int Initialize() { - // std::string name = (_config.find("embed_sparse_optimizer") == _config.end()) - // ? "adagrad" - // : _config["embed_sparse_optimizer"]; - // int sparse_embedx_dim = (_config.find("sparse_embedx_dim") == _config.end()) - // ? 8 - // : std::stoi(_config["sparse_embedx_dim"]); - // if (name.compare("adam") == 0) { - // common_feature_value.embed_sgd_dim = 4; - // common_feature_value.embedx_sgd_dim = sparse_embedx_dim * 2 + 2; - // } else if (name.compare("shared_adam") == 0) { - // common_feature_value.embed_sgd_dim = 4; - // common_feature_value.embedx_sgd_dim = 4; - // } else { - // common_feature_value.embed_sgd_dim = 1; - // common_feature_value.embedx_sgd_dim = 1; - // } - - // common_feature_value.embedx_dim = sparse_embedx_dim; - - // // VLOG(0) << " INTO FeatureValueAccessor::Initialize()"; - // InitAccessorInfo(); - // return 0; - // } __host__ __device__ virtual int Initialize() { int optimizer_type = (_config.find("optimizer_type") == _config.end()) @@ -223,40 +199,6 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); } - // friend std::ostream& operator<<(std::ostream& out, CommonFeatureValueAccessor& v) { - // /* - // uint64_t cpu_ptr; - // float delta_score; - // float show; - // float click; - // float embed_w; - // std::vector embed_g2sum; - // float slot; - // float mf_dim - // float mf_size - // std::vector embedx_g2sum; - // std::vector embedx_w; - // */ - // out << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4]; - // // << v[5] << " " << v[6]; - // for (int i = common_feature_value.EmbedG2SumIndex(); - // i < common_feature_value.EmbedxWIndex(); i++) { - // out << " " << v[i]; - // } - // out << " " << common_feature_value.Slot(v) << " " - // << common_feature_value.MfDim(v) - // << common_feature_value.MfSize(v); - - // for (int x = 0; x < common_feature_value.EmbedXDim(); x++) { - // out << " " << v[common_feature_value.EmbedxG2SumIndex() + x]; - // } - // for (int x = 0; x < common_feature_value.MfDim(v); x++) { - // out << " " << v[common_feature_value.EmbedxWIndex() + x]; - // } - // return out; - // } - - __host__ __device__ std::string ParseToString(const float* v, int param_size) { /* uint64_t cpu_ptr; // 2float @@ -295,8 +237,6 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { public: CommonFeatureValue common_feature_value; CommonPushValue common_push_value; - // SparseValueSGDRule* _embed_sgd_rule; - // SparseValueSGDRule* _embedx_sgd_rule; }; @@ -306,26 +246,17 @@ struct FeatureValue { float clk; int slot; float lr; + float lr_g2sum; int mf_size; int mf_dim; uint64_t cpu_ptr; - int lr_sgd_dim; - int mf_sgd_dim; - float lr_g2sum[1]; - float mf_g2sum[1]; float mf[0]; friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) { out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot << " lr: " << val.lr << " mf_dim: " << val.mf_dim << "cpuptr: " << val.cpu_ptr << " mf_size: " << val.mf_size << " mf:"; - for (int i = 0; i < val.lr_sgd_dim; ++i) { - out << " " << val.lr_g2sum[i]; - } - for (int i = 0; i < val.mf_sgd_dim; ++i) { - out << " " << val.mf_g2sum[i]; - } - for (int i = 0; i < val.mf_dim; ++i) { + for (int i = 0; i < val.mf_dim + 1; ++i) { out << " " << val.mf[i]; } return out; @@ -336,103 +267,16 @@ struct FeatureValue { clk = in.clk; slot = in.slot; lr = in.lr; + lr_g2sum = in.lr_g2sum; mf_size = in.mf_size; mf_dim = in.mf_dim; cpu_ptr = in.cpu_ptr; - lr_sgd_dim = in.lr_sgd_dim; - mf_sgd_dim = in.mf_sgd_dim; - - for (int i = 0; i < lr_sgd_dim; ++i) { - lr_g2sum[i] = in.lr_g2sum[i]; - } - for (int i = 0; i < mf_sgd_dim; ++i) { - mf_g2sum[i] = in.mf_g2sum[i]; - } - for (int i = 0; i < mf_dim; i++) { + for (int i = 0; i < mf_dim + 1; i++) { mf[i] = in.mf[i]; } } }; -// struct AdamFeatureValue { -// float delta_score; -// float show; -// float clk; -// int slot; -// float lr; -// int mf_size; -// int mf_dim; -// int lr_sgd_dim; -// int mf_sgd_dim; -// uint64_t cpu_ptr; -// float lr_g2sum[0]; -// float mf_g2sum[0]; -// float mf[0]; - - -// __device__ __forceinline__ void operator=(const FeatureValue& in) { -// delta_score = in.delta_score; -// show = in.show; -// clk = in.clk; -// slot = in.slot; -// lr = in.lr; -// mf_size = in.mf_size; -// mf_dim = in.mf_dim; -// cpu_ptr = in.cpu_ptr; -// lr_sgd_dim = in.lr_sgd_dim; -// mf_sgd_dim = in.mf_sgd_dim; - -// for (int i = 0; i < lr_sgd_dim; ++i) { -// lr_g2sum[i] = in.lr_g2sum[i]; -// } -// for (int i = 0; i < mf_sgd_dim; ++i) { -// mf_g2sum[i] = in.mf_g2sum[i]; -// } -// for (int i = 0; i < mf_dim; i++) { -// mf[i] = in.mf[i]; -// } -// } -// }; - -// struct AdamSharedFeatureValue { -// float delta_score; -// float show; -// float clk; -// int slot; -// float lr; -// int mf_size; -// int mf_dim; -// int lr_sgd_dim; -// int mf_sgd_dim; -// uint64_t cpu_ptr; -// float lr_g2sum[4]; -// float mf_g2sum[4]; -// float mf[0]; - -// __device__ __forceinline__ void operator=(const FeatureValue& in) { -// delta_score = in.delta_score; -// show = in.show; -// clk = in.clk; -// slot = in.slot; -// lr = in.lr; -// mf_size = in.mf_size; -// mf_dim = in.mf_dim; -// cpu_ptr = in.cpu_ptr; -// lr_sgd_dim = in.lr_sgd_dim; -// mf_sgd_dim = in.mf_sgd_dim; - -// for (int i = 0; i < lr_sgd_dim; ++i) { -// lr_g2sum[i] = in.lr_g2sum[i]; -// } -// for (int i = 0; i < mf_sgd_dim; ++i) { -// mf_g2sum[i] = in.mf_g2sum[i]; -// } -// for (int i = 0; i < mf_dim; i++) { -// mf[i] = in.mf[i]; -// } -// } -// }; - struct FeaturePushValue { float show; diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index f072b20e7e27aa..d2ecde637327d1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -169,8 +169,6 @@ __global__ void dy_mf_update_kernel(Table* table, if (it != table->end()) { float* cur = (float*)(grads + i * grad_value_size); sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, cur); - // printf("dy_mf_update_kernel: %d, %s", keys[i], - // sgd.feature_value_accessor_.ParseToString(cur, sgd.feature_value_accessor_.GetAccessorInfo().dim)); } else { if (keys[i] != 0) { printf("warning::push miss key: %llu", keys[i]); @@ -244,9 +242,6 @@ void HashTable::get(const KeyType* d_keys, return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; - VLOG(0) << "GET:" << feature_value_accessor_.common_feature_value.EmbedDim() - << " " << feature_value_accessor_.common_feature_value.EmbedXDim() - << " " << feature_value_accessor_.common_feature_value.EmbedWDim(); dy_mf_search_kernel<<>>( container_, d_keys, d_vals, len, pull_feature_value_size_, feature_value_accessor_); } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 5c4c29961f775b..7c982c4a699814 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -769,7 +769,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } #endif #ifdef PADDLE_WITH_PSCORE - VLOG(0) << "cpu build "<< k << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) + VLOG(5) << "cpu build "<< k << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) << " |: "<< cpu_table_accessor_->ParseToString(ptr_val, dim); val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] = ptr_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()]; @@ -814,7 +814,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x] = 0; } } - VLOG(0) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); + VLOG(5) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); } #endif @@ -1200,14 +1200,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, HeterPs_->pull_sparse( devid_2_index, total_keys, total_values_gpu, total_length); - // char* tmp_mem = (char*)malloc(total_length * feature_value_size); - // cudaMemcpy(tmp_mem, total_values_gpu, total_length * feature_value_size, - // cudaMemcpyDeviceToHost); - // for (int i =0 ; i < 20; i++){ - // float* val = (float*)(void*)&tmp_mem[(i)*feature_value_size]; - // VLOG(0) << "pullsparse_cpu "<< i << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); - // } - VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length << "]"; From 15dd7a1399fd93e377dee28bbbb9696588a60b03 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 15 Jun 2022 10:30:57 +0800 Subject: [PATCH 04/31] remove useless code;test=develop --- paddle/fluid/framework/fleet/heter_ps/feature_value.h | 1 - paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h | 2 +- .../fluid/framework/fleet/heter_ps/heter_comm_kernel.cu | 2 -- paddle/fluid/framework/fleet/heter_ps/heter_ps.cu | 1 - paddle/fluid/framework/fleet/ps_gpu_wrapper.cc | 9 ++------- paddle/fluid/framework/fleet/ps_gpu_wrapper.cu | 2 -- paddle/fluid/framework/multi_trainer.cc | 1 - paddle/fluid/framework/trainer.h | 1 - python/paddle/distributed/fleet/base/fleet_base.py | 3 ++- python/paddle/distributed/ps/the_one_ps.py | 3 --- 10 files changed, 5 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 4f7768f45f13ad..67f6d207b109d5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -277,7 +277,6 @@ struct FeatureValue { } }; - struct FeaturePushValue { float show; float clk; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 70edcee5950553..5558734ee0fa50 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -867,7 +867,7 @@ void HeterComm::split_input_to_shard( template void HeterComm::pull_sparse(int num, KeyType* d_keys, - float* d_vals, //new edit from ValType + float* d_vals, size_t len) { if (len == 0) { return; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 58001d5c022bc7..15adbded3e01f2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -140,8 +140,6 @@ __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys, const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { d_shard_keys[i] = d_keys[idx[i]]; - // *(float*)((char*)d_shard_grads + i * grad_value_size) = - // *(float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); float* cur = (float*)((char*)d_shard_grads + i * grad_value_size); float* shard_val = (float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index 2d030c75a82ba5..f5a03131f1bf7c 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -36,7 +36,6 @@ HeterPs::HeterPs(size_t capacity, std::shared_ptr resource, capacity, resource, feature_value_accessor); feature_value_accessor_ = feature_value_accessor; optimizer_type_ = optimizer_type; - // opt_ = Optimizer(); } HeterPs::~HeterPs() {} diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 7c982c4a699814..75bcd489619c83 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -1036,9 +1036,6 @@ void PSGPUWrapper::EndPass() { auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(*(reinterpret_cast(gpu_val+ feature_value_accessor_.common_feature_value.CpuPtrIndex()))); size_t downpour_value_size = downpour_value->size(); - VLOG(0) << "downpour_value_size:" <GetAccessorInfo().dim - << " MF_FIM:" << cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float); if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0 && downpour_value_size == (cpu_table_accessor_->GetAccessorInfo().dim - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float))) { // cpu_accessor @@ -1072,10 +1069,8 @@ void PSGPUWrapper::EndPass() { gpu_val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x]; } } - VLOG(0) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.GetAccessorInfo().dim) - << " =====CPU "<< cpu_table_accessor_->GetAccessorInfo().dim - << "-" << downpour_value->size() - << " :" << cpu_table_accessor_->ParseToString(cpu_val, downpour_value->size()); + VLOG(5) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.GetAccessorInfo().dim) + << " ===== CPU:" << cpu_table_accessor_->ParseToString(cpu_val, downpour_value->size()); } #endif diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index e14b78522ada35..550baf5f50f3e7 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -199,8 +199,6 @@ __global__ void PushCopyWithPool(float* dest, *(src[x] + y * (mf_dim + 3) + 1); cur[feature_value_accessor.common_push_value.EmbedGIndex()] = *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; - // printf("PushCopyWithPool show:%f ,click: %d\n", cur[feature_value_accessor.common_push_value.ShowIndex()], - // cur[feature_value_accessor.common_push_value.ClickIndex()]); for (int j = 0; j < mf_dim; j++) { cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; } diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index cabaac270274e3..c2c05f373c2d23 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -46,7 +46,6 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, places_.push_back(place); } #endif - // get filelist from trainer_desc here const std::vector readers = dataset->GetReaders(); diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 7dd988a22c908a..1a805ccd76e440 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -133,7 +133,6 @@ class MultiTrainer : public TrainerBase { #ifdef PADDLE_WITH_HETERPS std::vector places_; #endif - int mpi_rank_; int mpi_size_; int dump_file_num_; diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index be2902855f5760..e555047b0e8abb 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -3,7 +3,7 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software @@ -39,6 +39,7 @@ from paddle.fluid.dygraph import to_variable from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar + __all__ = [] _grad_scalar = None diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 4e3db17d6584a3..7d240983a1c289 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -619,10 +619,8 @@ def _set(self, table_proto): warnings.warn( "The accessor of sparse table is not set, use default value.") - # table_proto.accessor = usr_table_proto.accessor table_proto.accessor.ParseFromString( usr_table_proto.accessor.SerializeToString()) - print("====table_proto:", table_proto) self.accessor._set(table_proto.accessor, self.common.table_name, ctx.program_id(), self.context) @@ -938,7 +936,6 @@ def _init_worker(self, scopes=None): #with open("test_fl_ps_worker_desc", "w") as f: # f.write(worker_desc) if self.context['use_ps_gpu']: - #main_program._fleet_opt["fleet_desc"] = worker_desc() main_program = self.context['loss'].block.program if not main_program._fleet_opt: main_program._fleet_opt = {} From 4f91222b296cbcd3bc9b3fe6f5d2f2683340b90a Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 16 Jun 2022 09:40:48 +0800 Subject: [PATCH 05/31] remove useless code;test=develop --- paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc | 2 -- paddle/fluid/framework/fleet/heter_ps/feature_value.h | 5 ++--- paddle/fluid/framework/fleet/heter_ps/hashtable.h | 1 - paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h | 3 +-- paddle/fluid/framework/fleet/heter_ps/heter_ps.cu | 4 +++- paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h | 10 +++++++--- paddle/fluid/framework/fleet/ps_gpu_wrapper.cc | 2 +- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 2 -- python/paddle/distributed/fleet/base/fleet_base.py | 2 +- 9 files changed, 15 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc index f23e7f71c603d0..49ee493dbef50a 100644 --- a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc +++ b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc @@ -213,7 +213,6 @@ void SparseAdamSGDRule::UpdateValueWork(float* w, float beta1_pow_ = *beta1_pow; float beta2_pow_ = *beta2_pow; - // lr not change in one update lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_); for (size_t i = 0; i < _embedding_dim; i++) { // Calculation @@ -288,7 +287,6 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* float gsum_ = *gsum; float g2sum_ = *g2sum; - // lr not change in one update lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_); double sum_gsum = 0.0; double sum_g2sum = 0.0; diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 67f6d207b109d5..039c4563cccf37 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -55,7 +55,6 @@ class FeatureValueAccessor { __host__ __device__ virtual GpuAccessorInfo GetAccessorInfo() { return _accessor_info; } protected: - // TableAccessorParameter _config; std::unordered_map _config; GpuAccessorInfo _accessor_info; }; @@ -80,9 +79,9 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { std::vector embedx_w; */ - __host__ __device__ int Dim() { return 8 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr + __host__ __device__ int Dim() { return 8 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr(1) __host__ __device__ int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } - __host__ __device__ int Size() { return (Dim()-1) * sizeof(float) + sizeof(uint64_t); } + __host__ __device__ int Size() { return (Dim()-1) * sizeof(float) + sizeof(uint64_t); } // cpu_ptr:uint64 __host__ __device__ int EmbedDim() { return embed_sgd_dim;} __host__ __device__ int EmbedXDim() { return embedx_sgd_dim;} __host__ __device__ int EmbedWDim() { return embedx_dim;} diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 8b0d4dd3a53202..3269b7a24e9c48 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -211,7 +211,6 @@ class HashTable { size_t max_mf_dim_ = 8; size_t pull_feature_value_size_; size_t push_grad_value_size_; - }; } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 5558734ee0fa50..28a38ab520c99d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -690,7 +690,6 @@ void HeterComm::dynamic_merge_grad( size_t temp_storage_bytes; - // VLOG(1) << "hetercomm merge_grad: max_mf_dim: " << max_mf_dim_; size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); @@ -915,7 +914,7 @@ void HeterComm::pull_sparse(int num, auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - VLOG(0) << "PULLSPARSE len:" << len << " val_type_size: " << val_type_size; + VLOG(5) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = memory::Alloc(place, len * val_type_size); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index f5a03131f1bf7c..a653f08253b141 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -80,13 +80,15 @@ void HeterPs::push_sparse(int num, size_t len) { if (optimizer_type_ == 3) { //adam auto optimizer = SparseAdamOptimizer(feature_value_accessor_); - VLOG(0) << "INTO push_sparse SparseAdamOptimizer EmbedDim():" << optimizer.EmbedDim(); + VLOG(5) << "INTO push_sparse SparseAdamOptimizer, EmbedDim():" << optimizer.EmbedDim(); comm_->push_sparse(num, d_keys, d_grads, len, optimizer); } else if (optimizer_type_ == 4) { //shared_adam auto optimizer = SparseAdamSharedOptimizer(feature_value_accessor_); + VLOG(5) << "INTO push_sparse SparseAdamSharedOptimizer, EmbedDim():" << optimizer.EmbedDim(); comm_->push_sparse(num, d_keys, d_grads, len, optimizer); } else { auto optimizer = SparseAdagradOptimizer(feature_value_accessor_); + VLOG(5) << "INTO push_sparse SparseAdagradOptimizer, EmbedDim():" << optimizer.EmbedDim(); comm_->push_sparse(num, d_keys, d_grads, len, optimizer); } } diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index db27da8712f2dc..cdd6aa8b4a0690 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -38,6 +38,7 @@ class Optimizer { __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { + printf("Warning: update_value will not used. Please use dy_mf_update_value\n"); } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, @@ -87,6 +88,7 @@ class SparseAdagradOptimizer : public Optimizer { __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { + printf("Warning: update_value will not used. Please use dy_mf_update_value\n"); } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, float* ptr, const float* grad) { @@ -221,6 +223,7 @@ class SparseAdamOptimizer : public Optimizer { __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { + printf("Warning: update_value will not used. Please use dy_mf_update_value\n"); } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, float* ptr, const float* grad) { @@ -243,7 +246,7 @@ class SparseAdamOptimizer : public Optimizer { grad + feature_value_accessor_.common_push_value.EmbedGIndex(), g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); - printf("mf_dim: %f, lr_gsum: %f, ", mf_dim, ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex()]); + // printf("mf_dim: %f, lr_gsum: %f, ", mf_dim, ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex()]); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * @@ -275,8 +278,8 @@ class SparseAdamOptimizer : public Optimizer { grad + feature_value_accessor_.common_push_value.EmbedxGIndex(), g_show); } - printf("EmbedxGIndex: %f, mf_gsum: %f, ", feature_value_accessor_.common_push_value.EmbedxGIndex(), - ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex()]); + // printf("EmbedxGIndex: %f, mf_gsum: %f, ", feature_value_accessor_.common_push_value.EmbedxGIndex(), + // ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex()]); } __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); } @@ -346,6 +349,7 @@ class SparseAdamSharedOptimizer : public Optimizer { __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { + printf("Warning: update_value will not used. Please use dy_mf_update_value\n"); } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 75bcd489619c83..3f4c382b25be7a 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -1131,7 +1131,7 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, size_t feature_value_size = 0; feature_value_size = TYPEALIGN( 8, feature_value_accessor_.GetAccessorInfo().size); - VLOG(0) << "PULLSPASE" << feature_value_accessor_.GetAccessorInfo().size; + VLOG(5) << "PullSparse feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size; #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begine Gpu Ps PullSparse"; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 1a6b0812293c25..692bb1bd838c37 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -41,7 +41,6 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/dynload/nccl.h" -// #include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h" #endif #ifdef PADDLE_WITH_XPU_KP #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" @@ -359,7 +358,6 @@ class PSGPUWrapper { "mf_"); } - // CommonFeatureValueAccessor feature_value_accessor_; feature_value_accessor_.Configure(config); VLOG(0) << "INIT feature_value_accessor_:" << feature_value_accessor_.GetAccessorInfo().dim << " EMBX:" << feature_value_accessor_.common_feature_value.embedx_sgd_dim; diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index e555047b0e8abb..f4f2076cd12b79 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -3,7 +3,7 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software From ecc76cd373aef9ec2cd5250c21481fd2728456a7 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 16 Jun 2022 14:24:58 +0800 Subject: [PATCH 06/31] remove useless code;test=develop --- .../framework/fleet/heter_ps/heter_comm_inl.h | 15 ++++----------- paddle/fluid/framework/fleet/ps_gpu_wrapper.cc | 4 ---- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 28a38ab520c99d..80499b3ead97c0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -72,8 +72,8 @@ HeterComm::HeterComm( } else { max_mf_dim_ = resource_->max_mf_dim(); feature_value_accessor_ = feature_value_accessor; - VLOG(0) << " HeterComm INIT:" << feature_value_accessor_.GetAccessorInfo().size - << " " << feature_value_accessor_.GetAccessorInfo().update_size; + VLOG(3) << " HeterComm init, feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size + << ", feature_value_push_size:" << feature_value_accessor_.GetAccessorInfo().update_size; size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); auto ptr_table = new PtrTable(capacity / load_factor_); @@ -1101,15 +1101,8 @@ void HeterComm::push_sparse(int dev_num, if (h_left[i] == -1 || h_right[i] == -1) { continue; } - if (!multi_mf_dim_) { - create_storage(dev_num, - i, - shard_len * sizeof(KeyType), - shard_len * sizeof(GradType)); - } else { - create_storage( - dev_num, i, shard_len * sizeof(KeyType), shard_len * grad_value_size); - } + create_storage( + dev_num, i, shard_len * sizeof(KeyType), shard_len * grad_value_size); } walk_to_dest(dev_num, diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 3f4c382b25be7a..5cdbc7dd2a06a4 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -786,11 +786,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { ptr_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i]; } - // VLOG(0)<< "CpuPtrIndex:" << feature_value_accessor_.common_feature_value.CpuPtrIndex(); - // reinterpret_cast(val[feature_value_accessor_.common_feature_value.CpuPtrIndex()]) = - // (uint64_t)(device_dim_ptrs[k]); *(reinterpret_cast(val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = (uint64_t)(device_dim_ptrs[k]); - // (uint64_t*)(val + feature_value_accessor_.common_feature_value.CpuPtrIndex()) = (uint64_t)(device_dim_ptrs[k]); ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = float(mf_dim); val[feature_value_accessor_.common_feature_value.MfDimIndex()] = mf_dim; if (dim > cpu_table_accessor_->GetAccessorInfo().dim - From 1cffb3ee60ae87df48ef8478ee81307664f64ddd Mon Sep 17 00:00:00 2001 From: danleifeng Date: Mon, 20 Jun 2022 19:34:22 +0800 Subject: [PATCH 07/31] fix adam; test=develop --- .../distributed/ps/table/ctr_dymf_accessor.cc | 23 ++++++-- .../distributed/ps/table/ctr_dymf_accessor.h | 13 +++++ .../framework/fleet/heter_ps/feature_value.h | 52 +++++++++++++++---- .../fleet/heter_ps/hashtable_kernel.cu | 30 ++++++++--- .../framework/fleet/heter_ps/heter_comm_inl.h | 3 +- .../fleet/heter_ps/heter_comm_kernel.cu | 2 + .../framework/fleet/heter_ps/optimizer.cuh.h | 13 +++-- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 12 +++-- .../fluid/framework/fleet/ps_gpu_wrapper.cu | 3 +- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 4 +- 10 files changed, 122 insertions(+), 33 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc index a677c6c190177e..04b7295ea7a5ad 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc @@ -31,6 +31,7 @@ int CtrDymfAccessor::Initialize() { _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name); _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(), _config.embedx_dim()); + common_feature_value.optimizer_name = name; common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim(); common_feature_value.embedx_dim = _config.embedx_dim(); @@ -63,6 +64,20 @@ void CtrDymfAccessor::InitAccessorInfo() { (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); } +void CtrDymfAccessor::DynamicChangeDim(int mf_dim) { + // 假设一个任务中sparse优化器是不变的,改变的只是不同slot的embedding维度,比如组网中既包括8维又有32维 + if (common_feature_value.optimizer_name == "SparseAdamSGDRule") {//adam + common_feature_value.embedx_sgd_dim = mf_dim * 2 + 2; + } else if (common_feature_value.optimizer_name == "SparseSharedAdamSGDRule") { //shared_adam + common_feature_value.embedx_sgd_dim = 4; + } else { + common_feature_value.embedx_sgd_dim = 1; + } + common_feature_value.embedx_dim = mf_dim; + + // InitAccessorInfo(); + } + bool CtrDymfAccessor::Shrink(float* value) { auto delete_after_unseen_days = _config.ctr_accessor_param().delete_after_unseen_days(); @@ -182,7 +197,8 @@ int32_t CtrDymfAccessor::Create(float** values, size_t num) { value[common_feature_value.SlotIndex()] = -1; value[common_feature_value.MfDimIndex()] = -1; _embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(), - value + common_feature_value.EmbedG2SumIndex()); + value + common_feature_value.EmbedG2SumIndex(), + false); // adam embed init not zero, adagrad embed init zero _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(), value + common_feature_value.EmbedxG2SumIndex(), false); @@ -295,13 +311,14 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) { i++) { os << " " << v[i]; } - // os << " " << common_feature_value.Slot(const_cast(v)) << " " - // << common_feature_value.MfDim(const_cast(v)); auto show = common_feature_value.Show(const_cast(v)); auto click = common_feature_value.Click(const_cast(v)); auto score = ShowClickScore(show, click); + auto mf_dim = common_feature_value.MfDim(const_cast(v)); if (score >= _config.embedx_threshold() && param > common_feature_value.EmbedxG2SumIndex()) { + + DynamicChangeDim(int(mf_dim)); for (auto i = common_feature_value.EmbedxG2SumIndex(); i < common_feature_value.EmbedxWIndex() + common_feature_value.MfDim(const_cast(v)); diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h index d5b9acd8e9b258..720b7e5076350e 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h @@ -58,6 +58,16 @@ class CtrDymfAccessor : public ValueAccessor { int MfDimIndex() { return SlotIndex() + 1; } int EmbedxG2SumIndex() { return MfDimIndex() + 1; } int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } + // int EmbedxWOffsetIndex(float* val) { + // if (optimizer_name == "SparseAdamSGDRule") {//adam + // embedx_sgd_dim = int(MfDim(val)) * 2 + 2; + // } else if (optimizer_name == "SparseSharedAdamSGDRule") { //shared_adam + // embedx_sgd_dim = 4; + // } else { + // embedx_sgd_dim = 1; + // } + // return EmbedxG2SumIndex() + embedx_sgd_dim; + // } float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; } float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } @@ -73,6 +83,7 @@ class CtrDymfAccessor : public ValueAccessor { int embed_sgd_dim; int embedx_dim; int embedx_sgd_dim; + std::string optimizer_name; }; struct CtrDymfPushValue { @@ -151,6 +162,8 @@ class CtrDymfAccessor : public ValueAccessor { CtrDymfAccessor() {} virtual ~CtrDymfAccessor() {} virtual int Initialize(); + // 多种维度时更新目前的长度 + virtual void DynamicChangeDim(int mf_dim); // 初始化AccessorInfo virtual void InitAccessorInfo(); // 判断该value是否进行shrink diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 039c4563cccf37..c996d2cef34892 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -36,7 +36,7 @@ struct GpuAccessorInfo { size_t update_dim; // push value各个维度的size size_t update_size; - // value中mf动态长度部分总size大小, sparse下生效 + // value中mf动态长度部分总size大小, 包含mf_g2sum和 mf_dim, sparse下生效 size_t mf_size; }; @@ -96,6 +96,28 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { __host__ __device__ int MfSizeIndex() { return MfDimIndex() + 1; } // actual mf size (ex. 0) __host__ __device__ int EmbedxG2SumIndex() { return MfSizeIndex() + 1; } __host__ __device__ int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } + + __host__ __device__ int EmbedxG2SumOffsetIndex() { return 0; } + __host__ __device__ int EmbedxWOffsetIndex(float* val) { + // has mf + if (int(MfSize(val)) > 0) { + if (optimizer_type_ == 3) {//adam + embedx_sgd_dim = int(MfDim(val)) * 2 + 2; + } else if (optimizer_type_ == 4) { //shared_adam + embedx_sgd_dim = 4; + } else { + embedx_sgd_dim = 1; + } + // PADDLE_ENFORCE(embedx_sgd_dim + int(MfDim(val)) == int(MfSize(val)), + // "The number of embedx_sgd_dim size must be equal to mf_size." + // "But got embedx_sgd_dim = %d, mf_size = %s", embedx_sgd_dim, int(MfSize(val))); + return EmbedxG2SumIndex() + embedx_sgd_dim; + } else { + // no mf + return 0; + } + } + __host__ __device__ uint64_t CpuPtr(float* val) {return *(reinterpret_cast(val)); } __host__ __device__ float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } @@ -112,6 +134,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { int embed_sgd_dim; int embedx_dim; int embedx_sgd_dim; + int optimizer_type_; }; struct CommonPushValue { @@ -175,27 +198,36 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { common_feature_value.embed_sgd_dim = 1; common_feature_value.embedx_sgd_dim = 1; } - + common_feature_value.optimizer_type_ = optimizer_type; common_feature_value.embedx_dim = sparse_embedx_dim; - + // VLOG(0) << " INTO FeatureValueAccessor::Initialize()"; InitAccessorInfo(); return 0; } + __host__ __device__ virtual void DynamicChangeDim(int mf_dim) { + // 假设一个任务中sparse优化器是不变的,改变的只是不同slot的embedding维度,比如组网中既包括8维又有32维 + if (common_feature_value.optimizer_type_ == 3) { //adam + common_feature_value.embedx_sgd_dim = mf_dim * 2 + 2; + } else if (common_feature_value.optimizer_type_ == 4) { //shared_adam + common_feature_value.embedx_sgd_dim = 4; + } else { + common_feature_value.embedx_sgd_dim = 1; + } + common_feature_value.embedx_dim = mf_dim; + + InitAccessorInfo(); + } + // 初始化AccessorInfo __host__ __device__ virtual void InitAccessorInfo() { _accessor_info.dim = common_feature_value.Dim(); _accessor_info.size = common_feature_value.Size(); - - int embedx_dim = (_config.find("embedx_dim") == _config.end()) - ? 8 - : int(_config["embedx_dim"]); - // VLOG(0) << "feature value InitAccessorInfo embedx_dim:" << embedx_dim; - _accessor_info.update_dim = 5 + embedx_dim; + _accessor_info.update_dim = 5 + common_feature_value.EmbedWDim(); _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); _accessor_info.mf_size = - (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); + (common_feature_value.EmbedWDim() + common_feature_value.EmbedXDim()) * sizeof(float); } __host__ __device__ std::string ParseToString(const float* v, int param_size) { diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index d2ecde637327d1..47b1a8ed3b60ef 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -99,7 +99,8 @@ __global__ void dy_mf_search_kernel(Table* table, uint64_t offset = i * pull_feature_value_size; float* cur = (float*)(vals + offset); float* input = it->second; - + int mf_dim = int(input[feature_value_accessor.common_feature_value.MfDimIndex()]); + feature_value_accessor.DynamicChangeDim(mf_dim); cur[feature_value_accessor.common_feature_value.SlotIndex()] = input[feature_value_accessor.common_feature_value.SlotIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = @@ -123,13 +124,17 @@ __global__ void dy_mf_search_kernel(Table* table, input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; } - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedXDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x] = - input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x]; - } - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedWDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxWIndex() + x] = - input[feature_value_accessor.common_feature_value.EmbedxWIndex() + x]; + if (int(input[feature_value_accessor.common_feature_value.MfSizeIndex()]) > 0) { + for (int i =0; i < int(input[feature_value_accessor.common_feature_value.MfSizeIndex()]); i++){ + cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + i] = + input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + i]; + } + // input[feature_value_accessor.common_feature_value.EmbedxWOffsetIndex(input) + j]; + } else { + for (int i =0; i < int(feature_value_accessor.GetAccessorInfo().mf_size / sizeof(float)); i++){ + cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + i] = + input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + i]; + } } } else { if (keys[i] != 0) { @@ -361,6 +366,15 @@ void HashTable::update(const KeyType* d_keys, return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; + + char* tmp_config = (char*)malloc(sizeof(OptimizerConfig)); + cudaMemcpy(tmp_config, device_optimizer_config_, + sizeof(OptimizerConfig), cudaMemcpyDeviceToHost); + VLOG(0) << "tmp_config: learning_rate:" << tmp_config->learning_rate + << " initial_range:" << tmp_config->initial_range + << " mf_learning_rate:" << tmp_config->mf_learning_rate + << " mf_initial_range:" << tmp_config->mf_initial_range + << " mf_initial_g2sum:" << tmp_config->mf_initial_g2sum; dy_mf_update_kernel<<>>( container_, *device_optimizer_config_, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 80499b3ead97c0..e81a5c447f5c07 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -72,7 +72,8 @@ HeterComm::HeterComm( } else { max_mf_dim_ = resource_->max_mf_dim(); feature_value_accessor_ = feature_value_accessor; - VLOG(3) << " HeterComm init, feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size + feature_value_accessor_.DynamicChangeDim(max_mf_dim_); + VLOG(0) << " HeterComm init, max feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size << ", feature_value_push_size:" << feature_value_accessor_.GetAccessorInfo().update_size; size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 15adbded3e01f2..3975f8dcd8aab4 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -200,6 +200,8 @@ __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, uint64_t new_offset = uint64_t(idx[i]) * val_size; float* cur = (float*)((char*)d_vals + new_offset); float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); + int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); + feature_value_accessor.DynamicChangeDim(mf_dim); cur[feature_value_accessor.common_feature_value.SlotIndex()] = shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index cdd6aa8b4a0690..a713aad70b5e4d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -111,13 +111,15 @@ class SparseAdagradOptimizer : public Optimizer { g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + feature_value_accessor_.DynamicChangeDim(mf_dim); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -246,14 +248,15 @@ class SparseAdamOptimizer : public Optimizer { grad + feature_value_accessor_.common_push_value.EmbedGIndex(), g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); - // printf("mf_dim: %f, lr_gsum: %f, ", mf_dim, ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex()]); + feature_value_accessor_.DynamicChangeDim(mf_dim); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -373,13 +376,15 @@ class SparseAdamSharedOptimizer : public Optimizer { grad + feature_value_accessor_.common_push_value.EmbedGIndex(), g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + feature_value_accessor_.DynamicChangeDim(mf_dim); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = mf_dim; + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 5cdbc7dd2a06a4..e853467fef766a 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -661,6 +661,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); // this->HeterPs_->set_accessor(feature_value_accessor_); int mf_dim = this->index_dim_vec_[j]; + feature_value_accessor_.DynamicChangeDim(mf_dim); VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim << " feature_value_DIM:" << feature_value_accessor_.GetAccessorInfo().dim; size_t feature_value_size = @@ -787,6 +788,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } *(reinterpret_cast(val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = (uint64_t)(device_dim_ptrs[k]); + ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = float(mf_dim); val[feature_value_accessor_.common_feature_value.MfDimIndex()] = mf_dim; if (dim > cpu_table_accessor_->GetAccessorInfo().dim - @@ -983,7 +985,9 @@ void PSGPUWrapper::EndPass() { } // ============ multi-thread process feasign============ int mf_dim = this->index_dim_vec_[j]; - VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim << " key_len :" << len; + feature_value_accessor_.DynamicChangeDim(mf_dim); + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim + << " key_len :" << len; size_t feature_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); char* test_build_values = (char*)malloc(feature_value_size * real_len); @@ -1126,8 +1130,9 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); size_t feature_value_size = 0; - feature_value_size = TYPEALIGN( 8, feature_value_accessor_.GetAccessorInfo().size); - VLOG(5) << "PullSparse feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size; + feature_value_accessor_.DynamicChangeDim(max_mf_dim_); + feature_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + VLOG(0) << "PullSparse max_dim:" << max_mf_dim_ << " feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size; #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begine Gpu Ps PullSparse"; @@ -1287,6 +1292,7 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); // #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begin GPUPS PushSparseGrad"; + feature_value_accessor_.DynamicChangeDim(max_mf_dim_); size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); auto buf = memory::Alloc(place, total_length * grad_value_size); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 550baf5f50f3e7..5eff19e08ae29b 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -106,7 +106,8 @@ __global__ void PullCopy(float** dest, } } else { for (int j = 0; j < mf_dim; j++) { - *(dest[x] + y * (mf_dim + 3) + 3 + j) = feature_value_ptr[feature_value_accessor.common_feature_value.EmbedxWIndex() + j]; + *(dest[x] + y * (mf_dim + 3) + 3 + j) = + feature_value_ptr[feature_value_accessor.common_feature_value.EmbedxWOffsetIndex(feature_value_ptr) + j]; } } } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 692bb1bd838c37..dd81a7ba081ffe 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -558,9 +558,7 @@ class PSGPUWrapper { slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]]; } //TODO(FENGDANLEI): max_mf - VLOG(0) << "InitSlotInfo embed_sgd_dim_:" << embed_sgd_dim_ << " embedx_sgd_dim_:" - << embedx_sgd_dim_ << " embedx_dim_:" << embedx_dim_ - << " optimizer_type_:" << optimizer_type_; + feature_value_accessor_.DynamicChangeDim(max_mf_dim_); VLOG(0) << "InitSlotInfo:" << feature_value_accessor_.GetAccessorInfo().size; val_type_size_ =TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); From 9a6ff5def5aa131f9ebdd7d26a26bc40b0636431 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 22 Jun 2022 00:24:28 +0800 Subject: [PATCH 08/31] fix adam; test=develop --- .../framework/fleet/heter_ps/feature_value.h | 39 +++++++++-- .../fleet/heter_ps/hashtable_kernel.cu | 29 ++------ .../framework/fleet/heter_ps/heter_comm_inl.h | 16 +++-- .../fleet/heter_ps/heter_comm_kernel.cu | 10 +-- .../framework/fleet/heter_ps/optimizer.cuh.h | 12 ++-- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 5 +- .../fluid/framework/fleet/ps_gpu_wrapper.cu | 8 +-- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 68 +++++++++---------- 8 files changed, 98 insertions(+), 89 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index c996d2cef34892..2bc81db51a58b6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -81,7 +81,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { __host__ __device__ int Dim() { return 8 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr(1) __host__ __device__ int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } - __host__ __device__ int Size() { return (Dim()-1) * sizeof(float) + sizeof(uint64_t); } // cpu_ptr:uint64 + __host__ __device__ int Size() { return (Dim() - 1) * sizeof(float) + sizeof(uint64_t); } // cpu_ptr:uint64 __host__ __device__ int EmbedDim() { return embed_sgd_dim;} __host__ __device__ int EmbedXDim() { return embedx_sgd_dim;} __host__ __device__ int EmbedWDim() { return embedx_dim;} @@ -97,21 +97,48 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { __host__ __device__ int EmbedxG2SumIndex() { return MfSizeIndex() + 1; } __host__ __device__ int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } + + // 根据mf_dim计算的总长度 + __host__ __device__ int Dim(int& mf_dim) { + int tmp_embedx_sgd_dim = 1; + if (optimizer_type_ == 3) {//adam + tmp_embedx_sgd_dim = mf_dim * 2 + 2; + } else if (optimizer_type_ == 4) { //shared_adam + tmp_embedx_sgd_dim = 4; + } + return 8 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim; + } + + // 根据mf_dim 计算的总byte数 + __host__ __device__ int Size(int& mf_dim) { + return (Dim(mf_dim) - 1) * sizeof(float) + sizeof(uint64_t); // cpu_ptr:2 + } + + // 根据mf_dim 计算的总byte数 + __host__ __device__ int MfSize(int& mf_dim) { + int tmp_embedx_sgd_dim = 1; + if (optimizer_type_ == 3) {//adam + tmp_embedx_sgd_dim = mf_dim * 2 + 2; + } else if (optimizer_type_ == 4) { //shared_adam + tmp_embedx_sgd_dim = 4; + } + return (tmp_embedx_sgd_dim + mf_dim) * sizeof(float); + } + __host__ __device__ int EmbedxG2SumOffsetIndex() { return 0; } __host__ __device__ int EmbedxWOffsetIndex(float* val) { // has mf + int tmp_embedx_sgd_dim = 1; if (int(MfSize(val)) > 0) { if (optimizer_type_ == 3) {//adam - embedx_sgd_dim = int(MfDim(val)) * 2 + 2; + tmp_embedx_sgd_dim = int(MfDim(val)) * 2 + 2; } else if (optimizer_type_ == 4) { //shared_adam - embedx_sgd_dim = 4; - } else { - embedx_sgd_dim = 1; + tmp_embedx_sgd_dim = 4; } // PADDLE_ENFORCE(embedx_sgd_dim + int(MfDim(val)) == int(MfSize(val)), // "The number of embedx_sgd_dim size must be equal to mf_size." // "But got embedx_sgd_dim = %d, mf_size = %s", embedx_sgd_dim, int(MfSize(val))); - return EmbedxG2SumIndex() + embedx_sgd_dim; + return EmbedxG2SumIndex() + tmp_embedx_sgd_dim; } else { // no mf return 0; diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 47b1a8ed3b60ef..940dab449ef9ec 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -99,8 +99,6 @@ __global__ void dy_mf_search_kernel(Table* table, uint64_t offset = i * pull_feature_value_size; float* cur = (float*)(vals + offset); float* input = it->second; - int mf_dim = int(input[feature_value_accessor.common_feature_value.MfDimIndex()]); - feature_value_accessor.DynamicChangeDim(mf_dim); cur[feature_value_accessor.common_feature_value.SlotIndex()] = input[feature_value_accessor.common_feature_value.SlotIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = @@ -124,17 +122,13 @@ __global__ void dy_mf_search_kernel(Table* table, input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; } - if (int(input[feature_value_accessor.common_feature_value.MfSizeIndex()]) > 0) { - for (int i =0; i < int(input[feature_value_accessor.common_feature_value.MfSizeIndex()]); i++){ - cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + i] = - input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + i]; - } - // input[feature_value_accessor.common_feature_value.EmbedxWOffsetIndex(input) + j]; - } else { - for (int i =0; i < int(feature_value_accessor.GetAccessorInfo().mf_size / sizeof(float)); i++){ - cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + i] = - input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + i]; - } + for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedXDim(); x++) { + cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x] = + input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x]; + } + for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedWDim(); x++) { + cur[feature_value_accessor.common_feature_value.EmbedxWIndex() + x] = + input[feature_value_accessor.common_feature_value.EmbedxWIndex() + x]; } } else { if (keys[i] != 0) { @@ -366,15 +360,6 @@ void HashTable::update(const KeyType* d_keys, return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; - - char* tmp_config = (char*)malloc(sizeof(OptimizerConfig)); - cudaMemcpy(tmp_config, device_optimizer_config_, - sizeof(OptimizerConfig), cudaMemcpyDeviceToHost); - VLOG(0) << "tmp_config: learning_rate:" << tmp_config->learning_rate - << " initial_range:" << tmp_config->initial_range - << " mf_learning_rate:" << tmp_config->mf_learning_rate - << " mf_initial_range:" << tmp_config->mf_initial_range - << " mf_initial_g2sum:" << tmp_config->mf_initial_g2sum; dy_mf_update_kernel<<>>( container_, *device_optimizer_config_, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index e81a5c447f5c07..63265dd6a19ae1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -459,16 +459,18 @@ int HeterComm::get_index_by_devid(int devid) { template void HeterComm::set_sparse_sgd( const OptimizerConfig& optimizer_config) { - for (auto& table : tables_) { - table->set_sparse_sgd(optimizer_config); + for (int i = 0; i < resource_->total_device(); ++i) { + AnyDeviceGuard guard(resource_->dev_id(i)); + ptr_tables_[i]->set_sparse_sgd(optimizer_config); } } template void HeterComm::set_embedx_sgd( const OptimizerConfig& optimizer_config) { - for (auto& table : tables_) { - table->set_embedx_sgd(optimizer_config); + for (int i = 0; i < resource_->total_device(); ++i) { + AnyDeviceGuard guard(resource_->dev_id(i)); + ptr_tables_[i]->set_embedx_sgd(optimizer_config); } } @@ -915,7 +917,7 @@ void HeterComm::pull_sparse(int num, auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - VLOG(5) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; + VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_vals = memory::Alloc(place, len * val_type_size); @@ -996,6 +998,7 @@ void HeterComm::pull_sparse(int num, } destroy_storage(num, i); } + VLOG(0) << "pull sparse done"; } #if defined(PADDLE_WITH_CUDA) @@ -1148,6 +1151,9 @@ void HeterComm::push_sparse(int dev_num, } destroy_storage(dev_num, i); } + + VLOG(0) << " PUSHSPARSE destroy_storage done"; + } #elif defined(PADDLE_WITH_XPU_KP) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 3975f8dcd8aab4..b24919b2deea85 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -200,8 +200,8 @@ __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, uint64_t new_offset = uint64_t(idx[i]) * val_size; float* cur = (float*)((char*)d_vals + new_offset); float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); - int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); - feature_value_accessor.DynamicChangeDim(mf_dim); + // int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); + // feature_value_accessor.DynamicChangeDim(mf_dim); cur[feature_value_accessor.common_feature_value.SlotIndex()] = shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = @@ -214,8 +214,10 @@ __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; - cur[feature_value_accessor.common_feature_value.CpuPtrIndex()] = - shard_val[feature_value_accessor.common_feature_value.CpuPtrIndex()]; + for (int i = 0; i < 2; i ++) { + cur[feature_value_accessor.common_feature_value.CpuPtrIndex() + i] = + shard_val[feature_value_accessor.common_feature_value.CpuPtrIndex() + i]; + } cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = shard_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index a713aad70b5e4d..5231fc82b160d8 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -111,7 +111,6 @@ class SparseAdagradOptimizer : public Optimizer { g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); - feature_value_accessor_.DynamicChangeDim(mf_dim); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * @@ -119,7 +118,7 @@ class SparseAdagradOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); + feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -248,15 +247,14 @@ class SparseAdamOptimizer : public Optimizer { grad + feature_value_accessor_.common_push_value.EmbedGIndex(), g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); - feature_value_accessor_.DynamicChangeDim(mf_dim); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -361,7 +359,6 @@ class SparseAdamSharedOptimizer : public Optimizer { float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()]; float g_click = grad[feature_value_accessor_.common_push_value.ClickIndex()]; - ptr[feature_value_accessor_.common_feature_value.SlotIndex()] = grad[feature_value_accessor_.common_push_value.SlotIndex()]; ptr[feature_value_accessor_.common_feature_value.ShowIndex()] += g_show; @@ -376,7 +373,6 @@ class SparseAdamSharedOptimizer : public Optimizer { grad + feature_value_accessor_.common_push_value.EmbedGIndex(), g_show); int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); - feature_value_accessor_.DynamicChangeDim(mf_dim); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= optimizer_config.nonclk_coeff * @@ -384,7 +380,7 @@ class SparseAdamSharedOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); + feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index e853467fef766a..f0da42b855e3ad 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -655,6 +655,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { HeterPs_ = HeterPsBase::get_instance(size_max, resource_, feature_value_accessor_, optimizer_type_); #ifdef PADDLE_WITH_CUDA HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_); + HeterPs_->set_sparse_sgd(optimizer_config_); + HeterPs_->set_embedx_sgd(optimizer_config_); #endif auto build_dymf_mem_pool = [this, &gpu_task](int i, int j) { @@ -663,7 +665,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { int mf_dim = this->index_dim_vec_[j]; feature_value_accessor_.DynamicChangeDim(mf_dim); VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim - << " feature_value_DIM:" << feature_value_accessor_.GetAccessorInfo().dim; + << " feature_value_dim:" << feature_value_accessor_.GetAccessorInfo().dim + << " feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size; size_t feature_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 5eff19e08ae29b..94c7862fe36de3 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -395,8 +395,7 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float beta1_decay_rate, float beta2_decay_rate, float ada_epsilon) { - OptimizerConfig optimizer_config; - optimizer_config.set_sparse_sgd(nonclk_coeff, + optimizer_config_.set_sparse_sgd(nonclk_coeff, clk_coeff, min_bound, max_bound, @@ -406,7 +405,6 @@ void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, beta1_decay_rate, beta2_decay_rate, ada_epsilon); - HeterPs_->set_sparse_sgd(optimizer_config); } void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, @@ -418,8 +416,7 @@ void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, float mf_beta1_decay_rate, float mf_beta2_decay_rate, float mf_ada_epsilon) { - OptimizerConfig optimizer_config; - optimizer_config.set_embedx_sgd(mf_create_thresholds, + optimizer_config_.set_embedx_sgd(mf_create_thresholds, mf_learning_rate, mf_initial_g2sum, mf_initial_range, @@ -428,7 +425,6 @@ void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, mf_beta1_decay_rate, mf_beta2_decay_rate, mf_ada_epsilon); - HeterPs_->set_embedx_sgd(optimizer_config); } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index dd81a7ba081ffe..ec00473d586207 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -349,8 +349,9 @@ class PSGPUWrapper { config["embedx_dim"] = sparse_table_accessor.embedx_dim(); config["nonclk_coeff"] = sparse_table_accessor_parameter.nonclk_coeff(); config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff(); - - + config["mf_create_thresholds"] = sparse_table_accessor.embedx_threshold(); + + if (accessor_class == "CtrDymfAccessor") { // optimizer config for embed_w and embedx add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param()); @@ -378,13 +379,13 @@ class PSGPUWrapper { ? 10.0 : config["max_bound"]; float learning_rate = (config.find("learning_rate") == config.end()) - ? 1.0 + ? 0.05 : config["learning_rate"]; float initial_g2sum = (config.find("initial_g2sum") == config.end()) - ? 1.0 + ? 3.0 : config["initial_g2sum"]; float initial_range = (config.find("initial_range") == config.end()) - ? 1.0 + ? 1e-4 : config["initial_range"]; float beta1_decay_rate = (config.find("beta1_decay_rate") == config.end()) ? 0.9 @@ -401,19 +402,19 @@ class PSGPUWrapper { ? static_cast(1.0) : config["mf_create_thresholds"]; float mf_learning_rate = (config.find("mf_learning_rate") == config.end()) - ? 1.0 + ? 0.05 : config["mf_learning_rate"]; float mf_initial_g2sum = (config.find("mf_initial_g2sum") == config.end()) - ? 1.0 + ? 3.0 : config["mf_initial_g2sum"]; float mf_initial_range = (config.find("mf_initial_range") == config.end()) - ? 1.0 + ? 1e-4 : config["mf_initial_range"]; float mf_min_bound = (config.find("mf_min_bound") == config.end()) - ? 1.0 + ? -10.0 : config["mf_min_bound"]; float mf_max_bound = (config.find("mf_max_bound") == config.end()) - ? 1.0 + ? 10.0 : config["mf_max_bound"]; float mf_beta1_decay_rate = (config.find("mf_beta1_decay_rate") == config.end()) ? 0.9 @@ -424,32 +425,25 @@ class PSGPUWrapper { float mf_ada_epsilon = (config.find("mf_ada_epsilon") == config.end()) ? 1e-8 : config["mf_ada_epsilon"]; - for (size_t i = 0; i < heter_devices_.size(); i++) { -#ifdef PADDLE_WITH_CUDA - PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(heter_devices_[i])); -#elif defined(PADDLE_WITH_XPU_KP) - PADDLE_ENFORCE_XPU_SUCCESS(xpu_set_device(heter_devices_[i])); -#endif - this->SetSparseSGD(nonclk_coeff, - clk_coeff, - min_bound, - max_bound, - learning_rate, - initial_g2sum, - initial_range, - beta1_decay_rate, - beta2_decay_rate, - ada_epsilon); - this->SetEmbedxSGD(mf_create_thresholds, - mf_learning_rate, - mf_initial_g2sum, - mf_initial_range, - mf_min_bound, - mf_max_bound, - mf_beta1_decay_rate, - mf_beta2_decay_rate, - mf_ada_epsilon); - } + this->SetSparseSGD(nonclk_coeff, + clk_coeff, + min_bound, + max_bound, + learning_rate, + initial_g2sum, + initial_range, + beta1_decay_rate, + beta2_decay_rate, + ada_epsilon); + this->SetEmbedxSGD(mf_create_thresholds, + mf_learning_rate, + mf_initial_g2sum, + mf_initial_range, + mf_min_bound, + mf_max_bound, + mf_beta1_decay_rate, + mf_beta2_decay_rate, + mf_ada_epsilon); // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam) optimizer_type_ = (config.find("optimizer_type") == config.end()) @@ -672,7 +666,7 @@ class PSGPUWrapper { bool running_ = false; std::vector> pull_thread_pool_; std::vector> hbm_thread_pool_; - + OptimizerConfig optimizer_config_; protected: static bool is_initialized_; }; From e3f9c28135a84b57d7adc9da92ea3d14e4659d0e Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 23 Jun 2022 19:18:09 +0800 Subject: [PATCH 09/31] fix adam; test=develop --- .../distributed/ps/table/ctr_dymf_accessor.cc | 21 +------ .../distributed/ps/table/ctr_dymf_accessor.h | 27 ++++---- .../framework/fleet/heter_ps/feature_value.h | 37 +++-------- .../fleet/heter_ps/hashtable_kernel.cu | 38 +++++------- .../framework/fleet/heter_ps/heter_comm_inl.h | 16 ++--- .../fleet/heter_ps/heter_comm_kernel.cu | 38 +++++------- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 62 ++++++++----------- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 10 ++- 8 files changed, 97 insertions(+), 152 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc index 04b7295ea7a5ad..cd34c9e0e7ea3b 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc @@ -64,20 +64,6 @@ void CtrDymfAccessor::InitAccessorInfo() { (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float); } -void CtrDymfAccessor::DynamicChangeDim(int mf_dim) { - // 假设一个任务中sparse优化器是不变的,改变的只是不同slot的embedding维度,比如组网中既包括8维又有32维 - if (common_feature_value.optimizer_name == "SparseAdamSGDRule") {//adam - common_feature_value.embedx_sgd_dim = mf_dim * 2 + 2; - } else if (common_feature_value.optimizer_name == "SparseSharedAdamSGDRule") { //shared_adam - common_feature_value.embedx_sgd_dim = 4; - } else { - common_feature_value.embedx_sgd_dim = 1; - } - common_feature_value.embedx_dim = mf_dim; - - // InitAccessorInfo(); - } - bool CtrDymfAccessor::Shrink(float* value) { auto delete_after_unseen_days = _config.ctr_accessor_param().delete_after_unseen_days(); @@ -314,14 +300,11 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) { auto show = common_feature_value.Show(const_cast(v)); auto click = common_feature_value.Click(const_cast(v)); auto score = ShowClickScore(show, click); - auto mf_dim = common_feature_value.MfDim(const_cast(v)); + auto mf_dim = int(common_feature_value.MfDim(const_cast(v))); if (score >= _config.embedx_threshold() && param > common_feature_value.EmbedxG2SumIndex()) { - - DynamicChangeDim(int(mf_dim)); for (auto i = common_feature_value.EmbedxG2SumIndex(); - i < common_feature_value.EmbedxWIndex() + - common_feature_value.MfDim(const_cast(v)); + i < common_feature_value.Dim(mf_dim); ++i) { os << " " << v[i]; } diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h index 720b7e5076350e..04ff2dbcd3a6dc 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h @@ -58,16 +58,21 @@ class CtrDymfAccessor : public ValueAccessor { int MfDimIndex() { return SlotIndex() + 1; } int EmbedxG2SumIndex() { return MfDimIndex() + 1; } int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } - // int EmbedxWOffsetIndex(float* val) { - // if (optimizer_name == "SparseAdamSGDRule") {//adam - // embedx_sgd_dim = int(MfDim(val)) * 2 + 2; - // } else if (optimizer_name == "SparseSharedAdamSGDRule") { //shared_adam - // embedx_sgd_dim = 4; - // } else { - // embedx_sgd_dim = 1; - // } - // return EmbedxG2SumIndex() + embedx_sgd_dim; - // } + + // 根据mf_dim计算的总长度 + int Dim(int& mf_dim) { + int tmp_embedx_sgd_dim = 1; + if (optimizer_name == "SparseAdamSGDRule") {//adam + tmp_embedx_sgd_dim = mf_dim * 2 + 2; + } else if (optimizer_name == "SparseSharedAdamSGDRule") { //shared_adam + tmp_embedx_sgd_dim = 4; + } + return 7 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim; + } + + // 根据mf_dim计算的总byte数 + int Size(int& mf_dim) { return (Dim(mf_dim)) * sizeof(float); } + float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; } float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } @@ -162,8 +167,6 @@ class CtrDymfAccessor : public ValueAccessor { CtrDymfAccessor() {} virtual ~CtrDymfAccessor() {} virtual int Initialize(); - // 多种维度时更新目前的长度 - virtual void DynamicChangeDim(int mf_dim); // 初始化AccessorInfo virtual void InitAccessorInfo(); // 判断该value是否进行shrink diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 2bc81db51a58b6..a9073d611343da 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -79,9 +79,9 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { std::vector embedx_w; */ - __host__ __device__ int Dim() { return 8 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr(1) + __host__ __device__ int Dim() { return 9 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr(2) __host__ __device__ int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } - __host__ __device__ int Size() { return (Dim() - 1) * sizeof(float) + sizeof(uint64_t); } // cpu_ptr:uint64 + __host__ __device__ int Size() { return Dim() * sizeof(float); } // cpu_ptr:uint64=2float __host__ __device__ int EmbedDim() { return embed_sgd_dim;} __host__ __device__ int EmbedXDim() { return embedx_sgd_dim;} __host__ __device__ int EmbedWDim() { return embedx_dim;} @@ -106,18 +106,18 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { } else if (optimizer_type_ == 4) { //shared_adam tmp_embedx_sgd_dim = 4; } - return 8 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim; + return 9 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim; } // 根据mf_dim 计算的总byte数 __host__ __device__ int Size(int& mf_dim) { - return (Dim(mf_dim) - 1) * sizeof(float) + sizeof(uint64_t); // cpu_ptr:2 + return Dim(mf_dim) * sizeof(float); // cpu_ptr:2float } - // 根据mf_dim 计算的总byte数 + // 根据mf_dim 计算的 mf_size byte数 __host__ __device__ int MfSize(int& mf_dim) { int tmp_embedx_sgd_dim = 1; - if (optimizer_type_ == 3) {//adam + if (optimizer_type_ == 3) { //adam tmp_embedx_sgd_dim = mf_dim * 2 + 2; } else if (optimizer_type_ == 4) { //shared_adam tmp_embedx_sgd_dim = 4; @@ -135,9 +135,6 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { } else if (optimizer_type_ == 4) { //shared_adam tmp_embedx_sgd_dim = 4; } - // PADDLE_ENFORCE(embedx_sgd_dim + int(MfDim(val)) == int(MfSize(val)), - // "The number of embedx_sgd_dim size must be equal to mf_size." - // "But got embedx_sgd_dim = %d, mf_size = %s", embedx_sgd_dim, int(MfSize(val))); return EmbedxG2SumIndex() + tmp_embedx_sgd_dim; } else { // no mf @@ -227,24 +224,9 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { } common_feature_value.optimizer_type_ = optimizer_type; common_feature_value.embedx_dim = sparse_embedx_dim; - - // VLOG(0) << " INTO FeatureValueAccessor::Initialize()"; - InitAccessorInfo(); - return 0; - } - - __host__ __device__ virtual void DynamicChangeDim(int mf_dim) { - // 假设一个任务中sparse优化器是不变的,改变的只是不同slot的embedding维度,比如组网中既包括8维又有32维 - if (common_feature_value.optimizer_type_ == 3) { //adam - common_feature_value.embedx_sgd_dim = mf_dim * 2 + 2; - } else if (common_feature_value.optimizer_type_ == 4) { //shared_adam - common_feature_value.embedx_sgd_dim = 4; - } else { - common_feature_value.embedx_sgd_dim = 1; - } - common_feature_value.embedx_dim = mf_dim; InitAccessorInfo(); + return 0; } // 初始化AccessorInfo @@ -279,13 +261,14 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { i < common_feature_value.SlotIndex(); i++) { os << " " << v[i]; } + int mf_dim = int(common_feature_value.MfDim(const_cast(v))); os << " slot: " << common_feature_value.Slot(const_cast(v)) - << " mf_dim: " << common_feature_value.MfDim(const_cast(v)) + << " mf_dim: " << mf_dim << " mf_size: " << common_feature_value.MfSize(const_cast(v)) << " mf: "; if (param_size > common_feature_value.EmbedxG2SumIndex()) { for (auto i = common_feature_value.EmbedxG2SumIndex(); - i < int(common_feature_value.Size() / sizeof(float)); ++i) { + i < common_feature_value.Dim(mf_dim); ++i) { os << " " << v[i]; } } diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 940dab449ef9ec..98d4c53803d942 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -99,36 +99,32 @@ __global__ void dy_mf_search_kernel(Table* table, uint64_t offset = i * pull_feature_value_size; float* cur = (float*)(vals + offset); float* input = it->second; - cur[feature_value_accessor.common_feature_value.SlotIndex()] = - input[feature_value_accessor.common_feature_value.SlotIndex()]; + int mf_dim = int(input[feature_value_accessor.common_feature_value.MfDimIndex()]); + + *(reinterpret_cast(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) = + *(reinterpret_cast(input + feature_value_accessor.common_feature_value.CpuPtrIndex())); + cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = + input[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = input[feature_value_accessor.common_feature_value.ShowIndex()]; cur[feature_value_accessor.common_feature_value.ClickIndex()] = input[feature_value_accessor.common_feature_value.ClickIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = - input[feature_value_accessor.common_feature_value.MfDimIndex()]; cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = input[feature_value_accessor.common_feature_value.EmbedWIndex()]; + for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedDim(); x++) { + cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x] = + input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x]; + } + cur[feature_value_accessor.common_feature_value.SlotIndex()] = + input[feature_value_accessor.common_feature_value.SlotIndex()]; + cur[feature_value_accessor.common_feature_value.MfDimIndex()] = + input[feature_value_accessor.common_feature_value.MfDimIndex()]; cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = input[feature_value_accessor.common_feature_value.MfSizeIndex()]; - cur[feature_value_accessor.common_feature_value.CpuPtrIndex()] = - input[feature_value_accessor.common_feature_value.CpuPtrIndex()]; - cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - input[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; - cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = - input[feature_value_accessor.common_feature_value.EmbedWIndex()]; - for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { - cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = - input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; - } - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedXDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x] = - input[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x]; - } - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedWDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxWIndex() + x] = - input[feature_value_accessor.common_feature_value.EmbedxWIndex() + x]; + for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + cur[x] = input[x]; } } else { if (keys[i] != 0) { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 63265dd6a19ae1..e38daa453dbc8a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -72,11 +72,10 @@ HeterComm::HeterComm( } else { max_mf_dim_ = resource_->max_mf_dim(); feature_value_accessor_ = feature_value_accessor; - feature_value_accessor_.DynamicChangeDim(max_mf_dim_); - VLOG(0) << " HeterComm init, max feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size - << ", feature_value_push_size:" << feature_value_accessor_.GetAccessorInfo().update_size; - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + VLOG(0) << " HeterComm init, max feature_value_size:" << val_type_size + << ", feature_value_push_size:" << grad_type_size; auto ptr_table = new PtrTable(capacity / load_factor_); ptr_table->set_accessor(feature_value_accessor_); ptr_table->set_feature_value_size(val_type_size, grad_type_size); @@ -693,7 +692,7 @@ void HeterComm::dynamic_merge_grad( size_t temp_storage_bytes; - size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); @@ -916,7 +915,8 @@ void HeterComm::pull_sparse(int num, auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + + size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); @@ -1017,7 +1017,7 @@ void HeterComm::push_sparse(int dev_num, int dev_id = resource_->dev_id(dev_num); size_t grad_value_size = - TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(dev_num, 0); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index b24919b2deea85..790c744d3ee6a5 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -200,40 +200,32 @@ __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, uint64_t new_offset = uint64_t(idx[i]) * val_size; float* cur = (float*)((char*)d_vals + new_offset); float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); - // int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); - // feature_value_accessor.DynamicChangeDim(mf_dim); - cur[feature_value_accessor.common_feature_value.SlotIndex()] = - shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; + int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); + + *(reinterpret_cast(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) = + *(reinterpret_cast(shard_val + feature_value_accessor.common_feature_value.CpuPtrIndex())); + cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = + shard_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; cur[feature_value_accessor.common_feature_value.ShowIndex()] = shard_val[feature_value_accessor.common_feature_value.ShowIndex()]; cur[feature_value_accessor.common_feature_value.ClickIndex()] = shard_val[feature_value_accessor.common_feature_value.ClickIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = - shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]; - cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = - shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; - cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = - shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; - for (int i = 0; i < 2; i ++) { - cur[feature_value_accessor.common_feature_value.CpuPtrIndex() + i] = - shard_val[feature_value_accessor.common_feature_value.CpuPtrIndex() + i]; - } - cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - shard_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = shard_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; } + cur[feature_value_accessor.common_feature_value.SlotIndex()] = + shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; + cur[feature_value_accessor.common_feature_value.MfDimIndex()] = + shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]; + cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = + shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedXDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x] = - shard_val[feature_value_accessor.common_feature_value.EmbedxG2SumIndex() + x]; - } - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedWDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedxWIndex() + x] = - shard_val[feature_value_accessor.common_feature_value.EmbedxWIndex() + x]; + for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + cur[x] = shard_val[x]; } } } diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index f0da42b855e3ad..77b7e231eda5aa 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -663,12 +663,11 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); // this->HeterPs_->set_accessor(feature_value_accessor_); int mf_dim = this->index_dim_vec_[j]; - feature_value_accessor_.DynamicChangeDim(mf_dim); VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim - << " feature_value_dim:" << feature_value_accessor_.GetAccessorInfo().dim - << " feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size; + << " feature_value_dim:" << feature_value_accessor_.common_feature_value.Dim(mf_dim) + << " feature_value_size:" << feature_value_accessor_.common_feature_value.Size(mf_dim); size_t feature_value_size = - TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(mf_dim)); auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; size_t len = device_dim_keys.size(); @@ -797,25 +796,20 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { if (dim > cpu_table_accessor_->GetAccessorInfo().dim - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)) { val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.GetAccessorInfo().mf_size / sizeof(float); - for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { - val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x]; - } - for (int x = 0; x < mf_dim; x++) { - val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedxWIndex() + x]; + feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float); + + for (int x = feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor_.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + val[x] = ptr_val[x]; } } else { val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = 0; - for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { - val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = 0; - } - for (int x = 0; x < mf_dim; x++) { - val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x] = 0; + for (int x = feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(); + x < int(feature_value_accessor_.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + val[x] = 0; } } - VLOG(5) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.GetAccessorInfo().dim); + VLOG(5) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.common_feature_value.Dim(mf_dim)); } #endif @@ -988,11 +982,11 @@ void PSGPUWrapper::EndPass() { } // ============ multi-thread process feasign============ int mf_dim = this->index_dim_vec_[j]; - feature_value_accessor_.DynamicChangeDim(mf_dim); - VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim - << " key_len :" << len; size_t feature_value_size = - TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); + TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(mf_dim)); + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim + << " key_len :" << len << " feature_value_size:" << feature_value_size; + char* test_build_values = (char*)malloc(feature_value_size * real_len); uint64_t offset = left * feature_value_size; cudaMemcpy(test_build_values, @@ -1042,7 +1036,7 @@ void PSGPUWrapper::EndPass() { if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0 && downpour_value_size == (cpu_table_accessor_->GetAccessorInfo().dim - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float))) { // cpu_accessor - downpour_value->resize(cpu_table_accessor_->GetAccessorInfo().dim); + downpour_value->resize(cpu_table_accessor_->common_feature_value.Dim(mf_dim)); } float* cpu_val = downpour_value->data(); @@ -1063,16 +1057,14 @@ void PSGPUWrapper::EndPass() { } if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0) { - for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedXDim(); x++) { - cpu_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x] = + + for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float)); + x++) { + cpu_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x] = gpu_val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x]; } - for (int x = 0; x < feature_value_accessor_.common_feature_value.EmbedWDim(); x++) { - cpu_val[cpu_table_accessor_->common_feature_value.EmbedxWIndex() + x] = - gpu_val[feature_value_accessor_.common_feature_value.EmbedxWIndex() + x]; - } } - VLOG(5) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.GetAccessorInfo().dim) + VLOG(5) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.common_feature_value.Dim(mf_dim)) << " ===== CPU:" << cpu_table_accessor_->ParseToString(cpu_val, downpour_value->size()); } @@ -1133,9 +1125,8 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); size_t feature_value_size = 0; - feature_value_accessor_.DynamicChangeDim(max_mf_dim_); - feature_value_size = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - VLOG(0) << "PullSparse max_dim:" << max_mf_dim_ << " feature_value_size:" << feature_value_accessor_.GetAccessorInfo().size; + feature_value_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + VLOG(0) << "PullSparse max_dim:" << max_mf_dim_ << " feature_value_size:" << feature_value_size; #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begine Gpu Ps PullSparse"; @@ -1295,11 +1286,10 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); // #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begin GPUPS PushSparseGrad"; - feature_value_accessor_.DynamicChangeDim(max_mf_dim_); size_t grad_value_size = - TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); auto buf = memory::Alloc(place, total_length * grad_value_size); - VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_; + VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_ << "grad_value_size:" << grad_value_size; float* total_grad_values_gpu = reinterpret_cast(buf->ptr()); if (platform::is_cpu_place(place)) { diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index ec00473d586207..fdd740f7d0f1e2 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -361,7 +361,7 @@ class PSGPUWrapper { feature_value_accessor_.Configure(config); VLOG(0) << "INIT feature_value_accessor_:" << feature_value_accessor_.GetAccessorInfo().dim - << " EMBX:" << feature_value_accessor_.common_feature_value.embedx_sgd_dim; + << " embedx_sgd_dim:" << feature_value_accessor_.common_feature_value.embedx_sgd_dim; InitializeGPUServer(config); } #endif @@ -551,11 +551,9 @@ class PSGPUWrapper { for (size_t i = 0; i < slot_index_vec_.size(); i++) { slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]]; } - //TODO(FENGDANLEI): max_mf - feature_value_accessor_.DynamicChangeDim(max_mf_dim_); - VLOG(0) << "InitSlotInfo:" << feature_value_accessor_.GetAccessorInfo().size; - val_type_size_ =TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().size); - grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.GetAccessorInfo().update_size); + val_type_size_ = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_ << " grad_type_size_:" << grad_type_size_; slot_info_initialized_ = true; } #endif From ef8a7130fa8d5d1092eff9d5ddf785a97c867a65 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 23 Jun 2022 19:21:40 +0800 Subject: [PATCH 10/31] remove useless code;test=develop --- paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index e38daa453dbc8a..2e2d4b85174713 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -998,7 +998,6 @@ void HeterComm::pull_sparse(int num, } destroy_storage(num, i); } - VLOG(0) << "pull sparse done"; } #if defined(PADDLE_WITH_CUDA) @@ -1148,12 +1147,10 @@ void HeterComm::push_sparse(int dev_num, for (int i = 0; i < total_device; ++i) { if (h_left[i] == -1 || h_right[i] == -1) { continue; + } destroy_storage(dev_num, i); } - - VLOG(0) << " PUSHSPARSE destroy_storage done"; - } #elif defined(PADDLE_WITH_XPU_KP) From dda7284319047b6a3e0c9fdcd15edf83190c9a82 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Mon, 27 Jun 2022 17:04:28 +0800 Subject: [PATCH 11/31] fix adam; test=develop --- .../fluid/framework/fleet/heter_ps/feature_value.h | 2 +- .../fluid/framework/fleet/heter_ps/optimizer.cuh.h | 14 +++----------- paddle/fluid/framework/fleet/ps_gpu_wrapper.cc | 13 +++++++------ 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index a9073d611343da..53011571bcb230 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -115,7 +115,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { } // 根据mf_dim 计算的 mf_size byte数 - __host__ __device__ int MfSize(int& mf_dim) { + __host__ __device__ int MFSize(int& mf_dim) { int tmp_embedx_sgd_dim = 1; if (optimizer_type_ == 3) { //adam tmp_embedx_sgd_dim = mf_dim * 2 + 2; diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index 5231fc82b160d8..96116ba954a07b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -118,7 +118,7 @@ class SparseAdagradOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float); + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -254,7 +254,7 @@ class SparseAdamOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float); + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -263,10 +263,6 @@ class SparseAdamOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } - ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta1PowIndex()] = - optimizer_config.beta1_decay_rate; - ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta2PowIndex()] = - optimizer_config.beta2_decay_rate; ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate; ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta2PowIndex()] = @@ -380,7 +376,7 @@ class SparseAdamSharedOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float); + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; @@ -389,10 +385,6 @@ class SparseAdamSharedOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } - ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta1PowIndex()] = - optimizer_config.beta1_decay_rate; - ptr[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + Beta2PowIndex()] = - optimizer_config.beta2_decay_rate; ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate; ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta2PowIndex()] = diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 77b7e231eda5aa..120d8007a5583c 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -796,11 +796,12 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { if (dim > cpu_table_accessor_->GetAccessorInfo().dim - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)) { val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float); + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); - for (int x = feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(); - x < int(feature_value_accessor_.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - val[x] = ptr_val[x]; + for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float)); + x++) { + val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = + ptr_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x]; } } else { val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = 0; @@ -1035,7 +1036,7 @@ void PSGPUWrapper::EndPass() { size_t downpour_value_size = downpour_value->size(); if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0 && downpour_value_size == (cpu_table_accessor_->GetAccessorInfo().dim - - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float))) { // cpu_accessor + int(cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)))) { // cpu_accessor downpour_value->resize(cpu_table_accessor_->common_feature_value.Dim(mf_dim)); } float* cpu_val = downpour_value->data(); @@ -1058,7 +1059,7 @@ void PSGPUWrapper::EndPass() { if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0) { - for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MfSize(mf_dim) / sizeof(float)); + for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float)); x++) { cpu_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x] = gpu_val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x]; From 811d61f3e34530f2b41a1bf5abce57861c3966aa Mon Sep 17 00:00:00 2001 From: danleifeng Date: Mon, 27 Jun 2022 18:01:06 +0800 Subject: [PATCH 12/31] remove useless code;test=develop --- .../framework/fleet/heter_ps/feature_value.h | 27 ------------------- .../fleet/heter_ps/heter_comm_kernel.cu | 3 +-- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 2 +- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 2 -- 4 files changed, 2 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 53011571bcb230..221915fc713a82 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -27,19 +27,6 @@ namespace framework { typedef uint64_t FeatureKey; -struct GpuAccessorInfo { - // value维度 - size_t dim; - // value各个维度的size - size_t size; - // push value维度 - size_t update_dim; - // push value各个维度的size - size_t update_size; - // value中mf动态长度部分总size大小, 包含mf_g2sum和 mf_dim, sparse下生效 - size_t mf_size; -}; - class FeatureValueAccessor { public: __host__ __device__ FeatureValueAccessor() {} @@ -52,11 +39,8 @@ class FeatureValueAccessor { } __host__ __device__ virtual int Initialize() = 0; - __host__ __device__ virtual GpuAccessorInfo GetAccessorInfo() { return _accessor_info; } - protected: std::unordered_map _config; - GpuAccessorInfo _accessor_info; }; // adagrad: embed_sgd_dim=1, embedx_sgd_dim=1,embedx_dim=n @@ -225,20 +209,9 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { common_feature_value.optimizer_type_ = optimizer_type; common_feature_value.embedx_dim = sparse_embedx_dim; - InitAccessorInfo(); return 0; } - // 初始化AccessorInfo - __host__ __device__ virtual void InitAccessorInfo() { - _accessor_info.dim = common_feature_value.Dim(); - _accessor_info.size = common_feature_value.Size(); - _accessor_info.update_dim = 5 + common_feature_value.EmbedWDim(); - _accessor_info.update_size = _accessor_info.update_dim * sizeof(float); - _accessor_info.mf_size = - (common_feature_value.EmbedWDim() + common_feature_value.EmbedXDim()) * sizeof(float); - } - __host__ __device__ std::string ParseToString(const float* v, int param_size) { /* uint64_t cpu_ptr; // 2float diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 790c744d3ee6a5..aaa6bdc1787602 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -218,8 +218,7 @@ __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, } cur[feature_value_accessor.common_feature_value.SlotIndex()] = shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = - shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]; + cur[feature_value_accessor.common_feature_value.MfDimIndex()] = mf_dim; cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 120d8007a5583c..e545956844672a 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -1127,7 +1127,7 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, size_t feature_value_size = 0; feature_value_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); - VLOG(0) << "PullSparse max_dim:" << max_mf_dim_ << " feature_value_size:" << feature_value_size; + VLOG(3) << "PullSparse max_dim:" << max_mf_dim_ << " feature_value_size:" << feature_value_size; #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begine Gpu Ps PullSparse"; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index fdd740f7d0f1e2..555de56f1e0034 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -360,8 +360,6 @@ class PSGPUWrapper { } feature_value_accessor_.Configure(config); - VLOG(0) << "INIT feature_value_accessor_:" << feature_value_accessor_.GetAccessorInfo().dim - << " embedx_sgd_dim:" << feature_value_accessor_.common_feature_value.embedx_sgd_dim; InitializeGPUServer(config); } #endif From 106c2adb2286c28fea3f7d85a01a98bd9ae797af Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 28 Jun 2022 11:56:36 +0800 Subject: [PATCH 13/31] remove useless code;test=develop --- python/paddle/distributed/ps/the_one_ps.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 7d240983a1c289..42d7c11eab2631 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -824,7 +824,6 @@ def build_worker_desc(self): self.barrier_table_id = table.idx self.service._set( self.ps_desc.server_param.downpour_server_param.service_param) - self.fs_client._set(self.ps_desc.fs_client_param) return text_format.MessageToString(self.ps_desc) def build_server_desc(self): From 7814b04ce90782f7c49af9e0596fe7f24ae44ab4 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Fri, 1 Jul 2022 16:03:38 +0800 Subject: [PATCH 14/31] [gpups]refine adam aceessor;test=develop --- .../framework/fleet/heter_ps/feature_value.h | 208 ++++++++++++++++++ .../fleet/heter_ps/hashtable_kernel.cu | 30 +-- .../fleet/heter_ps/heter_comm_kernel.cu | 41 +--- .../fleet/heter_ps/heter_comm_kernel.h | 27 +-- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 61 +---- .../fluid/framework/fleet/ps_gpu_wrapper.cu | 21 +- 6 files changed, 217 insertions(+), 171 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 221915fc713a82..2486f34cc0e256 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" namespace paddle { @@ -185,6 +186,35 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { } }; + struct CommonPullValue { + /* + float show; + float click; + float embed_w; + std::vector embedx_w; + */ + + __host__ __device__ static int Dim(int embedx_dim) { return 3 + embedx_dim; } + __host__ __device__ int DimSize(size_t dim) { return sizeof(float); } + __host__ __device__ int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } + __host__ __device__ int ShowIndex() { return 0; } + __host__ __device__ int ClickIndex() { return 1; } + __host__ __device__ int EmbedWIndex() { return 2; } + __host__ __device__ int EmbedxWIndex() { return 3; } + __host__ __device__ float& Show(float* val) { + return val[CommonPullValue::ShowIndex()]; + } + __host__ __device__ float& Click(float* val) { + return val[CommonPullValue::ClickIndex()]; + } + __host__ __device__ float& EmbedW(float* val) { + return val[CommonPullValue::EmbedWIndex()]; + } + __host__ __device__ float* EmbedxW(float* val) { + return val + CommonPullValue::EmbedxWIndex(); + } + }; + __host__ __device__ CommonFeatureValueAccessor() {} __host__ __device__ ~CommonFeatureValueAccessor() {} @@ -212,6 +242,183 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { return 0; } + +// build阶段从cpu_val赋值给gpu_val +__host__ __device__ void BuildFill(float* gpu_val, + float* cpu_val, + paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + int mf_dim, + size_t cpu_fv_dim) { + + gpu_val[common_feature_value.DeltaScoreIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.DeltaScoreIndex()]; + gpu_val[common_feature_value.ShowIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.ShowIndex()]; + gpu_val[common_feature_value.ClickIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.ClickIndex()]; + gpu_val[common_feature_value.SlotIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.SlotIndex()]; + gpu_val[common_feature_value.EmbedWIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.EmbedWIndex()]; + for (int i = 0; i < common_feature_value.EmbedDim(); i++) { + gpu_val[common_feature_value.EmbedG2SumIndex() + i] = + cpu_val[cpu_table_accessor->common_feature_value.EmbedG2SumIndex() + i]; + } + + cpu_val[cpu_table_accessor->common_feature_value.MfDimIndex()] = float(mf_dim); + gpu_val[common_feature_value.MfDimIndex()] = mf_dim; + if (cpu_fv_dim > cpu_table_accessor->GetAccessorInfo().dim - + cpu_table_accessor->GetAccessorInfo().mf_size / sizeof(float)) { + gpu_val[common_feature_value.MfSizeIndex()] = + common_feature_value.MFSize(mf_dim) / sizeof(float); + + for (int x = 0; x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); + x++) { + gpu_val[common_feature_value.EmbedxG2SumIndex() + x] = + cpu_val[cpu_table_accessor->common_feature_value.EmbedxG2SumIndex() + x]; + } + } else { + gpu_val[common_feature_value.MfSizeIndex()] = 0; + for (int x = common_feature_value.EmbedxG2SumIndex(); + x < int(common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + gpu_val[x] = 0; + } + } +} + + +// dump_to_cpu阶段从gpu_val赋值给cpu_val +__host__ __device__ void DumpFill(float* cpu_val, + float* gpu_val, + paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + int mf_dim) { + + cpu_val[cpu_table_accessor->common_feature_value.DeltaScoreIndex()] = + gpu_val[common_feature_value.DeltaScoreIndex()]; + cpu_val[cpu_table_accessor->common_feature_value.ShowIndex()] = + gpu_val[common_feature_value.ShowIndex()]; + cpu_val[cpu_table_accessor->common_feature_value.ClickIndex()] = + gpu_val[common_feature_value.ClickIndex()]; + cpu_val[cpu_table_accessor->common_feature_value.EmbedWIndex()] = + gpu_val[common_feature_value.EmbedWIndex()]; + cpu_val[cpu_table_accessor->common_feature_value.SlotIndex()] = + gpu_val[common_feature_value.SlotIndex()]; + + for (int i = 0; i < common_feature_value.EmbedDim(); i++) { + cpu_val[cpu_table_accessor->common_feature_value.EmbedG2SumIndex() + i] = + gpu_val[common_feature_value.EmbedG2SumIndex() + i]; + } + + if (gpu_val[common_feature_value.MfSizeIndex()] > 0) { + + for (int x = 0; x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); + x++) { + cpu_val[cpu_table_accessor->common_feature_value.EmbedxG2SumIndex() + x] = + gpu_val[common_feature_value.EmbedxG2SumIndex() + x]; + } + } +} + + +// dy_mf_fill_dvals_kernel, dy_mf_search_kernel 阶段 gpukernel 中从src_val赋值给dest_val +__host__ __device__ void FeatureValueFill(float* dest_val, + float* src_val, + int mf_dim) { + *(reinterpret_cast(dest_val + common_feature_value.CpuPtrIndex())) = + *(reinterpret_cast(src_val + common_feature_value.CpuPtrIndex())); + dest_val[common_feature_value.DeltaScoreIndex()] = src_val[common_feature_value.DeltaScoreIndex()]; + dest_val[common_feature_value.ShowIndex()] = src_val[common_feature_value.ShowIndex()]; + dest_val[common_feature_value.ClickIndex()] = src_val[common_feature_value.ClickIndex()]; + dest_val[common_feature_value.EmbedWIndex()] = src_val[common_feature_value.EmbedWIndex()]; + for (int i = 0; i < common_feature_value.EmbedDim(); i++) { + dest_val[common_feature_value.EmbedG2SumIndex() + i] = + src_val[common_feature_value.EmbedG2SumIndex() + i]; + } + dest_val[common_feature_value.SlotIndex()] = src_val[common_feature_value.SlotIndex()]; + dest_val[common_feature_value.MfDimIndex()] = mf_dim; + dest_val[common_feature_value.MfSizeIndex()] = src_val[common_feature_value.MfSizeIndex()]; + + for (int x = common_feature_value.EmbedxG2SumIndex(); + x < int(common_feature_value.Size(mf_dim) / sizeof(float)); x++){ + dest_val[x] = src_val[x]; + } +} + + +// dy_mf_fill_shard_grads_kernel,update_one 阶段 gpukernel 中从src_val赋值给dest_val +__host__ __device__ void PushValueFill(float* dest_val, + const float* src_val) { + dest_val[common_push_value.SlotIndex()] = src_val[common_push_value.SlotIndex()]; + dest_val[common_push_value.ShowIndex()] = src_val[common_push_value.ShowIndex()]; + dest_val[common_push_value.ClickIndex()] = src_val[common_push_value.ClickIndex()]; + dest_val[common_push_value.MfDimIndex()] = src_val[common_push_value.MfDimIndex()]; + dest_val[common_push_value.EmbedGIndex()] = src_val[common_push_value.EmbedGIndex()]; + + for (int x = 0; x < int(src_val[common_push_value.MfDimIndex()]); x++) { + dest_val[common_push_value.EmbedxGIndex() + x] = src_val[common_push_value.EmbedxGIndex() + x]; + } +} + +// update_basic 阶段 gpukernel 中从src_val赋值给dest_val +__host__ __device__ void PushValueFillBasic(float* dest_val, + const float* src_val) { + dest_val[common_push_value.SlotIndex()] = src_val[common_push_value.SlotIndex()]; + dest_val[common_push_value.ShowIndex()] = src_val[common_push_value.ShowIndex()]; + dest_val[common_push_value.ClickIndex()] = src_val[common_push_value.ClickIndex()]; + dest_val[common_push_value.MfDimIndex()] = src_val[common_push_value.MfDimIndex()]; + dest_val[common_push_value.EmbedGIndex()] = src_val[common_push_value.EmbedGIndex()]; + +} + + +// merge_one 阶段 gpukernel 中 PushValue 从src_val赋值给dest_val +__host__ __device__ void MergePushValue(float* dest_val, + const float* src_val) { + dest_val[common_push_value.ShowIndex()] += src_val[common_push_value.ShowIndex()]; + dest_val[common_push_value.ClickIndex()] += src_val[common_push_value.ClickIndex()]; + dest_val[common_push_value.EmbedGIndex()] += src_val[common_push_value.EmbedGIndex()]; + for (int j = 0; j < int(dest_val[common_push_value.MfDimIndex()]); j++) { + dest_val[common_push_value.EmbedxGIndex() + j] += src_val[common_push_value.EmbedxGIndex() + j]; + } +} + + +// merge_basic 阶段 gpukernel 中 PushValue 从src_val赋值给dest_val +__host__ __device__ void MergePushValueBasic(float* dest_val, + const float* src_val) { + dest_val[common_push_value.ShowIndex()] += src_val[common_push_value.ShowIndex()]; + dest_val[common_push_value.ClickIndex()] += src_val[common_push_value.ClickIndex()]; + dest_val[common_push_value.EmbedGIndex()] += src_val[common_push_value.EmbedGIndex()]; +} + +// PullCopy 阶段 gpukernel 中 FeatureValue回填到PullValue +__host__ __device__ void Select(float* dest_val, + float* src_val, + uint64_t* key, + int mf_dim) { + if (*key == 0) { + *(dest_val + common_pull_value.ShowIndex()) = 0; + *(dest_val + common_pull_value.ClickIndex()) = 0; + *(dest_val + common_pull_value.EmbedWIndex()) = 0; + } else { + *(dest_val + common_pull_value.ShowIndex()) = src_val[common_feature_value.ShowIndex()]; + *(dest_val + common_pull_value.ClickIndex()) = src_val[common_feature_value.ClickIndex()]; + *(dest_val + common_pull_value.EmbedWIndex()) = src_val[common_feature_value.EmbedWIndex()]; + } + + if (src_val[common_feature_value.MfSizeIndex()] == 0 || *key == 0) { + for (int j = 0; j < mf_dim; j++) { + *(dest_val + common_pull_value.EmbedxWIndex() + j) = 0; + } + } else { + for (int j = 0; j < mf_dim; j++) { + *(dest_val + common_pull_value.EmbedxWIndex() + j) = + src_val[common_feature_value.EmbedxWOffsetIndex(src_val) + j]; + } + } +} + + __host__ __device__ std::string ParseToString(const float* v, int param_size) { /* uint64_t cpu_ptr; // 2float @@ -251,6 +458,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { public: CommonFeatureValue common_feature_value; CommonPushValue common_push_value; + CommonPullValue common_pull_value; }; diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 98d4c53803d942..9f210e545bc0ed 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -101,35 +101,7 @@ __global__ void dy_mf_search_kernel(Table* table, float* input = it->second; int mf_dim = int(input[feature_value_accessor.common_feature_value.MfDimIndex()]); - *(reinterpret_cast(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) = - *(reinterpret_cast(input + feature_value_accessor.common_feature_value.CpuPtrIndex())); - cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - input[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; - cur[feature_value_accessor.common_feature_value.ShowIndex()] = - input[feature_value_accessor.common_feature_value.ShowIndex()]; - cur[feature_value_accessor.common_feature_value.ClickIndex()] = - input[feature_value_accessor.common_feature_value.ClickIndex()]; - cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = - input[feature_value_accessor.common_feature_value.EmbedWIndex()]; - for (int x = 0; x < feature_value_accessor.common_feature_value.EmbedDim(); x++) { - cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x] = - input[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + x]; - } - cur[feature_value_accessor.common_feature_value.SlotIndex()] = - input[feature_value_accessor.common_feature_value.SlotIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = - input[feature_value_accessor.common_feature_value.MfDimIndex()]; - cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = - input[feature_value_accessor.common_feature_value.MfSizeIndex()]; - - for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); - x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - cur[x] = input[x]; - } - } else { - if (keys[i] != 0) { - printf("warning::pull miss key: %llu", keys[i]); - } + feature_value_accessor.FeatureValueFill(cur, input, mf_dim); } } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index aaa6bdc1787602..a5ed2be90142bf 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -143,21 +143,7 @@ __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys, float* cur = (float*)((char*)d_shard_grads + i * grad_value_size); float* shard_val = (float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); - cur[feature_value_accessor.common_push_value.SlotIndex()] = - shard_val[feature_value_accessor.common_push_value.SlotIndex()]; - cur[feature_value_accessor.common_push_value.ShowIndex()] = - shard_val[feature_value_accessor.common_push_value.ShowIndex()]; - cur[feature_value_accessor.common_push_value.ClickIndex()] = - shard_val[feature_value_accessor.common_push_value.ClickIndex()]; - cur[feature_value_accessor.common_push_value.MfDimIndex()] = - shard_val[feature_value_accessor.common_push_value.MfDimIndex()]; - cur[feature_value_accessor.common_push_value.EmbedGIndex()] = - shard_val[feature_value_accessor.common_push_value.EmbedGIndex()]; - - for (int x = 0; x < int(shard_val[feature_value_accessor.common_push_value.MfDimIndex()]); x++) { - cur[feature_value_accessor.common_push_value.EmbedxGIndex() + x] = - shard_val[feature_value_accessor.common_push_value.EmbedxGIndex() + x]; - } + feature_value_accessor.PushValueFill(cur, shard_val); } } @@ -202,30 +188,7 @@ __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); - *(reinterpret_cast(cur + feature_value_accessor.common_feature_value.CpuPtrIndex())) = - *(reinterpret_cast(shard_val + feature_value_accessor.common_feature_value.CpuPtrIndex())); - cur[feature_value_accessor.common_feature_value.DeltaScoreIndex()] = - shard_val[feature_value_accessor.common_feature_value.DeltaScoreIndex()]; - cur[feature_value_accessor.common_feature_value.ShowIndex()] = - shard_val[feature_value_accessor.common_feature_value.ShowIndex()]; - cur[feature_value_accessor.common_feature_value.ClickIndex()] = - shard_val[feature_value_accessor.common_feature_value.ClickIndex()]; - cur[feature_value_accessor.common_feature_value.EmbedWIndex()] = - shard_val[feature_value_accessor.common_feature_value.EmbedWIndex()]; - for (int i = 0; i < feature_value_accessor.common_feature_value.EmbedDim(); i++) { - cur[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i] = - shard_val[feature_value_accessor.common_feature_value.EmbedG2SumIndex() + i]; - } - cur[feature_value_accessor.common_feature_value.SlotIndex()] = - shard_val[feature_value_accessor.common_feature_value.SlotIndex()]; - cur[feature_value_accessor.common_feature_value.MfDimIndex()] = mf_dim; - cur[feature_value_accessor.common_feature_value.MfSizeIndex()] = - shard_val[feature_value_accessor.common_feature_value.MfSizeIndex()]; - - for (int x = feature_value_accessor.common_feature_value.EmbedxG2SumIndex(); - x < int(feature_value_accessor.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - cur[x] = shard_val[x]; - } + feature_value_accessor.FeatureValueFill(cur, shard_val, mf_dim); } } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index be433065c581e5..96ec86f5b36c25 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -43,34 +43,13 @@ struct DynamicGradMerger { __device__ __forceinline__ void update_one(float* output, const float* input, CommonFeatureValueAccessor& feature_value_accessor) { - output[feature_value_accessor.common_push_value.SlotIndex()] = - input[feature_value_accessor.common_push_value.SlotIndex()]; - output[feature_value_accessor.common_push_value.ShowIndex()] = - input[feature_value_accessor.common_push_value.ShowIndex()]; - output[feature_value_accessor.common_push_value.ClickIndex()] = - input[feature_value_accessor.common_push_value.ClickIndex()]; - output[feature_value_accessor.common_push_value.MfDimIndex()] = - input[feature_value_accessor.common_push_value.MfDimIndex()]; - output[feature_value_accessor.common_push_value.EmbedGIndex()] = - input[feature_value_accessor.common_push_value.EmbedGIndex()]; - for (int j = 0; j < int(output[feature_value_accessor.common_push_value.MfDimIndex()]); j++) { - output[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = - input[feature_value_accessor.common_push_value.EmbedxGIndex() + j]; - } + feature_value_accessor.PushValueFill(output, input); } __device__ __forceinline__ void merge_one(float* output, const float* input, CommonFeatureValueAccessor& feature_value_accessor) { - output[feature_value_accessor.common_push_value.ShowIndex()] += - input[feature_value_accessor.common_push_value.ShowIndex()]; - output[feature_value_accessor.common_push_value.ClickIndex()] += - input[feature_value_accessor.common_push_value.ClickIndex()]; - output[feature_value_accessor.common_push_value.EmbedGIndex()] += - input[feature_value_accessor.common_push_value.EmbedGIndex()]; - for (int j = 0; j < int(output[feature_value_accessor.common_push_value.MfDimIndex()]); j++) { - output[feature_value_accessor.common_push_value.EmbedxGIndex() + j] += - input[feature_value_accessor.common_push_value.EmbedxGIndex() + j]; - } + feature_value_accessor.MergePushValue(output, input); + } }; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index e545956844672a..3f7ea8e141849f 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -774,42 +774,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { #ifdef PADDLE_WITH_PSCORE VLOG(5) << "cpu build "<< k << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) << " |: "<< cpu_table_accessor_->ParseToString(ptr_val, dim); - val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()]; - val[feature_value_accessor_.common_feature_value.ShowIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.ShowIndex()]; - val[feature_value_accessor_.common_feature_value.ClickIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.ClickIndex()]; - val[feature_value_accessor_.common_feature_value.SlotIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.SlotIndex()]; - val[feature_value_accessor_.common_feature_value.EmbedWIndex()] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()]; - for (int i = 0; i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { - val[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + i] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i]; - } - + feature_value_accessor_.BuildFill(val, ptr_val, cpu_table_accessor_, mf_dim, dim); *(reinterpret_cast(val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = (uint64_t)(device_dim_ptrs[k]); - - ptr_val[cpu_table_accessor_->common_feature_value.MfDimIndex()] = float(mf_dim); - val[feature_value_accessor_.common_feature_value.MfDimIndex()] = mf_dim; - if (dim > cpu_table_accessor_->GetAccessorInfo().dim - - cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)) { - val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); - - for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float)); - x++) { - val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x] = - ptr_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x]; - } - } else { - val[feature_value_accessor_.common_feature_value.MfSizeIndex()] = 0; - for (int x = feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(); - x < int(feature_value_accessor_.common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - val[x] = 0; - } - } VLOG(5) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.common_feature_value.Dim(mf_dim)); } #endif @@ -1041,30 +1007,7 @@ void PSGPUWrapper::EndPass() { } float* cpu_val = downpour_value->data(); - cpu_val[cpu_table_accessor_->common_feature_value.DeltaScoreIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.DeltaScoreIndex()]; - cpu_val[cpu_table_accessor_->common_feature_value.ShowIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.ShowIndex()]; - cpu_val[cpu_table_accessor_->common_feature_value.ClickIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.ClickIndex()]; - cpu_val[cpu_table_accessor_->common_feature_value.EmbedWIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.EmbedWIndex()]; - cpu_val[cpu_table_accessor_->common_feature_value.SlotIndex()] = - gpu_val[feature_value_accessor_.common_feature_value.SlotIndex()]; - - for (int i = 0; i < feature_value_accessor_.common_feature_value.EmbedDim(); i++) { - cpu_val[cpu_table_accessor_->common_feature_value.EmbedG2SumIndex() + i] = - gpu_val[feature_value_accessor_.common_feature_value.EmbedG2SumIndex() + i]; - } - - if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0) { - - for (int x = 0; x < int(feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float)); - x++) { - cpu_val[cpu_table_accessor_->common_feature_value.EmbedxG2SumIndex() + x] = - gpu_val[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + x]; - } - } + feature_value_accessor_.DumpFill(cpu_val, gpu_val, cpu_table_accessor_, mf_dim); VLOG(5) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.common_feature_value.Dim(mf_dim)) << " ===== CPU:" << cpu_table_accessor_->ParseToString(cpu_val, downpour_value->size()); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 94c7862fe36de3..88785f47ed4bcd 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -90,26 +90,7 @@ __global__ void PullCopy(float** dest, float* feature_value_ptr = (float*)((char*)src + uint64_t(i) * uint64_t(max_val_size)); int mf_dim = gpu_dim[x] - 3; - if (*(keys[x] + y) == 0) { - *(dest[x] + y * (mf_dim + 3)) = 0; - *(dest[x] + y * (mf_dim + 3) + 1) = 0; - *(dest[x] + y * (mf_dim + 3) + 2) = 0; - } else { - *(dest[x] + y * (mf_dim + 3)) = feature_value_ptr[feature_value_accessor.common_feature_value.ShowIndex()]; - *(dest[x] + y * (mf_dim + 3) + 1) = feature_value_ptr[feature_value_accessor.common_feature_value.ClickIndex()]; - *(dest[x] + y * (mf_dim + 3) + 2) = feature_value_ptr[feature_value_accessor.common_feature_value.EmbedWIndex()]; - } - - if (feature_value_ptr[feature_value_accessor.common_feature_value.MfSizeIndex()] == 0 || *(keys[x] + y) == 0) { - for (int j = 0; j < mf_dim; j++) { - *(dest[x] + y * (mf_dim + 3) + 3 + j) = 0; - } - } else { - for (int j = 0; j < mf_dim; j++) { - *(dest[x] + y * (mf_dim + 3) + 3 + j) = - feature_value_ptr[feature_value_accessor.common_feature_value.EmbedxWOffsetIndex(feature_value_ptr) + j]; - } - } + feature_value_accessor.Select(dest[x] + y * (mf_dim + 3), feature_value_ptr, keys[x] + y, mf_dim); } } From 2b87eff7fa6cb84988d032df9b6e1eccdb4f7ff0 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Fri, 1 Jul 2022 16:44:02 +0800 Subject: [PATCH 15/31] template;test=develop --- .../framework/fleet/heter_ps/hashtable_kernel.cu | 4 ++-- .../framework/fleet/heter_ps/heter_comm_kernel.cu | 13 +++++++------ paddle/fluid/framework/fleet/ps_gpu_wrapper.cu | 6 ++++-- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 9f210e545bc0ed..367925a051c119 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -83,13 +83,13 @@ __global__ void search_kernel(Table* table, } } -template +template __global__ void dy_mf_search_kernel(Table* table, const typename Table::key_type* const keys, char* vals, size_t len, size_t pull_feature_value_size, - CommonFeatureValueAccessor feature_value_accessor) { + FVAceessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; // return; if (i < len) { diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index a5ed2be90142bf..fe23c39b0e3387 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -128,7 +128,7 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, } } -template +template __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys, KeyType* d_keys, float* d_shard_grads, @@ -136,7 +136,7 @@ __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys, T* idx, size_t len, size_t grad_value_size, - CommonFeatureValueAccessor feature_value_accessor) { + FVAceessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { d_shard_keys[i] = d_keys[idx[i]]; @@ -147,6 +147,7 @@ __global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys, } } +template __global__ void merge_gradients_kernel(const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, @@ -154,8 +155,8 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, char* output, int n, size_t grad_value_size, - DynamicGradMerger& merger_, - CommonFeatureValueAccessor& feature_value_accessor) { + DynamicGradMerger& merger, + FVAceessor& feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { uint32_t start = offset[i]; @@ -174,13 +175,13 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, } } -template +template __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals, T* idx, size_t len, size_t val_size, - CommonFeatureValueAccessor feature_value_accessor) { + FVAceessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { uint64_t new_offset = uint64_t(idx[i]) * val_size; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 88785f47ed4bcd..1fd5a3b73d0ffc 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -66,6 +66,7 @@ __global__ void PullCopy(float** dest, } } +template __global__ void PullCopy(float** dest, const float* src, const int64_t* len, @@ -74,7 +75,7 @@ __global__ void PullCopy(float** dest, uint64_t** keys, uint64_t max_val_size, int* gpu_dim, - CommonFeatureValueAccessor feature_value_accessor) { + FVAceessor feature_value_accessor) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; int high = slot_num - 1; @@ -145,6 +146,7 @@ __global__ void PushCopy(FeaturePushValue* dest, } } +template __global__ void PushCopyWithPool(float* dest, float** src, int64_t* len, @@ -154,7 +156,7 @@ __global__ void PushCopyWithPool(float* dest, int* slot_vector, int* mf_dim_vector, size_t grad_value_size, - CommonFeatureValueAccessor feature_value_accessor) { + FVAceessor feature_value_accessor) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; int high = slot_num - 1; From 13aac52b600de463de83964097ea16afcefd7fb8 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 6 Jul 2022 12:21:13 +0000 Subject: [PATCH 16/31] fix adam accessor:template;test=develop --- .../framework/fleet/heter_ps/CMakeLists.txt | 8 +- .../framework/fleet/heter_ps/feature_value.h | 684 ++++++++++++------ .../fleet/heter_ps/graph_gpu_ps_table.h | 3 +- .../framework/fleet/heter_ps/hashtable.h | 16 +- .../fleet/heter_ps/hashtable_kernel.cu | 98 ++- .../framework/fleet/heter_ps/heter_comm.h | 24 +- .../framework/fleet/heter_ps/heter_comm_inl.h | 546 ++++++++------ .../fleet/heter_ps/heter_comm_kernel.cu | 142 ++-- .../fleet/heter_ps/heter_comm_kernel.h | 36 +- .../framework/fleet/heter_ps/heter_ps.cc | 32 +- .../framework/fleet/heter_ps/heter_ps.cu | 146 ++-- .../fluid/framework/fleet/heter_ps/heter_ps.h | 17 +- .../framework/fleet/heter_ps/heter_ps_base.h | 12 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 487 +++++++------ .../fluid/framework/fleet/ps_gpu_wrapper.cu | 325 +-------- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 158 ++-- .../fluid/framework/fleet/ps_gpu_wrapper.kps | 183 ++--- 17 files changed, 1577 insertions(+), 1340 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt index 7540c6147f4b72..9631502f4f05e0 100644 --- a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt @@ -9,16 +9,16 @@ if(WITH_GPU) endif() nv_library( heter_comm_kernel - SRCS heter_comm_kernel.cu feature_value.h + SRCS heter_comm_kernel.cu feature_value.h feature_value.cu DEPS ${HETERPS_DEPS}) nv_library( hashtable_kernel - SRCS hashtable_kernel.cu feature_value.h + SRCS hashtable_kernel.cu feature_value.h feature_value.cu DEPS ${HETERPS_DEPS}) nv_library( heter_comm - SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h - mem_pool.h + SRCS heter_comm.h feature_value.h feature_value.cu heter_resource.cc + heter_resource.h mem_pool.h DEPS ${HETERPS_DEPS} heter_comm_kernel hashtable_kernel) nv_test( test_heter_comm diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index 2486f34cc0e256..f2001cdf42ae5b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -19,26 +19,35 @@ limitations under the License. */ #include #include #include -#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" +#ifdef PADDLE_WITH_PSCORE +#include "paddle/fluid/distributed/ps/table/accessor.h" +#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" +#include "paddle/fluid/distributed/ps/table/depends/feature_value.h" +#endif namespace paddle { namespace framework { #define MF_DIM 8 typedef uint64_t FeatureKey; +#define TYPEALIGN(ALIGNVAL, LEN) \ + (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1))) class FeatureValueAccessor { public: - __host__ __device__ FeatureValueAccessor() {} + __host__ __device__ FeatureValueAccessor() {} __host__ __device__ ~FeatureValueAccessor() {} - __host__ __device__ virtual int Configure(std::unordered_map config) { + __host__ __device__ virtual int Configure( + std::unordered_map config) { _config = config; Initialize(); return 0; } - __host__ __device__ virtual int Initialize() = 0; + __host__ __device__ virtual int Initialize() = 0; protected: std::unordered_map _config; @@ -64,47 +73,58 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { std::vector embedx_w; */ - __host__ __device__ int Dim() { return 9 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; } // has cpu_ptr(2) - __host__ __device__ int DimSize(size_t dim, int embedx_dim) { return sizeof(float); } - __host__ __device__ int Size() { return Dim() * sizeof(float); } // cpu_ptr:uint64=2float - __host__ __device__ int EmbedDim() { return embed_sgd_dim;} - __host__ __device__ int EmbedXDim() { return embedx_sgd_dim;} - __host__ __device__ int EmbedWDim() { return embedx_dim;} - __host__ __device__ int CpuPtrIndex() {return 0; } // cpuprt uint64 - __host__ __device__ int DeltaScoreIndex() { return CpuPtrIndex() + 2; } + __host__ __device__ int Dim() { + return 9 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; + } // has cpu_ptr(2) + __host__ __device__ int DimSize(size_t dim, int embedx_dim) { + return sizeof(float); + } + __host__ __device__ size_t Size() { + return TYPEALIGN(8, Dim() * sizeof(float)); + } // cpu_ptr:uint64=2float + __host__ __device__ int EmbedDim() { return embed_sgd_dim; } + __host__ __device__ int EmbedXDim() { return embedx_sgd_dim; } + __host__ __device__ int EmbedWDim() { return embedx_dim; } + __host__ __device__ int CpuPtrIndex() { return 0; } // cpuprt uint64 + __host__ __device__ int DeltaScoreIndex() { return CpuPtrIndex() + 2; } __host__ __device__ int ShowIndex() { return DeltaScoreIndex() + 1; } __host__ __device__ int ClickIndex() { return ShowIndex() + 1; } __host__ __device__ int EmbedWIndex() { return ClickIndex() + 1; } __host__ __device__ int EmbedG2SumIndex() { return EmbedWIndex() + 1; } - __host__ __device__ int SlotIndex() { return EmbedG2SumIndex() + embed_sgd_dim; } + __host__ __device__ int SlotIndex() { + return EmbedG2SumIndex() + embed_sgd_dim; + } __host__ __device__ int MfDimIndex() { return SlotIndex() + 1; } - __host__ __device__ int MfSizeIndex() { return MfDimIndex() + 1; } // actual mf size (ex. 0) + __host__ __device__ int MfSizeIndex() { + return MfDimIndex() + 1; + } // actual mf size (ex. 0) __host__ __device__ int EmbedxG2SumIndex() { return MfSizeIndex() + 1; } - __host__ __device__ int EmbedxWIndex() { return EmbedxG2SumIndex() + embedx_sgd_dim; } - + __host__ __device__ int EmbedxWIndex() { + return EmbedxG2SumIndex() + embedx_sgd_dim; + } // 根据mf_dim计算的总长度 __host__ __device__ int Dim(int& mf_dim) { int tmp_embedx_sgd_dim = 1; - if (optimizer_type_ == 3) {//adam + if (optimizer_type_ == 3) { // adam tmp_embedx_sgd_dim = mf_dim * 2 + 2; - } else if (optimizer_type_ == 4) { //shared_adam + } else if (optimizer_type_ == 4) { // shared_adam tmp_embedx_sgd_dim = 4; } return 9 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim; } // 根据mf_dim 计算的总byte数 - __host__ __device__ int Size(int& mf_dim) { - return Dim(mf_dim) * sizeof(float); // cpu_ptr:2float + __host__ __device__ size_t Size(int& mf_dim) { + return TYPEALIGN(8, Dim(mf_dim) * sizeof(float)); // cpu_ptr:2float } // 根据mf_dim 计算的 mf_size byte数 - __host__ __device__ int MFSize(int& mf_dim) { + __host__ __device__ size_t MFSize(int& mf_dim) { int tmp_embedx_sgd_dim = 1; - if (optimizer_type_ == 3) { //adam + if (optimizer_type_ == 3) { // adam tmp_embedx_sgd_dim = mf_dim * 2 + 2; - } else if (optimizer_type_ == 4) { //shared_adam + } else if (optimizer_type_ == 4) { // shared_adam tmp_embedx_sgd_dim = 4; } return (tmp_embedx_sgd_dim + mf_dim) * sizeof(float); @@ -112,33 +132,42 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { __host__ __device__ int EmbedxG2SumOffsetIndex() { return 0; } __host__ __device__ int EmbedxWOffsetIndex(float* val) { - // has mf + // has mf int tmp_embedx_sgd_dim = 1; if (int(MfSize(val)) > 0) { - if (optimizer_type_ == 3) {//adam + if (optimizer_type_ == 3) { // adam tmp_embedx_sgd_dim = int(MfDim(val)) * 2 + 2; - } else if (optimizer_type_ == 4) { //shared_adam + } else if (optimizer_type_ == 4) { // shared_adam tmp_embedx_sgd_dim = 4; } - return EmbedxG2SumIndex() + tmp_embedx_sgd_dim; + return EmbedxG2SumIndex() + tmp_embedx_sgd_dim; } else { // no mf return 0; } } - - __host__ __device__ uint64_t CpuPtr(float* val) {return *(reinterpret_cast(val)); } - __host__ __device__ float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } + __host__ __device__ uint64_t CpuPtr(float* val) { + return *(reinterpret_cast(val)); + } + __host__ __device__ float& DeltaScore(float* val) { + return val[DeltaScoreIndex()]; + } __host__ __device__ float& Show(float* val) { return val[ShowIndex()]; } __host__ __device__ float& Click(float* val) { return val[ClickIndex()]; } __host__ __device__ float& Slot(float* val) { return val[SlotIndex()]; } __host__ __device__ float& MfDim(float* val) { return val[MfDimIndex()]; } __host__ __device__ float& MfSize(float* val) { return val[MfSizeIndex()]; } __host__ __device__ float& EmbedW(float* val) { return val[EmbedWIndex()]; } - __host__ __device__ float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; } - __host__ __device__ float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; } - __host__ __device__ float& EmbedxW(float* val) { return val[EmbedxWIndex()]; } + __host__ __device__ float& EmbedG2Sum(float* val) { + return val[EmbedG2SumIndex()]; + } + __host__ __device__ float& EmbedxG2Sum(float* val) { + return val[EmbedxG2SumIndex()]; + } + __host__ __device__ float& EmbedxW(float* val) { + return val[EmbedxWIndex()]; + } int embed_sgd_dim; int embedx_dim; @@ -158,14 +187,28 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { __host__ __device__ int Dim(int embedx_dim) { return 5 + embedx_dim; } - __host__ __device__ int DimSize(int dim, int embedx_dim) { return sizeof(float); } - __host__ __device__ int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } + __host__ __device__ int DimSize(int dim, int embedx_dim) { + return sizeof(float); + } + __host__ __device__ int Size(int embedx_dim) { + return TYPEALIGN(8, Dim(embedx_dim) * sizeof(float)); + } __host__ __device__ int SlotIndex() { return 0; } - __host__ __device__ int ShowIndex() { return CommonPushValue::SlotIndex() + 1; } - __host__ __device__ int ClickIndex() { return CommonPushValue::ShowIndex() + 1; } - __host__ __device__ int MfDimIndex() { return CommonPushValue::ClickIndex() + 1; } - __host__ __device__ int EmbedGIndex() { return CommonPushValue::MfDimIndex() + 1; } - __host__ __device__ int EmbedxGIndex() { return CommonPushValue::EmbedGIndex() + 1; } + __host__ __device__ int ShowIndex() { + return CommonPushValue::SlotIndex() + 1; + } + __host__ __device__ int ClickIndex() { + return CommonPushValue::ShowIndex() + 1; + } + __host__ __device__ int MfDimIndex() { + return CommonPushValue::ClickIndex() + 1; + } + __host__ __device__ int EmbedGIndex() { + return CommonPushValue::MfDimIndex() + 1; + } + __host__ __device__ int EmbedxGIndex() { + return CommonPushValue::EmbedGIndex() + 1; + } __host__ __device__ float& Slot(float* val) { return val[CommonPushValue::SlotIndex()]; } @@ -194,9 +237,13 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { std::vector embedx_w; */ - __host__ __device__ static int Dim(int embedx_dim) { return 3 + embedx_dim; } + __host__ __device__ static int Dim(int embedx_dim) { + return 3 + embedx_dim; + } __host__ __device__ int DimSize(size_t dim) { return sizeof(float); } - __host__ __device__ int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); } + __host__ __device__ int Size(int embedx_dim) { + return TYPEALIGN(8, Dim(embedx_dim) * sizeof(float)); + } __host__ __device__ int ShowIndex() { return 0; } __host__ __device__ int ClickIndex() { return 1; } __host__ __device__ int EmbedWIndex() { return 2; } @@ -215,21 +262,20 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { } }; - __host__ __device__ CommonFeatureValueAccessor() {} __host__ __device__ ~CommonFeatureValueAccessor() {} __host__ __device__ virtual int Initialize() { int optimizer_type = (_config.find("optimizer_type") == _config.end()) - ? 1 - : int(_config["optimizer_type"]); + ? 1 + : int(_config["optimizer_type"]); int sparse_embedx_dim = (_config.find("embedx_dim") == _config.end()) ? 8 : int(_config["embedx_dim"]); - if (optimizer_type == 3) { //adam + if (optimizer_type == 3) { // adam common_feature_value.embed_sgd_dim = 4; common_feature_value.embedx_sgd_dim = sparse_embedx_dim * 2 + 2; - } else if (optimizer_type == 4) { //shared_adam + } else if (optimizer_type == 4) { // shared_adam common_feature_value.embed_sgd_dim = 4; common_feature_value.embedx_sgd_dim = 4; } else { @@ -242,168 +288,219 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { return 0; } - -// build阶段从cpu_val赋值给gpu_val -__host__ __device__ void BuildFill(float* gpu_val, - float* cpu_val, - paddle::distributed::CtrDymfAccessor* cpu_table_accessor, - int mf_dim, - size_t cpu_fv_dim) { - - gpu_val[common_feature_value.DeltaScoreIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.DeltaScoreIndex()]; - gpu_val[common_feature_value.ShowIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.ShowIndex()]; - gpu_val[common_feature_value.ClickIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.ClickIndex()]; - gpu_val[common_feature_value.SlotIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.SlotIndex()]; - gpu_val[common_feature_value.EmbedWIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.EmbedWIndex()]; - for (int i = 0; i < common_feature_value.EmbedDim(); i++) { - gpu_val[common_feature_value.EmbedG2SumIndex() + i] = - cpu_val[cpu_table_accessor->common_feature_value.EmbedG2SumIndex() + i]; + // // build阶段从cpu_val赋值给gpu_val + __host__ void BuildFill( + float* gpu_val, + void* cpu, + paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + int mf_dim) { +#ifdef PADDLE_WITH_PSCORE + paddle::distributed::FixedFeatureValue* cpu_ptr = + (paddle::distributed::FixedFeatureValue*)(cpu); + float* cpu_val = cpu_ptr->data(); + size_t cpu_dim = cpu_ptr->size(); + + gpu_val[common_feature_value.DeltaScoreIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.DeltaScoreIndex()]; + gpu_val[common_feature_value.ShowIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.ShowIndex()]; + gpu_val[common_feature_value.ClickIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.ClickIndex()]; + gpu_val[common_feature_value.SlotIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.SlotIndex()]; + gpu_val[common_feature_value.EmbedWIndex()] = + cpu_val[cpu_table_accessor->common_feature_value.EmbedWIndex()]; + for (int i = 0; i < common_feature_value.EmbedDim(); i++) { + gpu_val[common_feature_value.EmbedG2SumIndex() + i] = + cpu_val[cpu_table_accessor->common_feature_value.EmbedG2SumIndex() + + i]; + } + *(reinterpret_cast( + gpu_val + common_feature_value.CpuPtrIndex())) = (uint64_t)(cpu); + cpu_val[cpu_table_accessor->common_feature_value.MfDimIndex()] = + float(mf_dim); + gpu_val[common_feature_value.MfDimIndex()] = mf_dim; + if (cpu_dim > + cpu_table_accessor->GetAccessorInfo().dim - + cpu_table_accessor->GetAccessorInfo().mf_size / sizeof(float)) { + gpu_val[common_feature_value.MfSizeIndex()] = + common_feature_value.MFSize(mf_dim) / sizeof(float); + + for (int x = 0; + x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); + x++) { + gpu_val[common_feature_value.EmbedxG2SumIndex() + x] = cpu_val + [cpu_table_accessor->common_feature_value.EmbedxG2SumIndex() + x]; + } + } else { + gpu_val[common_feature_value.MfSizeIndex()] = 0; + for (int x = common_feature_value.EmbedxG2SumIndex(); + x < int(common_feature_value.Size(mf_dim) / sizeof(float)); + x++) { + gpu_val[x] = 0; + } + } +#endif } - cpu_val[cpu_table_accessor->common_feature_value.MfDimIndex()] = float(mf_dim); - gpu_val[common_feature_value.MfDimIndex()] = mf_dim; - if (cpu_fv_dim > cpu_table_accessor->GetAccessorInfo().dim - - cpu_table_accessor->GetAccessorInfo().mf_size / sizeof(float)) { - gpu_val[common_feature_value.MfSizeIndex()] = - common_feature_value.MFSize(mf_dim) / sizeof(float); - - for (int x = 0; x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); - x++) { - gpu_val[common_feature_value.EmbedxG2SumIndex() + x] = - cpu_val[cpu_table_accessor->common_feature_value.EmbedxG2SumIndex() + x]; + // dump_to_cpu阶段从gpu_val赋值给cpu_val + __host__ __device__ void DumpFill( + float* gpu_val, + paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + int mf_dim) { +#ifdef PADDLE_WITH_PSCORE + auto* downpour_value = + (paddle::distributed::FixedFeatureValue*)(*(reinterpret_cast( + gpu_val + common_feature_value.CpuPtrIndex()))); + size_t downpour_value_size = downpour_value->size(); + if (gpu_val[common_feature_value.MfSizeIndex()] > 0 && + downpour_value_size == + (cpu_table_accessor->GetAccessorInfo().dim - + int(cpu_table_accessor->GetAccessorInfo().mf_size / + sizeof(float)))) { // cpu_accessor + downpour_value->resize( + cpu_table_accessor->common_feature_value.Dim(mf_dim)); } - } else { - gpu_val[common_feature_value.MfSizeIndex()] = 0; - for (int x = common_feature_value.EmbedxG2SumIndex(); - x < int(common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - gpu_val[x] = 0; + float* cpu_val = downpour_value->data(); + cpu_val[cpu_table_accessor->common_feature_value.DeltaScoreIndex()] = + gpu_val[common_feature_value.DeltaScoreIndex()]; + cpu_val[cpu_table_accessor->common_feature_value.ShowIndex()] = + gpu_val[common_feature_value.ShowIndex()]; + cpu_val[cpu_table_accessor->common_feature_value.ClickIndex()] = + gpu_val[common_feature_value.ClickIndex()]; + cpu_val[cpu_table_accessor->common_feature_value.EmbedWIndex()] = + gpu_val[common_feature_value.EmbedWIndex()]; + cpu_val[cpu_table_accessor->common_feature_value.SlotIndex()] = + gpu_val[common_feature_value.SlotIndex()]; + + for (int i = 0; i < common_feature_value.EmbedDim(); i++) { + cpu_val[cpu_table_accessor->common_feature_value.EmbedG2SumIndex() + i] = + gpu_val[common_feature_value.EmbedG2SumIndex() + i]; } - } -} - - -// dump_to_cpu阶段从gpu_val赋值给cpu_val -__host__ __device__ void DumpFill(float* cpu_val, - float* gpu_val, - paddle::distributed::CtrDymfAccessor* cpu_table_accessor, - int mf_dim) { - - cpu_val[cpu_table_accessor->common_feature_value.DeltaScoreIndex()] = - gpu_val[common_feature_value.DeltaScoreIndex()]; - cpu_val[cpu_table_accessor->common_feature_value.ShowIndex()] = - gpu_val[common_feature_value.ShowIndex()]; - cpu_val[cpu_table_accessor->common_feature_value.ClickIndex()] = - gpu_val[common_feature_value.ClickIndex()]; - cpu_val[cpu_table_accessor->common_feature_value.EmbedWIndex()] = - gpu_val[common_feature_value.EmbedWIndex()]; - cpu_val[cpu_table_accessor->common_feature_value.SlotIndex()] = - gpu_val[common_feature_value.SlotIndex()]; - - for (int i = 0; i < common_feature_value.EmbedDim(); i++) { - cpu_val[cpu_table_accessor->common_feature_value.EmbedG2SumIndex() + i] = - gpu_val[common_feature_value.EmbedG2SumIndex() + i]; - } - if (gpu_val[common_feature_value.MfSizeIndex()] > 0) { - - for (int x = 0; x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); - x++) { - cpu_val[cpu_table_accessor->common_feature_value.EmbedxG2SumIndex() + x] = - gpu_val[common_feature_value.EmbedxG2SumIndex() + x]; + if (gpu_val[common_feature_value.MfSizeIndex()] > 0) { + for (int x = 0; + x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); + x++) { + cpu_val[cpu_table_accessor->common_feature_value.EmbedxG2SumIndex() + + x] = gpu_val[common_feature_value.EmbedxG2SumIndex() + x]; + } } +#endif } -} - - -// dy_mf_fill_dvals_kernel, dy_mf_search_kernel 阶段 gpukernel 中从src_val赋值给dest_val -__host__ __device__ void FeatureValueFill(float* dest_val, - float* src_val, - int mf_dim) { - *(reinterpret_cast(dest_val + common_feature_value.CpuPtrIndex())) = - *(reinterpret_cast(src_val + common_feature_value.CpuPtrIndex())); - dest_val[common_feature_value.DeltaScoreIndex()] = src_val[common_feature_value.DeltaScoreIndex()]; - dest_val[common_feature_value.ShowIndex()] = src_val[common_feature_value.ShowIndex()]; - dest_val[common_feature_value.ClickIndex()] = src_val[common_feature_value.ClickIndex()]; - dest_val[common_feature_value.EmbedWIndex()] = src_val[common_feature_value.EmbedWIndex()]; - for (int i = 0; i < common_feature_value.EmbedDim(); i++) { - dest_val[common_feature_value.EmbedG2SumIndex() + i] = - src_val[common_feature_value.EmbedG2SumIndex() + i]; - } - dest_val[common_feature_value.SlotIndex()] = src_val[common_feature_value.SlotIndex()]; - dest_val[common_feature_value.MfDimIndex()] = mf_dim; - dest_val[common_feature_value.MfSizeIndex()] = src_val[common_feature_value.MfSizeIndex()]; - for (int x = common_feature_value.EmbedxG2SumIndex(); - x < int(common_feature_value.Size(mf_dim) / sizeof(float)); x++){ - dest_val[x] = src_val[x]; + // dy_mf_fill_dvals_kernel, dy_mf_search_kernel 阶段 gpukernel + // 中从src_val赋值给dest_val + __host__ __device__ void FeatureValueFill(float* dest_val, + float* src_val, + int mf_dim) { + *(reinterpret_cast(dest_val + + common_feature_value.CpuPtrIndex())) = + *(reinterpret_cast(src_val + + common_feature_value.CpuPtrIndex())); + dest_val[common_feature_value.DeltaScoreIndex()] = + src_val[common_feature_value.DeltaScoreIndex()]; + dest_val[common_feature_value.ShowIndex()] = + src_val[common_feature_value.ShowIndex()]; + dest_val[common_feature_value.ClickIndex()] = + src_val[common_feature_value.ClickIndex()]; + dest_val[common_feature_value.EmbedWIndex()] = + src_val[common_feature_value.EmbedWIndex()]; + for (int i = 0; i < common_feature_value.EmbedDim(); i++) { + dest_val[common_feature_value.EmbedG2SumIndex() + i] = + src_val[common_feature_value.EmbedG2SumIndex() + i]; + } + dest_val[common_feature_value.SlotIndex()] = + src_val[common_feature_value.SlotIndex()]; + dest_val[common_feature_value.MfDimIndex()] = mf_dim; + dest_val[common_feature_value.MfSizeIndex()] = + src_val[common_feature_value.MfSizeIndex()]; + + for (int x = common_feature_value.EmbedxG2SumIndex(); + x < int(common_feature_value.Size(mf_dim) / sizeof(float)); + x++) { + dest_val[x] = src_val[x]; + } } -} + // dy_mf_fill_shard_grads_kernel,update_one 阶段 gpukernel + // 中从src_val赋值给dest_val + __host__ __device__ void PushValueFill(float* dest_val, + const float* src_val) { + dest_val[common_push_value.SlotIndex()] = + src_val[common_push_value.SlotIndex()]; + dest_val[common_push_value.ShowIndex()] = + src_val[common_push_value.ShowIndex()]; + dest_val[common_push_value.ClickIndex()] = + src_val[common_push_value.ClickIndex()]; + dest_val[common_push_value.MfDimIndex()] = + src_val[common_push_value.MfDimIndex()]; + dest_val[common_push_value.EmbedGIndex()] = + src_val[common_push_value.EmbedGIndex()]; + + for (int x = 0; x < int(src_val[common_push_value.MfDimIndex()]); x++) { + dest_val[common_push_value.EmbedxGIndex() + x] = + src_val[common_push_value.EmbedxGIndex() + x]; + } + } -// dy_mf_fill_shard_grads_kernel,update_one 阶段 gpukernel 中从src_val赋值给dest_val -__host__ __device__ void PushValueFill(float* dest_val, - const float* src_val) { - dest_val[common_push_value.SlotIndex()] = src_val[common_push_value.SlotIndex()]; - dest_val[common_push_value.ShowIndex()] = src_val[common_push_value.ShowIndex()]; - dest_val[common_push_value.ClickIndex()] = src_val[common_push_value.ClickIndex()]; - dest_val[common_push_value.MfDimIndex()] = src_val[common_push_value.MfDimIndex()]; - dest_val[common_push_value.EmbedGIndex()] = src_val[common_push_value.EmbedGIndex()]; + // update_basic 阶段 gpukernel 中从src_val赋值给dest_val + __host__ __device__ void PushValueFillBasic(float* dest_val, + const float* src_val) { + dest_val[common_push_value.SlotIndex()] = + src_val[common_push_value.SlotIndex()]; + dest_val[common_push_value.ShowIndex()] = + src_val[common_push_value.ShowIndex()]; + dest_val[common_push_value.ClickIndex()] = + src_val[common_push_value.ClickIndex()]; + dest_val[common_push_value.MfDimIndex()] = + src_val[common_push_value.MfDimIndex()]; + dest_val[common_push_value.EmbedGIndex()] = + src_val[common_push_value.EmbedGIndex()]; + } - for (int x = 0; x < int(src_val[common_push_value.MfDimIndex()]); x++) { - dest_val[common_push_value.EmbedxGIndex() + x] = src_val[common_push_value.EmbedxGIndex() + x]; + // merge_one 阶段 gpukernel 中 PushValue 从src_val赋值给dest_val + __host__ __device__ void MergePushValue(float* dest_val, + const float* src_val) { + dest_val[common_push_value.ShowIndex()] += + src_val[common_push_value.ShowIndex()]; + dest_val[common_push_value.ClickIndex()] += + src_val[common_push_value.ClickIndex()]; + dest_val[common_push_value.EmbedGIndex()] += + src_val[common_push_value.EmbedGIndex()]; + for (int j = 0; j < int(dest_val[common_push_value.MfDimIndex()]); j++) { + dest_val[common_push_value.EmbedxGIndex() + j] += + src_val[common_push_value.EmbedxGIndex() + j]; + } } -} - -// update_basic 阶段 gpukernel 中从src_val赋值给dest_val -__host__ __device__ void PushValueFillBasic(float* dest_val, - const float* src_val) { - dest_val[common_push_value.SlotIndex()] = src_val[common_push_value.SlotIndex()]; - dest_val[common_push_value.ShowIndex()] = src_val[common_push_value.ShowIndex()]; - dest_val[common_push_value.ClickIndex()] = src_val[common_push_value.ClickIndex()]; - dest_val[common_push_value.MfDimIndex()] = src_val[common_push_value.MfDimIndex()]; - dest_val[common_push_value.EmbedGIndex()] = src_val[common_push_value.EmbedGIndex()]; - -} - - -// merge_one 阶段 gpukernel 中 PushValue 从src_val赋值给dest_val -__host__ __device__ void MergePushValue(float* dest_val, - const float* src_val) { - dest_val[common_push_value.ShowIndex()] += src_val[common_push_value.ShowIndex()]; - dest_val[common_push_value.ClickIndex()] += src_val[common_push_value.ClickIndex()]; - dest_val[common_push_value.EmbedGIndex()] += src_val[common_push_value.EmbedGIndex()]; - for (int j = 0; j < int(dest_val[common_push_value.MfDimIndex()]); j++) { - dest_val[common_push_value.EmbedxGIndex() + j] += src_val[common_push_value.EmbedxGIndex() + j]; + + // merge_basic 阶段 gpukernel 中 PushValue 从src_val赋值给dest_val + __host__ __device__ void MergePushValueBasic(float* dest_val, + const float* src_val) { + dest_val[common_push_value.ShowIndex()] += + src_val[common_push_value.ShowIndex()]; + dest_val[common_push_value.ClickIndex()] += + src_val[common_push_value.ClickIndex()]; + dest_val[common_push_value.EmbedGIndex()] += + src_val[common_push_value.EmbedGIndex()]; } -} - - -// merge_basic 阶段 gpukernel 中 PushValue 从src_val赋值给dest_val -__host__ __device__ void MergePushValueBasic(float* dest_val, - const float* src_val) { - dest_val[common_push_value.ShowIndex()] += src_val[common_push_value.ShowIndex()]; - dest_val[common_push_value.ClickIndex()] += src_val[common_push_value.ClickIndex()]; - dest_val[common_push_value.EmbedGIndex()] += src_val[common_push_value.EmbedGIndex()]; -} - -// PullCopy 阶段 gpukernel 中 FeatureValue回填到PullValue -__host__ __device__ void Select(float* dest_val, - float* src_val, - uint64_t* key, - int mf_dim) { - if (*key == 0) { + + // PullCopy 阶段 gpukernel 中 FeatureValue回填到PullValue + __host__ __device__ void Select(float* dest_val, + float* src_val, + uint64_t* key, + int mf_dim) { + if (*key == 0) { *(dest_val + common_pull_value.ShowIndex()) = 0; *(dest_val + common_pull_value.ClickIndex()) = 0; *(dest_val + common_pull_value.EmbedWIndex()) = 0; } else { - *(dest_val + common_pull_value.ShowIndex()) = src_val[common_feature_value.ShowIndex()]; - *(dest_val + common_pull_value.ClickIndex()) = src_val[common_feature_value.ClickIndex()]; - *(dest_val + common_pull_value.EmbedWIndex()) = src_val[common_feature_value.EmbedWIndex()]; + *(dest_val + common_pull_value.ShowIndex()) = + src_val[common_feature_value.ShowIndex()]; + *(dest_val + common_pull_value.ClickIndex()) = + src_val[common_feature_value.ClickIndex()]; + *(dest_val + common_pull_value.EmbedWIndex()) = + src_val[common_feature_value.EmbedWIndex()]; } if (src_val[common_feature_value.MfSizeIndex()] == 0 || *key == 0) { @@ -412,14 +509,14 @@ __host__ __device__ void Select(float* dest_val, } } else { for (int j = 0; j < mf_dim; j++) { - *(dest_val + common_pull_value.EmbedxWIndex() + j) = + *(dest_val + common_pull_value.EmbedxWIndex() + j) = src_val[common_feature_value.EmbedxWOffsetIndex(src_val) + j]; } } -} - + } - __host__ __device__ std::string ParseToString(const float* v, int param_size) { + __host__ __device__ std::string ParseToString(const float* v, + int param_size) { /* uint64_t cpu_ptr; // 2float float delta_score; @@ -434,21 +531,23 @@ __host__ __device__ void Select(float* dest_val, std::vector embedx_w; */ std::stringstream os; - os << "cpuptr: " << common_feature_value.CpuPtr(const_cast(v)) << " delta_score: " << v[2] - << " show: " << v[3] << " click: " << v[4] - << " embed_w:" << v[5] << " embed_g2sum:"; + os << "cpuptr: " << common_feature_value.CpuPtr(const_cast(v)) + << " delta_score: " << v[2] << " show: " << v[3] << " click: " << v[4] + << " embed_w:" << v[5] << " embed_g2sum:"; for (int i = common_feature_value.EmbedG2SumIndex(); - i < common_feature_value.SlotIndex(); i++) { + i < common_feature_value.SlotIndex(); + i++) { os << " " << v[i]; } int mf_dim = int(common_feature_value.MfDim(const_cast(v))); - os << " slot: " << common_feature_value.Slot(const_cast(v)) - << " mf_dim: " << mf_dim - << " mf_size: " << common_feature_value.MfSize(const_cast(v)) - << " mf: "; + os << " slot: " << common_feature_value.Slot(const_cast(v)) + << " mf_dim: " << mf_dim + << " mf_size: " << common_feature_value.MfSize(const_cast(v)) + << " mf: "; if (param_size > common_feature_value.EmbedxG2SumIndex()) { for (auto i = common_feature_value.EmbedxG2SumIndex(); - i < common_feature_value.Dim(mf_dim); ++i) { + i < common_feature_value.Dim(mf_dim); + ++i) { os << " " << v[i]; } } @@ -461,7 +560,6 @@ __host__ __device__ void Select(float* dest_val, CommonPullValue common_pull_value; }; - struct FeatureValue { float delta_score; float show; @@ -533,6 +631,182 @@ struct FeaturePushValue { } }; +class VirtualAccessor { + public: + virtual int Configure(std::unordered_map config) = 0; + + virtual size_t GetFeatureValueSize(int& mf_dim) = 0; + + virtual size_t GetPushValueSize(int& mf_dim) = 0; + + // TODO: 在基类里调用cpu_table_accessor类型 + virtual void BuildFill( + void* gpu_val, + void* cpu_val, + paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + int mf_dim) = 0; + + // TODO: 在基类里调用cpu_table_accessor类型 + virtual void DumpFill( + float* gpu_val, + paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + int mf_dim) = 0; + + virtual void CopyForPull(const paddle::platform::Place& place, + uint64_t** gpu_keys, + const std::vector& values, + const float* total_values_gpu, + const int64_t* gpu_len, + const int slot_num, + const int hidden_size, + const int64_t total_length, + int* gpu_dim, + int feature_value_size) = 0; + + virtual void CopyForPush(const paddle::platform::Place& place, + const std::vector& grad_values, + float* total_grad_values_gpu, + const std::vector& slot_lengths, + const uint64_t total_length, + const int batch_size, + size_t grad_value_size, + std::vector& slot_vector, + std::vector& slot_mf_dim_vector) = 0; + + virtual std::string ParseToString(const float* v, int param_size) = 0; +}; + +template +class AccessorWrapper : public VirtualAccessor { + public: + explicit AccessorWrapper() {} + virtual ~AccessorWrapper() {} + AccessorWrapper(const AccessorWrapper&) = delete; + AccessorWrapper& operator=(const AccessorWrapper&) = delete; + + virtual int Configure(std::unordered_map config) { + return gpu_accessor_.Configure(config); + } + + virtual size_t GetFeatureValueSize(int& mf_dim) { + return gpu_accessor_.common_feature_value.Size(mf_dim); + } + + virtual size_t GetPushValueSize(int& mf_dim) { + return gpu_accessor_.common_push_value.Size(mf_dim); + } + + virtual void BuildFill( + void* gpu_val, + void* cpu_val, + paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + int mf_dim) { + gpu_accessor_.BuildFill( + (float*)(gpu_val), cpu_val, cpu_table_accessor, mf_dim); + } + + virtual void DumpFill( + float* gpu_val, + paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + int mf_dim) { + gpu_accessor_.DumpFill(gpu_val, cpu_table_accessor, mf_dim); + } + + virtual void CopyForPull(const paddle::platform::Place& place, + uint64_t** gpu_keys, + const std::vector& values, + const float* total_values_gpu, + const int64_t* gpu_len, + const int slot_num, + const int hidden_size, + const int64_t total_length, + int* gpu_dim, + int feature_value_size) { + CopyForPullImpl(place, + gpu_keys, + values, + total_values_gpu, + gpu_len, + slot_num, + hidden_size, + total_length, + gpu_dim, + feature_value_size); + } + + virtual void CopyForPush(const paddle::platform::Place& place, + const std::vector& grad_values, + float* total_grad_values_gpu, + const std::vector& slot_lengths, + const uint64_t total_length, + const int batch_size, + size_t grad_value_size, + std::vector& slot_vector, + std::vector& slot_mf_dim_vector) { + CopyForPushImpl(place, + grad_values, + total_grad_values_gpu, + slot_lengths, + total_length, + batch_size, + grad_value_size, + slot_vector, + slot_mf_dim_vector); + } + + void CopyForPullImpl(const paddle::platform::Place& place, + uint64_t** gpu_keys, + const std::vector& values, + const float* total_values_gpu, + const int64_t* gpu_len, + const int slot_num, + const int hidden_size, + const int64_t total_length, + int* gpu_dim, + int feature_value_size); + + void CopyForPushImpl(const paddle::platform::Place& place, + const std::vector& grad_values, + float* total_grad_values_gpu, + const std::vector& slot_lengths, + const uint64_t total_length, + const int batch_size, + size_t grad_value_size, + std::vector& slot_vector, + std::vector& slot_mf_dim_vector); + + virtual std::string ParseToString(const float* v, int param_size) { + return gpu_accessor_.ParseToString(v, param_size); + } + + GPUAccessor gpu_accessor_; +}; + +class GlobalAccessorTransfor { + public: + static GlobalAccessorTransfor& GetInstance() { + static GlobalAccessorTransfor ins; + return ins; + } + void Init(std::string accessor_type) { + if (accessor_wrapper_ptr_ != nullptr) { + return; + } + if (accessor_type == "CtrDymfAccessor") { + accessor_wrapper_ptr_ = new AccessorWrapper(); + } else { + VLOG(0) << "GlobalAccessorTransfor Init not support accessor_type:" + << accessor_type; + accessor_wrapper_ptr_ = new AccessorWrapper(); + } + } + VirtualAccessor* GetAccessorWrapper() { return accessor_wrapper_ptr_; } + + private: + VirtualAccessor* accessor_wrapper_ptr_ = nullptr; +}; + } // end namespace framework } // end namespace paddle + #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index 9a6581c2ae5e33..0e0d525c93a73e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -25,7 +25,8 @@ #ifdef PADDLE_WITH_HETERPS namespace paddle { namespace framework { -class GpuPsGraphTable : public HeterComm { +class GpuPsGraphTable + : public HeterComm { public: GpuPsGraphTable(std::shared_ptr resource, int topo_aware) : HeterComm(1, resource) { diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 3269b7a24e9c48..38abe87495c4ee 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -137,8 +137,12 @@ class HashTable { size_t len, StreamType stream); - template - void get(const KeyType* d_keys, char* d_vals, size_t len, StreamType stream); + template + void get(const KeyType* d_keys, + char* d_vals, + size_t len, + StreamType stream, + FVAccessor& fv_accessor); void show(); @@ -189,12 +193,12 @@ class HashTable { << " push value size: " << push_grad_value_size_; } - void set_accessor(CommonFeatureValueAccessor& accessor) { - feature_value_accessor_ = accessor; - } + // void set_accessor(FVAccessor& accessor) { + // feature_value_accessor_ = accessor; + // } std::unique_ptr rwlock_{nullptr}; - CommonFeatureValueAccessor feature_value_accessor_; + // FVAccessor feature_value_accessor_; private: #if defined(PADDLE_WITH_CUDA) diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 367925a051c119..aba6289b00ee47 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -83,13 +83,13 @@ __global__ void search_kernel(Table* table, } } -template +template __global__ void dy_mf_search_kernel(Table* table, const typename Table::key_type* const keys, char* vals, size_t len, size_t pull_feature_value_size, - FVAceessor feature_value_accessor) { + FVAccessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; // return; if (i < len) { @@ -99,7 +99,8 @@ __global__ void dy_mf_search_kernel(Table* table, uint64_t offset = i * pull_feature_value_size; float* cur = (float*)(vals + offset); float* input = it->second; - int mf_dim = int(input[feature_value_accessor.common_feature_value.MfDimIndex()]); + int mf_dim = + int(input[feature_value_accessor.common_feature_value.MfDimIndex()]); feature_value_accessor.FeatureValueFill(cur, input, mf_dim); } @@ -200,17 +201,18 @@ void HashTable::get(const KeyType* d_keys, } template -template +template void HashTable::get(const KeyType* d_keys, char* d_vals, size_t len, - StreamType stream) { + StreamType stream, + FVAccessor& fv_accessor) { if (len == 0) { return; } const int grid_size = (len - 1) / BLOCK_SIZE_ + 1; dy_mf_search_kernel<<>>( - container_, d_keys, d_vals, len, pull_feature_value_size_, feature_value_accessor_); + container_, d_keys, d_vals, len, pull_feature_value_size_, fv_accessor); } template @@ -349,15 +351,19 @@ template class HashTable; template class HashTable; template class HashTable; -template void HashTable::get< - cudaStream_t>(const unsigned long* d_keys, - float* d_vals, - size_t len, - cudaStream_t stream); +template void HashTable::get( + const unsigned long* d_keys, + float* d_vals, + size_t len, + cudaStream_t stream); template void -HashTable::get( - const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream); +HashTable::get( + const unsigned long* d_keys, + char* d_vals, + size_t len, + cudaStream_t stream, + CommonFeatureValueAccessor& fv_accessor); template void HashTable::get(const long* d_keys, int* d_vals, @@ -366,6 +372,13 @@ template void HashTable::get(const long* d_keys, template void HashTable::get( const unsigned long* d_keys, int* d_vals, size_t len, cudaStream_t stream); +template void HashTable::get( + const unsigned long* d_keys, + unsigned long* d_vals, + size_t len, + cudaStream_t stream); +template void HashTable::get( + const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get( const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get(const long* d_keys, @@ -381,19 +394,19 @@ template void HashTable::get( // const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t // stream); -template void HashTable::insert< - cudaStream_t>(const unsigned long* d_keys, - const float* d_vals, - size_t len, - cudaStream_t stream); +template void HashTable::insert( + const unsigned long* d_keys, + const float* d_vals, + size_t len, + cudaStream_t stream); -template void HashTable:: - insert(const unsigned long* d_keys, - size_t len, - char* pool, - size_t feature_value_size, - size_t start_index, - cudaStream_t stream); +template void HashTable::insert( + const unsigned long* d_keys, + size_t len, + char* pool, + size_t feature_value_size, + size_t start_index, + cudaStream_t stream); template void HashTable::insert(const long* d_keys, const int* d_vals, @@ -427,19 +440,34 @@ template void HashTable::insert( size_t len, cudaStream_t stream); -template void HashTable:: - dump_to_cpu(int devid, cudaStream_t stream); +template void HashTable::insert( + const unsigned long* d_keys, + const unsigned long* d_vals, + size_t len, + cudaStream_t stream); + +template void HashTable::dump_to_cpu( + int devid, cudaStream_t stream); template void -HashTable::update(const unsigned long* d_keys, const char* d_grads, size_t len, - SparseAdagradOptimizer sgd, - cudaStream_t stream); -template void -HashTable::update(const unsigned long* d_keys, const char* d_grads, size_t len, - SparseAdamOptimizer sgd, - cudaStream_t stream); +HashTable::update( + const unsigned long* d_keys, + const char* d_grads, + size_t len, + SparseAdagradOptimizer sgd, + cudaStream_t stream); template void -HashTable::update(const unsigned long* d_keys, const char* d_grads, size_t len, +HashTable::update( + const unsigned long* d_keys, + const char* d_grads, + size_t len, + SparseAdamOptimizer sgd, + cudaStream_t stream); +template void HashTable::update< + SparseAdamSharedOptimizer, + cudaStream_t>(const unsigned long* d_keys, + const char* d_grads, + size_t len, SparseAdamSharedOptimizer sgd, cudaStream_t stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h index 00a1583ee36797..12f94f0ad07849 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm.h @@ -46,11 +46,15 @@ namespace framework { #define TYPEALIGN(ALIGNVAL, LEN) \ (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1))) -template +template class HeterComm { public: HeterComm(size_t capacity, std::shared_ptr resource); - HeterComm(size_t capacity, std::shared_ptr resource, + HeterComm(size_t capacity, + std::shared_ptr resource, CommonFeatureValueAccessor& accessor); virtual ~HeterComm(); HeterComm(const HeterComm&) = delete; @@ -67,11 +71,8 @@ class HeterComm { GradType* d_grads, size_t len, int& uniq_len); // NOLINT - void dynamic_merge_grad(int gpu_num, - KeyType* d_keys, - float* d_grads, - size_t len, - int& uniq_len); + void dynamic_merge_grad( + int gpu_num, KeyType* d_keys, float* d_grads, size_t len, int& uniq_len); void pull_sparse(int num, KeyType* d_keys, float* d_vals, size_t len); void build_ps(int num, KeyType* h_keys, @@ -152,8 +153,11 @@ class HeterComm { max_mf_dim_ = max_mf_dim; } - void set_accessor(CommonFeatureValueAccessor& accessor) { - feature_value_accessor_ = accessor; + void set_accessor(FVAccessor& accessor) { + feature_value_accessor_ = accessor; + // for (auto& ptr_table: ptr_tables_) { + // ptr_table->set_accessor(feature_value_accessor_); + // } } #endif @@ -288,8 +292,8 @@ class HeterComm { char* src_val, size_t val_size); + FVAccessor feature_value_accessor_; - CommonFeatureValueAccessor feature_value_accessor_; protected: using Table = HashTable; using PtrTable = HashTable; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 2e2d4b85174713..7f0b4528d45d49 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" +#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" #include "paddle/fluid/platform/device_context.h" #ifdef PADDLE_WITH_XPU_KP @@ -24,8 +25,42 @@ limitations under the License. */ namespace paddle { namespace framework { -template -HeterComm::HeterComm( +// template +// HeterComm::HeterComm( +// size_t capacity, std::shared_ptr resource) { +// VLOG(1) << "Construct new HeterComm"; +// resource_ = resource; +// storage_.resize(resource_->total_device()); +// multi_mf_dim_ = resource->multi_mf(); +// load_factor_ = FLAGS_gpugraph_hbm_table_load_factor; +// VLOG(0) << "load_factor = " << load_factor_; +// for (int i = 0; i < resource_->total_device(); ++i) { +// #if defined(PADDLE_WITH_CUDA) +// platform::CUDADeviceGuard guard(resource_->dev_id(i)); +// allocators_.push_back(std::make_shared( +// 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT +// #endif +// if (!multi_mf_dim_) { +// auto table = new Table(capacity / load_factor_); +// tables_.push_back(table); +// } else { +// VLOG(0) << "Error:use HeterComm Construct with accessor"; +// return; +// } +// if (multi_node_) { +// storage_[i].init(feanum_, resource_->dev_id(i)); +// } +// } +// heter_comm_kernel_ = std::make_unique(block_size_); +// init_path(); +// } + +template +HeterComm::HeterComm( size_t capacity, std::shared_ptr resource) { VLOG(1) << "Construct new HeterComm"; resource_ = resource; @@ -36,48 +71,22 @@ HeterComm::HeterComm( platform::CUDADeviceGuard guard(resource_->dev_id(i)); allocators_.push_back(std::make_shared( 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT -#endif - if (!multi_mf_dim_) { - auto table = new Table(capacity / load_factor_); - tables_.push_back(table); - } else { - VLOG(0) << "Error:use HeterComm Construct with accessor"; - return; - } - if (multi_node_) { - storage_[i].init(feanum_, resource_->dev_id(i)); - } - } - heter_comm_kernel_ = std::make_unique(block_size_); - init_path(); -} - -template -HeterComm::HeterComm( - size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor& feature_value_accessor) { - VLOG(1) << "Construct new HeterComm"; - resource_ = resource; - storage_.resize(resource_->total_device()); - multi_mf_dim_ = resource->multi_mf(); - for (int i = 0; i < resource_->total_device(); ++i) { -#if defined(PADDLE_WITH_CUDA) - platform::CUDADeviceGuard guard(resource_->dev_id(i)); - allocators_.push_back(std::make_shared( - 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT #endif if (!multi_mf_dim_) { auto table = new Table(capacity / load_factor_); tables_.push_back(table); } else { max_mf_dim_ = resource_->max_mf_dim(); - feature_value_accessor_ = feature_value_accessor; - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); - size_t grad_type_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); - VLOG(0) << " HeterComm init, max feature_value_size:" << val_type_size + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); + size_t val_type_size = + accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_); + size_t grad_type_size = + accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); + VLOG(0) << " HeterComm init, max feature_value_size:" << val_type_size << ", feature_value_push_size:" << grad_type_size; auto ptr_table = new PtrTable(capacity / load_factor_); - ptr_table->set_accessor(feature_value_accessor_); + // ptr_table->set_accessor(feature_value_accessor_); ptr_table->set_feature_value_size(val_type_size, grad_type_size); ptr_tables_.push_back(ptr_table); } @@ -85,13 +94,17 @@ HeterComm::HeterComm( storage_[i].init(feanum_, resource_->dev_id(i)); } } - heter_comm_kernel_ = std::make_unique(block_size_, feature_value_accessor_); + // heter_comm_kernel_ = std::make_unique(block_size_, + // feature_value_accessor_); + heter_comm_kernel_ = std::make_unique(block_size_); init_path(); } - -template -void HeterComm::init_path() { +template +void HeterComm::init_path() { int total_device = resource_->total_device(); path_.resize(total_device); if (!topo_aware_) { @@ -143,14 +156,19 @@ void HeterComm::init_path() { } } -template +template template -void HeterComm::memory_copy(DstPlace dst_place, - void* dst, - SrcPlace src_place, - const void* src, - size_t count, - StreamType stream) { +<<<<<<< HEAD +void HeterComm::memory_copy( + DstPlace dst_place, + void* dst, + SrcPlace src_place, + const void* src, + size_t count, + StreamType stream) { #if defined(PADDLE_WITH_CUDA) cudaMemcpyAsync(dst, src, count, cudaMemcpyDefault, stream); if (stream == 0) { @@ -161,11 +179,12 @@ void HeterComm::memory_copy(DstPlace dst_place, #endif } -template -void HeterComm::create_storage(int start_index, - int end_index, - int keylen, - int vallen) { +template +void HeterComm::create_storage( + int start_index, int end_index, int keylen, int vallen) { #if defined(PADDLE_WITH_CUDA) auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_; @@ -199,9 +218,12 @@ void HeterComm::create_storage(int start_index, #endif } -template -void HeterComm::destroy_storage(int start_index, - int end_index) { +template +void HeterComm::destroy_storage( + int start_index, int end_index) { #if defined(PADDLE_WITH_CUDA) auto& allocator = allocators_[start_index]; auto& nodes = path_[start_index][end_index].nodes_; @@ -216,13 +238,17 @@ void HeterComm::destroy_storage(int start_index, #endif } -template -void HeterComm::walk_to_dest(int start_index, - int num, - int* h_left, - int* h_right, - KeyType* src_key, - GradType* src_val) { +template +void HeterComm::walk_to_dest( + int start_index, + int num, + int* h_left, + int* h_right, + KeyType* src_key, + GradType* src_val) { int need_copy_val = 0; if (src_val) { need_copy_val = 1; @@ -299,14 +325,18 @@ void HeterComm::walk_to_dest(int start_index, } } -template -void HeterComm::walk_to_dest(int start_index, - int gpu_num, - int* h_left, - int* h_right, - KeyType* src_key, - char* src_val, - size_t val_size) { +template +void HeterComm::walk_to_dest( + int start_index, + int gpu_num, + int* h_left, + int* h_right, + KeyType* src_key, + char* src_val, + size_t val_size) { int need_copy_val = 0; if (src_val) { need_copy_val = 1; @@ -359,13 +389,17 @@ void HeterComm::walk_to_dest(int start_index, } } -template -void HeterComm::walk_to_src(int start_index, - int gpu_num, - int* h_left, - int* h_right, - char* src_val, - size_t val_size) { +template +void HeterComm::walk_to_src( + int start_index, + int gpu_num, + int* h_left, + int* h_right, + char* src_val, + size_t val_size) { std::queue que; for (int i = 0; i < gpu_num; i++) { if (h_left[i] == -1 || h_right[i] == -1) { @@ -415,8 +449,11 @@ void HeterComm::walk_to_src(int start_index, } } -template -HeterComm::~HeterComm() { +template +HeterComm::~HeterComm() { if (!multi_mf_dim_) { for (auto& table : tables_) { delete table; @@ -434,15 +471,22 @@ HeterComm::~HeterComm() { } } -template -void HeterComm::show_one_table(int gpu_num) { +template +void HeterComm::show_one_table( + int gpu_num) { if (!multi_mf_dim_) { tables_[gpu_num]->show(); } } -template -int HeterComm::log2i(int x) { +template +int HeterComm::log2i(int x) { unsigned res = 0; while (x >>= 1) { ++res; @@ -450,13 +494,20 @@ int HeterComm::log2i(int x) { return res; } -template -int HeterComm::get_index_by_devid(int devid) { +template +int HeterComm::get_index_by_devid( + int devid) { return resource_->get_index_by_devid(devid); } -template -void HeterComm::set_sparse_sgd( +template +void HeterComm::set_sparse_sgd( const OptimizerConfig& optimizer_config) { for (int i = 0; i < resource_->total_device(); ++i) { AnyDeviceGuard guard(resource_->dev_id(i)); @@ -464,8 +515,11 @@ void HeterComm::set_sparse_sgd( } } -template -void HeterComm::set_embedx_sgd( +template +void HeterComm::set_embedx_sgd( const OptimizerConfig& optimizer_config) { for (int i = 0; i < resource_->total_device(); ++i) { AnyDeviceGuard guard(resource_->dev_id(i)); @@ -473,13 +527,17 @@ void HeterComm::set_embedx_sgd( } } -template -void HeterComm::build_ps(int dev_num, - KeyType* h_keys, - ValType* h_vals, - size_t len, - size_t chunk_size, - int stream_num) { +template +void HeterComm::build_ps( + int dev_num, + KeyType* h_keys, + ValType* h_vals, + size_t len, + size_t chunk_size, + int stream_num) { if (len <= 0) { return; } @@ -542,14 +600,18 @@ void HeterComm::build_ps(int dev_num, } } -template -void HeterComm::build_ps(int num, - KeyType* h_keys, - char* pool, - size_t len, - size_t feature_value_size, - size_t chunk_size, - int stream_num) { +template +void HeterComm::build_ps( + int num, + KeyType* h_keys, + char* pool, + size_t len, + size_t feature_value_size, + size_t chunk_size, + int stream_num) { if (len <= 0) { return; } @@ -604,8 +666,11 @@ void HeterComm::build_ps(int num, } } -template -void HeterComm::merge_grad( +template +void HeterComm::merge_grad( int dev_num, KeyType* d_keys, GradType* d_grads, @@ -678,13 +743,12 @@ void HeterComm::merge_grad( sync_stream(stream); } -template -void HeterComm::dynamic_merge_grad( - int gpu_num, - KeyType* d_keys, - float* d_grads, - size_t len, - int& uniq_len) { +template +void HeterComm::dynamic_merge_grad( + int gpu_num, KeyType* d_keys, float* d_grads, size_t len, int& uniq_len) { int dev_id = resource_->dev_id(gpu_num); platform::CUDAPlace place = platform::CUDAPlace(dev_id); platform::CUDADeviceGuard guard(dev_id); @@ -692,14 +756,15 @@ void HeterComm::dynamic_merge_grad( size_t temp_storage_bytes; - size_t grad_value_size = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); + size_t grad_value_size = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); auto d_merge_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_merge_keys_ptr = reinterpret_cast(d_merge_keys->ptr()); auto d_merge_grads = memory::Alloc(place, len * grad_value_size); - float* d_merge_grads_ptr = - reinterpret_cast(d_merge_grads->ptr()); + float* d_merge_grads_ptr = reinterpret_cast(d_merge_grads->ptr()); auto d_fea_num_info = memory::Alloc(place, sizeof(uint32_t) * (len * 3 + 1)); uint32_t* d_fea_num_info_ptr = @@ -794,7 +859,8 @@ void HeterComm::dynamic_merge_grad( uniq_len, grad_value_size, merger_, - stream); + stream, + feature_value_accessor_); PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(d_grads, d_merge_grads_ptr, @@ -804,8 +870,11 @@ void HeterComm::dynamic_merge_grad( PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream)); } -template -void HeterComm::split_input_to_shard( +template +void HeterComm::split_input_to_shard( KeyType* d_keys, int* d_idx_ptr, size_t len, @@ -865,11 +934,12 @@ void HeterComm::split_input_to_shard( sync_stream(stream); } -template -void HeterComm::pull_sparse(int num, - KeyType* d_keys, - float* d_vals, - size_t len) { +template +void HeterComm::pull_sparse( + int num, KeyType* d_keys, float* d_vals, size_t len) { if (len == 0) { return; } @@ -916,7 +986,9 @@ void HeterComm::pull_sparse(int num, auto d_idx = memory::Alloc(place, len * sizeof(int)); int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); - size_t val_type_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); + size_t val_type_size = accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_); VLOG(3) << "pull_sparse len:" << len << " val_type_size: " << val_type_size; auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType)); KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); @@ -967,7 +1039,8 @@ void HeterComm::pull_sparse(int num, ptr_tables_[i]->get(reinterpret_cast(node.key_storage), node.val_storage, h_right[i] - h_left[i] + 1, - resource_->remote_stream(i, num)); + resource_->remote_stream(i, num), + feature_value_accessor_); } for (int i = 0; i < total_device; ++i) { @@ -987,8 +1060,13 @@ void HeterComm::pull_sparse(int num, auto& node = path_[num][i].nodes_.front(); sync_stream(node.out_stream); } - heter_comm_kernel_->dy_mf_fill_dvals( - d_shard_vals_ptr, d_vals, d_idx_ptr, len, val_type_size, stream); + heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, + d_vals, + d_idx_ptr, + len, + val_type_size, + stream, + feature_value_accessor_); sync_stream(stream); @@ -1001,13 +1079,17 @@ void HeterComm::pull_sparse(int num, } #if defined(PADDLE_WITH_CUDA) -template +template template -void HeterComm::push_sparse(int dev_num, - KeyType* d_keys, - float* d_grads, - size_t len, - Sgd& sgd) { // NOLINT +void HeterComm::push_sparse( + int dev_num, + KeyType* d_keys, + float* d_grads, + size_t len, + Sgd& sgd) { // NOLINT if (len == 0) { return; } @@ -1015,8 +1097,9 @@ void HeterComm::push_sparse(int dev_num, int total_device = resource_->total_device(); int dev_id = resource_->dev_id(dev_num); - size_t grad_value_size = - TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); + size_t grad_value_size = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); DevPlace place = DevPlace(dev_id); AnyDeviceGuard guard(dev_id); auto stream = resource_->local_stream(dev_num, 0); @@ -1061,8 +1144,7 @@ void HeterComm::push_sparse(int dev_num, KeyType* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); auto d_shard_grads = memory::Alloc(place, len * grad_value_size); - float* d_shard_grads_ptr = - reinterpret_cast(d_shard_grads->ptr()); + float* d_shard_grads_ptr = reinterpret_cast(d_shard_grads->ptr()); int uniq_len = len; dynamic_merge_grad(dev_num, d_keys, d_grads, len, uniq_len); @@ -1073,92 +1155,93 @@ void HeterComm::push_sparse(int dev_num, d_keys, d_idx_ptr, uniq_len, d_left_ptr, d_right_ptr, dev_num); heter_comm_kernel_->dy_mf_fill_shard_grads(d_shard_keys_ptr, - d_keys, - d_shard_grads_ptr, - d_grads, - d_idx_ptr, - uniq_len, - grad_value_size, - stream); - } - - sync_stream(stream); - - auto dst_place = platform::CPUPlace(); - auto src_place = place; - memory_copy(dst_place, - h_left, - src_place, - d_left_ptr, - total_device * sizeof(int), - stream); - memory_copy(dst_place, - h_right, - src_place, - d_right_ptr, - total_device * sizeof(int), - stream); - - for (int i = 0; i < total_device; ++i) { - int shard_len = h_right[i] - h_left[i] + 1; - if (h_left[i] == -1 || h_right[i] == -1) { - continue; - } - create_storage( - dev_num, i, shard_len * sizeof(KeyType), shard_len * grad_value_size); - } + d_keys, + d_shard_grads_ptr, + d_grads, + d_idx_ptr, + uniq_len, + grad_value_size, + stream, + feature_value_accessor_); +} - walk_to_dest(dev_num, - total_device, - h_left, - h_right, - d_shard_keys_ptr, - reinterpret_cast(d_shard_grads_ptr), - grad_value_size); +sync_stream(stream); + +auto dst_place = platform::CPUPlace(); +auto src_place = place; +memory_copy(dst_place, + h_left, + src_place, + d_left_ptr, + total_device * sizeof(int), + stream); +memory_copy(dst_place, + h_right, + src_place, + d_right_ptr, + total_device * sizeof(int), + stream); + +for (int i = 0; i < total_device; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (h_left[i] == -1 || h_right[i] == -1) { + continue; } + create_storage( + dev_num, i, shard_len * sizeof(KeyType), shard_len * grad_value_size); +} - for (int i = 0; i < total_device; ++i) { - if (h_left[i] == -1 || h_right[i] == -1) { - continue; - } - auto& node = path_[dev_num][i].nodes_.back(); - sync_stream(node.in_stream); +walk_to_dest(dev_num, + total_device, + h_left, + h_right, + d_shard_keys_ptr, + reinterpret_cast(d_shard_grads_ptr), + grad_value_size); +} - AnyDeviceGuard guard(resource_->dev_id(i)); - ptr_tables_[i]->rwlock_->WRLock(); - ptr_tables_[i]->update(reinterpret_cast(node.key_storage), - node.val_storage, - h_right[i] - h_left[i] + 1, - sgd, - resource_->remote_stream(i, dev_num)); +for (int i = 0; i < total_device; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; } + auto& node = path_[dev_num][i].nodes_.back(); + sync_stream(node.in_stream); + + AnyDeviceGuard guard(resource_->dev_id(i)); + ptr_tables_[i]->rwlock_->WRLock(); + ptr_tables_[i]->update(reinterpret_cast(node.key_storage), + node.val_storage, + h_right[i] - h_left[i] + 1, + sgd, + resource_->remote_stream(i, dev_num)); +} - for (int i = 0; i < total_device; ++i) { - sync_stream(resource_->remote_stream(i, dev_num)); - if (h_left[i] != -1) { - if (!multi_mf_dim_) { - tables_[i]->rwlock_->UNLock(); - } else { - ptr_tables_[i]->rwlock_->UNLock(); - } +for (int i = 0; i < total_device; ++i) { + sync_stream(resource_->remote_stream(i, dev_num)); + if (h_left[i] != -1) { + if (!multi_mf_dim_) { + tables_[i]->rwlock_->UNLock(); + } else { + ptr_tables_[i]->rwlock_->UNLock(); } } +} - for (int i = 0; i < total_device; ++i) { - if (h_left[i] == -1 || h_right[i] == -1) { - continue; - - } - destroy_storage(dev_num, i); +for (int i = 0; i < total_device; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; } + destroy_storage(dev_num, i); +} } #elif defined(PADDLE_WITH_XPU_KP) -template -void HeterComm::push_sparse(int dev_num, - KeyType* d_keys, - GradType* d_grads, - size_t len) { +template +void HeterComm::push_sparse( + int dev_num, KeyType* d_keys, GradType* d_grads, size_t len) { if (len == 0) { return; } @@ -1294,9 +1377,12 @@ void HeterComm::push_sparse(int dev_num, #endif #if defined(PADDLE_WITH_CUDA) -template +template template -void HeterComm::update_one_table( +void HeterComm::update_one_table( int gpu_num, KeyType* d_keys, GradType* d_grads, @@ -1315,9 +1401,12 @@ void HeterComm::update_one_table( cudaStreamSynchronize(resource_->remote_stream(gpu_num, gpu_num)); } -template +template template -void HeterComm::push_sparse_multi_node( +void HeterComm::push_sparse_multi_node( int gpu_num, KeyType* d_keys, GradType* d_grads, @@ -1344,8 +1433,11 @@ void HeterComm::push_sparse_multi_node( sgd); } -template -int HeterComm::gather_one_node_grad( +template +int HeterComm::gather_one_node_grad( int gpu_num, KeyType* d_keys, GradType* d_grads, int len) { int total_gpu = resource_->total_device(); int dev_id = resource_->dev_id(gpu_num); @@ -1446,8 +1538,11 @@ int HeterComm::gather_one_node_grad( return ret; } -template -int HeterComm::gather_multi_node_grad( +template +int HeterComm::gather_multi_node_grad( int gpu_num, KeyType* d_keys, GradType* d_grads, int len) { int dev_id = resource_->dev_id(gpu_num); auto& storage = storage_[gpu_num]; @@ -1517,8 +1612,11 @@ int HeterComm::gather_multi_node_grad( } #endif -template -void HeterComm::end_pass() { +template +void HeterComm::end_pass() { int total_device = resource_->total_device(); std::vector threads; @@ -1539,8 +1637,10 @@ void HeterComm::end_pass() { } } -// template -// void HeterComm::dump_to_cpu(int index) { +// template +// void HeterComm::dump_to_cpu(int +// index) { // auto stream = resource_->local_stream(index, 0); // int dev_id = resource_->dev_id(index); // platform::CUDADeviceGuard guard(dev_id); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index fe23c39b0e3387..15e31e450de72d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -128,26 +128,28 @@ __global__ void fill_dvals_kernel(ValType* d_shard_vals, } } -template -__global__ void dy_mf_fill_shard_grads_kernel(KeyType* d_shard_keys, - KeyType* d_keys, - float* d_shard_grads, - float* d_grads, - T* idx, - size_t len, - size_t grad_value_size, - FVAceessor feature_value_accessor) { +template +__global__ void dy_mf_fill_shard_grads_kernel( + KeyType* d_shard_keys, + KeyType* d_keys, + float* d_shard_grads, + float* d_grads, + T* idx, + size_t len, + size_t grad_value_size, + FVAccessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { d_shard_keys[i] = d_keys[idx[i]]; float* cur = (float*)((char*)d_shard_grads + i * grad_value_size); - float* shard_val = (float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); + float* shard_val = + (float*)((char*)d_grads + uint64_t(idx[i]) * grad_value_size); feature_value_accessor.PushValueFill(cur, shard_val); } } -template +template __global__ void merge_gradients_kernel(const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, @@ -156,38 +158,37 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, int n, size_t grad_value_size, DynamicGradMerger& merger, - FVAceessor& feature_value_accessor) { + FVAccessor& feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { uint32_t start = offset[i]; uint32_t num = fea_num[i]; int ori_index = index[start]; float* out = (float*)(output + i * grad_value_size); - float* in = - (float*)(input + size_t(ori_index) * grad_value_size); + float* in = (float*)(input + size_t(ori_index) * grad_value_size); merger_.update_one(out, in, feature_value_accessor); for (int j = 1; j < num; ++j) { ori_index = index[start + j]; - float& rhs = - *(float*)(input + size_t(ori_index) * grad_value_size); + float& rhs = *(float*)(input + size_t(ori_index) * grad_value_size); merger_.merge_one(out, rhs, feature_value_accessor); } } } -template +template __global__ void dy_mf_fill_dvals_kernel(float* d_shard_vals, float* d_vals, T* idx, size_t len, size_t val_size, - FVAceessor feature_value_accessor) { + FVAccessor feature_value_accessor) { const size_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { uint64_t new_offset = uint64_t(idx[i]) * val_size; float* cur = (float*)((char*)d_vals + new_offset); float* shard_val = (float*)((char*)d_shard_vals + uint64_t(i) * val_size); - int mf_dim = int(shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); + int mf_dim = int( + shard_val[feature_value_accessor.common_feature_value.MfDimIndex()]); feature_value_accessor.FeatureValueFill(cur, shard_val, mf_dim); } @@ -321,15 +322,20 @@ void HeterCommKernel::reduce_by_key(void* d_temp_storage, debug_synchronous)); } -template -void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys, - KeyType* d_keys, - GradType* d_shard_grads, - GradType* d_grads, - T* idx, - long long len, - size_t grad_value_size, - const StreamType& stream) { +template +void HeterCommKernel::dy_mf_fill_shard_grads( + KeyType* d_shard_keys, + KeyType* d_keys, + GradType* d_shard_grads, + GradType* d_grads, + T* idx, + long long len, + size_t grad_value_size, + const StreamType& stream, + CommonFeatureValueAccessor& feature_value_accessor) { int grid_size = (len - 1) / block_size_ + 1; size_t c_len = (size_t)len; dy_mf_fill_shard_grads_kernel<<>>( @@ -343,7 +349,7 @@ void HeterCommKernel::dy_mf_fill_shard_grads(KeyType* d_shard_keys, feature_value_accessor_); } -template +template void HeterCommKernel::merge_gradient(const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, @@ -352,23 +358,33 @@ void HeterCommKernel::merge_gradient(const uint32_t* offset, int n, size_t grad_value_size, DynamicGradMerger& merger_, - const StreamType& stream) { + const StreamType& stream, + FVAccessor& feature_value_accessor) { int grid_size = (n - 1) / block_size_ + 1; merge_gradients_kernel<<>>( - offset, fea_num, index, input, output, n, grad_value_size, merger_, feature_value_accessor_); + offset, + fea_num, + index, + input, + output, + n, + grad_value_size, + merger_, + feature_value_accessor_); } -template +template void HeterCommKernel::dy_mf_fill_dvals(float* d_shard_vals, float* d_vals, T* idx, long long len, size_t val_size, - const StreamType& stream) { + const StreamType& stream, + FVAccessor& feature_value_accessor) { int grid_size = (len - 1) / block_size_ + 1; size_t c_len = (size_t)len; dy_mf_fill_dvals_kernel<<>>( - d_shard_vals, d_vals, idx, c_len, val_size, feature_value_accessor_); + d_shard_vals, d_vals, idx, c_len, val_size, feature_value_accessor); } template void HeterCommKernel::fill_idx( @@ -412,17 +428,15 @@ template void HeterCommKernel::fill_shard_key( long long len, const cudaStream_t& stream); -template void HeterCommKernel::fill_shard_grads< - unsigned long, - float, - int, - cudaStream_t>(unsigned long* d_shard_keys, - unsigned long* d_keys, - float* d_shard_grads, - float* d_grads, - int* idx, - long long len, - const cudaStream_t& stream); +template void +HeterCommKernel::fill_shard_grads( + unsigned long* d_shard_keys, + unsigned long* d_keys, + float* d_shard_grads, + float* d_grads, + int* idx, + long long len, + const cudaStream_t& stream); template void HeterCommKernel::fill_dvals( @@ -477,19 +491,23 @@ template void HeterCommKernel::reduce_by_key< cudaStream_t stream, bool debug_synchronous); -template void HeterCommKernel::dy_mf_fill_shard_grads< - unsigned long, - int, - cudaStream_t>(unsigned long* d_shard_keys, - unsigned long* d_keys, - float* d_shard_grads, - float* d_grads, - int* idx, - long long len, - size_t grad_value_size, - const cudaStream_t& stream); - -template void HeterCommKernel::merge_gradient( +template void +HeterCommKernel::dy_mf_fill_shard_grads( + unsigned long* d_shard_keys, + unsigned long* d_keys, + float* d_shard_grads, + float* d_grads, + int* idx, + long long len, + size_t grad_value_size, + const cudaStream_t& stream, + CommonFeatureValueAccessor& feature_value_accessor); + +template void +HeterCommKernel::merge_gradient( const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, @@ -498,16 +516,18 @@ template void HeterCommKernel::merge_gradient( int n, size_t grad_value_size, DynamicGradMerger& merger_, - const cudaStream_t& stream); + const cudaStream_t& stream, + CommonFeatureValueAccessor& feature_value_accessor); template void HeterCommKernel:: - dy_mf_fill_dvals( + dy_mf_fill_dvals( float* d_shard_vals, float* d_vals, int* idx, long long len, size_t val_size, - const cudaStream_t& stream); + const cudaStream_t& stream, + CommonFeatureValueAccessor& feature_value_accessor); #endif } // namespace framework diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 96ec86f5b36c25..bf82dfab165292 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -41,15 +41,18 @@ struct DynamicGradMerger { return out; } - __device__ __forceinline__ void update_one(float* output, const float* input, - CommonFeatureValueAccessor& feature_value_accessor) { + __device__ __forceinline__ void update_one( + float* output, + const float* input, + CommonFeatureValueAccessor& feature_value_accessor) { feature_value_accessor.PushValueFill(output, input); } - __device__ __forceinline__ void merge_one(float* output, const float* input, - CommonFeatureValueAccessor& feature_value_accessor) { + __device__ __forceinline__ void merge_one( + float* output, + const float* input, + CommonFeatureValueAccessor& feature_value_accessor) { feature_value_accessor.MergePushValue(output, input); - } }; @@ -58,7 +61,10 @@ class HeterCommKernel { HeterCommKernel() {} explicit HeterCommKernel(const int block_size) : block_size_(block_size) {} - explicit HeterCommKernel(const int block_size, CommonFeatureValueAccessor& feature_value_accessor) : block_size_(block_size), feature_value_accessor_(feature_value_accessor) {} + // explicit HeterCommKernel(const int block_size, CommonFeatureValueAccessor& + // feature_value_accessor) : block_size_(block_size), + // feature_value_accessor_(feature_value_accessor) {} + // explicit HeterCommKernel(const int block_size) : block_size_(block_size) {} template void fill_idx(T* idx, long long len, const StreamType& stream); @@ -139,7 +145,8 @@ class HeterCommKernel { template + typename StreamType, + typename FVAccessor> void dy_mf_fill_shard_grads(KeyType* d_shard_keys, KeyType* d_keys, float* d_shard_grads, @@ -147,9 +154,10 @@ class HeterCommKernel { T* idx, long long len, size_t grad_value_size, - const StreamType& stream); + const StreamType& stream, + FVAccessor& feature_value_accessor); - template + template void merge_gradient(const uint32_t* offset, const uint32_t* fea_num, const uint32_t* index, @@ -158,17 +166,19 @@ class HeterCommKernel { int n, size_t grad_value_size, DynamicGradMerger& merger_, - const StreamType& stream); + const StreamType& stream, + FVAccessor& feature_value_accessor); - template + template void dy_mf_fill_dvals(float* d_shard_vals, float* d_vals, T* idx, long long len, size_t val_size, - const StreamType& stream); + const StreamType& stream, + FVAccessor& feature_value_accessor); - CommonFeatureValueAccessor feature_value_accessor_; + // CommonFeatureValueAccessor feature_value_accessor_; private: int block_size_{256}; }; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc index 29b6c525971b12..4eff4a8ad55b94 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cc @@ -22,19 +22,31 @@ namespace paddle { namespace framework { HeterPsBase* HeterPsBase::get_instance( - size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, + size_t capacity, + std::shared_ptr resource, + std::unordered_map fleet_config, + std::string accessor_type, int optimizer_type) { - return new HeterPs(capacity, resource, feature_value_accessor, optimizer_type); + if (accessor_type == "CtrDymfAccessor" && + (optimizer_type == 1 || optimizer_type == 3 || optimizer_type == 4)) { + return new HeterPs( + capacity, resource, accessor_type, fleet_config, optimizer_type); + } else { + VLOG(0) << " HeterPsBase get_instance Warning: now only support " + "CtrDymfAccessor, but get " + << accessor_type_; + return new HeterPs( + capacity, resource, accessor_type, fleet_config, optimizer_type); + } } -HeterPs::HeterPs(size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, - int optimizer_type) { - comm_ = - std::make_shared>( - capacity, resource); - feature_value_accessor_ = feature_value_accessor; +HeterPs::HeterPs(size_t capacity, + std::shared_ptr resource, + std::unordered_map fleet_config, + std::string accessor_type, + int optimizer_type) { + comm_ = std::make_shared>( + capacity, resource); optimizer_type_ = optimizer_type; } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu index a653f08253b141..b059690990370e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.cu @@ -22,88 +22,136 @@ namespace paddle { namespace framework { HeterPsBase* HeterPsBase::get_instance( - size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, + size_t capacity, + std::shared_ptr resource, + std::unordered_map fleet_config, + std::string accessor_type, int optimizer_type) { - return new HeterPs(capacity, resource, feature_value_accessor, optimizer_type); + if (accessor_type == "CtrDymfAccessor" && + (optimizer_type == 1 || optimizer_type == 3 || optimizer_type == 4)) { + return new HeterPs( + capacity, resource, fleet_config, accessor_type, optimizer_type); + } else { + VLOG(0) << " HeterPsBase get_instance Warning: now only support " + "CtrDymfAccessor, but get " + << accessor_type; + return new HeterPs( + capacity, resource, fleet_config, accessor_type, optimizer_type); + } } -HeterPs::HeterPs(size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, - int optimizer_type) { - comm_ = - std::make_shared>( - capacity, resource, feature_value_accessor); - feature_value_accessor_ = feature_value_accessor; +template +HeterPs::HeterPs( + size_t capacity, + std::shared_ptr resource, + std::unordered_map fleet_config, + std::string accessor_type, + int optimizer_type) { + comm_ = std::make_shared>( + capacity, resource); + feature_value_accessor_.Configure(fleet_config); + set_accessor(feature_value_accessor_); + accessor_type_ = accessor_type; optimizer_type_ = optimizer_type; } -HeterPs::~HeterPs() {} +template +HeterPs::~HeterPs() {} -void HeterPs::pull_sparse(int num, - FeatureKey* d_keys, - float* d_vals, - size_t len) { +template +void HeterPs::pull_sparse(int num, + FeatureKey* d_keys, + float* d_vals, + size_t len) { comm_->pull_sparse(num, d_keys, d_vals, len); } -void HeterPs::build_ps(int num, - FeatureKey* h_keys, - char* pool, - size_t len, - size_t feature_value_size, - size_t chunk_size, - int stream_num) { +template +void HeterPs::build_ps(int num, + FeatureKey* h_keys, + char* pool, + size_t len, + size_t feature_value_size, + size_t chunk_size, + int stream_num) { comm_->build_ps( num, h_keys, pool, len, feature_value_size, chunk_size, stream_num); } -int HeterPs::get_index_by_devid(int devid) { +template +int HeterPs::get_index_by_devid(int devid) { return comm_->get_index_by_devid(devid); } -void HeterPs::set_sparse_sgd(const OptimizerConfig& optimizer_config) { +template +void HeterPs::set_sparse_sgd( + const OptimizerConfig& optimizer_config) { comm_->set_sparse_sgd(optimizer_config); } -void HeterPs::set_embedx_sgd(const OptimizerConfig& optimizer_config) { +template +void HeterPs::set_embedx_sgd( + const OptimizerConfig& optimizer_config) { comm_->set_embedx_sgd(optimizer_config); } -void HeterPs::end_pass() { comm_->end_pass(); } - -void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); } - -void HeterPs::push_sparse(int num, - FeatureKey* d_keys, - float* d_grads, - size_t len) { - if (optimizer_type_ == 3) { //adam - auto optimizer = SparseAdamOptimizer(feature_value_accessor_); - VLOG(5) << "INTO push_sparse SparseAdamOptimizer, EmbedDim():" << optimizer.EmbedDim(); - comm_->push_sparse(num, d_keys, d_grads, len, optimizer); - } else if (optimizer_type_ == 4) { //shared_adam - auto optimizer = SparseAdamSharedOptimizer(feature_value_accessor_); - VLOG(5) << "INTO push_sparse SparseAdamSharedOptimizer, EmbedDim():" << optimizer.EmbedDim(); - comm_->push_sparse(num, d_keys, d_grads, len, optimizer); +template +void HeterPs::end_pass() { + comm_->end_pass(); +} + +template +void HeterPs::show_one_table(int gpu_num) { + comm_->show_one_table(gpu_num); +} + +template +void HeterPs::push_sparse(int num, + FeatureKey* d_keys, + float* d_grads, + size_t len) { + if (accessor_type_ == "CtrDymfAccessor") { + if (optimizer_type_ == 3) { // adam + auto optimizer = SparseAdamOptimizer(feature_value_accessor_); + VLOG(5) << "INTO push_sparse SparseAdamOptimizer, EmbedDim():" + << optimizer.EmbedDim(); + comm_->push_sparse(num, d_keys, d_grads, len, optimizer); + } else if (optimizer_type_ == 4) { // shared_adam + auto optimizer = SparseAdamSharedOptimizer(feature_value_accessor_); + VLOG(5) << "INTO push_sparse SparseAdamSharedOptimizer, EmbedDim():" + << optimizer.EmbedDim(); + comm_->push_sparse(num, d_keys, d_grads, len, optimizer); + } else if (optimizer_type_ == 1) { // adagrad { + auto optimizer = SparseAdagradOptimizer(feature_value_accessor_); + VLOG(5) << "INTO push_sparse SparseAdagradOptimizer, EmbedDim():" + << optimizer.EmbedDim(); + comm_->push_sparse(num, d_keys, d_grads, len, optimizer); + } else { + VLOG(0) << " push sparse Error: CtrDymfAccessor only support adagrad(1)," + "adam(3) or shared_adam(4), bug get optimizer type:" + << optimizer_type_; + } } else { - auto optimizer = SparseAdagradOptimizer(feature_value_accessor_); - VLOG(5) << "INTO push_sparse SparseAdagradOptimizer, EmbedDim():" << optimizer.EmbedDim(); - comm_->push_sparse(num, d_keys, d_grads, len, optimizer); + VLOG(0) << " push sparse Error: now only support CtrDymfAccessor, but get " + << accessor_type_; } } -void HeterPs::set_nccl_comm_and_size(const std::vector& inner_comms, - const std::vector& inter_comms, - int comm_size) { +template +void HeterPs::set_nccl_comm_and_size( + const std::vector& inner_comms, + const std::vector& inter_comms, + int comm_size) { comm_->set_nccl_comm_and_size(inner_comms, inter_comms, comm_size); } -void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) { +template +void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) { comm_->set_multi_mf_dim(multi_mf_dim, max_mf_dim); } -void HeterPs::set_accessor(CommonFeatureValueAccessor& accessor) { +template +void HeterPs::set_accessor(FVAccessor& accessor) { comm_->set_accessor(accessor); } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h index 109facb5828cc2..439f5d6c818544 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps.h @@ -26,12 +26,15 @@ limitations under the License. */ namespace paddle { namespace framework { +template class HeterPs : public HeterPsBase { public: HeterPs() {} - HeterPs(size_t capacity, std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, - int optimizer_type); + HeterPs(size_t capacity, + std::shared_ptr resource, + std::unordered_map fleet_config, + std::string accessor_type, + int optimizer_type); virtual ~HeterPs(); HeterPs(const HeterPs&) = delete; HeterPs& operator=(const HeterPs&) = delete; @@ -52,7 +55,8 @@ class HeterPs : public HeterPsBase { const std::vector& inter_comms, int comm_size) override; void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) override; - void set_accessor(CommonFeatureValueAccessor& accessor) override; + + void set_accessor(FVAccessor& accessor); #endif void set_sparse_sgd(const OptimizerConfig& optimizer_config) override; @@ -67,9 +71,10 @@ class HeterPs : public HeterPsBase { size_t len) override; private: - std::shared_ptr> comm_; + std::shared_ptr> comm_; #if defined(PADDLE_WITH_CUDA) - CommonFeatureValueAccessor feature_value_accessor_; + FVAccessor feature_value_accessor_; + std::string accessor_type_; int optimizer_type_; #endif }; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h index 0769b9280ef4d8..e45d1db71ccae9 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h @@ -50,7 +50,6 @@ class HeterPsBase { const std::vector& inter_comms, int comm_size) = 0; virtual void set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) = 0; - virtual void set_accessor(CommonFeatureValueAccessor& accessor) = 0; #endif virtual void end_pass() = 0; @@ -63,10 +62,13 @@ class HeterPsBase { virtual void set_sparse_sgd(const OptimizerConfig& optimizer_config) = 0; virtual void set_embedx_sgd(const OptimizerConfig& optimizer_config) = 0; - static HeterPsBase* get_instance(size_t capacity, - std::shared_ptr resource, - CommonFeatureValueAccessor feature_value_accessor, - int optimizer_type); + static HeterPsBase* get_instance( + size_t capacity, + std::shared_ptr resource, + // CommonFeatureValueAccessor feature_value_accessor, + std::unordered_map fleet_config, + std::string accessor_type, + int optimizer_type); }; } // end namespace framework diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 3f7ea8e141849f..62265c3df6a4dd 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -33,8 +33,8 @@ limitations under the License. */ #include #include -#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/framework/data_set.h" +#include "paddle/fluid/platform/timer.h" #if defined(PADDLE_WITH_PSCORE) #include "paddle/fluid/distributed/ps/table/depends/feature_value.h" #endif @@ -538,14 +538,14 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { &device_vals, &device_task_keys, &device_task_ptrs](int dev, int shard_id) { - // auto& task_keys = device_task_keys[shard_id]; + // auto& task_keys = device_task_keys[shard_id]; #ifdef PADDLE_WITH_PSLIB auto& task_ptrs = device_task_ptrs[shard_id]; #endif -// #ifdef PADDLE_WITH_PSCORE -// auto& task_ptrs = device_task_ptrs[shard_id]; -// #endif + // #ifdef PADDLE_WITH_PSCORE + // auto& task_ptrs = device_task_ptrs[shard_id]; + // #endif // int len = prefix_sum[dev][shard_id + 1] - prefix_sum[dev][shard_id]; // int cur = prefix_sum[dev][shard_id]; @@ -577,7 +577,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { } } #endif -// #ifdef PADDLE_WITH_PSCORE + // #ifdef PADDLE_WITH_PSCORE // for (int j = 0; j < len; ++j) { // device_keys[dev][cur + j] = task_keys[dev][j]; // float* ptr_val = task_ptrs[dev][j]->data(); @@ -603,7 +603,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { // } // } // } -// #endif + // #endif VLOG(3) << "GpuPs build hbmps done"; }; @@ -652,22 +652,26 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { return; } std::vector threads(device_num); - HeterPs_ = HeterPsBase::get_instance(size_max, resource_, feature_value_accessor_, optimizer_type_); + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); + HeterPs_ = HeterPsBase::get_instance( + size_max, resource_, fleet_config_, accessor_class_, optimizer_type_); #ifdef PADDLE_WITH_CUDA HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_); HeterPs_->set_sparse_sgd(optimizer_config_); HeterPs_->set_embedx_sgd(optimizer_config_); #endif - auto build_dymf_mem_pool = [this, &gpu_task](int i, int j) { + auto build_dymf_mem_pool = [this, &gpu_task, &accessor_wrapper_ptr](int i, + int j) { this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); // this->HeterPs_->set_accessor(feature_value_accessor_); int mf_dim = this->index_dim_vec_[j]; VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim - << " feature_value_dim:" << feature_value_accessor_.common_feature_value.Dim(mf_dim) - << " feature_value_size:" << feature_value_accessor_.common_feature_value.Size(mf_dim); - size_t feature_value_size = - TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(mf_dim)); + << " feature_value_size:" + << accessor_wrapper_ptr->GetFeatureValueSize(mf_dim); + size_t feature_value_size = + accessor_wrapper_ptr->GetFeatureValueSize(mf_dim); auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; size_t len = device_dim_keys.size(); @@ -675,12 +679,13 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { this->mem_pools_[i * this->multi_mf_dim_ + j] = new MemoryPool(len, feature_value_size); }; - auto build_dymf_hbm_pool = [this, &gpu_task](int i, int j) { + auto build_dymf_hbm_pool = [this, &gpu_task, &accessor_wrapper_ptr](int i, + int j) { auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; size_t len = device_dim_keys.size(); int mf_dim = this->index_dim_vec_[j]; size_t feature_value_size = - TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float))); + accessor_wrapper_ptr->GetFeatureValueSize(mf_dim); auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; platform::CUDADeviceGuard guard(resource_->dev_id(i)); @@ -703,84 +708,92 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { delete mem_pool; }; int thread_num = 16; - auto build_dynamic_mf_func = [this, &gpu_task, thread_num]( - int i, int j, int z) { - // this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); - int mf_dim = this->index_dim_vec_[j]; - VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim; - // size_t feature_value_size = - // TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * sizeof(float))); - auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; - auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; - size_t len = device_dim_keys.size(); - CHECK(len == device_dim_ptrs.size()); - // this->mem_pools_[i * this->multi_mf_dim_ + j] = - // new MemoryPool(len, feature_value_size); - auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; - - // ============ add for multi-thread ================ - size_t len_per_thread = len / thread_num; - size_t remain = len % thread_num; - size_t left = 0, right = 0; - - size_t real_len = len_per_thread; - if ((size_t)z < remain) real_len++; - - if ((size_t)z < remain) { - left = z * (len_per_thread + 1); - right = left + real_len; - } else { - left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread; - right = left + real_len; - } - // ============ add for multi-thread ================ + auto build_dynamic_mf_func = + [this, &gpu_task, thread_num](int i, int j, int z) { + // this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); + int mf_dim = this->index_dim_vec_[j]; + VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim; + // size_t feature_value_size = + // TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * + // sizeof(float))); + auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; + auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; + size_t len = device_dim_keys.size(); + CHECK(len == device_dim_ptrs.size()); + // this->mem_pools_[i * this->multi_mf_dim_ + j] = + // new MemoryPool(len, feature_value_size); + auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; + + // ============ add for multi-thread ================ + size_t len_per_thread = len / thread_num; + size_t remain = len % thread_num; + size_t left = 0, right = 0; + + size_t real_len = len_per_thread; + if ((size_t)z < remain) real_len++; + + if ((size_t)z < remain) { + left = z * (len_per_thread + 1); + right = left + real_len; + } else { + left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread; + right = left + real_len; + } + // ============ add for multi-thread ================ - for (size_t k = left; k < right; k++) { - FeatureValue* val = (FeatureValue*)(mem_pool->mem_address(k)); - float* ptr_val = device_dim_ptrs[k]->data(); - size_t dim = device_dim_ptrs[k]->size(); + for (size_t k = left; k < right; k++) { + void* val = mem_pool->mem_address(k); + float* ptr_val = device_dim_ptrs[k]->data(); + size_t dim = device_dim_ptrs[k]->size(); #ifdef PADDLE_WITH_PSLIB - val->delta_score = - ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::delta_score_index()]; - val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::show_index()]; - val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::click_index()]; - val->slot = int(ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::slot_index()]); - val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::embed_w_index()]; - val->lr_g2sum = + val->delta_score = + ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::delta_score_index()]; + val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::show_index()]; + val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::click_index()]; + val->slot = + int(ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::slot_index()]); + val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_w_index()]; + val->lr_g2sum = + ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_g2sum_index()]; + // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::embed_g2sum_index()]; - // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor - ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - mf_dim_index()] = float(mf_dim); - val->mf_dim = mf_dim; - if (dim > 8) { // CpuPS alreay expand as mf_dim - val->mf_size = mf_dim + 1; - for (int x = 0; x < val->mf_dim + 1; x++) { - val->mf[x] = ptr_val[x + 8]; - } - } else { - val->mf_size = 0; - for (int x = 0; x < val->mf_dim + 1; x++) { - val->mf[x] = 0; + DownpourCtrDymfFeatureValue::mf_dim_index()] = + float(mf_dim); + val->mf_dim = mf_dim; + if (dim > 8) { // CpuPS alreay expand as mf_dim + val->mf_size = mf_dim + 1; + for (int x = 0; x < val->mf_dim + 1; x++) { + val->mf[x] = ptr_val[x + 8]; + } + } else { + val->mf_size = 0; + for (int x = 0; x < val->mf_dim + 1; x++) { + val->mf[x] = 0; + } + } } - } - } #endif #ifdef PADDLE_WITH_PSCORE - VLOG(5) << "cpu build "<< k << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) - << " |: "<< cpu_table_accessor_->ParseToString(ptr_val, dim); - feature_value_accessor_.BuildFill(val, ptr_val, cpu_table_accessor_, mf_dim, dim); - *(reinterpret_cast(val + feature_value_accessor_.common_feature_value.CpuPtrIndex())) = (uint64_t)(device_dim_ptrs[k]); - VLOG(5) << "build "<< k << " : "<< feature_value_accessor_.ParseToString(val, feature_value_accessor_.common_feature_value.Dim(mf_dim)); - } + VLOG(5) << "cpu build " << k + << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) + << " |: " << cpu_table_accessor_->ParseToString(ptr_val, dim); + accessor_wrapper_ptr->BuildFill( + val, device_dim_ptrs[k], cpu_table_accessor_, mf_dim); + VLOG(5) << "build " << k << " : " + << accessor_wrapper_ptr->ParseToString( + (float*)(val), + int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / + sizeof(float))); + } #endif - threads.resize(device_num * multi_mf_dim_); + threads.resize(device_num * multi_mf_dim_); for (int i = 0; i < device_num; i++) { for (int j = 0; j < multi_mf_dim_; j++) { threads[i + j * device_num] = std::thread(build_dymf_mem_pool, i, j); @@ -928,121 +941,129 @@ void PSGPUWrapper::EndPass() { std::max(keysize_max, current_task_->device_dim_keys_[i][j].size()); } } + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); int thread_num = 8; - auto dump_pool_to_cpu_func = [this, thread_num](int i, int j, int z) { - PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); - auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; - auto& device_keys = this->current_task_->device_dim_keys_[i][j]; - size_t len = device_keys.size(); - // ====== multi-thread process feasign================ - int len_per_thread = len / thread_num; - int remain = len % thread_num; - int left = -1, right = -1; - int real_len = len_per_thread; - if (z < remain) real_len++; - if (z < remain) { - left = z * (len_per_thread + 1); - right = left + real_len; - } else { - left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread; - right = left + real_len; - } - // ============ multi-thread process feasign============ - int mf_dim = this->index_dim_vec_[j]; - size_t feature_value_size = - TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(mf_dim)); - VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim - << " key_len :" << len << " feature_value_size:" << feature_value_size; - - char* test_build_values = (char*)malloc(feature_value_size * real_len); - uint64_t offset = left * feature_value_size; - cudaMemcpy(test_build_values, - hbm_pool->mem() + offset, - feature_value_size * real_len, - cudaMemcpyDeviceToHost); - CHECK(len == hbm_pool->capacity()); - uint64_t unuse_key = std::numeric_limits::max(); - for (int i = left; i < right; ++i) { - if (device_keys[i] == unuse_key) { - continue; - } - size_t local_offset = (i - left) * feature_value_size; - float* gpu_val = (float*)(test_build_values + local_offset); + auto dump_pool_to_cpu_func = + [this, thread_num, &accessor_wrapper_ptr](int i, int j, int z) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); + auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; + auto& device_keys = this->current_task_->device_dim_keys_[i][j]; + size_t len = device_keys.size(); + // ====== multi-thread process feasign================ + int len_per_thread = len / thread_num; + int remain = len % thread_num; + int left = -1, right = -1; + int real_len = len_per_thread; + if (z < remain) real_len++; + if (z < remain) { + left = z * (len_per_thread + 1); + right = left + real_len; + } else { + left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread; + right = left + real_len; + } + // ============ multi-thread process feasign============ + int mf_dim = this->index_dim_vec_[j]; + size_t feature_value_size = + accessor_wrapper_ptr->GetFeatureValueSize(mf_dim); + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim + << " key_len :" << len + << " feature_value_size:" << feature_value_size; + + char* test_build_values = (char*)malloc(feature_value_size * real_len); + uint64_t offset = left * feature_value_size; + cudaMemcpy(test_build_values, + hbm_pool->mem() + offset, + feature_value_size * real_len, + cudaMemcpyDeviceToHost); + CHECK(len == hbm_pool->capacity()); + uint64_t unuse_key = std::numeric_limits::max(); + for (int i = left; i < right; ++i) { + if (device_keys[i] == unuse_key) { + continue; + } + size_t local_offset = (i - left) * feature_value_size; + float* gpu_val = (float*)(test_build_values + local_offset); #ifdef PADDLE_WITH_PSLIB - auto* downpour_value = - (paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr); - int downpour_value_size = downpour_value->size(); - if (gpu_val->mf_size > 0 && downpour_value_size == 8) { - downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size); - } - float* cpu_val = downpour_value->data(); - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - delta_score_index()] = gpu_val->delta_score; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - show_index()] = gpu_val->show; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - click_index()] = gpu_val->clk; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - embed_w_index()] = gpu_val->lr; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - embed_g2sum_index()] = gpu_val->lr_g2sum; - cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: - slot_index()] = gpu_val->slot; - - if (gpu_val->mf_size > 0) { - for (int x = 0; x < gpu_val->mf_dim + 1; x++) { - cpu_val[x + 8] = gpu_val->mf[x]; + auto* downpour_value = + (paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr); + int downpour_value_size = downpour_value->size(); + if (gpu_val->mf_size > 0 && downpour_value_size == 8) { + downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size); + } + float* cpu_val = downpour_value->data(); + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::delta_score_index()] = + gpu_val->delta_score; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::show_index()] = + gpu_val->show; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::click_index()] = + gpu_val->clk; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_w_index()] = + gpu_val->lr; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_g2sum_index()] = + gpu_val->lr_g2sum; + cpu_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::slot_index()] = + gpu_val->slot; + + if (gpu_val->mf_size > 0) { + for (int x = 0; x < gpu_val->mf_dim + 1; x++) { + cpu_val[x + 8] = gpu_val->mf[x]; + } + } } - } - } #endif #ifdef PADDLE_WITH_PSCORE - auto* downpour_value = - (paddle::distributed::FixedFeatureValue*)(*(reinterpret_cast(gpu_val+ feature_value_accessor_.common_feature_value.CpuPtrIndex()))); - size_t downpour_value_size = downpour_value->size(); - if (gpu_val[feature_value_accessor_.common_feature_value.MfSizeIndex()] > 0 && - downpour_value_size == (cpu_table_accessor_->GetAccessorInfo().dim - - int(cpu_table_accessor_->GetAccessorInfo().mf_size / sizeof(float)))) { // cpu_accessor - downpour_value->resize(cpu_table_accessor_->common_feature_value.Dim(mf_dim)); + accessor_wrapper_ptr->DumpFill(gpu_val, cpu_table_accessor_, mf_dim); + auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(*( + reinterpret_cast(gpu_val))); + float* cpu_val = downpour_value->data(); + VLOG(5) << "dump to cpu " << index << " gpu_value: " + << accessor_wrapper_ptr->ParseToString( + gpu_val, + int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / + sizeof(float))) + << " \t cpu_value:" + << cpu_table_accessor_->ParseToString(cpu_val, + downpour_value->size()); } - float* cpu_val = downpour_value->data(); - - feature_value_accessor_.DumpFill(cpu_val, gpu_val, cpu_table_accessor_, mf_dim); - VLOG(5) << "dump to cpu "<< index << " : "<< feature_value_accessor_.ParseToString(gpu_val, feature_value_accessor_.common_feature_value.Dim(mf_dim)) - << " ===== CPU:" << cpu_table_accessor_->ParseToString(cpu_val, downpour_value->size()); - - } #endif - free(test_build_values); - }; - if (multi_mf_dim_) { - VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; - size_t device_num = heter_devices_.size(); - std::vector threads(device_num * multi_mf_dim_ * thread_num); - for (size_t i = 0; i < device_num; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - for (int k = 0; k < thread_num; k++) { - threads[(i + j * device_num) * thread_num + k] = - std::thread(dump_pool_to_cpu_func, i, j, k); - } + free(test_build_values); +}; +if (multi_mf_dim_) { + VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; + size_t device_num = heter_devices_.size(); + std::vector threads(device_num * multi_mf_dim_ * thread_num); + for (size_t i = 0; i < device_num; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + for (int k = 0; k < thread_num; k++) { + threads[(i + j * device_num) * thread_num + k] = + std::thread(dump_pool_to_cpu_func, i, j, k); } } - for (std::thread& t : threads) { - t.join(); - } } - if (keysize_max != 0) { - HeterPs_->end_pass(); + for (std::thread& t : threads) { + t.join(); } +} +if (keysize_max != 0) { + HeterPs_->end_pass(); +} - for (size_t i = 0; i < hbm_pools_.size(); i++) { - delete hbm_pools_[i]; - } - gpu_task_pool_.Push(current_task_); - current_task_ = nullptr; - gpu_free_channel_->Put(current_task_); - timer.Pause(); - VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; +for (size_t i = 0; i < hbm_pools_.size(); i++) { + delete hbm_pools_[i]; +} +gpu_task_pool_.Push(current_task_); +current_task_ = nullptr; +gpu_free_channel_->Put(current_task_); +timer.Pause(); +VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; } void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, @@ -1051,7 +1072,8 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, const std::vector& values, const std::vector& slot_lengths, const int hidden_size) { - VLOG(0) << "Warning:: recommand use pull_gpups_sparse op instead. This PullSparse is not used."; + VLOG(0) << "Warning:: recommand use pull_gpups_sparse op instead. This " + "PullSparse is not used."; } void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, @@ -1069,9 +1091,12 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); size_t feature_value_size = 0; - feature_value_size = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); - VLOG(3) << "PullSparse max_dim:" << max_mf_dim_ << " feature_value_size:" << feature_value_size; - + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); + feature_value_size = accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_); + VLOG(3) << "PullSparse max_dim:" << max_mf_dim_ + << " feature_value_size:" << feature_value_size; + #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begine Gpu Ps PullSparse"; auto buf = memory::Alloc(place, total_length * feature_value_size); @@ -1137,15 +1162,16 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length << "]"; - this->CopyForPull(place, - gpu_keys, - values, - total_values_gpu, - gpu_len, - static_cast(slot_lengths.size()), - hidden_size, - total_length, - gpu_dim); + accessor_wrapper_ptr->CopyForPull(place, + gpu_keys, + values, + total_values_gpu, + gpu_len, + static_cast(slot_lengths.size()), + hidden_size, + total_length, + gpu_dim, + val_type_size_); pull_gpups_timer.Pause(); @@ -1196,14 +1222,15 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length << "]"; - this->CopyForPull(place, - xpu_keys, - values, - total_values_gpu, - xpu_len, - static_cast(slot_lengths.size()), - hidden_size, - total_length); + accessor_wrapper_ptr->CopyForPull(place, + xpu_keys, + values, + total_values_gpu, + xpu_len, + static_cast(slot_lengths.size()), + hidden_size, + total_length, + val_type_size_); #endif } else { PADDLE_THROW(platform::errors::PreconditionNotMet( @@ -1230,12 +1257,13 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL); // #ifdef PADDLE_WITH_CUDA VLOG(3) << "Begin GPUPS PushSparseGrad"; - size_t grad_value_size = - TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); + size_t grad_value_size = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); auto buf = memory::Alloc(place, total_length * grad_value_size); - VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_ << "grad_value_size:" << grad_value_size; - float* total_grad_values_gpu = - reinterpret_cast(buf->ptr()); + VLOG(3) << "Push Sparse Max mf dimention: " << max_mf_dim_ + << "grad_value_size:" << grad_value_size; + float* total_grad_values_gpu = reinterpret_cast(buf->ptr()); if (platform::is_cpu_place(place)) { PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in GPUPS now.")); @@ -1247,13 +1275,15 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, uint64_t* total_keys = reinterpret_cast(cached_total_keys_tensor.data()); VLOG(3) << "Begin copy grad tensor to gpups struct"; - this->CopyForPush(place, - grad_values, - total_grad_values_gpu, - slot_lengths, - total_length, - batch_size, - grad_value_size); + accessor_wrapper_ptr->CopyForPush(place, + grad_values, + total_grad_values_gpu, + slot_lengths, + total_length, + batch_size, + grad_value_size, + slot_vector_, + slot_mf_dim_vector_); VLOG(3) << "Begin call PushSparseGPU in GPUPS, dev: " << devid_2_index << " len: " << total_length; @@ -1272,13 +1302,14 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, uint64_t* total_keys = reinterpret_cast(cached_total_keys_tensor.data()); VLOG(3) << "Begin copy grad tensor to xpups struct"; - this->CopyForPush(place, - grad_values, - total_grad_values_gpu, - slot_lengths, - hidden_size, - total_length, - batch_size); + accessor_wrapper_ptr->CopyForPush(place, + grad_values, + total_grad_values_gpu, + slot_lengths, + hidden_size, + total_length, + batch_size, + slot_vector_); VLOG(3) << "Begin call PushSparseXPU in XPUPS, dev: " << devid_2_index << " len: " << total_length; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu index 1fd5a3b73d0ffc..f8624f48d08f3f 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cu @@ -26,75 +26,6 @@ limitations under the License. */ namespace paddle { namespace framework { -__global__ void PullCopy(float** dest, - const FeatureValue* src, - const int64_t* len, - int hidden, - int slot_num, - int total_len, - uint64_t** keys) { - CUDA_KERNEL_LOOP(i, total_len) { - int low = 0; - int high = slot_num - 1; - while (low < high) { - int mid = (low + high) / 2; - if (i < len[mid]) - high = mid; - else - low = mid + 1; - } - int x = low; - int y = i - (x ? len[x - 1] : 0); - if (*(keys[x] + y) == 0) { - *(dest[x] + y * hidden) = 0; - *(dest[x] + y * hidden + 1) = 0; - *(dest[x] + y * hidden + 2) = 0; - } else { - *(dest[x] + y * hidden) = (src + i)->show; - *(dest[x] + y * hidden + 1) = (src + i)->clk; - *(dest[x] + y * hidden + 2) = (src + i)->lr; - } - if ((src + i)->mf_size == 0 || *(keys[x] + y) == 0) { - for (int j = 0; j < hidden - 3; j++) { - *(dest[x] + y * hidden + 3 + j) = 0; - } - } else { - for (int j = 0; j < hidden - 3; j++) { - *(dest[x] + y * hidden + 3 + j) = (src + i)->mf[1 + j]; - } - } - } -} - -template -__global__ void PullCopy(float** dest, - const float* src, - const int64_t* len, - int slot_num, - int total_len, - uint64_t** keys, - uint64_t max_val_size, - int* gpu_dim, - FVAceessor feature_value_accessor) { - CUDA_KERNEL_LOOP(i, total_len) { - int low = 0; - int high = slot_num - 1; - while (low < high) { - int mid = (low + high) / 2; - if (i < len[mid]) - high = mid; - else - low = mid + 1; - } - int x = low; - int y = i - (x ? len[x - 1] : 0); - float* feature_value_ptr = - (float*)((char*)src + uint64_t(i) * uint64_t(max_val_size)); - int mf_dim = gpu_dim[x] - 3; - feature_value_accessor.Select(dest[x] + y * (mf_dim + 3), feature_value_ptr, keys[x] + y, mf_dim); - } -} - __global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys, const int64_t* len, @@ -146,110 +77,8 @@ __global__ void PushCopy(FeaturePushValue* dest, } } -template -__global__ void PushCopyWithPool(float* dest, - float** src, - int64_t* len, - int slot_num, - uint64_t total_len, - int bs, - int* slot_vector, - int* mf_dim_vector, - size_t grad_value_size, - FVAceessor feature_value_accessor) { - CUDA_KERNEL_LOOP(i, total_len) { - int low = 0; - int high = slot_num - 1; - while (low < high) { - int mid = (low + high) / 2; - if (i < len[mid]) - high = mid; - else - low = mid + 1; - } - int x = low; - int y = i - (x ? len[low - 1] : 0); - float* cur = - (float*)((char*)dest + i * grad_value_size); - - cur[feature_value_accessor.common_push_value.SlotIndex()] = - (float)slot_vector[x]; - int mf_dim = mf_dim_vector[x]; - cur[feature_value_accessor.common_push_value.MfDimIndex()] = mf_dim; - - cur[feature_value_accessor.common_push_value.ShowIndex()] = - *(src[x] + y * (mf_dim + 3)); - cur[feature_value_accessor.common_push_value.ClickIndex()] = - *(src[x] + y * (mf_dim + 3) + 1); - cur[feature_value_accessor.common_push_value.EmbedGIndex()] = - *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; - for (int j = 0; j < mf_dim; j++) { - cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; - } - } -} PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; } -void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, - uint64_t** gpu_keys, - const std::vector& values, - const FeatureValue* total_values_gpu, - const int64_t* gpu_len, - const int slot_num, - const int hidden_size, - const int64_t total_length) { - auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - auto buf_value = memory::Alloc(place, values.size() * sizeof(float*)); - float** gpu_values = reinterpret_cast(buf_value->ptr()); - cudaMemcpy(gpu_values, - values.data(), - values.size() * sizeof(float*), - cudaMemcpyHostToDevice); - - PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( - gpu_values, - total_values_gpu, - gpu_len, - hidden_size, - slot_num, - total_length, - gpu_keys); - cudaStreamSynchronize(stream); -} - -void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, - uint64_t** gpu_keys, - const std::vector& values, - const float* total_values_gpu, - const int64_t* gpu_len, - const int slot_num, - const int hidden_size, - const int64_t total_length, - int* gpu_dim) { - auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - auto buf_value = memory::Alloc(place, values.size() * sizeof(float*)); - float** gpu_values = reinterpret_cast(buf_value->ptr()); - cudaMemcpy(gpu_values, - values.data(), - values.size() * sizeof(float*), - cudaMemcpyHostToDevice); - PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( - gpu_values, - total_values_gpu, - gpu_len, - slot_num, - total_length, - gpu_keys, - val_type_size_, - gpu_dim, - feature_value_accessor_); - cudaStreamSynchronize(stream); -} - void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys, uint64_t* total_keys, @@ -264,130 +93,26 @@ void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, cudaStreamSynchronize(stream); } -void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, - const std::vector& grad_values, - FeaturePushValue* total_grad_values_gpu, - const std::vector& slot_lengths, - const int hidden_size, - const int64_t total_length, - const int batch_size) { - auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - auto slot_lengths_lod = slot_lengths; - for (int i = 1; i < slot_lengths_lod.size(); i++) { - slot_lengths_lod[i] += slot_lengths_lod[i - 1]; - } - auto buf_grad_value = - memory::Alloc(place, grad_values.size() * sizeof(float*)); - auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); - auto buf_slot_vector = - memory::Alloc(place, slot_lengths_lod.size() * sizeof(int)); - - float** gpu_values = reinterpret_cast(buf_grad_value->ptr()); - int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); - int* d_slot_vector = reinterpret_cast(buf_slot_vector->ptr()); - - cudaMemcpy(gpu_values, - grad_values.data(), - grad_values.size() * sizeof(float*), - cudaMemcpyHostToDevice); - cudaMemcpy(gpu_len, - slot_lengths_lod.data(), - slot_lengths.size() * sizeof(int64_t), - cudaMemcpyHostToDevice); - cudaMemcpy(d_slot_vector, - slot_vector_.data(), - slot_lengths_lod.size() * sizeof(int), - cudaMemcpyHostToDevice); - - PushCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( - total_grad_values_gpu, - gpu_values, - gpu_len, - hidden_size, - slot_lengths.size(), - total_length, - batch_size, - d_slot_vector); - cudaStreamSynchronize(stream); -} - -void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, - const std::vector& grad_values, - float* total_grad_values_gpu, - const std::vector& slot_lengths, - const uint64_t total_length, - const int batch_size, - size_t grad_value_size) { - auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get(place)) - ->stream(); - auto slot_lengths_lod = slot_lengths; - for (int i = 1; i < slot_lengths_lod.size(); i++) { - slot_lengths_lod[i] += slot_lengths_lod[i - 1]; - } - auto buf_grad_value = - memory::Alloc(place, grad_values.size() * sizeof(float*)); - auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); - auto buf_slot_vector = - memory::Alloc(place, slot_lengths_lod.size() * sizeof(int)); - auto buf_mf_dim_vector = - memory::Alloc(place, slot_lengths_lod.size() * sizeof(int)); - float** gpu_values = reinterpret_cast(buf_grad_value->ptr()); - int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); - int* d_slot_vector = reinterpret_cast(buf_slot_vector->ptr()); - int* d_mf_dim_vector = reinterpret_cast(buf_mf_dim_vector->ptr()); - cudaMemcpy(gpu_values, - grad_values.data(), - grad_values.size() * sizeof(float*), - cudaMemcpyHostToDevice); - cudaMemcpy(gpu_len, - slot_lengths_lod.data(), - slot_lengths.size() * sizeof(int64_t), - cudaMemcpyHostToDevice); - cudaMemcpy(d_slot_vector, - slot_vector_.data(), - slot_lengths_lod.size() * sizeof(int), - cudaMemcpyHostToDevice); - cudaMemcpy(d_mf_dim_vector, - slot_mf_dim_vector_.data(), - slot_lengths_lod.size() * sizeof(int), - cudaMemcpyHostToDevice); - PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( - total_grad_values_gpu, - gpu_values, - gpu_len, - slot_lengths.size(), - total_length, - batch_size, - d_slot_vector, - d_mf_dim_vector, - grad_value_size, - feature_value_accessor_); - cudaStreamSynchronize(stream); -} - void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, float min_bound, float max_bound, float learning_rate, float initial_g2sum, - float initial_range, - float beta1_decay_rate, - float beta2_decay_rate, - float ada_epsilon) { + float initial_range, + float beta1_decay_rate, + float beta2_decay_rate, + float ada_epsilon) { optimizer_config_.set_sparse_sgd(nonclk_coeff, - clk_coeff, - min_bound, - max_bound, - learning_rate, - initial_g2sum, - initial_range, - beta1_decay_rate, - beta2_decay_rate, - ada_epsilon); + clk_coeff, + min_bound, + max_bound, + learning_rate, + initial_g2sum, + initial_range, + beta1_decay_rate, + beta2_decay_rate, + ada_epsilon); } void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, @@ -395,19 +120,19 @@ void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, float mf_initial_g2sum, float mf_initial_range, float mf_min_bound, - float mf_max_bound, - float mf_beta1_decay_rate, - float mf_beta2_decay_rate, - float mf_ada_epsilon) { + float mf_max_bound, + float mf_beta1_decay_rate, + float mf_beta2_decay_rate, + float mf_ada_epsilon) { optimizer_config_.set_embedx_sgd(mf_create_thresholds, - mf_learning_rate, - mf_initial_g2sum, - mf_initial_range, - mf_min_bound, - mf_max_bound, - mf_beta1_decay_rate, - mf_beta2_decay_rate, - mf_ada_epsilon); + mf_learning_rate, + mf_initial_g2sum, + mf_initial_range, + mf_min_bound, + mf_max_bound, + mf_beta1_decay_rate, + mf_beta2_decay_rate, + mf_ada_epsilon); } } // end namespace framework diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 555de56f1e0034..f0694cfb91b0e3 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -51,10 +51,10 @@ limitations under the License. */ #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_PSCORE -#include "paddle/fluid/distributed/ps/wrapper/fleet.h" +#include "paddle/fluid/distributed/ps.pb.h" #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" -#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/ps/wrapper/fleet.h" #endif #ifdef PADDLE_WITH_PSLIB #include "afs_api.h" @@ -66,9 +66,6 @@ limitations under the License. */ namespace paddle { namespace framework { -#define TYPEALIGN(ALIGNVAL, LEN) \ - (((uint64_t)(LEN) + ((ALIGNVAL)-1)) & ~((uint64_t)((ALIGNVAL)-1))) - class Dataset; #ifdef PADDLE_WITH_PSLIB @@ -141,37 +138,6 @@ class PSGPUWrapper { const int64_t* gpu_len, int slot_num, int total_len); - void CopyForPull(const paddle::platform::Place& place, - uint64_t** gpu_keys, - const std::vector& values, - const FeatureValue* total_values_gpu, - const int64_t* gpu_len, - const int slot_num, - const int hidden_size, - const int64_t total_length); - void CopyForPull(const paddle::platform::Place& place, - uint64_t** gpu_keys, - const std::vector& values, - const float* total_values_gpu, - const int64_t* gpu_len, - const int slot_num, - const int hidden_size, - const int64_t total_length, - int* gpu_dim); - void CopyForPush(const paddle::platform::Place& place, - const std::vector& grad_values, - FeaturePushValue* total_grad_values_gpu, - const std::vector& slot_lengths, - const int hidden_size, - const int64_t total_length, - const int batch_size); - void CopyForPush(const paddle::platform::Place& place, - const std::vector& grad_values, - float* total_grad_values_gpu, - const std::vector& slot_lengths, - const uint64_t total_length, - const int batch_size, - size_t grad_value_size); void BuildGPUTask(std::shared_ptr gpu_task); void PreBuildTask(std::shared_ptr gpu_task); @@ -276,9 +242,9 @@ class PSGPUWrapper { float max_bound, float learning_rate, float initial_g2sum, - float initial_range, + float initial_range, float beta1_decay_rate, - float beta2_decay_rate, + float beta2_decay_rate, float ada_epsilon); void SetEmbedxSGD(float mf_create_thresholds, float mf_learning_rate, @@ -286,10 +252,10 @@ class PSGPUWrapper { float mf_initial_range, float mf_min_bound, float mf_max_bound, - float mf_beta1_decay_rate, + float mf_beta1_decay_rate, float mf_beta2_decay_rate, float mf_ada_epsilon); - + #ifdef PADDLE_WITH_PSCORE void add_sparse_optimizer( std::unordered_map& config, // NOLINT @@ -339,11 +305,11 @@ class PSGPUWrapper { void InitializeGPUServer(paddle::distributed::PSParameter ps_param) { auto sparse_table = - ps_param.server_param().downpour_server_param().downpour_table_param(0); + ps_param.server_param().downpour_server_param().downpour_table_param(0); auto sparse_table_accessor = sparse_table.accessor(); auto sparse_table_accessor_parameter = sparse_table_accessor.ctr_accessor_param(); - auto accessor_class = sparse_table_accessor.accessor_class(); + accessor_class_ = sparse_table_accessor.accessor_class(); std::unordered_map config; config["embedx_dim"] = sparse_table_accessor.embedx_dim(); @@ -351,18 +317,25 @@ class PSGPUWrapper { config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff(); config["mf_create_thresholds"] = sparse_table_accessor.embedx_threshold(); +<<<<<<< HEAD if (accessor_class == "CtrDymfAccessor") { +======= + if (accessor_class_ == "CtrDymfAccessor") { +>>>>>>> b73bb24e28... fix adam accessor:template;test=develop // optimizer config for embed_w and embedx add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param()); - add_sparse_optimizer(config, sparse_table_accessor.embedx_sgd_param(), - "mf_"); + add_sparse_optimizer( + config, sparse_table_accessor.embedx_sgd_param(), "mf_"); } - feature_value_accessor_.Configure(config); + fleet_config_ = config; + GlobalAccessorTransfor::GetInstance().Init(accessor_class_); + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper()->Configure( + config); InitializeGPUServer(config); } - #endif +#endif void InitializeGPUServer(std::unordered_map config) { float nonclk_coeff = (config.find("nonclk_coeff") == config.end()) @@ -373,9 +346,8 @@ class PSGPUWrapper { float min_bound = (config.find("min_bound") == config.end()) ? -10.0 : config["min_bound"]; - float max_bound = (config.find("max_bound") == config.end()) - ? 10.0 - : config["max_bound"]; + float max_bound = + (config.find("max_bound") == config.end()) ? 10.0 : config["max_bound"]; float learning_rate = (config.find("learning_rate") == config.end()) ? 0.05 : config["learning_rate"]; @@ -392,8 +364,8 @@ class PSGPUWrapper { ? 0.999 : config["beta2_decay_rate"]; float ada_epsilon = (config.find("ada_epsilon") == config.end()) - ? 1e-8 - : config["ada_epsilon"]; + ? 1e-8 + : config["ada_epsilon"]; // mf config settings float mf_create_thresholds = (config.find("mf_create_thresholds") == config.end()) @@ -414,15 +386,18 @@ class PSGPUWrapper { float mf_max_bound = (config.find("mf_max_bound") == config.end()) ? 10.0 : config["mf_max_bound"]; - float mf_beta1_decay_rate = (config.find("mf_beta1_decay_rate") == config.end()) - ? 0.9 - : config["mf_beta1_decay_rate"]; - float mf_beta2_decay_rate = (config.find("mf_beta2_decay_rate") == config.end()) - ? 0.999 - : config["mf_beta2_decay_rate"]; + float mf_beta1_decay_rate = + (config.find("mf_beta1_decay_rate") == config.end()) + ? 0.9 + : config["mf_beta1_decay_rate"]; + float mf_beta2_decay_rate = + (config.find("mf_beta2_decay_rate") == config.end()) + ? 0.999 + : config["mf_beta2_decay_rate"]; float mf_ada_epsilon = (config.find("mf_ada_epsilon") == config.end()) - ? 1e-8 - : config["mf_ada_epsilon"]; +<<<<<<< HEAD + ? 1e-8 + : config["mf_ada_epsilon"]; this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, @@ -430,30 +405,53 @@ class PSGPUWrapper { learning_rate, initial_g2sum, initial_range, - beta1_decay_rate, - beta2_decay_rate, + beta1_decay_rate, + beta2_decay_rate, ada_epsilon); this->SetEmbedxSGD(mf_create_thresholds, mf_learning_rate, mf_initial_g2sum, mf_initial_range, mf_min_bound, - mf_max_bound, - mf_beta1_decay_rate, + mf_max_bound, + mf_beta1_decay_rate, mf_beta2_decay_rate, +======= + ? 1e-8 + : config["mf_ada_epsilon"]; + + this->SetSparseSGD(nonclk_coeff, + clk_coeff, + min_bound, + max_bound, + learning_rate, + initial_g2sum, + initial_range, + beta1_decay_rate, + beta2_decay_rate, + ada_epsilon); + this->SetEmbedxSGD(mf_create_thresholds, + mf_learning_rate, + mf_initial_g2sum, + mf_initial_range, + mf_min_bound, + mf_max_bound, + mf_beta1_decay_rate, + mf_beta2_decay_rate, +>>>>>>> b73bb24e28... fix adam accessor:template;test=develop mf_ada_epsilon); // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam) optimizer_type_ = (config.find("optimizer_type") == config.end()) - ? 1 - : int(config["optimizer_type"]); + ? 1 + : int(config["optimizer_type"]); embedx_dim_ = (config.find("embedx_dim") == config.end()) - ? 8 - : int(config["embedx_dim"]); - if (optimizer_type_ == 3) { //adam + ? 8 + : int(config["embedx_dim"]); + if (optimizer_type_ == 3) { // adam embed_sgd_dim_ = 4; embedx_sgd_dim_ = embedx_dim_ * 2 + 2; - } else if (optimizer_type_ == 4) { //shared_adam + } else if (optimizer_type_ == 4) { // shared_adam embed_sgd_dim_ = 4; embedx_sgd_dim_ = 4; } else { @@ -461,8 +459,9 @@ class PSGPUWrapper { embedx_sgd_dim_ = 1; } - VLOG(0) << "InitializeGPUServer embed_sgd_dim_:" << embed_sgd_dim_ << " embedx_sgd_dim_:" - << embedx_sgd_dim_ << " embedx_dim_:" << embedx_dim_ + VLOG(0) << "InitializeGPUServer embed_sgd_dim_:" << embed_sgd_dim_ + << " embedx_sgd_dim_:" << embedx_sgd_dim_ + << " embedx_dim_:" << embedx_dim_ << " optimizer_type_:" << optimizer_type_; } @@ -549,9 +548,13 @@ class PSGPUWrapper { for (size_t i = 0; i < slot_index_vec_.size(); i++) { slot_index_vec_[i] = dim_index_map[slot_mf_dim_vector_[i]]; } - val_type_size_ = TYPEALIGN(8, feature_value_accessor_.common_feature_value.Size(max_mf_dim_)); - grad_type_size_ = TYPEALIGN(8, feature_value_accessor_.common_push_value.Size(max_mf_dim_)); - VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_ << " grad_type_size_:" << grad_type_size_; + + auto accessor_wrapper_ptr = + GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); + val_type_size_ = accessor_wrapper_ptr->GetFeatureValueSize(max_mf_dim_); + grad_type_size_ = accessor_wrapper_ptr->GetPushValueSize(max_mf_dim_); + VLOG(0) << "InitSlotInfo: val_type_size_" << val_type_size_ + << " grad_type_size_:" << grad_type_size_; slot_info_initialized_ = true; } #endif @@ -574,11 +577,11 @@ class PSGPUWrapper { #ifdef PADDLE_WITH_PSCORE void SetTableAccessor(paddle::distributed::ValueAccessor* accessor) { - cpu_table_accessor_ = dynamic_cast(accessor); + cpu_table_accessor_ = + dynamic_cast(accessor); } #endif - CommonFeatureValueAccessor feature_value_accessor_; private: static std::shared_ptr s_instance_; Dataset* dataset_; @@ -635,6 +638,8 @@ class PSGPUWrapper { int embed_sgd_dim_ = 1; int embedx_sgd_dim_ = 1; int embedx_dim_ = 8; + std::string accessor_class_; + std::unordered_map fleet_config_; #ifdef PADDLE_WITH_PSCORE paddle::distributed::CtrDymfAccessor* cpu_table_accessor_; #endif @@ -663,6 +668,7 @@ class PSGPUWrapper { std::vector> pull_thread_pool_; std::vector> hbm_thread_pool_; OptimizerConfig optimizer_config_; + protected: static bool is_initialized_; }; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps b/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps index 369a20874d42e3..3505bff72e90a1 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.kps @@ -28,9 +28,13 @@ limitations under the License. */ namespace paddle { namespace framework { -__global__ void PullCopy(float* dest, const FeatureValue* src, - const long long* len, int hidden, int slot_num, - int total_len, unsigned long long* keys) { +__global__ void PullCopy(float* dest, + const FeatureValue* src, + const long long* len, + int hidden, + int slot_num, + int total_len, + unsigned long long* keys) { int cid = core_id(); int ncores = core_num(); if (cid >= ncores) { @@ -42,8 +46,8 @@ __global__ void PullCopy(float* dest, const FeatureValue* src, GM2LM(len, local_len, slot_num * sizeof(int64_t)); __global_ptr__ unsigned long long* local_keys[slot_num]; - GM2LM(keys, local_keys, - slot_num * sizeof(__global_ptr__ unsigned long long*)); + GM2LM( + keys, local_keys, slot_num * sizeof(__global_ptr__ unsigned long long*)); __global_ptr__ float* local_dest[slot_num]; GM2LM(dest, local_dest, slot_num * sizeof(__global_ptr__ float*)); @@ -64,10 +68,11 @@ __global__ void PullCopy(float* dest, const FeatureValue* src, // copy read_len (length) of slots' val to LM for (int k = 0; k < slot_len; k += read_len) { int real_read_len = min(read_len, slot_len - k); - GM2LM(src + dest_len + k, local_slot_vals, + GM2LM(src + dest_len + k, + local_slot_vals, real_read_len * sizeof(FeatureValue)); - GM2LM(local_keys[i] + k, local_slot_keys, - real_read_len * sizeof(uint64_t)); + GM2LM( + local_keys[i] + k, local_slot_keys, real_read_len * sizeof(uint64_t)); for (int j = 0; j < real_read_len; j++) { if (local_slot_keys[j] == 0) { local_dest_vals[j * hidden] = 0; @@ -89,7 +94,8 @@ __global__ void PullCopy(float* dest, const FeatureValue* src, } } } - LM2GM(local_dest_vals, local_dest[i] + k * hidden, + LM2GM(local_dest_vals, + local_dest[i] + k * hidden, real_read_len * hidden * sizeof(float)); } } @@ -97,7 +103,8 @@ __global__ void PullCopy(float* dest, const FeatureValue* src, __global__ void CopyKeysKernel(unsigned long long* src_keys, unsigned long long* dest_total_keys, - const long long* len, int slot_num, + const long long* len, + int slot_num, int total_len) { int cid = core_id(); int ncores = core_num(); @@ -110,7 +117,8 @@ __global__ void CopyKeysKernel(unsigned long long* src_keys, GM2LM(len, local_len, slot_num * sizeof(long long)); __global_ptr__ unsigned long long* local_keys[slot_num]; - GM2LM(src_keys, local_keys, + GM2LM(src_keys, + local_keys, slot_num * sizeof(__global_ptr__ unsigned long long*)); for (int i = thread_id; i < slot_num; i += nthreads) { @@ -123,16 +131,23 @@ __global__ void CopyKeysKernel(unsigned long long* src_keys, for (int k = 0; k < slot_len; k += read_len) { int real_read_len = min(read_len, slot_len - k); - GM2LM(local_keys[i] + k, local_slot_keys, + GM2LM(local_keys[i] + k, + local_slot_keys, real_read_len * sizeof(unsigned long long)); - LM2GM(local_slot_keys, dest_total_keys + dest_len + k, + LM2GM(local_slot_keys, + dest_total_keys + dest_len + k, real_read_len * sizeof(unsigned long long)); } } } -__global__ void PushCopy(FeaturePushValue* dest, float* src, long long* len, - int hidden, int slot_num, int total_len, int bs, +__global__ void PushCopy(FeaturePushValue* dest, + float* src, + long long* len, + int hidden, + int slot_num, + int total_len, + int bs, int* slot_vector) { int cid = core_id(); int ncores = core_num(); @@ -163,7 +178,8 @@ __global__ void PushCopy(FeaturePushValue* dest, float* src, long long* len, // copy read_len(length) of slots' grad to LM for (int k = 0; k < slot_len; k += read_len) { int real_read_len = min(read_len, slot_len - k); - GM2LM(local_src[i] + k * hidden, local_slot_grads, + GM2LM(local_src[i] + k * hidden, + local_slot_grads, real_read_len * hidden * sizeof(float)); // copy from slots' grad to total grad for (int j = 0; j < real_read_len; j++) { @@ -176,7 +192,8 @@ __global__ void PushCopy(FeaturePushValue* dest, float* src, long long* len, local_slot_grads[j * hidden + 3 + m] * -1. * bs; } } - LM2GM(local_dest_grads, dest + dest_len + k, + LM2GM(local_dest_grads, + dest + dest_len + k, real_read_len * sizeof(FeaturePushValue)); } } @@ -184,40 +201,11 @@ __global__ void PushCopy(FeaturePushValue* dest, float* src, long long* len, PSGPUWrapper::~PSGPUWrapper() { delete HeterPs_; } -void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place, - uint64_t** gpu_keys, - const std::vector& values, - const FeatureValue* total_values_gpu, - const int64_t* gpu_len, const int slot_num, - const int hidden_size, - const int64_t total_length) { - XPUStream stream = nullptr; - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx) - ->x_context() - ->xpu_stream; - // float* buf_value = nullptr; - // xpu_malloc(reinterpret_cast(&buf_value), - // values.size() * sizeof(float*)); - // float** gpu_values = reinterpret_cast(&buf_value); - float* gpu_values = nullptr; - xpu_malloc(reinterpret_cast(&gpu_values), - values.size() * sizeof(float*)); - xpu_memcpy(gpu_values, values.data(), values.size() * sizeof(float*), - XPU_HOST_TO_DEVICE); - - // unsigned long long** c_keys = (unsigned long long**)gpu_keys; - unsigned long long* c_keys = reinterpret_cast(gpu_keys); - const long long* c_len = (const long long*)gpu_len; - PullCopy<<<2, 64, stream>>>(gpu_values, total_values_gpu, c_len, hidden_size, - slot_num, total_length, c_keys); - - xpu_wait(stream); -} - void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, - uint64_t** origin_keys, uint64_t* total_keys, - const int64_t* gpu_len, int slot_num, + uint64_t** origin_keys, + uint64_t* total_keys, + const int64_t* gpu_len, + int slot_num, int total_len) { XPUStream stream = nullptr; auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); @@ -232,70 +220,49 @@ void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place, xpu_wait(stream); } -void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place, - const std::vector& grad_values, - FeaturePushValue* total_grad_values_gpu, - const std::vector& slot_lengths, - const int hidden_size, - const int64_t total_length, - const int batch_size) { - XPUStream stream = nullptr; - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx) - ->x_context() - ->xpu_stream; - auto slot_lengths_lod = slot_lengths; - for (size_t i = 1; i < slot_lengths_lod.size(); i++) { - slot_lengths_lod[i] += slot_lengths_lod[i - 1]; - } - - float* gpu_values = nullptr; - int64_t* gpu_len = nullptr; - int* d_slot_vector = nullptr; - - xpu_malloc(reinterpret_cast(&gpu_values), - grad_values.size() * sizeof(float*)); - xpu_malloc(reinterpret_cast(&gpu_len), - slot_lengths.size() * sizeof(int64_t)); - xpu_malloc(reinterpret_cast(&d_slot_vector), - slot_lengths_lod.size() * sizeof(int)); - - xpu_memcpy(gpu_values, grad_values.data(), - grad_values.size() * sizeof(float*), XPU_HOST_TO_DEVICE); - xpu_memcpy(gpu_len, slot_lengths_lod.data(), - slot_lengths.size() * sizeof(int64_t), XPU_HOST_TO_DEVICE); - xpu_memcpy(d_slot_vector, slot_vector_.data(), - slot_lengths_lod.size() * sizeof(int), XPU_HOST_TO_DEVICE); - - long long* c_len = (long long*)gpu_len; - PushCopy<<<2, 64, stream>>>(total_grad_values_gpu, gpu_values, c_len, - hidden_size, slot_lengths.size(), total_length, - batch_size, d_slot_vector); - xpu_wait(stream); -} - -void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, float clk_coeff, - float min_bound, float max_bound, - float learning_rate, float initial_g2sum, - float initial_range, float beta1_decay_rate, - float beta2_decay_rate, float ada_epsilon) { +void PSGPUWrapper::SetSparseSGD(float nonclk_coeff, + float clk_coeff, + float min_bound, + float max_bound, + float learning_rate, + float initial_g2sum, + float initial_range, + float beta1_decay_rate, + float beta2_decay_rate, + float ada_epsilon) { OptimizerConfig optimizer_config; - optimizer_config.set_sparse_sgd(nonclk_coeff, clk_coeff, min_bound, max_bound, - learning_rate, initial_g2sum, initial_range, - beta1_decay_rate, beta2_decay_rate, ada_epsilon); + optimizer_config.set_sparse_sgd(nonclk_coeff, + clk_coeff, + min_bound, + max_bound, + learning_rate, + initial_g2sum, + initial_range, + beta1_decay_rate, + beta2_decay_rate, + ada_epsilon); HeterPs_->set_sparse_sgd(optimizer_config); } void PSGPUWrapper::SetEmbedxSGD(float mf_create_thresholds, - float mf_learning_rate, float mf_initial_g2sum, - float mf_initial_range, float mf_min_bound, - float mf_max_bound, float mf_beta1_decay_rate, - float mf_beta2_decay_rate, float mf_ada_epsilon) { + float mf_learning_rate, + float mf_initial_g2sum, + float mf_initial_range, + float mf_min_bound, + float mf_max_bound, + float mf_beta1_decay_rate, + float mf_beta2_decay_rate, + float mf_ada_epsilon) { OptimizerConfig optimizer_config; - optimizer_config.set_embedx_sgd(mf_create_thresholds, mf_learning_rate, - mf_initial_g2sum, mf_initial_range, - mf_min_bound, mf_max_bound,mf_beta1_decay_rate, - mf_beta2_decay_rate, mf_ada_epsilon); + optimizer_config.set_embedx_sgd(mf_create_thresholds, + mf_learning_rate, + mf_initial_g2sum, + mf_initial_range, + mf_min_bound, + mf_max_bound, + mf_beta1_decay_rate, + mf_beta2_decay_rate, + mf_ada_epsilon); HeterPs_->set_embedx_sgd(optimizer_config); } From 7aa109f4db5363564f87133a65f7bd2198c6bac6 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 6 Jul 2022 13:41:56 +0000 Subject: [PATCH 17/31] fix; test=develop --- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 32 ------------ paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 50 ------------------- python/paddle/distributed/ps/the_one_ps.py | 1 + 3 files changed, 1 insertion(+), 82 deletions(-) diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 62265c3df6a4dd..d40245a3696180 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -577,33 +577,6 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { } } #endif - // #ifdef PADDLE_WITH_PSCORE - // for (int j = 0; j < len; ++j) { - // device_keys[dev][cur + j] = task_keys[dev][j]; - // float* ptr_val = task_ptrs[dev][j]->data(); - // FeatureValue& val = device_vals[dev][cur + j]; - // size_t dim = task_ptrs[dev][j]->size(); - // val.delta_score = ptr_val[2]; - // val.show = ptr_val[3]; - // val.clk = ptr_val[4]; - // val.slot = ptr_val[0]; - // val.lr = ptr_val[5]; - // val.lr_g2sum = ptr_val[6]; - // val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]); - - // if (dim > 7) { - // val.mf_size = MF_DIM + 1; - // for (int x = 0; x < val.mf_size; x++) { - // val.mf[x] = ptr_val[x + 7]; - // } - // } else { - // val.mf_size = 0; - // for (int x = 0; x < MF_DIM + 1; x++) { - // val.mf[x] = 0; - // } - // } - // } - // #endif VLOG(3) << "GpuPs build hbmps done"; }; @@ -713,15 +686,10 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { // this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); int mf_dim = this->index_dim_vec_[j]; VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim; - // size_t feature_value_size = - // TYPEALIGN(8, sizeof(FeatureValue) + ((mf_dim + 1) * - // sizeof(float))); auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; size_t len = device_dim_keys.size(); CHECK(len == device_dim_ptrs.size()); - // this->mem_pools_[i * this->multi_mf_dim_ + j] = - // new MemoryPool(len, feature_value_size); auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; // ============ add for multi-thread ================ diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index f0694cfb91b0e3..c28eb5ad7f0056 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -317,12 +317,7 @@ class PSGPUWrapper { config["clk_coeff"] = sparse_table_accessor_parameter.click_coeff(); config["mf_create_thresholds"] = sparse_table_accessor.embedx_threshold(); -<<<<<<< HEAD - - if (accessor_class == "CtrDymfAccessor") { -======= if (accessor_class_ == "CtrDymfAccessor") { ->>>>>>> b73bb24e28... fix adam accessor:template;test=develop // optimizer config for embed_w and embedx add_sparse_optimizer(config, sparse_table_accessor.embed_sgd_param()); add_sparse_optimizer( @@ -395,31 +390,8 @@ class PSGPUWrapper { ? 0.999 : config["mf_beta2_decay_rate"]; float mf_ada_epsilon = (config.find("mf_ada_epsilon") == config.end()) -<<<<<<< HEAD ? 1e-8 : config["mf_ada_epsilon"]; - this->SetSparseSGD(nonclk_coeff, - clk_coeff, - min_bound, - max_bound, - learning_rate, - initial_g2sum, - initial_range, - beta1_decay_rate, - beta2_decay_rate, - ada_epsilon); - this->SetEmbedxSGD(mf_create_thresholds, - mf_learning_rate, - mf_initial_g2sum, - mf_initial_range, - mf_min_bound, - mf_max_bound, - mf_beta1_decay_rate, - mf_beta2_decay_rate, -======= - ? 1e-8 - : config["mf_ada_epsilon"]; - this->SetSparseSGD(nonclk_coeff, clk_coeff, min_bound, @@ -438,31 +410,12 @@ class PSGPUWrapper { mf_max_bound, mf_beta1_decay_rate, mf_beta2_decay_rate, ->>>>>>> b73bb24e28... fix adam accessor:template;test=develop mf_ada_epsilon); // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam) optimizer_type_ = (config.find("optimizer_type") == config.end()) ? 1 : int(config["optimizer_type"]); - embedx_dim_ = (config.find("embedx_dim") == config.end()) - ? 8 - : int(config["embedx_dim"]); - if (optimizer_type_ == 3) { // adam - embed_sgd_dim_ = 4; - embedx_sgd_dim_ = embedx_dim_ * 2 + 2; - } else if (optimizer_type_ == 4) { // shared_adam - embed_sgd_dim_ = 4; - embedx_sgd_dim_ = 4; - } else { - embed_sgd_dim_ = 1; - embedx_sgd_dim_ = 1; - } - - VLOG(0) << "InitializeGPUServer embed_sgd_dim_:" << embed_sgd_dim_ - << " embedx_sgd_dim_:" << embedx_sgd_dim_ - << " embedx_dim_:" << embedx_dim_ - << " optimizer_type_:" << optimizer_type_; } void SetDate(int year, int month, int day) { @@ -635,9 +588,6 @@ class PSGPUWrapper { bool slot_info_initialized_ = false; int use_afs_api_ = 0; int optimizer_type_ = 1; - int embed_sgd_dim_ = 1; - int embedx_sgd_dim_ = 1; - int embedx_dim_ = 8; std::string accessor_class_; std::unordered_map fleet_config_; #ifdef PADDLE_WITH_PSCORE diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 42d7c11eab2631..7d240983a1c289 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -824,6 +824,7 @@ def build_worker_desc(self): self.barrier_table_id = table.idx self.service._set( self.ps_desc.server_param.downpour_server_param.service_param) + self.fs_client._set(self.ps_desc.fs_client_param) return text_format.MessageToString(self.ps_desc) def build_server_desc(self): From 407644048deca533a8db95bb73fafd8f41e72673 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 6 Jul 2022 14:21:32 +0000 Subject: [PATCH 18/31] fix; test=develop --- .../framework/fleet/heter_ps/heter_comm_inl.h | 34 ------------------- .../fleet/heter_ps/heter_comm_kernel.h | 1 - 2 files changed, 35 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 7f0b4528d45d49..03cce63d9f0533 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -25,36 +25,6 @@ limitations under the License. */ namespace paddle { namespace framework { -// template -// HeterComm::HeterComm( -// size_t capacity, std::shared_ptr resource) { -// VLOG(1) << "Construct new HeterComm"; -// resource_ = resource; -// storage_.resize(resource_->total_device()); -// multi_mf_dim_ = resource->multi_mf(); -// load_factor_ = FLAGS_gpugraph_hbm_table_load_factor; -// VLOG(0) << "load_factor = " << load_factor_; -// for (int i = 0; i < resource_->total_device(); ++i) { -// #if defined(PADDLE_WITH_CUDA) -// platform::CUDADeviceGuard guard(resource_->dev_id(i)); -// allocators_.push_back(std::make_shared( -// 8, 1, (unsigned int)-1, (size_t)-1, false, false)); // NOLINT -// #endif -// if (!multi_mf_dim_) { -// auto table = new Table(capacity / load_factor_); -// tables_.push_back(table); -// } else { -// VLOG(0) << "Error:use HeterComm Construct with accessor"; -// return; -// } -// if (multi_node_) { -// storage_[i].init(feanum_, resource_->dev_id(i)); -// } -// } -// heter_comm_kernel_ = std::make_unique(block_size_); -// init_path(); -// } template ::HeterComm( VLOG(0) << " HeterComm init, max feature_value_size:" << val_type_size << ", feature_value_push_size:" << grad_type_size; auto ptr_table = new PtrTable(capacity / load_factor_); - // ptr_table->set_accessor(feature_value_accessor_); ptr_table->set_feature_value_size(val_type_size, grad_type_size); ptr_tables_.push_back(ptr_table); } @@ -94,8 +63,6 @@ HeterComm::HeterComm( storage_[i].init(feanum_, resource_->dev_id(i)); } } - // heter_comm_kernel_ = std::make_unique(block_size_, - // feature_value_accessor_); heter_comm_kernel_ = std::make_unique(block_size_); init_path(); } @@ -161,7 +128,6 @@ template template -<<<<<<< HEAD void HeterComm::memory_copy( DstPlace dst_place, void* dst, diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index bf82dfab165292..82969ea45ba441 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -178,7 +178,6 @@ class HeterCommKernel { const StreamType& stream, FVAccessor& feature_value_accessor); - // CommonFeatureValueAccessor feature_value_accessor_; private: int block_size_{256}; }; From 1792f911ee5e4303672637638320770e8805d344 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Wed, 6 Jul 2022 14:40:10 +0000 Subject: [PATCH 19/31] add feature_value.cu;test=develop --- cmake/cuda.cmake | 3 +- .../framework/fleet/heter_ps/feature_value.cu | 155 ++++++++++++++++++ .../framework/fleet/heter_ps/feature_value.h | 55 ++++--- .../framework/fleet/heter_ps/hashtable.h | 4 - .../fluid/framework/fleet/ps_gpu_wrapper.cc | 30 ++-- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 5 +- 6 files changed, 201 insertions(+), 51 deletions(-) create mode 100644 paddle/fluid/framework/fleet/heter_ps/feature_value.cu diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 87b943abd0106d..ea53b103c333e1 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -260,7 +260,8 @@ add_definitions("-DCUDA_VERSION_MINOR=\"${CUDA_VERSION_MINOR}\"") add_definitions("-DCUDA_TOOLKIT_ROOT_DIR=\"${CUDA_TOOLKIT_ROOT_DIR}\"") # setting nvcc arch flags -select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) +#select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) +set(NVCC_FLAGS_EXTRA "-gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}") message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}") diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.cu b/paddle/fluid/framework/fleet/heter_ps/feature_value.cu new file mode 100644 index 00000000000000..eff345fe44caa8 --- /dev/null +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.cu @@ -0,0 +1,155 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" + + +namespace paddle { +namespace framework { + + +template +__global__ void PullCopy(float** dest, const float* src, + const int64_t* len, int slot_num, int total_len, + uint64_t** keys, uint64_t max_val_size, int* gpu_dim, + FVAccessor feature_value_accessor) { + CUDA_KERNEL_LOOP(i, total_len) { + int low = 0; + int high = slot_num - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < len[mid]) + high = mid; + else + low = mid + 1; + } + int x = low; + int y = i - (x ? len[x - 1] : 0); + float* feature_value_ptr = + (float*)((char*)src + uint64_t(i) * uint64_t(max_val_size)); + int mf_dim = gpu_dim[x] - 3; + feature_value_accessor.Select(dest[x] + y * (mf_dim + 3), feature_value_ptr, keys[x] + y, mf_dim); + } +} + +template +__global__ void PushCopyWithPool(float* dest, float** src, + int64_t* len, int slot_num, uint64_t total_len, + int bs, int* slot_vector, int* mf_dim_vector, + size_t grad_value_size, + FVAccessor feature_value_accessor) { + CUDA_KERNEL_LOOP(i, total_len) { + int low = 0; + int high = slot_num - 1; + while (low < high) { + int mid = (low + high) / 2; + if (i < len[mid]) + high = mid; + else + low = mid + 1; + } + int x = low; + int y = i - (x ? len[low - 1] : 0); + float* cur = + (float*)((char*)dest + i * grad_value_size); + + cur[feature_value_accessor.common_push_value.SlotIndex()] = + (float)slot_vector[x]; + int mf_dim = mf_dim_vector[x]; + cur[feature_value_accessor.common_push_value.MfDimIndex()] = mf_dim; + + cur[feature_value_accessor.common_push_value.ShowIndex()] = + *(src[x] + y * (mf_dim + 3)); + cur[feature_value_accessor.common_push_value.ClickIndex()] = + *(src[x] + y * (mf_dim + 3) + 1); + cur[feature_value_accessor.common_push_value.EmbedGIndex()] = + *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; + for (int j = 0; j < mf_dim; j++) { + cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; + } + } +} + +template +void AccessorWrapper::CopyForPullImpl(const paddle::platform::Place& place, + uint64_t** gpu_keys, + const std::vector& values, + const float* total_values_gpu, + const int64_t* gpu_len, const int slot_num, + const int hidden_size, + const int64_t total_length, + int* gpu_dim, + int feature_value_size) { + auto stream = dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + auto buf_value = memory::Alloc(place, values.size() * sizeof(float*)); + float** gpu_values = reinterpret_cast(buf_value->ptr()); + cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), + cudaMemcpyHostToDevice); + PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( + gpu_values, total_values_gpu, gpu_len, slot_num, total_length, gpu_keys, + feature_value_size, gpu_dim, gpu_accessor_); + cudaStreamSynchronize(stream); +} + +template +void AccessorWrapper::CopyForPushImpl(const paddle::platform::Place& place, + const std::vector& grad_values, + float* total_grad_values_gpu, + const std::vector& slot_lengths, + const uint64_t total_length, + const int batch_size, size_t grad_value_size, + std::vector& slot_vector, + std::vector& slot_mf_dim_vector) { + auto stream = dynamic_cast( + paddle::platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + auto slot_lengths_lod = slot_lengths; + for (int i = 1; i < slot_lengths_lod.size(); i++) { + slot_lengths_lod[i] += slot_lengths_lod[i - 1]; + } + auto buf_grad_value = + memory::Alloc(place, grad_values.size() * sizeof(float*)); + auto buf_length = memory::Alloc(place, slot_lengths.size() * sizeof(int64_t)); + auto buf_slot_vector = + memory::Alloc(place, slot_lengths_lod.size() * sizeof(int)); + auto buf_mf_dim_vector = + memory::Alloc(place, slot_lengths_lod.size() * sizeof(int)); + float** gpu_values = reinterpret_cast(buf_grad_value->ptr()); + int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); + int* d_slot_vector = reinterpret_cast(buf_slot_vector->ptr()); + int* d_mf_dim_vector = reinterpret_cast(buf_mf_dim_vector->ptr()); + cudaMemcpy(gpu_values, grad_values.data(), + grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice); + cudaMemcpy(gpu_len, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_slot_vector, slot_vector.data(), + slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpy(d_mf_dim_vector, slot_mf_dim_vector.data(), + slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); + PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( + total_grad_values_gpu, gpu_values, gpu_len, slot_lengths.size(), + total_length, batch_size, d_slot_vector, d_mf_dim_vector, + grad_value_size, gpu_accessor_); + cudaStreamSynchronize(stream); +} + +#ifdef PADDLE_WITH_PSCORE +template class AccessorWrapper; +#endif + +} +} +#endif \ No newline at end of file diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index f2001cdf42ae5b..b869d322b6389a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -292,37 +292,38 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { __host__ void BuildFill( float* gpu_val, void* cpu, - paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + paddle::distributed::ValueAccessor* cpu_table_accessor, int mf_dim) { #ifdef PADDLE_WITH_PSCORE + paddle::distributed::CtrDymfAccessor* cpu_accessor = dynamic_cast(cpu_table_accessor); paddle::distributed::FixedFeatureValue* cpu_ptr = (paddle::distributed::FixedFeatureValue*)(cpu); float* cpu_val = cpu_ptr->data(); size_t cpu_dim = cpu_ptr->size(); gpu_val[common_feature_value.DeltaScoreIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.DeltaScoreIndex()]; + cpu_val[cpu_accessor->common_feature_value.DeltaScoreIndex()]; gpu_val[common_feature_value.ShowIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.ShowIndex()]; + cpu_val[cpu_accessor->common_feature_value.ShowIndex()]; gpu_val[common_feature_value.ClickIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.ClickIndex()]; + cpu_val[cpu_accessor->common_feature_value.ClickIndex()]; gpu_val[common_feature_value.SlotIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.SlotIndex()]; + cpu_val[cpu_accessor->common_feature_value.SlotIndex()]; gpu_val[common_feature_value.EmbedWIndex()] = - cpu_val[cpu_table_accessor->common_feature_value.EmbedWIndex()]; + cpu_val[cpu_accessor->common_feature_value.EmbedWIndex()]; for (int i = 0; i < common_feature_value.EmbedDim(); i++) { gpu_val[common_feature_value.EmbedG2SumIndex() + i] = - cpu_val[cpu_table_accessor->common_feature_value.EmbedG2SumIndex() + + cpu_val[cpu_accessor->common_feature_value.EmbedG2SumIndex() + i]; } *(reinterpret_cast( gpu_val + common_feature_value.CpuPtrIndex())) = (uint64_t)(cpu); - cpu_val[cpu_table_accessor->common_feature_value.MfDimIndex()] = + cpu_val[cpu_accessor->common_feature_value.MfDimIndex()] = float(mf_dim); gpu_val[common_feature_value.MfDimIndex()] = mf_dim; if (cpu_dim > - cpu_table_accessor->GetAccessorInfo().dim - - cpu_table_accessor->GetAccessorInfo().mf_size / sizeof(float)) { + cpu_accessor->GetAccessorInfo().dim - + cpu_accessor->GetAccessorInfo().mf_size / sizeof(float)) { gpu_val[common_feature_value.MfSizeIndex()] = common_feature_value.MFSize(mf_dim) / sizeof(float); @@ -330,7 +331,7 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); x++) { gpu_val[common_feature_value.EmbedxG2SumIndex() + x] = cpu_val - [cpu_table_accessor->common_feature_value.EmbedxG2SumIndex() + x]; + [cpu_accessor->common_feature_value.EmbedxG2SumIndex() + x]; } } else { gpu_val[common_feature_value.MfSizeIndex()] = 0; @@ -346,35 +347,37 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { // dump_to_cpu阶段从gpu_val赋值给cpu_val __host__ __device__ void DumpFill( float* gpu_val, - paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + paddle::distributed::ValueAccessor* cpu_table_accessor, int mf_dim) { #ifdef PADDLE_WITH_PSCORE + paddle::distributed::CtrDymfAccessor* cpu_accessor = dynamic_cast(cpu_table_accessor); + auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(*(reinterpret_cast( gpu_val + common_feature_value.CpuPtrIndex()))); size_t downpour_value_size = downpour_value->size(); if (gpu_val[common_feature_value.MfSizeIndex()] > 0 && downpour_value_size == - (cpu_table_accessor->GetAccessorInfo().dim - - int(cpu_table_accessor->GetAccessorInfo().mf_size / + (cpu_accessor->GetAccessorInfo().dim - + int(cpu_accessor->GetAccessorInfo().mf_size / sizeof(float)))) { // cpu_accessor downpour_value->resize( - cpu_table_accessor->common_feature_value.Dim(mf_dim)); + cpu_accessor->common_feature_value.Dim(mf_dim)); } float* cpu_val = downpour_value->data(); - cpu_val[cpu_table_accessor->common_feature_value.DeltaScoreIndex()] = + cpu_val[cpu_accessor->common_feature_value.DeltaScoreIndex()] = gpu_val[common_feature_value.DeltaScoreIndex()]; - cpu_val[cpu_table_accessor->common_feature_value.ShowIndex()] = + cpu_val[cpu_accessor->common_feature_value.ShowIndex()] = gpu_val[common_feature_value.ShowIndex()]; - cpu_val[cpu_table_accessor->common_feature_value.ClickIndex()] = + cpu_val[cpu_accessor->common_feature_value.ClickIndex()] = gpu_val[common_feature_value.ClickIndex()]; - cpu_val[cpu_table_accessor->common_feature_value.EmbedWIndex()] = + cpu_val[cpu_accessor->common_feature_value.EmbedWIndex()] = gpu_val[common_feature_value.EmbedWIndex()]; - cpu_val[cpu_table_accessor->common_feature_value.SlotIndex()] = + cpu_val[cpu_accessor->common_feature_value.SlotIndex()] = gpu_val[common_feature_value.SlotIndex()]; for (int i = 0; i < common_feature_value.EmbedDim(); i++) { - cpu_val[cpu_table_accessor->common_feature_value.EmbedG2SumIndex() + i] = + cpu_val[cpu_accessor->common_feature_value.EmbedG2SumIndex() + i] = gpu_val[common_feature_value.EmbedG2SumIndex() + i]; } @@ -639,17 +642,15 @@ class VirtualAccessor { virtual size_t GetPushValueSize(int& mf_dim) = 0; - // TODO: 在基类里调用cpu_table_accessor类型 virtual void BuildFill( void* gpu_val, void* cpu_val, - paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + paddle::distributed::ValueAccessor* cpu_table_accessor, int mf_dim) = 0; - // TODO: 在基类里调用cpu_table_accessor类型 virtual void DumpFill( float* gpu_val, - paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + paddle::distributed::ValueAccessor* cpu_table_accessor, int mf_dim) = 0; virtual void CopyForPull(const paddle::platform::Place& place, @@ -699,7 +700,7 @@ class AccessorWrapper : public VirtualAccessor { virtual void BuildFill( void* gpu_val, void* cpu_val, - paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + paddle::distributed::ValueAccessor* cpu_table_accessor, int mf_dim) { gpu_accessor_.BuildFill( (float*)(gpu_val), cpu_val, cpu_table_accessor, mf_dim); @@ -707,7 +708,7 @@ class AccessorWrapper : public VirtualAccessor { virtual void DumpFill( float* gpu_val, - paddle::distributed::CtrDymfAccessor* cpu_table_accessor, + paddle::distributed::ValueAccessor* cpu_table_accessor, int mf_dim) { gpu_accessor_.DumpFill(gpu_val, cpu_table_accessor, mf_dim); } diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 38abe87495c4ee..8803c738455b4b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -193,12 +193,8 @@ class HashTable { << " push value size: " << push_grad_value_size_; } - // void set_accessor(FVAccessor& accessor) { - // feature_value_accessor_ = accessor; - // } std::unique_ptr rwlock_{nullptr}; - // FVAccessor feature_value_accessor_; private: #if defined(PADDLE_WITH_CUDA) diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index d40245a3696180..2d5caf001cc9bd 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -711,8 +711,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { for (size_t k = left; k < right; k++) { void* val = mem_pool->mem_address(k); - float* ptr_val = device_dim_ptrs[k]->data(); - size_t dim = device_dim_ptrs[k]->size(); + // float* ptr_val = device_dim_ptrs[k]->data(); + // size_t dim = device_dim_ptrs[k]->size(); #ifdef PADDLE_WITH_PSLIB val->delta_score = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: @@ -748,9 +748,9 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } #endif #ifdef PADDLE_WITH_PSCORE - VLOG(5) << "cpu build " << k - << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) - << " |: " << cpu_table_accessor_->ParseToString(ptr_val, dim); + // VLOG(5) << "cpu build " << k + // << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) + // << " |: " << cpu_table_accessor_->ParseToString(ptr_val, dim); accessor_wrapper_ptr->BuildFill( val, device_dim_ptrs[k], cpu_table_accessor_, mf_dim); VLOG(5) << "build " << k << " : " @@ -989,17 +989,15 @@ void PSGPUWrapper::EndPass() { #endif #ifdef PADDLE_WITH_PSCORE accessor_wrapper_ptr->DumpFill(gpu_val, cpu_table_accessor_, mf_dim); - auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(*( - reinterpret_cast(gpu_val))); - float* cpu_val = downpour_value->data(); - VLOG(5) << "dump to cpu " << index << " gpu_value: " - << accessor_wrapper_ptr->ParseToString( - gpu_val, - int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / - sizeof(float))) - << " \t cpu_value:" - << cpu_table_accessor_->ParseToString(cpu_val, - downpour_value->size()); + // auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(*( + // reinterpret_cast(gpu_val))); + // float* cpu_val = downpour_value->data(); + // VLOG(5) << "dump to cpu " << index << " gpu_value: " + // << accessor_wrapper_ptr->ParseToString(gpu_val, + // int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / sizeof(float))) + // << " \t cpu_value:" + // << cpu_table_accessor_->ParseToString(cpu_val, + // downpour_value->size()); } #endif free(test_build_values); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index f7ba07cbae622d..42fbf8d3a19c14 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -531,8 +531,7 @@ class PSGPUWrapper { #ifdef PADDLE_WITH_PSCORE void SetTableAccessor(paddle::distributed::ValueAccessor* accessor) { - cpu_table_accessor_ = - dynamic_cast(accessor); + cpu_table_accessor_ = accessor; } #endif @@ -592,7 +591,7 @@ class PSGPUWrapper { std::string accessor_class_; std::unordered_map fleet_config_; #ifdef PADDLE_WITH_PSCORE - paddle::distributed::CtrDymfAccessor* cpu_table_accessor_; + paddle::distributed::ValueAccessor* cpu_table_accessor_; #endif #ifdef PADDLE_WITH_CUDA From 9f0ea759cdb1cd3a64f2ffc27990e933aae9c13e Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 7 Jul 2022 03:48:47 +0000 Subject: [PATCH 20/31] format; test=develop --- .../distributed/ps/table/ctr_dymf_accessor.cc | 10 +- .../distributed/ps/table/ctr_dymf_accessor.h | 5 +- .../distributed/ps/table/sparse_sgd_rule.cc | 18 +- .../distributed/ps/table/sparse_sgd_rule.h | 4 +- .../distributed/ps/wrapper/CMakeLists.txt | 9 +- .../framework/fleet/heter_ps/optimizer.cuh.h | 317 +++++++++++------- 6 files changed, 220 insertions(+), 143 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc index cd34c9e0e7ea3b..4feee70fed751a 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc @@ -43,7 +43,8 @@ int CtrDymfAccessor::Initialize() { if (_config.ctr_accessor_param().show_scale()) { _show_scale = true; } - VLOG(0) << " INTO CtrDymfAccessor::Initialize(); embed_sgd_dim:" << common_feature_value.embed_sgd_dim + VLOG(0) << " INTO CtrDymfAccessor::Initialize(); embed_sgd_dim:" + << common_feature_value.embed_sgd_dim << " embedx_dim:" << common_feature_value.embedx_dim << " embedx_sgd_dim:" << common_feature_value.embedx_sgd_dim; InitAccessorInfo(); @@ -182,9 +183,10 @@ int32_t CtrDymfAccessor::Create(float** values, size_t num) { value[common_feature_value.ClickIndex()] = 0; value[common_feature_value.SlotIndex()] = -1; value[common_feature_value.MfDimIndex()] = -1; - _embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(), - value + common_feature_value.EmbedG2SumIndex(), - false); // adam embed init not zero, adagrad embed init zero + _embed_sgd_rule->InitValue( + value + common_feature_value.EmbedWIndex(), + value + common_feature_value.EmbedG2SumIndex(), + false); // adam embed init not zero, adagrad embed init zero _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(), value + common_feature_value.EmbedxG2SumIndex(), false); diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h index 04ff2dbcd3a6dc..df8c27b1d82d20 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h @@ -62,9 +62,9 @@ class CtrDymfAccessor : public ValueAccessor { // 根据mf_dim计算的总长度 int Dim(int& mf_dim) { int tmp_embedx_sgd_dim = 1; - if (optimizer_name == "SparseAdamSGDRule") {//adam + if (optimizer_name == "SparseAdamSGDRule") { // adam tmp_embedx_sgd_dim = mf_dim * 2 + 2; - } else if (optimizer_name == "SparseSharedAdamSGDRule") { //shared_adam + } else if (optimizer_name == "SparseSharedAdamSGDRule") { // shared_adam tmp_embedx_sgd_dim = 4; } return 7 + embed_sgd_dim + tmp_embedx_sgd_dim + mf_dim; @@ -73,7 +73,6 @@ class CtrDymfAccessor : public ValueAccessor { // 根据mf_dim计算的总byte数 int Size(int& mf_dim) { return (Dim(mf_dim)) * sizeof(float); } - float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; } float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; } float& Show(float* val) { return val[ShowIndex()]; } diff --git a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc index 49ee493dbef50a..014d6e450ab4ac 100644 --- a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc +++ b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc @@ -252,8 +252,8 @@ void SparseAdamSGDRule::InitValueWork(float* value, *(sgd + Beta2PowIndex()) = _beta2_decay_rate; } -void SparseSharedAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, - size_t emb_dim) { +void SparseSharedAdamSGDRule::LoadConfig( + const SparseCommonSGDRuleParameter& param, size_t emb_dim) { _embedding_dim = emb_dim; auto adam_param = param.adam(); learning_rate_ = adam_param.learning_rate(); @@ -273,8 +273,10 @@ void SparseSharedAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& par } } -void SparseSharedAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad, - float scale) { +void SparseSharedAdamSGDRule::UpdateValueWork(float* w, + float* sgd, + const float* grad, + float scale) { float* gsum = sgd + GSumIndex(); float* g2sum = sgd + G2SumIndex(); float* beta1_pow = sgd + Beta1PowIndex(); @@ -292,7 +294,8 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* double sum_g2sum = 0.0; for (int i = 0; i < _embedding_dim; i++) { // Calculation - double new_gsum = _beta1_decay_rate * gsum_ + (1 - _beta1_decay_rate) * g[i]; + double new_gsum = + _beta1_decay_rate * gsum_ + (1 - _beta1_decay_rate) * g[i]; double new_g2sum = _beta2_decay_rate * g2sum_ + (1 - _beta2_decay_rate) * g[i] * g[i]; w[i] = w[i] - lr * (new_gsum / (sqrt(new_g2sum) + _ada_epsilon)); @@ -307,8 +310,9 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* (*beta2_pow) *= _beta2_decay_rate; } -void SparseSharedAdamSGDRule::InitValueWork(float* value, float* sgd, - bool zero_init) { +void SparseSharedAdamSGDRule::InitValueWork(float* value, + float* sgd, + bool zero_init) { for (int i = 0; i < _embedding_dim; ++i) { if (zero_init) { value[i] = 0.0; diff --git a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h index aea7fa2cd85f14..473c823f5ae882 100644 --- a/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h +++ b/paddle/fluid/distributed/ps/table/sparse_sgd_rule.h @@ -149,7 +149,9 @@ class SparseSharedAdamSGDRule : public SparseValueSGDRule { public: virtual void LoadConfig(const SparseCommonSGDRuleParameter& param, size_t emb_dim); - virtual void UpdateValueWork(float* w, float* sgd, const float* push_value, + virtual void UpdateValueWork(float* w, + float* sgd, + const float* push_value, float scale); virtual void InitValueWork(float* value, float* sgd, bool zero_init); virtual size_t Dim() { return 4; } diff --git a/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt b/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt index 352e3aa19eb09f..6abd68e5d0aa9a 100644 --- a/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt @@ -1,9 +1,10 @@ get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) -set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses") -if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) - set(DISTRIBUTE_COMPILE_FLAGS - "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") +set(DISTRIBUTE_COMPILE_FLAGS + "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses" +) +if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") endif() set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h index 96116ba954a07b..3a6f60fef858ba 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h @@ -38,33 +38,35 @@ class Optimizer { __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { - printf("Warning: update_value will not used. Please use dy_mf_update_value\n"); + printf( + "Warning: update_value will not used. Please use dy_mf_update_value\n"); } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, - float* ptr, const float* grad) { - } - + float* ptr, + const float* grad) {} + CommonFeatureValueAccessor feature_value_accessor_; size_t _embedding_dim; size_t _lr_embedding_dim; - - }; class SparseAdagradOptimizer : public Optimizer { public: - - __host__ SparseAdagradOptimizer(CommonFeatureValueAccessor feature_value_accessor): Optimizer(feature_value_accessor) { + __host__ SparseAdagradOptimizer( + CommonFeatureValueAccessor feature_value_accessor) + : Optimizer(feature_value_accessor) { _lr_embedding_dim = 1; _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); } - - __device__ void update_value_work(const OptimizerConfig& optimizer_config, int n, - float* w, - float* sgd, // NOLINT - const float* g, float scale) { + + __device__ void update_value_work(const OptimizerConfig& optimizer_config, + int n, + float* w, + float* sgd, // NOLINT + const float* g, + float scale) { float& g2sum = sgd[G2SumIndex()]; double add_g2sum = 0; double ratio = optimizer_config.mf_learning_rate * @@ -88,13 +90,15 @@ class SparseAdagradOptimizer : public Optimizer { __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { - printf("Warning: update_value will not used. Please use dy_mf_update_value\n"); + printf( + "Warning: update_value will not used. Please use dy_mf_update_value\n"); } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, - float* ptr, const float* grad) { + float* ptr, + const float* grad) { float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()]; - float g_click = grad[feature_value_accessor_.common_push_value.ClickIndex()]; - + float g_click = + grad[feature_value_accessor_.common_push_value.ClickIndex()]; ptr[feature_value_accessor_.common_feature_value.SlotIndex()] = grad[feature_value_accessor_.common_push_value.SlotIndex()]; @@ -102,61 +106,73 @@ class SparseAdagradOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click; ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] += optimizer_config.nonclk_coeff * (g_show - g_click) + - optimizer_config.clk_coeff * g_click; - - update_value_work(optimizer_config, 1, - ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), - ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), - grad + feature_value_accessor_.common_push_value.EmbedGIndex(), - g_show); - - int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + optimizer_config.clk_coeff * g_click; + + update_value_work( + optimizer_config, + 1, + ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedGIndex(), + g_show); + + int mf_dim = + int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= - optimizer_config.nonclk_coeff * - (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + - optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); + optimizer_config.nonclk_coeff * + (ptr[feature_value_accessor_.common_feature_value + .ShowIndex()] - + ptr[feature_value_accessor_.common_feature_value + .ClickIndex()]) + + optimizer_config.clk_coeff * + ptr[feature_value_accessor_.common_feature_value + .ClickIndex()]) { + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / + sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); for (int i = 0; i < mf_dim; ++i) { - ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = + ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } } } else { - update_value_work(optimizer_config, mf_dim, + update_value_work( + optimizer_config, + mf_dim, ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(), ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(), grad + feature_value_accessor_.common_push_value.EmbedxGIndex(), g_show); } } - - __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim();} + + __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); } __host__ __device__ size_t EmbedDim() { return _lr_embedding_dim; } __host__ __device__ size_t EmbedxDim() { return _embedding_dim; } __host__ __device__ size_t G2SumIndex() { return 0; } __host__ __device__ size_t EmbedxG2SumIndex() { return 0; } - }; class SparseAdamOptimizer : public Optimizer { public: - - __host__ SparseAdamOptimizer(CommonFeatureValueAccessor feature_value_accessor): Optimizer(feature_value_accessor) { + __host__ SparseAdamOptimizer( + CommonFeatureValueAccessor feature_value_accessor) + : Optimizer(feature_value_accessor) { _lr_embedding_dim = 1; _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); } - __device__ void update_lr(const OptimizerConfig& optimizer_config, int n, + __device__ void update_lr(const OptimizerConfig& optimizer_config, + int n, float* w, float* sgd, - const float* g, float scale) { + const float* g, + float scale) { float* moment1 = sgd + GSumIndex(); float* moment2 = sgd + G2SumIndex(); float* beta1_pow = sgd + Beta1PowIndex(); @@ -166,15 +182,19 @@ class SparseAdamOptimizer : public Optimizer { float beta2_pow_ = *beta2_pow; float epsilon = 1e-08; - double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / + (1.0 - beta1_pow_); for (int i = 0; i < n; ++i) { double scaled_grad = g[i] / scale; - double new_moment1 = optimizer_config.beta1_decay_rate * moment1[i] + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; - double new_moment2 = optimizer_config.beta2_decay_rate * moment2[i] + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; + double new_moment1 = + optimizer_config.beta1_decay_rate * moment1[i] + + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; + double new_moment2 = + optimizer_config.beta2_decay_rate * moment2[i] + + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); - if (w[i] < optimizer_config.mf_min_bound) w[i] = optimizer_config.mf_min_bound; if (w[i] > optimizer_config.mf_max_bound) @@ -187,10 +207,12 @@ class SparseAdamOptimizer : public Optimizer { (*beta2_pow) *= optimizer_config.beta2_decay_rate; } - __device__ void update_mf(const OptimizerConfig& optimizer_config, int n, + __device__ void update_mf(const OptimizerConfig& optimizer_config, + int n, float* w, float* sgd, - const float* g, float scale) { + const float* g, + float scale) { float* moment1 = sgd + EmbedxGSumIndex(); float* moment2 = sgd + EmbedxG2SumIndex(); float* beta1_pow = sgd + EmbedxBeta1PowIndex(); @@ -200,15 +222,19 @@ class SparseAdamOptimizer : public Optimizer { float beta2_pow_ = *beta2_pow; float epsilon = 1e-08; - double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / + (1.0 - beta1_pow_); for (int i = 0; i < n; ++i) { double scaled_grad = g[i] / scale; - double new_moment1 = optimizer_config.beta1_decay_rate * moment1[i] + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; - double new_moment2 = optimizer_config.beta2_decay_rate * moment2[i] + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; + double new_moment1 = + optimizer_config.beta1_decay_rate * moment1[i] + + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; + double new_moment2 = + optimizer_config.beta2_decay_rate * moment2[i] + + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); - if (w[i] < optimizer_config.mf_min_bound) w[i] = optimizer_config.mf_min_bound; if (w[i] > optimizer_config.mf_max_bound) @@ -224,14 +250,15 @@ class SparseAdamOptimizer : public Optimizer { __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { - printf("Warning: update_value will not used. Please use dy_mf_update_value\n"); + printf( + "Warning: update_value will not used. Please use dy_mf_update_value\n"); } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, - float* ptr, const float* grad) { - + float* ptr, + const float* grad) { float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()]; - float g_click = grad[feature_value_accessor_.common_push_value.ClickIndex()]; - + float g_click = + grad[feature_value_accessor_.common_push_value.ClickIndex()]; ptr[feature_value_accessor_.common_feature_value.SlotIndex()] = grad[feature_value_accessor_.common_push_value.SlotIndex()]; @@ -239,73 +266,95 @@ class SparseAdamOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click; ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] += optimizer_config.nonclk_coeff * (g_show - g_click) + - optimizer_config.clk_coeff * g_click; - - update_lr(optimizer_config, 1, - ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), - ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), - grad + feature_value_accessor_.common_push_value.EmbedGIndex(), - g_show); - int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + optimizer_config.clk_coeff * g_click; + + update_lr( + optimizer_config, + 1, + ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedGIndex(), + g_show); + int mf_dim = + int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= - optimizer_config.nonclk_coeff * - (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + - optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { + optimizer_config.nonclk_coeff * + (ptr[feature_value_accessor_.common_feature_value + .ShowIndex()] - + ptr[feature_value_accessor_.common_feature_value + .ClickIndex()]) + + optimizer_config.clk_coeff * + ptr[feature_value_accessor_.common_feature_value + .ClickIndex()]) { ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / + sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); for (int i = 0; i < mf_dim; ++i) { - ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = + ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } - ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta1PowIndex()] = - optimizer_config.beta1_decay_rate; - ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta2PowIndex()] = - optimizer_config.beta2_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + + EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + + EmbedxBeta2PowIndex()] = optimizer_config.beta2_decay_rate; } } else { - update_mf(optimizer_config, mf_dim, + update_mf( + optimizer_config, + mf_dim, ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(), ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(), grad + feature_value_accessor_.common_push_value.EmbedxGIndex(), g_show); } - // printf("EmbedxGIndex: %f, mf_gsum: %f, ", feature_value_accessor_.common_push_value.EmbedxGIndex(), + // printf("EmbedxGIndex: %f, mf_gsum: %f, ", + // feature_value_accessor_.common_push_value.EmbedxGIndex(), // ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex()]); } - + __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); } __host__ __device__ size_t EmbedDim() { return _lr_embedding_dim * 2 + 2; } __host__ __device__ size_t EmbedxDim() { return _embedding_dim * 2 + 2; } __host__ __device__ size_t GSumIndex() { return 0; } - __host__ __device__ size_t G2SumIndex() { return GSumIndex() + _lr_embedding_dim; } - __host__ __device__ size_t Beta1PowIndex() { return G2SumIndex() + _lr_embedding_dim; } + __host__ __device__ size_t G2SumIndex() { + return GSumIndex() + _lr_embedding_dim; + } + __host__ __device__ size_t Beta1PowIndex() { + return G2SumIndex() + _lr_embedding_dim; + } __host__ __device__ size_t Beta2PowIndex() { return Beta1PowIndex() + 1; } __host__ __device__ size_t EmbedxGSumIndex() { return 0; } - __host__ __device__ size_t EmbedxG2SumIndex() { return EmbedxGSumIndex() + _embedding_dim; } - __host__ __device__ size_t EmbedxBeta1PowIndex() { return EmbedxG2SumIndex() + _embedding_dim; } - __host__ __device__ size_t EmbedxBeta2PowIndex() { return EmbedxBeta1PowIndex() + 1; } - + __host__ __device__ size_t EmbedxG2SumIndex() { + return EmbedxGSumIndex() + _embedding_dim; + } + __host__ __device__ size_t EmbedxBeta1PowIndex() { + return EmbedxG2SumIndex() + _embedding_dim; + } + __host__ __device__ size_t EmbedxBeta2PowIndex() { + return EmbedxBeta1PowIndex() + 1; + } }; - class SparseAdamSharedOptimizer : public Optimizer { public: - - __host__ SparseAdamSharedOptimizer(CommonFeatureValueAccessor feature_value_accessor): Optimizer(feature_value_accessor) { + __host__ SparseAdamSharedOptimizer( + CommonFeatureValueAccessor feature_value_accessor) + : Optimizer(feature_value_accessor) { _lr_embedding_dim = 1; _embedding_dim = feature_value_accessor_.common_feature_value.EmbedWDim(); } - __device__ void update_value_work(const OptimizerConfig& optimizer_config, int n, - float* w, - float* sgd, - const float* g, float scale) { + __device__ void update_value_work(const OptimizerConfig& optimizer_config, + int n, + float* w, + float* sgd, + const float* g, + float scale) { float* moment1 = sgd + GSumIndex(); float* moment2 = sgd + G2SumIndex(); float* beta1_pow = sgd + Beta1PowIndex(); @@ -316,18 +365,22 @@ class SparseAdamSharedOptimizer : public Optimizer { float moment1_ = *moment1; float moment2_ = *moment2; float epsilon = 1e-08; - double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / (1.0 - beta1_pow_); + double ratio = optimizer_config.learning_rate * sqrt(1.0 - beta2_pow_) / + (1.0 - beta1_pow_); double sum_mom1 = 0.0; double sum_mom2 = 0.0; for (int i = 0; i < n; ++i) { double scaled_grad = g[i] / scale; - double new_moment1 = optimizer_config.beta1_decay_rate * moment1_ + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; - double new_moment2 = optimizer_config.beta2_decay_rate * moment2_ + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; + double new_moment1 = + optimizer_config.beta1_decay_rate * moment1_ + + (1.0 - optimizer_config.beta1_decay_rate) * scaled_grad; + double new_moment2 = + optimizer_config.beta2_decay_rate * moment2_ + + (1.0 - optimizer_config.beta2_decay_rate) * scaled_grad * scaled_grad; w[i] += ratio * (new_moment1 / (sqrt(new_moment2) + epsilon)); - if (w[i] < optimizer_config.mf_min_bound) w[i] = optimizer_config.mf_min_bound; if (w[i] > optimizer_config.mf_max_bound) @@ -346,14 +399,16 @@ class SparseAdamSharedOptimizer : public Optimizer { __device__ void update_value(const OptimizerConfig& optimizer_config, float& val, // NOLINT const float& grad) { - printf("Warning: update_value will not used. Please use dy_mf_update_value\n"); + printf( + "Warning: update_value will not used. Please use dy_mf_update_value\n"); } __device__ void dy_mf_update_value(const OptimizerConfig& optimizer_config, - float* ptr, const float* grad) { - + float* ptr, + const float* grad) { float g_show = grad[feature_value_accessor_.common_push_value.ShowIndex()]; - float g_click = grad[feature_value_accessor_.common_push_value.ClickIndex()]; + float g_click = + grad[feature_value_accessor_.common_push_value.ClickIndex()]; ptr[feature_value_accessor_.common_feature_value.SlotIndex()] = grad[feature_value_accessor_.common_push_value.SlotIndex()]; @@ -361,44 +416,54 @@ class SparseAdamSharedOptimizer : public Optimizer { ptr[feature_value_accessor_.common_feature_value.ClickIndex()] += g_click; ptr[feature_value_accessor_.common_feature_value.DeltaScoreIndex()] += optimizer_config.nonclk_coeff * (g_show - g_click) + - optimizer_config.clk_coeff * g_click; - - update_value_work(optimizer_config, 1, - ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), - ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), - grad + feature_value_accessor_.common_push_value.EmbedGIndex(), - g_show); - int mf_dim = int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); + optimizer_config.clk_coeff * g_click; + + update_value_work( + optimizer_config, + 1, + ptr + feature_value_accessor_.common_feature_value.EmbedWIndex(), + ptr + feature_value_accessor_.common_feature_value.EmbedG2SumIndex(), + grad + feature_value_accessor_.common_push_value.EmbedGIndex(), + g_show); + int mf_dim = + int(ptr[feature_value_accessor_.common_feature_value.MfDimIndex()]); if (ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] == 0) { if (optimizer_config.mf_create_thresholds <= - optimizer_config.nonclk_coeff * - (ptr[feature_value_accessor_.common_feature_value.ShowIndex()] - - ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) + - optimizer_config.clk_coeff * ptr[feature_value_accessor_.common_feature_value.ClickIndex()]) { - ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = - feature_value_accessor_.common_feature_value.MFSize(mf_dim) / sizeof(float); + optimizer_config.nonclk_coeff * + (ptr[feature_value_accessor_.common_feature_value + .ShowIndex()] - + ptr[feature_value_accessor_.common_feature_value + .ClickIndex()]) + + optimizer_config.clk_coeff * + ptr[feature_value_accessor_.common_feature_value + .ClickIndex()]) { + ptr[feature_value_accessor_.common_feature_value.MfSizeIndex()] = + feature_value_accessor_.common_feature_value.MFSize(mf_dim) / + sizeof(float); int tid_x = blockIdx.x * blockDim.x + threadIdx.x; curandState state; curand_init(clock64(), tid_x, 0, &state); for (int i = 0; i < mf_dim; ++i) { - ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = + ptr[feature_value_accessor_.common_feature_value.EmbedxWIndex() + i] = (curand_uniform(&state)) * optimizer_config.mf_initial_range; } - ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta1PowIndex()] = - optimizer_config.beta1_decay_rate; - ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + EmbedxBeta2PowIndex()] = - optimizer_config.beta2_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + + EmbedxBeta1PowIndex()] = optimizer_config.beta1_decay_rate; + ptr[feature_value_accessor_.common_feature_value.EmbedxG2SumIndex() + + EmbedxBeta2PowIndex()] = optimizer_config.beta2_decay_rate; } } else { - update_value_work(optimizer_config, mf_dim, + update_value_work( + optimizer_config, + mf_dim, ptr + feature_value_accessor_.common_feature_value.EmbedxWIndex(), ptr + feature_value_accessor_.common_feature_value.EmbedxG2SumIndex(), grad + feature_value_accessor_.common_push_value.EmbedxGIndex(), g_show); } } - + __host__ __device__ size_t Dim() { return EmbedDim() + EmbedxDim(); } __host__ __device__ size_t EmbedDim() { return 4; } __host__ __device__ size_t EmbedxDim() { return 4; } @@ -407,13 +472,17 @@ class SparseAdamSharedOptimizer : public Optimizer { __host__ __device__ size_t Beta1PowIndex() { return G2SumIndex() + 1; } __host__ __device__ size_t Beta2PowIndex() { return Beta1PowIndex() + 1; } __host__ __device__ size_t EmbedxGSumIndex() { return 0; } - __host__ __device__ size_t EmbedxG2SumIndex() { return EmbedxGSumIndex() + 1; } - __host__ __device__ size_t EmbedxBeta1PowIndex() { return EmbedxG2SumIndex() + 1; } - __host__ __device__ size_t EmbedxBeta2PowIndex() { return EmbedxBeta1PowIndex() + 1; } - + __host__ __device__ size_t EmbedxG2SumIndex() { + return EmbedxGSumIndex() + 1; + } + __host__ __device__ size_t EmbedxBeta1PowIndex() { + return EmbedxG2SumIndex() + 1; + } + __host__ __device__ size_t EmbedxBeta2PowIndex() { + return EmbedxBeta1PowIndex() + 1; + } }; - #endif } // end namespace framework } // end namespace paddle From 680a02822cb390d6ab0e0bc651704779ad3cc1bc Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 7 Jul 2022 04:08:16 +0000 Subject: [PATCH 21/31] format; test=develop --- paddle/fluid/distributed/ps/wrapper/fleet.cc | 2 +- .../framework/fleet/heter_ps/feature_value.cu | 145 +++++++++++------- .../framework/fleet/heter_ps/hashtable.h | 1 - .../framework/fleet/heter_ps/optimizer_conf.h | 10 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 10 +- 5 files changed, 103 insertions(+), 65 deletions(-) diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index f474810bc87dd3..aa7962363ea309 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -134,7 +134,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc, paddle::distributed::PSClientFactory::Create(ps_param)); worker_ptr_->Configure(ps_param, dense_pull_regions, ps_env_, index); #if defined PADDLE_WITH_HETERPS && defined PADDLE_WITH_PSCORE - VLOG(0) << "FleetWrapper::InitWorker InitializeGPUServer"; + VLOG(3) << "FleetWrapper::InitWorker InitializeGPUServer"; auto* accessor = worker_ptr_->GetTableAccessor(0); auto ps_gpu_wrapper = paddle::framework::PSGPUWrapper::GetInstance(); ps_gpu_wrapper->InitializeGPUServer(ps_param); diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.cu b/paddle/fluid/framework/fleet/heter_ps/feature_value.cu index eff345fe44caa8..560ce33b9af78d 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.cu +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.cu @@ -14,15 +14,18 @@ limitations under the License. */ #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" - namespace paddle { namespace framework { - template -__global__ void PullCopy(float** dest, const float* src, - const int64_t* len, int slot_num, int total_len, - uint64_t** keys, uint64_t max_val_size, int* gpu_dim, +__global__ void PullCopy(float** dest, + const float* src, + const int64_t* len, + int slot_num, + int total_len, + uint64_t** keys, + uint64_t max_val_size, + int* gpu_dim, FVAccessor feature_value_accessor) { CUDA_KERNEL_LOOP(i, total_len) { int low = 0; @@ -39,14 +42,20 @@ __global__ void PullCopy(float** dest, const float* src, float* feature_value_ptr = (float*)((char*)src + uint64_t(i) * uint64_t(max_val_size)); int mf_dim = gpu_dim[x] - 3; - feature_value_accessor.Select(dest[x] + y * (mf_dim + 3), feature_value_ptr, keys[x] + y, mf_dim); + feature_value_accessor.Select( + dest[x] + y * (mf_dim + 3), feature_value_ptr, keys[x] + y, mf_dim); } } template -__global__ void PushCopyWithPool(float* dest, float** src, - int64_t* len, int slot_num, uint64_t total_len, - int bs, int* slot_vector, int* mf_dim_vector, +__global__ void PushCopyWithPool(float* dest, + float** src, + int64_t* len, + int slot_num, + uint64_t total_len, + int bs, + int* slot_vector, + int* mf_dim_vector, size_t grad_value_size, FVAccessor feature_value_accessor) { CUDA_KERNEL_LOOP(i, total_len) { @@ -61,58 +70,71 @@ __global__ void PushCopyWithPool(float* dest, float** src, } int x = low; int y = i - (x ? len[low - 1] : 0); - float* cur = - (float*)((char*)dest + i * grad_value_size); + float* cur = (float*)((char*)dest + i * grad_value_size); - cur[feature_value_accessor.common_push_value.SlotIndex()] = + cur[feature_value_accessor.common_push_value.SlotIndex()] = (float)slot_vector[x]; int mf_dim = mf_dim_vector[x]; cur[feature_value_accessor.common_push_value.MfDimIndex()] = mf_dim; - cur[feature_value_accessor.common_push_value.ShowIndex()] = - *(src[x] + y * (mf_dim + 3)); - cur[feature_value_accessor.common_push_value.ClickIndex()] = - *(src[x] + y * (mf_dim + 3) + 1); - cur[feature_value_accessor.common_push_value.EmbedGIndex()] = - *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; + cur[feature_value_accessor.common_push_value.ShowIndex()] = + *(src[x] + y * (mf_dim + 3)); + cur[feature_value_accessor.common_push_value.ClickIndex()] = + *(src[x] + y * (mf_dim + 3) + 1); + cur[feature_value_accessor.common_push_value.EmbedGIndex()] = + *(src[x] + y * (mf_dim + 3) + 2) * -1. * bs; for (int j = 0; j < mf_dim; j++) { - cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; + cur[feature_value_accessor.common_push_value.EmbedxGIndex() + j] = + *(src[x] + y * (mf_dim + 3) + 3 + j) * -1. * bs; } } } template -void AccessorWrapper::CopyForPullImpl(const paddle::platform::Place& place, - uint64_t** gpu_keys, - const std::vector& values, - const float* total_values_gpu, - const int64_t* gpu_len, const int slot_num, - const int hidden_size, - const int64_t total_length, - int* gpu_dim, - int feature_value_size) { +void AccessorWrapper::CopyForPullImpl( + const paddle::platform::Place& place, + uint64_t** gpu_keys, + const std::vector& values, + const float* total_values_gpu, + const int64_t* gpu_len, + const int slot_num, + const int hidden_size, + const int64_t total_length, + int* gpu_dim, + int feature_value_size) { auto stream = dynamic_cast( paddle::platform::DeviceContextPool::Instance().Get(place)) ->stream(); auto buf_value = memory::Alloc(place, values.size() * sizeof(float*)); float** gpu_values = reinterpret_cast(buf_value->ptr()); - cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), - cudaMemcpyHostToDevice); + cudaMemcpy(gpu_values, + values.data(), + values.size() * sizeof(float*), + cudaMemcpyHostToDevice); PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( - gpu_values, total_values_gpu, gpu_len, slot_num, total_length, gpu_keys, - feature_value_size, gpu_dim, gpu_accessor_); + gpu_values, + total_values_gpu, + gpu_len, + slot_num, + total_length, + gpu_keys, + feature_value_size, + gpu_dim, + gpu_accessor_); cudaStreamSynchronize(stream); } template -void AccessorWrapper::CopyForPushImpl(const paddle::platform::Place& place, - const std::vector& grad_values, - float* total_grad_values_gpu, - const std::vector& slot_lengths, - const uint64_t total_length, - const int batch_size, size_t grad_value_size, - std::vector& slot_vector, - std::vector& slot_mf_dim_vector) { +void AccessorWrapper::CopyForPushImpl( + const paddle::platform::Place& place, + const std::vector& grad_values, + float* total_grad_values_gpu, + const std::vector& slot_lengths, + const uint64_t total_length, + const int batch_size, + size_t grad_value_size, + std::vector& slot_vector, + std::vector& slot_mf_dim_vector) { auto stream = dynamic_cast( paddle::platform::DeviceContextPool::Instance().Get(place)) ->stream(); @@ -131,18 +153,33 @@ void AccessorWrapper::CopyForPushImpl(const paddle::platform::Place int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); int* d_slot_vector = reinterpret_cast(buf_slot_vector->ptr()); int* d_mf_dim_vector = reinterpret_cast(buf_mf_dim_vector->ptr()); - cudaMemcpy(gpu_values, grad_values.data(), - grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice); - cudaMemcpy(gpu_len, slot_lengths_lod.data(), - slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_slot_vector, slot_vector.data(), - slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); - cudaMemcpy(d_mf_dim_vector, slot_mf_dim_vector.data(), - slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); + cudaMemcpy(gpu_values, + grad_values.data(), + grad_values.size() * sizeof(float*), + cudaMemcpyHostToDevice); + cudaMemcpy(gpu_len, + slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), + cudaMemcpyHostToDevice); + cudaMemcpy(d_slot_vector, + slot_vector.data(), + slot_lengths_lod.size() * sizeof(int), + cudaMemcpyHostToDevice); + cudaMemcpy(d_mf_dim_vector, + slot_mf_dim_vector.data(), + slot_lengths_lod.size() * sizeof(int), + cudaMemcpyHostToDevice); PushCopyWithPool<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>( - total_grad_values_gpu, gpu_values, gpu_len, slot_lengths.size(), - total_length, batch_size, d_slot_vector, d_mf_dim_vector, - grad_value_size, gpu_accessor_); + total_grad_values_gpu, + gpu_values, + gpu_len, + slot_lengths.size(), + total_length, + batch_size, + d_slot_vector, + d_mf_dim_vector, + grad_value_size, + gpu_accessor_); cudaStreamSynchronize(stream); } @@ -150,6 +187,6 @@ void AccessorWrapper::CopyForPushImpl(const paddle::platform::Place template class AccessorWrapper; #endif -} -} -#endif \ No newline at end of file +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 8803c738455b4b..f5a54a0387a8f3 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -193,7 +193,6 @@ class HashTable { << " push value size: " << push_grad_value_size_; } - std::unique_ptr rwlock_{nullptr}; private: diff --git a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h index 8b301b9dbae015..2db259941c873a 100644 --- a/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h +++ b/paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h @@ -27,16 +27,16 @@ class OptimizerConfig { float learning_rate = 0.05; float initial_g2sum = 3.0; float initial_range = 0; - float beta1_decay_rate = 0.9; //adam - float beta2_decay_rate = 0.999; //adam + float beta1_decay_rate = 0.9; // adam + float beta2_decay_rate = 0.999; // adam float ada_epsilon = 1e-8; float mf_create_thresholds = 10; float mf_learning_rate = 0.05; float mf_initial_g2sum = 3.0; float mf_initial_range = 1e-4; - float mf_beta1_decay_rate = 0.9; //adam - float mf_beta2_decay_rate = 0.999; //adam + float mf_beta1_decay_rate = 0.9; // adam + float mf_beta2_decay_rate = 0.999; // adam float mf_min_bound = -10; float mf_max_bound = 10; float mf_ada_epsilon = 1e-8; @@ -48,7 +48,7 @@ class OptimizerConfig { float learning_rate, float initial_g2sum, float initial_range, - float beta1_decay_rate, + float beta1_decay_rate, float beta2_decay_rate, float ada_epsilon) { this->nonclk_coeff = nonclk_coeff; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 2d5caf001cc9bd..bb40b7c0ae79aa 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -711,8 +711,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { for (size_t k = left; k < right; k++) { void* val = mem_pool->mem_address(k); - // float* ptr_val = device_dim_ptrs[k]->data(); - // size_t dim = device_dim_ptrs[k]->size(); + // float* ptr_val = device_dim_ptrs[k]->data(); + // size_t dim = device_dim_ptrs[k]->size(); #ifdef PADDLE_WITH_PSLIB val->delta_score = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: @@ -750,7 +750,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { #ifdef PADDLE_WITH_PSCORE // VLOG(5) << "cpu build " << k // << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) - // << " |: " << cpu_table_accessor_->ParseToString(ptr_val, dim); + // << " |: " << cpu_table_accessor_->ParseToString(ptr_val, + // dim); accessor_wrapper_ptr->BuildFill( val, device_dim_ptrs[k], cpu_table_accessor_, mf_dim); VLOG(5) << "build " << k << " : " @@ -994,7 +995,8 @@ void PSGPUWrapper::EndPass() { // float* cpu_val = downpour_value->data(); // VLOG(5) << "dump to cpu " << index << " gpu_value: " // << accessor_wrapper_ptr->ParseToString(gpu_val, - // int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / sizeof(float))) + // int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / + // sizeof(float))) // << " \t cpu_value:" // << cpu_table_accessor_->ParseToString(cpu_val, // downpour_value->size()); From 4bf0be2e902478f002f236f62edfc5310830a29c Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 7 Jul 2022 05:44:13 +0000 Subject: [PATCH 22/31] format; test=develop --- cmake/cuda.cmake | 3 +- .../framework/fleet/heter_ps/feature_value.h | 63 +++++++++---------- 2 files changed, 29 insertions(+), 37 deletions(-) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index ea53b103c333e1..87b943abd0106d 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -260,8 +260,7 @@ add_definitions("-DCUDA_VERSION_MINOR=\"${CUDA_VERSION_MINOR}\"") add_definitions("-DCUDA_TOOLKIT_ROOT_DIR=\"${CUDA_TOOLKIT_ROOT_DIR}\"") # setting nvcc arch flags -#select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) -set(NVCC_FLAGS_EXTRA "-gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80") +select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}") message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}") diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index b869d322b6389a..da53cc647a6034 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -295,7 +295,8 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { paddle::distributed::ValueAccessor* cpu_table_accessor, int mf_dim) { #ifdef PADDLE_WITH_PSCORE - paddle::distributed::CtrDymfAccessor* cpu_accessor = dynamic_cast(cpu_table_accessor); + paddle::distributed::CtrDymfAccessor* cpu_accessor = + dynamic_cast(cpu_table_accessor); paddle::distributed::FixedFeatureValue* cpu_ptr = (paddle::distributed::FixedFeatureValue*)(cpu); float* cpu_val = cpu_ptr->data(); @@ -313,25 +314,22 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { cpu_val[cpu_accessor->common_feature_value.EmbedWIndex()]; for (int i = 0; i < common_feature_value.EmbedDim(); i++) { gpu_val[common_feature_value.EmbedG2SumIndex() + i] = - cpu_val[cpu_accessor->common_feature_value.EmbedG2SumIndex() + - i]; + cpu_val[cpu_accessor->common_feature_value.EmbedG2SumIndex() + i]; } *(reinterpret_cast( gpu_val + common_feature_value.CpuPtrIndex())) = (uint64_t)(cpu); - cpu_val[cpu_accessor->common_feature_value.MfDimIndex()] = - float(mf_dim); + cpu_val[cpu_accessor->common_feature_value.MfDimIndex()] = float(mf_dim); gpu_val[common_feature_value.MfDimIndex()] = mf_dim; - if (cpu_dim > - cpu_accessor->GetAccessorInfo().dim - - cpu_accessor->GetAccessorInfo().mf_size / sizeof(float)) { + if (cpu_dim > cpu_accessor->GetAccessorInfo().dim - + cpu_accessor->GetAccessorInfo().mf_size / sizeof(float)) { gpu_val[common_feature_value.MfSizeIndex()] = common_feature_value.MFSize(mf_dim) / sizeof(float); for (int x = 0; x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); x++) { - gpu_val[common_feature_value.EmbedxG2SumIndex() + x] = cpu_val - [cpu_accessor->common_feature_value.EmbedxG2SumIndex() + x]; + gpu_val[common_feature_value.EmbedxG2SumIndex() + x] = + cpu_val[cpu_accessor->common_feature_value.EmbedxG2SumIndex() + x]; } } else { gpu_val[common_feature_value.MfSizeIndex()] = 0; @@ -350,19 +348,18 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { paddle::distributed::ValueAccessor* cpu_table_accessor, int mf_dim) { #ifdef PADDLE_WITH_PSCORE - paddle::distributed::CtrDymfAccessor* cpu_accessor = dynamic_cast(cpu_table_accessor); + paddle::distributed::CtrDymfAccessor* cpu_accessor = + dynamic_cast(cpu_table_accessor); auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(*(reinterpret_cast( gpu_val + common_feature_value.CpuPtrIndex()))); size_t downpour_value_size = downpour_value->size(); if (gpu_val[common_feature_value.MfSizeIndex()] > 0 && - downpour_value_size == - (cpu_accessor->GetAccessorInfo().dim - - int(cpu_accessor->GetAccessorInfo().mf_size / - sizeof(float)))) { // cpu_accessor - downpour_value->resize( - cpu_accessor->common_feature_value.Dim(mf_dim)); + downpour_value_size == (cpu_accessor->GetAccessorInfo().dim - + int(cpu_accessor->GetAccessorInfo().mf_size / + sizeof(float)))) { // cpu_accessor + downpour_value->resize(cpu_accessor->common_feature_value.Dim(mf_dim)); } float* cpu_val = downpour_value->data(); cpu_val[cpu_accessor->common_feature_value.DeltaScoreIndex()] = @@ -642,16 +639,14 @@ class VirtualAccessor { virtual size_t GetPushValueSize(int& mf_dim) = 0; - virtual void BuildFill( - void* gpu_val, - void* cpu_val, - paddle::distributed::ValueAccessor* cpu_table_accessor, - int mf_dim) = 0; + virtual void BuildFill(void* gpu_val, + void* cpu_val, + paddle::distributed::ValueAccessor* cpu_table_accessor, + int mf_dim) = 0; - virtual void DumpFill( - float* gpu_val, - paddle::distributed::ValueAccessor* cpu_table_accessor, - int mf_dim) = 0; + virtual void DumpFill(float* gpu_val, + paddle::distributed::ValueAccessor* cpu_table_accessor, + int mf_dim) = 0; virtual void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys, @@ -697,19 +692,17 @@ class AccessorWrapper : public VirtualAccessor { return gpu_accessor_.common_push_value.Size(mf_dim); } - virtual void BuildFill( - void* gpu_val, - void* cpu_val, - paddle::distributed::ValueAccessor* cpu_table_accessor, - int mf_dim) { + virtual void BuildFill(void* gpu_val, + void* cpu_val, + paddle::distributed::ValueAccessor* cpu_table_accessor, + int mf_dim) { gpu_accessor_.BuildFill( (float*)(gpu_val), cpu_val, cpu_table_accessor, mf_dim); } - virtual void DumpFill( - float* gpu_val, - paddle::distributed::ValueAccessor* cpu_table_accessor, - int mf_dim) { + virtual void DumpFill(float* gpu_val, + paddle::distributed::ValueAccessor* cpu_table_accessor, + int mf_dim) { gpu_accessor_.DumpFill(gpu_val, cpu_table_accessor, mf_dim); } From e0610b03d62f5668245371e6b94bed27ee21c49c Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 7 Jul 2022 12:59:03 +0000 Subject: [PATCH 23/31] format; test=develop --- paddle/fluid/distributed/ps/wrapper/fleet.cc | 20 +- .../framework/fleet/heter_ps/feature_value.h | 11 +- .../fleet/heter_ps/graph_gpu_ps_table.h | 3 +- .../framework/fleet/heter_ps/hashtable.h | 2 +- .../fleet/heter_ps/hashtable_kernel.cu | 13 +- .../framework/fleet/heter_ps/heter_comm.h | 3 - .../framework/fleet/heter_ps/heter_comm_inl.h | 119 +++--- .../fleet/heter_ps/heter_comm_kernel.cu | 16 +- .../fleet/heter_ps/heter_comm_kernel.h | 15 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 371 +++++++++--------- 10 files changed, 271 insertions(+), 302 deletions(-) diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index aa7962363ea309..0f7ab1f8ff6b65 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -612,12 +612,10 @@ void FleetWrapper::PushSparseFromTensorAsync( // in // ctr_accessor.h push_values.back()[0] = 2; // TODO(zhaocaibei123): slot - push_values.back()[1] = (static_cast(i) >= show_size - ? 1 - : static_cast(show_tensor[i])); - push_values.back()[2] = (static_cast(i) >= clk_size - ? 0 - : static_cast(clk_tensor[i])); + push_values.back()[1] = + (i >= show_size ? 1 : static_cast(show_tensor[i])); + push_values.back()[2] = + (i >= clk_size ? 0 : static_cast(clk_tensor[i])); float* data = push_values.back().data() + 3; memcpy(data, g + output_len, sizeof(float) * fea_dim); } @@ -641,12 +639,10 @@ void FleetWrapper::PushSparseFromTensorAsync( // slot show clk grad... consistent with CtrCommonPushValue defined in // ctr_accessor.h push_values.back()[0] = 2; // TODO(zhaocaibei123): slot - push_values.back()[1] = (static_cast(i) >= show_size - ? 1 - : static_cast(show_tensor[i])); - push_values.back()[2] = (static_cast(i) >= clk_size - ? 0 - : static_cast(clk_tensor[i])); + push_values.back()[1] = + (i >= show_size ? 1 : static_cast(show_tensor[i])); + push_values.back()[2] = + (i >= clk_size ? 0 : static_cast(clk_tensor[i])); float* data = push_values.back().data() + 3; memcpy(data, g + output_len, sizeof(float) * fea_dim); } diff --git a/paddle/fluid/framework/fleet/heter_ps/feature_value.h b/paddle/fluid/framework/fleet/heter_ps/feature_value.h index da53cc647a6034..ef4533d64eac2e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/feature_value.h +++ b/paddle/fluid/framework/fleet/heter_ps/feature_value.h @@ -343,10 +343,9 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { } // dump_to_cpu阶段从gpu_val赋值给cpu_val - __host__ __device__ void DumpFill( - float* gpu_val, - paddle::distributed::ValueAccessor* cpu_table_accessor, - int mf_dim) { + __host__ void DumpFill(float* gpu_val, + paddle::distributed::ValueAccessor* cpu_table_accessor, + int mf_dim) { #ifdef PADDLE_WITH_PSCORE paddle::distributed::CtrDymfAccessor* cpu_accessor = dynamic_cast(cpu_table_accessor); @@ -382,8 +381,8 @@ class CommonFeatureValueAccessor : public FeatureValueAccessor { for (int x = 0; x < int(common_feature_value.MFSize(mf_dim) / sizeof(float)); x++) { - cpu_val[cpu_table_accessor->common_feature_value.EmbedxG2SumIndex() + - x] = gpu_val[common_feature_value.EmbedxG2SumIndex() + x]; + cpu_val[cpu_accessor->common_feature_value.EmbedxG2SumIndex() + x] = + gpu_val[common_feature_value.EmbedxG2SumIndex() + x]; } } #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index 0e0d525c93a73e..a4bee2c19bbdaa 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -29,7 +29,8 @@ class GpuPsGraphTable : public HeterComm { public: GpuPsGraphTable(std::shared_ptr resource, int topo_aware) - : HeterComm(1, resource) { + : HeterComm( + 1, resource) { load_factor_ = 0.25; rw_lock.reset(new pthread_rwlock_t()); gpu_num = resource_->total_device(); diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index f5a54a0387a8f3..43192df0c71f03 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -156,7 +156,7 @@ class HashTable { template void update(const KeyType* d_keys, - const GradType* d_grads, + const float* d_grads, size_t len, Sgd sgd, StreamType stream); diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu index 39c15ebbad51c6..e7db238a13b750 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu @@ -376,8 +376,7 @@ template void HashTable::get( unsigned long* d_vals, size_t len, cudaStream_t stream); -template void HashTable::get( - const unsigned long* d_keys, long* d_vals, size_t len, cudaStream_t stream); + template void HashTable::get( const long* d_keys, unsigned long* d_vals, size_t len, cudaStream_t stream); template void HashTable::get(const long* d_keys, @@ -470,16 +469,6 @@ template void HashTable::update< SparseAdamSharedOptimizer sgd, cudaStream_t stream); -template void HashTable:: - update, - cudaStream_t>(const unsigned long* d_keys, - const char* d_grads, - size_t len, - Optimizer sgd, - cudaStream_t stream); - // template void HashTable::update< // Optimizer resource); - HeterComm(size_t capacity, - std::shared_ptr resource, - CommonFeatureValueAccessor& accessor); virtual ~HeterComm(); HeterComm(const HeterComm&) = delete; HeterComm& operator=(const HeterComm&) = delete; diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index 353f02d026380a..f8657c8e895ad3 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/fleet/heter_ps/feature_value.h" -#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h" #include "paddle/fluid/platform/device_context.h" #ifdef PADDLE_WITH_XPU_KP @@ -1137,76 +1136,74 @@ void HeterComm::push_sparse( grad_value_size, stream, feature_value_accessor_); -} -sync_stream(stream); - -auto dst_place = platform::CPUPlace(); -auto src_place = place; -memory_copy(dst_place, - h_left, - src_place, - d_left_ptr, - total_device * sizeof(int), - stream); -memory_copy(dst_place, - h_right, - src_place, - d_right_ptr, - total_device * sizeof(int), - stream); - -for (int i = 0; i < total_device; ++i) { - int shard_len = h_right[i] - h_left[i] + 1; - if (h_left[i] == -1 || h_right[i] == -1) { - continue; + sync_stream(stream); + + auto dst_place = platform::CPUPlace(); + auto src_place = place; + memory_copy(dst_place, + h_left, + src_place, + d_left_ptr, + total_device * sizeof(int), + stream); + memory_copy(dst_place, + h_right, + src_place, + d_right_ptr, + total_device * sizeof(int), + stream); + + for (int i = 0; i < total_device; ++i) { + int shard_len = h_right[i] - h_left[i] + 1; + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + create_storage( + dev_num, i, shard_len * sizeof(KeyType), shard_len * grad_value_size); } - create_storage( - dev_num, i, shard_len * sizeof(KeyType), shard_len * grad_value_size); -} -walk_to_dest(dev_num, - total_device, - h_left, - h_right, - d_shard_keys_ptr, - reinterpret_cast(d_shard_grads_ptr), - grad_value_size); -} + walk_to_dest(dev_num, + total_device, + h_left, + h_right, + d_shard_keys_ptr, + reinterpret_cast(d_shard_grads_ptr), + grad_value_size); + + for (int i = 0; i < total_device; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + auto& node = path_[dev_num][i].nodes_.back(); + sync_stream(node.in_stream); -for (int i = 0; i < total_device; ++i) { - if (h_left[i] == -1 || h_right[i] == -1) { - continue; + AnyDeviceGuard guard(resource_->dev_id(i)); + ptr_tables_[i]->rwlock_->WRLock(); + ptr_tables_[i]->update(reinterpret_cast(node.key_storage), + node.val_storage, + h_right[i] - h_left[i] + 1, + sgd, + resource_->remote_stream(i, dev_num)); } - auto& node = path_[dev_num][i].nodes_.back(); - sync_stream(node.in_stream); - - AnyDeviceGuard guard(resource_->dev_id(i)); - ptr_tables_[i]->rwlock_->WRLock(); - ptr_tables_[i]->update(reinterpret_cast(node.key_storage), - node.val_storage, - h_right[i] - h_left[i] + 1, - sgd, - resource_->remote_stream(i, dev_num)); -} -for (int i = 0; i < total_device; ++i) { - sync_stream(resource_->remote_stream(i, dev_num)); - if (h_left[i] != -1) { - if (!multi_mf_dim_) { - tables_[i]->rwlock_->UNLock(); - } else { - ptr_tables_[i]->rwlock_->UNLock(); + for (int i = 0; i < total_device; ++i) { + sync_stream(resource_->remote_stream(i, dev_num)); + if (h_left[i] != -1) { + if (!multi_mf_dim_) { + tables_[i]->rwlock_->UNLock(); + } else { + ptr_tables_[i]->rwlock_->UNLock(); + } } } -} -for (int i = 0; i < total_device; ++i) { - if (h_left[i] == -1 || h_right[i] == -1) { - continue; + for (int i = 0; i < total_device; ++i) { + if (h_left[i] == -1 || h_right[i] == -1) { + continue; + } + destroy_storage(dev_num, i); } - destroy_storage(dev_num, i); -} } #elif defined(PADDLE_WITH_XPU_KP) diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu index 15e31e450de72d..ebf7e76527af0e 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu @@ -166,11 +166,11 @@ __global__ void merge_gradients_kernel(const uint32_t* offset, int ori_index = index[start]; float* out = (float*)(output + i * grad_value_size); float* in = (float*)(input + size_t(ori_index) * grad_value_size); - merger_.update_one(out, in, feature_value_accessor); + merger.update_one(out, in, feature_value_accessor); for (int j = 1; j < num; ++j) { ori_index = index[start + j]; - float& rhs = *(float*)(input + size_t(ori_index) * grad_value_size); - merger_.merge_one(out, rhs, feature_value_accessor); + in = (float*)(input + size_t(ori_index) * grad_value_size); + merger.merge_one(out, in, feature_value_accessor); } } } @@ -329,13 +329,13 @@ template >>( @@ -346,7 +346,7 @@ void HeterCommKernel::dy_mf_fill_shard_grads( idx, c_len, grad_value_size, - feature_value_accessor_); + feature_value_accessor); } template @@ -370,7 +370,7 @@ void HeterCommKernel::merge_gradient(const uint32_t* offset, n, grad_value_size, merger_, - feature_value_accessor_); + feature_value_accessor); } template diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h index 82969ea45ba441..57f0aff4b6e56b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h @@ -41,17 +41,15 @@ struct DynamicGradMerger { return out; } + template __device__ __forceinline__ void update_one( - float* output, - const float* input, - CommonFeatureValueAccessor& feature_value_accessor) { + float* output, const float* input, FVAccessor& feature_value_accessor) { feature_value_accessor.PushValueFill(output, input); } + template __device__ __forceinline__ void merge_one( - float* output, - const float* input, - CommonFeatureValueAccessor& feature_value_accessor) { + float* output, const float* input, FVAccessor& feature_value_accessor) { feature_value_accessor.MergePushValue(output, input); } }; @@ -61,11 +59,6 @@ class HeterCommKernel { HeterCommKernel() {} explicit HeterCommKernel(const int block_size) : block_size_(block_size) {} - // explicit HeterCommKernel(const int block_size, CommonFeatureValueAccessor& - // feature_value_accessor) : block_size_(block_size), - // feature_value_accessor_(feature_value_accessor) {} - // explicit HeterCommKernel(const int block_size) : block_size_(block_size) {} - template void fill_idx(T* idx, long long len, const StreamType& stream); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index bb40b7c0ae79aa..f84ac1ea61e349 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -135,7 +135,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { std::string data_set_name = std::string(typeid(*dataset_).name()); if (data_set_name.find("SlotRecordDataset") != std::string::npos) { - SlotRecordDataset* dataset = dynamic_cast(dataset_); + SlotRecordDataset* dataset = (SlotRecordDataset*)(dataset_); auto input_channel = dataset->GetInputChannel(); VLOG(0) << "psgpu wrapperinputslotchannle size: " << input_channel->Size(); const std::deque& vec_data = input_channel->GetData(); @@ -185,7 +185,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr gpu_task) { } else { CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos); VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset"; - MultiSlotDataset* dataset = dynamic_cast(dataset_); + MultiSlotDataset* dataset = (MultiSlotDataset*)(dataset_); auto input_channel = dataset->GetInputChannel(); const std::deque& vec_data = input_channel->GetData(); @@ -272,6 +272,8 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { auto& local_dim_keys = gpu_task->feature_dim_keys_; auto& local_dim_ptr = gpu_task->value_dim_ptr_; + auto& device_keys = gpu_task->device_keys_; + auto& device_vals = gpu_task->device_values_; auto& device_dim_keys = gpu_task->device_dim_keys_; auto& device_dim_ptr = gpu_task->device_dim_ptr_; auto& device_dim_mutex = gpu_task->dim_mutex_; @@ -590,6 +592,17 @@ void PSGPUWrapper::BuildPull(std::shared_ptr gpu_task) { for (std::thread& t : threads) { t.join(); } + } else { + for (int i = 0; i < thread_keys_shard_num_; i++) { + for (int j = 0; j < device_num; j++) { + task_futures.emplace_back( + hbm_thread_pool_[i]->enqueue(prepare_dev_value_func, j, i)); + } + } + for (auto& f : task_futures) { + f.wait(); + } + task_futures.clear(); } timeline.Pause(); VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() @@ -681,88 +694,90 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { delete mem_pool; }; int thread_num = 16; - auto build_dynamic_mf_func = - [this, &gpu_task, thread_num](int i, int j, int z) { - // this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); - int mf_dim = this->index_dim_vec_[j]; - VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim; - auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; - auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; - size_t len = device_dim_keys.size(); - CHECK(len == device_dim_ptrs.size()); - auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; - - // ============ add for multi-thread ================ - size_t len_per_thread = len / thread_num; - size_t remain = len % thread_num; - size_t left = 0, right = 0; - - size_t real_len = len_per_thread; - if ((size_t)z < remain) real_len++; - - if ((size_t)z < remain) { - left = z * (len_per_thread + 1); - right = left + real_len; - } else { - left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread; - right = left + real_len; - } - // ============ add for multi-thread ================ + auto build_dynamic_mf_func = [this, + &gpu_task, + thread_num, + &accessor_wrapper_ptr](int i, int j, int z) { + // this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); + int mf_dim = this->index_dim_vec_[j]; + VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim; + auto& device_dim_keys = gpu_task->device_dim_keys_[i][j]; + auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; + size_t len = device_dim_keys.size(); + CHECK(len == device_dim_ptrs.size()); + auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; + + // ============ add for multi-thread ================ + size_t len_per_thread = len / thread_num; + size_t remain = len % thread_num; + size_t left = 0, right = 0; + + size_t real_len = len_per_thread; + if ((size_t)z < remain) real_len++; + + if ((size_t)z < remain) { + left = z * (len_per_thread + 1); + right = left + real_len; + } else { + left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread; + right = left + real_len; + } + // ============ add for multi-thread ================ - for (size_t k = left; k < right; k++) { - void* val = mem_pool->mem_address(k); - // float* ptr_val = device_dim_ptrs[k]->data(); - // size_t dim = device_dim_ptrs[k]->size(); #ifdef PADDLE_WITH_PSLIB - val->delta_score = - ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::delta_score_index()]; - val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::show_index()]; - val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::click_index()]; - val->slot = - int(ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::slot_index()]); - val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::embed_w_index()]; - val->lr_g2sum = - ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::embed_g2sum_index()]; - // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor + for (size_t k = left; k < right; k++) { + float* ptr_val = device_dim_ptrs[k]->data(); + size_t dim = device_dim_ptrs[k]->size(); + val->delta_score = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::mf_dim_index()] = - float(mf_dim); - val->mf_dim = mf_dim; - if (dim > 8) { // CpuPS alreay expand as mf_dim - val->mf_size = mf_dim + 1; - for (int x = 0; x < val->mf_dim + 1; x++) { - val->mf[x] = ptr_val[x + 8]; - } - } else { - val->mf_size = 0; - for (int x = 0; x < val->mf_dim + 1; x++) { - val->mf[x] = 0; - } - } + DownpourCtrDymfFeatureValue::delta_score_index()]; + val->show = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::show_index()]; + val->clk = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::click_index()]; + val->slot = int(ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::slot_index()]); + val->lr = ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_w_index()]; + val->lr_g2sum = + ptr_val[paddle::ps::DownpourCtrDymfAccessor:: + DownpourCtrDymfFeatureValue::embed_g2sum_index()]; + // TODO(xuefeng) set mf_dim while using DownpourCtrDymfAccessor + ptr_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + mf_dim_index()] = float(mf_dim); + val->mf_dim = mf_dim; + if (dim > 8) { // CpuPS alreay expand as mf_dim + val->mf_size = mf_dim + 1; + for (int x = 0; x < val->mf_dim + 1; x++) { + val->mf[x] = ptr_val[x + 8]; + } + } else { + val->mf_size = 0; + for (int x = 0; x < val->mf_dim + 1; x++) { + val->mf[x] = 0; } + } + } #endif #ifdef PADDLE_WITH_PSCORE - // VLOG(5) << "cpu build " << k - // << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) - // << " |: " << cpu_table_accessor_->ParseToString(ptr_val, - // dim); - accessor_wrapper_ptr->BuildFill( - val, device_dim_ptrs[k], cpu_table_accessor_, mf_dim); - VLOG(5) << "build " << k << " : " - << accessor_wrapper_ptr->ParseToString( - (float*)(val), - int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / - sizeof(float))); - } + for (size_t k = left; k < right; k++) { + void* val = mem_pool->mem_address(k); + // VLOG(5) << "cpu build " << k + // << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) + // << " |: " << cpu_table_accessor_->ParseToString(ptr_val, + // dim); + accessor_wrapper_ptr->BuildFill( + val, device_dim_ptrs[k], cpu_table_accessor_, mf_dim); + VLOG(5) << "build " << k << " : " + << accessor_wrapper_ptr->ParseToString( + (float*)(val), + int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / + sizeof(float))); + } #endif + }; - threads.resize(device_num * multi_mf_dim_); + threads.resize(device_num * multi_mf_dim_); for (int i = 0; i < device_num; i++) { for (int j = 0; j < multi_mf_dim_; j++) { threads[i + j * device_num] = std::thread(build_dymf_mem_pool, i, j); @@ -913,126 +928,109 @@ void PSGPUWrapper::EndPass() { auto accessor_wrapper_ptr = GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); int thread_num = 8; - auto dump_pool_to_cpu_func = - [this, thread_num, &accessor_wrapper_ptr](int i, int j, int z) { - PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); - auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; - auto& device_keys = this->current_task_->device_dim_keys_[i][j]; - size_t len = device_keys.size(); - // ====== multi-thread process feasign================ - int len_per_thread = len / thread_num; - int remain = len % thread_num; - int left = -1, right = -1; - int real_len = len_per_thread; - if (z < remain) real_len++; - if (z < remain) { - left = z * (len_per_thread + 1); - right = left + real_len; - } else { - left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread; - right = left + real_len; - } - // ============ multi-thread process feasign============ - int mf_dim = this->index_dim_vec_[j]; - size_t feature_value_size = - accessor_wrapper_ptr->GetFeatureValueSize(mf_dim); - VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim - << " key_len :" << len - << " feature_value_size:" << feature_value_size; - - char* test_build_values = (char*)malloc(feature_value_size * real_len); - uint64_t offset = left * feature_value_size; - cudaMemcpy(test_build_values, - hbm_pool->mem() + offset, - feature_value_size * real_len, - cudaMemcpyDeviceToHost); - CHECK(len == hbm_pool->capacity()); - uint64_t unuse_key = std::numeric_limits::max(); - for (int i = left; i < right; ++i) { - if (device_keys[i] == unuse_key) { - continue; - } - size_t local_offset = (i - left) * feature_value_size; - float* gpu_val = (float*)(test_build_values + local_offset); + auto dump_pool_to_cpu_func = [this, thread_num, &accessor_wrapper_ptr]( + int i, int j, int z) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); + auto& hbm_pool = this->hbm_pools_[i * this->multi_mf_dim_ + j]; + auto& device_keys = this->current_task_->device_dim_keys_[i][j]; + size_t len = device_keys.size(); + // ====== multi-thread process feasign================ + int len_per_thread = len / thread_num; + int remain = len % thread_num; + int left = -1, right = -1; + int real_len = len_per_thread; + if (z < remain) real_len++; + if (z < remain) { + left = z * (len_per_thread + 1); + right = left + real_len; + } else { + left = remain * (len_per_thread + 1) + (z - remain) * len_per_thread; + right = left + real_len; + } + // ============ multi-thread process feasign============ + int mf_dim = this->index_dim_vec_[j]; + size_t feature_value_size = + accessor_wrapper_ptr->GetFeatureValueSize(mf_dim); + VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim + << " key_len :" << len + << " feature_value_size:" << feature_value_size; + + char* test_build_values = (char*)malloc(feature_value_size * real_len); + uint64_t offset = left * feature_value_size; + cudaMemcpy(test_build_values, + hbm_pool->mem() + offset, + feature_value_size * real_len, + cudaMemcpyDeviceToHost); + CHECK(len == hbm_pool->capacity()); + uint64_t unuse_key = std::numeric_limits::max(); + for (int i = left; i < right; ++i) { + if (device_keys[i] == unuse_key) { + continue; + } + size_t local_offset = (i - left) * feature_value_size; + float* gpu_val = (float*)(test_build_values + local_offset); #ifdef PADDLE_WITH_PSLIB - auto* downpour_value = - (paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr); - int downpour_value_size = downpour_value->size(); - if (gpu_val->mf_size > 0 && downpour_value_size == 8) { - downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size); - } - float* cpu_val = downpour_value->data(); - cpu_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::delta_score_index()] = - gpu_val->delta_score; - cpu_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::show_index()] = - gpu_val->show; - cpu_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::click_index()] = - gpu_val->clk; - cpu_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::embed_w_index()] = - gpu_val->lr; - cpu_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::embed_g2sum_index()] = - gpu_val->lr_g2sum; - cpu_val[paddle::ps::DownpourCtrDymfAccessor:: - DownpourCtrDymfFeatureValue::slot_index()] = - gpu_val->slot; - - if (gpu_val->mf_size > 0) { - for (int x = 0; x < gpu_val->mf_dim + 1; x++) { - cpu_val[x + 8] = gpu_val->mf[x]; - } - } + auto* downpour_value = + (paddle::ps::DownpourFixedFeatureValue*)(gpu_val->cpu_ptr); + int downpour_value_size = downpour_value->size(); + if (gpu_val->mf_size > 0 && downpour_value_size == 8) { + downpour_value->resize(gpu_val->mf_dim + 1 + downpour_value_size); + } + float* cpu_val = downpour_value->data(); + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + delta_score_index()] = gpu_val->delta_score; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + show_index()] = gpu_val->show; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + click_index()] = gpu_val->clk; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + embed_w_index()] = gpu_val->lr; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + embed_g2sum_index()] = gpu_val->lr_g2sum; + cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: + slot_index()] = gpu_val->slot; + + if (gpu_val->mf_size > 0) { + for (int x = 0; x < gpu_val->mf_dim + 1; x++) { + cpu_val[x + 8] = gpu_val->mf[x]; } + } #endif #ifdef PADDLE_WITH_PSCORE - accessor_wrapper_ptr->DumpFill(gpu_val, cpu_table_accessor_, mf_dim); - // auto* downpour_value = (paddle::distributed::FixedFeatureValue*)(*( - // reinterpret_cast(gpu_val))); - // float* cpu_val = downpour_value->data(); - // VLOG(5) << "dump to cpu " << index << " gpu_value: " - // << accessor_wrapper_ptr->ParseToString(gpu_val, - // int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / - // sizeof(float))) - // << " \t cpu_value:" - // << cpu_table_accessor_->ParseToString(cpu_val, - // downpour_value->size()); - } + accessor_wrapper_ptr->DumpFill(gpu_val, cpu_table_accessor_, mf_dim); #endif - free(test_build_values); -}; -if (multi_mf_dim_) { - VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; - size_t device_num = heter_devices_.size(); - std::vector threads(device_num * multi_mf_dim_ * thread_num); - for (size_t i = 0; i < device_num; i++) { - for (int j = 0; j < multi_mf_dim_; j++) { - for (int k = 0; k < thread_num; k++) { - threads[(i + j * device_num) * thread_num + k] = - std::thread(dump_pool_to_cpu_func, i, j, k); + } + free(test_build_values); + }; + if (multi_mf_dim_) { + VLOG(0) << "psgpu wrapper dump pool: multi_mf_dim_: " << multi_mf_dim_; + size_t device_num = heter_devices_.size(); + std::vector threads(device_num * multi_mf_dim_ * thread_num); + for (size_t i = 0; i < device_num; i++) { + for (int j = 0; j < multi_mf_dim_; j++) { + for (int k = 0; k < thread_num; k++) { + threads[(i + j * device_num) * thread_num + k] = + std::thread(dump_pool_to_cpu_func, i, j, k); + } } } + for (std::thread& t : threads) { + t.join(); + } } - for (std::thread& t : threads) { - t.join(); + if (keysize_max != 0) { + HeterPs_->end_pass(); } -} -if (keysize_max != 0) { - HeterPs_->end_pass(); -} -for (size_t i = 0; i < hbm_pools_.size(); i++) { - delete hbm_pools_[i]; -} -gpu_task_pool_.Push(current_task_); -current_task_ = nullptr; -gpu_free_channel_->Put(current_task_); -timer.Pause(); -VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; -} + for (size_t i = 0; i < hbm_pools_.size(); i++) { + delete hbm_pools_[i]; + } + gpu_task_pool_.Push(current_task_); + current_task_ = nullptr; + gpu_free_channel_->Put(current_task_); + timer.Pause(); + VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; +} // namespace framework void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, const int table_id, @@ -1303,4 +1301,3 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, } // end namespace framework } // end namespace paddle -// #endif From 0719957498a696397ddb1ae7549ac7ced72c5528 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 7 Jul 2022 13:18:39 +0000 Subject: [PATCH 24/31] format; test=develop --- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index f84ac1ea61e349..b7099dec25092e 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -651,7 +651,6 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { auto build_dymf_mem_pool = [this, &gpu_task, &accessor_wrapper_ptr](int i, int j) { this->HeterPs_->set_multi_mf_dim(multi_mf_dim_, max_mf_dim_); - // this->HeterPs_->set_accessor(feature_value_accessor_); int mf_dim = this->index_dim_vec_[j]; VLOG(0) << "building table: " << i << "with mf dim: " << mf_dim << " feature_value_size:" @@ -705,6 +704,8 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { auto& device_dim_ptrs = gpu_task->device_dim_ptr_[i][j]; size_t len = device_dim_keys.size(); CHECK(len == device_dim_ptrs.size()); + // this->mem_pools_[i * this->multi_mf_dim_ + j] = + // new MemoryPool(len, feature_value_size); auto& mem_pool = this->mem_pools_[i * this->multi_mf_dim_ + j]; // ============ add for multi-thread ================ @@ -724,8 +725,9 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } // ============ add for multi-thread ================ -#ifdef PADDLE_WITH_PSLIB for (size_t k = left; k < right; k++) { +#ifdef PADDLE_WITH_PSLIB + float* val = (float*)(mem_pool->mem_address(k)); float* ptr_val = device_dim_ptrs[k]->data(); size_t dim = device_dim_ptrs[k]->size(); val->delta_score = @@ -757,24 +759,13 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { val->mf[x] = 0; } } - } #endif #ifdef PADDLE_WITH_PSCORE - for (size_t k = left; k < right; k++) { - void* val = mem_pool->mem_address(k); - // VLOG(5) << "cpu build " << k - // << " cpuptr: " << (uint64_t)(device_dim_ptrs[k]) - // << " |: " << cpu_table_accessor_->ParseToString(ptr_val, - // dim); + float* val = (float*)(mem_pool->mem_address(k)); accessor_wrapper_ptr->BuildFill( val, device_dim_ptrs[k], cpu_table_accessor_, mf_dim); - VLOG(5) << "build " << k << " : " - << accessor_wrapper_ptr->ParseToString( - (float*)(val), - int(accessor_wrapper_ptr->GetFeatureValueSize(mf_dim) / - sizeof(float))); - } #endif + } }; threads.resize(device_num * multi_mf_dim_); @@ -925,9 +916,9 @@ void PSGPUWrapper::EndPass() { std::max(keysize_max, current_task_->device_dim_keys_[i][j].size()); } } + int thread_num = 8; auto accessor_wrapper_ptr = GlobalAccessorTransfor::GetInstance().GetAccessorWrapper(); - int thread_num = 8; auto dump_pool_to_cpu_func = [this, thread_num, &accessor_wrapper_ptr]( int i, int j, int z) { PADDLE_ENFORCE_GPU_SUCCESS(cudaSetDevice(this->resource_->dev_id(i))); @@ -954,7 +945,6 @@ void PSGPUWrapper::EndPass() { VLOG(0) << "dump pool to cpu table: " << i << "with mf dim: " << mf_dim << " key_len :" << len << " feature_value_size:" << feature_value_size; - char* test_build_values = (char*)malloc(feature_value_size * real_len); uint64_t offset = left * feature_value_size; cudaMemcpy(test_build_values, @@ -989,7 +979,6 @@ void PSGPUWrapper::EndPass() { embed_g2sum_index()] = gpu_val->lr_g2sum; cpu_val[paddle::ps::DownpourCtrDymfAccessor::DownpourCtrDymfFeatureValue:: slot_index()] = gpu_val->slot; - if (gpu_val->mf_size > 0) { for (int x = 0; x < gpu_val->mf_dim + 1; x++) { cpu_val[x + 8] = gpu_val->mf[x]; @@ -1030,7 +1019,7 @@ void PSGPUWrapper::EndPass() { gpu_free_channel_->Put(current_task_); timer.Pause(); VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; -} // namespace framework +} void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, const int table_id, From 85fab755d1c04b955cbcfd41d42a04eaf2ebc5ca Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 7 Jul 2022 13:35:16 +0000 Subject: [PATCH 25/31] format; test=develop --- paddle/fluid/framework/fleet/ps_gpu_wrapper.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index b7099dec25092e..d9bb6e946f42d4 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -761,7 +761,7 @@ void PSGPUWrapper::BuildGPUTask(std::shared_ptr gpu_task) { } #endif #ifdef PADDLE_WITH_PSCORE - float* val = (float*)(mem_pool->mem_address(k)); + void* val = mem_pool->mem_address(k); accessor_wrapper_ptr->BuildFill( val, device_dim_ptrs[k], cpu_table_accessor_, mf_dim); #endif @@ -1130,7 +1130,6 @@ void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, pull_gpups_timer.Pause(); -#endif } else if (platform::is_xpu_place(place)) { #ifdef PADDLE_WITH_XPU_KP VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; @@ -1290,3 +1289,4 @@ void PSGPUWrapper::PushSparseGrad(const paddle::platform::Place& place, } // end namespace framework } // end namespace paddle +#endif From 1e99bbe407e3127ec939728bfa69cfc37d44577e Mon Sep 17 00:00:00 2001 From: danleifeng Date: Thu, 7 Jul 2022 16:12:26 +0000 Subject: [PATCH 26/31] add ut; test=develop --- .../tests/unittests/test_dist_fleet_ps13.py | 201 ++++++++++++++++++ tools/parallel_UT_rule.py | 3 +- 2 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_ps13.py diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps13.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps13.py new file mode 100644 index 00000000000000..c5ae2365b07cda --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps13.py @@ -0,0 +1,201 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os + +os.environ["WITH_DISTRIBUTE"] = "ON" + +import unittest +import tempfile +import shutil + +import paddle +import paddle.fluid as fluid +import paddle.distributed.fleet.base.role_maker as role_maker +import paddle.distributed.fleet as fleet + +paddle.enable_static() + +# For Net +base_lr = 0.2 +emb_lr = base_lr * 3 +dict_dim = 1500 +emb_dim = 128 +hid_dim = 128 +margin = 0.1 +sample_rate = 1 +batch_size = 4 + + +# this unittest is tested for SparseSharedAdamSGDRule +class TestPSPassWithBow(unittest.TestCase): + + def net(self): + + def get_acc(cos_q_nt, cos_q_pt, batch_size): + cond = fluid.layers.less_than(cos_q_nt, cos_q_pt) + cond = fluid.layers.cast(cond, dtype='float64') + cond_3 = fluid.layers.reduce_sum(cond) + acc = fluid.layers.elementwise_div(cond_3, + fluid.layers.fill_constant( + shape=[1], + value=batch_size * 1.0, + dtype='float64'), + name="simnet_acc") + return acc + + def get_loss(cos_q_pt, cos_q_nt): + loss_op1 = fluid.layers.elementwise_sub( + fluid.layers.fill_constant_batch_size_like(input=cos_q_pt, + shape=[-1, 1], + value=margin, + dtype='float32'), + cos_q_pt) + loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt) + loss_op3 = fluid.layers.elementwise_max( + fluid.layers.fill_constant_batch_size_like(input=loss_op2, + shape=[-1, 1], + value=0.0, + dtype='float32'), + loss_op2) + avg_cost = fluid.layers.mean(loss_op3) + return avg_cost + + is_distributed = False + is_sparse = True + + # query + q = fluid.layers.data(name="query_ids", + shape=[1], + dtype="int64", + lod_level=1) + # embedding + q_emb = fluid.contrib.layers.sparse_embedding( + input=q, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr)) + q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim]) + # vsum + q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') + q_ss = fluid.layers.softsign(q_sum) + # fc layer after conv + q_fc = fluid.layers.fc( + input=q_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__q_fc__", + learning_rate=base_lr)) + # label data + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + # pt + pt = fluid.layers.data(name="pos_title_ids", + shape=[1], + dtype="int64", + lod_level=1) + # embedding + pt_emb = fluid.contrib.layers.sparse_embedding( + input=pt, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr)) + pt_emb = fluid.layers.reshape(pt_emb, [-1, emb_dim]) + # vsum + pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum') + pt_ss = fluid.layers.softsign(pt_sum) + # fc layer + pt_fc = fluid.layers.fc( + input=pt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__fc__", + learning_rate=base_lr), + bias_attr=fluid.ParamAttr(name="__fc_b__")) + # nt + nt = fluid.layers.data(name="neg_title_ids", + shape=[1], + dtype="int64", + lod_level=1) + # embedding + nt_emb = fluid.contrib.layers.sparse_embedding( + input=nt, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr)) + nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim]) + # vsum + nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum') + nt_ss = fluid.layers.softsign(nt_sum) + # fc layer + nt_fc = fluid.layers.fc( + input=nt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__fc__", + learning_rate=base_lr), + bias_attr=fluid.ParamAttr(name="__fc_b__")) + cos_q_pt = fluid.layers.cos_sim(q_fc, pt_fc) + cos_q_nt = fluid.layers.cos_sim(q_fc, nt_fc) + # loss + avg_cost = get_loss(cos_q_pt, cos_q_nt) + # acc + acc = get_acc(cos_q_nt, cos_q_pt, batch_size) + return [avg_cost, acc, cos_q_pt] + + def test(self): + os.environ["PADDLE_PSERVER_NUMS"] = "2" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_PORT"] = "36001" + os.environ["PADDLE_TRAINER_ID"] = "0" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ[ + "PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001" + os.environ["TRAINING_ROLE"] = "PSERVER" + + role = role_maker.PaddleCloudRoleMaker() + fleet.init(role) + loss, acc, _ = self.net() + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.a_sync = True + + configs = {} + configs['__emb__'] = { + "table_parameters.__emb__.accessor.embed_sgd_param.name": + "SparseSharedAdamSGDRule", + "table_parameters.__emb__.accessor.embedx_sgd_param.name": + "SparseSharedAdamSGDRule", + } + strategy.sparse_table_configs = configs + optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(loss) + + fleet.init_server() + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 559f2d95b915f6..53ab93f57ce567 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -671,7 +671,8 @@ 'test_trt_convert_reduce_sum', 'save_quant2_model_lstm', 'test_trt_convert_slice', - 'test_quant2_int8_lstm_mkldnn' + 'test_quant2_int8_lstm_mkldnn', + 'test_dist_fleet_ps13' ] # mem=0 but always timeout or failed : It run 15 job each time in Single cases; From f6eb220592bd69396d6e9eaead0e36d5549070b5 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 19 Jul 2022 03:19:03 +0000 Subject: [PATCH 27/31] add ut; test=develop --- .../tests/unittests/test_fleet_distributed_strategy.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 455a7a30cfd185..7834c3afb0361a 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -334,6 +334,13 @@ def test_fleet_desc_configs(self): strategy.sparse_table_configs[0].accessor.embed_sgd_param.adagrad. initial_range, 0.0001) + strategy = paddle.distributed.fleet.DistributedStrategy() + configs = {} + configs['emb'] = {"sparse_optimizer": "shared_adam"} + strategy.fleet_desc_configs = configs + self.assertEqual(strategy.sparse_table_configs[0] + .accessor.embed_sgd_param.adam.beta1_decay_rate, 0.9) + def test_trainer_desc_configs(self): strategy = paddle.distributed.fleet.DistributedStrategy() configs = { From e49041c2ff85bd12fe45a5d51cfa1f02d6c35e9d Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 19 Jul 2022 06:53:15 +0000 Subject: [PATCH 28/31] add ut; test=develop --- .../fluid/tests/unittests/test_fleet_distributed_strategy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 7834c3afb0361a..9ac88c802111f1 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -338,8 +338,9 @@ def test_fleet_desc_configs(self): configs = {} configs['emb'] = {"sparse_optimizer": "shared_adam"} strategy.fleet_desc_configs = configs - self.assertEqual(strategy.sparse_table_configs[0] - .accessor.embed_sgd_param.adam.beta1_decay_rate, 0.9) + self.assertEqual( + strategy.sparse_table_configs[0].accessor.embed_sgd_param.adam. + beta1_decay_rate, 0.9) def test_trainer_desc_configs(self): strategy = paddle.distributed.fleet.DistributedStrategy() From f90ea1564daf72f6673284089f0aa1c3e2fe6f2d Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 19 Jul 2022 08:03:44 +0000 Subject: [PATCH 29/31] change cmakelist; test=develop --- paddle/fluid/framework/fleet/CMakeLists.txt | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 42235b7c484e34..4cf3ab8dc1a67d 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -25,10 +25,17 @@ endif() if(WITH_HETERPS) if(WITH_NCCL AND WITH_GPU) - nv_library( - ps_gpu_wrapper - SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc - DEPS heter_ps gloo_wrapper ${BRPC_DEPS}) + if(WITH_PSCORE) + nv_library( + ps_gpu_wrapper + SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc + DEPS heter_ps gloo_wrapper ps_framework_proto ${BRPC_DEPS}) + else() + nv_library( + ps_gpu_wrapper + SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc + DEPS heter_ps gloo_wrapper ${BRPC_DEPS}) + endif() add_subdirectory(heter_ps) elseif(WITH_XPU_KP) xpu_library( From be6a31e2b3e46878e7cf39c32973a4f9a578d461 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 19 Jul 2022 09:15:30 +0000 Subject: [PATCH 30/31] change cmakelist; test=develop --- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 42fbf8d3a19c14..0d1669a42b1e9f 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -51,10 +51,10 @@ limitations under the License. */ #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_PSCORE -#include "paddle/fluid/distributed/ps.pb.h" #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h" #include "paddle/fluid/distributed/ps/wrapper/fleet.h" +#include "paddle/fluid/distributed/the_one_ps.pb.h" #endif #ifdef PADDLE_WITH_PSLIB #include "afs_api.h" @@ -416,7 +416,7 @@ class PSGPUWrapper { // set optimizer type(naive,adagrad,std_adagrad,adam,share_adam) optimizer_type_ = (config.find("optimizer_type") == config.end()) ? 1 - : int(config["optimizer_type"]); + : static_cast(config["optimizer_type"]); } void SetDate(int year, int month, int day) { From b82f4c8f70929ae68d1ead49150c8623eb152a83 Mon Sep 17 00:00:00 2001 From: danleifeng Date: Tue, 19 Jul 2022 13:13:30 +0000 Subject: [PATCH 31/31] change cmakelist; test=develop --- paddle/fluid/distributed/ps/wrapper/CMakeLists.txt | 7 ------- 1 file changed, 7 deletions(-) diff --git a/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt b/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt index 6abd68e5d0aa9a..c9cd883dabb69a 100644 --- a/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/wrapper/CMakeLists.txt @@ -1,12 +1,5 @@ get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) -set(DISTRIBUTE_COMPILE_FLAGS - "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses" -) -if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) - set(DISTRIBUTE_COMPILE_FLAGS "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") -endif() - set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_library(