Skip to content

Commit 974f00a

Browse files
committed
[AArch64][SVE] Fold constant multiply of element count
Summary: E.g. %0 = tail call i64 @llvm.aarch64.sve.cntw(i32 31) %mul = mul i64 %0, <const> Should emit: cntw x0, all, mul #<const> For <const> in the range 1-16. Patch by Kerry McLaughlin Reviewers: sdesmalen, huntergr, dancgr, rengolin, efriedma Reviewed By: sdesmalen Subscribers: tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D71014
1 parent b237179 commit 974f00a

File tree

4 files changed

+126
-1
lines changed

4 files changed

+126
-1
lines changed

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,28 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
169169
return SelectSVELogicalImm(N, VT, Imm);
170170
}
171171

172+
// Returns a suitable CNT/INC/DEC/RDVL multiplier to calculate VSCALE*N.
173+
template<signed Min, signed Max, signed Scale, bool Shift>
174+
bool SelectCntImm(SDValue N, SDValue &Imm) {
175+
if (!isa<ConstantSDNode>(N))
176+
return false;
177+
178+
int64_t MulImm = cast<ConstantSDNode>(N)->getSExtValue();
179+
if (Shift)
180+
MulImm = 1 << MulImm;
181+
182+
if ((MulImm % std::abs(Scale)) != 0)
183+
return false;
184+
185+
MulImm /= Scale;
186+
if ((MulImm >= Min) && (MulImm <= Max)) {
187+
Imm = CurDAG->getTargetConstant(MulImm, SDLoc(N), MVT::i32);
188+
return true;
189+
}
190+
191+
return false;
192+
}
193+
172194
/// Form sequences of consecutive 64/128-bit registers for use in NEON
173195
/// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have
174196
/// between 1 and 4 elements. If it contains a single element that is returned

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

+22-1
Original file line numberDiff line numberDiff line change
@@ -9541,6 +9541,19 @@ AArch64TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
95419541
return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA);
95429542
}
95439543

9544+
static bool IsSVECntIntrinsic(SDValue S) {
9545+
switch(getIntrinsicID(S.getNode())) {
9546+
default:
9547+
break;
9548+
case Intrinsic::aarch64_sve_cntb:
9549+
case Intrinsic::aarch64_sve_cnth:
9550+
case Intrinsic::aarch64_sve_cntw:
9551+
case Intrinsic::aarch64_sve_cntd:
9552+
return true;
9553+
}
9554+
return false;
9555+
}
9556+
95449557
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
95459558
TargetLowering::DAGCombinerInfo &DCI,
95469559
const AArch64Subtarget *Subtarget) {
@@ -9551,9 +9564,18 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
95519564
if (!isa<ConstantSDNode>(N->getOperand(1)))
95529565
return SDValue();
95539566

9567+
SDValue N0 = N->getOperand(0);
95549568
ConstantSDNode *C = cast<ConstantSDNode>(N->getOperand(1));
95559569
const APInt &ConstValue = C->getAPIntValue();
95569570

9571+
// Allow the scaling to be folded into the `cnt` instruction by preventing
9572+
// the scaling to be obscured here. This makes it easier to pattern match.
9573+
if (IsSVECntIntrinsic(N0) ||
9574+
(N0->getOpcode() == ISD::TRUNCATE &&
9575+
(IsSVECntIntrinsic(N0->getOperand(0)))))
9576+
if (ConstValue.sge(1) && ConstValue.sle(16))
9577+
return SDValue();
9578+
95579579
// Multiplication of a power of two plus/minus one can be done more
95589580
// cheaply as as shift+add/sub. For now, this is true unilaterally. If
95599581
// future CPUs have a cheaper MADD instruction, this may need to be
@@ -9564,7 +9586,6 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
95649586
// e.g. 6=3*2=(2+1)*2.
95659587
// TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
95669588
// which equals to (1+2)*16-(1+2).
9567-
SDValue N0 = N->getOperand(0);
95689589
// TrailingZeroes is used to test if the mul can be lowered to
95699590
// shift+add+shift.
95709591
unsigned TrailingZeroes = ConstValue.countTrailingZeros();

llvm/lib/Target/AArch64/SVEInstrFormats.td

+10
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ def sve_incdec_imm : Operand<i32>, TImmLeaf<i32, [{
244244
let DecoderMethod = "DecodeSVEIncDecImm";
245245
}
246246

247+
// This allows i32 immediate extraction from i64 based arithmetic.
248+
def sve_cnt_mul_imm : ComplexPattern<i32, 1, "SelectCntImm<1, 16, 1, false>">;
249+
def sve_cnt_shl_imm : ComplexPattern<i32, 1, "SelectCntImm<1, 16, 1, true>">;
250+
247251
//===----------------------------------------------------------------------===//
248252
// SVE PTrue - These are used extensively throughout the pattern matching so
249253
// it's important we define them first.
@@ -635,6 +639,12 @@ multiclass sve_int_count<bits<3> opc, string asm, SDPatternOperator op> {
635639
def : InstAlias<asm # "\t$Rd",
636640
(!cast<Instruction>(NAME) GPR64:$Rd, 0b11111, 1), 2>;
637641

642+
def : Pat<(i64 (mul (op sve_pred_enum:$pattern), (sve_cnt_mul_imm i32:$imm))),
643+
(!cast<Instruction>(NAME) sve_pred_enum:$pattern, sve_incdec_imm:$imm)>;
644+
645+
def : Pat<(i64 (shl (op sve_pred_enum:$pattern), (i64 (sve_cnt_shl_imm i32:$imm)))),
646+
(!cast<Instruction>(NAME) sve_pred_enum:$pattern, sve_incdec_imm:$imm)>;
647+
638648
def : Pat<(i64 (op sve_pred_enum:$pattern)),
639649
(!cast<Instruction>(NAME) sve_pred_enum:$pattern, 1)>;
640650
}

llvm/test/CodeGen/AArch64/sve-intrinsics-counting-elems.ll

+72
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,24 @@ define i64 @cntb() {
1212
ret i64 %out
1313
}
1414

15+
define i64 @cntb_mul3() {
16+
; CHECK-LABEL: cntb_mul3:
17+
; CHECK: cntb x0, vl6, mul #3
18+
; CHECK-NEXT: ret
19+
%cnt = call i64 @llvm.aarch64.sve.cntb(i32 6)
20+
%out = mul i64 %cnt, 3
21+
ret i64 %out
22+
}
23+
24+
define i64 @cntb_mul4() {
25+
; CHECK-LABEL: cntb_mul4:
26+
; CHECK: cntb x0, vl8, mul #4
27+
; CHECK-NEXT: ret
28+
%cnt = call i64 @llvm.aarch64.sve.cntb(i32 8)
29+
%out = mul i64 %cnt, 4
30+
ret i64 %out
31+
}
32+
1533
;
1634
; CNTH
1735
;
@@ -24,6 +42,24 @@ define i64 @cnth() {
2442
ret i64 %out
2543
}
2644

45+
define i64 @cnth_mul5() {
46+
; CHECK-LABEL: cnth_mul5:
47+
; CHECK: cnth x0, vl7, mul #5
48+
; CHECK-NEXT: ret
49+
%cnt = call i64 @llvm.aarch64.sve.cnth(i32 7)
50+
%out = mul i64 %cnt, 5
51+
ret i64 %out
52+
}
53+
54+
define i64 @cnth_mul8() {
55+
; CHECK-LABEL: cnth_mul8:
56+
; CHECK: cnth x0, vl5, mul #8
57+
; CHECK-NEXT: ret
58+
%cnt = call i64 @llvm.aarch64.sve.cnth(i32 5)
59+
%out = mul i64 %cnt, 8
60+
ret i64 %out
61+
}
62+
2763
;
2864
; CNTW
2965
;
@@ -36,6 +72,24 @@ define i64 @cntw() {
3672
ret i64 %out
3773
}
3874

75+
define i64 @cntw_mul11() {
76+
; CHECK-LABEL: cntw_mul11:
77+
; CHECK: cntw x0, vl8, mul #11
78+
; CHECK-NEXT: ret
79+
%cnt = call i64 @llvm.aarch64.sve.cntw(i32 8)
80+
%out = mul i64 %cnt, 11
81+
ret i64 %out
82+
}
83+
84+
define i64 @cntw_mul2() {
85+
; CHECK-LABEL: cntw_mul2:
86+
; CHECK: cntw x0, vl6, mul #2
87+
; CHECK-NEXT: ret
88+
%cnt = call i64 @llvm.aarch64.sve.cntw(i32 6)
89+
%out = mul i64 %cnt, 2
90+
ret i64 %out
91+
}
92+
3993
;
4094
; CNTD
4195
;
@@ -48,6 +102,24 @@ define i64 @cntd() {
48102
ret i64 %out
49103
}
50104

105+
define i64 @cntd_mul15() {
106+
; CHECK-LABEL: cntd_mul15:
107+
; CHECK: cntd x0, vl16, mul #15
108+
; CHECK-NEXT: ret
109+
%cnt = call i64 @llvm.aarch64.sve.cntd(i32 9)
110+
%out = mul i64 %cnt, 15
111+
ret i64 %out
112+
}
113+
114+
define i64 @cntd_mul16() {
115+
; CHECK-LABEL: cntd_mul16:
116+
; CHECK: cntd x0, vl32, mul #16
117+
; CHECK-NEXT: ret
118+
%cnt = call i64 @llvm.aarch64.sve.cntd(i32 10)
119+
%out = mul i64 %cnt, 16
120+
ret i64 %out
121+
}
122+
51123
;
52124
; CNTP
53125
;

0 commit comments

Comments
 (0)