Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RAM transforms #1 (Remove setOperation) #873

Merged
merged 2 commits into from
Feb 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/RamOperation.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,6 @@ class RamNestedOperation : public RamOperation {
return *nestedOperation;
}

/** Set nested operation */
void setOperation(std::unique_ptr<RamOperation> nested) {
nestedOperation = std::move(nested);
}

/** Print */
void print(std::ostream& os, int tabpos) const override {
nestedOperation->print(os, tabpos + 1);
Expand Down
89 changes: 43 additions & 46 deletions src/RamTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace {
std::vector<RamCondition*> getConditions(const RamCondition* condition) {
std::vector<RamCondition*> conditions;
while (condition != nullptr) {
if (const RamAnd* ramAnd = dynamic_cast<const RamAnd*>(condition)) {
if (const auto* ramAnd = dynamic_cast<const RamAnd*>(condition)) {
conditions.push_back(ramAnd->getRHS().clone());
condition = &ramAnd->getLHS();
} else {
Expand Down Expand Up @@ -64,20 +64,16 @@ bool LevelConditionsTransformer::levelConditions(RamProgram& program) {
}
}

using RamNodeMapper::operator();

std::unique_ptr<RamNode> operator()(std::unique_ptr<RamNode> node) const override {
if (RamNestedOperation* nested = dynamic_cast<RamNestedOperation*>(node.get())) {
if (const RamFilter* filter = dynamic_cast<const RamFilter*>(&nested->getOperation())) {
const RamCondition& condition = filter->getCondition();
if (auto* filter = dynamic_cast<RamFilter*>(node.get())) {
const RamCondition& condition = filter->getCondition();

if (context->rcla->getLevel(&condition) == identifier) {
addCondition(std::unique_ptr<RamCondition>(condition.clone()));
if (context->rcla->getLevel(&condition) == identifier) {
addCondition(std::unique_ptr<RamCondition>(condition.clone()));

// skip this filter
nested->setOperation(std::unique_ptr<RamOperation>(filter->getOperation().clone()));
return (*this)(std::move(node));
}
// skip this filter
node->apply(*this);
return std::unique_ptr<RamOperation>(filter->getOperation().clone());
}
}

Expand All @@ -86,6 +82,21 @@ bool LevelConditionsTransformer::levelConditions(RamProgram& program) {
}
};

class RamFilterInsert : public RamNodeMapper {
std::unique_ptr<RamCondition> condition;

public:
RamFilterInsert(std::unique_ptr<RamCondition> c) : condition(std::move(c)) {}

std::unique_ptr<RamNode> operator()(std::unique_ptr<RamNode> node) const override {
if (nullptr != dynamic_cast<RamOperation*>(node.get())) {
return std::make_unique<RamFilter>(std::unique_ptr<RamCondition>(condition->clone()),
std::unique_ptr<RamOperation>(dynamic_cast<RamOperation*>(node.release())));
}
return node;
}
};

// Node-mapper that searches for and updates RAM scans nested in RAM inserts
class RamScanCapturer : public RamNodeMapper {
mutable bool modified = false;
Expand All @@ -99,18 +110,16 @@ bool LevelConditionsTransformer::levelConditions(RamProgram& program) {
}

std::unique_ptr<RamNode> operator()(std::unique_ptr<RamNode> node) const override {
if (RamScan* scan = dynamic_cast<RamScan*>(node.get())) {
if (auto* scan = dynamic_cast<RamScan*>(node.get())) {
RamFilterCapturer filterUpdate(context, scan->getIdentifier());
std::unique_ptr<RamScan> newScan = filterUpdate(std::unique_ptr<RamScan>(scan->clone()));
node->apply(filterUpdate);

// If a condition applies to this scan level, filter the scan based on the condition
if (std::unique_ptr<RamCondition> condition = filterUpdate.getCondition()) {
newScan->setOperation(std::make_unique<RamFilter>(std::move(condition),
std::unique_ptr<RamOperation>(newScan->getOperation().clone())));
RamFilterInsert filterInsert(std::move(condition));
node->apply(filterInsert);
modified = true;
}

node = std::move(newScan);
}

node->apply(*this);
Expand All @@ -132,7 +141,7 @@ bool LevelConditionsTransformer::levelConditions(RamProgram& program) {

std::unique_ptr<RamNode> operator()(std::unique_ptr<RamNode> node) const override {
// get all RAM inserts
if (RamInsert* insert = dynamic_cast<RamInsert*>(node.get())) {
if (auto* insert = dynamic_cast<RamInsert*>(node.get())) {
RamScanCapturer scanUpdate(context);
insert->apply(scanUpdate);

Expand All @@ -158,7 +167,7 @@ bool LevelConditionsTransformer::levelConditions(RamProgram& program) {
/** Get indexable element */
std::unique_ptr<RamValue> CreateIndicesTransformer::getIndexElement(
RamCondition* c, size_t& element, size_t identifier) {
if (RamBinaryRelation* binRelOp = dynamic_cast<RamBinaryRelation*>(c)) {
if (auto* binRelOp = dynamic_cast<RamBinaryRelation*>(c)) {
if (binRelOp->getOperator() == BinaryConstraintOp::EQ) {
if (auto* lhs = dynamic_cast<RamElementAccess*>(binRelOp->getLHS())) {
RamValue* rhs = binRelOp->getRHS();
Expand All @@ -182,7 +191,7 @@ std::unique_ptr<RamValue> CreateIndicesTransformer::getIndexElement(
}

std::unique_ptr<RamOperation> CreateIndicesTransformer::rewriteScan(const RamScan* scan) {
if (const RamFilter* filter = dynamic_cast<const RamFilter*>(&scan->getOperation())) {
if (const auto* filter = dynamic_cast<const RamFilter*>(&scan->getOperation())) {
const RamRelationReference& rel = scan->getRelation();
const size_t identifier = scan->getIdentifier();

Expand Down Expand Up @@ -240,19 +249,17 @@ bool CreateIndicesTransformer::createIndices(RamProgram& program) {
CreateIndicesTransformer* context;

public:
RamScanCapturer(CreateIndicesTransformer* c) : modified(false), context(c) {}
RamScanCapturer(CreateIndicesTransformer* c) : context(c) {}

bool getModified() const {
return modified;
}

std::unique_ptr<RamNode> operator()(std::unique_ptr<RamNode> node) const override {
if (RamNestedOperation* nested = dynamic_cast<RamNestedOperation*>(node.get())) {
if (const RamScan* scan = dynamic_cast<const RamScan*>(&nested->getOperation())) {
if (std::unique_ptr<RamOperation> op = context->rewriteScan(scan)) {
modified = true;
nested->setOperation(std::move(op));
}
if (auto* scan = dynamic_cast<RamScan*>(node.get())) {
if (std::unique_ptr<RamOperation> op = context->rewriteScan(scan)) {
modified = true;
node = std::move(op);
}
}
node->apply(*this);
Expand All @@ -274,14 +281,7 @@ bool CreateIndicesTransformer::createIndices(RamProgram& program) {

std::unique_ptr<RamNode> operator()(std::unique_ptr<RamNode> node) const override {
// get all RAM inserts
if (RamInsert* insert = dynamic_cast<RamInsert*>(node.get())) {
// TODO: better way to modify the child of a RAM insert
if (const RamScan* scan = dynamic_cast<const RamScan*>(&insert->getOperation())) {
if (std::unique_ptr<RamOperation> op = context->rewriteScan(scan)) {
modified = true;
insert->setOperation(std::move(op));
}
}
if (auto* insert = dynamic_cast<RamInsert*>(node.get())) {
RamScanCapturer scanUpdate(context);
insert->apply(scanUpdate);
if (!modified && scanUpdate.getModified()) {
Expand All @@ -297,7 +297,6 @@ bool CreateIndicesTransformer::createIndices(RamProgram& program) {

// level all RAM inserts
RamInsertCapturer insertUpdate(this);

program.getMain()->apply(insertUpdate);

return insertUpdate.getModified();
Expand All @@ -323,17 +322,15 @@ bool ConvertExistenceChecksTransformer::convertExistenceChecks(RamProgram& progr
while (!queue.empty()) {
const RamValue* val = queue.back();
queue.pop_back();
if (const RamElementAccess* elemAccess = dynamic_cast<const RamElementAccess*>(val)) {
if (const auto* elemAccess = dynamic_cast<const RamElementAccess*>(val)) {
if (context->rvla->getLevel(elemAccess) == identifier) {
return true;
}
} else if (const RamIntrinsicOperator* intrinsicOp =
dynamic_cast<const RamIntrinsicOperator*>(val)) {
} else if (const auto* intrinsicOp = dynamic_cast<const RamIntrinsicOperator*>(val)) {
for (const RamValue* arg : intrinsicOp->getArguments()) {
queue.push_back(arg);
}
} else if (const RamUserDefinedOperator* userDefinedOp =
dynamic_cast<const RamUserDefinedOperator*>(val)) {
} else if (const auto* userDefinedOp = dynamic_cast<const RamUserDefinedOperator*>(val)) {
for (const RamValue* arg : userDefinedOp->getArguments()) {
queue.push_back(arg);
}
Expand All @@ -343,14 +340,14 @@ bool ConvertExistenceChecksTransformer::convertExistenceChecks(RamProgram& progr
}

bool dependsOn(const RamCondition* condition, const size_t identifier) const {
if (const RamBinaryRelation* binRel = dynamic_cast<const RamBinaryRelation*>(condition)) {
if (const auto* binRel = dynamic_cast<const RamBinaryRelation*>(condition)) {
return dependsOn(binRel->getLHS(), identifier) || dependsOn(binRel->getRHS(), identifier);
}
return false;
}

std::unique_ptr<RamNode> operator()(std::unique_ptr<RamNode> node) const override {
if (RamRelationSearch* scan = dynamic_cast<RamRelationSearch*>(node.get())) {
if (auto* scan = dynamic_cast<RamRelationSearch*>(node.get())) {
const size_t identifier = scan->getIdentifier();
bool isExistCheck = true;
visitDepthFirst(scan->getOperation(), [&](const RamFilter& filter) {
Expand Down Expand Up @@ -388,7 +385,7 @@ bool ConvertExistenceChecksTransformer::convertExistenceChecks(RamProgram& progr
const RamValue* value = values.back();
values.pop_back();

if (const RamPack* pack = dynamic_cast<const RamPack*>(value)) {
if (const auto* pack = dynamic_cast<const RamPack*>(value)) {
const std::vector<RamValue*> args = pack->getArguments();
values.insert(values.end(), args.begin(), args.end());
} else if (const auto* intrinsicOp =
Expand Down Expand Up @@ -448,7 +445,7 @@ bool ConvertExistenceChecksTransformer::convertExistenceChecks(RamProgram& progr

std::unique_ptr<RamNode> operator()(std::unique_ptr<RamNode> node) const override {
// get all RAM inserts
if (RamInsert* insert = dynamic_cast<RamInsert*>(node.get())) {
if (auto* insert = dynamic_cast<RamInsert*>(node.get())) {
RamScanCapturer scanUpdate(context);
insert->apply(scanUpdate);

Expand Down