forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add symbolTable & symbolicDimProduct & symbolicDimMgr. (PaddlePaddle#…
…56351) * add symbolicDimProduct & symbolicDimMgr without method shape_constraint related * split ddim in phi, add a target ddim, used by pd_type * add pd_type.cc to ir_shape CMakeLists
- Loading branch information
Showing
8 changed files
with
394 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,9 @@ | ||
add_subdirectory(ir) | ||
file(GLOB_RECURSE SHAPE_SRCS "*.cc") | ||
ir_library( | ||
ir_shape | ||
SRCS | ||
${SHAPE_SRCS} | ||
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc | ||
DEPS | ||
ddim | ||
ir_core) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/ir/dialect/shape/utils/shape_utils.h" | ||
#include <string> | ||
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" | ||
namespace ir { | ||
|
||
bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { | ||
if (lhs.size() < 1 || lhs[0] != 'S' && lhs[0] != 'C') return lhs < rhs; | ||
if (rhs.size() < 1 || rhs[0] != 'S' && rhs[0] != 'C') return lhs < rhs; | ||
int64_t lhsIdx = 0, rhsIdx = 0; | ||
try { | ||
lhsIdx = stol(lhs.substr(1)); | ||
rhsIdx = stol(rhs.substr(1)); | ||
} catch (const std::exception& e) { | ||
IR_THROW("Invalid symbolic name"); | ||
} | ||
return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhsIdx < rhsIdx); | ||
} | ||
|
||
ir::Operation* SymbolTable::lookup(const std::string& name) const { | ||
auto it = symbolTableMap_.find(name); | ||
return it != symbolTableMap_.end() ? it->second : nullptr; | ||
} | ||
|
||
const std::string SymbolTable::insert(ir::Operation* symbol) { | ||
std::string name; | ||
if (symbol->HasAttribute("sym_name")) { | ||
name = symbol->dyn_cast<SymbolicDim>().getSymName(); | ||
} | ||
// TODO(liujinnan): add constraint_func name branch. | ||
symbolTableMap_.insert({name, symbol}); | ||
return name; | ||
} | ||
|
||
const std::string SymbolicDimMgr::getNextName() { | ||
std::string name; | ||
do { | ||
name = "S" + std::to_string(nextSymbolicIdx_++); | ||
} while (!symbolNameSet_.insert(name).second); | ||
return name; | ||
} | ||
|
||
SymbolicDimMgr::SymbolicDimMgr(ir::ModuleOp m) : m_(m), symbolTable_(m_) {} | ||
|
||
SymbolicDim SymbolicDimMgr::newSymbolicDim(const std::string& name) { | ||
::ir::Builder builder = ::ir::Builder(m_.ir_context(), m_.block()); | ||
ir::dialect::SymbolicDim symbol = builder.Build<ir::dialect::SymbolicDim>( | ||
name.empty() ? getNextName() : name); | ||
symbolDimUnionSet_[symbol] = symbol; | ||
symbolTable_.insert(symbol); | ||
return symbol; | ||
} | ||
|
||
SymbolicDim SymbolicDimMgr::newConstantSymbolicDim(int64_t val) { | ||
auto it = constantSymbolicDimMap_.find(val); | ||
if (it == constantSymbolicDimMap_.end()) { | ||
auto name = "C" + std::to_string(val); | ||
it = constantSymbolicDimMap_ | ||
.insert(std::make_pair(val, newSymbolicDim(name))) | ||
.first; | ||
it->second.updateValue(val); | ||
} | ||
return getRootSymbolicDim(it->second); | ||
} | ||
|
||
std::vector<SymbolicDim> SymbolicDimMgr::createSymbolicDimsForRankedValue( | ||
ir::Value value) { | ||
std::vector<SymbolicDim> symbols; | ||
auto dims = value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims(); | ||
for (int idx = 0; idx < dims.size(); ++idx) { | ||
symbols.push_back( | ||
dims[idx] == -100000 // TODO(zhangbo): value = ShapedType::kDynamic | ||
? newSymbolicDim() | ||
: newConstantSymbolicDim(dims[idx])); | ||
} | ||
return symbols; | ||
} | ||
|
||
SymbolicDim SymbolicDimMgr::getRootSymbolicDim(SymbolicDim symbol) { | ||
SymbolicDim current = symbol; | ||
std::vector<SymbolicDim> path; | ||
while (symbolDimUnionSet_[current] != current) { | ||
path.push_back(current); | ||
current = symbolDimUnionSet_[current]; | ||
} | ||
for (SymbolicDim sym : path) symbolDimUnionSet_[sym] = current; | ||
return current; | ||
} | ||
|
||
bool SymbolicDimMgr::isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { | ||
SymbolicDim lhsRoot = getRootSymbolicDim(lhs); | ||
SymbolicDim rhsRoot = getRootSymbolicDim(rhs); | ||
return lhsRoot == rhsRoot; | ||
} | ||
|
||
bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { | ||
SymbolicDim lhsRoot = getRootSymbolicDim(lhs); | ||
SymbolicDim rhsRoot = getRootSymbolicDim(rhs); | ||
|
||
if (lhsRoot != rhsRoot) { | ||
if (compareSymbolicDimNames(lhsRoot.getSymName(), rhsRoot.getSymName())) { | ||
if (!lhsRoot.merge(rhsRoot)) return false; | ||
symbolDimUnionSet_[rhsRoot] = lhsRoot; | ||
} else { | ||
if (!rhsRoot.merge(lhsRoot)) return false; | ||
symbolDimUnionSet_[lhsRoot] = rhsRoot; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
} // namespace ir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <functional> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
#include "paddle/ir/core/builtin_op.h" | ||
#include "paddle/ir/core/utils.h" | ||
#include "paddle/ir/dialect/shape/ir/shape_op.h" | ||
|
||
namespace ir { | ||
|
||
using ir::dialect::SymbolicDim; | ||
|
||
struct SymbolicDimProduct { | ||
std::vector<SymbolicDim> symbols; | ||
int64_t factor = 1; | ||
bool empty() { return factor == 1 && symbols.empty(); } | ||
}; | ||
|
||
inline bool operator==(const SymbolicDimProduct& lhs, | ||
const SymbolicDimProduct& rhs) { | ||
return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols; | ||
} | ||
|
||
inline bool operator!=(const SymbolicDimProduct& lhs, | ||
const SymbolicDimProduct& rhs) { | ||
return !(lhs == rhs); | ||
} | ||
|
||
class SymbolTable { | ||
public: | ||
explicit SymbolTable(ir::Operation* symbolTableOp) | ||
: symbolTableOp_(symbolTableOp) {} | ||
ir::Operation* lookup(const std::string& name) const; | ||
const std::string insert(Operation* symbol); | ||
ir::Operation* getOp() const { return symbolTableOp_; } | ||
|
||
private: | ||
ir::Operation* symbolTableOp_; | ||
std::unordered_map<std::string, ir::Operation*> symbolTableMap_; | ||
}; | ||
|
||
struct SymDimHasher { | ||
size_t operator()(const ir::dialect::SymbolicDim& symbol) const noexcept { | ||
return std::hash<ir::Operation*>{}(symbol.operation()); | ||
} | ||
}; | ||
|
||
struct SymProductHasher { | ||
size_t operator()(const ir::SymbolicDimProduct& symProd) const noexcept { | ||
size_t hash = std::hash<size_t>{}(symProd.symbols.size()); | ||
for (auto& symbol : symProd.symbols) { | ||
hash = hash_combine(hash, SymDimHasher{}(symbol)); // NOLINT | ||
} | ||
hash = hash_combine(hash, std::hash<int64_t>{}(symProd.factor)); | ||
return hash; | ||
} | ||
}; | ||
|
||
class SymbolicDimMgr { | ||
public: | ||
explicit SymbolicDimMgr(ir::ModuleOp m); | ||
SymbolicDim newSymbolicDim(const std::string& name = {}); | ||
SymbolicDim newConstantSymbolicDim(int64_t val); | ||
std::vector<SymbolicDim> createSymbolicDimsForRankedValue(Value value); | ||
SymbolicDim getRootSymbolicDim(SymbolicDim symbol); | ||
bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); | ||
SymbolTable& symbolTable() { return symbolTable_; } | ||
bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); | ||
|
||
private: | ||
const std::string getNextName(); | ||
|
||
private: | ||
ir::ModuleOp m_; | ||
|
||
SymbolTable symbolTable_; | ||
|
||
int64_t nextSymbolicIdx_ = 0; | ||
|
||
std::unordered_set<std::string> symbolNameSet_; | ||
|
||
std::unordered_map<SymbolicDim, SymbolicDim, SymDimHasher> symbolDimUnionSet_; | ||
|
||
std::unordered_map<int64_t, SymbolicDim> constantSymbolicDimMap_; | ||
|
||
// productEqualityMap_[A][B] == true : Product[A] == Product[B] | ||
using SymbolicDimProductMap = std::unordered_map< | ||
SymbolicDimProduct, | ||
std::unordered_map<SymbolicDimProduct, bool, SymProductHasher>, | ||
SymProductHasher>; | ||
SymbolicDimProductMap productEqualityMap_; | ||
}; | ||
|
||
} // namespace ir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,3 +39,5 @@ collect_srcs( | |
kernel_factory.cc | ||
tensor_utils.cc | ||
utils/type_info.cc) | ||
|
||
cc_library(ddim SRCS ddim.cc) |
Oops, something went wrong.