[EmitHLSCpp] Support max/min ops

This commit is contained in:
Hanchen Ye 2022-02-16 19:12:51 -06:00
parent d775ccc9f9
commit 3100f042f3
3 changed files with 53 additions and 6 deletions

View File

@ -61,13 +61,14 @@ public:
// Float binary expressions.
arith::CmpFOp, arith::AddFOp, arith::SubFOp, arith::MulFOp,
arith::DivFOp, arith::RemFOp,
arith::DivFOp, arith::RemFOp, arith::MaxFOp, arith::MinFOp,
// Integer binary expressions.
arith::CmpIOp, arith::AddIOp, arith::SubIOp, arith::MulIOp,
arith::DivSIOp, arith::RemSIOp, arith::DivUIOp, arith::RemUIOp,
arith::XOrIOp, arith::AndIOp, arith::OrIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp,
arith::ShRSIOp, arith::ShRUIOp, arith::MaxSIOp, arith::MinSIOp,
arith::MaxUIOp, arith::MinUIOp,
// Special expressions.
SelectOp, ConstantOp, arith::ConstantOp, arith::TruncIOp,
@ -166,6 +167,8 @@ public:
HANDLE(arith::MulFOp);
HANDLE(arith::DivFOp);
HANDLE(arith::RemFOp);
HANDLE(arith::MaxFOp);
HANDLE(arith::MinFOp);
// Integer binary expressions.
HANDLE(arith::CmpIOp);
@ -182,6 +185,10 @@ public:
HANDLE(arith::ShLIOp);
HANDLE(arith::ShRSIOp);
HANDLE(arith::ShRUIOp);
HANDLE(arith::MaxSIOp);
HANDLE(arith::MinSIOp);
HANDLE(arith::MaxUIOp);
HANDLE(arith::MinUIOp);
// Special expressions.
HANDLE(SelectOp);

View File

@ -235,6 +235,7 @@ public:
/// Standard expression emitters.
void emitUnary(Operation *op, const char *syntax);
void emitBinary(Operation *op, const char *syntax);
template <typename OpType> void emitMaxMin(OpType op, const char *syntax);
/// Special expression emitters.
void emitSelect(SelectOp op);
@ -461,6 +462,8 @@ public:
bool visitOp(arith::MulFOp op) { return emitter.emitBinary(op, "*"), true; }
bool visitOp(arith::DivFOp op) { return emitter.emitBinary(op, "/"), true; }
bool visitOp(arith::RemFOp op) { return emitter.emitBinary(op, "%"), true; }
bool visitOp(arith::MaxFOp op) { return emitter.emitMaxMin(op, "max"), true; }
bool visitOp(arith::MinFOp op) { return emitter.emitMaxMin(op, "min"), true; }
/// Integer binary expressions.
bool visitOp(arith::CmpIOp op);
@ -477,6 +480,18 @@ public:
bool visitOp(arith::ShLIOp op) { return emitter.emitBinary(op, "<<"), true; }
bool visitOp(arith::ShRSIOp op) { return emitter.emitBinary(op, ">>"), true; }
bool visitOp(arith::ShRUIOp op) { return emitter.emitBinary(op, ">>"), true; }
bool visitOp(arith::MaxSIOp op) {
return emitter.emitMaxMin(op, "max"), true;
}
bool visitOp(arith::MinSIOp op) {
return emitter.emitMaxMin(op, "min"), true;
}
bool visitOp(arith::MaxUIOp op) {
return emitter.emitMaxMin(op, "max"), true;
}
bool visitOp(arith::MinUIOp op) {
return emitter.emitMaxMin(op, "min"), true;
}
/// Special expressions.
bool visitOp(SelectOp op) { return emitter.emitSelect(op), true; }
@ -1187,6 +1202,20 @@ void ModuleEmitter::emitBinary(Operation *op, const char *syntax) {
emitNestedLoopFooter(rank);
}
template <typename OpType>
void ModuleEmitter::emitMaxMin(OpType op, const char *syntax) {
auto rank = emitNestedLoopHeader(op.getResult());
indent();
emitValue(op.getResult());
os << " = " << syntax << "(";
emitValue(op.getLhs(), rank);
os << ", ";
emitValue(op.getRhs(), rank);
os << ");";
emitInfoAndNewLine(op);
emitNestedLoopFooter(rank);
}
/// Special expression emitters.
void ModuleEmitter::emitSelect(SelectOp op) {
unsigned rank = emitNestedLoopHeader(op.getResult());

View File

@ -26,10 +26,15 @@ func @test_integer_binary(%arg0: i32, %arg1: i32) -> i32 {
%10 = arith.shli %arg0, %9 : i32
// CHECK: >>
%11 = arith.shrsi %arg0, %10 : i32
// CHECK: *[[VAL_2:.*]] = [[ARG_0:.*]] >> [[VAL_1:.*]];
// CHECK: >>
%12 = arith.shrui %arg0, %11 : i32
return %12 : i32
// CHECK: int32_t [[VAL_3:.*]] = max([[ARG_0:.*]], [[VAL_2:.*]]);
%13 = arith.maxsi %arg0, %12 : i32
// CHECK: *[[VAL_4:.*]] = min([[ARG_0:.*]], [[VAL_3]]);
%14 = arith.minui %arg0, %13 : i32
return %14 : i32
}
func @test_float_binary_unary(%arg0: f32, %arg1: f32) -> f32 {
@ -71,7 +76,13 @@ func @test_float_binary_unary(%arg0: f32, %arg1: f32) -> f32 {
%16 = math.log2 %15 : f32
// CHECK: log10
%17 = math.log10 %16 : f32
return %17 : f32
// CHECK: float [[VAL_4:.*]] = max([[ARG_0:.*]], [[VAL_3:.*]]);
%18 = arith.maxf %arg0, %17 : f32
// CHECK: *[[VAL_5:.*]] = min([[ARG_0:.*]], [[VAL_4]]);
%19 = arith.minf %arg0, %18 : f32
return %19 : f32
}
func @test_special_expr(%arg0: i1, %arg1: index, %arg2: index) -> index {