Skip to content

Commit

Permalink
add symbolTable & symbolicDimProduct & symbolicDimMgr. (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…56351)

* add symbolicDimProduct & symbolicDimMgr without method shape_constraint related

* split ddim in phi, add a target ddim, used by pd_type

* add pd_type.cc to ir_shape CMakeLists
  • Loading branch information
liuruyan authored and BeingGod committed Sep 9, 2023
1 parent d93ff52 commit f11e47f
Show file tree
Hide file tree
Showing 8 changed files with 394 additions and 27 deletions.
10 changes: 9 additions & 1 deletion paddle/ir/dialect/shape/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
add_subdirectory(ir)
file(GLOB_RECURSE SHAPE_SRCS "*.cc")
ir_library(
ir_shape
SRCS
${SHAPE_SRCS}
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc
DEPS
ddim
ir_core)
2 changes: 0 additions & 2 deletions paddle/ir/dialect/shape/ir/CMakeLists.txt

This file was deleted.

31 changes: 30 additions & 1 deletion paddle/ir/dialect/shape/ir/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void SymbolicDim::Build(
argument.AddAttribute("knownNonSizeZero", attr_knownNonSizeZero);
}

std::string SymbolicDim::getSymName() {
const std::string SymbolicDim::getSymName() {
return attribute<ir::StrAttribute>("sym_name").AsString();
}
int64_t SymbolicDim::getValue() {
Expand Down Expand Up @@ -103,6 +103,35 @@ void SymbolicDim::updateKnownNonSizeZero(bool attrValue) {
ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue));
}

bool SymbolicDim::isDynamic() {
return getValue() == -100000;
} // TODO(zhangbo): getValue() == ShapedType::kDynamic;

bool SymbolicDim::merge(SymbolicDim other) {
if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue())
return false;
if (isDynamic() && !other.isDynamic()) updateValue(other.getValue());

bool knownNonNegativeFlag =
getKnownNonNegative() || other.getKnownNonNegative();
bool knownNegativeOneFlag =
getKnownNegativeOne() || other.getKnownNegativeOne();
bool knownNonSizeOneFlag = getKnownNonSizeOne() ||
other.getKnownNonSizeOne() || knownNegativeOneFlag;
bool knownNonSizeZeroFlag = getKnownNonSizeZero() ||
other.getKnownNonSizeZero() ||
knownNegativeOneFlag;

if (knownNonNegativeFlag && knownNegativeOneFlag) return false;

updateKnownNonSizeZero(knownNonSizeZeroFlag);
updateKnownNonSizeOne(knownNonSizeOneFlag);
updateKnownNegativeOne(knownNegativeOneFlag);
updateKnownNonNegative(knownNonNegativeFlag);

return true;
}

} // namespace dialect
} // namespace ir

Expand Down
7 changes: 5 additions & 2 deletions paddle/ir/dialect/shape/ir/shape_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class IR_API SymbolicDim : public Op<SymbolicDim> {
bool knownNegativeOne = false,
bool knownNonSizeOne = false,
bool knownNonSizeZero = false);
std::string getSymName();
const std::string getSymName();
int64_t getValue();
bool getKnownNonNegative();
bool getKnownNegativeOne();
Expand All @@ -46,11 +46,14 @@ class IR_API SymbolicDim : public Op<SymbolicDim> {

void updateSymName(std::string attrValue);
void updateValue(int64_t attrValue);

void updateKnownNonNegative(bool attrValue);
void updateKnownNegativeOne(bool attrValue);
void updateKnownNonSizeOne(bool attrValue);
void updateKnownNonSizeZero(bool attrValue);

bool isDynamic();
bool merge(SymbolicDim other);

void Verify() {}
};

Expand Down
125 changes: 125 additions & 0 deletions paddle/ir/dialect/shape/utils/shape_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/ir/dialect/shape/utils/shape_utils.h"
#include <string>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
namespace ir {

bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) {
if (lhs.size() < 1 || lhs[0] != 'S' && lhs[0] != 'C') return lhs < rhs;
if (rhs.size() < 1 || rhs[0] != 'S' && rhs[0] != 'C') return lhs < rhs;
int64_t lhsIdx = 0, rhsIdx = 0;
try {
lhsIdx = stol(lhs.substr(1));
rhsIdx = stol(rhs.substr(1));
} catch (const std::exception& e) {
IR_THROW("Invalid symbolic name");
}
return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhsIdx < rhsIdx);
}

ir::Operation* SymbolTable::lookup(const std::string& name) const {
auto it = symbolTableMap_.find(name);
return it != symbolTableMap_.end() ? it->second : nullptr;
}

const std::string SymbolTable::insert(ir::Operation* symbol) {
std::string name;
if (symbol->HasAttribute("sym_name")) {
name = symbol->dyn_cast<SymbolicDim>().getSymName();
}
// TODO(liujinnan): add constraint_func name branch.
symbolTableMap_.insert({name, symbol});
return name;
}

const std::string SymbolicDimMgr::getNextName() {
std::string name;
do {
name = "S" + std::to_string(nextSymbolicIdx_++);
} while (!symbolNameSet_.insert(name).second);
return name;
}

SymbolicDimMgr::SymbolicDimMgr(ir::ModuleOp m) : m_(m), symbolTable_(m_) {}

SymbolicDim SymbolicDimMgr::newSymbolicDim(const std::string& name) {
::ir::Builder builder = ::ir::Builder(m_.ir_context(), m_.block());
ir::dialect::SymbolicDim symbol = builder.Build<ir::dialect::SymbolicDim>(
name.empty() ? getNextName() : name);
symbolDimUnionSet_[symbol] = symbol;
symbolTable_.insert(symbol);
return symbol;
}

SymbolicDim SymbolicDimMgr::newConstantSymbolicDim(int64_t val) {
auto it = constantSymbolicDimMap_.find(val);
if (it == constantSymbolicDimMap_.end()) {
auto name = "C" + std::to_string(val);
it = constantSymbolicDimMap_
.insert(std::make_pair(val, newSymbolicDim(name)))
.first;
it->second.updateValue(val);
}
return getRootSymbolicDim(it->second);
}

std::vector<SymbolicDim> SymbolicDimMgr::createSymbolicDimsForRankedValue(
ir::Value value) {
std::vector<SymbolicDim> symbols;
auto dims = value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
for (int idx = 0; idx < dims.size(); ++idx) {
symbols.push_back(
dims[idx] == -100000 // TODO(zhangbo): value = ShapedType::kDynamic
? newSymbolicDim()
: newConstantSymbolicDim(dims[idx]));
}
return symbols;
}

SymbolicDim SymbolicDimMgr::getRootSymbolicDim(SymbolicDim symbol) {
SymbolicDim current = symbol;
std::vector<SymbolicDim> path;
while (symbolDimUnionSet_[current] != current) {
path.push_back(current);
current = symbolDimUnionSet_[current];
}
for (SymbolicDim sym : path) symbolDimUnionSet_[sym] = current;
return current;
}

bool SymbolicDimMgr::isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) {
SymbolicDim lhsRoot = getRootSymbolicDim(lhs);
SymbolicDim rhsRoot = getRootSymbolicDim(rhs);
return lhsRoot == rhsRoot;
}

bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) {
SymbolicDim lhsRoot = getRootSymbolicDim(lhs);
SymbolicDim rhsRoot = getRootSymbolicDim(rhs);

if (lhsRoot != rhsRoot) {
if (compareSymbolicDimNames(lhsRoot.getSymName(), rhsRoot.getSymName())) {
if (!lhsRoot.merge(rhsRoot)) return false;
symbolDimUnionSet_[rhsRoot] = lhsRoot;
} else {
if (!rhsRoot.merge(lhsRoot)) return false;
symbolDimUnionSet_[lhsRoot] = rhsRoot;
}
}
return true;
}

} // namespace ir
109 changes: 109 additions & 0 deletions paddle/ir/dialect/shape/utils/shape_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <unordered_map>
#include <unordered_set>
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/utils.h"
#include "paddle/ir/dialect/shape/ir/shape_op.h"

namespace ir {

using ir::dialect::SymbolicDim;

struct SymbolicDimProduct {
std::vector<SymbolicDim> symbols;
int64_t factor = 1;
bool empty() { return factor == 1 && symbols.empty(); }
};

inline bool operator==(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols;
}

inline bool operator!=(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
return !(lhs == rhs);
}

class SymbolTable {
public:
explicit SymbolTable(ir::Operation* symbolTableOp)
: symbolTableOp_(symbolTableOp) {}
ir::Operation* lookup(const std::string& name) const;
const std::string insert(Operation* symbol);
ir::Operation* getOp() const { return symbolTableOp_; }

private:
ir::Operation* symbolTableOp_;
std::unordered_map<std::string, ir::Operation*> symbolTableMap_;
};

struct SymDimHasher {
size_t operator()(const ir::dialect::SymbolicDim& symbol) const noexcept {
return std::hash<ir::Operation*>{}(symbol.operation());
}
};

struct SymProductHasher {
size_t operator()(const ir::SymbolicDimProduct& symProd) const noexcept {
size_t hash = std::hash<size_t>{}(symProd.symbols.size());
for (auto& symbol : symProd.symbols) {
hash = hash_combine(hash, SymDimHasher{}(symbol)); // NOLINT
}
hash = hash_combine(hash, std::hash<int64_t>{}(symProd.factor));
return hash;
}
};

class SymbolicDimMgr {
public:
explicit SymbolicDimMgr(ir::ModuleOp m);
SymbolicDim newSymbolicDim(const std::string& name = {});
SymbolicDim newConstantSymbolicDim(int64_t val);
std::vector<SymbolicDim> createSymbolicDimsForRankedValue(Value value);
SymbolicDim getRootSymbolicDim(SymbolicDim symbol);
bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs);
SymbolTable& symbolTable() { return symbolTable_; }
bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs);

private:
const std::string getNextName();

private:
ir::ModuleOp m_;

SymbolTable symbolTable_;

int64_t nextSymbolicIdx_ = 0;

std::unordered_set<std::string> symbolNameSet_;

std::unordered_map<SymbolicDim, SymbolicDim, SymDimHasher> symbolDimUnionSet_;

std::unordered_map<int64_t, SymbolicDim> constantSymbolicDimMap_;

// productEqualityMap_[A][B] == true : Product[A] == Product[B]
using SymbolicDimProductMap = std::unordered_map<
SymbolicDimProduct,
std::unordered_map<SymbolicDimProduct, bool, SymProductHasher>,
SymProductHasher>;
SymbolicDimProductMap productEqualityMap_;
};

} // namespace ir
2 changes: 2 additions & 0 deletions paddle/phi/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ collect_srcs(
kernel_factory.cc
tensor_utils.cc
utils/type_info.cc)

cc_library(ddim SRCS ddim.cc)
Loading

0 comments on commit f11e47f

Please sign in to comment.