[SMT] Add quantifier operations (#6842)

This commit is contained in:
Martin Erhart 2024-03-22 16:56:50 +01:00 committed by GitHub
parent 8ec05233d1
commit 6f0fc3de79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 758 additions and 1 deletions

View File

@ -179,7 +179,8 @@ def YieldOp : SMTOp<"yield", [
Pure,
Terminator,
ReturnLike,
ParentOneOf<["smt::SolverOp", "smt::CheckOp"]>,
ParentOneOf<["smt::SolverOp", "smt::CheckOp",
"smt::ForallOp", "smt::ExistsOp"]>,
]> {
let summary = "terminator operation for various regions of SMT operations";
let arguments = (ins Variadic<AnyType>:$values);
@ -337,4 +338,71 @@ def ImpliesOp : SMTOp<"implies", [Pure]> {
let assemblyFormat = "$lhs `,` $rhs attr-dict";
}
class QuantifierOp<string mnemonic> : SMTOp<mnemonic, [
RecursivelySpeculatable,
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"smt::YieldOp">,
]> {
let description = [{
This operation represents the }] # summary # [{ as described in the
[SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf).
It is part of the language itself rather than a theory or logic.
The operation specifies the name prefixes (as an optional attribute) and
types (as the types of the block arguments of the regions) of bound
variables that may be used in the 'body' of the operation. If a 'patterns'
region is specified, the block arguments must match the ones of the 'body'
region and (other than there) must be used at least once in the 'patterns'
region. It may also not contain any operations that bind variables, such as
quantifiers. While the 'body' region must always yield exactly one
`!smt.bool`-typed value, the 'patterns' region can yield an arbitrary number
(but at least one) of SMT values.
The 'no_patterns' attribute is only allowed when no 'patterns' region is
specified and forbids the solver to generate and use patterns for this
quantifier.
The 'weight' attribute indicates the importance of this quantifier being
instantiated compared to other quantifiers that may be present. The default
value is zero.
Both the 'no_patterns' and 'weight' attributes are annotations to the
quantifiers body term. Annotations and attributes are described in the
standard in sections 3.4, and 3.6 (specifically 3.6.5). SMT-LIB allows
adding custom attributes to provide solvers with additional metadata, e.g.,
hints such as above mentioned attributes. They are not part of the standard
themselves, but supported by common SMT solvers (e.g., Z3).
}];
let arguments = (ins DefaultValuedAttr<I32Attr, "0">:$weight,
UnitAttr:$noPattern,
OptionalAttr<StrArrayAttr>:$boundVarNames);
let regions = (region SizedRegion<1>:$body,
VariadicRegion<SizedRegion<1>>:$patterns);
let results = (outs BoolType:$result);
let builders = [
OpBuilder<(ins
"TypeRange":$boundVarTypes,
"function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
"{}">:$patternBuilder,
CArg<"uint32_t", "0">:$weight,
CArg<"bool", "false">:$noPattern)>
];
let skipDefaultBuilders = true;
let assemblyFormat = [{
($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
attr-dict-with-keyword $body (`patterns` $patterns^)?
}];
let hasVerifier = true;
let hasRegionVerifier = true;
}
def ForallOp : QuantifierOp<"forall"> { let summary = "forall quantifier"; }
def ExistsOp : QuantifierOp<"exists"> { let summary = "exists quantifier"; }
#endif // CIRCT_DIALECT_SMT_SMTOPS_TD

View File

@ -313,5 +313,156 @@ ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
//===----------------------------------------------------------------------===//
// ForallOp
//===----------------------------------------------------------------------===//
template <typename QuantifierOp>
static LogicalResult verifyQuantifierRegions(QuantifierOp op) {
if (op.getBoundVarNames() &&
op.getBody().getNumArguments() != op.getBoundVarNames()->size())
return op.emitOpError(
"number of bound variable names must match number of block arguments");
if (op.getBody().front().getTerminator()->getNumOperands() != 1)
return op.emitOpError("must have exactly one yielded value");
if (!isa<BoolType>(
op.getBody().front().getTerminator()->getOperand(0).getType()))
return op.emitOpError("yielded value must be of '!smt.bool' type");
for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) {
unsigned i = regionWithIndex.index();
Region &region = regionWithIndex.value();
if (op.getBody().getArgumentTypes() != region.getArgumentTypes())
return op.emitOpError()
<< "block argument number and types of the 'body' "
"and 'patterns' region #"
<< i << " must match";
if (region.front().getTerminator()->getNumOperands() < 1)
return op.emitOpError() << "'patterns' region #" << i
<< " must have at least one yielded value";
// All operations in the 'patterns' region must be SMT operations.
auto result = region.walk([&](Operation *childOp) {
if (!isa<SMTDialect>(childOp->getDialect())) {
auto diag = op.emitOpError()
<< "the 'patterns' region #" << i
<< " may only contain SMT dialect operations";
diag.attachNote(childOp->getLoc()) << "first non-SMT operation here";
return WalkResult::interrupt();
}
// There may be no quantifier (or other variable binding) operations in
// the 'patterns' region.
if (isa<ForallOp, ExistsOp>(childOp)) {
auto diag = op.emitOpError() << "the 'patterns' region #" << i
<< " must not contain "
"any variable binding operations";
diag.attachNote(childOp->getLoc()) << "first violating operation here";
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (result.wasInterrupted())
return failure();
}
return success();
}
template <typename Properties>
static void buildQuantifier(
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
std::optional<ArrayRef<StringRef>> boundVarNames,
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
uint32_t weight, bool noPattern) {
odsState.addTypes(BoolType::get(odsBuilder.getContext()));
if (weight != 0)
odsState.getOrAddProperties<Properties>().weight =
odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight);
if (noPattern)
odsState.getOrAddProperties<Properties>().noPattern =
odsBuilder.getUnitAttr();
if (boundVarNames.has_value()) {
SmallVector<Attribute> boundVarNamesList;
for (StringRef str : *boundVarNames)
boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str));
odsState.getOrAddProperties<Properties>().boundVarNames =
odsBuilder.getArrayAttr(boundVarNamesList);
}
{
OpBuilder::InsertionGuard guard(odsBuilder);
Region *region = odsState.addRegion();
Block *block = odsBuilder.createBlock(region);
block->addArguments(
boundVarTypes,
SmallVector<Location>(boundVarTypes.size(), odsState.location));
Value returnVal =
bodyBuilder(odsBuilder, odsState.location, block->getArguments());
odsBuilder.create<smt::YieldOp>(odsState.location, returnVal);
}
if (patternBuilder) {
Region *region = odsState.addRegion();
OpBuilder::InsertionGuard guard(odsBuilder);
Block *block = odsBuilder.createBlock(region);
block->addArguments(
boundVarTypes,
SmallVector<Location>(boundVarTypes.size(), odsState.location));
ValueRange returnVals =
patternBuilder(odsBuilder, odsState.location, block->getArguments());
odsBuilder.create<smt::YieldOp>(odsState.location, returnVals);
}
}
LogicalResult ForallOp::verify() {
if (!getPatterns().empty() && getNoPattern())
return emitOpError() << "patterns and the no_pattern attribute must not be "
"specified at the same time";
return success();
}
LogicalResult ForallOp::verifyRegions() {
return verifyQuantifierRegions(*this);
}
void ForallOp::build(
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
std::optional<ArrayRef<StringRef>> boundVarNames,
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
uint32_t weight, bool noPattern) {
buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
boundVarNames, patternBuilder, weight, noPattern);
}
//===----------------------------------------------------------------------===//
// ExistsOp
//===----------------------------------------------------------------------===//
LogicalResult ExistsOp::verify() {
if (!getPatterns().empty() && getNoPattern())
return emitOpError() << "patterns and the no_pattern attribute must not be "
"specified at the same time";
return success();
}
LogicalResult ExistsOp::verifyRegions() {
return verifyQuantifierRegions(*this);
}
void ExistsOp::build(
OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
std::optional<ArrayRef<StringRef>> boundVarNames,
function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
uint32_t weight, bool noPattern) {
buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
boundVarNames, patternBuilder, weight, noPattern);
}
#define GET_OP_CLASSES
#include "circt/Dialect/SMT/SMT.cpp.inc"

View File

@ -84,3 +84,94 @@ func.func @core(%in: i8) {
return
}
// CHECK-LABEL: func @quantifiers
func.func @quantifiers() {
// CHECK-NEXT: smt.forall ["a", "b"] weight 2 attributes {smt.some_attr} {
// CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool):
// CHECK-NEXT: smt.eq
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: } patterns {
// CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool, %{{.*}}: !smt.bool):
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: }, {
// CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool):
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: }
%0 = smt.forall ["a", "b"] weight 2 attributes {smt.some_attr} {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
%1 = smt.eq %arg2, %arg3 : !smt.bool
smt.yield %1 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
}, {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
}
// CHECK-NEXT: smt.forall ["a", "b"] no_pattern attributes {smt.some_attr} {
// CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool):
// CHECK-NEXT: smt.eq
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: }
%1 = smt.forall ["a", "b"] no_pattern attributes {smt.some_attr} {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
%2 = smt.eq %arg2, %arg3 : !smt.bool
smt.yield %2 : !smt.bool
}
// CHECK-NEXT: smt.forall {
// CHECK-NEXT: smt.constant
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: }
%2 = smt.forall {
%3 = smt.constant true
smt.yield %3 : !smt.bool
}
// CHECK-NEXT: smt.exists ["a", "b"] weight 2 attributes {smt.some_attr} {
// CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool):
// CHECK-NEXT: smt.eq
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: } patterns {
// CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool, %{{.*}}: !smt.bool):
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: }, {
// CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool):
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: }
%3 = smt.exists ["a", "b"] weight 2 attributes {smt.some_attr} {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
%4 = smt.eq %arg2, %arg3 : !smt.bool
smt.yield %4 : !smt.bool {smt.some_attr}
} patterns {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
}, {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
}
// CHECK-NEXT: smt.exists no_pattern attributes {smt.some_attr} {
// CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool):
// CHECK-NEXT: smt.eq
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: }
%4 = smt.exists no_pattern attributes {smt.some_attr} {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
%5 = smt.eq %arg2, %arg3 : !smt.bool
smt.yield %5 : !smt.bool {smt.some_attr}
}
// CHECK-NEXT: smt.exists [] {
// CHECK-NEXT: smt.constant
// CHECK-NEXT: smt.yield %{{.*}}
// CHECK-NEXT: }
%5 = smt.exists [] {
%6 = smt.constant true
smt.yield %6 : !smt.bool
}
return
}

View File

@ -155,3 +155,261 @@ func.func @ite_type_mismatch(%a: !smt.bool, %b: !smt.bv<32>) {
"smt.ite"(%a, %a, %b) {} : (!smt.bool, !smt.bool, !smt.bv<32>) -> !smt.bool
return
}
// -----
func.func @forall_number_of_decl_names_must_match_num_args() {
// expected-error @below {{number of bound variable names must match number of block arguments}}
%1 = smt.forall ["a"] {
^bb0(%arg2: !smt.int, %arg3: !smt.int):
%2 = smt.eq %arg2, %arg3 : !smt.int
smt.yield %2 : !smt.bool
}
return
}
// -----
func.func @exists_number_of_decl_names_must_match_num_args() {
// expected-error @below {{number of bound variable names must match number of block arguments}}
%1 = smt.exists ["a"] {
^bb0(%arg2: !smt.int, %arg3: !smt.int):
%2 = smt.eq %arg2, %arg3 : !smt.int
smt.yield %2 : !smt.bool
}
return
}
// -----
func.func @forall_yield_must_have_exactly_one_bool_value() {
// expected-error @below {{yielded value must be of '!smt.bool' type}}
%1 = smt.forall ["a", "b"] {
^bb0(%arg2: !smt.int, %arg3: !smt.int):
%2 = smt.int.add %arg2, %arg3
smt.yield %2 : !smt.int
}
return
}
// -----
func.func @forall_yield_must_have_exactly_one_bool_value() {
// expected-error @below {{must have exactly one yielded value}}
%1 = smt.forall ["a", "b"] {
^bb0(%arg2: !smt.int, %arg3: !smt.int):
smt.yield
}
return
}
// -----
func.func @exists_yield_must_have_exactly_one_bool_value() {
// expected-error @below {{yielded value must be of '!smt.bool' type}}
%1 = smt.exists ["a", "b"] {
^bb0(%arg2: !smt.int, %arg3: !smt.int):
%2 = smt.int.add %arg2, %arg3
smt.yield %2 : !smt.int
}
return
}
// -----
func.func @exists_yield_must_have_exactly_one_bool_value() {
// expected-error @below {{must have exactly one yielded value}}
%1 = smt.exists ["a", "b"] {
^bb0(%arg2: !smt.int, %arg3: !smt.int):
smt.yield
}
return
}
// -----
func.func @exists_patterns_region_and_no_patterns_attr_are_mutually_exclusive() {
// expected-error @below {{patterns and the no_pattern attribute must not be specified at the same time}}
%1 = smt.exists ["a"] no_pattern {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
}
return
}
// -----
func.func @forall_patterns_region_and_no_patterns_attr_are_mutually_exclusive() {
// expected-error @below {{patterns and the no_pattern attribute must not be specified at the same time}}
%1 = smt.forall ["a"] no_pattern {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
}
return
}
// -----
func.func @exists_patterns_region_num_args() {
// expected-error @below {{block argument number and types of the 'body' and 'patterns' region #0 must match}}
%1 = smt.exists ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
}
return
}
// -----
func.func @forall_patterns_region_num_args() {
// expected-error @below {{block argument number and types of the 'body' and 'patterns' region #0 must match}}
%1 = smt.forall ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool, %arg3: !smt.bool):
smt.yield %arg2, %arg3 : !smt.bool, !smt.bool
}
return
}
// -----
func.func @exists_patterns_region_at_least_one_yielded_value() {
// expected-error @below {{'patterns' region #0 must have at least one yielded value}}
%1 = smt.exists ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
smt.yield
}
return
}
// -----
func.func @forall_patterns_region_at_least_one_yielded_value() {
// expected-error @below {{'patterns' region #0 must have at least one yielded value}}
%1 = smt.forall ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
smt.yield
}
return
}
// -----
func.func @exists_all_pattern_regions_tested() {
// expected-error @below {{'patterns' region #1 must have at least one yielded value}}
%1 = smt.exists ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
}, {
^bb0(%arg2: !smt.bool):
smt.yield
}
return
}
// -----
func.func @forall_all_pattern_regions_tested() {
// expected-error @below {{'patterns' region #1 must have at least one yielded value}}
%1 = smt.forall ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
}, {
^bb0(%arg2: !smt.bool):
smt.yield
}
return
}
// -----
func.func @exists_patterns_region_no_non_smt_operations() {
// expected-error @below {{'patterns' region #0 may only contain SMT dialect operations}}
%1 = smt.exists ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
// expected-note @below {{first non-SMT operation here}}
hw.constant 0 : i32
smt.yield %arg2 : !smt.bool
}
return
}
// -----
func.func @forall_patterns_region_no_non_smt_operations() {
// expected-error @below {{'patterns' region #0 may only contain SMT dialect operations}}
%1 = smt.forall ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
// expected-note @below {{first non-SMT operation here}}
hw.constant 0 : i32
smt.yield %arg2 : !smt.bool
}
return
}
// -----
func.func @exists_patterns_region_no_var_binding_operations() {
// expected-error @below {{'patterns' region #0 must not contain any variable binding operations}}
%1 = smt.exists ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
// expected-note @below {{first violating operation here}}
smt.exists ["b"] {
^bb0(%arg3: !smt.bool):
smt.yield %arg3 : !smt.bool
}
smt.yield %arg2 : !smt.bool
}
return
}
// -----
func.func @forall_patterns_region_no_var_binding_operations() {
// expected-error @below {{'patterns' region #0 must not contain any variable binding operations}}
%1 = smt.forall ["a"] {
^bb0(%arg2: !smt.bool):
smt.yield %arg2 : !smt.bool
} patterns {
^bb0(%arg2: !smt.bool):
// expected-note @below {{first violating operation here}}
smt.forall ["b"] {
^bb0(%arg3: !smt.bool):
smt.yield %arg3 : !smt.bool
}
smt.yield %arg2 : !smt.bool
}
return
}

View File

@ -1,5 +1,6 @@
add_circt_unittest(CIRCTSMTTests
AttributeTest.cpp
QuantifierTest.cpp
)
target_link_libraries(CIRCTSMTTests

View File

@ -0,0 +1,188 @@
//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/SMT/SMTOps.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace circt;
using namespace smt;
namespace {
//===----------------------------------------------------------------------===//
// Test custom builders of ExistsOp
//===----------------------------------------------------------------------===//
TEST(QuantifierTest, ExistsBuilderWithPattern) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));
OpBuilder builder(&context);
auto boolTy = BoolType::get(&context);
ExistsOp existsOp = builder.create<ExistsOp>(
loc, TypeRange{boolTy, boolTy},
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
return builder.create<AndOp>(loc, boundVars);
},
std::nullopt,
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
return boundVars;
},
/*weight=*/2);
SmallVector<char, 1024> buffer;
llvm::raw_svector_ostream stream(buffer);
existsOp.print(stream);
ASSERT_STREQ(
stream.str().str().c_str(),
"%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
}
TEST(QuantifierTest, ExistsBuilderNoPattern) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));
OpBuilder builder(&context);
auto boolTy = BoolType::get(&context);
ExistsOp existsOp = builder.create<ExistsOp>(
loc, TypeRange{boolTy, boolTy},
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
return builder.create<AndOp>(loc, boundVars);
},
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
SmallVector<char, 1024> buffer;
llvm::raw_svector_ostream stream(buffer);
existsOp.print(stream);
ASSERT_STREQ(stream.str().str().c_str(),
"%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
"smt.yield %0 : !smt.bool\n}\n");
}
TEST(QuantifierTest, ExistsBuilderDefault) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));
OpBuilder builder(&context);
auto boolTy = BoolType::get(&context);
ExistsOp existsOp = builder.create<ExistsOp>(
loc, TypeRange{boolTy, boolTy},
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
return builder.create<AndOp>(loc, boundVars);
},
ArrayRef<StringRef>{"a", "b"});
SmallVector<char, 1024> buffer;
llvm::raw_svector_ostream stream(buffer);
existsOp.print(stream);
ASSERT_STREQ(stream.str().str().c_str(),
"%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
"%0 : !smt.bool\n}\n");
}
//===----------------------------------------------------------------------===//
// Test custom builders of ForallOp
//===----------------------------------------------------------------------===//
TEST(QuantifierTest, ForallBuilderWithPattern) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));
OpBuilder builder(&context);
auto boolTy = BoolType::get(&context);
ForallOp forallOp = builder.create<ForallOp>(
loc, TypeRange{boolTy, boolTy},
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
return builder.create<AndOp>(loc, boundVars);
},
ArrayRef<StringRef>{"a", "b"},
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
return boundVars;
},
/*weight=*/2);
SmallVector<char, 1024> buffer;
llvm::raw_svector_ostream stream(buffer);
forallOp.print(stream);
ASSERT_STREQ(
stream.str().str().c_str(),
"%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
"!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
"smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
}
TEST(QuantifierTest, ForallBuilderNoPattern) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));
OpBuilder builder(&context);
auto boolTy = BoolType::get(&context);
ForallOp forallOp = builder.create<ForallOp>(
loc, TypeRange{boolTy, boolTy},
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
return builder.create<AndOp>(loc, boundVars);
},
ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
SmallVector<char, 1024> buffer;
llvm::raw_svector_ostream stream(buffer);
forallOp.print(stream);
ASSERT_STREQ(stream.str().str().c_str(),
"%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
"!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
"smt.yield %0 : !smt.bool\n}\n");
}
TEST(QuantifierTest, ForallBuilderDefault) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));
OpBuilder builder(&context);
auto boolTy = BoolType::get(&context);
ForallOp forallOp = builder.create<ForallOp>(
loc, TypeRange{boolTy, boolTy},
[](OpBuilder &builder, Location loc, ValueRange boundVars) {
return builder.create<AndOp>(loc, boundVars);
},
std::nullopt);
SmallVector<char, 1024> buffer;
llvm::raw_svector_ostream stream(buffer);
forallOp.print(stream);
ASSERT_STREQ(stream.str().str().c_str(),
"%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
"%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
"%0 : !smt.bool\n}\n");
}
} // namespace