[EmitHLSCpp] Support max/min ops
This commit is contained in:
parent
d775ccc9f9
commit
3100f042f3
|
@ -61,13 +61,14 @@ public:
|
|||
|
||||
// Float binary expressions.
|
||||
arith::CmpFOp, arith::AddFOp, arith::SubFOp, arith::MulFOp,
|
||||
arith::DivFOp, arith::RemFOp,
|
||||
arith::DivFOp, arith::RemFOp, arith::MaxFOp, arith::MinFOp,
|
||||
|
||||
// Integer binary expressions.
|
||||
arith::CmpIOp, arith::AddIOp, arith::SubIOp, arith::MulIOp,
|
||||
arith::DivSIOp, arith::RemSIOp, arith::DivUIOp, arith::RemUIOp,
|
||||
arith::XOrIOp, arith::AndIOp, arith::OrIOp, arith::ShLIOp,
|
||||
arith::ShRSIOp, arith::ShRUIOp,
|
||||
arith::ShRSIOp, arith::ShRUIOp, arith::MaxSIOp, arith::MinSIOp,
|
||||
arith::MaxUIOp, arith::MinUIOp,
|
||||
|
||||
// Special expressions.
|
||||
SelectOp, ConstantOp, arith::ConstantOp, arith::TruncIOp,
|
||||
|
@ -166,6 +167,8 @@ public:
|
|||
HANDLE(arith::MulFOp);
|
||||
HANDLE(arith::DivFOp);
|
||||
HANDLE(arith::RemFOp);
|
||||
HANDLE(arith::MaxFOp);
|
||||
HANDLE(arith::MinFOp);
|
||||
|
||||
// Integer binary expressions.
|
||||
HANDLE(arith::CmpIOp);
|
||||
|
@ -182,6 +185,10 @@ public:
|
|||
HANDLE(arith::ShLIOp);
|
||||
HANDLE(arith::ShRSIOp);
|
||||
HANDLE(arith::ShRUIOp);
|
||||
HANDLE(arith::MaxSIOp);
|
||||
HANDLE(arith::MinSIOp);
|
||||
HANDLE(arith::MaxUIOp);
|
||||
HANDLE(arith::MinUIOp);
|
||||
|
||||
// Special expressions.
|
||||
HANDLE(SelectOp);
|
||||
|
|
|
@ -235,6 +235,7 @@ public:
|
|||
/// Standard expression emitters.
|
||||
void emitUnary(Operation *op, const char *syntax);
|
||||
void emitBinary(Operation *op, const char *syntax);
|
||||
template <typename OpType> void emitMaxMin(OpType op, const char *syntax);
|
||||
|
||||
/// Special expression emitters.
|
||||
void emitSelect(SelectOp op);
|
||||
|
@ -461,6 +462,8 @@ public:
|
|||
bool visitOp(arith::MulFOp op) { return emitter.emitBinary(op, "*"), true; }
|
||||
bool visitOp(arith::DivFOp op) { return emitter.emitBinary(op, "/"), true; }
|
||||
bool visitOp(arith::RemFOp op) { return emitter.emitBinary(op, "%"), true; }
|
||||
bool visitOp(arith::MaxFOp op) { return emitter.emitMaxMin(op, "max"), true; }
|
||||
bool visitOp(arith::MinFOp op) { return emitter.emitMaxMin(op, "min"), true; }
|
||||
|
||||
/// Integer binary expressions.
|
||||
bool visitOp(arith::CmpIOp op);
|
||||
|
@ -477,6 +480,18 @@ public:
|
|||
bool visitOp(arith::ShLIOp op) { return emitter.emitBinary(op, "<<"), true; }
|
||||
bool visitOp(arith::ShRSIOp op) { return emitter.emitBinary(op, ">>"), true; }
|
||||
bool visitOp(arith::ShRUIOp op) { return emitter.emitBinary(op, ">>"), true; }
|
||||
bool visitOp(arith::MaxSIOp op) {
|
||||
return emitter.emitMaxMin(op, "max"), true;
|
||||
}
|
||||
bool visitOp(arith::MinSIOp op) {
|
||||
return emitter.emitMaxMin(op, "min"), true;
|
||||
}
|
||||
bool visitOp(arith::MaxUIOp op) {
|
||||
return emitter.emitMaxMin(op, "max"), true;
|
||||
}
|
||||
bool visitOp(arith::MinUIOp op) {
|
||||
return emitter.emitMaxMin(op, "min"), true;
|
||||
}
|
||||
|
||||
/// Special expressions.
|
||||
bool visitOp(SelectOp op) { return emitter.emitSelect(op), true; }
|
||||
|
@ -1187,6 +1202,20 @@ void ModuleEmitter::emitBinary(Operation *op, const char *syntax) {
|
|||
emitNestedLoopFooter(rank);
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
void ModuleEmitter::emitMaxMin(OpType op, const char *syntax) {
|
||||
auto rank = emitNestedLoopHeader(op.getResult());
|
||||
indent();
|
||||
emitValue(op.getResult());
|
||||
os << " = " << syntax << "(";
|
||||
emitValue(op.getLhs(), rank);
|
||||
os << ", ";
|
||||
emitValue(op.getRhs(), rank);
|
||||
os << ");";
|
||||
emitInfoAndNewLine(op);
|
||||
emitNestedLoopFooter(rank);
|
||||
}
|
||||
|
||||
/// Special expression emitters.
|
||||
void ModuleEmitter::emitSelect(SelectOp op) {
|
||||
unsigned rank = emitNestedLoopHeader(op.getResult());
|
||||
|
|
|
@ -26,10 +26,15 @@ func @test_integer_binary(%arg0: i32, %arg1: i32) -> i32 {
|
|||
%10 = arith.shli %arg0, %9 : i32
|
||||
// CHECK: >>
|
||||
%11 = arith.shrsi %arg0, %10 : i32
|
||||
|
||||
// CHECK: *[[VAL_2:.*]] = [[ARG_0:.*]] >> [[VAL_1:.*]];
|
||||
// CHECK: >>
|
||||
%12 = arith.shrui %arg0, %11 : i32
|
||||
return %12 : i32
|
||||
|
||||
// CHECK: int32_t [[VAL_3:.*]] = max([[ARG_0:.*]], [[VAL_2:.*]]);
|
||||
%13 = arith.maxsi %arg0, %12 : i32
|
||||
|
||||
// CHECK: *[[VAL_4:.*]] = min([[ARG_0:.*]], [[VAL_3]]);
|
||||
%14 = arith.minui %arg0, %13 : i32
|
||||
return %14 : i32
|
||||
}
|
||||
|
||||
func @test_float_binary_unary(%arg0: f32, %arg1: f32) -> f32 {
|
||||
|
@ -71,7 +76,13 @@ func @test_float_binary_unary(%arg0: f32, %arg1: f32) -> f32 {
|
|||
%16 = math.log2 %15 : f32
|
||||
// CHECK: log10
|
||||
%17 = math.log10 %16 : f32
|
||||
return %17 : f32
|
||||
|
||||
// CHECK: float [[VAL_4:.*]] = max([[ARG_0:.*]], [[VAL_3:.*]]);
|
||||
%18 = arith.maxf %arg0, %17 : f32
|
||||
|
||||
// CHECK: *[[VAL_5:.*]] = min([[ARG_0:.*]], [[VAL_4]]);
|
||||
%19 = arith.minf %arg0, %18 : f32
|
||||
return %19 : f32
|
||||
}
|
||||
|
||||
func @test_special_expr(%arg0: i1, %arg1: index, %arg2: index) -> index {
|
||||
|
|
Loading…
Reference in New Issue