Skip to content

Commit

Permalink
[PIR] normalize the use of value. (#55322)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Sep 17, 2023
1 parent b0b71d2 commit 50669e0
Show file tree
Hide file tree
Showing 19 changed files with 220 additions and 209 deletions.
8 changes: 3 additions & 5 deletions paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ std::vector<ir::Tensor> CollectInputTensor(
std::vector<ir::Tensor>* func_args,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) {
std::vector<ir::Tensor> tensors;
for (auto& operand : op->operands()) {
CHECK(operand);
auto in_value = operand.source();
for (auto in_value : op->operands_source()) {
VLOG(4) << "input tensor name: " << CompatibleInfo::ValueName(in_value);
// NOTE(Aurelius84): Need always to create placeholder for input tensor.
ir::Tensor tensor = details::GetTensor(in_value);
Expand All @@ -72,7 +70,7 @@ std::vector<ir::Tensor> CollectInputTensor(
return tensors;
}

void CollectOutputInfo(const ::pir::Operation* op,
void CollectOutputInfo(::pir::Operation* op,
std::vector<Type>* out_types,
std::vector<std::vector<int>>* out_shapes) {
auto op_results = op->results();
Expand Down Expand Up @@ -359,7 +357,7 @@ std::vector<ir::Expr> OpLowererImpl::LowerOps(

std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
const ::pir::Operation* op,
::pir::Operation* op,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors) {
VLOG(4) << "Do lower with Compute, op: " << op->name();
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
*/
std::vector<ir::LoweredFunc> DoOpLower(
std::shared_ptr<hlir::framework::OpImpl> op_impl,
const ::pir::Operation* op,
::pir::Operation* op,
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map,
std::vector<ir::Tensor>* op_func_arg_tensors);

Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/hlir/framework/new_ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ std::vector<std::string> CompatibleInfo::InputNames(const ::pir::Operation& op,
return names;
}

std::vector<std::string> CompatibleInfo::OutputNames(
const ::pir::Operation& op) {
std::vector<std::string> CompatibleInfo::OutputNames(::pir::Operation& op) {
std::vector<std::string> names;
for (int i = 0; i < op.num_results(); ++i) {
auto value = op.result(i);
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/framework/new_ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct CompatibleInfo {
static std::vector<std::string> InputNames(const ::pir::Operation& op,
bool allow_duplicate = false);

static std::vector<std::string> OutputNames(const ::pir::Operation& op);
static std::vector<std::string> OutputNames(::pir::Operation& op); // NOLINT
};

} // namespace newir
Expand Down
5 changes: 1 addition & 4 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,7 @@ void BindOperation(py::module *m) {
)DOC");
op.def("name", &Operation::name)
.def("get_parent_block",
py::overload_cast<>(&Operation::GetParent),
return_value_policy::reference)
.def("get_parent_block",
py::overload_cast<>(&Operation::GetParent, py::const_),
&Operation::GetParent,
return_value_policy::reference)
.def("num_operands", &Operation::num_operands)
.def("num_results", &Operation::num_results)
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void Block::AddArgument(Type type) {

bool Block::TopoOrderCheck(const OpListType &op_list) {
std::unordered_set<Value> visited_values;
for (const Operation *op : op_list) {
for (Operation *op : op_list) {
if (op->num_operands() > 0) {
for (size_t i = 0; i < op->num_operands(); ++i) {
auto operand = op->operand_source(i);
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/block_argument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ namespace detail {
class BlockArgumentImpl : public ValueImpl {
public:
static bool classof(const ValueImpl &value) {
return value.kind() == BLOCK_ARGUMENT_INDEX;
return value.kind() == BLOCK_ARG_IDX;
}

private:
BlockArgumentImpl(Type type, Block *owner, uint32_t index)
: ValueImpl(type, BLOCK_ARGUMENT_INDEX), owner_(owner), index_(index) {}
: ValueImpl(type, BLOCK_ARG_IDX), owner_(owner), index_(index) {}

~BlockArgumentImpl();
// access construction and owner
Expand Down
16 changes: 8 additions & 8 deletions paddle/pir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void IrPrinter::PrintOperation(Operation* op) {
PrintGeneralOperation(op);
}

void IrPrinter::PrintGeneralOperation(const Operation* op) {
void IrPrinter::PrintGeneralOperation(Operation* op) {
// TODO(lyk): add API to get opresults directly
PrintOpResult(op);
os << " =";
Expand All @@ -160,7 +160,7 @@ void IrPrinter::PrintGeneralOperation(const Operation* op) {
PrintOpReturnType(op);
}

void IrPrinter::PrintFullOperation(const Operation* op) {
void IrPrinter::PrintFullOperation(Operation* op) {
PrintGeneralOperation(op);
if (op->num_regions() > 0) {
os << newline;
Expand All @@ -186,7 +186,7 @@ void IrPrinter::PrintBlock(const Block* block) {
os << "}\n";
}

void IrPrinter::PrintValue(const Value& v) {
void IrPrinter::PrintValue(Value v) {
if (!v) {
os << "<<NULL VALUE>>";
return;
Expand All @@ -204,7 +204,7 @@ void IrPrinter::PrintValue(const Value& v) {
os << new_name;
}

void IrPrinter::PrintOpResult(const Operation* op) {
void IrPrinter::PrintOpResult(Operation* op) {
os << " (";
auto num_op_result = op->num_results();
std::vector<OpResult> op_results;
Expand All @@ -220,7 +220,7 @@ void IrPrinter::PrintOpResult(const Operation* op) {
os << ")";
}

void IrPrinter::PrintAttributeMap(const Operation* op) {
void IrPrinter::PrintAttributeMap(Operation* op) {
AttributeMap attributes = op->attributes();
std::map<std::string, Attribute, std::less<std::string>> order_attributes(
attributes.begin(), attributes.end());
Expand All @@ -239,7 +239,7 @@ void IrPrinter::PrintAttributeMap(const Operation* op) {
os << "}";
}

void IrPrinter::PrintOpOperands(const Operation* op) {
void IrPrinter::PrintOpOperands(Operation* op) {
os << " (";
auto num_op_operands = op->num_operands();
std::vector<Value> op_operands;
Expand All @@ -255,7 +255,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) {
os << ")";
}

void IrPrinter::PrintOperandsType(const Operation* op) {
void IrPrinter::PrintOperandsType(Operation* op) {
auto num_op_operands = op->num_operands();
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
Expand All @@ -276,7 +276,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) {
os << ")";
}

void IrPrinter::PrintOpReturnType(const Operation* op) {
void IrPrinter::PrintOpReturnType(Operation* op) {
auto num_op_result = op->num_results();
std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result);
Expand Down
16 changes: 8 additions & 8 deletions paddle/pir/core/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,24 @@ class IR_API IrPrinter : public BasicIrPrinter {
/// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(Operation* op);
/// @brief print operation itself without its regions
void PrintGeneralOperation(const Operation* op);
void PrintGeneralOperation(Operation* op);
/// @brief print operation and its regions
void PrintFullOperation(const Operation* op);
void PrintFullOperation(Operation* op);

void PrintRegion(const Region& Region);
void PrintBlock(const Block* block);

void PrintValue(const Value& v);
void PrintValue(Value v);

void PrintOpResult(const Operation* op);
void PrintOpResult(Operation* op);

void PrintAttributeMap(const Operation* op);
void PrintAttributeMap(Operation* op);

void PrintOpOperands(const Operation* op);
void PrintOpOperands(Operation* op);

void PrintOperandsType(const Operation* op);
void PrintOperandsType(Operation* op);

void PrintOpReturnType(const Operation* op);
void PrintOpReturnType(Operation* op);

private:
size_t cur_var_number_{0};
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/op_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ bool OpResult::operator==(const OpResult &other) const {
return impl_ == other.impl_;
}

OpResult::OpResult(const detail::OpResultImpl *impl) : Value(impl) {}
OpResult::OpResult(detail::OpResultImpl *impl) : Value(impl) {}

} // namespace pir
2 changes: 1 addition & 1 deletion paddle/pir/core/op_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IR_API OpResult : public Value {

private:
friend Operation;
OpResult(const detail::OpResultImpl *impl); // NOLINT
OpResult(detail::OpResultImpl *impl); // NOLINT
// Access classof annd dyn_cast_from.
friend Value;
static bool classof(Value value);
Expand Down
33 changes: 16 additions & 17 deletions paddle/pir/core/op_result_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pir/core/op_result_impl.h"

#include <cassert>
#include "paddle/pir/core/operation.h"

namespace pir {
namespace detail {
Expand All @@ -22,31 +21,31 @@ uint32_t OpResultImpl::index() const {
if (const auto *outline_result = dyn_cast<OpOutlineResultImpl>(this)) {
return outline_result->index();
}
return dyn_cast<OpInlineResultImpl>(this)->index();
return static_cast<const OpInlineResultImpl *>(this)->index();
}

OpResultImpl::~OpResultImpl() { assert(use_empty()); }
OpResultImpl::~OpResultImpl() {
if (!use_empty()) {
LOG(FATAL) << "Destoryed a op_result that is still in use. \n"
<< "The owner op type is:" << owner()->name();
}
}

Operation *OpResultImpl::owner() const {
Operation *OpResultImpl::owner() {
// For inline result, pointer offset index to obtain the address of op.
if (const auto *result = dyn_cast<OpInlineResultImpl>(this)) {
if (auto *result = dyn_cast<OpInlineResultImpl>(this)) {
result += result->index() + 1;
return reinterpret_cast<Operation *>(
const_cast<OpInlineResultImpl *>(result));
return reinterpret_cast<Operation *>(result);
}
// For outline result, pointer offset outline_index to obtain the address of
// maximum inline result.
const OpOutlineResultImpl *outline_result =
(const OpOutlineResultImpl *)(this);
outline_result +=
(outline_result->outline_index_ - GetMaxInlineResultIndex());
auto *outline_result = static_cast<OpOutlineResultImpl *>(this);
outline_result += (outline_result->index() - MAX_INLINE_RESULT_IDX);
// The offset of the maximum inline result distance op is
// GetMaxInlineResultIndex.
const auto *inline_result =
reinterpret_cast<const OpInlineResultImpl *>(outline_result);
inline_result += (GetMaxInlineResultIndex() + 1);
return reinterpret_cast<Operation *>(
const_cast<OpInlineResultImpl *>(inline_result));
auto *inline_result = reinterpret_cast<OpInlineResultImpl *>(outline_result);
inline_result += OUTLINE_RESULT_IDX;
return reinterpret_cast<Operation *>(inline_result);
}

} // namespace detail
Expand Down
21 changes: 7 additions & 14 deletions paddle/pir/core/op_result_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,20 @@ class OpResultImpl : public ValueImpl {
using ValueImpl::ValueImpl;

static bool classof(const ValueImpl &value) {
return value.kind() <= OUTLINE_OP_RESULT_INDEX;
return value.kind() <= OUTLINE_RESULT_IDX;
}

///
/// \brief Get the parent operation of this result.(op_ptr = value_ptr +
/// index)
///
Operation *owner() const;
Operation *owner();

///
/// \brief Get the result index of the operation result.
///
uint32_t index() const;

///
/// \brief Get the maximum number of results that can be stored inline.
///
static uint32_t GetMaxInlineResultIndex() {
return OUTLINE_OP_RESULT_INDEX - 1;
}

~OpResultImpl();
};

Expand All @@ -58,13 +51,13 @@ class OpInlineResultImpl : public OpResultImpl {
public:
OpInlineResultImpl(Type type, uint32_t result_index)
: OpResultImpl(type, result_index) {
if (result_index > GetMaxInlineResultIndex()) {
if (result_index > MAX_INLINE_RESULT_IDX) {
throw("Inline result index should not exceed MaxInlineResultIndex(5)");
}
}

static bool classof(const ValueImpl &value) {
return value.kind() < OUTLINE_OP_RESULT_INDEX;
return value.kind() < OUTLINE_RESULT_IDX;
}

uint32_t index() const { return kind(); }
Expand All @@ -77,15 +70,15 @@ class OpInlineResultImpl : public OpResultImpl {
class OpOutlineResultImpl : public OpResultImpl {
public:
OpOutlineResultImpl(Type type, uint32_t outline_index)
: OpResultImpl(type, OUTLINE_OP_RESULT_INDEX),
outline_index_(outline_index) {}
: OpResultImpl(type, OUTLINE_RESULT_IDX), outline_index_(outline_index) {}

static bool classof(const ValueImpl &value) {
return value.kind() == OUTLINE_OP_RESULT_INDEX;
return value.kind() == OUTLINE_RESULT_IDX;
}

uint32_t index() const { return outline_index_; }

private:
uint32_t outline_index_;
};

Expand Down
Loading

0 comments on commit 50669e0

Please sign in to comment.