[SMT] Add lowering to LLVM IR (#6902)

This commit is contained in:
Martin Erhart 2024-04-20 10:11:47 +02:00 committed by GitHub
parent 3a08dce574
commit 5cf1ff57f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1352 additions and 2 deletions

View File

@ -41,6 +41,7 @@
#include "circt/Conversion/MooreToCore.h"
#include "circt/Conversion/PipelineToHW.h"
#include "circt/Conversion/SCFToCalyx.h"
#include "circt/Conversion/SMTToZ3LLVM.h"
#include "circt/Conversion/SeqToSV.h"
#include "circt/Conversion/SimToSV.h"
#include "circt/Conversion/VerifToSMT.h"

View File

@ -729,6 +729,23 @@ def LowerArcToLLVM : Pass<"lower-arc-to-llvm", "mlir::ModuleOp"> {
];
}
//===----------------------------------------------------------------------===//
// ConvertSMTToZ3LLVM
//===----------------------------------------------------------------------===//
def LowerSMTToZ3LLVM : Pass<"lower-smt-to-z3-llvm", "mlir::ModuleOp"> {
let summary = "Lower the SMT dialect to LLVM IR calling the Z3 API";
let dependentDialects = [
"smt::SMTDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect",
"mlir::cf::ControlFlowDialect"
];
let options = [
Option<"debug", "debug", "bool", "false",
"Insert additional printf calls printing the solver's state to "
"stdout (e.g. at check-sat operations) for debugging purposes">,
];
}
//===----------------------------------------------------------------------===//
// ConvertSeqToSV
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,80 @@
//===- SMTToZ3LLVM.h --------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CIRCT_CONVERSION_SMTTOZ3LLVM_H
#define CIRCT_CONVERSION_SMTTOZ3LLVM_H
#include "circt/Support/LLVM.h"
#include "circt/Support/Namespace.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/ADT/StringRef.h"
#include <memory>
namespace circt {
#define GEN_PASS_DECL_LOWERSMTTOZ3LLVM
#include "circt/Conversion/Passes.h.inc"
/// A symbol cache for LLVM globals and functions relevant to SMT lowering
/// patterns.
struct SMTGlobalsHandler {
/// Creates the LLVM global operations to store the pointers to the solver and
/// the context and returns a 'SMTGlobalHandler' initialized with those new
/// globals.
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module);
/// Initializes the caches and keeps track of the given globals to store the
/// pointers to the SMT solver and context. It is assumed that the passed
/// global operations are of the correct (or at least compatible) form. E.g.,
/// ```
/// llvm.mlir.global internal @ctx() {alignment = 8 : i64} : !llvm.ptr {
/// %0 = llvm.mlir.zero : !llvm.ptr
/// llvm.return %0 : !llvm.ptr
/// }
/// ```
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver,
mlir::LLVM::GlobalOp ctx);
/// Initializes the caches and keeps track of the given globals to store the
/// pointers to the SMT solver and context. It is assumed that the passed
/// global operations are of the correct (or at least compatible) form. E.g.,
/// ```
/// llvm.mlir.global internal @ctx() {alignment = 8 : i64} : !llvm.ptr {
/// %0 = llvm.mlir.zero : !llvm.ptr
/// llvm.return %0 : !llvm.ptr
/// }
/// ```
SMTGlobalsHandler(Namespace &&names, mlir::LLVM::GlobalOp solver,
mlir::LLVM::GlobalOp ctx);
/// The global storing the pointer to the SMT solver object currently active.
const mlir::LLVM::GlobalOp solver;
/// The global storing the pointer to the SMT context object currently active.
const mlir::LLVM::GlobalOp ctx;
Namespace names;
DenseMap<StringRef, mlir::LLVM::LLVMFuncOp> funcMap;
DenseMap<Block *, Value> ctxCache;
DenseMap<Block *, Value> solverCache;
DenseMap<StringRef, mlir::LLVM::GlobalOp> stringCache;
};
/// Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter);
/// Add the SMT to LLVM IR conversion patterns to 'patterns'. A
/// 'SMTGlobalHandler' object has to be passed which acts as a symbol cache for
/// LLVM globals and functions.
void populateSMTToZ3LLVMConversionPatterns(
RewritePatternSet &patterns, TypeConverter &converter,
SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options);
} // namespace circt
#endif // CIRCT_CONVERSION_SMTTOZ3LLVM_H

View File

@ -247,7 +247,7 @@ def EqOp : SMTOp<"eq", [Pure, SameTypeOperands]> {
`and (and (= a b) (= b c)) (= c d)`.
}];
let arguments = (ins Variadic<AnySMTType>:$inputs);
let arguments = (ins Variadic<AnyNonFuncSMTType>:$inputs);
let results = (outs BoolType:$result);
let builders = [
@ -288,7 +288,7 @@ def DistinctOp : SMTOp<"distinct", [Pure, SameTypeOperands]> {
```
}];
let arguments = (ins Variadic<AnySMTType>:$inputs);
let arguments = (ins Variadic<AnyNonFuncSMTType>:$inputs);
let results = (outs BoolType:$result);
let builders = [

View File

@ -13,6 +13,10 @@ set(CIRCT_INTEGRATION_TEST_DEPENDS
handshake-runner
)
if (MLIR_ENABLE_EXECUTION_ENGINE)
list(APPEND CIRCT_INTEGRATION_TEST_DEPENDS mlir-cpu-runner)
endif()
# If Python bindings are available to build then enable the tests.
if(CIRCT_BINDINGS_PYTHON_ENABLED)
list(APPEND CIRCT_INTEGRATION_TEST_DEPENDS CIRCTPythonModules)

View File

@ -0,0 +1,96 @@
// RUN: circt-opt %s --lower-smt-to-z3-llvm --canonicalize | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void --shared-libs=%libz3 | \
// RUN: FileCheck %s
// RUN: circt-opt %s --lower-smt-to-z3-llvm=debug=true --canonicalize | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void --shared-libs=%libz3 | \
// RUN: FileCheck %s
// REQUIRES: libz3
// REQUIRES: mlir-cpu-runner
func.func @entry() {
%false = llvm.mlir.constant(0 : i1) : i1
// CHECK: sat
// CHECK: Res: 1
smt.solver () : () -> () {
%c42_bv65 = smt.bv.constant #smt.bv<42> : !smt.bv<65>
%1 = smt.declare_fun : !smt.bv<65>
%2 = smt.declare_fun "a" : !smt.bv<65>
%3 = smt.eq %c42_bv65, %1, %2 : !smt.bv<65>
func.call @check(%3) : (!smt.bool) -> ()
smt.yield
}
// CHECK: sat
// CHECK: Res: 1
smt.solver () : () -> () {
%c0_bv8 = smt.bv.constant #smt.bv<0> : !smt.bv<8>
%c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8>
%2 = smt.distinct %c0_bv8, %c-1_bv8 : !smt.bv<8>
func.call @check(%2) : (!smt.bool) -> ()
smt.yield
}
// CHECK: sat
// CHECK: Res: 1
smt.solver () : () -> () {
%0 = smt.declare_fun : !smt.func<(!smt.bv<4>) !smt.array<[!smt.int -> !smt.sort<"uninterpreted_sort"[!smt.bool]>]>>
%1 = smt.declare_fun : !smt.func<(!smt.bv<4>) !smt.array<[!smt.int -> !smt.sort<"uninterpreted_sort"[!smt.bool]>]>>
%c0_bv4 = smt.bv.constant #smt.bv<0> : !smt.bv<4>
%2 = smt.apply_func %0(%c0_bv4) : !smt.func<(!smt.bv<4>) !smt.array<[!smt.int -> !smt.sort<"uninterpreted_sort"[!smt.bool]>]>>
%3 = smt.apply_func %1(%c0_bv4) : !smt.func<(!smt.bv<4>) !smt.array<[!smt.int -> !smt.sort<"uninterpreted_sort"[!smt.bool]>]>>
%4 = smt.eq %2, %3 : !smt.array<[!smt.int -> !smt.sort<"uninterpreted_sort"[!smt.bool]>]>
func.call @check(%4) : (!smt.bool) -> ()
smt.yield
}
// CHECK: unsat
// CHECK: Res: -1
smt.solver (%false) : (i1) -> () {
^bb0(%arg0: i1):
%c0_bv32 = smt.bv.constant #smt.bv<0> : !smt.bv<32>
%0 = scf.if %arg0 -> !smt.bv<32> {
%1 = smt.declare_fun : !smt.bv<32>
scf.yield %1 : !smt.bv<32>
} else {
%c1_bv32 = smt.bv.constant #smt.bv<-1> : !smt.bv<32>
scf.yield %c1_bv32 : !smt.bv<32>
}
%1 = smt.eq %c0_bv32, %0 : !smt.bv<32>
func.call @check(%1) : (!smt.bool) -> ()
smt.yield
}
return
}
func.func @check(%expr: !smt.bool) {
smt.assert %expr
%0 = smt.check sat {
%1 = llvm.mlir.addressof @sat : !llvm.ptr
llvm.call @printf(%1) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr) -> i32
%c1 = llvm.mlir.constant(1 : i32) : i32
smt.yield %c1 : i32
} unknown {
%1 = llvm.mlir.addressof @unknown : !llvm.ptr
llvm.call @printf(%1) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr) -> i32
%c0 = llvm.mlir.constant(0 : i32) : i32
smt.yield %c0 : i32
} unsat {
%1 = llvm.mlir.addressof @unsat : !llvm.ptr
llvm.call @printf(%1) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr) -> i32
%c-1 = llvm.mlir.constant(-1 : i32) : i32
smt.yield %c-1 : i32
} -> i32
%1 = llvm.mlir.addressof @res : !llvm.ptr
llvm.call @printf(%1, %0) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, i32) -> i32
return
}
llvm.func @printf(!llvm.ptr, ...) -> i32
llvm.mlir.global private constant @res("Res: %d\n\00") {addr_space = 0 : i32}
llvm.mlir.global private constant @sat("sat\n\00") {addr_space = 0 : i32}
llvm.mlir.global private constant @unsat("unsat\n\00") {addr_space = 0 : i32}
llvm.mlir.global private constant @unknown("unknown\n\00") {addr_space = 0 : i32}

View File

@ -207,6 +207,16 @@ if config.z3_path != "":
tools.append('z3')
config.available_features.add('z3')
# Enable libz3 if it has been detected.
if config.z3_library != "":
tools.append(ToolSubst(f"%libz3", config.z3_library))
config.available_features.add('libz3')
# Add mlir-cpu-runner if the execution engine is built.
if config.mlir_enable_execution_engine:
config.available_features.add('mlir-cpu-runner')
tools.append('mlir-cpu-runner')
# Add circt-verilog if the Slang frontend is enabled.
if config.slang_frontend_enabled:
config.available_features.add('slang')

View File

@ -54,6 +54,8 @@ config.bindings_python_enabled = @CIRCT_BINDINGS_PYTHON_ENABLED@
config.bindings_tcl_enabled = @CIRCT_BINDINGS_TCL_ENABLED@
config.lec_enabled = "@CIRCT_LEC_ENABLED@"
config.z3_path = "@Z3_PATH@"
config.z3_library = "@Z3_LIBRARIES@"
config.mlir_enable_execution_engine = "@MLIR_ENABLE_EXECUTION_ENGINE@"
config.slang_frontend_enabled = "@CIRCT_SLANG_FRONTEND_ENABLED@"
config.arcilator_jit_enabled = @ARCILATOR_JIT_ENABLED@
config.driver = "@CIRCT_SOURCE_DIR@/tools/circt-rtl-sim/driver.cpp"

View File

@ -32,6 +32,7 @@ add_mlir_public_c_api_library(CIRCTCAPIConversion
CIRCTSCFToCalyx
CIRCTSeqToSV
CIRCTSimToSV
CIRCTSMTToZ3LLVM
CIRCTCFToHandshake
CIRCTVerifToSMT
CIRCTVerifToSV

View File

@ -29,6 +29,7 @@ add_subdirectory(SeqToSV)
add_subdirectory(SimToSV)
add_subdirectory(CFToHandshake)
add_subdirectory(VerifToSMT)
add_subdirectory(SMTToZ3LLVM)
add_subdirectory(VerifToSV)
add_subdirectory(CalyxNative)

View File

@ -0,0 +1,18 @@
add_circt_conversion_library(CIRCTSMTToZ3LLVM
LowerSMTToZ3LLVM.cpp
DEPENDS
CIRCTConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
CIRCTSMT
CIRCTSupport
MLIRLLVMCommonConversion
MLIRSCFToControlFlow
MLIRControlFlowToLLVM
MLIRFuncToLLVM
MLIRTransforms
)

View File

@ -0,0 +1,925 @@
//===- LowerSMTToZ3LLVM.cpp -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "circt/Conversion/SMTToZ3LLVM.h"
#include "circt/Dialect/SMT/SMTOps.h"
#include "circt/Support/Namespace.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "lower-smt-to-z3-llvm"
namespace circt {
#define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
#include "circt/Conversion/Passes.h.inc"
} // namespace circt
using namespace mlir;
using namespace circt;
using namespace smt;
//===----------------------------------------------------------------------===//
// SMTGlobalHandler implementation
//===----------------------------------------------------------------------===//
SMTGlobalsHandler SMTGlobalsHandler::create(OpBuilder &builder,
ModuleOp module) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(module.getBody());
SymbolCache symCache;
symCache.addDefinitions(module);
Namespace names;
names.add(symCache);
Location loc = module.getLoc();
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
auto createGlobal = [&](StringRef namePrefix) {
auto global = builder.create<LLVM::GlobalOp>(
loc, ptrTy, false, LLVM::Linkage::Internal, names.newName(namePrefix),
Attribute{}, /*alignment=*/8);
OpBuilder::InsertionGuard g(builder);
builder.createBlock(&global.getInitializer());
Value res = builder.create<LLVM::ZeroOp>(loc, ptrTy);
builder.create<LLVM::ReturnOp>(loc, res);
return global;
};
auto ctxGlobal = createGlobal("ctx");
auto solverGlobal = createGlobal("solver");
return SMTGlobalsHandler(std::move(names), solverGlobal, ctxGlobal);
}
SMTGlobalsHandler::SMTGlobalsHandler(Namespace &&names,
mlir::LLVM::GlobalOp solver,
mlir::LLVM::GlobalOp ctx)
: solver(solver), ctx(ctx), names(names) {}
SMTGlobalsHandler::SMTGlobalsHandler(ModuleOp module,
mlir::LLVM::GlobalOp solver,
mlir::LLVM::GlobalOp ctx)
: solver(solver), ctx(ctx) {
SymbolCache symCache;
symCache.addDefinitions(module);
names.add(symCache);
}
//===----------------------------------------------------------------------===//
// Lowering Pattern Base
//===----------------------------------------------------------------------===//
namespace {
template <typename OpTy>
class SMTLoweringPattern : public OpConversionPattern<OpTy> {
public:
SMTLoweringPattern(const TypeConverter &typeConverter, MLIRContext *context,
SMTGlobalsHandler &globals,
const LowerSMTToZ3LLVMOptions &options)
: OpConversionPattern<OpTy>(typeConverter, context), globals(globals),
options(options) {}
private:
Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
LLVM::GlobalOp global,
DenseMap<Block *, Value> &cache) const {
Block *block = builder.getBlock();
if (auto iter = cache.find(block); iter != cache.end())
return iter->getSecond();
Value globalAddr = builder.create<LLVM::AddressOfOp>(loc, global);
return cache[block] = builder.create<LLVM::LoadOp>(
loc, LLVM::LLVMPointerType::get(builder.getContext()),
globalAddr);
}
protected:
/// A convenience function to get the pointer to the context from the 'global'
/// operation. The result is cached for each basic block, i.e., it is assumed
/// that this function is never called in the same basic block again at a
/// location (insertion point of the 'builder') not dominating all previous
/// locations this function was called at.
Value buildContextPtr(OpBuilder &builder, Location loc) const {
return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
}
/// A convenience function to get the pointer to the solver from the 'global'
/// operation. The result is cached for each basic block, i.e., it is assumed
/// that this function is never called in the same basic block again at a
/// location (insertion point of the 'builder') not dominating all previous
/// locations this function was called at.
Value buildSolverPtr(OpBuilder &builder, Location loc) const {
return buildGlobalPtrToGlobal(builder, loc, globals.solver,
globals.solverCache);
}
/// Create a `llvm.call` operation to a function with the given 'name' and
/// 'type'. If there does not already exist a (external) function with that
/// name create a matching external function declaration.
LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
LLVM::LLVMFunctionType funcType,
ValueRange args) const {
auto &funcOp = globals.funcMap[name];
if (!funcOp) {
OpBuilder::InsertionGuard guard(builder);
auto module =
builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
builder.setInsertionPointToEnd(module.getBody());
funcOp = LLVM::lookupOrCreateFn(module, name, funcType.getParams(),
funcType.getReturnType(),
funcType.getVarArg());
}
return builder.create<LLVM::CallOp>(loc, funcOp, args);
}
/// Build a global constant for the given string and construct an 'addressof'
/// operation at the current 'builder' insertion point to get a pointer to it.
/// Multiple calls with the same string will reuse the same global. It is
/// guaranteed that the symbol of the global will be unique.
Value buildString(OpBuilder &builder, Location loc, StringRef str) const {
auto &global = globals.stringCache[str];
if (!global) {
OpBuilder::InsertionGuard guard(builder);
auto module =
builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
builder.setInsertionPointToEnd(module.getBody());
auto arrayTy =
LLVM::LLVMArrayType::get(builder.getI8Type(), str.size() + 1);
auto strAttr = builder.getStringAttr(str.str() + '\00');
global = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
globals.names.newName("str"), strAttr);
}
return builder.create<LLVM::AddressOfOp>(loc, global);
}
/// Most API functions require a pointer to the the Z3 context object as the
/// first argument. This helper function prepends this pointer value to the
/// call for convenience.
LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
StringRef name, Type returnType,
ValueRange args = {}) const {
auto ctx = buildContextPtr(builder, loc);
SmallVector<Value> arguments;
arguments.emplace_back(ctx);
arguments.append(SmallVector<Value>(args));
return buildCall(
builder, loc, name,
LLVM::LLVMFunctionType::get(
returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
arguments);
}
/// Most API functions we need to call return a 'Z3_AST' object which is a
/// pointer in LLVM. This helper function simplifies calling those API
/// functions.
Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
ValueRange args = {}) const {
return buildAPICallWithContext(
builder, loc, name,
LLVM::LLVMPointerType::get(builder.getContext()), args)
->getResult(0);
}
/// Build a value representing the SMT sort given with 'type'.
Value buildSort(OpBuilder &builder, Location loc, Type type) const {
// NOTE: if a type not handled by this switch is passed, an assertion will
// be triggered.
return TypeSwitch<Type, Value>(type)
.Case([&](smt::IntType ty) {
return buildPtrAPICall(builder, loc, "Z3_mk_int_sort");
})
.Case([&](smt::BitVectorType ty) {
Value bitwidth = builder.create<LLVM::ConstantOp>(
loc, builder.getI32Type(), ty.getWidth());
return buildPtrAPICall(builder, loc, "Z3_mk_bv_sort", {bitwidth});
})
.Case([&](smt::BoolType ty) {
return buildPtrAPICall(builder, loc, "Z3_mk_bool_sort");
})
.Case([&](smt::SortType ty) {
Value str = buildString(builder, loc, ty.getIdentifier());
Value sym =
buildPtrAPICall(builder, loc, "Z3_mk_string_symbol", {str});
return buildPtrAPICall(builder, loc, "Z3_mk_uninterpreted_sort",
{sym});
})
.Case([&](smt::ArrayType ty) {
return buildPtrAPICall(builder, loc, "Z3_mk_array_sort",
{buildSort(builder, loc, ty.getDomainType()),
buildSort(builder, loc, ty.getRangeType())});
});
}
SMTGlobalsHandler &globals;
const LowerSMTToZ3LLVMOptions &options;
};
//===----------------------------------------------------------------------===//
// Lowering Patterns
//===----------------------------------------------------------------------===//
/// The 'smt.declare_fun' operation is used to declare both constants and
/// functions. The Z3 API, however, uses two different functions. Therefore,
/// depending on the result type of this operation, one of the following two
/// API functions is used to create the symbolic value:
/// ```
/// Z3_ast Z3_API Z3_mk_fresh_const(Z3_context c, Z3_string prefix, Z3_sort ty);
/// Z3_func_decl Z3_API Z3_mk_fresh_func_decl(
/// Z3_context c, Z3_string prefix, unsigned domain_size,
/// Z3_sort const domain[], Z3_sort range);
/// ```
struct DeclareFunOpLowering : public SMTLoweringPattern<DeclareFunOp> {
using SMTLoweringPattern::SMTLoweringPattern;
LogicalResult
matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
// Create the name prefix.
Value prefix;
if (adaptor.getNamePrefix())
prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
else
prefix = rewriter.create<LLVM::ZeroOp>(
loc, LLVM::LLVMPointerType::get(getContext()));
// Handle the constant value case.
if (!isa<SMTFuncType>(op.getType())) {
Value sort = buildSort(rewriter, loc, op.getType());
Value constDecl =
buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_const", {prefix, sort});
rewriter.replaceOp(op, constDecl);
return success();
}
// Otherwise, we declare a function.
Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
auto funcType = cast<SMTFuncType>(op.getResult().getType());
Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
Type arrTy =
LLVM::LLVMArrayType::get(llvmPtrTy, funcType.getDomainTypes().size());
Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
for (auto [i, ty] : llvm::enumerate(funcType.getDomainTypes())) {
Value sort = buildSort(rewriter, loc, ty);
domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, sort, i);
}
Value one =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
Value domainStorage =
rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
Value domainSize = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
Value decl =
buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_func_decl",
{prefix, domainSize, domainStorage, rangeSort});
rewriter.replaceOp(op, decl);
return success();
}
};
/// Lower the 'smt.apply_func' operation to Z3 API calls of the form:
/// ```
/// Z3_ast Z3_API Z3_mk_app(Z3_context c, Z3_func_decl d,
/// unsigned num_args, Z3_ast const args[]);
/// ```
struct ApplyFuncOpLowering : public SMTLoweringPattern<ApplyFuncOp> {
using SMTLoweringPattern::SMTLoweringPattern;
LogicalResult
matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
Type arrTy = LLVM::LLVMArrayType::get(llvmPtrTy, adaptor.getArgs().size());
// Create an array of the function arguments.
Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
for (auto [i, arg] : llvm::enumerate(adaptor.getArgs()))
domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, arg, i);
// Store the array on the stack.
Value one =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
Value domainStorage =
rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
// Call the API function with a pointer to the function, the number of
// arguments, and the pointer to the arguments stored on the stack.
Value domainSize = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), adaptor.getArgs().size());
Value returnVal =
buildPtrAPICall(rewriter, loc, "Z3_mk_app",
{adaptor.getFunc(), domainSize, domainStorage});
rewriter.replaceOp(op, returnVal);
return success();
}
};
/// Lower the `smt.bv.constant` operation to either
/// ```
/// Z3_ast Z3_API Z3_mk_unsigned_int64(Z3_context c, uint64_t v, Z3_sort ty);
/// ```
/// if the bit-vector fits into a 64-bit integer or convert it to a string and
/// use the sligtly slower but arbitrary precision API function:
/// ```
/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
/// ```
/// Note that there is also an API function taking an array of booleans, and
/// while those are typically compiled to 'i8' in LLVM they don't necessarily
/// have to (I think).
struct BVConstantOpLowering : public SMTLoweringPattern<smt::BVConstantOp> {
using SMTLoweringPattern::SMTLoweringPattern;
LogicalResult
matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
unsigned width = op.getType().getWidth();
auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
APInt val = adaptor.getValue().getValue();
if (width <= 64) {
Value bvConst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI64Type(), val.getZExtValue());
Value res = buildPtrAPICall(rewriter, loc, "Z3_mk_unsigned_int64",
{bvConst, bvSort});
rewriter.replaceOp(op, res);
return success();
}
std::string str;
llvm::raw_string_ostream stream(str);
stream << val;
Value bvString = buildString(rewriter, loc, str);
Value bvNumeral =
buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {bvString, bvSort});
rewriter.replaceOp(op, bvNumeral);
return success();
}
};
/// Some of the Z3 API supports a variadic number of operands for some
/// operations (in particular if the expansion would lead to a super-linear
/// increase in operations such as with the ':pairwise' attribute). Those API
/// calls take an 'unsigned' argument indicating the size of an array of
/// pointers to the operands.
template <typename SourceTy>
struct VariadicSMTPattern : public SMTLoweringPattern<SourceTy> {
using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
VariadicSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
SMTGlobalsHandler &globals,
const LowerSMTToZ3LLVMOptions &options,
StringRef apiFuncName, unsigned minNumArgs)
: SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
LogicalResult
matchAndRewrite(SourceTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (adaptor.getOperands().size() < minNumArgs)
return failure();
Location loc = op.getLoc();
Value numOperands = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), op->getNumOperands());
Value constOne =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
Type arrTy = LLVM::LLVMArrayType::get(ptrTy, op->getNumOperands());
Value storage =
rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
for (auto [i, operand] : llvm::enumerate(adaptor.getOperands()))
array = rewriter.create<LLVM::InsertValueOp>(
loc, array, operand, ArrayRef<int64_t>{(int64_t)i});
rewriter.create<LLVM::StoreOp>(loc, array, storage);
rewriter.replaceOp(op,
SMTLoweringPattern<SourceTy>::buildPtrAPICall(
rewriter, loc, apiFuncName, {numOperands, storage}));
return success();
}
private:
StringRef apiFuncName;
unsigned minNumArgs;
};
/// Lower an SMT operation to a function call with the name 'apiFuncName' with
/// arguments matching the operands one-to-one.
template <typename SourceTy>
struct OneToOneSMTPattern : public SMTLoweringPattern<SourceTy> {
using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
OneToOneSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
SMTGlobalsHandler &globals,
const LowerSMTToZ3LLVMOptions &options,
StringRef apiFuncName, unsigned numOperands)
: SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
apiFuncName(apiFuncName), numOperands(numOperands) {}
LogicalResult
matchAndRewrite(SourceTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (adaptor.getOperands().size() != numOperands)
return failure();
rewriter.replaceOp(
op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
return success();
}
private:
StringRef apiFuncName;
unsigned numOperands;
};
/// A pattern to lower SMT operations with a variadic number of operands
/// modelling the ':chainable' attribute in SMT to binary operations.
template <typename SourceTy>
class LowerChainableSMTPattern : public SMTLoweringPattern<SourceTy> {
using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
LogicalResult
matchAndRewrite(SourceTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (adaptor.getOperands().size() <= 2)
return failure();
Location loc = op.getLoc();
SmallVector<Value> elements;
for (int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
Value val = rewriter.create<SourceTy>(
loc, op->getResultTypes(),
ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
elements.push_back(val);
}
rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
return success();
}
};
/// The 'smt.solver' operation has a region that corresponds to the lifetime of
/// the Z3 context and one solver instance created within this context.
/// To create a context, a Z3 configuration has to be built first and various
/// configuration parameters can be set before creating a context from it. Once
/// we have a context, we can create a solver and store a pointer to the context
/// and the solver in an LLVM global such that operations in the child region
/// have access to them. While the context created with `Z3_mk_context` takes
/// care of the reference counting of `Z3_AST` objects, it still requires manual
/// reference counting of `Z3_solver` objects, therefore, we need to increase
/// the ref. counter of the solver we get from `Z3_mk_solver` and must decrease
/// it again once we don't need it anymore. Finally, the configuration object
/// can be deleted.
/// ```
/// Z3_config Z3_API Z3_mk_config(void);
/// void Z3_API Z3_set_param_value(Z3_config c, Z3_string param_id,
/// Z3_string param_value);
/// Z3_context Z3_API Z3_mk_context(Z3_config c);
/// Z3_solver Z3_API Z3_mk_solver(Z3_context c);
/// void Z3_API Z3_solver_inc_ref(Z3_context c, Z3_solver s);
/// void Z3_API Z3_del_config(Z3_config c);
/// ```
/// At the end of the solver lifetime, we have to tell the context that we
/// don't need the solver anymore and delete the context itself.
/// ```
/// void Z3_API Z3_solver_dec_ref(Z3_context c, Z3_solver s);
/// void Z3_API Z3_del_context(Z3_context c);
/// ```
/// Note that the solver created here is a combined solver. There might be some
/// potential for optimization by creating more specialized solvers supported by
/// the Z3 API according the the kind of operations present in the body region.
struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
using SMTLoweringPattern::SMTLoweringPattern;
LogicalResult
matchAndRewrite(SolverOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto ptrTy = LLVM::LLVMPointerType::get(getContext());
auto voidTy = LLVM::LLVMVoidType::get(getContext());
auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});
// Create the configuration.
Value config = buildCall(rewriter, loc, "Z3_mk_config",
LLVM::LLVMFunctionType::get(ptrTy, {}), {})
.getResult();
// In debug-mode, we enable proofs such that we can fetch one in the 'unsat'
// region of each 'smt.check' operation.
if (options.debug) {
Value paramKey = buildString(rewriter, loc, "proof");
Value paramValue = buildString(rewriter, loc, "true");
buildCall(rewriter, loc, "Z3_set_param_value",
LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, ptrTy}),
{config, paramKey, paramValue});
}
// Create the context and store a pointer to it in the global variable.
Value ctx = buildCall(rewriter, loc, "Z3_mk_context", ptrToPtrFunc, config)
.getResult();
Value ctxAddr =
rewriter.create<LLVM::AddressOfOp>(loc, globals.ctx).getResult();
rewriter.create<LLVM::StoreOp>(loc, ctx, ctxAddr);
// Delete the configuration again.
buildCall(rewriter, loc, "Z3_del_config", ptrToVoidFunc, {config});
// Create a solver instance, increase its reference counter, and store a
// pointer to it in the global variable.
Value solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
->getResult(0);
buildCall(rewriter, loc, "Z3_solver_inc_ref", ptrPtrToVoidFunc,
{ctx, solver});
Value solverAddr =
rewriter.create<LLVM::AddressOfOp>(loc, globals.solver).getResult();
rewriter.create<LLVM::StoreOp>(loc, solver, solverAddr);
// This assumes that no constant hoisting of the like happens inbetween
// the patterns defined in this pass because once the solver initialization
// and deallocation calls are inserted and the body region is inlined,
// canonicalizations and folders applied inbetween lowering patterns might
// hoist the SMT constants which means they would access uninitialized
// global variables once they are lowered.
SmallVector<Type> convertedTypes;
if (failed(
typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
return failure();
func::FuncOp funcOp;
{
OpBuilder::InsertionGuard guard(rewriter);
auto module = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPointToEnd(module.getBody());
funcOp = rewriter.create<func::FuncOp>(
loc, globals.names.newName("solver"),
rewriter.getFunctionType(adaptor.getInputs().getTypes(),
convertedTypes));
rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
funcOp.end());
}
ValueRange results =
rewriter.create<func::CallOp>(loc, funcOp, adaptor.getInputs())
->getResults();
// At the end of the region, decrease the solver's reference counter and
// delete the context.
// NOTE: we cannot use the convenience helper here because we don't want to
// load the context from the global but use the result from the 'mk_context'
// call directly for two reasons:
// * avoid an unnecessary load
// * the caching mechanism of the context does not work here because it
// would reuse the loaded context from a earlier solver
buildCall(rewriter, loc, "Z3_solver_dec_ref", ptrPtrToVoidFunc,
{ctx, solver});
buildCall(rewriter, loc, "Z3_del_context", ptrToVoidFunc, ctx);
rewriter.replaceOp(op, results);
return success();
}
};
/// Lower `smt.assert` operations to Z3 API calls of the form:
/// ```
/// void Z3_API Z3_solver_assert(Z3_context c, Z3_solver s, Z3_ast a);
/// ```
struct AssertOpLowering : public SMTLoweringPattern<AssertOp> {
using SMTLoweringPattern::SMTLoweringPattern;
LogicalResult
matchAndRewrite(AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
buildAPICallWithContext(
rewriter, loc, "Z3_solver_assert",
LLVM::LLVMVoidType::get(getContext()),
{buildSolverPtr(rewriter, loc), adaptor.getInput()});
rewriter.eraseOp(op);
return success();
}
};
/// Lower `smt.yield` operations to `scf.yield` operations. This not necessary
/// for the yield in `smt.solver` or in quantifiers since they are deleted
/// directly by the parent operation, but makes the lowering of the `smt.check`
/// operation simpler and more convenient since the regions get translated
/// directly to regions of `scf.if` operations.
struct YieldOpLowering : public SMTLoweringPattern<YieldOp> {
using SMTLoweringPattern::SMTLoweringPattern;
LogicalResult
matchAndRewrite(YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (op->getParentOfType<func::FuncOp>()) {
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
return success();
}
if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
return success();
}
if (isa<scf::SCFDialect>(op->getParentOp()->getDialect())) {
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
return success();
}
return failure();
}
};
/// Lower `smt.check` operations to Z3 API calls and control-flow operations.
/// ```
/// Z3_lbool Z3_API Z3_solver_check(Z3_context c, Z3_solver s);
///
/// typedef enum
/// {
/// Z3_L_FALSE = -1, // means unsatisfiable here
/// Z3_L_UNDEF, // means unknown here
/// Z3_L_TRUE // means satisfiable here
/// } Z3_lbool;
/// ```
struct CheckOpLowering : public SMTLoweringPattern<CheckOp> {
using SMTLoweringPattern::SMTLoweringPattern;
LogicalResult
matchAndRewrite(CheckOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
auto printfType =
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrTy}, true);
auto getHeaderString = [](const std::string &title) {
unsigned titleSize = title.size() + 2; // Add a space left and right
return std::string((80 - titleSize) / 2, '-') + " " + title + " " +
std::string((80 - titleSize + 1) / 2, '-') + "\n%s\n" +
std::string(80, '-') + "\n";
};
// Get the pointer to the solver instance.
Value solver = buildSolverPtr(rewriter, loc);
// In debug-mode, print the state of the solver before calling 'check-sat'
// on it. This prints the asserted SMT expressions.
if (options.debug) {
auto solverStringPtr =
buildPtrAPICall(rewriter, loc, "Z3_solver_to_string", {solver});
auto solverFormatString =
buildString(rewriter, loc, getHeaderString("Solver"));
buildCall(rewriter, op.getLoc(), "printf", printfType,
{solverFormatString, solverStringPtr});
}
// Convert the result types of the `smt.check` operation.
SmallVector<Type> resultTypes;
if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
return failure();
// Call 'check-sat' and check if the assertions are satisfiable.
Value checkResult =
buildAPICallWithContext(rewriter, loc, "Z3_solver_check",
rewriter.getI32Type(), {solver})
->getResult(0);
Value constOne =
rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), 1);
Value isSat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
checkResult, constOne);
// Simply inline the 'sat' region into the 'then' region of the 'scf.if'
auto satIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isSat);
rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
satIfOp.getThenRegion().end());
// Otherwise, the 'else' block checks if the assertions are unsatisfiable or
// unknown. The corresponding regions can also be simply inlined into the
// two branches of this nested if-statement as well.
rewriter.createBlock(&satIfOp.getElseRegion());
Value constNegOne =
rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), -1);
Value isUnsat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
checkResult, constNegOne);
auto unsatIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isUnsat);
rewriter.create<scf::YieldOp>(loc, unsatIfOp->getResults());
rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
unsatIfOp.getThenRegion().end());
rewriter.inlineRegionBefore(op.getUnknownRegion(),
unsatIfOp.getElseRegion(),
unsatIfOp.getElseRegion().end());
rewriter.replaceOp(op, satIfOp->getResults());
if (options.debug) {
// In debug-mode, if the assertions are unsatisfiable we can print the
// proof.
rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
auto proof = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_proof",
{solver});
auto stringPtr =
buildPtrAPICall(rewriter, op.getLoc(), "Z3_ast_to_string", {proof});
auto formatString =
buildString(rewriter, op.getLoc(), getHeaderString("Proof"));
buildCall(rewriter, op.getLoc(), "printf", printfType,
{formatString, stringPtr});
// In debug mode, if the assertions are satisfiable we can print the model
// (effectively a counter-example).
rewriter.setInsertionPointToStart(satIfOp.thenBlock());
auto model = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_model",
{solver});
auto modelStringPtr =
buildPtrAPICall(rewriter, op.getLoc(), "Z3_model_to_string", {model});
auto modelFormatString =
buildString(rewriter, op.getLoc(), getHeaderString("Model"));
buildCall(rewriter, op.getLoc(), "printf", printfType,
{modelFormatString, modelStringPtr});
}
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pass Implementation
//===----------------------------------------------------------------------===//
namespace {
struct LowerSMTToZ3LLVMPass
: public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
using Base::Base;
void runOnOperation() override;
};
} // namespace
void circt::populateSMTToZ3LLVMTypeConverter(TypeConverter &converter) {
converter.addConversion([](smt::BoolType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
converter.addConversion([](smt::BitVectorType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
converter.addConversion([](smt::ArrayType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
converter.addConversion([](smt::IntType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
converter.addConversion([](smt::SMTFuncType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
converter.addConversion([](smt::SortType type) {
return LLVM::LLVMPointerType::get(type.getContext());
});
}
void circt::populateSMTToZ3LLVMConversionPatterns(
RewritePatternSet &patterns, TypeConverter &converter,
SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options) {
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
patterns.add<VariadicSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
converter, patterns.getContext(), \
globals, options, APINAME, \
MIN_NUM_ARGS);
// Lower `smt.distinct` operations which allows a variadic number of operands
// according to the `:pairwise` attribute. The Z3 API function supports a
// variadic number of operands as well, i.e., a direct lowering is possible:
// ```
// Z3_ast Z3_API Z3_mk_distinct(Z3_context c, unsigned num_args, Z3_ast const
// args[])
// ```
// The API function requires num_args > 1 which is guaranteed to be satisfied
// because `smt.distinct` is verified to have > 1 operands.
ADD_VARIADIC_PATTERN(DistinctOp, "Z3_mk_distinct", 2);
// Lower `smt.and` operations which allows a variadic number of operands
// according to the `:left-assoc` attribute. The Z3 API function supports a
// variadic number of operands as well, i.e., a direct lowering is possible:
// ```
// Z3_ast Z3_API Z3_mk_and(Z3_context c, unsigned num_args, Z3_ast const
// args[])
// ```
// The API function requires num_args > 0. This is not guaranteed by the
// `smt.and` operation and thus the pattern will not apply when no operand is
// present. The constant folder of the operation is assumed to fold this to
// a constant 'true' (neutral element of AND).
ADD_VARIADIC_PATTERN(AndOp, "Z3_mk_and", 1);
#undef ADD_VARIADIC_PATTERN
// Lower `smt.eq` operations which allows a variadic number of operands
// according to the `:chainable` attribute. The Z3 API function does not
// support a variadic number of operands, but exactly two:
// ```
// Z3_ast Z3_API Z3_mk_eq(Z3_context c, Z3_ast l, Z3_ast r)
// ```
// As a result, we first apply a rewrite pattern that unfolds chainable
// operators and then lower it one-to-one to the API function. In this case,
// this means:
// ```
// eq(a,b,c,d) ->
// and(eq(a,b), eq(b,c), eq(c,d)) ->
// and(Z3_mk_eq(ctx, a, b), Z3_mk_eq(ctx, b, c), Z3_mk_eq(ctx, c, d))
// ```
// The patterns for `smt.and` will then do the remaining work.
patterns.add<LowerChainableSMTPattern<EqOp>>(converter, patterns.getContext(),
globals, options);
patterns.add<OneToOneSMTPattern<EqOp>>(converter, patterns.getContext(),
globals, options, "Z3_mk_eq", 2);
// Other lowering patterns. Refer to their implementation directly for more
// information.
patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
CheckOpLowering, SolverOpLowering, ApplyFuncOpLowering,
YieldOpLowering>(converter, patterns.getContext(), globals,
options);
}
void LowerSMTToZ3LLVMPass::runOnOperation() {
LowerSMTToZ3LLVMOptions options;
options.debug = debug;
// Set up the type converter
LLVMTypeConverter converter(&getContext());
populateSMTToZ3LLVMTypeConverter(converter);
RewritePatternSet patterns(&getContext());
// Populate the func to LLVM conversion patterns for two reasons:
// * Typically functions are represented using `func.func` and including the
// patterns to lower them here is more convenient for most lowering
// pipelines (avoids running another pass).
// * Already having `llvm.func` in the input or lowering `func.func` before
// the SMT in the body leads to issues because the SCF conversion patterns
// don't take the type converter into consideration and thus create blocks
// with the old types for block arguments. However, the conversion happens
// top-down and thus are assumed to be converted by the parent function op
// which at that point would have already been lowered (and the blocks are
// also not there when doing everything in one pass, i.e.,
// `populateAnyFunctionOpInterfaceTypeConversionPattern` does not have any
// effect as well). Are the SCF lowering patterns actually broken and should
// take a type-converter?
populateFuncToLLVMConversionPatterns(converter, patterns);
// Populate SCF to CF and CF to LLVM lowering patterns because we create
// `scf.if` operations in the lowering patterns for convenience (given the
// above issue we might want to lower to LLVM directly; or fix upstream?)
populateSCFToControlFlowConversionPatterns(patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// Create the globals to store the context and solver and populate the SMT
// lowering patterns.
OpBuilder builder(&getContext());
auto globals = SMTGlobalsHandler::create(builder, getOperation());
populateSMTToZ3LLVMConversionPatterns(patterns, converter, globals, options);
// Do a full conversion. This assumes that all other dialects have been
// lowered before this pass already.
LLVMConversionTarget target(getContext());
target.addLegalOp<mlir::ModuleOp>();
target.addLegalOp<scf::YieldOp>();
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
}

View File

@ -0,0 +1,194 @@
// RUN: circt-opt %s --lower-smt-to-z3-llvm | FileCheck %s
// RUN: circt-opt %s --lower-smt-to-z3-llvm=debug=true | FileCheck %s --check-prefix=CHECK-DEBUG
// CHECK-LABEL: llvm.mlir.global internal @ctx_0()
// CHECK-NEXT: llvm.mlir.zero : !llvm.ptr
// CHECK-NEXT: llvm.return
// CHECK-LABEL: llvm.mlir.global internal @solver_0()
// CHECK-NEXT: llvm.mlir.zero : !llvm.ptr
// CHECK-NEXT: llvm.return
// CHECK-LABEL: llvm.mlir.global internal @ctx()
llvm.mlir.global internal @ctx() {alignment = 8 : i64} : !llvm.ptr {
%0 = llvm.mlir.zero : !llvm.ptr
llvm.return %0 : !llvm.ptr
}
// CHECK-LABEL: llvm.mlir.global internal @solver()
llvm.mlir.global internal @solver() {alignment = 8 : i64} : !llvm.ptr {
%0 = llvm.mlir.zero : !llvm.ptr
llvm.return %0 : !llvm.ptr
}
// CHECK-LABEL: llvm.func @test
// CHECK: [[CONFIG:%.+]] = llvm.call @Z3_mk_config() : () -> !llvm.ptr
// CHECK-DEBUG: [[PROOF_STR:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: [[TRUE_STR:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: llvm.call @Z3_set_param_value({{.*}}, [[PROOF_STR]], [[TRUE_STR]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
// CHECK: [[CTX:%.+]] = llvm.call @Z3_mk_context([[CONFIG]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: [[CTX_ADDR:%.+]] = llvm.mlir.addressof @ctx_0 : !llvm.ptr
// CHECK: llvm.store [[CTX]], [[CTX_ADDR]] : !llvm.ptr, !llvm.ptr
// CHECK: llvm.call @Z3_del_config([[CONFIG]]) : (!llvm.ptr) -> ()
// CHECK: [[SOLVER:%.+]] = llvm.call @Z3_mk_solver([[CTX]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: llvm.call @Z3_solver_inc_ref([[CTX]], [[SOLVER]]) : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: [[SOLVER_ADDR:%.+]] = llvm.mlir.addressof @solver_0 : !llvm.ptr
// CHECK: llvm.store [[SOLVER]], [[SOLVER_ADDR]] : !llvm.ptr, !llvm.ptr
// CHECK: llvm.call @solver
// CHECK: llvm.call @Z3_solver_dec_ref([[CTX]], [[SOLVER]]) : (!llvm.ptr, !llvm.ptr) -> ()
// CHECK: llvm.call @Z3_del_context([[CTX]]) : (!llvm.ptr) -> ()
// CHECK: llvm.return
// CHECK-LABEL: llvm.func @solver
func.func @test(%arg0: i32) {
%0 = smt.solver (%arg0) : (i32) -> (i32) {
^bb0(%arg1: i32):
// CHECK: [[STR:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK: [[CTX_ADDR:%.+]] = llvm.mlir.addressof @ctx_0 : !llvm.ptr
// CHECK: [[CTX:%.+]] = llvm.load [[CTX_ADDR]] : !llvm.ptr -> !llvm.ptr
// CHECK: [[INT_SORT:%.+]] = llvm.call @Z3_mk_int_sort([[CTX]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: [[BOOL_SORT:%.+]] = llvm.call @Z3_mk_bool_sort([[CTX]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: [[ARRAY_SORT:%.+]] = llvm.call @Z3_mk_array_sort([[CTX]], [[INT_SORT]], [[BOOL_SORT]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK: llvm.call @Z3_mk_fresh_const([[CTX]], [[STR]], [[ARRAY_SORT]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr
// Test: declare constant, array, int, bool types
%1 = smt.declare_fun "a" : !smt.array<[!smt.int -> !smt.bool]>
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero : !llvm.ptr
// CHECK: [[STR:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK: [[SYM:%.+]] = llvm.call @Z3_mk_string_symbol([[CTX]], [[STR]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK: [[SORT:%.+]] = llvm.call @Z3_mk_uninterpreted_sort([[CTX]], [[SYM]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK: [[ARR0:%.+]] = llvm.mlir.undef : !llvm.array<2 x ptr>
// CHECK: [[C65:%.+]] = llvm.mlir.constant(65 : i32) : i32
// CHECK: [[BV_SORT:%.+]] = llvm.call @Z3_mk_bv_sort([[CTX]], [[C65]]) : (!llvm.ptr, i32) -> !llvm.ptr
// CHECK: [[ARR1:%.+]] = llvm.insertvalue [[BV_SORT]], [[ARR0]][0] : !llvm.array<2 x ptr>
// CHECK: [[C4:%.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: [[BV_SORT:%.+]] = llvm.call @Z3_mk_bv_sort([[CTX]], [[C4]]) : (!llvm.ptr, i32) -> !llvm.ptr
// CHECK: [[ARR2:%.+]] = llvm.insertvalue [[BV_SORT]], [[ARR1]][1] : !llvm.array<2 x ptr>
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[STORAGE:%.+]] = llvm.alloca [[C1]] x !llvm.array<2 x ptr> : (i32) -> !llvm.ptr
// CHECK: llvm.store [[ARR2]], [[STORAGE]] : !llvm.array<2 x ptr>, !llvm.ptr
// CHECK: [[C2:%.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[V2:%.+]] = llvm.call @Z3_mk_fresh_func_decl([[CTX]], [[ZERO]], [[C2]], [[STORAGE]], [[SORT]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr, !llvm.ptr) -> !llvm.ptr
// Test: declare function, bit-vector, uninterpreted sort types
%2 = smt.declare_fun : !smt.func<(!smt.bv<65>, !smt.bv<4>) !smt.sort<"uninterpreted_sort"[!smt.bool]>>
// CHECK: [[C4:%.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: [[BV_SORT:%.+]] = llvm.call @Z3_mk_bv_sort([[CTX]], [[C4]]) : (!llvm.ptr, i32) -> !llvm.ptr
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: [[BV0:%.+]] = llvm.call @Z3_mk_unsigned_int64([[CTX]], [[C0]], [[BV_SORT]]) : (!llvm.ptr, i64, !llvm.ptr) -> !llvm.ptr
%c0_bv4 = smt.bv.constant #smt.bv<0> : !smt.bv<4>
// CHECK: [[C65:%.+]] = llvm.mlir.constant(65 : i32) : i32
// CHECK: [[BV_SORT:%.+]] = llvm.call @Z3_mk_bv_sort([[CTX]], [[C65]]) : (!llvm.ptr, i32) -> !llvm.ptr
// CHECK: [[STR:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK: [[BV42:%.+]] = llvm.call @Z3_mk_numeral([[CTX]], [[STR]], [[BV_SORT]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr
%c42_bv65 = smt.bv.constant #smt.bv<42> : !smt.bv<65>
// CHECK: [[ARR0:%.+]] = llvm.mlir.undef : !llvm.array<2 x ptr>
// CHECK: [[ARR1:%.+]] = llvm.insertvalue [[BV42]], [[ARR0]][0] : !llvm.array<2 x ptr>
// CHECK: [[ARR2:%.+]] = llvm.insertvalue [[BV0]], [[ARR1]][1] : !llvm.array<2 x ptr>
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[STORAGE:%.+]] = llvm.alloca [[C1]] x !llvm.array<2 x ptr> : (i32) -> !llvm.ptr
// CHECK: llvm.store [[ARR2]], [[STORAGE]] : !llvm.array<2 x ptr>, !llvm.ptr
// CHECK: [[C2:%.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: llvm.call @Z3_mk_app([[CTX]], [[V2]], [[C2]], [[STORAGE]]) : (!llvm.ptr, !llvm.ptr, i32, !llvm.ptr) -> !llvm.ptr
%3 = smt.apply_func %2(%c42_bv65, %c0_bv4) : !smt.func<(!smt.bv<65>, !smt.bv<4>) !smt.sort<"uninterpreted_sort"[!smt.bool]>>
// CHECK: [[EQ2:%.+]] = llvm.call @Z3_mk_eq([[CTX]], [[BV0]], [[BV0]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK: [[EQ3:%.+]] = llvm.call @Z3_mk_eq([[CTX]], [[BV0]], [[BV0]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK: [[C2:%.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[STORAGE:%.+]] = llvm.alloca [[C1]] x !llvm.array<2 x ptr> : (i32) -> !llvm.ptr
// CHECK: [[ARR0:%.+]] = llvm.mlir.undef : !llvm.array<2 x ptr>
// CHECK: [[ARR1:%.+]] = llvm.insertvalue [[EQ2]], [[ARR0]][0] : !llvm.array<2 x ptr>
// CHECK: [[ARR2:%.+]] = llvm.insertvalue [[EQ3]], [[ARR1]][1] : !llvm.array<2 x ptr>
// CHECK: llvm.store [[ARR2]], [[STORAGE]] : !llvm.array<2 x ptr>, !llvm.ptr
// CHECK: [[EQ0:%.+]] = llvm.call @Z3_mk_and([[CTX]], [[C2]], [[STORAGE]]) : (!llvm.ptr, i32, !llvm.ptr) -> !llvm.ptr
%4 = smt.eq %c0_bv4, %c0_bv4, %c0_bv4 : !smt.bv<4>
// CHECK: [[EQ1:%.+]] = llvm.call @Z3_mk_eq([[CTX]], [[BV0]], [[BV0]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr
%5 = smt.eq %c0_bv4, %c0_bv4 : !smt.bv<4>
// CHECK-NEXT: [[THREE:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK-NEXT: [[ONE:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: [[STORAGE:%.+]] = llvm.alloca [[ONE]] x !llvm.array<3 x ptr> : (i32) -> !llvm.ptr
// CHECK-NEXT: [[A0:%.+]] = llvm.mlir.undef : !llvm.array<3 x ptr>
// CHECK-NEXT: [[A1:%.+]] = llvm.insertvalue [[BV0]], [[A0]][0] : !llvm.array<3 x ptr>
// CHECK-NEXT: [[A2:%.+]] = llvm.insertvalue [[BV0]], [[A1]][1] : !llvm.array<3 x ptr>
// CHECK-NEXT: [[A3:%.+]] = llvm.insertvalue [[BV0]], [[A2]][2] : !llvm.array<3 x ptr>
// CHECK-NEXT: llvm.store [[A3]], [[STORAGE]] : !llvm.array<3 x ptr>, !llvm.ptr
// CHECK-NEXT: [[DISTINCT:%.+]] = llvm.call @Z3_mk_distinct([[CTX]], [[THREE]], [[STORAGE]]) : (!llvm.ptr, i32, !llvm.ptr) -> !llvm.ptr
%6 = smt.distinct %c0_bv4, %c0_bv4, %c0_bv4 : !smt.bv<4>
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[STORAGE:%.+]] = llvm.alloca [[C1]] x !llvm.array<3 x ptr> : (i32) -> !llvm.ptr
// CHECK: [[ARR0:%.+]] = llvm.mlir.undef : !llvm.array<3 x ptr>
// CHECK: [[ARR1:%.+]] = llvm.insertvalue [[EQ0]], [[ARR0]][0] : !llvm.array<3 x ptr>
// CHECK: [[ARR2:%.+]] = llvm.insertvalue [[EQ1]], [[ARR1]][1] : !llvm.array<3 x ptr>
// CHECK: [[ARR3:%.+]] = llvm.insertvalue [[DISTINCT]], [[ARR2]][2] : !llvm.array<3 x ptr>
// CHECK: llvm.store [[ARR3]], [[STORAGE]] : !llvm.array<3 x ptr>, !llvm.ptr
// CHECK: [[AND:%.+]] = llvm.call @Z3_mk_and([[CTX]], [[C3]], [[STORAGE]]) : (!llvm.ptr, i32, !llvm.ptr) -> !llvm.ptr
%7 = smt.and %4, %5, %6
// CHECK: [[S_ADDR:%.+]] = llvm.mlir.addressof @solver_0 : !llvm.ptr
// CHECK: [[S:%.+]] = llvm.load [[S_ADDR]] : !llvm.ptr -> !llvm.ptr
// CHECK: llvm.call @Z3_solver_assert([[CTX]], [[S]], [[AND]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
smt.assert %7
// CHECK-DEBUG: [[SOLVER_STR:%.+]] = llvm.call @Z3_solver_to_string({{.*}}, {{.*}}) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK-DEBUG: [[FMT_STR:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: llvm.call @printf([[FMT_STR]], [[SOLVER_STR]]) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK: [[CHECK:%.+]] = llvm.call @Z3_solver_check([[CTX]], [[S]]) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[IS_SAT:%.+]] = llvm.icmp "eq" [[CHECK]], [[C1]] : i32
// CHECK: llvm.cond_br [[IS_SAT]], ^[[BB1:.+]], ^[[BB2:.+]]
// CHECK: ^[[BB1]]:
// CHECK-DEBUG: [[CTX_ADDR:%.+]] = llvm.mlir.addressof @ctx_0 : !llvm.ptr
// CHECK-DEBUG: [[CTX0:%.+]] = llvm.load [[CTX_ADDR]] : !llvm.ptr -> !llvm.ptr
// CHECK-DEBUG: [[MODEL:%.+]] = llvm.call @Z3_solver_get_model([[CTX0]], {{.*}}) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK-DEBUG: [[MODEL_STR:%.+]] = llvm.call @Z3_model_to_string([[CTX0]], [[MODEL]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK-DEBUG: [[FMT_STR:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: llvm.call @printf([[FMT_STR]], [[MODEL_STR]]) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.br ^[[BB7:.+]]([[C1]] : i32)
// CHECK: ^[[BB2]]:
// CHECK: [[CNEG1:%.+]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: [[IS_UNSAT:%.+]] = llvm.icmp "eq" [[CHECK]], [[CNEG1]] : i32
// CHECK: llvm.cond_br [[IS_UNSAT]], ^[[BB3:.+]], ^[[BB4:.+]]
// CHECK: ^[[BB3]]:
// CHECK-DEBUG: [[CTX_ADDR:%.+]] = llvm.mlir.addressof @ctx_0 : !llvm.ptr
// CHECK-DEBUG: [[CTX1:%.+]] = llvm.load [[CTX_ADDR]] : !llvm.ptr -> !llvm.ptr
// CHECK-DEBUG: [[PROOF:%.+]] = llvm.call @Z3_solver_get_proof([[CTX1]], {{.*}}) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK-DEBUG: [[PROOF_STR:%.+]] = llvm.call @Z3_ast_to_string([[CTX1]], [[PROOF]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
// CHECK-DEBUG: [[FMT_STR:%.+]] = llvm.mlir.addressof @str{{.*}} : !llvm.ptr
// CHECK-DEBUG: llvm.call @printf([[FMT_STR]], [[PROOF_STR]]) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, !llvm.ptr) -> i32
// CHECK: [[CNEG1:%.+]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: llvm.br ^[[BB5:.+]]([[CNEG1:%.+]] : i32)
// CHECK: ^[[BB4]]:
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.br ^[[BB5]]([[C0]] : i32)
// CHECK: ^[[BB5]]([[ARG0:%.+]]: i32):
// CHECK: llvm.br ^[[BB6:.+]]
// CHECK: ^[[BB6]]:
// CHECK: llvm.br ^[[BB7]]([[ARG0]] : i32)
// CHECK: ^[[BB7]]({{.+}}: i32):
// CHECK: llvm.br
%8 = smt.check sat {
%c1 = llvm.mlir.constant(1 : i32) : i32
smt.yield %c1 : i32
} unknown {
%c0 = llvm.mlir.constant(0 : i32) : i32
smt.yield %c0 : i32
} unsat {
%c-1 = llvm.mlir.constant(-1 : i32) : i32
smt.yield %c-1 : i32
} -> i32
smt.yield %8 : i32
}
// CHECK: llvm.return
return
}

View File

@ -84,6 +84,7 @@ target_link_libraries(circt-opt
CIRCTPipelineToHW
CIRCTPipelineTransforms
CIRCTSMT
CIRCTSMTToZ3LLVM
CIRCTSV
CIRCTSVTransforms
CIRCTHWArith