-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fleet_executor] Dist model run method Implementation #39194
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
f6a10e1
run method for dist model
FeixLiu 1779f23
bug fix
FeixLiu e9be30c
bug fix
FeixLiu a510c7b
fix the init value of memory hold
FeixLiu d17c195
formate refine
FeixLiu 0ae9a26
bug fix for pp comm init
FeixLiu 2b36ea7
bug fix
FeixLiu 6513ca6
add feed var dtype check
FeixLiu 9a14a81
bug fix
FeixLiu 6f65dbd
bug fix and update log
FeixLiu 884a8db
add more error log
FeixLiu 925d350
add pp stage checker and feed fetch vars checkerw
FeixLiu 9cce72c
prune logic branch
FeixLiu 139a28f
add timer
FeixLiu e5d607e
add ut for thw whole work flow
FeixLiu c4e30da
update ut
FeixLiu 312f679
update license
FeixLiu b9fa0bf
bug fix for ut
FeixLiu 11deb17
update ut
FeixLiu eb8f265
rename file
FeixLiu dcdd46f
minor fix
FeixLiu de4c1cf
for cla
FeixLiu 2765213
add compile flag
FeixLiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,13 @@ | |
// limitations under the License. | ||
|
||
#include <glog/logging.h> | ||
#include <chrono> // NOLINT | ||
|
||
#include "paddle/fluid/distributed/fleet_executor/dist_model.h" | ||
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" | ||
#include "paddle/fluid/distributed/fleet_executor/task_node.h" | ||
#include "paddle/fluid/framework/block_desc.h" | ||
#include "paddle/fluid/framework/feed_fetch_method.h" | ||
#include "paddle/fluid/framework/naive_executor.h" | ||
#include "paddle/fluid/framework/op_proto_maker.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
|
@@ -37,10 +39,110 @@ bool IsPersistable(const framework::VarDesc *var) { | |
} | ||
return false; | ||
} | ||
|
||
bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, | ||
framework::LoDTensor *input_tensor, | ||
const platform::Place &place) { | ||
VLOG(3) << "Loading data from DistModelTensor for " << input_data.name; | ||
framework::DDim dims = framework::make_ddim(input_data.shape); | ||
void *input_tensor_ptr; | ||
if (input_data.dtype == DistModelDataType::INT64) { | ||
input_tensor_ptr = input_tensor->mutable_data<int64_t>(dims, place); | ||
} else if (input_data.dtype == DistModelDataType::FLOAT32) { | ||
input_tensor_ptr = input_tensor->mutable_data<float>(dims, place); | ||
} else if (input_data.dtype == DistModelDataType::INT32) { | ||
input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place); | ||
} else { | ||
// Q(fleet exe dev): for input/output, should we support fp16 | ||
LOG(ERROR) << "unsupported feed type " << input_data.dtype; | ||
return false; | ||
} | ||
|
||
PADDLE_ENFORCE_NOT_NULL( | ||
input_tensor_ptr, | ||
paddle::platform::errors::Fatal( | ||
"LoDTensor creation failed. DistModel loaded data failed.")); | ||
PADDLE_ENFORCE_NOT_NULL(input_data.data.data(), | ||
paddle::platform::errors::InvalidArgument( | ||
"DistModelTensor contains no data.")); | ||
|
||
if (platform::is_cpu_place(place)) { | ||
VLOG(3) << "Loading data for CPU."; | ||
std::memcpy(static_cast<void *>(input_tensor_ptr), input_data.data.data(), | ||
input_data.data.length()); | ||
} else if (platform::is_gpu_place(place)) { | ||
VLOG(3) << "Loading data for GPU."; | ||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); | ||
auto *dev_ctx = | ||
dynamic_cast<const platform::CUDADeviceContext *>(pool.Get(place)); | ||
auto gpu_place = place; | ||
memory::Copy(gpu_place, static_cast<void *>(input_tensor_ptr), | ||
platform::CPUPlace(), input_data.data.data(), | ||
input_data.data.length(), dev_ctx->stream()); | ||
#else | ||
PADDLE_THROW(paddle::platform::errors::Fatal( | ||
"Paddle wasn't compiled with CUDA, but place is GPU.")); | ||
#endif | ||
} else { | ||
PADDLE_THROW(paddle::platform::errors::InvalidArgument( | ||
"DistModel only supports CPU and GPU.")); | ||
} | ||
|
||
framework::LoD dst_lod; | ||
for (auto &src_lod : input_data.lod) { | ||
dst_lod.emplace_back(src_lod); | ||
} | ||
input_tensor->set_lod(dst_lod); | ||
return true; | ||
} | ||
|
||
std::string DistModelDTypeToString(DistModelDataType dtype) { | ||
switch (dtype) { | ||
case DistModelDataType::FLOAT32: | ||
return "float32"; | ||
case DistModelDataType::FLOAT16: | ||
return "float16"; | ||
case DistModelDataType::INT64: | ||
return "int64"; | ||
case DistModelDataType::INT32: | ||
return "int32"; | ||
case DistModelDataType::INT8: | ||
return "int8"; | ||
} | ||
return "NOT SUPPORT DTYPE"; | ||
} | ||
|
||
bool IsPPFirstStage(const DistModelConfig &config) { | ||
return config.local_rank - config.mp_degree < 0; | ||
} | ||
|
||
bool IsPPLastStage(const DistModelConfig &config) { | ||
return config.local_rank + config.mp_degree >= config.nranks; | ||
} | ||
|
||
class DistModelTimer { | ||
public: | ||
void tic() { tic_time = std::chrono::high_resolution_clock::now(); } | ||
double toc() { | ||
std::chrono::high_resolution_clock::time_point toc_time = | ||
std::chrono::high_resolution_clock::now(); | ||
std::chrono::duration<double> time_elapse = | ||
std::chrono::duration_cast<std::chrono::duration<double>>(toc_time - | ||
tic_time); | ||
double time_elapse_in_ms = | ||
static_cast<double>(time_elapse.count()) * 1000.0; | ||
return time_elapse_in_ms; | ||
} | ||
|
||
private: | ||
std::chrono::high_resolution_clock::time_point tic_time; | ||
}; | ||
|
||
} // namespace | ||
|
||
bool DistModel::Init() { | ||
/* TODO(fleet exe dev): implement this funct */ | ||
carrier_id_ = "inference"; | ||
bool init_method = (!config_.model_dir.empty() || config_.program_desc); | ||
PADDLE_ENFORCE_EQ(init_method, true, | ||
platform::errors::InvalidArgument( | ||
|
@@ -127,10 +229,9 @@ bool DistModel::CommInit() { | |
InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints, | ||
comm_init_block, config_.mp_ring_id); | ||
} | ||
if (config_.pp_degree) { | ||
// NOTE: the last pp stage doesn't need init pp comm | ||
if (config_.pp_degree > 1) { | ||
VLOG(3) << "Init comm group for pp."; | ||
if (config_.local_rank - config_.mp_degree >= 0) { | ||
if (!IsPPFirstStage(config_)) { | ||
PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true, | ||
platform::errors::InvalidArgument( | ||
"pp upstream ring id must be provided for " | ||
|
@@ -143,7 +244,7 @@ bool DistModel::CommInit() { | |
comm_init_block, config_.pp_upstream_ring_id); | ||
} | ||
|
||
if (config_.local_rank + config_.mp_degree < config_.nranks) { | ||
if (!IsPPLastStage(config_)) { | ||
PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true, | ||
platform::errors::InvalidArgument( | ||
"pp downstream ring id must be provided for " | ||
|
@@ -326,7 +427,7 @@ bool DistModel::PrepareFleetExe() { | |
id_to_rank.insert({i, i}); | ||
} | ||
fleet_exe.reset(new FleetExecutor(executor_desc_)); | ||
fleet_exe->Init("inference", *(program_.get()), scope_.get(), place_, 1, | ||
fleet_exe->Init(carrier_id_, *(program_.get()), scope_.get(), place_, 1, | ||
{task_node_.get()}, id_to_rank); | ||
return true; | ||
} | ||
|
@@ -340,24 +441,194 @@ bool DistModel::PrepareFeedAndFetch() { | |
feeds_.resize(idx + 1); | ||
} | ||
feeds_[idx] = op; | ||
feed_names_[op->Output("Out")[0]] = idx; | ||
idx_to_feeds_[idx] = op->Output("Out")[0]; | ||
std::string var_name = op->Output("Out")[0]; | ||
feed_names_[var_name] = idx; | ||
idx_to_feeds_[idx] = var_name; | ||
framework::VarDesc *real_var = program_->Block(0).FindVar(var_name); | ||
if (!real_var) { | ||
LOG(ERROR) | ||
<< "The output of feed ops [" << var_name | ||
<< "] cannot be found in the program. Check the inference program."; | ||
return false; | ||
} | ||
if (real_var->GetDataType() == framework::proto::VarType::FP32) { | ||
feeds_to_dtype_.insert({var_name, DistModelDataType::FLOAT32}); | ||
} else if (real_var->GetDataType() == framework::proto::VarType::INT32) { | ||
feeds_to_dtype_.insert({var_name, DistModelDataType::INT32}); | ||
} else if (real_var->GetDataType() == framework::proto::VarType::INT64) { | ||
feeds_to_dtype_.insert({var_name, DistModelDataType::INT64}); | ||
} else { | ||
LOG(ERROR) << "Don't support feed var dtype for: " | ||
<< real_var->GetDataType(); | ||
return false; | ||
} | ||
} else if (op->Type() == "fetch") { | ||
VLOG(3) << "fetch op with fetch var: " << op->Input("X")[0]; | ||
int idx = BOOST_GET_CONST(int, op->GetAttr("col")); | ||
if (fetches_.size() <= static_cast<size_t>(idx)) { | ||
fetches_.resize(idx + 1); | ||
} | ||
fetches_[idx] = op; | ||
id_to_fetches_[idx] = op->Input("X")[0]; | ||
idx_to_fetches_[idx] = op->Input("X")[0]; | ||
} | ||
} | ||
|
||
if (config_.pp_degree == 1) { | ||
if (feeds_.size() == 0) { | ||
LOG(ERROR) << "No feed ops in the inf program, please check the program."; | ||
return false; | ||
} | ||
if (fetches_.size() == 0) { | ||
LOG(ERROR) << "No fetch op in the inf program, please check the program."; | ||
return false; | ||
} | ||
} else { | ||
if (IsPPFirstStage(config_)) { | ||
if (feeds_.size() == 0) { | ||
LOG(ERROR) << "Feed ops are needed for the first pp stage."; | ||
return false; | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个else确定没有写错? |
||
LOG(WARNING) << "No feed ops in non-first pp stage."; | ||
} | ||
} else if (feeds_.size() > 0) { | ||
LOG(WARNING) << "Feed op is found in the non-first stage of pp."; | ||
} | ||
if (IsPPLastStage(config_)) { | ||
if (fetches_.size() == 0) { | ||
LOG(ERROR) << "Fetch op is needed for the last pp stage."; | ||
return false; | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个else也是 |
||
LOG(WARNING) << "No fetch op in non-last pp stage."; | ||
} | ||
} else if (fetches_.size() > 0) { | ||
LOG(WARNING) << "Fetch op is found in the non-last stage of pp."; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
bool DistModel::FeedData(const std::vector<DistModelTensor> &input_data, | ||
framework::Scope *scope) { | ||
VLOG(3) << "DistModel is feeding data."; | ||
if (input_data.size() != feeds_.size()) { | ||
LOG(ERROR) << "Should provide " << feeds_.size() << " feeds, but got " | ||
<< input_data.size() << " data."; | ||
return false; | ||
} | ||
feed_tensors_.resize(feeds_.size()); | ||
for (size_t i = 0; i < input_data.size(); ++i) { | ||
// feed each data separately | ||
framework::LoDTensor *input_tensor = &(feed_tensors_[i]); | ||
if (!LoadDataFromDistModelTensor(input_data[i], input_tensor, place_)) { | ||
LOG(ERROR) << "Fail to load data from tensor " << input_data[i].name; | ||
return false; | ||
} | ||
std::string target_name = input_data[i].name; | ||
if (feed_names_.find(target_name) == feed_names_.end()) { | ||
LOG(ERROR) << "The input name [" << target_name | ||
<< "] cannot be found in the program." | ||
<< " DistModel loads data failed."; | ||
return false; | ||
} | ||
if (input_data[i].dtype != feeds_to_dtype_[target_name]) { | ||
LOG(ERROR) << "Feed var [" << target_name << "] expected dtype is: " | ||
<< DistModelDTypeToString(feeds_to_dtype_[target_name]) | ||
<< ". But received dtype is: " | ||
<< DistModelDTypeToString(input_data[i].dtype) << "."; | ||
return false; | ||
} | ||
int feed_idx = feed_names_[target_name]; | ||
framework::SetFeedVariable(scope, *input_tensor, "feed", feed_idx); | ||
} | ||
return true; | ||
} | ||
|
||
bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data, | ||
framework::Scope *scope) { | ||
VLOG(3) << "DistModel is fetch results."; | ||
output_data->resize(fetches_.size()); | ||
for (size_t i = 0; i < fetches_.size(); ++i) { | ||
int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col")); | ||
VLOG(3) << "Fetching data for [" << idx_to_fetches_[idx] << "]"; | ||
PADDLE_ENFORCE_EQ( | ||
static_cast<size_t>(idx), i, | ||
platform::errors::InvalidArgument( | ||
"Fetch op's col attr(%d) should be equal to the index(%d)", idx, | ||
i)); | ||
framework::FetchType &fetch_var = | ||
framework::GetFetchVariable(*scope, "fetch", idx); | ||
auto &fetch = BOOST_GET(framework::LoDTensor, fetch_var); | ||
auto type = fetch.type(); | ||
auto output = &(output_data->at(i)); | ||
output->name = idx_to_fetches_[idx]; | ||
bool rst = false; | ||
if (type == framework::proto::VarType::FP32) { | ||
rst = FetchResult<float>(fetch, output); | ||
output->dtype = DistModelDataType::FLOAT32; | ||
} else if (type == framework::proto::VarType::INT64) { | ||
rst = FetchResult<int64_t>(fetch, output); | ||
output->dtype = DistModelDataType::INT64; | ||
} else if (type == framework::proto::VarType::INT32) { | ||
rst = FetchResult<int32_t>(fetch, output); | ||
output->dtype = DistModelDataType::INT32; | ||
} else { | ||
LOG(ERROR) << "DistModel meets unknown fetch data type. DistModel only " | ||
"supports float32, int64 and int32 fetch type for now."; | ||
} | ||
if (!rst) { | ||
LOG(ERROR) << "DistModel fails to fetch result " << idx_to_fetches_[idx]; | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
void DistModel::Run(const std::vector<DistModelTensor> &input_data, | ||
template <typename T> | ||
bool DistModel::FetchResult(const framework::LoDTensor &fetch, | ||
DistModelTensor *output_data) { | ||
auto shape = framework::vectorize(fetch.dims()); | ||
output_data->shape.assign(shape.begin(), shape.end()); | ||
const T *data = fetch.data<T>(); | ||
int64_t num_elems = fetch.numel(); | ||
output_data->data.Resize(num_elems * sizeof(T)); | ||
// The output of fetch op is always on the cpu, no need switch on place | ||
memcpy(output_data->data.data(), data, num_elems * sizeof(T)); | ||
output_data->lod.clear(); | ||
for (auto &level : fetch.lod()) { | ||
output_data->lod.emplace_back(level.begin(), level.end()); | ||
} | ||
return true; | ||
} | ||
|
||
bool DistModel::Run(const std::vector<DistModelTensor> &input_data, | ||
std::vector<DistModelTensor> *output_data) { | ||
/* TODO(fleet exe dev): implement this funct */ | ||
// TODO(fleet exe dev): support pipeline inf mode | ||
VLOG(3) << "DistModel run for once."; | ||
|
||
DistModelTimer timer; | ||
timer.tic(); | ||
|
||
if (!FeedData(input_data, scope_.get())) { | ||
LOG(ERROR) << "DistModel failed at feeding data."; | ||
return false; | ||
} | ||
double feed_elapse = timer.toc(); | ||
VLOG(3) << "Finish loading data, cost " << feed_elapse << "ms."; | ||
|
||
fleet_exe->Run(carrier_id_); | ||
double fleet_exe_elapse = timer.toc(); | ||
VLOG(3) << "Finish FleetExe running, cost " << fleet_exe_elapse - feed_elapse | ||
<< "ms."; | ||
|
||
if (!FetchResults(output_data, scope_.get())) { | ||
LOG(ERROR) << "DistModel failed at fetching result."; | ||
return false; | ||
} | ||
double fetch_elapse = timer.toc(); | ||
VLOG(3) << "Finish fetching data, cost " << fetch_elapse - fleet_exe_elapse | ||
<< "ms."; | ||
VLOG(3) << "DistModel finish inf, cost " << fetch_elapse << "ms"; | ||
return true; | ||
} | ||
|
||
} // namespace distributed | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一定需要fetch吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不fetch的话,结果怎么拿出来?你是说我们之后第一个carrier返回消息的时候fetch在第一个pp stage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不拿,只算