Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GlobalIsel] Combine logic of icmps #77855

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,12 @@ class CombinerHelper {
/// Combine selects.
bool matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Combine ands,
bool matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Combine ors,
bool matchOr(MachineInstr &MI, BuildFnTy &MatchInfo);

private:
/// Checks for legality of an indexed variant of \p LdSt.
bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
Expand Down Expand Up @@ -919,6 +925,12 @@ class CombinerHelper {
bool AllowUndefs);

std::optional<APInt> getConstantOrConstantSplatVector(Register Src);

/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
/// into a single comparison using range-based reasoning.
bool tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic,
BuildFnTy &MatchInfo);
};
} // namespace llvm

Expand Down
128 changes: 128 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,134 @@ class GPhi : public GenericMachineInstr {
}
};

/// Represents a binary operation, i.e, x = y op z.
class GBinOp : public GenericMachineInstr {
public:
Register getLHSReg() const { return getReg(1); }
Register getRHSReg() const { return getReg(2); }

static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
// Integer.
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_UMAX:
// Floating point.
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXIMUM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FPOW:
// Logical.
case TargetOpcode::G_AND:
case TargetOpcode::G_OR:
case TargetOpcode::G_XOR:
return true;
default:
return false;
}
};
};

/// Represents an integer binary operation.
class GIntBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_UMAX:
return true;
default:
return false;
}
};
};

/// Represents a floating point binary operation.
class GFBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXIMUM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FPOW:
return true;
default:
return false;
}
};
};

/// Represents a logical binary operation.
class GLogicalBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_AND:
case TargetOpcode::G_OR:
case TargetOpcode::G_XOR:
return true;
default:
return false;
}
};
};

/// Represents an integer addition.
class GAdd : public GIntBinOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_ADD;
};
};

/// Represents a logical and.
class GAnd : public GLogicalBinOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_AND;
};
};

/// Represents a logical or.
class GOr : public GLogicalBinOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_OR;
};
};

} // namespace llvm

#endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
14 changes: 13 additions & 1 deletion llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,18 @@ def match_selects : GICombineRule<
[{ return Helper.matchSelect(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def match_ands : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_AND):$root,
[{ return Helper.matchAnd(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def match_ors : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_OR):$root,
[{ return Helper.matchOr(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// FIXME: These should use the custom predicate feature once it lands.
def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero,
undef_to_negative_one,
Expand Down Expand Up @@ -1314,7 +1326,7 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
intdiv_combines, mulh_combines, redundant_neg_operands,
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
sub_add_reg, select_to_minmax, redundant_binop_in_equality,
fsub_to_fneg, commute_constant_to_rhs]>;
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors]>;

// A combine group used to for prelegalizer combiners at -O0. The combines in
// this group have been selected based on experiments to balance code size and
Expand Down
176 changes: 176 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -6643,3 +6644,178 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {

return false;
}

/// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2)
/// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2)
/// into a single comparison using range-based reasoning.
/// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges.
Comment on lines +6649 to +6652
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd think this would be possible just using tablegen patterns? Is there a way to plug the trickier ConstantRange handling at the end of the precondition into a mostly-tablegened pattern @Pierre-vh

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a private member function of the CombinerHelper. It is only called from matchAnd and matchOr.

bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges(GLogicalBinOp *Logic,
BuildFnTy &MatchInfo) {
assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor");
bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND;
Register DstReg = Logic->getReg(0);
Register LHS = Logic->getLHSReg();
Register RHS = Logic->getRHSReg();
unsigned Flags = Logic->getFlags();

// We need an G_ICMP on the LHS register.
GICmp *Cmp1 = getOpcodeDef<GICmp>(LHS, MRI);
if (!Cmp1)
return false;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be hasOneUse checks here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code had one-use checks a few lines below. I moved them up.

// We need an G_ICMP on the RHS register.
GICmp *Cmp2 = getOpcodeDef<GICmp>(RHS, MRI);
if (!Cmp2)
return false;

APInt C1;
APInt C2;
std::optional<ValueAndVReg> MaybeC1 =
getIConstantVRegValWithLookThrough(Cmp1->getRHSReg(), MRI);
if (!MaybeC1)
return false;
C1 = MaybeC1->Value;

std::optional<ValueAndVReg> MaybeC2 =
getIConstantVRegValWithLookThrough(Cmp2->getRHSReg(), MRI);
if (!MaybeC2)
return false;
C2 = MaybeC2->Value;

Register R1 = Cmp1->getLHSReg();
Register R2 = Cmp2->getLHSReg();
CmpInst::Predicate Pred1 = Cmp1->getCond();
CmpInst::Predicate Pred2 = Cmp2->getCond();
LLT CmpTy = MRI.getType(Cmp1->getReg(0));
LLT CmpOperandTy = MRI.getType(R1);

// We build ands, adds, and constants of type CmpOperandTy.
// They must be legal to build.
if (!isLegalOrBeforeLegalizer({TargetOpcode::G_AND, CmpOperandTy}) ||
!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, CmpOperandTy}) ||
!isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, CmpOperandTy}))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use isConstantLegalOrBeforeLegalizer for correct vector handling

Copy link
Author

@tschuett tschuett Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did. The intent is more obvious, but this combine only uses scalars.

return false;

// Look through add of a constant offset on R1, R2, or both operands. This
// allows us to interpret the R + C' < C'' range idiom into a proper range.
std::optional<APInt> Offset1;
std::optional<APInt> Offset2;
if (R1 != R2) {
if (GAdd *Add = getOpcodeDef<GAdd>(R1, MRI)) {
std::optional<ValueAndVReg> MaybeOffset1 =
getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
if (MaybeOffset1) {
R1 = Add->getLHSReg();
Offset1 = MaybeOffset1->Value;
}
}
if (GAdd *Add = getOpcodeDef<GAdd>(R2, MRI)) {
std::optional<ValueAndVReg> MaybeOffset2 =
getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI);
if (MaybeOffset2) {
R2 = Add->getLHSReg();
Offset2 = MaybeOffset2->Value;
}
}
}

if (R1 != R2)
return false;

// We calculate the icmp ranges including maybe offsets.
ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, C1);
if (Offset1)
CR1 = CR1.subtract(*Offset1);

ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, C2);
if (Offset2)
CR2 = CR2.subtract(*Offset2);

bool CreateMask = false;
APInt LowerDiff;
std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2);
if (!CR) {
// We want to fold the icmps.
if (!MRI.hasOneNonDBGUse(Cmp1->getReg(0)) ||
!MRI.hasOneNonDBGUse(Cmp2->getReg(0)) || CR1.isWrappedSet() ||
CR2.isWrappedSet())
return false;

// Check whether we have equal-size ranges that only differ by one bit.
// In that case we can apply a mask to map one range onto the other.
LowerDiff = CR1.getLower() ^ CR2.getLower();
APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1);
APInt CR1Size = CR1.getUpper() - CR1.getLower();
if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff ||
CR1Size != CR2.getUpper() - CR2.getLower())
return false;

CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2;
CreateMask = true;
}

if (IsAnd)
CR = CR->inverse();

CmpInst::Predicate NewPred;
APInt NewC, Offset;
CR->getEquivalentICmp(NewPred, NewC, Offset);

// We take the result type of one of the original icmps, CmpTy, for
// the to be build icmp. The operand type, CmpOperandTy, is used for
// the other instructions and constants to be build. The types of
// the parameters and output are the same for add and and. CmpTy
// and the type of DstReg might differ. That is why we zext or trunc
// the icmp into the destination register.

MatchInfo = [=](MachineIRBuilder &B) {
if (CreateMask && Offset != 0) {
auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
auto Add = B.buildAdd(CmpOperandTy, And, OffsetC, Flags);
auto NewCon = B.buildConstant(CmpOperandTy, NewC);
auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
B.buildZExtOrTrunc(DstReg, ICmp);
} else if (CreateMask && Offset == 0) {
auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff);
auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask.
auto NewCon = B.buildConstant(CmpOperandTy, NewC);
auto ICmp = B.buildICmp(NewPred, CmpTy, And, NewCon);
B.buildZExtOrTrunc(DstReg, ICmp);
} else if (!CreateMask && Offset != 0) {
auto OffsetC = B.buildConstant(CmpOperandTy, Offset);
auto Add = B.buildAdd(CmpOperandTy, R1, OffsetC, Flags);
auto NewCon = B.buildConstant(CmpOperandTy, NewC);
auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon);
B.buildZExtOrTrunc(DstReg, ICmp);
} else if (!CreateMask && Offset == 0) {
auto NewCon = B.buildConstant(CmpOperandTy, NewC);
auto ICmp = B.buildICmp(NewPred, CmpTy, R1, NewCon);
B.buildZExtOrTrunc(DstReg, ICmp);
} else {
assert(false && "unexpected configuration of CreateMask and Offset");
}
};
return true;
}

bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) {
GAnd *And = cast<GAnd>(&MI);

if (tryFoldAndOrOrICmpsUsingRanges(And, MatchInfo))
return true;

return false;
}

bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) {
GOr *Or = cast<GOr>(&MI);

if (tryFoldAndOrOrICmpsUsingRanges(Or, MatchInfo))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

return true;

return false;
}
Loading