From 6f0fc3de79f1877f9ccb3d8e044f3cb4276aa91f Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Fri, 22 Mar 2024 16:56:50 +0100 Subject: [PATCH] [SMT] Add quantifier operations (#6842) --- include/circt/Dialect/SMT/SMTOps.td | 70 +++++- lib/Dialect/SMT/SMTOps.cpp | 151 +++++++++++++ test/Dialect/SMT/basic.mlir | 91 ++++++++ test/Dialect/SMT/core-errors.mlir | 258 +++++++++++++++++++++++ unittests/Dialect/SMT/CMakeLists.txt | 1 + unittests/Dialect/SMT/QuantifierTest.cpp | 188 +++++++++++++++++ 6 files changed, 758 insertions(+), 1 deletion(-) create mode 100644 unittests/Dialect/SMT/QuantifierTest.cpp diff --git a/include/circt/Dialect/SMT/SMTOps.td b/include/circt/Dialect/SMT/SMTOps.td index a03ed0a1f9..4aa4c30b24 100644 --- a/include/circt/Dialect/SMT/SMTOps.td +++ b/include/circt/Dialect/SMT/SMTOps.td @@ -179,7 +179,8 @@ def YieldOp : SMTOp<"yield", [ Pure, Terminator, ReturnLike, - ParentOneOf<["smt::SolverOp", "smt::CheckOp"]>, + ParentOneOf<["smt::SolverOp", "smt::CheckOp", + "smt::ForallOp", "smt::ExistsOp"]>, ]> { let summary = "terminator operation for various regions of SMT operations"; let arguments = (ins Variadic:$values); @@ -337,4 +338,71 @@ def ImpliesOp : SMTOp<"implies", [Pure]> { let assemblyFormat = "$lhs `,` $rhs attr-dict"; } +class QuantifierOp : SMTOp, +]> { + let description = [{ + This operation represents the }] # summary # [{ as described in the + [SMT-LIB 2.6 standard](https://smtlib.cs.uiowa.edu/papers/smt-lib-reference-v2.6-r2021-05-12.pdf). + It is part of the language itself rather than a theory or logic. + + The operation specifies the name prefixes (as an optional attribute) and + types (as the types of the block arguments of the regions) of bound + variables that may be used in the 'body' of the operation. If a 'patterns' + region is specified, the block arguments must match the ones of the 'body' + region and (other than there) must be used at least once in the 'patterns' + region. It may also not contain any operations that bind variables, such as + quantifiers. While the 'body' region must always yield exactly one + `!smt.bool`-typed value, the 'patterns' region can yield an arbitrary number + (but at least one) of SMT values. + + The 'no_patterns' attribute is only allowed when no 'patterns' region is + specified and forbids the solver to generate and use patterns for this + quantifier. + + The 'weight' attribute indicates the importance of this quantifier being + instantiated compared to other quantifiers that may be present. The default + value is zero. + + Both the 'no_patterns' and 'weight' attributes are annotations to the + quantifiers body term. Annotations and attributes are described in the + standard in sections 3.4, and 3.6 (specifically 3.6.5). SMT-LIB allows + adding custom attributes to provide solvers with additional metadata, e.g., + hints such as above mentioned attributes. They are not part of the standard + themselves, but supported by common SMT solvers (e.g., Z3). + }]; + + let arguments = (ins DefaultValuedAttr:$weight, + UnitAttr:$noPattern, + OptionalAttr:$boundVarNames); + let regions = (region SizedRegion<1>:$body, + VariadicRegion>:$patterns); + let results = (outs BoolType:$result); + + let builders = [ + OpBuilder<(ins + "TypeRange":$boundVarTypes, + "function_ref":$bodyBuilder, + CArg<"std::optional>", "std::nullopt">:$boundVarNames, + CArg<"function_ref", + "{}">:$patternBuilder, + CArg<"uint32_t", "0">:$weight, + CArg<"bool", "false">:$noPattern)> + ]; + let skipDefaultBuilders = true; + + let assemblyFormat = [{ + ($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)? + attr-dict-with-keyword $body (`patterns` $patterns^)? + }]; + + let hasVerifier = true; + let hasRegionVerifier = true; +} + +def ForallOp : QuantifierOp<"forall"> { let summary = "forall quantifier"; } +def ExistsOp : QuantifierOp<"exists"> { let summary = "exists quantifier"; } + #endif // CIRCT_DIALECT_SMT_SMTOPS_TD diff --git a/lib/Dialect/SMT/SMTOps.cpp b/lib/Dialect/SMT/SMTOps.cpp index aa8199d359..f047a495da 100644 --- a/lib/Dialect/SMT/SMTOps.cpp +++ b/lib/Dialect/SMT/SMTOps.cpp @@ -313,5 +313,156 @@ ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +//===----------------------------------------------------------------------===// +// ForallOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyQuantifierRegions(QuantifierOp op) { + if (op.getBoundVarNames() && + op.getBody().getNumArguments() != op.getBoundVarNames()->size()) + return op.emitOpError( + "number of bound variable names must match number of block arguments"); + if (op.getBody().front().getTerminator()->getNumOperands() != 1) + return op.emitOpError("must have exactly one yielded value"); + if (!isa( + op.getBody().front().getTerminator()->getOperand(0).getType())) + return op.emitOpError("yielded value must be of '!smt.bool' type"); + + for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) { + unsigned i = regionWithIndex.index(); + Region ®ion = regionWithIndex.value(); + + if (op.getBody().getArgumentTypes() != region.getArgumentTypes()) + return op.emitOpError() + << "block argument number and types of the 'body' " + "and 'patterns' region #" + << i << " must match"; + if (region.front().getTerminator()->getNumOperands() < 1) + return op.emitOpError() << "'patterns' region #" << i + << " must have at least one yielded value"; + + // All operations in the 'patterns' region must be SMT operations. + auto result = region.walk([&](Operation *childOp) { + if (!isa(childOp->getDialect())) { + auto diag = op.emitOpError() + << "the 'patterns' region #" << i + << " may only contain SMT dialect operations"; + diag.attachNote(childOp->getLoc()) << "first non-SMT operation here"; + return WalkResult::interrupt(); + } + + // There may be no quantifier (or other variable binding) operations in + // the 'patterns' region. + if (isa(childOp)) { + auto diag = op.emitOpError() << "the 'patterns' region #" << i + << " must not contain " + "any variable binding operations"; + diag.attachNote(childOp->getLoc()) << "first violating operation here"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return failure(); + } + + return success(); +} + +template +static void buildQuantifier( + OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, + function_ref bodyBuilder, + std::optional> boundVarNames, + function_ref patternBuilder, + uint32_t weight, bool noPattern) { + odsState.addTypes(BoolType::get(odsBuilder.getContext())); + if (weight != 0) + odsState.getOrAddProperties().weight = + odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight); + if (noPattern) + odsState.getOrAddProperties().noPattern = + odsBuilder.getUnitAttr(); + if (boundVarNames.has_value()) { + SmallVector boundVarNamesList; + for (StringRef str : *boundVarNames) + boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str)); + odsState.getOrAddProperties().boundVarNames = + odsBuilder.getArrayAttr(boundVarNamesList); + } + { + OpBuilder::InsertionGuard guard(odsBuilder); + Region *region = odsState.addRegion(); + Block *block = odsBuilder.createBlock(region); + block->addArguments( + boundVarTypes, + SmallVector(boundVarTypes.size(), odsState.location)); + Value returnVal = + bodyBuilder(odsBuilder, odsState.location, block->getArguments()); + odsBuilder.create(odsState.location, returnVal); + } + if (patternBuilder) { + Region *region = odsState.addRegion(); + OpBuilder::InsertionGuard guard(odsBuilder); + Block *block = odsBuilder.createBlock(region); + block->addArguments( + boundVarTypes, + SmallVector(boundVarTypes.size(), odsState.location)); + ValueRange returnVals = + patternBuilder(odsBuilder, odsState.location, block->getArguments()); + odsBuilder.create(odsState.location, returnVals); + } +} + +LogicalResult ForallOp::verify() { + if (!getPatterns().empty() && getNoPattern()) + return emitOpError() << "patterns and the no_pattern attribute must not be " + "specified at the same time"; + + return success(); +} + +LogicalResult ForallOp::verifyRegions() { + return verifyQuantifierRegions(*this); +} + +void ForallOp::build( + OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, + function_ref bodyBuilder, + std::optional> boundVarNames, + function_ref patternBuilder, + uint32_t weight, bool noPattern) { + buildQuantifier(odsBuilder, odsState, boundVarTypes, bodyBuilder, + boundVarNames, patternBuilder, weight, noPattern); +} + +//===----------------------------------------------------------------------===// +// ExistsOp +//===----------------------------------------------------------------------===// + +LogicalResult ExistsOp::verify() { + if (!getPatterns().empty() && getNoPattern()) + return emitOpError() << "patterns and the no_pattern attribute must not be " + "specified at the same time"; + + return success(); +} + +LogicalResult ExistsOp::verifyRegions() { + return verifyQuantifierRegions(*this); +} + +void ExistsOp::build( + OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, + function_ref bodyBuilder, + std::optional> boundVarNames, + function_ref patternBuilder, + uint32_t weight, bool noPattern) { + buildQuantifier(odsBuilder, odsState, boundVarTypes, bodyBuilder, + boundVarNames, patternBuilder, weight, noPattern); +} + #define GET_OP_CLASSES #include "circt/Dialect/SMT/SMT.cpp.inc" diff --git a/test/Dialect/SMT/basic.mlir b/test/Dialect/SMT/basic.mlir index de42670294..0054b5c51f 100644 --- a/test/Dialect/SMT/basic.mlir +++ b/test/Dialect/SMT/basic.mlir @@ -84,3 +84,94 @@ func.func @core(%in: i8) { return } + +// CHECK-LABEL: func @quantifiers +func.func @quantifiers() { + // CHECK-NEXT: smt.forall ["a", "b"] weight 2 attributes {smt.some_attr} { + // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool): + // CHECK-NEXT: smt.eq + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } patterns { + // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool, %{{.*}}: !smt.bool): + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: }, { + // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool): + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %0 = smt.forall ["a", "b"] weight 2 attributes {smt.some_attr} { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + %1 = smt.eq %arg2, %arg3 : !smt.bool + smt.yield %1 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + }, { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + } + + // CHECK-NEXT: smt.forall ["a", "b"] no_pattern attributes {smt.some_attr} { + // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool): + // CHECK-NEXT: smt.eq + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %1 = smt.forall ["a", "b"] no_pattern attributes {smt.some_attr} { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + %2 = smt.eq %arg2, %arg3 : !smt.bool + smt.yield %2 : !smt.bool + } + + // CHECK-NEXT: smt.forall { + // CHECK-NEXT: smt.constant + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %2 = smt.forall { + %3 = smt.constant true + smt.yield %3 : !smt.bool + } + + // CHECK-NEXT: smt.exists ["a", "b"] weight 2 attributes {smt.some_attr} { + // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool): + // CHECK-NEXT: smt.eq + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } patterns { + // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool, %{{.*}}: !smt.bool): + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: }, { + // CHECK-NEXT: ^bb0(%{{.*}}: !smt.bool): + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %3 = smt.exists ["a", "b"] weight 2 attributes {smt.some_attr} { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + %4 = smt.eq %arg2, %arg3 : !smt.bool + smt.yield %4 : !smt.bool {smt.some_attr} + } patterns { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + }, { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + } + + // CHECK-NEXT: smt.exists no_pattern attributes {smt.some_attr} { + // CHECK-NEXT: ^bb0({{.*}}: !smt.bool, {{.*}}: !smt.bool): + // CHECK-NEXT: smt.eq + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %4 = smt.exists no_pattern attributes {smt.some_attr} { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + %5 = smt.eq %arg2, %arg3 : !smt.bool + smt.yield %5 : !smt.bool {smt.some_attr} + } + + // CHECK-NEXT: smt.exists [] { + // CHECK-NEXT: smt.constant + // CHECK-NEXT: smt.yield %{{.*}} + // CHECK-NEXT: } + %5 = smt.exists [] { + %6 = smt.constant true + smt.yield %6 : !smt.bool + } + + return +} diff --git a/test/Dialect/SMT/core-errors.mlir b/test/Dialect/SMT/core-errors.mlir index 8f172681b1..feda0c4335 100644 --- a/test/Dialect/SMT/core-errors.mlir +++ b/test/Dialect/SMT/core-errors.mlir @@ -155,3 +155,261 @@ func.func @ite_type_mismatch(%a: !smt.bool, %b: !smt.bv<32>) { "smt.ite"(%a, %a, %b) {} : (!smt.bool, !smt.bool, !smt.bv<32>) -> !smt.bool return } + +// ----- + +func.func @forall_number_of_decl_names_must_match_num_args() { + // expected-error @below {{number of bound variable names must match number of block arguments}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + %2 = smt.eq %arg2, %arg3 : !smt.int + smt.yield %2 : !smt.bool + } + return +} + +// ----- + +func.func @exists_number_of_decl_names_must_match_num_args() { + // expected-error @below {{number of bound variable names must match number of block arguments}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + %2 = smt.eq %arg2, %arg3 : !smt.int + smt.yield %2 : !smt.bool + } + return +} + +// ----- + +func.func @forall_yield_must_have_exactly_one_bool_value() { + // expected-error @below {{yielded value must be of '!smt.bool' type}} + %1 = smt.forall ["a", "b"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + %2 = smt.int.add %arg2, %arg3 + smt.yield %2 : !smt.int + } + return +} + +// ----- + +func.func @forall_yield_must_have_exactly_one_bool_value() { + // expected-error @below {{must have exactly one yielded value}} + %1 = smt.forall ["a", "b"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + smt.yield + } + return +} + +// ----- + +func.func @exists_yield_must_have_exactly_one_bool_value() { + // expected-error @below {{yielded value must be of '!smt.bool' type}} + %1 = smt.exists ["a", "b"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + %2 = smt.int.add %arg2, %arg3 + smt.yield %2 : !smt.int + } + return +} + +// ----- + +func.func @exists_yield_must_have_exactly_one_bool_value() { + // expected-error @below {{must have exactly one yielded value}} + %1 = smt.exists ["a", "b"] { + ^bb0(%arg2: !smt.int, %arg3: !smt.int): + smt.yield + } + return +} + +// ----- + +func.func @exists_patterns_region_and_no_patterns_attr_are_mutually_exclusive() { + // expected-error @below {{patterns and the no_pattern attribute must not be specified at the same time}} + %1 = smt.exists ["a"] no_pattern { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @forall_patterns_region_and_no_patterns_attr_are_mutually_exclusive() { + // expected-error @below {{patterns and the no_pattern attribute must not be specified at the same time}} + %1 = smt.forall ["a"] no_pattern { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @exists_patterns_region_num_args() { + // expected-error @below {{block argument number and types of the 'body' and 'patterns' region #0 must match}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + } + return +} + +// ----- + +func.func @forall_patterns_region_num_args() { + // expected-error @below {{block argument number and types of the 'body' and 'patterns' region #0 must match}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool, %arg3: !smt.bool): + smt.yield %arg2, %arg3 : !smt.bool, !smt.bool + } + return +} + +// ----- + +func.func @exists_patterns_region_at_least_one_yielded_value() { + // expected-error @below {{'patterns' region #0 must have at least one yielded value}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield + } + return +} + +// ----- + +func.func @forall_patterns_region_at_least_one_yielded_value() { + // expected-error @below {{'patterns' region #0 must have at least one yielded value}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield + } + return +} + +// ----- + +func.func @exists_all_pattern_regions_tested() { + // expected-error @below {{'patterns' region #1 must have at least one yielded value}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + }, { + ^bb0(%arg2: !smt.bool): + smt.yield + } + return +} + +// ----- + +func.func @forall_all_pattern_regions_tested() { + // expected-error @below {{'patterns' region #1 must have at least one yielded value}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + }, { + ^bb0(%arg2: !smt.bool): + smt.yield + } + return +} + +// ----- + +func.func @exists_patterns_region_no_non_smt_operations() { + // expected-error @below {{'patterns' region #0 may only contain SMT dialect operations}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + // expected-note @below {{first non-SMT operation here}} + hw.constant 0 : i32 + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @forall_patterns_region_no_non_smt_operations() { + // expected-error @below {{'patterns' region #0 may only contain SMT dialect operations}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + // expected-note @below {{first non-SMT operation here}} + hw.constant 0 : i32 + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @exists_patterns_region_no_var_binding_operations() { + // expected-error @below {{'patterns' region #0 must not contain any variable binding operations}} + %1 = smt.exists ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + // expected-note @below {{first violating operation here}} + smt.exists ["b"] { + ^bb0(%arg3: !smt.bool): + smt.yield %arg3 : !smt.bool + } + smt.yield %arg2 : !smt.bool + } + return +} + +// ----- + +func.func @forall_patterns_region_no_var_binding_operations() { + // expected-error @below {{'patterns' region #0 must not contain any variable binding operations}} + %1 = smt.forall ["a"] { + ^bb0(%arg2: !smt.bool): + smt.yield %arg2 : !smt.bool + } patterns { + ^bb0(%arg2: !smt.bool): + // expected-note @below {{first violating operation here}} + smt.forall ["b"] { + ^bb0(%arg3: !smt.bool): + smt.yield %arg3 : !smt.bool + } + smt.yield %arg2 : !smt.bool + } + return +} diff --git a/unittests/Dialect/SMT/CMakeLists.txt b/unittests/Dialect/SMT/CMakeLists.txt index 2e9ba7eef8..cf89b70282 100644 --- a/unittests/Dialect/SMT/CMakeLists.txt +++ b/unittests/Dialect/SMT/CMakeLists.txt @@ -1,5 +1,6 @@ add_circt_unittest(CIRCTSMTTests AttributeTest.cpp + QuantifierTest.cpp ) target_link_libraries(CIRCTSMTTests diff --git a/unittests/Dialect/SMT/QuantifierTest.cpp b/unittests/Dialect/SMT/QuantifierTest.cpp new file mode 100644 index 0000000000..dd115d4a3f --- /dev/null +++ b/unittests/Dialect/SMT/QuantifierTest.cpp @@ -0,0 +1,188 @@ +//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/SMT/SMTOps.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace circt; +using namespace smt; + +namespace { + +//===----------------------------------------------------------------------===// +// Test custom builders of ExistsOp +//===----------------------------------------------------------------------===// + +TEST(QuantifierTest, ExistsBuilderWithPattern) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + ExistsOp existsOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + std::nullopt, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return boundVars; + }, + /*weight=*/2); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + existsOp.print(stream); + + ASSERT_STREQ( + stream.str().str().c_str(), + "%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, " + "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : " + "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n " + "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n"); +} + +TEST(QuantifierTest, ExistsBuilderNoPattern) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + ExistsOp existsOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + ArrayRef{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + existsOp.print(stream); + + ASSERT_STREQ(stream.str().str().c_str(), + "%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: " + "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n " + "smt.yield %0 : !smt.bool\n}\n"); +} + +TEST(QuantifierTest, ExistsBuilderDefault) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + ExistsOp existsOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + ArrayRef{"a", "b"}); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + existsOp.print(stream); + + ASSERT_STREQ(stream.str().str().c_str(), + "%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, " + "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield " + "%0 : !smt.bool\n}\n"); +} + +//===----------------------------------------------------------------------===// +// Test custom builders of ForallOp +//===----------------------------------------------------------------------===// + +TEST(QuantifierTest, ForallBuilderWithPattern) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + ForallOp forallOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + ArrayRef{"a", "b"}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return boundVars; + }, + /*weight=*/2); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + forallOp.print(stream); + + ASSERT_STREQ( + stream.str().str().c_str(), + "%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, " + "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : " + "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n " + "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n"); +} + +TEST(QuantifierTest, ForallBuilderNoPattern) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + ForallOp forallOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + ArrayRef{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + forallOp.print(stream); + + ASSERT_STREQ(stream.str().str().c_str(), + "%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: " + "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n " + "smt.yield %0 : !smt.bool\n}\n"); +} + +TEST(QuantifierTest, ForallBuilderDefault) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + OpBuilder builder(&context); + auto boolTy = BoolType::get(&context); + + ForallOp forallOp = builder.create( + loc, TypeRange{boolTy, boolTy}, + [](OpBuilder &builder, Location loc, ValueRange boundVars) { + return builder.create(loc, boundVars); + }, + std::nullopt); + + SmallVector buffer; + llvm::raw_svector_ostream stream(buffer); + forallOp.print(stream); + + ASSERT_STREQ(stream.str().str().c_str(), + "%0 = smt.forall {\n^bb0(%arg0: !smt.bool, " + "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield " + "%0 : !smt.bool\n}\n"); +} + +} // namespace