[arcilator] Introduce integrated JIT for simulation execution (#6783)

This PR adds a JIT runtime for arcilator, backed by MLIR's ExecutionEngine. This JIT allows executing `arc.sim` operations directly from the arcilator binary.
This commit is contained in:
Théo Degioanni 2024-03-18 11:27:08 +01:00 committed by GitHub
parent 4b075de691
commit 83a8292085
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 388 additions and 117 deletions

View File

@ -145,6 +145,12 @@ else()
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
# If building as part of a unified build, whether or not MLIR's execution engine
# is enabled must be fetched from its subdirectory scope.
get_directory_property(MLIR_ENABLE_EXECUTION_ENGINE
DIRECTORY ${MLIR_MAIN_SRC_DIR}
DEFINITION MLIR_ENABLE_EXECUTION_ENGINE)
set(BACKEND_PACKAGE_STRING "${PACKAGE_STRING}")
set(CIRCT_GTEST_AVAILABLE 1)
@ -587,6 +593,16 @@ if(CIRCT_SLANG_FRONTEND_ENABLED)
endif()
endif()
#-------------------------------------------------------------------------------
# Arcilator JIT
#-------------------------------------------------------------------------------
if(MLIR_ENABLE_EXECUTION_ENGINE)
set(ARCILATOR_JIT_ENABLED 1)
else()
set(ARCILATOR_JIT_ENABLED 0)
endif()
#-------------------------------------------------------------------------------
# Directory setup
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,26 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit
// CHECK: output = 5
hw.module @adder(in %a: i8, in %b: i8, out c: i8) {
%res = comb.add %a, %b : i8
hw.output %res : i8
}
func.func @main() {
%two = arith.constant 2 : i8
%three = arith.constant 3 : i8
arc.sim.instantiate @adder as %model {
arc.sim.set_input %model, "a" = %two : i8, !arc.sim.instance<@adder>
arc.sim.set_input %model, "b" = %three : i8, !arc.sim.instance<@adder>
arc.sim.step %model : !arc.sim.instance<@adder>
%res = arc.sim.get_port %model, "c" : i8, !arc.sim.instance<@adder>
arc.sim.emit "output", %res : i8
}
return
}

View File

@ -0,0 +1,50 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit
// CHECK: counter_value = 0
// CHECK-NEXT: counter_value = 1
// CHECK-NEXT: counter_value = 2
// CHECK-NEXT: counter_value = 3
// CHECK-NEXT: counter_value = 4
// CHECK-NEXT: counter_value = 5
// CHECK-NEXT: counter_value = 6
// CHECK-NEXT: counter_value = 7
// CHECK-NEXT: counter_value = 8
// CHECK-NEXT: counter_value = 9
// CHECK-NEXT: counter_value = a
hw.module @counter(in %clk: i1, out o: i8) {
%seq_clk = seq.to_clock %clk
%reg = seq.compreg %added, %seq_clk : i8
%one = hw.constant 1 : i8
%added = comb.add %reg, %one : i8
hw.output %reg : i8
}
func.func @main() {
%zero = arith.constant 0 : i1
%one = arith.constant 1 : i1
%lb = arith.constant 0 : index
%ub = arith.constant 10 : index
%step = arith.constant 1 : index
arc.sim.instantiate @counter as %model {
%init_val = arc.sim.get_port %model, "o" : i8, !arc.sim.instance<@counter>
arc.sim.emit "counter_value", %init_val : i8
scf.for %i = %lb to %ub step %step {
arc.sim.set_input %model, "clk" = %one : i1, !arc.sim.instance<@counter>
arc.sim.step %model : !arc.sim.instance<@counter>
arc.sim.set_input %model, "clk" = %zero : i1, !arc.sim.instance<@counter>
arc.sim.step %model : !arc.sim.instance<@counter>
%counter_val = arc.sim.get_port %model, "o" : i8, !arc.sim.instance<@counter>
arc.sim.emit "counter_value", %counter_val : i8
}
}
return
}

View File

@ -0,0 +1,8 @@
// RUN: ! (arcilator %s --run --jit-entry=unknown 2> %t) && FileCheck --input-file=%t %s
// REQUIRES: arcilator-jit
// CHECK: entry point not found: 'unknown'
func.func @main() {
return
}

View File

@ -0,0 +1,6 @@
// RUN: ! (arcilator %s --run --jit-entry=foo 2> %t) && FileCheck --input-file=%t %s
// REQUIRES: arcilator-jit
// CHECK: entry point 'foo' was found but on an operation of type 'llvm.mlir.global' while an LLVM function was expected
llvm.mlir.global @foo(0 : i32) : i32

View File

@ -0,0 +1,8 @@
// RUN: ! (arcilator %s --run --jit-entry=main 2> %t) && FileCheck --input-file=%t %s
// REQUIRES: arcilator-jit
// CHECK: entry point 'main' must have no arguments
func.func @main(%a: i32) {
return
}

View File

@ -0,0 +1,10 @@
// RUN: arcilator %s --run | FileCheck %s
// REQUIRES: arcilator-jit
// CHECK: result = 4
func.func @entry() {
%four = arith.constant 4 : i32
arc.sim.emit "result", %four : i32
return
}

View File

@ -78,7 +78,7 @@ tool_dirs = [
config.llvm_tools_dir
]
tools = [
'circt-opt', 'circt-translate', 'firtool', 'circt-rtl-sim.py',
'arcilator', 'circt-opt', 'circt-translate', 'firtool', 'circt-rtl-sim.py',
'equiv-rtl.sh', 'handshake-runner', 'hlstool', 'ibistool'
]
@ -206,6 +206,10 @@ if config.slang_frontend_enabled:
config.available_features.add('slang')
tools.append('circt-verilog')
# Add arcilator JIT if MLIR's execution engine is enabled.
if config.arcilator_jit_enabled:
config.available_features.add('arcilator-jit')
config.substitutions.append(('%driver', f'{config.driver}'))
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -54,6 +54,7 @@ config.bindings_python_enabled = @CIRCT_BINDINGS_PYTHON_ENABLED@
config.bindings_tcl_enabled = @CIRCT_BINDINGS_TCL_ENABLED@
config.lec_enabled = "@CIRCT_LEC_ENABLED@"
config.slang_frontend_enabled = "@CIRCT_SLANG_FRONTEND_ENABLED@"
config.arcilator_jit_enabled = @ARCILATOR_JIT_ENABLED@
config.driver = "@CIRCT_SOURCE_DIR@/tools/circt-rtl-sim/driver.cpp"
# Support substitution of the tools_dir with user parameters. This is

View File

@ -355,6 +355,8 @@ struct SimInstantiateOpLowering
ConversionPatternRewriter::InsertionGuard guard(rewriter);
// FIXME: like the rest of MLIR, this assumes sizeof(intptr_t) ==
// sizeof(size_t) on the target architecture.
Type convertedIndex = typeConverter->convertType(rewriter.getIndexType());
LLVM::LLVMFuncOp mallocFunc =
@ -460,8 +462,9 @@ struct SimStepOpLowering : public ModelAwarePattern<arc::SimStepOp> {
}
};
/// Lowers SimEmitValueOp to a printf call. This pattern will mutate the global
/// module.
/// Lowers SimEmitValueOp to a printf call. The integer will be printed in its
/// entirety if it is of size up to size_t, and explicitly truncated otherwise.
/// This pattern will mutate the global module.
struct SimEmitValueOpLowering
: public OpConversionPattern<arc::SimEmitValueOp> {
using OpConversionPattern::OpConversionPattern;
@ -475,19 +478,38 @@ struct SimEmitValueOpLowering
Location loc = op.getLoc();
Value toPrint = rewriter.create<LLVM::IntToPtrOp>(
loc, LLVM::LLVMPointerType::get(getContext()), adaptor.getValue());
ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
// Cast the value to a size_t.
// FIXME: like the rest of MLIR, this assumes sizeof(intptr_t) ==
// sizeof(size_t) on the target architecture.
Value toPrint = adaptor.getValue();
DataLayout layout = DataLayout::closest(op);
llvm::TypeSize sizeOfSizeT =
layout.getTypeSizeInBits(rewriter.getIndexType());
assert(!sizeOfSizeT.isScalable() &&
sizeOfSizeT.getFixedValue() <= std::numeric_limits<unsigned>::max());
bool truncated = false;
if (valueType.getWidth() > sizeOfSizeT) {
toPrint = rewriter.create<LLVM::TruncOp>(
loc, IntegerType::get(getContext(), sizeOfSizeT.getFixedValue()),
toPrint);
truncated = true;
} else if (valueType.getWidth() < sizeOfSizeT)
toPrint = rewriter.create<LLVM::ZExtOp>(
loc, IntegerType::get(getContext(), sizeOfSizeT.getFixedValue()),
toPrint);
// Lookup of create printf function symbol.
auto printfFunc = LLVM::lookupOrCreateFn(
moduleOp, "printf", LLVM::LLVMPointerType::get(getContext()),
LLVM::LLVMVoidType::get(getContext()), true);
// Insert the format string if not already available.
SmallString<16> formatStrName{"_arc_sim_emit_"};
formatStrName.append(truncated ? "trunc_" : "full_");
formatStrName.append(adaptor.getValueName());
LLVM::GlobalOp formatStrGlobal;
if (!(formatStrGlobal =
@ -495,7 +517,10 @@ struct SimEmitValueOpLowering
ConversionPatternRewriter::InsertionGuard insertGuard(rewriter);
SmallString<16> formatStr = adaptor.getValueName();
formatStr.append(" = %0.8p\n");
formatStr.append(" = ");
if (truncated)
formatStr.append("(truncated) ");
formatStr.append("%zx\n");
SmallVector<char> formatStrVec{formatStr.begin(), formatStr.end()};
formatStrVec.push_back(0);

View File

@ -1,13 +1,18 @@
// RUN: arcilator %s --emit-mlir | FileCheck %s
hw.module @id(in %i: i8, in %j: i8, out o: i8) {
module attributes { dlti.dl_spec = #dlti.dl_spec<
#dlti.dl_entry<index, 16>
> } {
hw.module @id(in %i: i8, in %j: i8, out o: i8) {
hw.output %i : i8
}
}
// CHECK-DAG: llvm.mlir.global internal constant @[[format_str:.*]]("result = %0.8p\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[format_str2:.*]]("result2 = %0.8p\0A\00")
// CHECK-LABEL: llvm.func @full
func.func @full() {
// CHECK-DAG: llvm.mlir.global internal constant @[[format_str:.*]]("result = %zx\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[format_str2:.*]]("result2 = %zx\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[format_str_trunc:.*]]("result = (truncated) %zx\0A\00")
// CHECK-LABEL: llvm.func @full
func.func @full() {
%c = arith.constant 24 : i8
// CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(24 : i8)
@ -16,33 +21,45 @@ func.func @full() {
// CHECK-DAG: %[[state:.*]] = llvm.call @malloc(%[[size:.*]]) :
// CHECK: "llvm.intr.memset"(%[[state]], %[[zero]], %[[size]]) <{isVolatile = false}>
arc.sim.instantiate @id as %model {
// CHECK-NEXT: llvm.store %[[c]], %[[state]] : i8
arc.sim.set_input %model, "i" = %c : i8, !arc.sim.instance<@id>
// CHECK-NEXT: llvm.store %[[c]], %[[state]] : i8
arc.sim.set_input %model, "i" = %c : i8, !arc.sim.instance<@id>
// CHECK-NEXT: %[[j_ptr:.*]] = llvm.getelementptr %[[state]][1] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK-NEXT: llvm.store %[[c]], %[[j_ptr]] : i8
arc.sim.set_input %model, "j" = %c : i8, !arc.sim.instance<@id>
// CHECK-NEXT: %[[j_ptr:.*]] = llvm.getelementptr %[[state]][1] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK-NEXT: llvm.store %[[c]], %[[j_ptr]] : i8
arc.sim.set_input %model, "j" = %c : i8, !arc.sim.instance<@id>
// CHECK-NEXT: llvm.call @id_eval(%[[state]])
arc.sim.step %model : !arc.sim.instance<@id>
// CHECK-NEXT: llvm.call @id_eval(%[[state]])
arc.sim.step %model : !arc.sim.instance<@id>
// CHECK-NEXT: %[[o_ptr:.*]] = llvm.getelementptr %[[state]][2] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK-NEXT: %[[result:.*]] = llvm.load %[[o_ptr]] : !llvm.ptr -> i8
%result = arc.sim.get_port %model, "o" : i8, !arc.sim.instance<@id>
// CHECK-NEXT: %[[o_ptr:.*]] = llvm.getelementptr %[[state]][2] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK-NEXT: %[[result:.*]] = llvm.load %[[o_ptr]] : !llvm.ptr -> i8
%result = arc.sim.get_port %model, "o" : i8, !arc.sim.instance<@id>
// CHECK-DAG: %[[to_print:.*]] = llvm.inttoptr %[[result]] : i8 to !llvm.ptr
// CHECK-DAG: %[[format_str_ptr:.*]] = llvm.mlir.addressof @[[format_str]] : !llvm.ptr
// CHECK: llvm.call @printf(%[[format_str_ptr]], %[[to_print]])
arc.sim.emit "result", %result : i8
// CHECK-DAG: %[[to_print:.*]] = llvm.zext %[[result]] : i8 to i16
// CHECK-DAG: %[[format_str_ptr:.*]] = llvm.mlir.addressof @[[format_str]] : !llvm.ptr
// CHECK: llvm.call @printf(%[[format_str_ptr]], %[[to_print]])
arc.sim.emit "result", %result : i8
// CHECK-DAG: %[[format_str2_ptr:.*]] = llvm.mlir.addressof @[[format_str2]] : !llvm.ptr
// CHECK: llvm.call @printf(%[[format_str2_ptr]], %[[to_print]])
arc.sim.emit "result2", %result : i8
// CHECK-DAG: %[[format_str2_ptr:.*]] = llvm.mlir.addressof @[[format_str2]] : !llvm.ptr
// CHECK: llvm.call @printf(%[[format_str2_ptr]], %[[to_print]])
arc.sim.emit "result2", %result : i8
// CHECK: llvm.call @printf(%[[format_str_ptr]], %[[to_print]])
arc.sim.emit "result", %result : i8
// CHECK: llvm.call @printf(%[[format_str_ptr]], %[[to_print]])
arc.sim.emit "result", %result : i8
}
// CHECK: llvm.call @free(%[[state]])
return
}
// CHECK-LABEL: llvm.func @trunc
func.func @trunc() {
%v = arith.constant 0 : i32
// CHECK-DAG: %[[val_i32:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[val_truncated:.*]] = llvm.trunc %[[val_i32]] : i32 to i16
// CHECK-DAG: %[[format_str_trunc_ptr:.*]] = llvm.mlir.addressof @[[format_str_trunc]] : !llvm.ptr
// CHECK-DAG: llvm.call @printf(%[[format_str_trunc_ptr]], %[[val_truncated]])
arc.sim.emit "result", %v : i32
return
}
}

View File

@ -1,4 +1,10 @@
set(LLVM_LINK_COMPONENTS Support)
if(ARCILATOR_JIT_ENABLED)
add_compile_definitions(ARCILATOR_ENABLE_JIT)
set(ARCILATOR_JIT_LLVM_COMPONENTS native)
set(ARCILATOR_JIT_DEPS MLIRExecutionEngine)
endif()
set(LLVM_LINK_COMPONENTS Support ${ARCILATOR_JIT_LLVM_COMPONENTS})
add_circt_tool(arcilator arcilator.cpp)
target_link_libraries(arcilator
@ -16,11 +22,14 @@ target_link_libraries(arcilator
CIRCTSupport
CIRCTTransforms
MLIRBuiltinToLLVMIRTranslation
MLIRDLTIDialect
MLIRFuncInlinerExtension
MLIRLLVMIRTransforms
MLIRLLVMToLLVMIRTranslation
MLIRParser
MLIRTargetLLVMIRExport
${ARCILATOR_JIT_DEPS}
)
llvm_update_compile_flags(arcilator)

View File

@ -29,11 +29,14 @@
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OperationSupport.h"
@ -55,12 +58,12 @@
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/ToolOutputFile.h"
#include <iostream>
#include <optional>
using namespace llvm;
using namespace mlir;
using namespace circt;
using namespace arc;
@ -69,89 +72,97 @@ using namespace arc;
// Command Line Arguments
//===----------------------------------------------------------------------===//
static cl::OptionCategory mainCategory("arcilator Options");
static llvm::cl::OptionCategory mainCategory("arcilator Options");
static cl::opt<std::string> inputFilename(cl::Positional,
cl::desc("<input file>"),
cl::init("-"), cl::cat(mainCategory));
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"),
llvm::cl::cat(mainCategory));
static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"),
cl::cat(mainCategory));
static llvm::cl::opt<std::string>
outputFilename("o", llvm::cl::desc("Output filename"),
llvm::cl::value_desc("filename"), llvm::cl::init("-"),
llvm::cl::cat(mainCategory));
static cl::opt<bool> observePorts("observe-ports",
cl::desc("Make all ports observable"),
cl::init(false), cl::cat(mainCategory));
static llvm::cl::opt<bool>
observePorts("observe-ports", llvm::cl::desc("Make all ports observable"),
llvm::cl::init(false), llvm::cl::cat(mainCategory));
static cl::opt<bool> observeWires("observe-wires",
cl::desc("Make all wires observable"),
cl::init(false), cl::cat(mainCategory));
static llvm::cl::opt<bool>
observeWires("observe-wires", llvm::cl::desc("Make all wires observable"),
llvm::cl::init(false), llvm::cl::cat(mainCategory));
static cl::opt<bool>
observeNamedValues("observe-named-values",
cl::desc("Make values with `sv.namehint` observable"),
cl::init(false), cl::cat(mainCategory));
static llvm::cl::opt<bool> observeNamedValues(
"observe-named-values",
llvm::cl::desc("Make values with `sv.namehint` observable"),
llvm::cl::init(false), llvm::cl::cat(mainCategory));
static cl::opt<bool> observeRegisters("observe-registers",
cl::desc("Make all registers observable"),
cl::init(false), cl::cat(mainCategory));
static llvm::cl::opt<bool>
observeRegisters("observe-registers",
llvm::cl::desc("Make all registers observable"),
llvm::cl::init(false), llvm::cl::cat(mainCategory));
static cl::opt<bool>
static llvm::cl::opt<bool>
observeMemories("observe-memories",
cl::desc("Make all memory contents observable"),
cl::init(false), cl::cat(mainCategory));
llvm::cl::desc("Make all memory contents observable"),
llvm::cl::init(false), llvm::cl::cat(mainCategory));
static cl::opt<std::string> stateFile("state-file", cl::desc("State file"),
cl::value_desc("filename"), cl::init(""),
cl::cat(mainCategory));
static llvm::cl::opt<std::string> stateFile("state-file",
llvm::cl::desc("State file"),
llvm::cl::value_desc("filename"),
llvm::cl::init(""),
llvm::cl::cat(mainCategory));
static cl::opt<bool> shouldInline("inline", cl::desc("Inline arcs"),
cl::init(true), cl::cat(mainCategory));
static llvm::cl::opt<bool> shouldInline("inline", llvm::cl::desc("Inline arcs"),
llvm::cl::init(true),
llvm::cl::cat(mainCategory));
static cl::opt<bool> shouldDedup("dedup", cl::desc("Deduplicate arcs"),
cl::init(true), cl::cat(mainCategory));
static llvm::cl::opt<bool> shouldDedup("dedup",
llvm::cl::desc("Deduplicate arcs"),
llvm::cl::init(true),
llvm::cl::cat(mainCategory));
static cl::opt<bool> shouldDetectEnables(
static llvm::cl::opt<bool> shouldDetectEnables(
"detect-enables",
cl::desc("Infer enable conditions for states to avoid computation"),
cl::init(true), cl::cat(mainCategory));
llvm::cl::desc("Infer enable conditions for states to avoid computation"),
llvm::cl::init(true), llvm::cl::cat(mainCategory));
static cl::opt<bool> shouldDetectResets(
static llvm::cl::opt<bool> shouldDetectResets(
"detect-resets",
cl::desc("Infer reset conditions for states to avoid computation"),
cl::init(false), cl::cat(mainCategory));
llvm::cl::desc("Infer reset conditions for states to avoid computation"),
llvm::cl::init(false), llvm::cl::cat(mainCategory));
static cl::opt<bool>
static llvm::cl::opt<bool>
shouldMakeLUTs("lookup-tables",
cl::desc("Optimize arcs into lookup tables"), cl::init(true),
cl::cat(mainCategory));
llvm::cl::desc("Optimize arcs into lookup tables"),
llvm::cl::init(true), llvm::cl::cat(mainCategory));
static cl::opt<bool> printDebugInfo("print-debug-info",
cl::desc("Print debug information"),
cl::init(false), cl::cat(mainCategory));
static llvm::cl::opt<bool>
printDebugInfo("print-debug-info",
llvm::cl::desc("Print debug information"),
llvm::cl::init(false), llvm::cl::cat(mainCategory));
static cl::opt<bool>
verifyPasses("verify-each",
cl::desc("Run the verifier after each transformation pass"),
cl::init(true), cl::cat(mainCategory));
static llvm::cl::opt<bool> verifyPasses(
"verify-each",
llvm::cl::desc("Run the verifier after each transformation pass"),
llvm::cl::init(true), llvm::cl::cat(mainCategory));
static cl::opt<bool>
verifyDiagnostics("verify-diagnostics",
cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
cl::init(false), cl::Hidden, cl::cat(mainCategory));
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false), llvm::cl::Hidden, llvm::cl::cat(mainCategory));
static cl::opt<bool>
verbosePassExecutions("verbose-pass-executions",
cl::desc("Log executions of toplevel module passes"),
cl::init(false), cl::cat(mainCategory));
static llvm::cl::opt<bool> verbosePassExecutions(
"verbose-pass-executions",
llvm::cl::desc("Log executions of toplevel module passes"),
llvm::cl::init(false), llvm::cl::cat(mainCategory));
static cl::opt<bool>
splitInputFile("split-input-file",
cl::desc("Split the input file into pieces and process each "
"chunk independently"),
cl::init(false), cl::Hidden, cl::cat(mainCategory));
static llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false), llvm::cl::Hidden, llvm::cl::cat(mainCategory));
// Options to control early-out from pipeline.
enum Until {
@ -163,7 +174,7 @@ enum Until {
UntilLLVMLowering,
UntilEnd
};
static auto runUntilValues = cl::values(
static auto runUntilValues = llvm::cl::values(
clEnumValN(UntilPreprocessing, "preproc", "Input preprocessing"),
clEnumValN(UntilArcConversion, "arc-conv", "Conversion of modules to arcs"),
clEnumValN(UntilArcOpt, "arc-opt", "Arc optimizations"),
@ -171,24 +182,30 @@ static auto runUntilValues = cl::values(
clEnumValN(UntilStateAlloc, "state-alloc", "State allocation"),
clEnumValN(UntilLLVMLowering, "llvm-lowering", "Lowering to LLVM"),
clEnumValN(UntilEnd, "all", "Run entire pipeline (default)"));
static cl::opt<Until>
runUntilBefore("until-before",
cl::desc("Stop pipeline before a specified point"),
runUntilValues, cl::init(UntilEnd), cl::cat(mainCategory));
static cl::opt<Until>
runUntilAfter("until-after",
cl::desc("Stop pipeline after a specified point"),
runUntilValues, cl::init(UntilEnd), cl::cat(mainCategory));
static llvm::cl::opt<Until> runUntilBefore(
"until-before", llvm::cl::desc("Stop pipeline before a specified point"),
runUntilValues, llvm::cl::init(UntilEnd), llvm::cl::cat(mainCategory));
static llvm::cl::opt<Until> runUntilAfter(
"until-after", llvm::cl::desc("Stop pipeline after a specified point"),
runUntilValues, llvm::cl::init(UntilEnd), llvm::cl::cat(mainCategory));
// Options to control the output format.
enum OutputFormat { OutputMLIR, OutputLLVM, OutputDisabled };
static cl::opt<OutputFormat> outputFormat(
cl::desc("Specify output format"),
cl::values(clEnumValN(OutputMLIR, "emit-mlir", "Emit MLIR dialects"),
clEnumValN(OutputLLVM, "emit-llvm", "Emit LLVM"),
clEnumValN(OutputDisabled, "disable-output",
"Do not output anything")),
cl::init(OutputLLVM), cl::cat(mainCategory));
enum OutputFormat { OutputMLIR, OutputLLVM, OutputRunJIT, OutputDisabled };
static llvm::cl::opt<OutputFormat> outputFormat(
llvm::cl::desc("Specify output format"),
llvm::cl::values(clEnumValN(OutputMLIR, "emit-mlir", "Emit MLIR dialects"),
clEnumValN(OutputLLVM, "emit-llvm", "Emit LLVM"),
clEnumValN(OutputRunJIT, "run",
"Run the simulation and emit its output"),
clEnumValN(OutputDisabled, "disable-output",
"Do not output anything")),
llvm::cl::init(OutputLLVM), llvm::cl::cat(mainCategory));
static llvm::cl::opt<std::string>
jitEntryPoint("jit-entry",
llvm::cl::desc("Name of the function containing the "
"simulation to run when output is set to run"),
llvm::cl::init("entry"), llvm::cl::cat(mainCategory));
//===----------------------------------------------------------------------===//
// Main Tool Logic
@ -380,6 +397,64 @@ static LogicalResult processBuffer(
if (failed(pmLlvm.run(module.get())))
return failure();
#ifdef ARCILATOR_ENABLE_JIT
// Handle JIT execution.
if (outputFormat == OutputRunJIT) {
Operation *toCall = module->lookupSymbol(jitEntryPoint);
if (!toCall) {
llvm::errs() << "entry point not found: '" << jitEntryPoint << "'\n";
return failure();
}
auto toCallFunc = llvm::dyn_cast<LLVM::LLVMFuncOp>(toCall);
if (!toCallFunc) {
llvm::errs() << "entry point '" << jitEntryPoint
<< "' was found but on an operation of type '"
<< toCall->getName()
<< "' while an LLVM function was expected\n";
return failure();
}
if (toCallFunc.getNumArguments() != 0) {
llvm::errs() << "entry point '" << jitEntryPoint
<< "' must have no arguments\n";
return failure();
}
mlir::ExecutionEngineOptions engineOptions;
engineOptions.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
engineOptions.transformer = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
auto executionEngine =
mlir::ExecutionEngine::create(module.get(), engineOptions);
if (!executionEngine) {
llvm::handleAllErrors(
executionEngine.takeError(), [](const llvm::ErrorInfoBase &info) {
llvm::errs() << "failed to create execution engine: "
<< info.message() << "\n";
});
return failure();
}
auto expectedFunc = (*executionEngine)->lookupPacked(jitEntryPoint);
if (!expectedFunc) {
llvm::handleAllErrors(
expectedFunc.takeError(), [](const llvm::ErrorInfoBase &info) {
llvm::errs() << "failed to run simulation: " << info.message()
<< "\n";
});
return failure();
}
void (*simulationFunc)(void **) = *expectedFunc;
(*simulationFunc)(nullptr);
return success();
}
#endif // ARCILATOR_ENABLE_JIT
// Handle MLIR output.
if (runUntilBefore != UntilEnd || runUntilAfter != UntilEnd ||
outputFormat == OutputMLIR) {
@ -438,7 +513,7 @@ processInput(MLIRContext &context, TimingScope &ts,
return splitAndProcessBuffer(
std::move(input),
[&](std::unique_ptr<MemoryBuffer> buffer, raw_ostream &) {
[&](std::unique_ptr<llvm::MemoryBuffer> buffer, raw_ostream &) {
return processInputSplit(context, ts, std::move(buffer), outputFile);
},
llvm::outs());
@ -475,6 +550,7 @@ static LogicalResult executeArcilator(MLIRContext &context) {
comb::CombDialect,
emit::EmitDialect,
hw::HWDialect,
mlir::DLTIDialect,
mlir::LLVM::LLVMDialect,
mlir::arith::ArithDialect,
mlir::cf::ControlFlowDialect,
@ -510,11 +586,11 @@ static LogicalResult executeArcilator(MLIRContext &context) {
/// can `exit(0)` at the end of the program to avoid teardown of the MLIRContext
/// and modules inside of it (reducing compile time).
int main(int argc, char **argv) {
InitLLVM y(argc, argv);
llvm::InitLLVM y(argc, argv);
// Hide default LLVM options, other than for this tool.
// MLIR options are added below.
cl::HideUnrelatedOptions(mainCategory);
llvm::cl::HideUnrelatedOptions(mainCategory);
// Register passes before parsing command-line options, so that they are
// available for use with options like `--mlir-print-ir-before`.
@ -534,11 +610,26 @@ int main(int argc, char **argv) {
registerPassManagerCLOptions();
registerDefaultTimingManagerCLOptions();
registerAsmPrinterCLOptions();
cl::AddExtraVersionPrinter(
llvm::cl::AddExtraVersionPrinter(
[](raw_ostream &os) { os << getCirctVersion() << '\n'; });
// Parse pass names in main to ensure static initialization completed.
cl::ParseCommandLineOptions(argc, argv, "MLIR-based circuit simulator\n");
llvm::cl::ParseCommandLineOptions(argc, argv,
"MLIR-based circuit simulator\n");
if (outputFormat == OutputRunJIT) {
#ifdef ARCILATOR_ENABLE_JIT
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
#else
llvm::errs() << "This arcilator binary was not built with JIT support.\n";
llvm::errs() << "To enable JIT features, build arcilator with MLIR's "
"execution engine.\n";
llvm::errs() << "This can be achieved by building arcilator with the "
"host's LLVM target enabled.\n";
exit(1);
#endif // ARCILATOR_ENABLE_JIT
}
MLIRContext context;
auto result = executeArcilator(context);