diff --git a/deploy/runtools/firesim_topology_elements.py b/deploy/runtools/firesim_topology_elements.py index 4f5a2c9a..d72f8911 100644 --- a/deploy/runtools/firesim_topology_elements.py +++ b/deploy/runtools/firesim_topology_elements.py @@ -327,7 +327,6 @@ class FireSimServerNode(FireSimNode): all_paths.append([self.server_hardware_config.get_local_driver_path(), '']) all_paths.append([self.server_hardware_config.get_local_runtime_conf_path(), '']) - all_paths.append([self.server_hardware_config.get_local_assert_def_path(), '']) # shared libraries all_paths.append(["$RISCV/lib/libdwarf.so", "libdwarf.so.1"]) @@ -508,7 +507,6 @@ class FireSimSuperNodeServerNode(FireSimServerNode): all_paths.append([self.server_hardware_config.get_local_driver_path(), '']) all_paths.append([self.server_hardware_config.get_local_runtime_conf_path(), '']) - all_paths.append([self.server_hardware_config.get_local_assert_def_path(), '']) return all_paths class FireSimDummyServerNode(FireSimServerNode): diff --git a/deploy/runtools/runtime_config.py b/deploy/runtools/runtime_config.py index e103513e..3fc92123 100644 --- a/deploy/runtools/runtime_config.py +++ b/deploy/runtools/runtime_config.py @@ -79,14 +79,6 @@ class RuntimeHWConfig: runtime_conf_local = CUSTOM_RUNTIMECONFS_BASE + my_runtimeconfig return runtime_conf_local - # TODO: Delete this and bake the assertion definitions into the Driver - def get_local_assert_def_path(self): - """ return relative local path of the synthesized assertion definitions. """ - my_deploytriplet = self.get_deploytriplet_for_config() - gen_src_dir = LOCAL_DRIVERS_GENERATED_SRC + "/" + my_deploytriplet + "/" - assert_def_local = gen_src_dir + self.get_design_name() + ".asserts" - return assert_def_local - def get_boot_simulation_command(self, slotid, all_macs, all_rootfses, all_linklatencies, all_netbws, profile_interval, @@ -102,8 +94,8 @@ class RuntimeHWConfig: runtime parameters currently. """ # TODO: supernode support - tracefile = "+tracefile0=TRACEFILE0" if trace_enable else "" - autocounterfile = "+autocounter-filename0=AUTOCOUNTERFILE0" + tracefile = "+tracefile=TRACEFILE" if trace_enable else "" + autocounterfile = "+autocounter-filename=AUTOCOUNTERFILE" # this monstrosity boots the simulator, inside screen, inside script # the sed is in there to get rid of newlines in runtime confs @@ -135,10 +127,10 @@ class RuntimeHWConfig: zero_out_dram = "+zero-out-dram" if (enable_zerooutdram) else "" # TODO supernode support - dwarf_file_name = "+dwarf-file-name0=" + all_bootbinaries[0] + "-dwarf" + dwarf_file_name = "+dwarf-file-name=" + all_bootbinaries[0] + "-dwarf" - # TODO: supernode support (tracefile0, trace-select0.. etc) - basecommand = """screen -S fsim{slotid} -d -m bash -c "script -f -c 'stty intr ^] && sudo sudo LD_LIBRARY_PATH=.:$LD_LIBRARY_PATH ./{driver} +permissive $(sed \':a;N;$!ba;s/\\n/ /g\' {runtimeconf}) +slotid={slotid} +profile-interval={profile_interval} {zero_out_dram} {command_macs} {command_rootfses} {command_niclogs} {command_blkdev_logs} {tracefile} +trace-select0={trace_select} +trace-start0={trace_start} +trace-end0={trace_end} +trace-output-format0={trace_output_format} {dwarf_file_name} +autocounter-readrate0={autocounter_readrate} {autocounterfile} {command_linklatencies} {command_netbws} {command_shmemportnames} +permissive-off {command_bootbinaries} && stty intr ^c' uartlog"; sleep 1""".format( + # TODO: supernode support (tracefile, trace-select.. etc) + basecommand = """screen -S fsim{slotid} -d -m bash -c "script -f -c 'stty intr ^] && sudo sudo LD_LIBRARY_PATH=.:$LD_LIBRARY_PATH ./{driver} +permissive $(sed \':a;N;$!ba;s/\\n/ /g\' {runtimeconf}) +slotid={slotid} +profile-interval={profile_interval} {zero_out_dram} {command_macs} {command_rootfses} {command_niclogs} {command_blkdev_logs} {tracefile} +trace-select={trace_select} +trace-start={trace_start} +trace-end={trace_end} +trace-output-format={trace_output_format} {dwarf_file_name} +autocounter-readrate={autocounter_readrate} {autocounterfile} {command_linklatencies} {command_netbws} {command_shmemportnames} +permissive-off {command_bootbinaries} && stty intr ^c' uartlog"; sleep 1""".format( slotid=slotid, driver=driver, runtimeconf=runtimeconf, command_macs=command_macs, command_rootfses=command_rootfses, diff --git a/deploy/runtools/user_topology.py b/deploy/runtools/user_topology.py index 8272af45..2b9ada2a 100644 --- a/deploy/runtools/user_topology.py +++ b/deploy/runtools/user_topology.py @@ -370,9 +370,9 @@ class UserTopologies(object): # Spins up all of the precompiled, unnetworked targets def all_no_net_targets_config(self): hwdb_entries = [ - "fireboom-singlecore-no-nic-l2-llc4mb-ddr3", - "fireboom-singlecore-no-nic-l2-llc4mb-ddr3-ramopts", - "firesim-quadcore-no-nic-l2-llc4mb-ddr3", + "firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3", + "firesim-rocket-quadcore-no-nic-l2-llc4mb-ddr3", + "firesim-rocket-quadcore-no-nic-l2-llc4mb-ddr3-halfrate", ] assert len(hwdb_entries) == self.no_net_num_nodes self.roots = [FireSimServerNode(hwdb_entries[x]) for x in range(self.no_net_num_nodes)] diff --git a/deploy/sample-backup-configs/sample_config_build.ini b/deploy/sample-backup-configs/sample_config_build.ini index d7735ab3..08249883 100644 --- a/deploy/sample-backup-configs/sample_config_build.ini +++ b/deploy/sample-backup-configs/sample_config_build.ini @@ -18,7 +18,8 @@ firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3 firesim-boom-singlecore-nic-l2-llc4mb-ddr3 firesim-supernode-rocket-singlecore-nic-l2-lbp -firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3-ramopts +#firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3-ramopts +firesim-rocket-quadcore-no-nic-l2-llc4mb-ddr3-half-freq-uncore firesim-rocket-singlecore-gemmini-no-nic-l2-llc4mb-ddr3 @@ -33,7 +34,8 @@ firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3 firesim-boom-singlecore-nic-l2-llc4mb-ddr3 firesim-supernode-rocket-singlecore-nic-l2-lbp -firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3-ramopts +#firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3-ramopts +firesim-rocket-quadcore-no-nic-l2-llc4mb-ddr3-half-freq-uncore firesim-rocket-singlecore-gemmini-no-nic-l2-llc4mb-ddr3 diff --git a/deploy/sample-backup-configs/sample_config_build_recipes.ini b/deploy/sample-backup-configs/sample_config_build_recipes.ini index 0b4ff038..63bb340c 100644 --- a/deploy/sample-backup-configs/sample_config_build_recipes.ini +++ b/deploy/sample-backup-configs/sample_config_build_recipes.ini @@ -77,6 +77,13 @@ PLATFORM_CONFIG=F85MHz_BaseF1Config instancetype=z1d.2xlarge deploytriplet=None +# Multiclock Temporary Example:ncept Quad-core, Rocket-based recipes +[firesim-rocket-quadcore-no-nic-l2-llc4mb-ddr3-half-freq-uncore] +DESIGN=FireSimMulticlockPOC +TARGET_CONFIG=DDR3FRFCFSLLC4MB_FireSimQuadRocketMulticlockConfig +PLATFORM_CONFIG=F90MHz_BaseF1Config +instancetype=z1d.2xlarge +deploytriplet=None # MIDAS Examples -- BUILD SUPPORT ONLY; Can't launch driver correctly on runfarm [midasexamples-gcd] diff --git a/deploy/sample-backup-configs/sample_config_hwdb.ini b/deploy/sample-backup-configs/sample_config_hwdb.ini index 4e3411e5..86708558 100644 --- a/deploy/sample-backup-configs/sample_config_hwdb.ini +++ b/deploy/sample-backup-configs/sample_config_hwdb.ini @@ -8,38 +8,39 @@ # Only AGFIs for the latest release of FireSim are guaranteed to be available. # If you are using an older version of FireSim, you will need to generate your # own images. + [firesim-boom-singlecore-nic-l2-llc4mb-ddr3] -agfi=agfi-0d80edc1e51eeed2a +agfi=agfi-0b00913f35fd3b17b deploytripletoverride=None customruntimeconfig=None [firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3] -agfi=agfi-062b20613c52a2313 -deploytripletoverride=None -customruntimeconfig=None - -[firesim-boom-singlecore-no-nic-l2-llc4mb-ddr3-ramopts] -agfi=agfi-0cfb3cb54b7928d08 +agfi=agfi-06addb09b5f339914 deploytripletoverride=None customruntimeconfig=None [firesim-rocket-quadcore-nic-l2-llc4mb-ddr3] -agfi=agfi-0c98ad02959d4fd95 +agfi=agfi-0fb5684b14d71b79f deploytripletoverride=None customruntimeconfig=None [firesim-rocket-quadcore-no-nic-l2-llc4mb-ddr3] -agfi=agfi-04812c9eb510b7913 +agfi=agfi-05fb794ef6d3fa423 +deploytripletoverride=None +customruntimeconfig=None + +[firesim-rocket-quadcore-no-nic-l2-llc4mb-ddr3-half-freq-uncore] +agfi=agfi-0cde94735d54677e7 deploytripletoverride=None customruntimeconfig=None [firesim-rocket-singlecore-gemmini-no-nic-l2-llc4mb-ddr3] -agfi=agfi-006f97febc774fee1 +agfi=agfi-047da9dd6850b8e9d deploytripletoverride=None customruntimeconfig=None [firesim-supernode-rocket-singlecore-nic-l2-lbp] -agfi=agfi-05617d1eb6490d689 +agfi=agfi-0f0efddd34003118d deploytripletoverride=None customruntimeconfig=None diff --git a/sim/firesim-lib/src/main/cc/bridges/tracerv.cc b/sim/firesim-lib/src/main/cc/bridges/tracerv.cc index a0de08cc..e5743637 100644 --- a/sim/firesim-lib/src/main/cc/bridges/tracerv.cc +++ b/sim/firesim-lib/src/main/cc/bridges/tracerv.cc @@ -14,19 +14,6 @@ #include -// TODO: generate a header with these automatically - -// bitwidths for stuff in the trace. assume this order too. -#define VALID_WID 1 -#define IADDR_WID 40 -#define INSN_WID 32 -#define PRIV_WID 3 -#define EXCP_WID 1 -#define INT_WID 1 -#define CAUSE_WID 8 -#define TVAL_WID 40 -#define TOTAL_WID (VALID_WID + IADDR_WID + INSN_WID + PRIV_WID + EXCP_WID + INT_WID + CAUSE_WID + TVAL_WID) - // The maximum number of beats available in the FPGA-side FIFO #define QUEUE_DEPTH 6144 @@ -34,18 +21,28 @@ // useful for iterating on software side only without re-running on FPGA. //#define FIREPERF_LOGGER +constexpr uint64_t valid_mask = (1ULL << 40); + tracerv_t::tracerv_t( - simif_t *sim, std::vector &args, TRACERVBRIDGEMODULE_struct * mmio_addrs, int tracerno, long dma_addr) : bridge_driver_t(sim) -{ - static_assert(NUM_CORES <= 7, "TRACERV CURRENTLY ONLY SUPPORT <= 7 Cores/Instruction Streams"); - this->mmio_addrs = mmio_addrs; - this->dma_addr = dma_addr; + simif_t *sim, + std::vector &args, + TRACERVBRIDGEMODULE_struct * mmio_addrs, + long dma_addr, + const unsigned int max_core_ipc, + const char* const clock_domain_name, + const unsigned int clock_multiplier, + const unsigned int clock_divisor, + int tracerno) : + bridge_driver_t(sim), + mmio_addrs(mmio_addrs), + max_core_ipc(max_core_ipc), + clock_info(clock_domain_name, clock_multiplier, clock_divisor), + dma_addr(dma_addr) { + //Biancolin: move into elaboration + assert(this->max_core_ipc <= 7 && "TracerV only supports cores with a maximum IPC <= 7"); const char *tracefilename = NULL; const char *dwarf_file_name = NULL; - - for (int i = 0; i < NUM_CORES; i++) { - this->tracefiles[i] = NULL; - } + this->tracefile = NULL; this->trace_trigger_start = 0; this->trace_trigger_end = ULONG_MAX; @@ -55,18 +52,18 @@ tracerv_t::tracerv_t( long outputfmtselect = 0; - std::string num_equals = std::to_string(tracerno) + std::string("="); - std::string tracefile_arg = std::string("+tracefile") + num_equals; - std::string tracestart_arg = std::string("+trace-start") + num_equals; - std::string traceend_arg = std::string("+trace-end") + num_equals; - std::string traceselect_arg = std::string("+trace-select") + num_equals; + std::string suffix = std::string("="); + std::string tracefile_arg = std::string("+tracefile") + suffix; + std::string tracestart_arg = std::string("+trace-start") + suffix; + std::string traceend_arg = std::string("+trace-end") + suffix; + std::string traceselect_arg = std::string("+trace-select") + suffix; // Testing: provides a reference file to diff the collected trace against - std::string testoutput_arg = std::string("+trace-test-output") + std::to_string(tracerno); + std::string testoutput_arg = std::string("+trace-test-output"); // Formats the output before dumping the trace to file - std::string humanreadable_arg = std::string("+trace-humanreadable") + std::to_string(tracerno); + std::string humanreadable_arg = std::string("+trace-humanreadable"); - std::string trace_output_format_arg = std::string("+trace-output-format") + num_equals; - std::string dwarf_file_arg = std::string("+dwarf-file-name") + num_equals; + std::string trace_output_format_arg = std::string("+trace-output-format") + suffix; + std::string dwarf_file_arg = std::string("+dwarf-file-name") + suffix; for (auto &arg: args) { if (arg.find(tracefile_arg) == 0) { @@ -77,15 +74,26 @@ tracerv_t::tracerv_t( char *str = const_cast(arg.c_str()) + traceselect_arg.length(); this->trigger_selector = atol(str); } + // These next two arguments are overloaded to provide trigger start and + // stop condition information based on setting of the +trace-select if (arg.find(tracestart_arg) == 0) { + // Start and end cycles are given in decimal char *str = const_cast(arg.c_str()) + tracestart_arg.length(); - char * pEnd; - this->trace_trigger_start = trigger_selector==1 ? atol(str) : strtoul (str,&pEnd,16); + this->trace_trigger_start = this->clock_info.to_local_cycles(atol(str)); + // PCs values, and instruction and mask encodings are given in hex + uint64_t mask_and_insn = strtoul(str, NULL, 16); + this->trigger_start_insn = (uint32_t) mask_and_insn; + this->trigger_start_insn_mask = mask_and_insn >> 32; + this->trigger_start_pc = mask_and_insn; } if (arg.find(traceend_arg) == 0) { char *str = const_cast(arg.c_str()) + traceend_arg.length(); - char * pEnd; - this->trace_trigger_end = trigger_selector==1 ? atol(str) : strtoul (str,&pEnd,16); + this->trace_trigger_end = this->clock_info.to_local_cycles(atol(str)); + + uint64_t mask_and_insn = strtoul(str, NULL, 16); + this->trigger_stop_insn = (uint32_t) mask_and_insn; + this->trigger_stop_insn_mask = mask_and_insn >> 32; + this->trigger_stop_pc = mask_and_insn; } if (arg.find(testoutput_arg) == 0) { this->test_output = true; @@ -102,14 +110,13 @@ tracerv_t::tracerv_t( if (tracefilename) { // giving no tracefilename means we will create NO tracefiles - for (int i = 0; i < NUM_CORES; i++) { - std::string tfname = std::string(tracefilename) + std::string("-C") + std::to_string(i); - this->tracefiles[i] = fopen(tfname.c_str(), "w"); - if (!this->tracefiles[i]) { - fprintf(stderr, "Could not open Trace log file: %s\n", tracefilename); - abort(); - } + std::string tfname = std::string(tracefilename) + std::string("-C") + std::to_string(tracerno); + this->tracefile = fopen(tfname.c_str(), "w"); + if (!this->tracefile) { + fprintf(stderr, "Could not open Trace log file: %s\n", tracefilename); + abort(); } + fputs(this->clock_info.file_header().c_str(), this->tracefile); // This must be kept consistent with config_runtime.ini's output_format. // That file's comments are the single source of truth for this. @@ -126,7 +133,8 @@ tracerv_t::tracerv_t( fprintf(stderr, "Invalid trace format arg\n"); } } else { - fprintf(stderr, "TraceRV: Warning: No +tracefileN given!\n"); + fprintf(stderr, "TraceRV %d: Tracing disabled, since +tracefile was not provided.\n", tracerno); + this->trace_enabled = false; } if (fireperf) { @@ -134,23 +142,28 @@ tracerv_t::tracerv_t( fprintf(stderr, "+fireperf specified but no +dwarf-file-name given\n"); abort(); } - for (int i = 0; i < NUM_CORES; i++) { - this->trace_trackers[i] = new TraceTracker(this->dwarf_file_name, this->tracefiles[i]); - } + this->trace_tracker = new TraceTracker(this->dwarf_file_name, this->tracefile); } } tracerv_t::~tracerv_t() { - for (int i = 0; i < NUM_CORES; i++) { - if (this->tracefiles[i]) { - fclose(this->tracefiles[i]); - } + if (this->tracefile) { + fclose(this->tracefile); } free(this->mmio_addrs); } void tracerv_t::init() { - if (this->trigger_selector == 1) + if (!this->trace_enabled) { + // Explicitly disable token collection inthe bridge by only collecting + // tokens from cycle 0 to cycle 0, saving DMA bandwidth and improving FMR + write(this->mmio_addrs->triggerSelector, 1); + write(this->mmio_addrs->hostTriggerCycleCountStartHigh, 0); + write(this->mmio_addrs->hostTriggerCycleCountStartLow, 0); + write(this->mmio_addrs->hostTriggerCycleCountEndHigh, 0); + write(this->mmio_addrs->hostTriggerCycleCountEndLow, 0); + } + else if (this->trigger_selector == 1) { write(this->mmio_addrs->triggerSelector, this->trigger_selector); write(this->mmio_addrs->hostTriggerCycleCountStartHigh, this->trace_trigger_start >> 32); @@ -162,93 +175,97 @@ void tracerv_t::init() { else if (this->trigger_selector == 2) { write(this->mmio_addrs->triggerSelector, this->trigger_selector); - write(this->mmio_addrs->hostTriggerPCStartHigh, this->trace_trigger_start >> 32); - write(this->mmio_addrs->hostTriggerPCStartLow, this->trace_trigger_start & ((1ULL << 32) - 1)); - write(this->mmio_addrs->hostTriggerPCEndHigh, this->trace_trigger_end >> 32); - write(this->mmio_addrs->hostTriggerPCEndLow, this->trace_trigger_end & ((1ULL << 32) - 1)); - printf("TracerV: Collect trace from instruction address %lx to %lx\n", trace_trigger_start, trace_trigger_end); + write(this->mmio_addrs->hostTriggerPCStartHigh, this->trigger_start_pc >> 32); + write(this->mmio_addrs->hostTriggerPCStartLow, this->trigger_start_pc & ((1ULL << 32) - 1)); + write(this->mmio_addrs->hostTriggerPCEndHigh, this->trigger_stop_pc >> 32); + write(this->mmio_addrs->hostTriggerPCEndLow, this->trigger_stop_pc & ((1ULL << 32) - 1)); + printf("TracerV: Collect trace from instruction address %lx to %lx\n", trigger_start_pc, trigger_stop_pc); } else if (this->trigger_selector == 3) { write(this->mmio_addrs->triggerSelector, this->trigger_selector); - write(this->mmio_addrs->hostTriggerStartInst, this->trace_trigger_start & ((1ULL << 32) - 1)); - write(this->mmio_addrs->hostTriggerStartInstMask, this->trace_trigger_start >> 32); - write(this->mmio_addrs->hostTriggerEndInst, this->trace_trigger_end & ((1ULL << 32) - 1)); - write(this->mmio_addrs->hostTriggerEndInstMask, this->trace_trigger_end >> 32); + write(this->mmio_addrs->hostTriggerStartInst, this->trigger_start_insn); + write(this->mmio_addrs->hostTriggerStartInstMask, this->trigger_start_insn_mask); + write(this->mmio_addrs->hostTriggerEndInst, this->trigger_stop_insn); + write(this->mmio_addrs->hostTriggerEndInstMask, this->trigger_stop_insn_mask); printf("TracerV: Collect trace with start trigger instruction %x masked with %x, and end trigger instruction %x masked with %x\n", - this->trace_trigger_start & ((1ULL << 32) - 1), this->trace_trigger_start >> 32, - this->trace_trigger_end & ((1ULL << 32) - 1), this->trace_trigger_end >> 32); + this->trigger_start_insn, this->trigger_start_insn_mask, + this->trigger_stop_insn, this->trigger_stop_insn_mask); } else { + // Biancolin: should we not error here? write(this->mmio_addrs->triggerSelector, this->trigger_selector); - printf("TracerV: Collecting trace from %lu to %lu cycles\n", trace_trigger_start, trace_trigger_end); + printf("TracerV: No trigger selected. Collecting trace from %lu to %lu cycles\n", 0, ULONG_MAX); } + write(this->mmio_addrs->initDone, true); } -// defining this stores as human readable hex (e.g. open in VIM) -// undefining this stores as bin (e.g. open with vim hex mode) - -void tracerv_t::tick() { - uint64_t outfull = read(this->mmio_addrs->tracequeuefull); - +void tracerv_t::process_tokens(int num_beats) { + // TODO. as opt can mmap file and just load directly into it. alignas(4096) uint64_t OUTBUF[QUEUE_DEPTH * 8]; - - if (outfull) { - // TODO. as opt can mmap file and just load directly into it. - pull(dma_addr, (char*)OUTBUF, QUEUE_DEPTH * 64); - //check that a tracefile exists (one is enough) since the manager - //does not create a tracefile when trace_enable is disabled, but the - //TracerV bridge still exists, and no tracefiles are create be default. - if (this->tracefiles[0]) { - if (this->human_readable || this->test_output) { - for (int i = 0; i < QUEUE_DEPTH * 8; i+=8) { - if (this->test_output) { - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+7]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+6]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+5]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+4]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+3]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+2]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+1]); - fprintf(this->tracefiles[0], "%016lx\n", OUTBUF[i+0]); - } else { - for (int q = 0; q < NUM_CORES; q++) { - if ((OUTBUF[i+0+q] >> 40) & 0x1) { - fprintf(this->tracefiles[q], "C%d: %016llx, cycle: %016llx\n", q, OUTBUF[i+0+q], OUTBUF[i+7]); - } - } + pull(dma_addr, (char*)OUTBUF, num_beats * 64); + //check that a tracefile exists (one is enough) since the manager + //does not create a tracefile when trace_enable is disabled, but the + //TracerV bridge still exists, and no tracefile is created by default. + if (this->tracefile) { + if (this->human_readable || this->test_output) { + for (int i = 0; i < num_beats * 8; i+=8) { + if (this->test_output) { + fprintf(this->tracefile, "%016lx", OUTBUF[i+7]); + fprintf(this->tracefile, "%016lx", OUTBUF[i+6]); + fprintf(this->tracefile, "%016lx", OUTBUF[i+5]); + fprintf(this->tracefile, "%016lx", OUTBUF[i+4]); + fprintf(this->tracefile, "%016lx", OUTBUF[i+3]); + fprintf(this->tracefile, "%016lx", OUTBUF[i+2]); + fprintf(this->tracefile, "%016lx", OUTBUF[i+1]); + fprintf(this->tracefile, "%016lx\n", OUTBUF[i+0]); + // At least one valid instruction + } else { + for (int q = 0; q < max_core_ipc; q++) { + if (OUTBUF[i+q+1] & valid_mask) { + fprintf(this->tracefile, "Cycle: %016lld I%d: %016llx\n", OUTBUF[i+0], q, OUTBUF[i+q+1] & (~valid_mask)); + } else { + break; + } } } - } else if (this->fireperf) { - for (int i = 0; i < QUEUE_DEPTH * 8; i+=8) { - uint64_t cycle_internal = OUTBUF[i+7]; + } + } else if (this->fireperf) { - for (int q = 0; q < NUM_CORES; q++) { - if ((OUTBUF[i+0+q] >> 40) & 0x1) { - uint64_t iaddr = (uint64_t)((((int64_t)(OUTBUF[i+0+q])) << 24) >> 24); - this->trace_trackers[q]->addInstruction(iaddr, cycle_internal); + for (int i = 0; i < QUEUE_DEPTH * 8; i+=8) { + uint64_t cycle_internal = OUTBUF[i+0]; + + for (int q = 0; q < max_core_ipc; q++) { + if (OUTBUF[i+1+q] & valid_mask) { + uint64_t iaddr = (uint64_t)((((int64_t)(OUTBUF[i+1+q])) << 24) >> 24); + this->trace_tracker->addInstruction(iaddr, cycle_internal); #ifdef FIREPERF_LOGGER - fprintf(this->tracefiles[q], "%016llx", iaddr); - fprintf(this->tracefiles[q], "%016llx\n", cycle_internal); + fprintf(this->tracefile, "%016llx", iaddr); + fprintf(this->tracefile, "%016llx\n", cycle_internal); #endif //FIREPERF_LOGGER - } } } - } else { - for (int i = 0; i < QUEUE_DEPTH * 8; i+=8) { - // this stores as raw binary. stored as little endian. - // e.g. to get the same thing as the human readable above, - // flip all the bytes in each 512-bit line. - for (int q = 0; q < 8; q++) { - fwrite(OUTBUF + (i+q), sizeof(uint64_t), 1, this->tracefiles[0]); - } + } + } else { + for (int i = 0; i < QUEUE_DEPTH * 8; i+=8) { + // this stores as raw binary. stored as little endian. + // e.g. to get the same thing as the human readable above, + // flip all the bytes in each 512-bit line. + for (int q = 0; q < 8; q++) { + fwrite(OUTBUF + (i+q), sizeof(uint64_t), 1, this->tracefile); } } } } } +void tracerv_t::tick() { + if (this->trace_enabled) { + uint64_t outfull = read(this->mmio_addrs->tracequeuefull); + if (outfull) process_tokens(QUEUE_DEPTH); + } +} int tracerv_t::beats_available_stable() { size_t prev_beats_available = 0; @@ -264,64 +281,9 @@ int tracerv_t::beats_available_stable() { // Pull in any remaining tokens and flush them to file // WARNING: may not function correctly if the simulator is actively running void tracerv_t::flush() { - - alignas(4096) uint64_t OUTBUF[QUEUE_DEPTH * 8]; - size_t beats_available = beats_available_stable(); - - // TODO. as opt can mmap file and just load directly into it. - pull(dma_addr, (char*)OUTBUF, beats_available * 64); - //check that a tracefile exists (one is enough) since the manager - //does not create a tracefile when trace_enable is disabled, but the - //TracerV bridge still exists, and no tracefiles are create be default. - if (this->tracefiles[0]) { - if (this->human_readable || this->test_output) { - for (int i = 0; i < beats_available * 8; i+=8) { - - if (this->test_output) { - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+7]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+6]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+5]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+4]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+3]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+2]); - fprintf(this->tracefiles[0], "%016lx", OUTBUF[i+1]); - fprintf(this->tracefiles[0], "%016lx\n", OUTBUF[i+0]); - } else { - for (int q = 0; q < NUM_CORES; q++) { - if ((OUTBUF[i+0+q] >> 40) & 0x1) { - fprintf(this->tracefiles[q], "C%d: %016llx, cycle: %016llx\n", q, OUTBUF[i+0+q], OUTBUF[i+7]); - } - } - } - } - } else if (this->fireperf) { - for (int i = 0; i < beats_available * 8; i+=8) { - uint64_t cycle_internal = OUTBUF[i+7]; - - for (int q = 0; q < NUM_CORES; q++) { - if ((OUTBUF[i+0+q] >> 40) & 0x1) { - // is a valid instruction - // - // sign extended from sv39 - uint64_t iaddr = (uint64_t)((((int64_t)(OUTBUF[i+0+q])) << 24) >> 24); - this->trace_trackers[q]->addInstruction(iaddr, cycle_internal); -#ifdef FIREPERF_LOGGER - fprintf(this->tracefiles[q], "%016llx", iaddr); - fprintf(this->tracefiles[q], "%016llx\n", cycle_internal); -#endif //FIREPERF_LOGGER - } - } - } - } else { - for (int i = 0; i < beats_available * 8; i+=8) { - // this stores as raw binary. stored as little endian. - // e.g. to get the same thing as the human readable above, - // flip all the bytes in each 512-bit line. - for (int q = 0; q < 8; q++) { - fwrite(OUTBUF + (i+q), sizeof(uint64_t), 1, this->tracefiles[0]); - } - } - } + if (this->trace_enabled) { + size_t beats_available = beats_available_stable(); + process_tokens(beats_available); } } #endif // TRACERVBRIDGEMODULE_struct_guard diff --git a/sim/firesim-lib/src/main/cc/bridges/tracerv.h b/sim/firesim-lib/src/main/cc/bridges/tracerv.h index fed48cf3..45644336 100644 --- a/sim/firesim-lib/src/main/cc/bridges/tracerv.h +++ b/sim/firesim-lib/src/main/cc/bridges/tracerv.h @@ -3,20 +3,40 @@ #define __TRACERV_H #include "bridges/bridge_driver.h" +#include "bridges/clock_info.h" #include #include "bridges/tracerv/tracerv_processing.h" #include "bridges/tracerv/trace_tracker.h" -// TODO: get this automatically -#define NUM_CORES 1 - #ifdef TRACERVBRIDGEMODULE_struct_guard + +// Bridge Driver Instantiation Template +#define INSTANTIATE_TRACERV(FUNC,IDX) \ + TRACERVBRIDGEMODULE_ ## IDX ## _substruct_create; \ + FUNC(new tracerv_t( \ + this, \ + args, \ + TRACERVBRIDGEMODULE_ ## IDX ## _substruct, \ + TRACERVBRIDGEMODULE_ ## IDX ## _DMA_ADDR, \ + TRACERVBRIDGEMODULE_ ## IDX ## _max_core_ipc, \ + TRACERVBRIDGEMODULE_ ## IDX ## _clock_domain_name, \ + TRACERVBRIDGEMODULE_ ## IDX ## _clock_multiplier, \ + TRACERVBRIDGEMODULE_ ## IDX ## _clock_divisor, \ + IDX)); \ + class tracerv_t: public bridge_driver_t { public: - tracerv_t(simif_t *sim, std::vector &args, - TRACERVBRIDGEMODULE_struct * mmio_addrs, int tracervno, long dma_addr); + tracerv_t(simif_t *sim, + std::vector &args, + TRACERVBRIDGEMODULE_struct * mmio_addrs, + long dma_addr, + const unsigned int max_core_ipc, + const char* const clock_domain_name, + const unsigned int clock_multiplier, + const unsigned int clock_divisor, + int tracerno); ~tracerv_t(); virtual void init(); @@ -27,25 +47,38 @@ class tracerv_t: public bridge_driver_t private: TRACERVBRIDGEMODULE_struct * mmio_addrs; - simif_t* sim; - FILE * tracefiles[NUM_CORES]; + const int max_core_ipc; + ClockInfo clock_info; + + FILE * tracefile; uint64_t cur_cycle; uint64_t trace_trigger_start, trace_trigger_end; + uint32_t trigger_start_insn = 0; + uint32_t trigger_start_insn_mask = 0; + uint32_t trigger_stop_insn = 0; + uint32_t trigger_stop_insn_mask = 0; uint32_t trigger_selector; + uint64_t trigger_start_pc = 0; + uint64_t trigger_stop_pc = 0; // TODO: rename this from linuxbin ObjdumpedBinary * linuxbin; - TraceTracker * trace_trackers[NUM_CORES]; + TraceTracker * trace_tracker; bool human_readable = false; + // If no filename is provided, the instruction trace is not collected + // and the bridge drops all tokens to improve FMR + bool trace_enabled = true; // Used in unit testing to check TracerV is correctly pulling instuctions off the target bool test_output = false; long dma_addr; - void flush(); - int beats_available_stable(); std::string tracefilename; std::string dwarf_file_name; bool fireperf = false; + + void process_tokens(int num_beats); + int beats_available_stable(); + void flush(); }; #endif // TRACERVBRIDGEMODULE_struct_guard diff --git a/sim/firesim-lib/src/main/scala/bridges/BlockDevBridge.scala b/sim/firesim-lib/src/main/scala/bridges/BlockDevBridge.scala index 2501b23d..faf52f80 100644 --- a/sim/firesim-lib/src/main/scala/bridges/BlockDevBridge.scala +++ b/sim/firesim-lib/src/main/scala/bridges/BlockDevBridge.scala @@ -14,6 +14,7 @@ import testchipip.{BlockDeviceIO, BlockDeviceRequest, BlockDeviceData, BlockDevi class BlockDevBridgeTargetIO(implicit val p: Parameters) extends Bundle { val bdev = Flipped(new BlockDeviceIO) val reset = Input(Bool()) + val clock = Input(Clock()) } class BlockDevBridge(implicit p: Parameters) extends BlackBox @@ -25,9 +26,10 @@ class BlockDevBridge(implicit p: Parameters) extends BlackBox } object BlockDevBridge { - def apply(blkdevIO: BlockDeviceIO, reset: Bool)(implicit p: Parameters): BlockDevBridge = { + def apply(clock: Clock, blkdevIO: BlockDeviceIO, reset: Bool)(implicit p: Parameters): BlockDevBridge = { val ep = Module(new BlockDevBridge) ep.io.bdev <> blkdevIO + ep.io.clock := clock ep.io.reset := reset ep } diff --git a/sim/firesim-lib/src/main/scala/bridges/GroundTestBridge.scala b/sim/firesim-lib/src/main/scala/bridges/GroundTestBridge.scala index 1a9c278f..1233a505 100644 --- a/sim/firesim-lib/src/main/scala/bridges/GroundTestBridge.scala +++ b/sim/firesim-lib/src/main/scala/bridges/GroundTestBridge.scala @@ -16,15 +16,17 @@ class GroundTestBridge extends BlackBox } object GroundTestBridge { - def apply(success: Bool)(implicit p: Parameters): GroundTestBridge = { + def apply(clock: Clock, success: Bool)(implicit p: Parameters): GroundTestBridge = { val bridge = Module(new GroundTestBridge) bridge.io.success := success + bridge.io.clock := clock bridge } } class GroundTestBridgeTargetIO extends Bundle { val success = Input(Bool()) + val clock = Input(Clock()) } class GroundTestBridgeModule(implicit p: Parameters) diff --git a/sim/firesim-lib/src/main/scala/bridges/SerialBridge.scala b/sim/firesim-lib/src/main/scala/bridges/SerialBridge.scala index 778bc8d7..fd7b5dfe 100644 --- a/sim/firesim-lib/src/main/scala/bridges/SerialBridge.scala +++ b/sim/firesim-lib/src/main/scala/bridges/SerialBridge.scala @@ -18,9 +18,10 @@ class SerialBridge extends BlackBox with Bridge[HostPortIO[SerialBridgeTargetIO] } object SerialBridge { - def apply(port: SerialIO)(implicit p: Parameters): SerialBridge = { + def apply(clock: Clock, port: SerialIO)(implicit p: Parameters): SerialBridge = { val ep = Module(new SerialBridge) ep.io.serial <> port + ep.io.clock := clock ep } } @@ -28,6 +29,7 @@ object SerialBridge { class SerialBridgeTargetIO extends Bundle { val serial = Flipped(new SerialIO(testchipip.SerialAdapter.SERIAL_IF_WIDTH)) val reset = Input(Bool()) + val clock = Input(Clock()) } class SerialBridgeModule(implicit p: Parameters) extends BridgeModule[HostPortIO[SerialBridgeTargetIO]]()(p) { diff --git a/sim/firesim-lib/src/main/scala/bridges/SimpleNICBridge.scala b/sim/firesim-lib/src/main/scala/bridges/SimpleNICBridge.scala index 59049cef..4833a955 100644 --- a/sim/firesim-lib/src/main/scala/bridges/SimpleNICBridge.scala +++ b/sim/firesim-lib/src/main/scala/bridges/SimpleNICBridge.scala @@ -24,17 +24,24 @@ import TokenQueueConsts._ case object LoopbackNIC extends Field[Boolean](false) -class NICBridge(implicit p: Parameters) extends BlackBox with Bridge[HostPortIO[NICIOvonly], SimpleNICBridgeModule] { - val io = IO(Flipped(new NICIOvonly)) +class NICTargetIO extends Bundle { + val clock = Input(Clock()) + val nic = Flipped(new NICIOvonly) +} + +class NICBridge(implicit p: Parameters) extends BlackBox with Bridge[HostPortIO[NICTargetIO], SimpleNICBridgeModule] { + val io = IO(new NICTargetIO) val bridgeIO = HostPort(io) val constructorArg = None generateAnnotations() } + object NICBridge { - def apply(nicIO: NICIOvonly)(implicit p: Parameters): NICBridge = { + def apply(clock: Clock, nicIO: NICIOvonly)(implicit p: Parameters): NICBridge = { val ep = Module(new NICBridge) - ep.io <> nicIO + ep.io.nic <> nicIO + ep.io.clock := clock ep } } @@ -174,10 +181,10 @@ class HostToNICTokenGenerator(nTokens: Int)(implicit p: Parameters) extends Modu when (seedDone) { state := s_forward } } -class SimpleNICBridgeModule(implicit p: Parameters) extends BridgeModule[HostPortIO[NICIOvonly]]()(p) +class SimpleNICBridgeModule(implicit p: Parameters) extends BridgeModule[HostPortIO[NICTargetIO]]()(p) with BidirectionalDMA { val io = IO(new WidgetIO) - val hPort = IO(HostPort(Flipped(new NICIOvonly))) + val hPort = IO(HostPort(new NICTargetIO)) // DMA mixin parameters lazy val fromHostCPUQueueDepth = TOKEN_QUEUE_DEPTH lazy val toHostCPUQueueDepth = TOKEN_QUEUE_DEPTH @@ -190,7 +197,7 @@ class SimpleNICBridgeModule(implicit p: Parameters) extends BridgeModule[HostPor val bigtokenToNIC = Module(new BigTokenToNICTokenAdapter) val NICtokenToBig = Module(new NICTokenToBigTokenAdapter) - val target = hPort.hBits + val target = hPort.hBits.nic val tFireHelper = DecoupledHelper(hPort.toHost.hValid, hPort.fromHost.hReady) val tFire = tFireHelper.fire diff --git a/sim/firesim-lib/src/main/scala/bridges/TracerVBridge.scala b/sim/firesim-lib/src/main/scala/bridges/TracerVBridge.scala index 91585316..1e79a623 100644 --- a/sim/firesim-lib/src/main/scala/bridges/TracerVBridge.scala +++ b/sim/firesim-lib/src/main/scala/bridges/TracerVBridge.scala @@ -11,46 +11,70 @@ import freechips.rocketchip.rocket.TracedInstruction import freechips.rocketchip.subsystem.RocketTilesKey import freechips.rocketchip.tile.TileKey -import testchipip.{TraceOutputTop, DeclockedTracedInstruction, TracedInstructionWidths} +import testchipip.{TileTraceIO, DeclockedTracedInstruction, TracedInstructionWidths} +import midas.targetutils.TriggerSource import midas.widgets._ import testchipip.{StreamIO, StreamChannel} import junctions.{NastiIO, NastiKey} import TokenQueueConsts._ - case class TracerVKey( - insnWidths: Seq[TracedInstructionWidths], // Widths of variable length fields in each TI - vecSizes: Seq[Int] // The number of insns in each vec (= max insns retired at that core) + insnWidths: TracedInstructionWidths, // Widths of variable length fields in each TI + vecSize: Int // The number of insns in the traced insn vec (= max insns retired at that core) ) -class TracerVBridge(traceProto: Seq[Vec[DeclockedTracedInstruction]]) extends BlackBox - with Bridge[HostPortIO[TraceOutputTop], TracerVBridgeModule] { - val io = IO(Flipped(new TraceOutputTop(traceProto))) + +class TracerVTargetIO(insnWidths: TracedInstructionWidths, numInsns: Int) extends Bundle { + val trace = Input(new TileTraceIO(insnWidths, numInsns)) + val triggerCredit = Output(Bool()) + val triggerDebit = Output(Bool()) +} + +// Warning: If you're not going to use the companion object to instantiate this +// bridge you must call generate trigger annotations _in the parent module_. +// +// TODO: Generalize a mechanism to promote annotations from extracted bridges... +class TracerVBridge(insnWidths: TracedInstructionWidths, numInsns: Int) extends BlackBox + with Bridge[HostPortIO[TracerVTargetIO], TracerVBridgeModule] { + val io = IO(new TracerVTargetIO(insnWidths, numInsns)) val bridgeIO = HostPort(io) - val constructorArg = Some(TracerVKey(io.getWidths, io.getVecSizes)) + val constructorArg = Some(TracerVKey(insnWidths, numInsns)) generateAnnotations() + // Use in parent module: annotates the bridge instance's ports to indicate its trigger sources + // def generateTriggerAnnotations(): Unit = TriggerSource(io.triggerCredit, io.triggerDebit) + def generateTriggerAnnotations(): Unit = + TriggerSource.evenUnderReset(WireDefault(io.triggerCredit), WireDefault(io.triggerDebit)) } object TracerVBridge { - def apply(port: TraceOutputTop)(implicit p:Parameters): Seq[TracerVBridge] = { - val ep = Module(new TracerVBridge(port.getProto)) - ep.io <> port - Seq(ep) + def apply(tracedInsns: TileTraceIO)(implicit p:Parameters): TracerVBridge = { + val ep = Module(new TracerVBridge(tracedInsns.insnWidths, tracedInsns.numInsns)) + ep.generateTriggerAnnotations + ep.io.trace := tracedInsns + ep } } -class TracerVBridgeModule(key: TracerVKey)(implicit p: Parameters) extends BridgeModule[HostPortIO[TraceOutputTop]]()(p) +class TracerVBridgeModule(key: TracerVKey)(implicit p: Parameters) extends BridgeModule[HostPortIO[TracerVTargetIO]]()(p) with UnidirectionalDMAToHostCPU { val io = IO(new WidgetIO) - val hPort = IO(HostPort(Flipped(TraceOutputTop(key.insnWidths, key.vecSizes)))) + val hPort = IO(HostPort(new TracerVTargetIO(key.insnWidths, key.vecSize))) - val tFire = hPort.toHost.hValid && hPort.fromHost.hReady - //trigger conditions - val traces = hPort.hBits.traces.flatten + + // Mask off valid committed instructions when under reset + val traces = hPort.hBits.trace.insns.map({ unmasked => + val masked = WireDefault(unmasked) + masked.valid := unmasked.valid && !hPort.hBits.trace.reset + masked + }) private val pcWidth = traces.map(_.iaddr.getWidth).max private val insnWidth = traces.map(_.insn.getWidth).max + val cycleCountWidth = 64 + // Set after trigger-dependent memory-mapped registers have been set, to + // prevent spurious credits + val initDone = genWORegInit(Wire(Bool()), "initDone", false.B) //Program Counter trigger value can be configured externally val hostTriggerPCWidthOffset = pcWidth - p(CtrlNastiKey).dataBits val hostTriggerPCLowWidth = if (hostTriggerPCWidthOffset > 0) p(CtrlNastiKey).dataBits else pcWidth @@ -82,7 +106,7 @@ class TracerVBridgeModule(key: TracerVKey)(implicit p: Parameters) extends Bridg attach(hostTriggerCycleCountStartHigh, "hostTriggerCycleCountStartHigh", WriteOnly) attach(hostTriggerCycleCountStartLow, "hostTriggerCycleCountStartLow", WriteOnly) val hostTriggerCycleCountStart = Cat(hostTriggerCycleCountStartHigh, hostTriggerCycleCountStartLow) - val triggerCycleCountStart = RegInit(0.U(64.W)) + val triggerCycleCountStart = RegInit(0.U(cycleCountWidth.W)) triggerCycleCountStart := hostTriggerCycleCountStart val hostTriggerCycleCountEndHigh = RegInit(0.U(hostTriggerCycleCountHighWidth.W)) @@ -90,10 +114,10 @@ class TracerVBridgeModule(key: TracerVKey)(implicit p: Parameters) extends Bridg attach(hostTriggerCycleCountEndHigh, "hostTriggerCycleCountEndHigh", WriteOnly) attach(hostTriggerCycleCountEndLow, "hostTriggerCycleCountEndLow", WriteOnly) val hostTriggerCycleCountEnd = Cat(hostTriggerCycleCountEndHigh, hostTriggerCycleCountEndLow) - val triggerCycleCountEnd = RegInit(0.U(64.W)) + val triggerCycleCountEnd = RegInit(0.U(cycleCountWidth.W)) triggerCycleCountEnd := hostTriggerCycleCountEnd - val trace_cycle_counter = RegInit(0.U(64.W)) + val trace_cycle_counter = RegInit(0.U(cycleCountWidth.W)) //target instruction type trigger (trigger through target software) //can configure the trigger instruction type externally though simulation driver @@ -144,25 +168,21 @@ class TracerVBridgeModule(key: TracerVKey)(implicit p: Parameters) extends Bridg 2.U -> triggerPCValVec.reduce(_ || _), 3.U -> triggerInstValVec.reduce(_ || _))) - //TODO: for inter-widget triggering - //io.trigger_out.head <> trigger - if (p(midas.TraceTrigger)) { - BoringUtils.addSource(trigger, s"trace_trigger") - } + val tFireHelper = DecoupledHelper(outgoingPCISdat.io.enq.ready, hPort.toHost.hValid, hPort.fromHost.hReady, initDone) + + val triggerReg = RegEnable(trigger, false.B, tFireHelper.fire) + hPort.hBits.triggerDebit := !trigger && triggerReg + hPort.hBits.triggerCredit := trigger && !triggerReg // DMA mixin parameters lazy val toHostCPUQueueDepth = TOKEN_QUEUE_DEPTH lazy val dmaSize = BigInt((BIG_TOKEN_WIDTH / 8) * TOKEN_QUEUE_DEPTH) val uint_traces = (traces map (trace => Cat(trace.valid, trace.iaddr).pad(64))).reverse - outgoingPCISdat.io.enq.bits := Cat(Cat(trace_cycle_counter, - 0.U((outgoingPCISdat.io.enq.bits.getWidth - Cat(uint_traces).getWidth - trace_cycle_counter.getWidth).W)), - Cat(uint_traces)) + outgoingPCISdat.io.enq.bits := Cat(uint_traces :+ trace_cycle_counter.pad(64)).pad(BIG_TOKEN_WIDTH) - val tFireHelper = DecoupledHelper(outgoingPCISdat.io.enq.ready, hPort.toHost.hValid) hPort.toHost.hReady := tFireHelper.fire(hPort.toHost.hValid) - // We don't drive tokens back to the target. - hPort.fromHost.hValid := true.B + hPort.fromHost.hValid := tFireHelper.fire(hPort.fromHost.hReady) outgoingPCISdat.io.enq.valid := tFireHelper.fire(outgoingPCISdat.io.enq.ready, trigger) @@ -170,23 +190,13 @@ class TracerVBridgeModule(key: TracerVKey)(implicit p: Parameters) extends Bridg trace_cycle_counter := trace_cycle_counter + 1.U } - // This need to go on a debug switch - //when (outgoingPCISdat.io.enq.fire()) { - // hPort.hBits.traces.zipWithIndex.foreach({ case (bundle, bIdx) => - // printf("Tile %d Trace Bundle\n", bIdx.U) - // bundle.zipWithIndex.foreach({ case (insn, insnIdx) => - // printf(p"insn ${insnIdx}: ${insn}\n") - // //printf(b"insn ${insnIdx}, valid: ${insn.valid}") - // //printf(b"insn ${insnIdx}, iaddr: ${insn.iaddr}") - // //printf(b"insn ${insnIdx}, insn: ${insn.insn}") - // //printf(b"insn ${insnIdx}, priv: ${insn.priv}") - // //printf(b"insn ${insnIdx}, exception: ${insn.exception}") - // //printf(b"insn ${insnIdx}, interrupt: ${insn.interrupt}") - // //printf(b"insn ${insnIdx}, cause: ${insn.cause}") - // //printf(b"insn ${insnIdx}, tval: ${insn.tval}") - // }) - // }) - //} attach(outgoingPCISdat.io.deq.valid && !outgoingPCISdat.io.enq.ready, "tracequeuefull", ReadOnly) genCRFile() + override def genHeader(base: BigInt, sb: StringBuilder) { + import CppGenerationUtils._ + val headerWidgetName = getWName.toUpperCase + super.genHeader(base, sb) + sb.append(genConstStatic(s"${headerWidgetName}_max_core_ipc", UInt32(traces.size))) + emitClockDomainInfo(headerWidgetName, sb) + } } diff --git a/sim/firesim-lib/src/main/scala/bridges/UARTBridge.scala b/sim/firesim-lib/src/main/scala/bridges/UARTBridge.scala index e35535b1..1b71ecda 100644 --- a/sim/firesim-lib/src/main/scala/bridges/UARTBridge.scala +++ b/sim/firesim-lib/src/main/scala/bridges/UARTBridge.scala @@ -15,6 +15,7 @@ import sifive.blocks.devices.uart.{UARTPortIO, PeripheryUARTKey} // DOC include start: UART Bridge Target-Side Interface class UARTBridgeTargetIO extends Bundle { + val clock = Input(Clock()) val uart = Flipped(new UARTPortIO) // Note this reset is optional and used only to reset target-state modelled // in the bridge This reset just like any other Bool included in your target @@ -58,9 +59,10 @@ class UARTBridge(implicit p: Parameters) extends BlackBox // DOC include start: UART Bridge Companion Object object UARTBridge { - def apply(uart: UARTPortIO)(implicit p: Parameters): UARTBridge = { + def apply(clock: Clock, uart: UARTPortIO)(implicit p: Parameters): UARTBridge = { val ep = Module(new UARTBridge) ep.io.uart <> uart + ep.io.clock := clock ep } } diff --git a/sim/firesim-lib/src/main/scala/configs/CompilerConfigs.scala b/sim/firesim-lib/src/main/scala/configs/CompilerConfigs.scala index aed5904c..8e6009e6 100644 --- a/sim/firesim-lib/src/main/scala/configs/CompilerConfigs.scala +++ b/sim/firesim-lib/src/main/scala/configs/CompilerConfigs.scala @@ -57,13 +57,13 @@ class WithILATopWiringTransform extends Config((site, here, up) => { // Implements the AutoCounter performace counters features class WithAutoCounter extends Config((site, here, up) => { - case midas.TraceTrigger => true - case TargetTransforms => ((p: Parameters) => Seq(new midas.passes.AutoCounterTransform()(p))) +: up(TargetTransforms, site) + case midas.EnableAutoCounter => true }) class WithAutoCounterPrintf extends Config((site, here, up) => { + case midas.EnableAutoCounter => true + case midas.AutoCounterUsePrintfImpl => true case midas.SynthPrints => true - case TargetTransforms => ((p: Parameters) => Seq(new midas.passes.AutoCounterTransform(printcounter = true)(p))) +: up(TargetTransforms, site) }) class BaseF1Config extends Config( diff --git a/sim/midas/src/main/cc/bridges/autocounter.cc b/sim/midas/src/main/cc/bridges/autocounter.cc index 548562ca..4c727c12 100644 --- a/sim/midas/src/main/cc/bridges/autocounter.cc +++ b/sim/midas/src/main/cc/bridges/autocounter.cc @@ -15,33 +15,52 @@ #include autocounter_t::autocounter_t( - simif_t *sim, std::vector &args, AUTOCOUNTERBRIDGEMODULE_struct * mmio_addrs, AddressMap addr_map, int autocounterno) : bridge_driver_t(sim), addr_map(addr_map) -{ - this->mmio_addrs = mmio_addrs; + simif_t *sim, + std::vector &args, + AUTOCOUNTERBRIDGEMODULE_struct * mmio_addrs, + AddressMap addr_map, + const char* const clock_domain_name, + const unsigned int clock_multiplier, + const unsigned int clock_divisor, + int autocounterno) : + bridge_driver_t(sim), + mmio_addrs(mmio_addrs), + addr_map(addr_map), + clock_info(clock_domain_name, clock_multiplier, clock_divisor) { this->readrate = 0; this->autocounter_filename = "AUTOCOUNTER"; const char *autocounter_filename_in = NULL; - std::string num_equals = std::to_string(autocounterno) + std::string("="); - std::string readrate_arg = std::string("+autocounter-readrate") + num_equals; - std::string filename_arg = std::string("+autocounter-filename") + num_equals; + std::string readrate_arg = std::string("+autocounter-readrate="); + std::string filename_arg = std::string("+autocounter-filename="); for (auto &arg: args) { if (arg.find(readrate_arg) == 0) { char *str = const_cast(arg.c_str()) + readrate_arg.length(); - this->readrate = atol(str);; + uint64_t base_cycles = atol(str); + this->readrate = this->clock_info.to_local_cycles(base_cycles); + // TODO: Just fix this in the bridge by not sampling with a fixed frequency + if (this->clock_info.to_base_cycles(this->readrate) != base_cycles) { + fprintf(stderr, +"[AutoCounter] Warning: requested sample rate of %llu [base] cycles does not map to a whole number\n\ + of cycles in clock domain: %s, (%d/%d) of base clock.\n", + base_cycles, this->clock_info.domain_name, + this->clock_info.multiplier, this->clock_info.divisor); + fprintf(stderr, "[AutoCounter] Workaround: Pick a sample rate that is divisible by all clock divisors.\n"); + } + } if (arg.find(filename_arg) == 0) { autocounter_filename_in = const_cast(arg.c_str()) + filename_arg.length(); - this->autocounter_filename = std::string(autocounter_filename_in); + this->autocounter_filename = std::string(autocounter_filename_in) + std::to_string(autocounterno); } } autocounter_file.open(this->autocounter_filename, std::ofstream::out); if(!autocounter_file.is_open()) { - //throw std::runtime_error("Could not open autocounter output file\n"); throw std::runtime_error("Could not open output file: " + this->autocounter_filename); } + this->clock_info.emit_file_header(autocounter_file); } autocounter_t::~autocounter_t() { @@ -50,67 +69,46 @@ autocounter_t::~autocounter_t() { void autocounter_t::init() { cur_cycle = 0; - - write(addr_map.w_registers.at("readrate_low"), readrate & ((1ULL << 32) - 1)); + // Decrement the readrate by one to simplify the HW a little bit + write(addr_map.w_registers.at("readrate_low"), (readrate - 1) & ((1ULL << 32) - 1)); write(addr_map.w_registers.at("readrate_high"), this->readrate >> 32); - write(addr_map.w_registers.at("readdone"), 1); - + write(mmio_addrs->init_done, 1); } +bool autocounter_t::drain_sample() { + bool bridge_has_sample = read(addr_map.r_registers.at("countersready")); + + if (bridge_has_sample) { + cur_cycle = read(this->mmio_addrs->cycles_low); + cur_cycle |= ((uint64_t)read(this->mmio_addrs->cycles_high)) << 32; + autocounter_file << "Cycle " << cur_cycle << std::endl; + autocounter_file << "============================" << std::endl; + for (auto pair: addr_map.r_registers) { + + std::string low_prefix = std::string("autocounter_low_"); + std::string high_prefix = std::string("autocounter_high_"); + + if (pair.first.find("autocounter_low_") == 0) { + char *str = const_cast(pair.first.c_str()) + low_prefix.length(); + std::string countername(str); + uint64_t counter_val = ((uint64_t) (read(addr_map.r_registers.at(high_prefix + countername)))) << 32; + counter_val |= read(pair.second); + autocounter_file << "PerfCounter " << str << ": " << counter_val << std::endl; + } + + } + write(addr_map.w_registers.at("readdone"), 1); + autocounter_file << "" << std::endl; + } + return bridge_has_sample; +} void autocounter_t::tick() { - write(addr_map.w_registers.at("readdone"), 0); - if (read(addr_map.r_registers.at("countersready"))) { - write(addr_map.w_registers.at("readdone"), 1); - cur_cycle = read(this->mmio_addrs->cycles_low); - cur_cycle |= ((uint64_t)read(this->mmio_addrs->cycles_high)) << 32; - autocounter_file << "Cycle " << cur_cycle << std::endl; - autocounter_file << "============================" << std::endl; - for (auto pair: addr_map.r_registers) { - - std::string low_prefix = std::string("autocounter_low_"); - std::string high_prefix = std::string("autocounter_high_"); - - if (pair.first.find("autocounter_low_") == 0) { - char *str = const_cast(pair.first.c_str()) + low_prefix.length(); - std::string countername(str); - uint64_t counter_val = ((uint64_t) (read(addr_map.r_registers.at(high_prefix + countername)))) << 32; - counter_val |= read(pair.second); - autocounter_file << "PerfCounter " << str << ": " << counter_val << std::endl; - } - - } - write(addr_map.w_registers.at("readdone"), 1); - autocounter_file << "" << std::endl; - } + drain_sample(); } - void autocounter_t::finish() { - write(addr_map.w_registers.at("readdone"), 0); - if (read(addr_map.r_registers.at("countersready"))) { - write(addr_map.w_registers.at("readdone"), 1); - cur_cycle = read(this->mmio_addrs->cycles_low); - cur_cycle |= ((uint64_t)read(this->mmio_addrs->cycles_high)) << 32; - autocounter_file << "Cycle " << cur_cycle << std::endl; - autocounter_file << "============================" << std::endl; - for (auto pair: addr_map.r_registers) { - - std::string low_prefix = std::string("autocounter_low_"); - std::string high_prefix = std::string("autocounter_high_"); - - if (pair.first.find("autocounter_low_") == 0) { - char *str = const_cast(pair.first.c_str()) + low_prefix.length(); - std::string countername(str); - uint64_t counter_val = ((uint64_t) (read(addr_map.r_registers.at(high_prefix + countername)))) << 32; - counter_val |= read(pair.second); - autocounter_file << "PerfCounter " << str << ": " << counter_val << std::endl; - } - - } - write(addr_map.w_registers.at("readdone"), 1); - autocounter_file << "" << std::endl; - } + while(drain_sample()); } #endif // AUTOCOUNTERBRIDGEMODULE_struct_guard diff --git a/sim/midas/src/main/cc/bridges/autocounter.h b/sim/midas/src/main/cc/bridges/autocounter.h index 8a85a820..5a7dd617 100644 --- a/sim/midas/src/main/cc/bridges/autocounter.h +++ b/sim/midas/src/main/cc/bridges/autocounter.h @@ -3,17 +3,40 @@ #include "bridges/bridge_driver.h" #include "bridges/address_map.h" +#include "bridges/clock_info.h" #include #include -// TODO: get this automatically -#define NUM_CORES 1 +// Bridge Driver Instantiation Template +#define INSTANTIATE_AUTOCOUNTER(FUNC,IDX) \ + AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _substruct_create; \ + FUNC(new autocounter_t( \ + this, \ + args, \ + AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _substruct, \ + AddressMap(AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _R_num_registers, \ + (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _R_addrs, \ + (const char* const*) AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _R_names, \ + AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _W_num_registers, \ + (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _W_addrs, \ + (const char* const*) AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _W_names), \ + AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _clock_domain_name, \ + AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _clock_multiplier, \ + AUTOCOUNTERBRIDGEMODULE_ ## IDX ## _clock_divisor, \ + IDX)); \ + #ifdef AUTOCOUNTERBRIDGEMODULE_struct_guard class autocounter_t: public bridge_driver_t { public: - autocounter_t(simif_t *sim, std::vector &args, - AUTOCOUNTERBRIDGEMODULE_struct * mmio_addrs, AddressMap addr_map, int autocounterno); + autocounter_t(simif_t *sim, + std::vector &args, + AUTOCOUNTERBRIDGEMODULE_struct * mmio_addrs, + AddressMap addr_map, + const char* const clock_domain_name, + const unsigned int clock_multiplier, + const unsigned int clock_divisor, + int autocounterno); ~autocounter_t(); virtual void init(); @@ -23,13 +46,18 @@ class autocounter_t: public bridge_driver_t { virtual void finish(); private: + simif_t* sim; AUTOCOUNTERBRIDGEMODULE_struct * mmio_addrs; AddressMap addr_map; - simif_t* sim; + ClockInfo clock_info; uint64_t cur_cycle; uint64_t readrate; std::string autocounter_filename; std::ofstream autocounter_file; + + // Pulls a single sample from the Bridge, if available. + // Returns true if a sample was read + bool drain_sample(); }; #endif // AUTOCOUNTERWIDGET_struct_guard diff --git a/sim/midas/src/main/cc/bridges/clock_info.h b/sim/midas/src/main/cc/bridges/clock_info.h new file mode 100644 index 00000000..2a86cbd8 --- /dev/null +++ b/sim/midas/src/main/cc/bridges/clock_info.h @@ -0,0 +1,51 @@ +// See LICENSE for license details. + +#ifndef __CLOCK_INFO_H +#define __CLOCK_INFO_H + +#include + +/** + * Stores Bridge clock domain information and provides methods for converting + * from base clock cycles to cycles in the Bridge's clock domain ("local" cycles). + */ +class ClockInfo +{ +public: + ClockInfo( + const char* const clock_domain_name, + const unsigned int clock_multiplier, + const unsigned int clock_divisor): + domain_name(clock_domain_name), + multiplier(clock_multiplier), + divisor(clock_divisor) {}; + + const char* const domain_name; + const unsigned int multiplier; + const unsigned int divisor; + + // NB: These truncate and may be inexact, use with care + uint64_t to_local_cycles(uint64_t base_clock_cycles) { + return (base_clock_cycles * multiplier) / divisor; + }; + + uint64_t to_base_cycles(uint64_t local_clock_cycles) { + return (local_clock_cycles * divisor) / multiplier; + }; + + // Capture clock domain info in a string that can be prepended to + // driver-generated files so the user can disambiguate between them + std::string file_header() { + char buf[200]; + sprintf(buf, "# Clock Domain: %s, Relative Frequency: %d/%d of Base Clock\n", + domain_name, multiplier, divisor); + return std::string(buf); + }; + + void emit_file_header(std::ostream& os) { + os << file_header(); + } + +}; + +#endif // __CLOCK_INFO_H diff --git a/sim/midas/src/main/cc/bridges/synthesized_assertions.cc b/sim/midas/src/main/cc/bridges/synthesized_assertions.cc index 4b29e087..c589d76f 100644 --- a/sim/midas/src/main/cc/bridges/synthesized_assertions.cc +++ b/sim/midas/src/main/cc/bridges/synthesized_assertions.cc @@ -4,12 +4,6 @@ #include #include - -synthesized_assertions_t::synthesized_assertions_t(simif_t* sim, - ASSERTBRIDGEMODULE_struct * mmio_addrs): bridge_driver_t(sim) { - this->mmio_addrs = mmio_addrs; -}; - synthesized_assertions_t::~synthesized_assertions_t() { free(this->mmio_addrs); } @@ -17,22 +11,10 @@ synthesized_assertions_t::~synthesized_assertions_t() { void synthesized_assertions_t::tick() { if (read(this->mmio_addrs->fire)) { // Read assertion information - std::vector msgs; - std::ifstream file(std::string(TARGET_NAME) + ".asserts"); - std::string line; - std::ostringstream oss; - while (std::getline(file, line)) { - if (line == "0") { - msgs.push_back(oss.str()); - oss.str(std::string()); - } else { - oss << line << std::endl; - } - } assert_cycle = read(this->mmio_addrs->cycle_low); assert_cycle |= ((uint64_t)read(this->mmio_addrs->cycle_high)) << 32; assert_id = read(this->mmio_addrs->id); - std::cerr << msgs[assert_id]; + std::cerr << this->msgs[assert_id]; std::cerr << " at cycle: " << assert_cycle << std::endl; assert_fired = true; } diff --git a/sim/midas/src/main/cc/bridges/synthesized_assertions.h b/sim/midas/src/main/cc/bridges/synthesized_assertions.h index 84a0b51f..17482093 100644 --- a/sim/midas/src/main/cc/bridges/synthesized_assertions.h +++ b/sim/midas/src/main/cc/bridges/synthesized_assertions.h @@ -8,7 +8,15 @@ class synthesized_assertions_t: public bridge_driver_t { public: - synthesized_assertions_t(simif_t* sim, ASSERTBRIDGEMODULE_struct * mmio_addrs); + synthesized_assertions_t( + simif_t* sim, + ASSERTBRIDGEMODULE_struct * mmio_addrs, + unsigned int num_asserts, + const char* const* msgs) : + bridge_driver_t(sim), + mmio_addrs(mmio_addrs), + num_asserts(num_asserts), + msgs(msgs) {}; ~synthesized_assertions_t(); virtual void init() {}; virtual void tick(); @@ -21,6 +29,8 @@ class synthesized_assertions_t: public bridge_driver_t int assert_id; uint64_t assert_cycle; ASSERTBRIDGEMODULE_struct * mmio_addrs; + const unsigned int num_asserts; + const char* const* msgs; }; #endif // ASSERTBRIDGEMODULE_struct_guard diff --git a/sim/midas/src/main/cc/bridges/synthesized_prints.cc b/sim/midas/src/main/cc/bridges/synthesized_prints.cc index 56c10560..adbfedd9 100644 --- a/sim/midas/src/main/cc/bridges/synthesized_prints.cc +++ b/sim/midas/src/main/cc/bridges/synthesized_prints.cc @@ -15,7 +15,11 @@ synthesized_prints_t::synthesized_prints_t( const char* const* format_strings, const unsigned int* argument_counts, const unsigned int* argument_widths, - unsigned int dma_address): + unsigned int dma_address, + const char* const clock_domain_name, + const unsigned int clock_multiplier, + const unsigned int clock_divisor, + int printno) : bridge_driver_t(sim), mmio_addrs(mmio_addrs), print_count(print_count), @@ -25,17 +29,24 @@ synthesized_prints_t::synthesized_prints_t( format_strings(format_strings), argument_counts(argument_counts), argument_widths(argument_widths), - dma_address(dma_address) { + dma_address(dma_address), + clock_info(clock_domain_name, clock_multiplier, clock_divisor), + printno(printno) { assert((token_bytes & (token_bytes - 1)) == 0); assert(print_count > 0); - const char *printfilename = default_filename.c_str(); + auto printfilename = default_filename + std::to_string(printno); this->start_cycle = 0; this->end_cycle = -1ULL; + std::string num_equals = std::to_string(printno) + std::string("="); + // PlusArgs are shared across all Bridge Driver instances + // The file into which to emit captured prinfs. This is suffixed with the driver number std::string printfile_arg = std::string("+print-file="); + // The cycle at which to start printing in base clock cycles std::string printstart_arg = std::string("+print-start="); + // The cycle at which to stop printing in base clock cycles std::string printend_arg = std::string("+print-end="); // Does not format the printfs, before writing them to file std::string binary_arg = std::string("+print-binary"); @@ -49,17 +60,17 @@ synthesized_prints_t::synthesized_prints_t( this->batch_beats = desired_batch_beats; } - for (auto &arg: args) { + for (auto arg: args) { if (arg.find(printfile_arg) == 0) { - printfilename = const_cast(arg.c_str()) + printfile_arg.length(); + printfilename = arg.erase(0, printfile_arg.length()) + std::to_string(printno); } if (arg.find(printstart_arg) == 0) { char *str = const_cast(arg.c_str()) + printstart_arg.length(); - this->start_cycle = atol(str); + this->start_cycle = this->clock_info.to_local_cycles(atol(str)); } if (arg.find(printend_arg) == 0) { char *str = const_cast(arg.c_str()) + printend_arg.length(); - this->end_cycle = atol(str); + this->end_cycle = this->clock_info.to_local_cycles(atol(str)); } if (arg.find(binary_arg) == 0) { human_readable = false; @@ -70,13 +81,14 @@ synthesized_prints_t::synthesized_prints_t( } current_cycle = start_cycle; // We won't receive tokens until start_cycle; so fast-forward - this->printfile.open(printfilename, std::ios_base::out | std::ios_base::binary); + this->printfile.open(printfilename.c_str(), std::ios_base::out | std::ios_base::binary); if (!this->printfile.is_open()) { fprintf(stderr, "Could not open print log file: %s\n", printfilename); abort(); } this->printstream = &(this->printfile); + this->clock_info.emit_file_header(*(this->printstream)); widths.resize(print_count); // Used to reconstruct the relative position of arguments in the flattened argument_widths array diff --git a/sim/midas/src/main/cc/bridges/synthesized_prints.h b/sim/midas/src/main/cc/bridges/synthesized_prints.h index 96914afc..8fa41254 100644 --- a/sim/midas/src/main/cc/bridges/synthesized_prints.h +++ b/sim/midas/src/main/cc/bridges/synthesized_prints.h @@ -9,6 +9,27 @@ #include #include "bridge_driver.h" +#include "clock_info.h" + +// Bridge Driver Instantiation Template +#define INSTANTIATE_PRINTF(FUNC,IDX) \ + PRINTBRIDGEMODULE_ ## IDX ## _substruct_create; \ + FUNC(new synthesized_prints_t( \ + this, \ + args, \ + PRINTBRIDGEMODULE_ ## IDX ## _substruct, \ + PRINTBRIDGEMODULE_ ## IDX ## _print_count, \ + PRINTBRIDGEMODULE_ ## IDX ## _token_bytes, \ + PRINTBRIDGEMODULE_ ## IDX ## _idle_cycles_mask, \ + PRINTBRIDGEMODULE_ ## IDX ## _print_offsets, \ + PRINTBRIDGEMODULE_ ## IDX ## _format_strings, \ + PRINTBRIDGEMODULE_ ## IDX ## _argument_counts, \ + PRINTBRIDGEMODULE_ ## IDX ## _argument_widths, \ + PRINTBRIDGEMODULE_ ## IDX ## _DMA_ADDR, \ + PRINTBRIDGEMODULE_ ## IDX ## _clock_domain_name, \ + PRINTBRIDGEMODULE_ ## IDX ## _clock_multiplier, \ + PRINTBRIDGEMODULE_ ## IDX ## _clock_divisor, \ + IDX)); \ struct print_vars_t { std::vector data; @@ -34,7 +55,11 @@ class synthesized_prints_t: public bridge_driver_t const char* const* format_strings, const unsigned int* argument_counts, const unsigned int* argument_widths, - unsigned int dma_address); + unsigned int dma_address, + const char* const clock_domain_name, + const unsigned int clock_multiplier, + const unsigned int clock_divisor, + int printno); ~synthesized_prints_t(); virtual void init(); virtual void tick(); @@ -52,6 +77,8 @@ class synthesized_prints_t: public bridge_driver_t const unsigned int* argument_counts; const unsigned int* argument_widths; const unsigned int dma_address; + ClockInfo clock_info; + const int printno; // DMA batching parameters const size_t beat_bytes = DMA_DATA_BITS / 8; diff --git a/sim/midas/src/main/cc/simif.cc b/sim/midas/src/main/cc/simif.cc index 7682183b..2549b9d9 100644 --- a/sim/midas/src/main/cc/simif.cc +++ b/sim/midas/src/main/cc/simif.cc @@ -25,6 +25,8 @@ simif_t::simif_t() { this->loadmem_mmio_addrs = LOADMEMWIDGET_0_substruct; PEEKPOKEBRIDGEMODULE_0_substruct_create; this->defaultiowidget_mmio_addrs = PEEKPOKEBRIDGEMODULE_0_substruct; + CLOCKBRIDGEMODULE_0_substruct_create; + this->clock_bridge_mmio_addrs = CLOCKBRIDGEMODULE_0_substruct; } void simif_t::init(int argc, char** argv, bool log) { @@ -57,16 +59,16 @@ void simif_t::init(int argc, char** argv, bool log) { } uint64_t simif_t::actual_tcycle() { - write(this->defaultiowidget_mmio_addrs->tCycle_latch, 1); - data_t cycle_l = read(this->defaultiowidget_mmio_addrs->tCycle_0); - data_t cycle_h = read(this->defaultiowidget_mmio_addrs->tCycle_1); + write(this->clock_bridge_mmio_addrs->tCycle_latch, 1); + data_t cycle_l = read(this->clock_bridge_mmio_addrs->tCycle_0); + data_t cycle_h = read(this->clock_bridge_mmio_addrs->tCycle_1); return (((uint64_t) cycle_h) << 32) | cycle_l; } uint64_t simif_t::hcycle() { - write(this->defaultiowidget_mmio_addrs->hCycle_latch, 1); - data_t cycle_l = read(this->defaultiowidget_mmio_addrs->hCycle_0); - data_t cycle_h = read(this->defaultiowidget_mmio_addrs->hCycle_1); + write(this->clock_bridge_mmio_addrs->hCycle_latch, 1); + data_t cycle_l = read(this->clock_bridge_mmio_addrs->hCycle_0); + data_t cycle_h = read(this->clock_bridge_mmio_addrs->hCycle_1); return (((uint64_t) cycle_h) << 32) | cycle_l; } @@ -87,7 +89,7 @@ int simif_t::finish() { finish_sampling(); #endif - fprintf(stderr, "Runs %llu cycles\n", actual_tcycle()); + fprintf(stderr, "Ran %llu cycles (fastest target clock)\n", actual_tcycle()); fprintf(stderr, "[%s] %s Test", pass ? "PASS" : "FAIL", TARGET_NAME); if (!pass) { fprintf(stdout, " at cycle %llu", fail_t); } fprintf(stderr, "\nSEED: %ld\n", seed); diff --git a/sim/midas/src/main/cc/simif.h b/sim/midas/src/main/cc/simif.h index 1ccfbc97..b1899357 100644 --- a/sim/midas/src/main/cc/simif.h +++ b/sim/midas/src/main/cc/simif.h @@ -41,6 +41,7 @@ class simif_t SIMULATIONMASTER_struct * master_mmio_addrs; LOADMEMWIDGET_struct * loadmem_mmio_addrs; PEEKPOKEBRIDGEMODULE_struct * defaultiowidget_mmio_addrs; + CLOCKBRIDGEMODULE_struct * clock_bridge_mmio_addrs; midas_time_t sim_start_time; inline void take_steps(size_t n, bool blocking) { @@ -108,8 +109,8 @@ class simif_t // Returns an upper bound for the cycle reached by the target // If using blocking steps, this will be ~equivalent to actual_tcycle() uint64_t cycles(){ return t; }; - // Returns the current target cycle as measured by a hardware counter in the DefaultIOWidget - // (# of reset tokens generated) + // Returns the current target cycle of the fastest clock in the simulated system, based + // on the number of clock tokens enqueued (will report a larger number) uint64_t actual_tcycle(); // Returns the current host cycle as measured by a hardware counter uint64_t hcycle(); diff --git a/sim/midas/src/main/scala/midas/Config.scala b/sim/midas/src/main/scala/midas/Config.scala index decea7be..4c2cfa65 100644 --- a/sim/midas/src/main/scala/midas/Config.scala +++ b/sim/midas/src/main/scala/midas/Config.scala @@ -18,9 +18,26 @@ case object Platform extends Field[(Parameters) => PlatformShim] // Switches to synthesize prints and assertions case object SynthAsserts extends Field[Boolean] case object SynthPrints extends Field[Boolean] -case object TraceTrigger extends Field[Boolean] -// Exclude module instances from assertion and print synthesis -// Tuple of Parent Module (where the instance is instantiated) and the instance name + +// Auto Counter Switches +case object EnableAutoCounter extends Field[Boolean](false) +/** + * Chooses between the two implementation strategies for Auto Counter. + * + * True: Synthesized Printf Implementation + * - Generates a counter directly in the target module, adds a printf, + * and annotates it for printf synthesis + * - Pros: cycle-exact event resolution; counters are printed every time the event is asserted + * - Cons: considerably more resource intensive (64-bit values are synthesized in the printf) + * Biancolin: This seems like a waste of bandwidth? Maybe just print the message? + * + * False: Native Bridge Implementation (Default) + * - Wires out each annotated event (Bool) to a dedicated AutoCounter bridge. + * - Pros: More resource efficient; + * - Cons: Coarse event resolution (depends on the sampling frequency set in the bridge) + */ +case object AutoCounterUsePrintfImpl extends Field[Boolean](false) + case object EnableSnapshot extends Field[Boolean] case object HasDMAChannel extends Field[Boolean] case object KeepSamplesInMem extends Field[Boolean] @@ -45,7 +62,6 @@ class SimConfig extends Config((site, here, up) => { case DaisyWidth => 32 case SynthAsserts => false case SynthPrints => false - case TraceTrigger => false case EnableSnapshot => false case KeepSamplesInMem => true case CtrlNastiKey => NastiParameters(32, 32, 12) diff --git a/sim/midas/src/main/scala/midas/core/Channel.scala b/sim/midas/src/main/scala/midas/core/Channel.scala index 8550a397..9d9cad21 100644 --- a/sim/midas/src/main/scala/midas/core/Channel.scala +++ b/sim/midas/src/main/scala/midas/core/Channel.scala @@ -5,7 +5,7 @@ package core import freechips.rocketchip.config.{Parameters, Field} import freechips.rocketchip.unittest._ -import freechips.rocketchip.util.{DecoupledHelper} +import freechips.rocketchip.util.{DecoupledHelper, ShiftQueue} import freechips.rocketchip.tilelink.LFSR64 // Better than chisel's import chisel3._ @@ -20,36 +20,6 @@ import midas.core.SimUtils.{ChLeafType} // token streams irrevocably will introduce simulation non-determinism. case object GenerateTokenIrrevocabilityAssertions extends Field[Boolean](false) -// For now use the convention that clock ratios are set with respect to the transformed RTL -trait IsRationalClockRatio { - def numerator: Int - def denominator: Int - def isUnity() = numerator == denominator - def isReciprocal() = numerator == 1 - def isIntegral() = denominator == 1 - def inverse: IsRationalClockRatio -} - -case class RationalClockRatio(numerator: Int, denominator: Int) extends IsRationalClockRatio { - def inverse() = RationalClockRatio(denominator, numerator) -} - -case object UnityClockRatio extends IsRationalClockRatio { - val numerator = 1 - val denominator = 1 - def inverse() = UnityClockRatio -} - -case class ReciprocalClockRatio(denominator: Int) extends IsRationalClockRatio { - val numerator = 1 - def inverse = IntegralClockRatio(numerator = denominator) -} - -case class IntegralClockRatio(numerator: Int) extends IsRationalClockRatio { - val denominator = 1 - def inverse = ReciprocalClockRatio(denominator = numerator) -} - class PipeChannelIO[T <: ChLeafType](gen: T)(implicit p: Parameters) extends Bundle { val in = Flipped(Decoupled(gen)) val out = Decoupled(gen) @@ -61,15 +31,12 @@ class PipeChannelIO[T <: ChLeafType](gen: T)(implicit p: Parameters) extends Bun class PipeChannel[T <: ChLeafType]( val gen: T, - latency: Int, - clockRatio: IsRationalClockRatio = UnityClockRatio + latency: Int )(implicit p: Parameters) extends Module { - - require(clockRatio.isUnity) require(latency == 0 || latency == 1) val io = IO(new PipeChannelIO(gen)) - val tokens = Module(new Queue(gen, p(ChannelLen))) + val tokens = Module(new ShiftQueue(gen, 2)) tokens.io.enq <> io.in io.out <> tokens.io.deq @@ -77,6 +44,7 @@ class PipeChannel[T <: ChLeafType]( val initializing = RegNext(reset.toBool) when(initializing) { tokens.io.enq.valid := true.B + tokens.io.enq.bits := 0.U.asTypeOf(tokens.io.enq.bits) io.in.ready := false.B } } @@ -119,7 +87,7 @@ class PipeChannelUnitTest( override val testName = "PipeChannel Unit Test" val payloadWidth = 8 - val dut = Module(new PipeChannel(UInt(payloadWidth.W), latency, UnityClockRatio)) + val dut = Module(new PipeChannel(UInt(payloadWidth.W), latency)) val referenceInput = Wire(UInt(payloadWidth.W)) val referenceOutput = ShiftRegister(referenceInput, latency) @@ -232,24 +200,21 @@ class ReadyValidChannelIO[T <: Data](gen: T)(implicit p: Parameters) extends Bun class ReadyValidChannel[T <: Data]( gen: T, n: Int = 2, // Target queue depth - // Clock ratio (N/M) of deq interface (N) vs enq interface (M) - clockRatio: IsRationalClockRatio = UnityClockRatio )(implicit p: Parameters) extends Module { - require(clockRatio.isUnity, "CDC is not currently implemented") val io = IO(new ReadyValidChannelIO(gen)) - val enqFwdQ = Module(new Queue(ValidIO(gen), 2, flow = true)) + val enqFwdQ = Module(new ShiftQueue(ValidIO(gen), 2, flow = true)) enqFwdQ.io.enq.bits.valid := io.enq.target.valid enqFwdQ.io.enq.bits.bits := io.enq.target.bits enqFwdQ.io.enq.valid := io.enq.fwd.hValid io.enq.fwd.hReady := enqFwdQ.io.enq.ready - val deqRevQ = Module(new Queue(Bool(), 2, flow = true)) + val deqRevQ = Module(new ShiftQueue(Bool(), 2, flow = true)) deqRevQ.io.enq.bits := io.deq.target.ready deqRevQ.io.enq.valid := io.deq.rev.hValid io.deq.rev.hReady := deqRevQ.io.enq.ready - val reference = Module(new Queue(gen, n)) + val reference = Module(new ShiftQueue(gen, n)) val deqFwdFired = RegInit(false.B) val enqRevFired = RegInit(false.B) @@ -300,13 +265,13 @@ class ReadyValidChannelUnitTest( queueDepth: Int = 2, timeout: Int = 50000 )(implicit p: Parameters) extends UnitTest(timeout) { - override val testName = "PipeChannel ClockRatio: ${clockRatio.numerator}/${clockRatio.denominator}" + override val testName = "PipeChannel" val payloadType = UInt(8.W) val resetLength = 4 val dut = Module(new ReadyValidChannel(payloadType)) - val reference = Module(new Queue(payloadType, queueDepth)) + val reference = Module(new ShiftQueue(payloadType, queueDepth)) // Generates target-reset tokens def resetTokenGen(): Bool = { diff --git a/sim/midas/src/main/scala/midas/core/ClockDomainCrossing.scala b/sim/midas/src/main/scala/midas/core/ClockDomainCrossing.scala deleted file mode 100644 index c8f33b9c..00000000 --- a/sim/midas/src/main/scala/midas/core/ClockDomainCrossing.scala +++ /dev/null @@ -1,60 +0,0 @@ -// See LICENSE for license details. - -package midas.core - -import freechips.rocketchip.tilelink.LFSR64 // Better than chisel's - -import chisel3._ -import chisel3.util._ - -trait ClockUtils { - // Assume time is measured in ps - val timeStepBits = 32 - -} - -class GenericClockCrossing[T <: Data](gen: T) extends MultiIOModule with ClockUtils { - val enq = IO(Flipped(Decoupled(gen))) - val deq = IO(Decoupled(gen)) - val enqDomainTimeStep = IO(Input(UInt(timeStepBits.W))) - val deqDomainTimeStep = IO(Input(UInt(timeStepBits.W))) - - val enqTokens = Queue(enq, 2) - - // Deq Domain handling - val residualTime = Reg(UInt(timeStepBits.W)) - val hasResidualTime = RegInit(false.B) - val timeToNextEnqEdge = Mux(hasResidualTime, residualTime, enqDomainTimeStep) - val timeToNextDeqEdge = RegInit(0.U(timeStepBits.W)) - - val enqTokenVisible = timeToNextEnqEdge > timeToNextDeqEdge - val tokenWouldExpire = timeToNextEnqEdge < timeToNextDeqEdge + deqDomainTimeStep - - deq.valid := enqTokens.valid && enqTokenVisible - deq.bits := enqTokens.bits - enqTokens.ready := !enqTokenVisible || deq.ready && tokenWouldExpire - - val enqTokenExpiring = enqTokens.fire - val deqTokenReleased = deq.fire - - // CASE 1: This ENQ token is visible in the current deq token, but not visible in future DEQ tokens - // ENQ N | ENQ N1 | - // ... | DEQ M | DEQ M1 | - when (enqTokenExpiring && deqTokenReleased) { - hasResidualTime := false.B - timeToNextDeqEdge := timeToNextDeqEdge + deqDomainTimeStep - timeToNextEnqEdge - // Case 2: This ENQ token is no longer visible, generally Fast -> Slow) - // ENQ N | ENQ N+1 | ... - // DEQ M | DEQ M+1... - }.elsewhen(enqTokenExpiring) { - hasResidualTime := false.B - timeToNextDeqEdge := timeToNextDeqEdge - timeToNextEnqEdge - // Case 3: This ENQ token is visible in the current and possibly future output tokens - // ENQ M | ... - // ENQ N | ENQ N+1 | ... - }.elsewhen(deqTokenReleased) { - hasResidualTime := true.B - timeToNextDeqEdge := deqDomainTimeStep - residualTime := timeToNextEnqEdge - deqDomainTimeStep - } -} diff --git a/sim/midas/src/main/scala/midas/core/SimUtils.scala b/sim/midas/src/main/scala/midas/core/SimUtils.scala index 8a500203..abff0a6e 100644 --- a/sim/midas/src/main/scala/midas/core/SimUtils.scala +++ b/sim/midas/src/main/scala/midas/core/SimUtils.scala @@ -80,4 +80,11 @@ object SimUtils { def parsePortsSeq(io: Seq[(String, Data)], alsoFlattenRVPorts: Boolean = true): ParsePortsTuple = parsePorts(io, alsoFlattenRVPorts) + // Returns reference to all clocks + def findClocks(field: Data): Seq[Clock] = field match { + case c: Clock => Seq(c) + case b: Record => b.elements.flatMap({ case (_, field) => findClocks(field) }).toSeq + case v: Vec[_] => v.flatMap(findClocks) + case o => Seq() + } } diff --git a/sim/midas/src/main/scala/midas/core/SimWrapper.scala b/sim/midas/src/main/scala/midas/core/SimWrapper.scala index 820495b3..85cac1c7 100644 --- a/sim/midas/src/main/scala/midas/core/SimWrapper.scala +++ b/sim/midas/src/main/scala/midas/core/SimWrapper.scala @@ -14,7 +14,7 @@ import freechips.rocketchip.config.{Parameters, Field} import chisel3._ import chisel3.util._ -import chisel3.experimental.{Direction, ChiselAnnotation, annotate} +import chisel3.experimental.{Direction, chiselName, ChiselAnnotation, annotate} import chisel3.experimental.DataMirror.directionOf import firrtl.annotations.{SingleTargetAnnotation, ReferenceTarget} @@ -106,22 +106,21 @@ abstract class ChannelizedWrapperIO(chAnnos: Seq[FAMEChannelConnectionAnnotation val payloadTypeMap: Map[FAMEChannelConnectionAnnotation, Data] = chAnnos.collect({ // Target Decoupled Channels need to have their target-valid ReferenceTarget removed - case ch @ FAMEChannelConnectionAnnotation(_,DecoupledForwardChannel(_,Some(vsrc),_,_),Some(srcs),_) => + case ch @ FAMEChannelConnectionAnnotation(_,DecoupledForwardChannel(_,Some(vsrc),_,_), _, Some(srcs),_) => ch -> regenPayloadType(srcs.filterNot(_ == vsrc)) - case ch @ FAMEChannelConnectionAnnotation(_,DecoupledForwardChannel(_,_,_,Some(vsink)),_,Some(sinks)) => + case ch @ FAMEChannelConnectionAnnotation(_,DecoupledForwardChannel(_,_,_,Some(vsink)), _, _, Some(sinks)) => ch -> regenPayloadType(sinks.filterNot(_ == vsink)) }).toMap val wireTypeMap: Map[FAMEChannelConnectionAnnotation, ChLeafType] = chAnnos.collect({ - case ch @ FAMEChannelConnectionAnnotation(_,fame.PipeChannel(_),Some(srcs),_) => ch -> regenWireType(srcs) - case ch @ FAMEChannelConnectionAnnotation(_,fame.PipeChannel(_),_,Some(sinks)) => ch -> regenWireType(sinks) + case ch @ FAMEChannelConnectionAnnotation(_,fame.PipeChannel(_),_,Some(srcs),_) => ch -> regenWireType(srcs) + case ch @ FAMEChannelConnectionAnnotation(_,fame.PipeChannel(_),_,_,Some(sinks)) => ch -> regenWireType(sinks) }).toMap val wireElements = ArrayBuffer[(String, ReadyValidIO[Data])]() - val wirePortMap: Map[String, WirePortTuple] = chAnnos.collect({ - case ch @ FAMEChannelConnectionAnnotation(globalName, fame.PipeChannel(_),sources,sinks) => { + case ch @ FAMEChannelConnectionAnnotation(globalName, fame.PipeChannel(_), _, sources, sinks) => { val sinkP = sinks.map({ tRefs => val name = tRefs.head.ref.stripSuffix("_bits") val port = Flipped(Decoupled(wireTypeMap(ch))) @@ -152,7 +151,7 @@ abstract class ChannelizedWrapperIO(chAnnos: Seq[FAMEChannelConnectionAnnotation // Using a channel's globalName; look up it's associated port tuple val rvPortMap: Map[String, TargetRVPortTuple] = chAnnos.collect({ - case ch @ FAMEChannelConnectionAnnotation(globalName, info@DecoupledForwardChannel(_,_,_,_), leafSources, leafSinks) => + case ch @ FAMEChannelConnectionAnnotation(globalName, info@DecoupledForwardChannel(_,_,_,_), _, leafSources, leafSinks) => val sourcePortPair = leafSources.map({ tRefs => require(!tRefs.isEmpty, "FIXME: Are empty decoupleds OK?") val validTRef: ReferenceTarget = info.validSource.getOrElse(throw new RuntimeException( @@ -198,15 +197,30 @@ abstract class ChannelizedWrapperIO(chAnnos: Seq[FAMEChannelConnectionAnnotation val chNameToAnnoMap = chAnnos.map(anno => anno.globalName -> anno) } +class ClockRecord(numClocks: Int) extends Record { + override val elements = ListMap(Seq.tabulate(numClocks)(i => s"_$i" -> Clock()):_*) + override def cloneType = new ClockRecord(numClocks).asInstanceOf[this.type] +} + class TargetBoxIO(val chAnnos: Seq[FAMEChannelConnectionAnnotation], leafTypeMap: Map[ReferenceTarget, firrtl.ir.Port]) extends ChannelizedWrapperIO(chAnnos, leafTypeMap) { - val clock = Input(Clock()) + def regenClockType(refTargets: Seq[ReferenceTarget]): Data = refTargets.size match { + case 1 => Clock() + case size => new ClockRecord(refTargets.size) + } + + val clockElement: (String, DecoupledIO[Data]) = chAnnos.collectFirst({ + case ch @ FAMEChannelConnectionAnnotation(globalName, fame.TargetClockChannel(_), _, _, Some(sinks)) => + sinks.head.ref.stripSuffix("_bits") -> Flipped(Decoupled(regenClockType(sinks))) + }).get + + val hostClock = Input(Clock()) val hostReset = Input(Bool()) - override val elements = ListMap((wireElements ++ rvElements):_*) ++ + override val elements = ListMap((Seq(clockElement) ++ wireElements ++ rvElements):_*) ++ // Untokenized ports - ListMap("clock" -> clock, "hostReset" -> hostReset) + ListMap("hostClock" -> hostClock, "hostReset" -> hostReset) override def cloneType: this.type = new TargetBoxIO(chAnnos, leafTypeMap).asInstanceOf[this.type] } @@ -220,7 +234,14 @@ class SimWrapperChannels(val chAnnos: Seq[FAMEChannelConnectionAnnotation], leafTypeMap: Map[ReferenceTarget, firrtl.ir.Port]) extends ChannelizedWrapperIO(chAnnos, leafTypeMap) { - override val elements = ListMap((wireElements ++ rvElements):_*) + def regenClockType(refTargets: Seq[ReferenceTarget]): Vec[Bool] = Vec(refTargets.size, Bool()) + + val clockElement: (String, DecoupledIO[Vec[Bool]]) = chAnnos.collectFirst({ + case ch @ FAMEChannelConnectionAnnotation(globalName, fame.TargetClockChannel(_), _, _, Some(sinks)) => + sinks.head.ref.stripSuffix("_bits") -> Flipped(Decoupled(regenClockType(sinks))) + }).get + + override val elements = ListMap((Seq(clockElement) ++ wireElements ++ rvElements):_*) override def cloneType: this.type = new SimWrapperChannels(chAnnos, bridgeAnnos, leafTypeMap).asInstanceOf[this.type] } @@ -236,8 +257,8 @@ class SimWrapper(config: SimWrapperConfig)(implicit val p: Parameters) extends M // Remove all FCAs that are loopback channels. All non-loopback FCAs connect // to bridges and will be presented in the SimWrapper's IO val bridgeChAnnos = chAnnos.collect({ - case fca @ FAMEChannelConnectionAnnotation(_,_,_,None) => fca - case fca @ FAMEChannelConnectionAnnotation(_,_,None,_) => fca + case fca @ FAMEChannelConnectionAnnotation(_,_,_,_,None) => fca + case fca @ FAMEChannelConnectionAnnotation(_,_,_,None,_) => fca }) val channelPorts = IO(new SimWrapperChannels(bridgeChAnnos, bridgeAnnos, leafTypeMap)) @@ -249,7 +270,7 @@ class SimWrapper(config: SimWrapperConfig)(implicit val p: Parameters) extends M }) target.io.hostReset := reset.toBool - target.io.clock := clock + target.io.hostClock := clock import chisel3.ExplicitCompileOptions.NotStrict // FIXME def getPipeChannelType(chAnno: FAMEChannelConnectionAnnotation): ChLeafType = { @@ -279,6 +300,17 @@ class SimWrapper(config: SimWrapperConfig)(implicit val p: Parameters) extends M channel } + @chiselName + def genClockChannel(chAnno: FAMEChannelConnectionAnnotation): Unit = { + val clockTokens = channelPorts.clockElement._2 + target.io.clockElement._2.valid := clockTokens.valid + clockTokens.ready := target.io.clockElement._2.ready + target.io.clockElement._2.bits match { + case port: Clock => port := clockTokens.bits(0).asClock + case port: ClockRecord => port.elements.zip(clockTokens.bits).foreach({ case ((_, p), i) => p := i.asClock}) + } + } + // Helper functions to attach legacy SimReadyValidIO to true, dual-channel implementations of target ready-valid def bindRVChannelEnq[T <: Data](enq: SimReadyValidIO[T], port: TargetRVPortType): Unit = { val (fwdPort, revPort) = port @@ -314,12 +346,6 @@ class SimWrapper(config: SimWrapperConfig)(implicit val p: Parameters) extends M def genReadyValidChannel(chAnno: FAMEChannelConnectionAnnotation): ReadyValidChannel[Data] = { val chName = chAnno.globalName val strippedName = chName.stripSuffix("_fwd") - // Determine which bridge this channel belongs to by looking it up with the valid - //val bridgeClockRatio = io.bridges.find(_(rvInterface.valid)) match { - // case Some(bridge) => bridge.clockRatio - // case None => UnityClockRatio - //} - val bridgeClockRatio = UnityClockRatio // TODO: FIXME // A channel is considered "flipped" if it's sunk by the tranformed RTL (sourced by an bridge) val channel = Module(new ReadyValidChannel(getReadyValidChannelType(chAnno).cloneType)) @@ -346,11 +372,16 @@ class SimWrapper(config: SimWrapperConfig)(implicit val p: Parameters) extends M // Generate all ready-valid channels val rvChannels = chAnnos.collect({ - case ch @ FAMEChannelConnectionAnnotation(_,fame.DecoupledForwardChannel(_,_,_,_),_,_) => genReadyValidChannel(ch) + case ch @ FAMEChannelConnectionAnnotation(_,fame.DecoupledForwardChannel(_,_,_,_),_,_,_) => genReadyValidChannel(ch) }) // Generate all wire channels, excluding reset chAnnos.collect({ - case ch @ FAMEChannelConnectionAnnotation(name, fame.PipeChannel(latency),_,_) => genPipeChannel(ch, latency) + case ch @ FAMEChannelConnectionAnnotation(name, fame.PipeChannel(latency),_,_,_) => genPipeChannel(ch, latency) }) + + // Generate clock channels + val clockChannels = chAnnos.collect({case ch @ FAMEChannelConnectionAnnotation(_, fame.TargetClockChannel(_),_,_,_) => ch }) + require(clockChannels.size == 1) + genClockChannel(clockChannels.head) } diff --git a/sim/midas/src/main/scala/midas/models/dram/FASEDMemoryTimingModel.scala b/sim/midas/src/main/scala/midas/models/dram/FASEDMemoryTimingModel.scala index ab08a004..07209d05 100644 --- a/sim/midas/src/main/scala/midas/models/dram/FASEDMemoryTimingModel.scala +++ b/sim/midas/src/main/scala/midas/models/dram/FASEDMemoryTimingModel.scala @@ -166,6 +166,7 @@ class FuncModelProgrammableRegs extends Bundle with HasProgrammableRegisters { class FASEDTargetIO(implicit val p: Parameters) extends Bundle { val axi4 = Flipped(new NastiIO) val reset = Input(Bool()) + val clock = Input(Clock()) } class MemModelIO(implicit val p: Parameters) extends WidgetIO()(p){ @@ -566,9 +567,10 @@ class FASEDBridge(argument: CompleteConfig)(implicit p: Parameters) } object FASEDBridge { - def apply(axi4: AXI4Bundle, reset: Bool, cfg: CompleteConfig)(implicit p: Parameters): FASEDBridge = { + def apply(clock: Clock, axi4: AXI4Bundle, reset: Bool, cfg: CompleteConfig)(implicit p: Parameters): FASEDBridge = { val ep = Module(new FASEDBridge(cfg)(p.alterPartial({ case NastiKey => cfg.axi4Widths }))) ep.io.reset := reset + ep.io.clock := clock import chisel3.ExplicitCompileOptions.NotStrict ep.io.axi4 <> axi4 ep diff --git a/sim/midas/src/main/scala/midas/passes/AssertPass.scala b/sim/midas/src/main/scala/midas/passes/AssertPass.scala index 0ed8bddb..916220e6 100644 --- a/sim/midas/src/main/scala/midas/passes/AssertPass.scala +++ b/sim/midas/src/main/scala/midas/passes/AssertPass.scala @@ -2,9 +2,10 @@ package midas package passes import java.io.{File, FileWriter, Writer} +import scala.collection.mutable import firrtl._ -import firrtl.annotations.{CircuitName, ModuleName, ComponentName, ModuleTarget} +import firrtl.annotations._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.WrappedExpression._ @@ -23,16 +24,18 @@ private[passes] class AssertPass( (implicit p: Parameters) extends firrtl.Transform { def inputForm = LowForm def outputForm = HighForm - override def name = "[MIDAS] Assertion Synthesis" + override def name = "[Golden Gate] Assertion Synthesis" - type Asserts = collection.mutable.HashMap[String, (Int, String)] + type Asserts = collection.mutable.HashMap[String, (Int, String, String)] type Messages = collection.mutable.HashMap[Int, String] private val asserts = collection.mutable.HashMap[String, Asserts]() private val messages = collection.mutable.HashMap[String, Messages]() - private val assertPorts = collection.mutable.HashMap[String, Port]() + private val assertPorts = collection.mutable.HashMap[String, (Port, Seq[ReferenceTarget])]() private val excludeInstAsserts = collection.mutable.HashSet[(String, String)]() + private def getNameOrCroak(ns: Namespace, name: String): String = { assert(ns.tryName(name)); name } + // Helper method to filter out module instances private def excludeInst(excludes: Seq[(String, String)]) (parentMod: String, inst: String): Boolean = @@ -44,11 +47,16 @@ private[passes] class AssertPass( namespace: Namespace) (s: Statement): Statement = s map synAsserts(mname, namespace) match { - case s: Stop if s.ret != 0 && !weq(s.en, zero) => + case Stop(info, ret, clk, en) if ret != 0 && !weq(en, zero) => val idx = asserts(mname).size val name = namespace newName s"assert_$idx" - asserts(mname)(s.en.serialize) = idx -> name - DefNode(s.info, name, s.en) + val clockName = clk match { + case Reference(name, _) => name + case WRef(name, _, _, _) => name + case o => throw new RuntimeException(s"$clk") + } + asserts(mname)(en.serialize) = (idx, name, clockName) + DefNode(info, name, en) case s => s } @@ -56,12 +64,12 @@ private[passes] class AssertPass( private def findMessages(mname: String) (s: Statement): Statement = s map findMessages(mname) match { - // work around a large design assert build failure + // work around a large design assert build failure // drop arguments and just show the format string //case s: Print if s.args.isEmpty => case s: Print => asserts(mname) get s.en.serialize match { - case Some((idx, str)) => + case Some((idx, str, _)) => messages(mname)(idx) = s.string.serialize EmptyStmt case _ => s @@ -69,17 +77,12 @@ private[passes] class AssertPass( case s => s } - private def transform(meta: StroberMetaData) - (m: DefModule): DefModule = { - val namespace = Namespace(m) - asserts(m.name) = new Asserts - messages(m.name) = new Messages - - def getChildren(ports: collection.mutable.Map[String, Port], + private class ModuleAssertInfo(m: Module, meta: StroberMetaData, mT: ModuleTarget) { + def getChildren(ports: collection.mutable.Map[String, (Port, Seq[ReferenceTarget])], instExcludes: collection.mutable.HashSet[(String, String)]) = { (meta.childInsts(m.name) .filterNot(instName => excludeInst(instExcludes.toSeq)(m.name, instName)) - .foldRight(Seq[(String, Port)]())((x, res) => + .foldRight(Seq[(String, (Port, Seq[ReferenceTarget]))]())((x, res) => ports get meta.instModMap(x -> m.name) match { case None => res case Some(p) => res :+ (x -> p) @@ -88,90 +91,155 @@ private[passes] class AssertPass( ) } - (m map synAsserts(m.name, namespace) - map findMessages(m.name)) match { - case m: Module => - val ports = collection.mutable.ArrayBuffer[Port]() - val stmts = collection.mutable.ArrayBuffer[Statement]() - // Connect asserts - val assertChildren = getChildren(assertPorts, excludeInstAsserts) - val assertWidth = asserts(m.name).size + ((assertChildren foldLeft 0)( - (res, x) => res + firrtl.bitWidth(x._2.tpe).toInt)) - if (assertWidth > 0) { - val tpe = UIntType(IntWidth(assertWidth)) - val port = Port(NoInfo, namespace.newName("midasAsserts"), Output, tpe) - val stmt = Connect(NoInfo, WRef(port.name), cat( - (assertChildren map (x => wsub(wref(x._1), x._2.name))) ++ - (asserts(m.name).values.toSeq sortWith (_._1 > _._1) map (x => wref(x._2))))) - assertPorts(m.name) = port - ports += port - stmts += stmt - } - m.copy(ports = m.ports ++ ports.toSeq, body = Block(m.body +: stmts.toSeq)) - case m: ExtModule => m - } + val assertChildren = getChildren(assertPorts, excludeInstAsserts) + val assertWidth = asserts(m.name).size + ((assertChildren foldLeft 0)({ + case (sum, (_, (assertP, _))) => sum + firrtl.bitWidth(assertP.tpe).toInt + })) + + def hasAsserts(): Boolean = assertWidth > 0 + + // Get references to assertion ports on all child instances + lazy val (childAsserts, childAssertClocksRTs): (Seq[WSubField], Seq[Seq[ReferenceTarget]]) = + (for ((childInstName, (assertPort, clockRTs)) <- assertChildren) yield { + val childWidth = firrtl.bitWidth(assertPort.tpe).toInt + val assertRef = wsub(wref(childInstName), assertPort.name) + val clockRefs = clockRTs.map(_.addHierarchy(m.name, childInstName)) + (assertRef, clockRefs) + }).unzip + + // Get references to all module-local synthesized assertions + val sortedLocalAsserts = asserts(m.name).values.toSeq.sortWith (_._1 > _._1) + val (localAsserts, localClocks) = sortedLocalAsserts.map({ case (_, en, clk) => (wref(en), mT.ref(clk)) }).unzip + + def allAsserts = childAsserts ++ localAsserts + def allClocks = childAssertClocksRTs.flatten ++ localClocks + def assertUInt = UIntType(IntWidth(assertWidth)) } - private var assertNum = 0 - def dump(writer: Writer, meta: StroberMetaData, mod: String, path: String) { - asserts(mod).values.toSeq sortWith (_._1 < _._1) foreach { case (idx, _) => - writer write s"[id: $assertNum, module: $mod, path: $path]\n" - writer write (messages(mod)(idx) replace ("""\n""", "\n")) - writer write "0\n" - assertNum += 1 + + private def replaceStopsAndFindMessages(m: DefModule): DefModule = m match { + case m: Module => + val namespace = Namespace(m) + asserts(m.name) = new Asserts + messages(m.name) = new Messages + m.map(synAsserts(m.name, namespace)) + .map(findMessages(m.name)) + case m: ExtModule => m + } + + private def wireSynthesizedAssertions(meta: StroberMetaData, cT: CircuitTarget) + (m: DefModule): DefModule = m match { + case m: Module => + val ports = collection.mutable.ArrayBuffer[Port]() + val stmts = collection.mutable.ArrayBuffer[Statement]() + val mT = cT.module(m.name) + val mInfo = new ModuleAssertInfo(m, meta, mT) + // Connect asserts + if (mInfo.hasAsserts) { + val namespace = Namespace(m) + val tpe = mInfo.assertUInt + val port = Port(NoInfo, namespace.newName("midasAsserts"), Output, tpe) + val assertConnect = Connect(NoInfo, WRef(port.name), cat(mInfo.allAsserts.reverse)) + assertPorts(m.name) = (port, mInfo.allClocks) + ports += port + stmts += assertConnect } - meta.childInsts(mod) - .filterNot(inst => excludeInstAsserts((mod, inst))) - .foreach(child => dump(writer, meta, meta.instModMap(child, mod), s"${path}.${child}")) - + m.copy(ports = m.ports ++ ports.toSeq, body = Block(m.body +: stmts.toSeq)) + case m: ExtModule => m } - def synthesizeAsserts(state: CircuitState): CircuitState = { - val c = state.circuit + def formatMessages(meta: StroberMetaData, topModule: String): Seq[String] = { + val formattedMessages = new mutable.ArrayBuffer[String]() + def dump(mod: String, path: String): Unit = { + formattedMessages ++= (asserts(mod).values.toSeq).sortWith(_._1 > _._1).map({ case (idx, _, _) => + s"module: $mod, path: $path]\n" + (messages(mod)(idx) replace ("""\n""", "\n")) + }) + meta.childInsts(mod) + .filterNot(inst => excludeInstAsserts((mod, inst))) + .foreach(child => dump(meta.instModMap(child, mod), s"${path}.${child}")) + } + dump(topModule, topModule) + formattedMessages.toSeq + } + + def synthesizeAsserts(state: CircuitState): CircuitState = { + + // Step 1: Grab module-based exclusions state.annotations.collect { case a @ (_: ExcludeInstanceAssertsAnnotation) => excludeInstAsserts += a.target } + // Step 2: Replace stop statements (asserts) and find associated message + val c = state.circuit.copy(modules = state.circuit.modules.map(replaceStopsAndFindMessages)) + val topModule = c.modules.collectFirst({ case m: Module if m.name == c.main => m }).get + val topMT = ModuleTarget(c.main, c.main) + val namespace = Namespace(topModule) + + // Step 3: Wire assertions to the top-level val meta = StroberMetaData(c) - val mods = postorder(c, meta)(transform(meta)) - val f = new FileWriter(new File(dir, s"${c.main}.asserts")) - dump(f, meta, c.main, c.main) - f.close + val mods = postorder(c, meta)(wireSynthesizedAssertions(meta, CircuitTarget(c.main))) + val formattedMessages = formatMessages(meta, c.main) - println(s"[MIDAS] total # of assertions synthesized: $assertNum") + val mInfo = new ModuleAssertInfo(topModule, meta, topMT) + println(s"[Golden Gate] total # of assertions synthesized: ${mInfo.assertWidth}") - val mName = ModuleName(c.main, CircuitName(c.main)) - val assertAnnos = if (assertNum > 0) { - val portName = assertPorts(c.main).name - val portRT = ModuleTarget(c.main, c.main).ref(portName) - val fcca = FAMEChannelConnectionAnnotation( - globalName = portName, - channelInfo = WireChannel, - sources = Some(Seq(portRT)), - sinks = None) + if (!mInfo.hasAsserts) state else { + // Step 4: Associate each assertion with a source clock + val postWiredState = state.copy(circuit = c.copy(modules = mods), form = MidForm) + val loweredState = Seq(new ResolveAndCheck, new HighFirrtlToMiddleFirrtl, new MiddleFirrtlToLowFirrtl).foldLeft(postWiredState)((state, xform) => xform.transform(state)) + val finder = new ClockSourceFinder(loweredState) + val clockMapping = mInfo.allClocks.map(cT => cT -> finder.findRootDriver(cT)).toMap + val rootClocks = mInfo.allClocks.map(clockMapping).flatten - val bridgeAnno = BridgeIOAnnotation( - target = portRT, - widget = Some((p: Parameters) => new AssertBridgeModule(assertNum)(p)), - channelMapping = Map("" -> portName) + // For each clock in clock channel, list associated assert indices + val groupedAsserts = rootClocks.zipWithIndex.groupBy(_._1).mapValues(values => values.map(_._2)) + + // Step 5: Re-wire the top-level module + val ports = collection.mutable.ArrayBuffer[Port]() + val stmts = collection.mutable.ArrayBuffer[Statement]() + val assertAnnos = collection.mutable.ArrayBuffer[Annotation]() + + // Step 5a: Connect all assertions to a single wire to match the order of our previous analysis + val allAssertsWire = DefWire(NoInfo, "allAsserts", mInfo.assertUInt) + val allAssertConnect = Connect(NoInfo, WRef(allAssertsWire), cat(mInfo.allAsserts.reverse)) + stmts ++= Seq(allAssertsWire, allAssertConnect) + + // Step 5b: Generate unique ports for each clock + for ((clockRT, asserts) <- groupedAsserts) { + val portName = namespace.newName(s"midasAsserts_${clockRT.ref}") + val clockPortName = namespace.newName(s"midasAsserts_${clockRT.ref}_clock") + val tpe = UIntType(IntWidth(asserts.size)) + val port = Port(NoInfo, portName, Output, tpe) + val clockPort = Port(NoInfo, clockPortName, Output, ClockType) + ports ++= Seq(port, clockPort) + val bitExtracts = asserts.map(idx => DoPrim(PrimOps.Bits, Seq(WRef(allAssertsWire)), Seq(idx, idx), UIntType(IntWidth(1)))) + val connectAsserts = Connect(NoInfo, WRef(port), cat(bitExtracts.reverse)) + val connectClock = Connect(NoInfo, WRef(clockPort), WRef(clockRT.ref)) + stmts ++= Seq(connectClock, connectAsserts) + + // Generate the bridge Annotation + val portRT = ModuleTarget(c.main, c.main).ref(portName) + val clockPortRT = ModuleTarget(c.main, c.main).ref(clockPortName) + val fcca = FAMEChannelConnectionAnnotation.source(portName, WireChannel, Some(clockPortRT), Seq(portRT)) + val assertMessages = asserts.map(formattedMessages(_)) + val bridgeAnno = BridgeIOAnnotation( + target = portRT, + widget = Some((p: Parameters) => new AssertBridgeModule(assertMessages)(p)), + channelMapping = Map("" -> portName) + ) + assertAnnos ++= Seq(fcca, bridgeAnno) + } + val wiredTopModule = topModule.copy(ports = topModule.ports ++ ports, + body = Block(topModule.body +: stmts.toSeq)) + + state.copy( + circuit = c.copy(modules = wiredTopModule +: mods.filterNot(_.name == c.main)), + form = HighForm, + annotations = state.annotations ++ assertAnnos ) - - Seq(fcca, bridgeAnno) - } else { - Seq() } - - state.copy( - circuit = c.copy(modules = mods), - form = HighForm, - annotations = state.annotations ++ assertAnnos) } def execute(state: CircuitState): CircuitState = { - if (p(SynthAsserts)) synthesizeAsserts(state) else { - // Still need to touch the file. - val f = new FileWriter(new File(dir, s"${state.circuit.main}.asserts")) - f.close - state - } + if (p(SynthAsserts)) synthesizeAsserts(state) else state } } diff --git a/sim/midas/src/main/scala/midas/passes/AutoCounterCoverTransform.scala b/sim/midas/src/main/scala/midas/passes/AutoCounterCoverTransform.scala index 0efb0df3..b1a6d959 100644 --- a/sim/midas/src/main/scala/midas/passes/AutoCounterCoverTransform.scala +++ b/sim/midas/src/main/scala/midas/passes/AutoCounterCoverTransform.scala @@ -1,23 +1,23 @@ // See LICENSE for license details. -package midas -package passes +package midas.passes import firrtl._ import firrtl.ir._ import firrtl.passes._ import firrtl.passes.wiring._ -import firrtl.Utils.throwInternalError +import firrtl.Utils.{throwInternalError, BoolType, one} import firrtl.annotations._ import firrtl.analyses.InstanceGraph import firrtl.transforms.TopWiring._ import freechips.rocketchip.util.property._ import freechips.rocketchip.util.WideCounter import freechips.rocketchip.config.{Parameters, Field} +import midas.{EnableAutoCounter, AutoCounterUsePrintfImpl} import midas.widgets._ import midas.targetutils._ import midas.passes.Utils.{widx, wsub} -import midas.passes.fame.FAMEChannelConnectionAnnotation +import midas.passes.fame.{WireChannel, FAMEChannelConnectionAnnotation, And, Or, Negate} import java.io._ import scala.io.Source @@ -31,347 +31,236 @@ class FireSimPropertyLibrary extends BasePropertyLibrary { def generateProperty(prop_param: BasePropertyParameters)(implicit sourceInfo: SourceInfo) { //requireIsHardware(prop_param.cond, "condition covered for counter is not hardware!") if (!(prop_param.cond.isLit) && chisel3.experimental.DataMirror.internal.isSynthesizable(prop_param.cond)) { - annotate(new ChiselAnnotation { def toFirrtl = AutoCounterCoverAnnotation(prop_param.cond.toNamed, prop_param.label, prop_param.message) }) + dontTouch(prop_param.cond) + dontTouch(chisel3.Module.reset) + dontTouch(chisel3.Module.clock) + annotate(new ChiselAnnotation { + val implicitClock = chisel3.Module.clock + val implicitReset = chisel3.Module.reset + def toFirrtl = AutoCounterFirrtlAnnotation(prop_param.cond.toNamed, + implicitClock.toNamed.toTarget, + implicitReset.toNamed.toTarget, + prop_param.label, + prop_param.message, + coverGenerated = true) + }) } } } //========================================================================= + /** -Take the annotated cover points and convert them to counters -**/ -class AutoCounterTransform(dir: File = new File("/tmp/"), printcounter: Boolean = false) + * Take the annotated cover points and convert them to counters + */ +class AutoCounterTransform(dir: File = new File("/tmp/")) (implicit p: Parameters) extends Transform with AutoCounterConsts { def inputForm: CircuitForm = LowForm - def outputForm: CircuitForm = LowForm - override def name = "[FireSim] AutoCounter Cover Transform" - val newAnnos = mutable.ArrayBuffer.empty[Annotation] - val autoCounterLabels = mutable.ArrayBuffer.empty[String] - val autoCounterMods = mutable.ArrayBuffer.empty[Module] - val autoCounterLabelsSourceMap = mutable.Map.empty[String, String] - val autoCounterReadableLabels = mutable.ArrayBuffer.empty[String] - val autoCounterReadableLabelsMap = mutable.Map.empty[String, String] - val autoCounterPortsMap = mutable.Map.empty[Int, String] - val autoCounterLabelsMap = mutable.Map.empty[String, String] + def outputForm: CircuitForm = MidForm + override def name = "[Golden Gate] AutoCounter Cover Transform" - private def makeCounter(label: String, hasTracerWidget: Boolean = false): CircuitState = { - import chisel3._ - import chisel3.core.MultiIOModule - import chisel3.experimental.ChiselAnnotation - import chisel3.util.experimental.BoringUtils - def countermodule() = new MultiIOModule { - //override def desiredName = "AutoCounter" - val in0 = IO(Input(Bool())) + val enableTransform = p(EnableAutoCounter) + val usePrintfImplementation = p(AutoCounterUsePrintfImpl) - //connect the trigger from the tracer widget - val trigger = WireDefault(true.B) - hasTracerWidget match { - case true => BoringUtils.addSink(trigger, s"trace_trigger") - case _ => trigger := true.B - } - //*****if we ever want to use the trigger to reset the counters****** - //withReset(reset = ~trigger) { - // val countfun = WideCounter(64, in0) - //} - - val countfun = WideCounter(64, in0) - val count = Wire(UInt(64.W)) - count := countfun - if (printcounter) { - when (in0 & trigger) { - printf(midas.targetutils.SynthesizePrintf(s"[AutoCounter] $label: %d\n",count)) - } - } else { - chisel3.core.dontTouch(count) - chisel3.experimental.annotate(new ChiselAnnotation { def toFirrtl = TopWiringAnnotation(count.toNamed, s"autocounter_") }) - //********In the future, when BoringUtils will be more rubust with TargetRefs*********** - //autoCounterLabels ++= Seq(s"AutoCounter_$label") - //BoringUtils.addSource(count, s"AutoCounter_$label") - } - } - - val chiselIR = chisel3.Driver.elaborate(() => countermodule()) - val annos = chiselIR.annotations.map(_.toFirrtl) - val firrtlIR = chisel3.Driver.toFirrtl(chiselIR) - val lowFirrtlIR = (new LowFirrtlCompiler()).compile(CircuitState(firrtlIR, ChirrtlForm, annos), Seq()) - lowFirrtlIR + // Gates each auto-counter event with the associated reset, moving the + // annotation's target to point at the new boolean (updatedAnnos). + // This is used in both implementation strategies. + private def gateEventsWithReset(coverTupleAnnoMap: Map[String, Seq[AutoCounterFirrtlAnnotation]], + updatedAnnos: mutable.ArrayBuffer[AutoCounterFirrtlAnnotation]) + (mod: DefModule): DefModule = mod match { + case m: Module if coverTupleAnnoMap.isDefinedAt(m.name) => + val coverAnnos = coverTupleAnnoMap(m.name) + val mT = coverAnnos.head.enclosingModuleTarget + val moduleNS = Namespace(mod) + val addedStmts = coverAnnos.flatMap({ anno => + val eventName = moduleNS.newName(anno.label) + updatedAnnos += anno.copy(target = mT.ref(eventName)) + Seq(DefWire(NoInfo, eventName, BoolType), + Connect(NoInfo, WRef(eventName), And(Negate(WRef(anno.reset.ref)), WRef(anno.target.ref)))) + }) + m.copy(body = Block(m.body, addedStmts:_*)) + case o => o } - private def onModule(topNS: Namespace, covertuples: Seq[(ReferenceTarget, String)], hasTracerWidget: Boolean = false)(mod: Module): Seq[Module] = { + private def onModulePrintfImpl(coverTupleAnnoMap: Map[String, Seq[AutoCounterFirrtlAnnotation]], + addedAnnos: mutable.ArrayBuffer[Annotation]) + (mod: DefModule): DefModule = mod match { + case m: Module if coverTupleAnnoMap.isDefinedAt(m.name) => + val coverAnnos = coverTupleAnnoMap(m.name) + val mT = coverAnnos.head.enclosingModuleTarget + val moduleNS = Namespace(mod) + val addedStmts = new mutable.ArrayBuffer[Statement] - val namespace = Namespace(mod) - val resetRef = mod.ports.collectFirst { case Port(_,"reset",_,_) => WRef("reset") } + val countType = UIntType(IntWidth(64)) + val zeroLit = UIntLiteral(0, IntWidth(64)) + val oneLit = UIntLiteral(1, IntWidth(64)) - //need some mutable lists - val newMods = mutable.ArrayBuffer.empty[Module] - val newInsts = mutable.ArrayBuffer.empty[WDefInstance] - val newCons = mutable.ArrayBuffer.empty[Connect] + addedStmts ++= coverAnnos.flatMap({ case AutoCounterFirrtlAnnotation(target, clock, reset, label, _, _) => + val countName = moduleNS.newName(label + "_counter") + val count = DefRegister(NoInfo, countName, countType, WRef(clock.ref), WRef(reset.ref), zeroLit) + val plusOneName = moduleNS.newName(label + "_plusOne") + val plusOne = DefNode(NoInfo, plusOneName, DoPrim(PrimOps.Add, Seq(WRef(count), oneLit), Seq.empty, countType)) + val countUpdate = Connect(NoInfo, WRef(count), Mux(WRef(target.ref), WRef(plusOne), WRef(count), countType)) - //for each annotated signal within this module - covertuples.foreach { case (target, label) => - //create counter - val countermodstate = makeCounter(label, hasTracerWidget = hasTracerWidget) - val countermod = countermodstate.circuit.modules match { - case Seq(one: firrtl.ir.Module) => one - case other => throwInternalError(s"Invalid resulting modules ${other.map(_.name)}") - } - //add to new modules list that will be added to the circuit - val newmodulename = topNS.newName(countermod.name) - val countermodn = countermod.copy(name = newmodulename) - val maincircuitname = target.circuitOpt.get - val renamemap = RenameMap(Map(ModuleTarget(countermodstate.circuit.main, countermod.name) -> Seq(ModuleTarget(maincircuitname, newmodulename)))) - newMods += countermodn - autoCounterMods += countermodn - autoCounterLabelsMap += countermodn.name -> label - newAnnos ++= countermodstate.annotations.toSeq.flatMap { case anno => anno.update(renamemap) } - //instantiate the counter - val instName = namespace.newName(s"autocounter_" + label) // Helps debug - val inst = WDefInstance(NoInfo, instName, countermodn.name, UnknownType) - //add to new instances list that will be added to the block - newInsts += inst + // Generate a trigger sink and annotate it + val triggerName = moduleNS.newName("trigger") + val trigger = DefWire(NoInfo, triggerName, BoolType) + addedStmts ++= Seq(trigger, Connect(NoInfo, WRef(trigger), one)) + addedAnnos += TriggerSinkAnnotation(mT.ref(triggerName), clock) - //create input connection to the counter - val wcons = { - val lhs = WSubField(WRef(inst.name),"in0") - val rhs = WRef(target.name) - Connect(NoInfo, lhs, rhs) - } - newCons += wcons - - val clocks = mod.ports.collect({ case Port(_,name,_,ClockType) => name}) - //create clock connection to the counter - val clkCon = { - val lhs = WSubField(WRef(inst.name), "clock") - val rhs = WRef(clocks(0)) - Connect(NoInfo, lhs, rhs) - } - newCons += clkCon - - //create reset connection to the counter - val reset = resetRef.getOrElse(UIntLiteral(0, IntWidth(1))) - val resetCon = Connect(NoInfo, WSubField(WRef(inst.name), "reset"), reset) - newCons += resetCon - } - - //add new block of statements to the module (with the new instantiations and connections) - val bodyx = Block(mod.body +: (newInsts ++ newCons)) - Seq(mod.copy(body = bodyx)) ++ newMods - } - - private def fixupCircuit(instate: CircuitState): CircuitState = { - val xforms = Seq( - //new WiringTransform, - new ResolveAndCheck - ) - (xforms foldLeft instate)((in, xform) => - xform runTransform in).copy(form=outputForm) - } - - //create the appropriate perf counters target widget - private def makeAutoCounterWidget(topNS: Namespace, numCounters: Int, maincircuit: Circuit, hasTracerWidget: Boolean = false): (ExtModule, Seq[Annotation]) = { - val bridgeName = topNS.newName("AutoCounterBridge") - val bridgeMT = ModuleTarget(maincircuit.main, bridgeName) - val ioName = "counters" - val portList = Seq.tabulate(numCounters)(idx => - Port(NoInfo, s"${ioName}_${idx}", Input, UIntType(IntWidth(counterWidth)))) - val extModule = ExtModule(NoInfo, bridgeName, portList, bridgeName, Seq()) - val channelAnnos = portList.map(port => - FAMEChannelConnectionAnnotation.source(port.name, fame.WireChannel, Seq(bridgeMT.ref(port.name)))) - val bridgeCtorArg = AutoCounterBridgeConstArgs(numCounters, autoCounterPortsMap, hasTracerWidget) - val bridgeAnnotation = InMemoryBridgeAnnotation(bridgeMT, channelAnnos.map(_.globalName), - (p: Parameters) => new AutoCounterBridgeModule(bridgeCtorArg)(p)) - (extModule, bridgeAnnotation +: channelAnnos) + // Now emit a printf using all the generated hardware + val printFormat = StringLit(s"""[AutoCounter] $label: %d\n""") + val printStmt = Print(NoInfo, printFormat, Seq(WRef(count)), + WRef(clock.ref), And(WRef(trigger), WRef(target.ref))) + addedAnnos += SynthPrintfAnnotation(Seq(Seq(mT.ref(countName))), mT, printFormat.string, Some(target.ref + "_print")) + Seq(count, plusOne, printStmt, countUpdate) + }) + m.copy(body = Block(m.body, addedStmts:_*)) + case o => o } - private def CreateTopCounterSources(instancepaths: Seq[Seq[WDefInstance]], state: CircuitState, topnamespace: Namespace): Seq[Statement] = { + private def implementViaPrintf( + state: CircuitState, + eventModuleMap: Map[String, Seq[AutoCounterFirrtlAnnotation]]): CircuitState = { - instancepaths.flatMap { case instpath => - val instpathnames = instpath.map {case WDefInstance(_,name,_,_) => name} - val path = instpathnames.tail.tail.mkString("_") - val portname = s"autocounter_" + path + "_count" - val fullportname = instpathnames.tail.head + "." + portname - val mod = instpath.last.module // WDefInstance.module is a string, should be equivalent to the module name - val oldlabel = autoCounterLabelsMap(mod) - val readablelabel = oldlabel + s"[" + instpathnames.dropRight(1).mkString(".") + s"]" - val newlabel = oldlabel + s"_" + instpathnames.dropRight(1).mkString("_") + val addedAnnos = new mutable.ArrayBuffer[Annotation]() + val updatedModules = state.circuit.modules.map( + onModulePrintfImpl(eventModuleMap, addedAnnos)) + state.copy(circuit = state.circuit.copy(modules = updatedModules), + annotations = state.annotations ++ addedAnnos) + } - //When the wiring transform gets fixed - //======================================================= - //newAnnos += SourceAnnotation(ModuleTarget(state.circuit.main, instpath.head.module).ref(fullportname).toNamed, s"AutoCounter_$newlabel") - //======================================================= - //Instead of wiring transform, manually connect the counter wire source from the DUT side - val sourceref = WSubField(WRef(instpathnames.tail.head), portname) - val wirename = topnamespace.newName(newlabel) - val medref = WRef(wirename) - autoCounterLabelsSourceMap += newlabel -> wirename - autoCounterReadableLabelsMap += newlabel -> readablelabel - Seq(DefWire(NoInfo, wirename, UIntType(IntWidth(64))), Connect(NoInfo, medref, sourceref)) - } - } + private def implementViaBridge( + state: CircuitState, + eventModuleMap: Map[String, Seq[AutoCounterFirrtlAnnotation]]): CircuitState = { + val labelMap = eventModuleMap.values.flatten.map(anno => anno.target -> anno.label).toMap + val bridgeTopWiringAnnos = eventModuleMap.values.flatten.map( + anno => BridgeTopWiringAnnotation(anno.target, anno.clock)) - //count the number of generated perf counters - //create an appropriate widget IO - //wire the counters to the widget - private def AddAutoCounterWidget(state: CircuitState, hasTracerWidget: Boolean = false): CircuitState = { + // Step 1: Call BridgeTopWiring, grouping all events by their source clock + val topWiringPrefix = "autocounter" + val wiredState = (new BridgeTopWiring(topWiringPrefix + "_")).execute( + state.copy(annotations = state.annotations ++ bridgeTopWiringAnnos)) + val outputAnnos = wiredState.annotations.collect({ case a: BridgeTopWiringOutputAnnotation => a }) + val groupedOutputs = outputAnnos.groupBy(_.srcClockPort) - //need to "remember/save" the "old/original" top level module, since the TopWiring transform - //punches signals out all the way to the top, and we want them punched out only - //through the DUT - val oldinstanceGraph = new InstanceGraph(state.circuit) - val top = oldinstanceGraph.moduleOrder.head.asInstanceOf[Module] + // Step 2: For each group of wired events, generate associated bridge annotations + val c = wiredState.circuit + val topModule = c.modules.collectFirst({ case m: Module if m.name == c.main => m }).get + val topMT = ModuleTarget(c.main, c.main) + val topNS = Namespace(topModule) + val addedPorts = mutable.ArrayBuffer[Port]() + val addedStmts = mutable.ArrayBuffer[Statement]() + val bridgeAnnos = for ((srcClockRT, oAnnos) <- groupedOutputs.toSeq.sortBy(_._1.ref)) yield { + val sinkClockRT = oAnnos.head.sinkClockPort + val fccas = oAnnos.map({ anno => + FAMEChannelConnectionAnnotation.source( + anno.topSink.ref, + WireChannel, + Some(sinkClockRT), + Seq(anno.topSink)) + }) - //punch out counter signals to the top - def topwiringtransform = new TopWiringTransform - val newstate: CircuitState = topwiringtransform.execute(state.copy(annotations = state.annotations ++ newAnnos)) + val labels = oAnnos.map({ anno => + val pathlessLabel = labelMap(anno.pathlessSource) + val instPath = anno.absoluteSource.circuit +: anno.absoluteSource.asPath.map(_._1.value) + anno.topSink.ref -> (pathlessLabel +: instPath).mkString("_") + }) - //cleanup the topwiring annotations - val srcannos = newAnnos.collect { - case a: TopWiringAnnotation => a - } - newAnnos --= srcannos + // Step 2b. Manually add a boolean channel to carry the trigger signal to the bridge + val triggerPortName = topNS.newName(s"${topWiringPrefix}_triggerEnable") + // Introduce an extra node until TriggerWriring supports port sinks + val triggerPortNode = topNS.newName(s"${topWiringPrefix}_triggerEnable_node") + addedPorts += Port(NoInfo, triggerPortName, Output, BoolType) + // In the event there are no trigger sources, default to enabled + addedStmts ++= Seq( + DefNode(NoInfo, triggerPortNode, one), + Connect(NoInfo, WRef(triggerPortName), WRef(triggerPortNode))) + val triggerPortRT = topMT.ref(triggerPortName) + val triggerFcca = FAMEChannelConnectionAnnotation.source( + triggerPortName, + WireChannel, + Some(sinkClockRT), + Seq(triggerPortRT)) - //Find the relevant ports/wires that were punched out to the top by finding the autocounter instances - val circuit = newstate.circuit - val topnamespace = Namespace(circuit) - val instanceGraph = new InstanceGraph(circuit) + val triggerSinkAnno = TriggerSinkAnnotation(topMT.ref(triggerPortNode), sinkClockRT) - val autocounterinsts = autoCounterMods.flatMap { case mod => instanceGraph.findInstancesInHierarchy(mod.name) } - - val numcounters = autocounterinsts.size - val sourceconnections = CreateTopCounterSources(autocounterinsts, newstate, topnamespace) - - - //create the bridge module (widget) - val (widgetMod, bridgeAnnos) = makeAutoCounterWidget(topnamespace, numcounters, newstate.circuit, hasTracerWidget) - - val topSort = instanceGraph.moduleOrder - - val widgetInstName = topnamespace.newName(s"AutoCounterBridge_inst") // Helps debug - val widgetInst = WDefInstance(NoInfo, widgetInstName, widgetMod.name, UnknownType) - val counterPorts = widgetMod.ports.map(_.name) - - - //When the wiring transform gets fixed - //wiring transform annotation to connect to the counters - //================================================ - //autoCounterLabels.zip(counterports).foreach { - // case(label, counterport) => { - // newAnnos += SinkAnnotation(ModuleTarget(newstate.circuit.main, newtop.name).ref(widgetInst.name).field(counterport).toNamed, label) - // } - //} - //================================================ - - - //Instead of wiring transform, manually connect the counter wire sinks on the Bridge side - val sinkconnections = autoCounterLabelsSourceMap.keys.zipWithIndex.flatMap { - case(label,i) => { - val sinkref = wsub(WRef(widgetInst.name), counterPorts(i)) - val medref = WRef(autoCounterLabelsSourceMap(label)) - //autoCounterPortsMap += i -> autoCounterReadableLabelsMap(label) - autoCounterPortsMap += i -> label - Seq(Connect(NoInfo, sinkref, medref)) - } - } - val newstatements = Seq(widgetInst) ++ sourceconnections ++ sinkconnections - - //update the body of the top level module - val bodyx = Block(top.body +: newstatements) - val newtop = top.copy(body = bodyx) - - newstate.copy( - circuit = newstate.circuit.copy(modules = topSort.tail ++ Seq(newtop, widgetMod)), - annotations = state.annotations ++ bridgeAnnos) - } - - def execute(state: CircuitState): CircuitState = { - - //collect annotation generate by the built in cover points in rocketchip - //(or manually inserted annotations) - val coverannos = state.annotations.collect { - case a: AutoCounterCoverAnnotation => a + val bridgeAnno = BridgeIOAnnotation( + target = topMT.ref(topWiringPrefix), + // We need to pass the name of the trigger port so each bridge can + // disambiguate between them and connect to the correct one in simulation mapping + widget = (p: Parameters) => new AutoCounterBridgeModule(labels, triggerPortName)(p), + channelNames = (triggerFcca +: fccas).map(_.globalName) + ) + Seq(bridgeAnno, triggerSinkAnno, triggerFcca) ++ fccas } - //collect annotations for manually annotated AutoCounter perf counters - val autocounterannos = state.annotations.collect { - case a: AutoCounterFirrtlAnnotation => a - } + val updatedCircuit = c.copy(modules = c.modules.map({ + case m: Module if m.name == c.main => m.copy(ports = m.ports ++ addedPorts, body = Block(m.body, addedStmts:_*)) + case o => o + })) - //identify if there is a TracerV bridge to supply a trigger signal - val hasTracerWidget = p(midas.TraceTrigger) - - //----if we want to identify the tracer widget based on bridge annotations, we can use this code segment, - //----but this requires the transform to be part of the firesim package rather than midas package - //val hasTracerWidget = state.annotations.collect({ case midas.widgets.SerializableBridgeAnnotation(_,_,widget, _) => - // widget match { - // case "firesim.bridges.TracerVBridgeModule" => true - // case _ => Nil - // }}).length > 0 + val cleanedAnnotations = wiredState.annotations.filterNot(outputAnnos.toSet) + CircuitState(updatedCircuit, wiredState.form, cleanedAnnotations ++ bridgeAnnos.flatten) + } + def doTransform(state: CircuitState): CircuitState = { //select/filter which modules do we want to actually look at, and generate counters for //this can be done in one of two way: //1. Using an input file called `covermodules.txt` in a directory declared in the transform concstructor //2. Using chisel annotations to be added in the Platform Config (in SimConfigs.scala). The annotations are // of the form AutoCounterModuleAnnotation("ModuleName") val modulesfile = new File(dir,"autocounter-covermodules.txt") - val filemoduleannos = mutable.ArrayBuffer.empty[AutoCounterCoverModuleFirrtlAnnotation] + val moduleAnnos = new mutable.ArrayBuffer[AutoCounterCoverModuleFirrtlAnnotation]() + val counterAnnos = new mutable.ArrayBuffer[AutoCounterFirrtlAnnotation]() + val remainingAnnos = new mutable.ArrayBuffer[Annotation]() if (modulesfile.exists()) { val sourcefile = scala.io.Source.fromFile(modulesfile.getPath()) val covermodulesnames = (for (line <- sourcefile.getLines()) yield line).toList sourcefile.close() - filemoduleannos ++= covermodulesnames.map {m: String => AutoCounterCoverModuleFirrtlAnnotation(ModuleTarget(state.circuit.main,m))} + moduleAnnos ++= covermodulesnames.map {m: String => AutoCounterCoverModuleFirrtlAnnotation(ModuleTarget(state.circuit.main,m))} + } + state.annotations.foreach { + case a: AutoCounterCoverModuleFirrtlAnnotation => moduleAnnos += a + case a: AutoCounterFirrtlAnnotation => counterAnnos += a + case o => remainingAnnos += o } - val moduleannos = (state.annotations.collect { - case a: AutoCounterCoverModuleFirrtlAnnotation => a - } ++ filemoduleannos).distinct //extract the module names from the methods mentioned previously - val covermodulesnames = moduleannos.map { case AutoCounterCoverModuleFirrtlAnnotation(ModuleTarget(_,m)) => m } + val covermodulesnames = moduleAnnos.map(_.target.module).distinct - if (!covermodulesnames.isEmpty) { - println("[AutoCounter]: Cover modules in AutoCounterTransform:") - println(covermodulesnames) - } - - //filter the cover annotations only by the modules that we want - val filtercoverannos = coverannos.filter{ case AutoCounterCoverAnnotation(ReferenceTarget(_,modname,_,_,_),l,m) => - covermodulesnames.contains(modname) } - - val allcounterannos = filtercoverannos ++ autocounterannos - //group the selected signal by modules, and attach label from the cover point to each signal - val selectedsignals = allcounterannos.map { case AutoCounterCoverAnnotation(target,l,m) => (target, l) - case AutoCounterFirrtlAnnotation(target,l,m) => (target, l) - } - .groupBy { case (ReferenceTarget(_,modname,_,_,_), l) => modname } + //collect annotations for manually annotated AutoCounter perf counters + val filteredCounterAnnos = counterAnnos.filter(_.shouldBeIncluded(covermodulesnames)) + // group the selected signal by modules, and attach label from the cover point to each signal + val selectedsignals = filteredCounterAnnos.groupBy(_.enclosingModule) if (!selectedsignals.isEmpty) { - println("[AutoCounter]: AutoCounter signals are:") - println(selectedsignals) + println("[AutoCounter] AutoCounter signals are:") + selectedsignals.foreach({ case (modName, localEvents) => + println(s" Module ${modName}") + localEvents.foreach({ anno => println(s" ${anno.label}: ${anno.message}") }) + }) - //create counters for each of the Bools in the filtered cover functions - val moduleNamespace = Namespace(state.circuit) - val modulesx: Seq[DefModule] = state.circuit.modules.map { - case mod: Module => - val covertuples = selectedsignals.getOrElse(mod.name, Seq()) - if (!covertuples.isEmpty) { - val mods = onModule(moduleNamespace, covertuples, hasTracerWidget = hasTracerWidget)(mod) - val newMods = mods.filter(_.name != mod.name) - assert(newMods.size + 1 == mods.size) // Sanity check - mods - } else { Seq(mod) } - case ext: ExtModule => Seq(ext) - }.flatten + // Common preprocessing: gate all annotated events with their associated reset + val updatedAnnos = new mutable.ArrayBuffer[AutoCounterFirrtlAnnotation]() + val updatedModules = state.circuit.modules.map((gateEventsWithReset(selectedsignals, updatedAnnos))) + val eventModuleMap = updatedAnnos.groupBy(_.enclosingModule) + val preppedState = state.copy(circuit = state.circuit.copy(modules = updatedModules), + annotations = remainingAnnos) - val statewithwidget = printcounter match { - case true => state.copy(circuit = state.circuit.copy(modules = modulesx), annotations = state.annotations ++ newAnnos) - case _ => AddAutoCounterWidget(state.copy(circuit = state.circuit.copy(modules = modulesx)), hasTracerWidget = hasTracerWidget) + if (usePrintfImplementation) { + implementViaPrintf(preppedState, eventModuleMap) + } else { + implementViaBridge(preppedState, eventModuleMap) } - - fixupCircuit(statewithwidget) } else { state } } -} + def execute(state: CircuitState): CircuitState = { + if (enableTransform) doTransform(state) else state + } +} diff --git a/sim/midas/src/main/scala/midas/passes/BridgeTopWiring.scala b/sim/midas/src/main/scala/midas/passes/BridgeTopWiring.scala new file mode 100644 index 00000000..f7ff5e0f --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/BridgeTopWiring.scala @@ -0,0 +1,212 @@ +package midas.passes + +import scala.collection.mutable + +import firrtl._ +import firrtl.analyses.InstanceGraph +import firrtl.annotations._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.transforms.TopWiring.{TopWiringAnnotation, TopWiringTransform, TopWiringOutputFilesAnnotation} +import TargetToken.{Instance, OfModule} + +import midas.passes.fame.RTRenamer +import midas.targetutils.FAMEAnnotation + +/** + * Provides signals for the transform to wire to the top-level of hte module hierarchy. + * + * @param target The signal to be plumbed to the top + * + * @param clock The clock to which this signal is sychronous. This will _not_ be wired. + * + */ +case class BridgeTopWiringAnnotation(target: ReferenceTarget, clock: ReferenceTarget) extends Annotation with FAMEAnnotation { + def update(renames: RenameMap): Seq[BridgeTopWiringAnnotation] = + Seq(this.copy(RTRenamer.exact(renames)(target), RTRenamer.exact(renames)(clock))) + + def toWiringAnnotation(prefix: String): TopWiringAnnotation = TopWiringAnnotation(target.pathlessTarget.toNamed, prefix) +} + +/** + * Provides reference targets to the newly generated top-level IO and a + * generated output-clock port that output is synchronous to. + * + * @param pathlessSource The original target passed in a [[BridgeTopWiringAnnotation]] + * + * @param absoluteSource An absolute reference target to the particular + * instance of that signal that drives the new output. NB: A single [[BridgeTopWiringAnnotation]] + * will generate as many output annotations as there are instances of the pathless source. + * + * @param topSink The new top-level port the source has been connected to + * + * @param srcClockPort The input clock associated with the source + * + * @param clockPort The new output clock port the topSink port to be referenced in FCCAs + * + */ +case class BridgeTopWiringOutputAnnotation( + pathlessSource: ReferenceTarget, + absoluteSource: ReferenceTarget, + topSink: ReferenceTarget, + srcClockPort: ReferenceTarget, + sinkClockPort: ReferenceTarget) extends Annotation with FAMEAnnotation { + + def update(renames: RenameMap): Seq[BridgeTopWiringOutputAnnotation] = { + val renameExact = RTRenamer.exact(renames) + Seq(BridgeTopWiringOutputAnnotation(renameExact(pathlessSource), + renameExact(absoluteSource), + renameExact(topSink), + renameExact(srcClockPort), + renameExact(sinkClockPort))) + } +} + +object BridgeTopWiring { + class NoClockSourceFoundException(data: ReferenceTarget, clock: ReferenceTarget) + extends Exception(s"Could not determine the source of clock ${clock} for target-to-wire ${data}") +} + +/** + * A utility transform used to implement features that are finely distrubuted + * through out the target, such as assertion and printf synthesis. This + * transform preforms most of the circuit modifications and analysis to emit + * BridgeIOAnnotations and FCCAs directly. For this pass to function + * correctly, the clock bridge must already be extracted. + * + * For each BridgeTopWiringAnnotation, this transform: + * 1) Wires out every instance of that signal to a unique port in the top-level module + * These will be referenced by Bridge FCCAs and will become simulation channels. + * 2) Determines the source clock (these are now inputs on the top-level module) to which each that port is synchronous + * + * For each clock that is synchronous with at least one output port: + * 1) Loop that clock back to a new output port (Bridge FCCAs will point at this clock) + * + * Finally emit a [[BridgeTopWiringOutputAnnotation]] for each created data-output port. + * + * @param prefix Provides the top-wiring prefix + * + */ + +class BridgeTopWiring(val prefix: String) extends firrtl.Transform { + def inputForm = MidForm + def outputForm = MidForm + case class TopWiringMapping(src: ComponentName, instPath: Seq[String]) { + // TopWiring doesn't return references to the top-level ports; we need to reconstruct those + // from an instance path and the prefix provided + val topMT = ModuleTarget(src.module.circuit.name, src.module.circuit.name) + + // The RT to then new top-level port + def portRT(): ReferenceTarget = topMT.ref(prefix + (instPath :+ src.name).mkString("_")) + + // A complete-path reference target to the source that drives this new output + def absoluteSourceRT(childInstGraph: Map[String, Map[Instance, OfModule]]): ReferenceTarget = { + instPath.foldLeft[IsModule](topMT)((target, instName) => { + val moduleName = target match { + case iT: InstanceTarget => iT.ofModule + case mT: ModuleTarget => mT.module + } + val instModuleName = childInstGraph(moduleName)(Instance(instName)) + target.instOf(instName, instModuleName.value) + }).ref(src.name) + } + + def getSourceSinkPair(iMaps: Map[String, Map[Instance, OfModule]]): + (ReferenceTarget, ReferenceTarget) = absoluteSourceRT(iMaps) -> portRT + } + + def execute(state: CircuitState): CircuitState = { + + require(state.annotations.collect({ case t: TopWiringAnnotation => t }).isEmpty, + "CircuitState cannot have existing TopWiring annotations before BridgeTopWiring.") + + val inputAnnos = state.annotations.collect({ case a: BridgeTopWiringAnnotation => a }) + val localClockMap = inputAnnos.map(anno => anno.target -> anno.clock).toMap + + // Step 1: Invoke top wiring + // Hacky: Instead of generated output files, instead sneak out the mappings from the TopWiring + // transform. + var topLevelOutputs = Seq[TopWiringMapping]() + def wiringAnnoOutputFunc(td: String, + mappings: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], + state: CircuitState): CircuitState = { + topLevelOutputs = mappings.unzip._1.map(inscrutable5Tuple => TopWiringMapping(inscrutable5Tuple._1, inscrutable5Tuple._4.dropRight(1))) + state + } + + val topWiringOFileAnno = TopWiringOutputFilesAnnotation("unused", wiringAnnoOutputFunc) + val topWiringAnnos = topWiringOFileAnno +: inputAnnos.map(_.toWiringAnnotation(prefix)).distinct + val wiredState = new TopWiringTransform().execute(state.copy(annotations = topWiringAnnos ++ state.annotations)) + + // Step 2: Reconstruct a map from source RT to newly wired reference target + val instanceMaps = new InstanceGraph(wiredState.circuit).getChildrenInstanceMap + .map({ case (m, set) => m.value -> set.toMap }) + .toMap + // localRTtoAbsRTs + val localToAbsSource = topLevelOutputs.groupBy(_.src).map({ case (src, mappings) => + val localRT = src.toTarget + val absoluteRTs = mappings.map({ m => + val srcRT = m.absoluteSourceRT(instanceMaps) + val clockRT = srcRT.copy(ref = localClockMap(localRT).ref) + (srcRT, clockRT) + }) + (localRT, absoluteRTs) + }).toMap + + val allAbsClockRTs = localToAbsSource.values.flatMap(_.unzip._2) + + // Maps complete source RT to the portRT it has been wired to + val absSrcToPort = topLevelOutputs.map(_.getSourceSinkPair(instanceMaps)).toMap + + // Step 3: Do clock analysis using complete paths to clocks + val findClockSourceAnnos = allAbsClockRTs.map(src => FindClockSourceAnnotation(src)).toSeq + val stateToAnalyze = wiredState.copy(annotations = findClockSourceAnnos) + val loweredState = Seq(new ResolveAndCheck, + new MiddleFirrtlToLowFirrtl, + FindClockSources).foldLeft(stateToAnalyze)((state, xform) => xform.transform(state)) + + val clockSourceMap = loweredState.annotations.collect({ + case ClockSourceAnnotation(qT, source) => qT -> source + }).toMap + + + // Step 4: Add new clock ports + val c = wiredState.circuit + val uniqueSources = clockSourceMap.values.flatten.toSeq.distinct + val addedPorts = new mutable.ArrayBuffer[Port]() + val addedConnects = new mutable.ArrayBuffer[Connect]() + val ns = Namespace(c.modules.find(_.name == c.main).get) + val src2sinkClockMap = (for (clock <- uniqueSources) yield { + val portName = ns.newName(s"${prefix}${clock.ref}") + val port = Port(NoInfo, portName, Output, ClockType) + addedConnects += Connect(NoInfo, WRef(port), WRef(clock.ref)) + addedPorts += port + clock -> clock.copy(ref = portName) + }).toMap + + val updatedCircuit = c.copy(modules = c.modules.map({ + case m: Module if m.name == c.main => m.copy(ports = m.ports ++ addedPorts, + body = Block(m.body, addedConnects:_*)) + case o => o + })) + + // Step 5: Generate output annotations + val outputAnnotations = (for ((localRT, absRTs) <- localToAbsSource) yield { + absRTs.map { case (absSourceRT, absClockRT) => + clockSourceMap(absClockRT) match { + case Some(clock) => BridgeTopWiringOutputAnnotation(localRT, absSourceRT, absSrcToPort(absSourceRT), clock, src2sinkClockMap(clock)) + case None => throw new BridgeTopWiring.NoClockSourceFoundException(absSourceRT, absClockRT) + } + } + }).flatten + + val updatedAnnotations = outputAnnotations.toSeq ++ wiredState.annotations.flatMap({ + case s: TopWiringAnnotation => None + case s: TopWiringOutputFilesAnnotation => None + case s: BridgeTopWiringAnnotation => None + case o => Some(o) + }) + + wiredState.copy(circuit = updatedCircuit, annotations = updatedAnnotations) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/ChannelizeTargetIO.scala b/sim/midas/src/main/scala/midas/passes/ChannelizeTargetIO.scala deleted file mode 100644 index c8e2b964..00000000 --- a/sim/midas/src/main/scala/midas/passes/ChannelizeTargetIO.scala +++ /dev/null @@ -1,115 +0,0 @@ -// See LICENSE for license details. - -package midas.passes - -import firrtl._ -import firrtl.analyses.InstanceGraph -import firrtl.annotations.{ModuleTarget, ReferenceTarget, TargetToken} -import firrtl.ir._ -import firrtl.Mappers._ -import fame._ -import Utils._ - -import collection.mutable -import java.io.{File, FileWriter, StringWriter} - -import midas.core.SimUtils._ - -private[passes] class ChannelizeTargetIO(io: Seq[(String, chisel3.Data)]) extends firrtl.Transform { - - override def name = "[MIDAS] ChannelizeTargetIO" - def inputForm = LowForm - def outputForm = LowForm - - def execute(state: CircuitState): CircuitState = { - val circuit = state.circuit - - // From trivial channel excision - val topName = state.circuit.main - val topModule = state.circuit.modules.find(_.name == topName).collect({ - case m: Module => m - }).get - - val topModulePortMap: Map[String, Port] = topModule.ports.map({p => p.name -> p}).toMap - - val (wireSinks, wireSources, rvSinks, rvSources) = parsePortsSeq(io, alsoFlattenRVPorts = false) - - // Helper functions to generate annotations and ReferenceTargets - def portRefTarget(field: String) = ReferenceTarget(circuit.main, circuit.main, Nil, field, Nil) - - def wireSinkAnno(chName: String) = - FAMEChannelConnectionAnnotation(chName, PipeChannel(1), None, Some(Seq(portRefTarget(chName)))) - def wireSourceAnno(chName: String) = - FAMEChannelConnectionAnnotation(chName, PipeChannel(1), Some(Seq(portRefTarget(chName))), None) - - def decoupledRevSinkAnno(name: String, readyTarget: ReferenceTarget) = - FAMEChannelConnectionAnnotation(prefixWith(name, "rev"), DecoupledReverseChannel, None, Some(Seq(readyTarget))) - def decoupledRevSourceAnno(name: String, readyTarget: ReferenceTarget) = - FAMEChannelConnectionAnnotation(prefixWith(name, "rev"), DecoupledReverseChannel, Some(Seq(readyTarget)), None) - - def decoupledFwdSinkAnno(chName: String, - validTarget: ReferenceTarget, - readyTarget: ReferenceTarget, - leaves: Seq[ReferenceTarget]): FAMEChannelConnectionAnnotation = { - - val chInfo = DecoupledForwardChannel( - validSink = Some(validTarget), - readySource = Some(readyTarget), - validSource = None, - readySink = None) - FAMEChannelConnectionAnnotation(prefixWith(chName, "fwd"), chInfo, None, Some(leaves)) - } - - def decoupledFwdSourceAnno(chName: String, - validTarget: ReferenceTarget, - readyTarget: ReferenceTarget, - leaves: Seq[ReferenceTarget]): FAMEChannelConnectionAnnotation = { - - val chInfo = DecoupledForwardChannel( - validSource = Some(validTarget), - readySink = Some(readyTarget), - validSink = None, - readySource = None) - FAMEChannelConnectionAnnotation(prefixWith(chName, "fwd"), chInfo, Some(leaves), None) - } - - // Generate ReferenceTargets for the leaves in an RV payload - def getRVLeaves(name: String, field: chisel3.Data): Seq[ReferenceTarget] = field match { - case b: chisel3.Record => - b.elements.toSeq flatMap { case (n, e) => getRVLeaves(prefixWith(name, n), e) } - case v: chisel3.Vec[_] => - v.zipWithIndex flatMap { case (e, i) => getRVLeaves(prefixWith(name, i), e) } - case b: chisel3.Element => Seq(portRefTarget(name)) - case _ => throw new RuntimeException("Unexpected type in ready-valid payload") - } - - // Generate valid, ready, and payload reference targets for a decoupled interface - def getRVTargets(port: chisel3.Data, name: String): (ReferenceTarget, ReferenceTarget, Seq[ReferenceTarget]) = { - val validTarget = portRefTarget(prefixWith(name, "valid")) - val readyTarget = portRefTarget(prefixWith(name, "ready")) - val payloadTargets = getRVLeaves(prefixWith(name, "bits"), port) - (validTarget, readyTarget, Seq(validTarget) ++ payloadTargets) - } - - def rvSinkAnnos(chTuple: RVChTuple): Seq[FAMEChannelConnectionAnnotation] = chTuple match { - case (port, name) => - val (vT, rT, pTs) = getRVTargets(port.bits, name) - Seq(decoupledFwdSinkAnno(name, vT, rT, pTs), decoupledRevSourceAnno(name, rT)) - } - - def rvSourceAnnos(chTuple: RVChTuple): Seq[FAMEChannelConnectionAnnotation] = chTuple match { - case (port, name) => - val (vT, rT, pTs) = getRVTargets(port.bits, name) - Seq(decoupledFwdSourceAnno(name, vT, rT, pTs), decoupledRevSinkAnno(name, rT)) - } - - val chAnnos = - wireSinks .map({ case (_, chName) => wireSinkAnno(chName) }) ++ - wireSources.map({ case (_, chName) => wireSourceAnno(chName) }) ++ - rvSinks .flatMap(rvSinkAnnos) ++ - rvSources .flatMap(rvSourceAnnos) - - val f1Anno = FAMETransformAnnotation(FAME1Transform, ModuleTarget(topName, topName)) - state.copy(annotations = state.annotations ++ Seq(f1Anno) ++ chAnnos) - } -} diff --git a/sim/midas/src/main/scala/midas/passes/CheckCombLoops.scala b/sim/midas/src/main/scala/midas/passes/CheckCombLoops.scala new file mode 100644 index 00000000..3925af27 --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/CheckCombLoops.scala @@ -0,0 +1,290 @@ +// See LICENSE for license details. + +package midas.passes + +import scala.collection.mutable + +import firrtl._ +import firrtl.ir._ +import firrtl.passes.{Errors, PassException} +import firrtl.traversals.Foreachers._ +import firrtl.annotations._ +import firrtl.Utils.throwInternalError +import firrtl.graph.{MutableDiGraph,DiGraph} +import firrtl.analyses.InstanceGraph +import firrtl.options.{RegisteredTransform, ShellOption} + +object CheckCombLoops { + class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException( + s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n")) +} + +case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation + +case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarget) extends Annotation { + if (!source.isLocal || !sink.isLocal || source.module != sink.module) { + throwInternalError(s"ExtModulePathAnnotation must connect two local targets from the same module") + } + + override def getTargets: Seq[ReferenceTarget] = Seq(source, sink) + + override def update(renames: RenameMap): Seq[Annotation] = { + val sources = renames.get(source).getOrElse(Seq(source)) + val sinks = renames.get(sink).getOrElse(Seq(sink)) + val paths = sources flatMap { s => sinks.map((s, _)) } + paths.collect { + case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink) + } + } +} + +case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget]) extends Annotation { + override def update(renames: RenameMap): Seq[Annotation] = { + val newSources = sources.flatMap { s => renames(s) }.collect {case x: ReferenceTarget if x.isLocal => x} + val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x} + newSinks.map(snk => CombinationalPath(snk, newSources)) + } +} + +case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None) + +/** Finds and detects combinational logic loops in a circuit, if any exist. Returns the input circuit with no + * modifications. + * + * @throws firrtl.transforms.CheckCombLoops.CombLoopException if a loop is found + * @note Input form: Low FIRRTL + * @note Output form: Low FIRRTL (identity transform) + * @note The pass looks for loops through combinational-read memories + * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules + * @note The pass will throw exceptions on "false paths" + */ +class CheckCombLoops extends Transform with RegisteredTransform { + def inputForm = LowForm + def outputForm = LowForm + + import CheckCombLoops._ + + val options = Seq( + new ShellOption[Unit]( + longOption = "no-check-comb-loops", + toAnnotationSeq = (_: Unit) => Seq(DontCheckCombLoopsAnnotation), + helpText = "Disable combinational loop checking" ) ) + + /* + * A case class that represents a net in the circuit. This is + * necessary since combinational loop checking is an analysis on the + * netlist of the circuit; the fields are specialized for low + * FIRRTL. Since all wires are ground types, a given ground type net + * may only be a subfield of an instance or a memory + * port. Therefore, it is uniquely specified within its module + * context by its name, its optional parent instance (a WDefInstance + * or WDefMemory), and its optional memory port name. + */ + + private type ConnectivityGraphs = mutable.HashMap[String,DiGraph[LogicNode]] + + private def toLogicNode(e: Expression): LogicNode = e match { + case idx: WSubIndex => + toLogicNode(idx.expr) + case r: WRef => + LogicNode(r.name) + case s: WSubField => + s.expr match { + case modref: WRef => + LogicNode(s.name,Some(modref.name)) + case memport: WSubField => + memport.expr match { + case memref: WRef => + LogicNode(s.name,Some(memref.name),Some(memport.name)) + case _ => throwInternalError(s"toLogicNode: unrecognized subsubfield expression - $memport") + } + case _ => throwInternalError(s"toLogicNode: unrecognized subfield expression - $s") + } + } + + + private def getExprDeps(deps: MutableDiGraph[LogicNode], v: LogicNode)(e: Expression): Unit = e match { + case r: WRef => deps.addEdgeIfValid(v, toLogicNode(r)) + case s: WSubField => deps.addEdgeIfValid(v, toLogicNode(s)) + case _ => e.foreach(getExprDeps(deps, v)) + } + + private def getStmtDeps( + simplifiedModules: mutable.Map[String,DiGraph[LogicNode]], + deps: MutableDiGraph[LogicNode])(s: Statement): Unit = s match { + case Connect(_,loc,expr) => + val lhs = toLogicNode(loc) + if (deps.contains(lhs)) { + getExprDeps(deps, lhs)(expr) + } + case w: DefWire => + deps.addVertex(LogicNode(w.name)) + case n: DefNode => + val lhs = LogicNode(n.name) + deps.addVertex(lhs) + getExprDeps(deps, lhs)(n.value) + case m: DefMemory if (m.readLatency == 0) => + for (rp <- m.readers) { + val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp))) + deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp)))) + deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp)))) + } + case i: WDefInstance => + val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) + iGraph.getVertices.foreach(deps.addVertex(_)) + iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) + case _ => + s.foreach(getStmtDeps(simplifiedModules,deps)) + } + + /* + * Recover the full path from a path passing through simplified + * instances. Since edges may pass through simplified instances, the + * hierarchy that the path passes through must be recursively + * recovered. + */ + private def expandInstancePaths( + m: String, + moduleGraphs: mutable.Map[String,DiGraph[LogicNode]], + moduleDeps: Map[String, Map[String,String]], + prefix: Seq[String], + path: Seq[LogicNode]): Seq[String] = { + def absNodeName(prefix: Seq[String], n: LogicNode) = + (prefix ++ n.inst ++ n.memport :+ n.name).mkString(".") + val pathNodes = (path zip path.tail) map { case (a, b) => + if (a.inst.isDefined && !a.memport.isDefined && a.inst == b.inst) { + val child = moduleDeps(m)(a.inst.get) + val newprefix = prefix :+ a.inst.get + val subpath = moduleGraphs(child).path(b.copy(inst=None),a.copy(inst=None)).tail.reverse + expandInstancePaths(child,moduleGraphs,moduleDeps,newprefix,subpath) + } else { + Seq(absNodeName(prefix,a)) + } + } + pathNodes.flatten :+ absNodeName(prefix, path.last) + } + + /* + * An SCC may contain more than one loop. In this case, the sequence + * of nodes forming the SCC cannot be interpreted as a simple + * cycle. However, it is desirable to print an error consisting of a + * loop rather than an arbitrary ordering of the SCC. This function + * operates on a pruned subgraph composed only of the SCC and finds + * a simple cycle by performing an arbitrary walk. + */ + private def findCycleInSCC[T](sccGraph: DiGraph[T]): Seq[T] = { + val walk = new mutable.ArrayBuffer[T] + val visited = new mutable.HashSet[T] + var current = sccGraph.getVertices.head + while (!visited.contains(current)) { + walk += current + visited += current + current = sccGraph.getEdges(current).head + } + walk.drop(walk.indexOf(current)).toSeq :+ current + } + + /* + * This implementation of combinational loop detection avoids ever + * generating a full netlist from the FIRRTL circuit. Instead, each + * module is converted to a netlist and analyzed locally, with its + * subinstances represented by trivial, simplified subgraphs. The + * overall outline of the process is: + * + * 1. Create a graph of module instance dependances + + * 2. Linearize this acyclic graph + * + * 3. Generate a local netlist; replace any instances with + * simplified subgraphs representing connectivity of their IOs + * + * 4. Check for nontrivial strongly connected components + * + * 5. Create a reduced representation of the netlist with only the + * module IOs as nodes, where output X (which must be a ground type, + * as only low FIRRTL is supported) will have an edge to input Y if + * and only if it combinationally depends on input Y. Associate this + * reduced graph with the module for future use. + */ + private def run(state: CircuitState): (CircuitState, Errors, ConnectivityGraphs, ConnectivityGraphs) = { + val c = state.circuit + val errors = new Errors() + val extModulePaths = state.annotations.groupBy { + case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module) + case ann: Annotation => CircuitTarget(c.main) + } + val moduleMap = c.modules.map({m => (m.name,m) }).toMap + val iGraph = new InstanceGraph(c).graph + val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap + val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } + val moduleGraphs = new ConnectivityGraphs + val simplifiedModuleGraphs = new ConnectivityGraphs + topoSortedModules.foreach { + case em: ExtModule => + val portSet = em.ports.map(p => LogicNode(p.name)).toSet + val extModuleDeps = new MutableDiGraph[LogicNode] + portSet.foreach(extModuleDeps.addVertex(_)) + extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect { + case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref)) + } + moduleGraphs(em.name) = DiGraph(extModuleDeps).simplify(portSet) + simplifiedModuleGraphs(em.name) = moduleGraphs(em.name) + case m: Module => + val portSet = m.ports.map(p => LogicNode(p.name)).toSet + val internalDeps = new MutableDiGraph[LogicNode] + portSet.foreach(internalDeps.addVertex(_)) + m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps)) + val moduleGraph = DiGraph(internalDeps) + moduleGraphs(m.name) = moduleGraph + simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify(portSet) + // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs! + for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) { + errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name))) + } + for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) { + val sccSubgraph = moduleGraph.subgraph(scc.toSet) + val cycle = findCycleInSCC(sccSubgraph) + (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) }) + val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse) + errors.append(new CombLoopException(m.info, m.name, expandedCycle)) + } + case m => throwInternalError(s"Module ${m.name} has unrecognized type") + } + val mt = ModuleTarget(c.main, c.main) + val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty => + val sink = mt.ref(from.name) + val sources = tos.map(to => mt.ref(to.name)) + CombinationalPath(sink, sources.toSeq) + } + (state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs, moduleGraphs) + } + + /** + * Returns a Map from Module name to port connectivity + */ + def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = { + val (result, errors, connectivity, _) = run(state) + connectivity.map { + case (k, v) => (k, v.transformNodes(ln => ln.name)) + } + } + + /** + * Returns a Map from Module name to complete netlist connectivity + */ + def analyzeFull(state: CircuitState): collection.Map[String,DiGraph[LogicNode]] = { + run(state)._4 + } + + def execute(state: CircuitState): CircuitState = { + val dontRun = state.annotations.contains(DontCheckCombLoopsAnnotation) + if (dontRun) { + logger.warn("Skipping Combinational Loop Detection") + state + } else { + val (result, errors, connectivity, _) = run(state) + errors.trigger() + result + } + } +} diff --git a/sim/midas/src/main/scala/midas/passes/ClockSourceFinder.scala b/sim/midas/src/main/scala/midas/passes/ClockSourceFinder.scala new file mode 100644 index 00000000..f54c9fdd --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/ClockSourceFinder.scala @@ -0,0 +1,65 @@ +// See LICENSE for license details. + +package midas.passes + +import firrtl._ +import firrtl.ir._ +import firrtl.annotations._ +import firrtl.annotations.TargetToken.{OfModule, Instance, Field} + +/** + * Contains exception classes for [[midas.passes.ClockSourceFinder]] + */ +object ClockSourceFinder { + class MultipleDriversException(target: ReferenceTarget, module: String, drivers: Seq[String]) + extends Exception(s"Clock ${target} is driven by multiple signals in module ${module}: ${drivers}", null) +} + +/** A utility for finding the upstream drivers of arbitrary clock signals in a circuit. + * find and return that input port. + * + * @param state the CircuitState to analyze + */ +class ClockSourceFinder(state: CircuitState) { + import ClockSourceFinder._ + + private def inputClockNodes(m: DefModule) = m.ports.collect { + case Port(_, name, Input, ClockType) => LogicNode(name) + } + + private val moduleMap = state.circuit.modules.map(m => m.name -> m).toMap + private val inputClockNodeSets = state.circuit.modules.map(m => m.name -> inputClockNodes(m).toSet).toMap + private lazy val connectivity = new CheckCombLoops().analyzeFull(state) + + /** If a clock signal is directly driven by an input clock port at the top of a multi-module hierarchy, + * find and return that input port. + * + * @param queryTarget the downstream clock to analyze + * @return an option containing a "local" reference to the port in queryTarget.module that drives queryTarget, if any. + * @note The analysis is limited to the hierarchy rooted at the root module of queryTarget + */ + def findRootDriver(queryTarget: ReferenceTarget): Option[ReferenceTarget] = { + require(queryTarget.component.isEmpty) + def getPortDriver(rT: ReferenceTarget): Option[ReferenceTarget] = { + val portOption = rT.component.collectFirst { case Field(f) => f } + val node = portOption.map(p => LogicNode(p, Some(rT.ref))).getOrElse(LogicNode(rT.ref)) + val (inst, module) = rT.path.lastOption.getOrElse((Instance(rT.module), OfModule(rT.module))) + val drivingCone = connectivity(module.value).reachableFrom(node) + val drivingPorts = (drivingCone + node) & inputClockNodeSets(module.value) + if (drivingPorts.size == 0) { + None + } else if (drivingPorts.size > 1) { + throw new MultipleDriversException(queryTarget, module.value, drivingPorts.map(_.name).toSeq) + } else { + val drivingPort = drivingPorts.head.name + if (rT.path.isEmpty) { + Some(rT.copy(ref = drivingPort, component = Nil)) + } else { + val parentRT = rT.copy(path = rT.path.init, ref = inst.value, component = Seq(Field(drivingPort))) + getPortDriver(parentRT) + } + } + } + getPortDriver(queryTarget) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/EnsureNoTargetIO.scala b/sim/midas/src/main/scala/midas/passes/EnsureNoTargetIO.scala index 20c35cb2..a668b4c8 100644 --- a/sim/midas/src/main/scala/midas/passes/EnsureNoTargetIO.scala +++ b/sim/midas/src/main/scala/midas/passes/EnsureNoTargetIO.scala @@ -18,7 +18,9 @@ import java.io.{File, FileWriter, StringWriter} // Ensures that there are no dangling IO on the target. All I/O coming off the DUT must be bound // to an Bridge BlackBox -private[passes] class EnsureNoTargetIO extends firrtl.Transform { +case class TargetMalformedException(message: String) extends RuntimeException(message) + +private[passes] object EnsureNoTargetIO extends firrtl.Transform { def inputForm = HighForm def outputForm = HighForm override def name = "[MIDAS] Ensure No Target IO" @@ -27,14 +29,20 @@ private[passes] class EnsureNoTargetIO extends firrtl.Transform { val topName = state.circuit.main val topModule = state.circuit.modules.find(_.name == topName).get - val nonClockPorts = topModule.ports.filter(_.tpe != ClockType) + val (clockPorts, nonClockPorts) = topModule.ports.partition(_.tpe == ClockType) + + if (!clockPorts.isEmpty) { + val exceptionMessage = "Your target design has the following unexpected clock ports:\n" + + clockPorts.map(_.name).mkString("\n") + + "\nRemove these ports and generate clocks for your simulated system using a ClockBridge." + throw TargetMalformedException(exceptionMessage) + } if (!nonClockPorts.isEmpty) { - val exceptionMessage = """ -Your target design has dangling IO. -You must bind the following top-level ports to an Bridge BlackBox: -""" + nonClockPorts.map(_.name).mkString("\n") - throw new Exception(exceptionMessage) + val exceptionMessage = "Your target design has the following unexpecte IO ports:\n" + + nonClockPorts.map(_.name).mkString("\n") + + "\nRemove these ports and instead bind their sources/sinks to a target-to-host Bridge." + throw TargetMalformedException(exceptionMessage) } state } diff --git a/sim/midas/src/main/scala/midas/passes/ExtractBridges.scala b/sim/midas/src/main/scala/midas/passes/ExtractBridges.scala index 838d5f16..c6fc0910 100644 --- a/sim/midas/src/main/scala/midas/passes/ExtractBridges.scala +++ b/sim/midas/src/main/scala/midas/passes/ExtractBridges.scala @@ -2,7 +2,7 @@ package midas.passes -import midas.widgets.{BridgeAnnotation} +import midas.widgets.{BridgeAnnotation, ClockBridgeAnnotation} import midas.passes.fame.{PromoteSubmodule, PromoteSubmoduleAnnotation, FAMEChannelConnectionAnnotation} import firrtl._ @@ -99,6 +99,14 @@ private[passes] class BridgeExtraction extends firrtl.Transform { topModule.foreach(getBridgeConnectivity(portInstPairs, instList)) val instMap = instList.toMap + val clockBridgeInsts = instList.map(inst => inst._1 -> bridgeAnnoMap(instMap(inst._1))) + .collect({ case (inst, cb: ClockBridgeAnnotation) => inst }) + + val bridgeInstMessage = "You must use a single ClockBridge instance to generate clocks for your simulated system." + assert(clockBridgeInsts.nonEmpty, s"No ClockBridge instances found. ${bridgeInstMessage}") + assert(clockBridgeInsts.size == 1, + s"Multiple ClockBridge instances found: ${clockBridgeInsts.mkString("\n")} ${bridgeInstMessage}") + val ioAnnotations = portInstPairs.flatMap({ case (port, inst) => val updatedBridgeAnno = bridgeAnnoMap(instMap(inst)).toIOAnnotation(port) val updatedFCAAnnos = fcaMap(instMap(inst)).map(_.moveFromBridge(port)) diff --git a/sim/midas/src/main/scala/midas/passes/FindClockSources.scala b/sim/midas/src/main/scala/midas/passes/FindClockSources.scala new file mode 100644 index 00000000..848c6a06 --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/FindClockSources.scala @@ -0,0 +1,36 @@ +// See LICENSE for license details. + +package midas.passes + +import firrtl._ +import firrtl.annotations._ + +import midas.passes.fame.RTRenamer + +case class FindClockSourceAnnotation( + target: ReferenceTarget, + originalTarget: Option[ReferenceTarget] = None) extends Annotation { + require(target.module == target.circuit, s"Queried leaf clock ${target} must provide an absolute instance path") + def update(renames: RenameMap): Seq[FindClockSourceAnnotation] = + Seq(this.copy(RTRenamer.exact(renames)(target), originalTarget.orElse(Some(target)))) +} + +case class ClockSourceAnnotation(queryTarget: ReferenceTarget, source: Option[ReferenceTarget]) extends Annotation { + def update(renames: RenameMap): Seq[ClockSourceAnnotation] = + Seq(this.copy(queryTarget, source.map(s => RTRenamer.exact(renames)(s)))) +} + +object FindClockSources extends firrtl.Transform { + def inputForm = LowForm + def outputForm = LowForm + + def execute(state: CircuitState): CircuitState = { + val queryAnnotations = state.annotations.collect({ case anno: FindClockSourceAnnotation => anno }) + val sourceFinder = new ClockSourceFinder(state) + val sourceMappings = queryAnnotations.map(qA => qA.target -> sourceFinder.findRootDriver(qA.target)).toMap + val clockSourceAnnotations = queryAnnotations.map(qAnno => + ClockSourceAnnotation(qAnno.originalTarget.getOrElse(qAnno.target), sourceMappings(qAnno.target))) + val prunedAnnos = state.annotations.filterNot(queryAnnotations.toSet) + state.copy(annotations = clockSourceAnnotations ++ prunedAnnos) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/MidasTransforms.scala b/sim/midas/src/main/scala/midas/passes/MidasTransforms.scala index 5ffa4197..c412a886 100644 --- a/sim/midas/src/main/scala/midas/passes/MidasTransforms.scala +++ b/sim/midas/src/main/scala/midas/passes/MidasTransforms.scala @@ -14,6 +14,7 @@ import firrtl.ir._ import logger._ import firrtl.Mappers._ import firrtl.transforms.{DedupModules, DeadCodeElimination} +import firrtl.passes.wiring.{WiringTransform} import Utils._ import java.io.{File, FileWriter} @@ -44,18 +45,29 @@ private[midas] class MidasTransforms(implicit p: Parameters) extends Transform { firrtl.passes.SplitExpressions, firrtl.passes.CommonSubexpressionElimination, new firrtl.transforms.DeadCodeElimination, - new EnsureNoTargetIO, - // NB: Carelessly removing this pass will break the FireSim manager as we always - // need to generate the *.asserts file. Fix by baking into driver. + EnsureNoTargetIO, + new BridgeExtraction, + new ResolveAndCheck, + new EmitFirrtl("post-bridge-extraction.fir"), + new HighFirrtlToMiddleFirrtl, + new MiddleFirrtlToLowFirrtl, + new AutoCounterTransform, + new EmitFirrtl("post-autocounter.fir"), + new fame.EmitFAMEAnnotations("post-autocounter.json"), + new ResolveAndCheck, new AssertPass(dir), new PrintSynthesis(dir), new ResolveAndCheck, - new HighFirrtlToMiddleFirrtl, - new MiddleFirrtlToLowFirrtl, - new BridgeExtraction, - new ResolveAndCheck, - new MiddleFirrtlToLowFirrtl, - new fame.WrapTop, + new EmitFirrtl("post-debug-synthesis.fir"), + new fame.EmitFAMEAnnotations("post-debug-synthesis.json"), + // All trigger sources and sinks must exist in the target RTL before this pass runs + TriggerWiring, + new EmitFirrtl("post-trigger-wiring.fir"), + new fame.EmitFAMEAnnotations("post-trigger-wiring.json"), + // We should consider moving these lower + ChannelClockInfoAnalysis, + UpdateBridgeClockInfo, + fame.WrapTop, new ResolveAndCheck, new EmitFirrtl("post-wrap-top.fir")) ++ optionalTargetTransforms ++ @@ -66,16 +78,22 @@ private[midas] class MidasTransforms(implicit p: Parameters) extends Transform { new HighFirrtlToMiddleFirrtl, new MiddleFirrtlToLowFirrtl, new fame.FAMEDefaults, + new EmitFirrtl("post-fame-defaults.fir"), + new fame.EmitFAMEAnnotations("post-fame-defaults.json"), + fame.FindDefaultClocks, + new fame.EmitFAMEAnnotations("post-find-default-clocks.json"), new fame.ChannelExcision, - new fame.InferModelPorts, - new ResolveAndCheck, + new fame.EmitFAMEAnnotations("post-channel-excision.json"), new EmitFirrtl("post-channel-excision.fir"), + new fame.InferModelPorts, + new fame.EmitFAMEAnnotations("post-infer-model-ports.json"), new fame.FAMETransform, DefineAbstractClockGate, new EmitFirrtl("post-fame-transform.fir"), + new fame.EmitFAMEAnnotations("post-fame-transform.json"), + new ResolveAndCheck, new ResolveAndCheck, new fame.EmitAndWrapRAMModels, - ConnectHostClock, new EmitFirrtl("post-gen-sram-models.fir"), new ResolveAndCheck) ++ Seq( diff --git a/sim/midas/src/main/scala/midas/passes/PrintSynthesis.scala b/sim/midas/src/main/scala/midas/passes/PrintSynthesis.scala index 71d3f5b1..9745678f 100644 --- a/sim/midas/src/main/scala/midas/passes/PrintSynthesis.scala +++ b/sim/midas/src/main/scala/midas/passes/PrintSynthesis.scala @@ -24,16 +24,16 @@ import midas.targetutils.SynthPrintfAnnotation private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends firrtl.Transform { def inputForm = MidForm def outputForm = MidForm - override def name = "[MIDAS] Print Synthesis" + override def name = "[Golden Gate] Print Synthesis" private val printMods = new mutable.HashSet[ModuleTarget]() private val formatStringMap = new mutable.HashMap[ReferenceTarget, String]() + val topWiringPrefix = "synthesizedPrintf_" - // Generates a bundle to aggregate + // Generates a bundle containing a print's clock, enable, and argument fields def genPrintBundleType(print: Print): Type = BundleType(Seq( Field("enable", Default, BoolType)) ++ - print.args.zipWithIndex.map({ case (arg, idx) => Field(s"args_${idx}", Default, arg.tpe) }) - ) + print.args.zipWithIndex.map({ case (arg, idx) => Field(s"args_${idx}", Default, arg.tpe) })) def getPrintName(p: Print, anno: SynthPrintfAnnotation, ns: Namespace): String = { // If the user provided a name in the annotation use it; otherwise use the source locator @@ -44,45 +44,31 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends ns.newName(candidateName) } - // Hacky: Instead of generated output files, instead sneak out the mappings from the TopWiring - // transform. - type TopWiringSink = ((ComponentName, Type, Boolean, Seq[String], String), Int) - var topLevelOutputs = Seq[TopWiringSink]() - def wiringAnnoOutputFunc(td: String, - mappings: Seq[TopWiringSink], - state: CircuitState): CircuitState = { - topLevelOutputs = mappings - state - } - // Takes a single printPort and emits an FCCA for each field - def genFCCAsFromPort(mT: ModuleTarget, p: Port): Seq[FAMEChannelConnectionAnnotation] = { + def genFCCAsFromPort(p: Port, portRT: ReferenceTarget, clockRT: ReferenceTarget): Seq[FAMEChannelConnectionAnnotation] = { p.tpe match { - case BundleType(fields) => - fields.map(field => - FAMEChannelConnectionAnnotation( - p.name + "_" + field.name, + case BundleType(dataFields) => + dataFields.map(field => + FAMEChannelConnectionAnnotation.source( + portRT.ref + "_" + field.name, WireChannel, - sources = Some(Seq(mT.ref(p.name).field(field.name))), - sinks = None + clock = Some(clockRT), + Seq(portRT.field(field.name)) ) ) - case other => Seq() + case other => ??? } } def synthesizePrints(state: CircuitState, printfAnnos: Seq[SynthPrintfAnnotation]): CircuitState = { - require(state.annotations.collect({ case t: TopWiringAnnotation => t }).isEmpty, - "CircuitState cannot have existing TopWiring annotations before PrintSynthesis.") val c = state.circuit + def mTarget(m: Module): ModuleTarget = ModuleTarget(c.main, m.name) + def portRT(p: Port): ReferenceTarget = ModuleTarget(c.main, c.main).ref(p.name) + def portClockRT(p: Port): ReferenceTarget = portRT(p).field("clock") val modToAnnos = printfAnnos.groupBy(_.mod) - - val topWiringAnnos = mutable.ArrayBuffer[Annotation]( - TopWiringOutputFilesAnnotation("unused", wiringAnnoOutputFunc)) - - val topWiringPrefix = "synthesizedPrintf_" + val topWiringAnnos = mutable.ArrayBuffer[Annotation]() def onModule(m: DefModule): DefModule = m match { case m: Module if printMods(mTarget(m)) => @@ -92,63 +78,62 @@ private[passes] class PrintSynthesis(dir: File)(implicit p: Parameters) extends def onStmt(annos: Seq[SynthPrintfAnnotation], modNamespace: Namespace) (s: Statement): Statement = s.map(onStmt(annos, modNamespace)) match { - case p @ Print(_,format,args,_,en) if annos.exists(_.format == format.string) => + case p @ Print(_,format,args,clk ,en) if annos.exists(_.format == format.string) => val associatedAnno = annos.find(_.format == format.string).get val printName = getPrintName(p, associatedAnno, modNamespace) // Generate an aggregate with all of our arguments; this will be wired out val wire = DefWire(NoInfo, printName, genPrintBundleType(p)) - val enableConnect = Connect(NoInfo, wsub(WRef(wire), s"enable"), en) + val enableConnect = Connect(NoInfo, wsub(WRef(wire), "enable"), en) val argumentConnects = (p.args.zipWithIndex).map({ case (arg, idx) => Connect(NoInfo, wsub(WRef(wire), s"args_${idx}"), arg)}) val printBundleTarget = associatedAnno.mod.ref(printName) - topWiringAnnos += TopWiringAnnotation(printBundleTarget, topWiringPrefix) + val clockTarget = clk match { + case WRef(name,_,_,_) => associatedAnno.mod.ref(name) + case o => ??? + } + topWiringAnnos += BridgeTopWiringAnnotation(printBundleTarget, clockTarget) formatStringMap(printBundleTarget) = format.serialize Block(Seq(p, wire, enableConnect) ++ argumentConnects) case s => s } - + // Step 1: Find and replace printfs with stubs val processedCircuit = c.map(onModule) - val wiredState = (new TopWiringTransform).execute(state.copy( + + // Step 2: Wire out print stubs to top level module + val wiredState = (new BridgeTopWiring(topWiringPrefix)).execute(state.copy( circuit = processedCircuit, annotations = state.annotations ++ topWiringAnnos)) - val topModule = wiredState.circuit.modules.find(_.name == wiredState.circuit.main).get - val portMap: Map[String, Port] = topModule.ports.map(port => port.name -> port).toMap - val addedPrintPorts = topLevelOutputs.map({ case ((cname,_,_,path,prefix),_) => - // Look up the format string by regenerating the referenceTarget to the original print bundle - val formatString = formatStringMap(cname) - val port = portMap(prefix + path.mkString("_")) - (port, formatString) - }) + // Step 3: Group top-wired ports by their associated clock + val outputAnnos = wiredState.annotations.collect({ case a: BridgeTopWiringOutputAnnotation => a }) + val groupedPrints = outputAnnos.groupBy(_.sinkClockPort) - println(s"[MIDAS] total # of prints synthesized: ${addedPrintPorts.size}") + println(s"[Golden Gate] total # of printf instances synthesized: ${outputAnnos.size}") - val printRecordAnno = addedPrintPorts match { - case Nil => Seq() - case ports => { - // TODO: Generate sensible channel annotations once we can aggregate wire channels - val portName = topWiringPrefix.stripSuffix("_") - val mT = ModuleTarget(c.main, c.main) - val portRT = mT.ref(portName) + // Step 4: Generate FCCAs and Bridge Annotations for each clock domain + val topModule = wiredState.circuit.modules.find(_.name == c.main).get + val portMap = topModule.ports.map(p => portRT(p) -> p).toMap - val fccaAnnos = ports.flatMap({ case (port, _) => genFCCAsFromPort(mT, port) }) - val bridgeAnno = BridgeIOAnnotation( - target = portRT, - widget = (p: Parameters) => new PrintBridgeModule(addedPrintPorts)(p), - channelNames = fccaAnnos.map(_.globalName) - ) - bridgeAnno +: fccaAnnos - } + val printRecordAnnos = for ((clockRT, oAnnos) <- groupedPrints.toSeq.sortBy(_._1.ref)) yield { + val fccaAnnos = oAnnos.flatMap({ case BridgeTopWiringOutputAnnotation(_,_,oPortRT,_,oClockRT) => + genFCCAsFromPort(portMap(oPortRT), oPortRT, oClockRT) }) + + val portTuples = oAnnos.map({ case BridgeTopWiringOutputAnnotation(srcRT,_,oPortRT,_,_) => + portMap(oPortRT) -> formatStringMap(srcRT) }) + + val bridgeAnno = BridgeIOAnnotation( + target = ModuleTarget(c.main, c.main).ref(topWiringPrefix.stripSuffix("_")), + widget = (p: Parameters) => new PrintBridgeModule(portTuples)(p), + channelNames = fccaAnnos.map(_.globalName) + ) + bridgeAnno +: fccaAnnos } - // Remove added TopWiringAnnotations to prevent being reconsumed by a downstream pass - val cleanedAnnotations = wiredState.annotations.flatMap({ - case TopWiringAnnotation(_,_) => None - case otherAnno => Some(otherAnno) - }) - wiredState.copy(annotations = cleanedAnnotations ++ printRecordAnno) + // Remove added Annotations to prevent being reconsumed by a downstream pass + val cleanedAnnotations = wiredState.annotations.filterNot(outputAnnos.toSet) + wiredState.copy(annotations = cleanedAnnotations ++ printRecordAnnos.toSeq.flatten) } def execute(state: CircuitState): CircuitState = { diff --git a/sim/midas/src/main/scala/midas/passes/SimulationMapping.scala b/sim/midas/src/main/scala/midas/passes/SimulationMapping.scala index 185c50fb..8ae45195 100644 --- a/sim/midas/src/main/scala/midas/passes/SimulationMapping.scala +++ b/sim/midas/src/main/scala/midas/passes/SimulationMapping.scala @@ -14,7 +14,7 @@ import firrtl.Mappers._ import firrtl.passes.LowerTypes.loweredName import firrtl.Utils.{BoolType, splitRef, mergeRef, create_exps, gender, module_type} import firrtl.passes.wiring._ -import fame.{FAMEChannelConnectionAnnotation, FAMEChannelAnalysis, FAME1Transform} +import fame.{FAMEChannelConnectionAnnotation, FAMEChannelPortsAnnotation, FAMEChannelAnalysis, FAME1Transform} import Utils._ import freechips.rocketchip.config.Parameters @@ -128,7 +128,6 @@ private[passes] class SimulationMapping(targetName: String)(implicit val p: Para val transforms = Seq( new Fame1Instances, - new WiringTransform, new PreLinkRenaming(Namespace(innerCircuit))) val outerState = new LowFirrtlCompiler().compile(CircuitState(chirrtl, ChirrtlForm, annos), transforms) @@ -144,7 +143,7 @@ private[passes] class SimulationMapping(targetName: String)(implicit val p: Para // FIXME: Renamer complains if i leave these in val innerAnnos = loweredInnerState.annotations.filter(_ match { - case _: FAMEChannelConnectionAnnotation => false + case _: FAMEChannelConnectionAnnotation | _: FAMEChannelPortsAnnotation => false case _: BridgeIOAnnotation => false case _ => true }) diff --git a/sim/midas/src/main/scala/midas/passes/TargetClockAnalysis.scala b/sim/midas/src/main/scala/midas/passes/TargetClockAnalysis.scala new file mode 100644 index 00000000..4ed521f9 --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/TargetClockAnalysis.scala @@ -0,0 +1,49 @@ +// See LICENSE for license details. + +package midas.passes + +import midas.passes.fame.{FAMEChannelConnectionAnnotation, TargetClockChannel} +import midas.widgets.{RationalClock} + +import firrtl._ +import firrtl.annotations._ + + +/** + * [[ChannelClockInfoAnalysis]]'s output annotation. Maps channel global name + * (See [[FAMEChannelConnectionAnnotation]] to a clock info class. + */ +case class ChannelClockInfoAnnotation(infoMap: Map[String, RationalClock]) extends NoTargetAnnotation + +/** + * Returns a map from a channel's global name to a RationalClock case class which + * contains metadata about the target clock including its name and relative + * frequency to the base clock + * + */ +object ChannelClockInfoAnalysis extends Transform { + def inputForm = LowForm + def outputForm = LowForm + def analyze(state: CircuitState): Map[String, RationalClock] = { + val clockChannels = state.annotations.collect { + case FAMEChannelConnectionAnnotation(_,TargetClockChannel(clocks),_,_,Some(clockRTs)) => + clockRTs zip clocks + } + require(clockChannels.size == 1, + s"Expected exactly one clock channel annotation. Got: ${clockChannels.size}") + val sourceInfoMap = clockChannels.head.toMap + + // This relies on the assumption that the clock channel will not have its clock field set + val channelClocks = state.annotations.collect({ + case FAMEChannelConnectionAnnotation(name,_,Some(clock),_,_) => (name, clock) + }).toMap + + val finder = new ClockSourceFinder(state) + val clockSourceMap = channelClocks.map({ case (k, v) => v -> finder.findRootDriver(v) }).toMap + channelClocks.mapValues(sinkClock => sourceInfoMap(clockSourceMap(sinkClock).get)) + } + def execute(state: CircuitState): CircuitState = { + val infoAnno = ChannelClockInfoAnnotation(analyze(state)) + state.copy(annotations = infoAnno +: state.annotations) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/TriggerWiring.scala b/sim/midas/src/main/scala/midas/passes/TriggerWiring.scala new file mode 100644 index 00000000..1f6ff51c --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/TriggerWiring.scala @@ -0,0 +1,255 @@ +//See LICENSE for license details. + +package midas.passes + +import midas.targetutils.{TriggerSourceAnnotation, TriggerSinkAnnotation} +import midas.passes.fame._ + +import freechips.rocketchip.util.DensePrefixSum +import firrtl._ +import firrtl.annotations._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.passes.wiring.{SinkAnnotation, SourceAnnotation, WiringTransform} +import firrtl.Utils.{zero, BoolType} + +import scala.collection.mutable + +/* + * Implements Golden Gate's trigger system by emitting target-hardware to aggregate + * all trigger source signals into trigger enables. + * + * Refer to the FireSim Docs for more detail, and for a schematic of the generated HW. + */ +private[passes] object TriggerWiring extends firrtl.Transform { + def inputForm = LowForm + def outputForm = HighForm + override def name = "[Golden Gate] Trigger Wiring" + val topWiringPrefix = "simulationTrigger_" + val sinkWiringKey = "trigger_sink" + + // Defines the width of credit and debit counters local to a specific clock domain + // For the trigger to function correctly: + // localWidth >= log2Ceil(max(localCreditSources, localDebitSources) * Ceil(N/M)) + // where: + // local{Credit,Debit}Sources are the number of the associated source in the local clock domain, + // the local clock frequency is (N/M) times the base frequency + val localCType = UIntType(IntWidth(16)) + // Defines the type of the global counter in the base clock domain. This + // just needs to be large enough to represent the largest number of credits + // the system will produce. Since there are only two of these, make it large + // to be safe. + val globalCType = UIntType(IntWidth(32)) + + // Masks off trigger sources when the are under reset. + private def gateEventsWithReset(sourceModuleMap: Map[String, Seq[TriggerSourceAnnotation]], + updatedAnnos: mutable.ArrayBuffer[TriggerSourceAnnotation]) + (mod: DefModule): DefModule = mod match { + case m: Module if sourceModuleMap.isDefinedAt(m.name) => + val annos = sourceModuleMap(m.name) + val mT = annos.head.enclosingModuleTarget + val moduleNS = Namespace(mod) + val addedStmts = annos.flatMap({ anno => + if (anno.reset.nonEmpty) { + val eventName = moduleNS.newName(anno.target.ref + "_masked") + updatedAnnos += anno.copy(target = mT.ref(eventName)) + Seq(DefNode(NoInfo, eventName, And(Negate(WRef(anno.reset.get.ref)), WRef(anno.target.ref)))) + } else { + updatedAnnos += anno + Nil + } + }) + m.copy(body = Block(m.body, addedStmts:_*)) + case o => o + } + + // Generates the sink-side hardware. See onStmtSink + private def onModuleSink(sinkAnnoModuleMap: Map[String, Seq[TriggerSinkAnnotation]], + addedAnnos: mutable.ArrayBuffer[Annotation]) + (m: DefModule): DefModule = m match { + case m: Module if sinkAnnoModuleMap.isDefinedAt(m.name) => + val sinkNameMap = sinkAnnoModuleMap(m.name).map(anno => anno.target.ref -> anno).toMap + val ns = Namespace(m) + m.map(onStmtSink(sinkNameMap, addedAnnos, ns)) + case o => o + } + + /** + * For each TriggerSink: + * 1) Emit a register that will synchronize the trigger signal to the to local domain (from the base one) + * 2) Emit a wiring annotation pointing at that register. + */ + private def onStmtSink(sinkAnnos: Map[String, TriggerSinkAnnotation], + addedAnnos: mutable.ArrayBuffer[Annotation], + ns: Namespace) + (s: Statement): Statement = s.map(onStmtSink(sinkAnnos, addedAnnos, ns)) match { + case node@DefNode(_,name,_) if sinkAnnos.isDefinedAt(name) => + val sinkAnno = sinkAnnos(name) + val mT = sinkAnno.enclosingModuleTarget + val triggerSyncName = ns.newName("trigger_sync") + val triggerSync = RegZeroPreset(NoInfo, triggerSyncName, BoolType, WRef(sinkAnno.clock.ref)) + addedAnnos += SinkAnnotation(mT.ref(triggerSyncName).toNamed, sinkWiringKey) + Block(triggerSync, node.copy(value = WRef(triggerSync))) + // Implement the missing cases? + case r@DefRegister(_,name,_,_,_,_) if sinkAnnos.isDefinedAt(name) => ??? + case s => s + } + + def execute(state: CircuitState): CircuitState = { + + val topModName = state.circuit.main + val topMod = state.circuit.modules.find(_.name == topModName).get + val prexistingPorts = topMod.ports + // 1) Collect Trigger Annotations, and generate BridgeTopWiring annotations + val srcCreditAnnos = new mutable.ArrayBuffer[TriggerSourceAnnotation]() + val srcDebitAnnos = new mutable.ArrayBuffer[TriggerSourceAnnotation]() + val sinkAnnos = new mutable.ArrayBuffer[TriggerSinkAnnotation]() + state.annotations.collect({ + case a: TriggerSourceAnnotation if a.sourceType => srcCreditAnnos += a + case a: TriggerSourceAnnotation => srcDebitAnnos += a + case a: TriggerSinkAnnotation => sinkAnnos += a + }) + + require(!(srcCreditAnnos.isEmpty && srcDebitAnnos.nonEmpty), "Provided trigger debit sources but no credit sources") + // It may make sense to relax this in the future + // This would enable trigger without posibility of disabling it in the future + require(!(srcDebitAnnos.isEmpty && srcCreditAnnos.nonEmpty), "Provided trigger credit sources but no debit sources") + + val updatedState = if ((srcCreditAnnos.isEmpty && srcDebitAnnos.isEmpty) || sinkAnnos.isEmpty) { + state + } else { + // Step 1) Gate credits and debits with their associated reset, if provided + val updatedAnnos = new mutable.ArrayBuffer[TriggerSourceAnnotation]() + val srcAnnoMap = (srcCreditAnnos ++ srcDebitAnnos).groupBy(_.enclosingModule) + val gatedCircuit = state.circuit.map(gateEventsWithReset(srcAnnoMap, updatedAnnos)) + val (gatedCredits, gatedDebits) = updatedAnnos.partition(_.sourceType) + + // Step 2) Use bridge topWiring to generate inter-module connectivity -- but drop the port list + val bridgeTopWiringAnnos = updatedAnnos.map(anno => BridgeTopWiringAnnotation(anno.target, anno.clock)) + val wiredState = (new BridgeTopWiring(topWiringPrefix)).execute(state.copy( + circuit = gatedCircuit, annotations = state.annotations ++ bridgeTopWiringAnnos)) + val wiredTopModule = wiredState.circuit.modules.collectFirst({ + case m@Module(_,name,_,_) if name == topModName => m + }).get + val otherModules = wiredState.circuit.modules.filter(_.name != topModName) + val addedPorts = wiredTopModule.ports.filterNot(p => prexistingPorts.contains(p)) + + // Step 3: Group top-wired outputs by their associated clock + val outputAnnos = wiredState.annotations.collect({ case a: BridgeTopWiringOutputAnnotation => a }) + val groupedTriggers = outputAnnos.groupBy(_.srcClockPort) + + // Step 4: Convert port assignments to wire assignments + val portName2WireMap = addedPorts.map(p => p.name -> DefWire(NoInfo, p.name, p.tpe)).toMap + def updateAssignments(stmt: Statement): Statement = stmt.map(updateAssignments) match { + case c@Connect(_, WRef(name,_,_,_), _) if portName2WireMap.isDefinedAt(name) => + val defWire = portName2WireMap(name) + Block(defWire, c.copy(loc = WRef(defWire))) + case o => o + } + + val portRemovedBody = wiredTopModule.body.map(updateAssignments) + + // 5) Per-clock-domain: count local credits and debits + val ns = Namespace(wiredTopModule) + val addedStmts = new mutable.ArrayBuffer[Statement]() + + def addReduce(bools: Seq[WRef]): WRef = DensePrefixSum(bools)({ case (a, b) => + val name = ns.newTemp + val node = DefNode(NoInfo, name, DoPrim(PrimOps.Add, Seq(a, b), Seq.empty, UnknownType)) + addedStmts += node + WRef(node) + }).last + + def counter(name: String, tpe: UIntType, clock: WRef, incr: WRef): (DefRegister, DefNode) = { + val countName = ns.newName(name) + val count = RegZeroPreset(NoInfo, countName, tpe, clock) + val nextName = ns.newName(countName + "_next") + val next = DefNode(NoInfo, nextName, DoPrim(PrimOps.Add, Seq(WRef(count), incr), Seq.empty, tpe)) + val countUpdate = Connect(NoInfo, WRef(count), WRef(next)) + addedStmts ++= Seq(count, next, countUpdate) + (count, next) + } + + // Add-reduces a set of UInt signals, before adding it to a running count, + // returning a reference to value the counter will take on in the next cycle. + def doAccounting(counterType: UIntType, clock: WRef)(name: String, bools: Seq[WRef]): WRef = + WRef(counter(name, counterType, clock, addReduce(bools))._2) + + // In each clock domain, add up all of the credits and debits + val (localCredits, localDebits) = (for ((clockRT, oAnnos) <- groupedTriggers) yield { + val credits = oAnnos.collect { + case a if gatedCredits.exists(_.target == a.pathlessSource) => WRef(portName2WireMap(a.topSink.ref)) + } + val debits = oAnnos.collect { + case a if gatedDebits.exists(_.target == a.pathlessSource) => WRef(portName2WireMap(a.topSink.ref)) + } + def doLocalAccounting = doAccounting(localCType, WRef(clockRT.ref)) _ + + val domainName = clockRT.ref + (doLocalAccounting(s"${domainName}_credits", credits), doLocalAccounting(s"${domainName}_debits", debits)) + }).unzip + + // Step 6) Synchronize and aggregate local counts into global counts in the base clock domain + val refClockRT = wiredState.annotations.collectFirst({ + case FAMEChannelConnectionAnnotation(_,TargetClockChannel(_),_,_,Some(clock :: _)) => clock + }).get + + // We only need to use a single register to synchronize a signal in GG, we use two here + // to measure how much the local count has changed between base clock cycles (hence, Diff). + def syncAndDiff(next: WRef): WRef = { + val name = next.name + val syncNameS1 = ns.newName(s"${name}_count_sync_s1") + val syncS1 = RegZeroPreset(NoInfo, syncNameS1, localCType, WRef(refClockRT.ref)) + val syncNameS2 = ns.newName(s"${name}_count_sync_s2") + val syncS2 = RegZeroPreset(NoInfo, syncNameS2, localCType, WRef(refClockRT.ref)) + val diffName = ns.newName(s"${name}_diff") + val diffNode = DefNode(NoInfo, diffName, DoPrim(PrimOps.Sub, Seq(WRef(syncS1), WRef(syncS2)), Seq.empty, localCType)) + addedStmts ++= Seq( + syncS1, syncS2, diffNode, + Connect(NoInfo, WRef(syncS1), next), + Connect(NoInfo, WRef(syncS2), WRef(syncS1)) + ) + WRef(diffNode) + } + val creditUpdates = localCredits.map(syncAndDiff).toSeq + val debitUpdates = localDebits.map(syncAndDiff).toSeq + + // Add together the changes in local debits and credits, and apply them + // to the global count + def doGlobalAccounting = doAccounting(globalCType, WRef(refClockRT.ref)) _ + val totalCredit = doGlobalAccounting("totalCredits", creditUpdates) + val totalDebit = doGlobalAccounting("totalDebits", debitUpdates) + + // Step 7) Generate the trigger enable, and prep all sinks for Wiring by + // adding a synchronization register + val triggerName = ns.newName("trigger_source") + val triggerSource = DefNode(NoInfo, triggerName, Neq(totalCredit, totalDebit)) + val triggerSourceRT = ModuleTarget(topModName, topModName).ref(triggerName) + addedStmts += triggerSource + val topModWithTrigger = wiredTopModule.copy(ports = prexistingPorts, body = Block(portRemovedBody, addedStmts:_*)) + val updatedCircuit = wiredState.circuit.copy(modules = topModWithTrigger +: otherModules) + + // Step 8) Wire the generated trigger to all sinks using the WiringTranform + val sinkModuleMap = sinkAnnos.groupBy(_.target.module) + val wiringAnnos = new mutable.ArrayBuffer[Annotation] + wiringAnnos += SourceAnnotation(triggerSourceRT.toNamed, sinkWiringKey) + val preSinkWiringCircuit = updatedCircuit.map(onModuleSink(sinkModuleMap, wiringAnnos)) + val preSinkWiringState = CircuitState(preSinkWiringCircuit, HighForm, wiredState.annotations ++ wiringAnnos) + + val sinkWiringTransforms = Seq( + new ResolveAndCheck, + new HighFirrtlToMiddleFirrtl, + new WiringTransform, + new ResolveAndCheck) + sinkWiringTransforms.foldLeft(preSinkWiringState)((in, xform) => xform.runTransform(in)) + } + + val cleanedAnnos = updatedState.annotations.flatMap({ + case a: TriggerSourceAnnotation => None + case a: TriggerSinkAnnotation => None + case a: BridgeTopWiringOutputAnnotation => None + case o => Some(o) + }) + updatedState.copy(annotations = cleanedAnnos) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/UpdateBridgeClockInfo.scala b/sim/midas/src/main/scala/midas/passes/UpdateBridgeClockInfo.scala new file mode 100644 index 00000000..ddad8928 --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/UpdateBridgeClockInfo.scala @@ -0,0 +1,33 @@ +// See LICENSE for license details. + +package midas.passes + +import midas.widgets.{RationalClock, BridgeIOAnnotation} + +import firrtl._ +import firrtl.annotations._ + +/** + * Determines which clock each bridge is synchronous with, and updates that bridge's IO annotation + * to include it's domain clock info. + * + */ +object UpdateBridgeClockInfo extends Transform { + def inputForm = LowForm + def outputForm = LowForm + def execute(state: CircuitState): CircuitState = { + val infoMaps = state.annotations.collect { + case ChannelClockInfoAnnotation(map) => map + } + require(infoMaps.size == 1, + s"Expected exactly one ChannelClockInfoAnnotation. Got: ${infoMaps.size}") + val infoMap = infoMaps.head + val annosx = state.annotations.map({ + case a@BridgeIOAnnotation(_,_,None,_,_,_) => + // There will be some cases where this is left unpopulated, i.e., for the clockBridge + a.copy(clockInfo = infoMap.get(a.channelMapping.values.head)) + case o => o + }) + state.copy(annotations = annosx) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/Annotations.scala b/sim/midas/src/main/scala/midas/passes/fame/Annotations.scala index 1b6d9d28..3caee7ef 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/Annotations.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/Annotations.scala @@ -3,22 +3,25 @@ package midas.passes.fame import firrtl._ import annotations._ +import midas.targetutils.FAMEAnnotation +import midas.widgets.RationalClock + /** * An annotation that describes the ports that constitute one channel * from the perspective of a particular module that will be replaced * by a simulation model. Note that this describes the channels as * they appear locally from within the module, so this annotation will * apply to *all* instances of that module. - * + * * Upon creation, this annotation is associated with a particular * target RTL module M that will eventually be transformed into a FAME * model. This module must only be instantiated at the top level. - * + * * @param localName refers to the name of the channel within the scope of the * eventual FAME model. This will be used as the channel’s port * name in the model. It will also be used to identify * microarchitectural state associated with the channel - * + * * @param ports a list of the ports that are grouped into the channel. The * ReferenceTargets should be rooted at M, since this information is * local to the module. This is also what associates the annotation @@ -26,12 +29,13 @@ import annotations._ */ case class FAMEChannelPortsAnnotation( localName: String, - ports: Seq[ReferenceTarget]) extends Annotation { + clockPort: Option[ReferenceTarget], + ports: Seq[ReferenceTarget]) extends Annotation with FAMEAnnotation { def update(renames: RenameMap): Seq[Annotation] = { val renamer = RTRenamer.exact(renames) - Seq(FAMEChannelPortsAnnotation(localName, ports.map(renamer))) + Seq(FAMEChannelPortsAnnotation(localName, clockPort.map(renamer), ports.map(renamer))) } - override def getTargets: Seq[ReferenceTarget] = ports + override def getTargets: Seq[ReferenceTarget] = clockPort ++: ports } /** @@ -42,15 +46,20 @@ case class FAMEChannelPortsAnnotation( * * @param channelInfo describes the type of the channel (Wire, Forward/Reverse * Decoupled) + * + * @param clock the *source* of the clock (if any) associated with this channel + * + * @note The clock source must be a port on the model side of the channel */ case class FAMEChannelConnectionAnnotation( globalName: String, channelInfo: FAMEChannelInfo, + clock: Option[ReferenceTarget], sources: Option[Seq[ReferenceTarget]], - sinks: Option[Seq[ReferenceTarget]]) extends Annotation with HasSerializationHints { + sinks: Option[Seq[ReferenceTarget]]) extends Annotation with FAMEAnnotation with HasSerializationHints { def update(renames: RenameMap): Seq[Annotation] = { val renamer = RTRenamer.exact(renames) - Seq(FAMEChannelConnectionAnnotation(globalName, channelInfo.update(renames), sources.map(_.map(renamer)), sinks.map(_.map(renamer)))) + Seq(FAMEChannelConnectionAnnotation(globalName, channelInfo.update(renames), clock.map(renamer), sources.map(_.map(renamer)), sinks.map(_.map(renamer)))) } def typeHints(): Seq[Class[_]] = Seq(channelInfo.getClass) @@ -60,7 +69,7 @@ case class FAMEChannelConnectionAnnotation( def updateRT(rT: ReferenceTarget): ReferenceTarget = ModuleTarget(rT.circuit, rT.circuit).ref(portName).field(rT.ref) require(sources == None || sinks == None, "Bridge-connected channels cannot loopback") - val rTs = sources.getOrElse(sinks.get) ++ (channelInfo match { + val rTs = sources.getOrElse(sinks.get) ++ clock ++ (channelInfo match { case i: DecoupledForwardChannel => Seq(i.readySink.getOrElse(i.readySource.get)) case other => Seq() }) @@ -69,22 +78,41 @@ case class FAMEChannelConnectionAnnotation( copy(globalName = s"${portName}_${globalName}").update(localRenames).head.asInstanceOf[this.type] } - override def getTargets: Seq[ReferenceTarget] = sources.toSeq.flatten ++ sinks.toSeq.flatten + override def getTargets: Seq[ReferenceTarget] = clock ++: (sources.toSeq.flatten ++ sinks.toSeq.flatten) } -// Helper factory methods for generating bridge annotations that have only sinks or sources +// Helper factory methods for generating common patterns object FAMEChannelConnectionAnnotation { + def implicitlyClockedLoopback( + globalName: String, + channelInfo: FAMEChannelInfo, + sources: Seq[ReferenceTarget], + sinks: Seq[ReferenceTarget]): FAMEChannelConnectionAnnotation = + FAMEChannelConnectionAnnotation(globalName, channelInfo, None, Some(sources), Some(sinks)) + def sink( globalName: String, channelInfo: FAMEChannelInfo, + clock: Option[ReferenceTarget], sinks: Seq[ReferenceTarget]): FAMEChannelConnectionAnnotation = - FAMEChannelConnectionAnnotation(globalName, channelInfo, None, Some(sinks)) + FAMEChannelConnectionAnnotation(globalName, channelInfo, clock, None, Some(sinks)) + + def implicitlyClockedSink( + globalName: String, + channelInfo: FAMEChannelInfo, + sinks: Seq[ReferenceTarget]): FAMEChannelConnectionAnnotation = sink(globalName, channelInfo, None, sinks) def source( globalName: String, channelInfo: FAMEChannelInfo, + clock: Option[ReferenceTarget], sources: Seq[ReferenceTarget]): FAMEChannelConnectionAnnotation = - FAMEChannelConnectionAnnotation(globalName, channelInfo, Some(sources), None) + FAMEChannelConnectionAnnotation(globalName, channelInfo, clock, Some(sources), None) + + def implicitlyClockedSource( + globalName: String, + channelInfo: FAMEChannelInfo, + sources: Seq[ReferenceTarget]): FAMEChannelConnectionAnnotation = source(globalName, channelInfo, None, sources) } /** @@ -104,7 +132,7 @@ sealed trait FAMEChannelInfo { */ case class PipeChannel(val latency: Int) extends FAMEChannelInfo -/** +/** * Indicates that a channel connection is the reverse (ready) half of * a decoupled target connection. Since the forward half incorporates * references to the ready signals, this channel contains no signal @@ -113,17 +141,22 @@ case class PipeChannel(val latency: Int) extends FAMEChannelInfo case object DecoupledReverseChannel extends FAMEChannelInfo /** - * Indicates that a channel connection is the reverse (ready) half of + * Indicates that a channel connection carries target clocks + */ +case class TargetClockChannel(clockInfo: Seq[RationalClock]) extends FAMEChannelInfo + +/** + * Indicates that a channel connection is the forward (valid) half of * a decoupled target connection. - * + * * @param readySink sink port component of the corresponding reverse channel - * + * * @param validSource valid port component from this channel's sources - * + * * @param readySource source port component of the corresponding reverse channel - * + * * @param validSink valid port component from this channel's sinks - * + * * @note (readySink, validSource) are on one model, (readySource, validSink) on the other */ case class DecoupledForwardChannel( @@ -150,15 +183,6 @@ object DecoupledForwardChannel { DecoupledForwardChannel(Some(ready), Some(valid), None, None) } -/** - * Indicates that a particular instance is a FAME Model - */ -case class FAMEModelAnnotation(target: InstanceTarget) extends SingleTargetAnnotation[InstanceTarget] { - def targets = Seq(target) - def duplicate(n: InstanceTarget) = this.copy(n) -} - - /** * Specifies what form of FAME transform should be applied when * generated a simulation model from a target module. @@ -184,7 +208,9 @@ case object FAME1Transform extends FAMETransformType * this is a ModuleTarget, all instances at the top level will be * transformed identically. */ -case class FAMETransformAnnotation(transformType: FAMETransformType, target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] { +case class FAMETransformAnnotation( + transformType: FAMETransformType, + target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] with FAMEAnnotation { def targets = Seq(target) def duplicate(n: ModuleTarget) = this.copy(transformType, n) } @@ -194,17 +220,18 @@ case class FAMETransformAnnotation(transformType: FAMETransformType, target: Mod * level in the hierarchy. The specified instance will be pulled out * of its parent module and will reside in its "grandparent" module * after the PromoteSubmodule transform has run. - * + * * @param target The instance to be promoted. Note that this must * be a *local* instance target, as all instances of the parent * module will be transformed identically. */ -case class PromoteSubmoduleAnnotation(target: InstanceTarget) extends SingleTargetAnnotation[InstanceTarget] { +case class PromoteSubmoduleAnnotation( + target: InstanceTarget) extends SingleTargetAnnotation[InstanceTarget] with FAMEAnnotation { def targets = Seq(target) def duplicate(n: InstanceTarget) = this.copy(n) } -abstract class FAMEGlobalSignal extends SingleTargetAnnotation[ReferenceTarget] { +abstract class FAMEGlobalSignal extends SingleTargetAnnotation[ReferenceTarget] with FAMEAnnotation { val target: ReferenceTarget def targets = Seq(target) def duplicate(n: ReferenceTarget): FAMEGlobalSignal @@ -218,7 +245,7 @@ case class FAMEHostReset(target: ReferenceTarget) extends FAMEGlobalSignal { def duplicate(t: ReferenceTarget): FAMEHostReset = this.copy(t) } -abstract class MemPortAnnotation extends Annotation { +abstract class MemPortAnnotation extends Annotation with FAMEAnnotation { val en: ReferenceTarget val addr: ReferenceTarget } @@ -287,3 +314,24 @@ case class ModelReadWritePort( } override def getTargets: Seq[ReferenceTarget] = Seq(wmode, rdata, wdata, wmask, addr, en) } + +/** + * A pass that dumps all FAME annotations to a file for debugging. + */ +class EmitFAMEAnnotations(fileName: String) extends firrtl.Transform { + import firrtl.options.TargetDirAnnotation + def inputForm = UnknownForm + def outputForm = UnknownForm + + override def name = s"[Golden Gate] Debugging FAME Annotation Emission Pass: $fileName" + + def execute(state: CircuitState) = { + val targetDir = state.annotations.collectFirst { case TargetDirAnnotation(dir) => dir } + val dirName = targetDir.getOrElse(".") + val outputFile = new java.io.PrintWriter(s"${dirName}/${fileName}") + val fameAnnos = state.annotations.collect { case fa: FAMEAnnotation => fa } + outputFile.write(JsonProtocol.serialize(fameAnnos)) + outputFile.close() + state + } +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/ChannelExcision.scala b/sim/midas/src/main/scala/midas/passes/fame/ChannelExcision.scala index 0da9c658..d689847c 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/ChannelExcision.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/ChannelExcision.scala @@ -11,13 +11,15 @@ import Utils._ import firrtl.passes.MemPortUtils import annotations.{ModuleTarget, ReferenceTarget, Annotation, SingleTargetAnnotation} +import midas.targetutils.FirrtlFAMEModelAnnotation + import scala.collection.mutable class ChannelExcision extends Transform { def inputForm = LowForm def outputForm = LowForm - val addedChannelAnnos = new mutable.ArrayBuffer[FAMEModelAnnotation]() + val addedChannelAnnos = new mutable.ArrayBuffer[FirrtlFAMEModelAnnotation]() val pipeChannels = new mutable.HashMap[(ReferenceTarget, ReferenceTarget), String]() @@ -29,19 +31,19 @@ class ChannelExcision extends Transform { def portTarget(p: Port) = topTarget.ref(p.name) def onStmt(addedPorts: mutable.ArrayBuffer[Port])(s: Statement): Statement = s.map(onStmt(addedPorts)) match { - case c @ Connect(_, lhs @ WSubField(WRef(lhsiname, _, InstanceKind, _), lhspname, _, _), - rhs @ WSubField(WRef(rhsiname, _, InstanceKind, _), rhspname, _, _)) => + case c @ Connect(_, lhs @ WSubField(WRef(lhsiname, _, InstanceKind, _), lhspname, lType, _), + rhs @ WSubField(WRef(rhsiname, _, InstanceKind, _), rhspname, rType, _)) => val lhsTarget = subfieldTarget(lhsiname, lhspname) val rhsTarget = subfieldTarget(rhsiname, rhspname) - pipeChannels.get((lhsTarget, rhsTarget)) match { - case Some(chName) => - val srcP = Port(NoInfo, s"${rhsiname}_${rhspname}_source", Output, lhs.tpe) - val sinkP = Port(NoInfo, s"${lhsiname}_${lhspname}_sink", Input, rhs.tpe) - addedPorts ++= Seq(srcP, sinkP) - renames.record(lhsTarget, portTarget(sinkP)) - renames.record(rhsTarget, portTarget(srcP)) - Block(Seq(Connect(NoInfo, lhs, WRef(sinkP)), Connect(NoInfo, WRef(srcP), rhs))) - case None => c + if (pipeChannels.contains((lhsTarget, rhsTarget)) || lType == ClockType) { + val srcP = Port(NoInfo, s"${rhsiname}_${rhspname}_source", Output, lType) + val sinkP = Port(NoInfo, s"${lhsiname}_${lhspname}_sink", Input, rType) + addedPorts ++= Seq(srcP, sinkP) + renames.record(lhsTarget, portTarget(sinkP)) + renames.record(rhsTarget, portTarget(srcP)) + Block(Seq(Connect(NoInfo, lhs, WRef(sinkP)), Connect(NoInfo, WRef(srcP), rhs))) + } else { + c } case s => s } @@ -56,7 +58,7 @@ class ChannelExcision extends Transform { // Step 1: Analysis -> build a map from reference targets to channel name state.annotations.collect({ - case fta@ FAMEChannelConnectionAnnotation(name, PipeChannel(_), Some(srcs), Some(sinks)) => + case fta@ FAMEChannelConnectionAnnotation(name, PipeChannel(_), _, Some(srcs), Some(sinks)) => sinks.zip(srcs).foreach({ pipeChannels(_) = name }) }) diff --git a/sim/midas/src/main/scala/midas/passes/fame/EmitAndWrapRAMModels.scala b/sim/midas/src/main/scala/midas/passes/fame/EmitAndWrapRAMModels.scala index 35e5b64c..5410e3d0 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/EmitAndWrapRAMModels.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/EmitAndWrapRAMModels.scala @@ -90,7 +90,7 @@ class ReadPort(val anno: ModelReadPort, val ports: Seq[Port]) extends IsMemoryPo val iSignals = Seq(anno.addr, anno.en) val iValids = iSignals.map(valid) Seq( - Connect(NoInfo, readCmd.valid, Reduce.and(iValids)), + Connect(NoInfo, readCmd.valid, And.reduce(iValids)), Connect(NoInfo, readCmd.bits("en"), bits(anno.en)), Connect(NoInfo, readCmd.bits("addr"), bits(anno.addr)), Connect(NoInfo, ready(anno.addr), readCmd.ready), @@ -110,7 +110,7 @@ class WritePort(val anno: ModelWritePort, val ports: Seq[Port]) extends IsMemory iSignals.flatMap(rT => Seq( Connect(NoInfo, ready(rT), writeCmd.ready) )) ++ Seq( - Connect(NoInfo, writeCmd.valid, Reduce.and(iValids)), + Connect(NoInfo, writeCmd.valid, And.reduce(iValids)), Connect(NoInfo, writeCmd.bits("en"), bits(anno.en)), Connect(NoInfo, writeCmd.bits("addr"), bits(anno.addr)), Connect(NoInfo, writeCmd.bits("data"), bits(anno.data)), diff --git a/sim/midas/src/main/scala/midas/passes/fame/ExtractModel.scala b/sim/midas/src/main/scala/midas/passes/fame/ExtractModel.scala index 9ec30c3a..6c9b1937 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/ExtractModel.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/ExtractModel.scala @@ -11,6 +11,8 @@ import Utils._ import firrtl.passes.MemPortUtils import annotations.{InstanceTarget, Annotation, SingleTargetAnnotation} +import midas.targetutils.FirrtlFAMEModelAnnotation + import scala.collection.mutable import mutable.{LinkedHashSet, LinkedHashMap} @@ -20,7 +22,7 @@ class ExtractModel extends Transform { def promoteModels(state: CircuitState): CircuitState = { val anns = state.annotations.flatMap { - case a @ FAMEModelAnnotation(it) if (it.module != it.circuit) => Seq(a, PromoteSubmoduleAnnotation(it)) + case a @ FirrtlFAMEModelAnnotation(it) if (it.module != it.circuit) => Seq(a, PromoteSubmoduleAnnotation(it)) case a => Seq(a) } if (anns.toSeq == state.annotations.toSeq) { diff --git a/sim/midas/src/main/scala/midas/passes/fame/FAMEDefaults.scala b/sim/midas/src/main/scala/midas/passes/fame/FAMEDefaults.scala index ff95975f..1306a292 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/FAMEDefaults.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/FAMEDefaults.scala @@ -8,6 +8,8 @@ import ir._ import annotations._ import collection.mutable.{ArrayBuffer, LinkedHashSet} +import midas.targetutils.FAMEAnnotation + // Assumes: AQB form // Run after ExtractModel // Label all unbound top-level ports as wire channels @@ -21,16 +23,12 @@ class FAMEDefaults extends Transform { override def execute(state: CircuitState): CircuitState = { val analysis = new FAMEChannelAnalysis(state, FAME1Transform) val topModule = state.circuit.modules.find(_.name == state.circuit.main).get.asInstanceOf[Module] - val globalSignals = state.annotations.collect({ case g: FAMEGlobalSignal => g.target.ref }).toSet - val channelNames = state.annotations.collect({ case fca: FAMEChannelConnectionAnnotation => fca.globalName }) + val fameAnnos = state.annotations.collect({ case fa: FAMEAnnotation => fa }) // for performance, avoid other annos + val globalSignals = fameAnnos.collect({ case g: FAMEGlobalSignal => g.target.ref }).toSet + val channelNames = fameAnnos.collect({ case fca: FAMEChannelConnectionAnnotation => fca.globalName }) val channelNS = Namespace(channelNames) def isGlobal(topPort: Port) = globalSignals.contains(topPort.name) def isBound(topPort: Port) = analysis.channelsByPort.contains(analysis.topTarget.ref(topPort.name)) - val defaultExtChannelAnnos = topModule.ports.filterNot(isGlobal).filterNot(isBound).flatMap({ - case Port(_, _, _, ClockType) => None // FIXME: Reject the clock in RC's debug interface - case Port(_, name, Input, _) => Some(FAMEChannelConnectionAnnotation(channelNS.newName(name), WireChannel, None, Some(Seq(analysis.topTarget.ref(name))))) - case Port(_, name, Output, _) => Some(FAMEChannelConnectionAnnotation(channelNS.newName(name), WireChannel, Some(Seq(analysis.topTarget.ref(name))), None)) - }) val channelModules = new LinkedHashSet[String] // TODO: find modules to absorb into channels, don't label as FAME models val defaultLoopbackAnnos = new ArrayBuffer[FAMEChannelConnectionAnnotation] val defaultModelAnnos = new ArrayBuffer[FAMETransformAnnotation] @@ -41,16 +39,16 @@ class FAMEDefaults extends Transform { wi case c @ Connect(_, WSubField(WRef(lhsiname, _, InstanceKind, _), lhspname, _, _), WSubField(WRef(rhsiname, _, InstanceKind, _), rhspname, _, _)) => if (c.loc.tpe != ClockType && c.expr.tpe != ClockType) { - defaultLoopbackAnnos += FAMEChannelConnectionAnnotation( + defaultLoopbackAnnos += FAMEChannelConnectionAnnotation.implicitlyClockedLoopback( channelNS.newName(s"${rhsiname}_${rhspname}__to__${lhsiname}_${lhspname}"), WireChannel, - Some(Seq(topTarget.ref(rhsiname).field(rhspname))), - Some(Seq(topTarget.ref(lhsiname).field(lhspname)))) + Seq(topTarget.ref(rhsiname).field(rhspname)), + Seq(topTarget.ref(lhsiname).field(lhspname))) } c case s => s } topModule.body.map(onStmt) - state.copy(annotations = state.annotations ++ defaultExtChannelAnnos ++ defaultLoopbackAnnos ++ defaultModelAnnos) + state.copy(annotations = state.annotations ++ defaultLoopbackAnnos ++ defaultModelAnnos) } } diff --git a/sim/midas/src/main/scala/midas/passes/fame/FAMETransform.scala b/sim/midas/src/main/scala/midas/passes/fame/FAMETransform.scala index 1044cabf..42a25582 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/FAMETransform.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/FAMETransform.scala @@ -9,6 +9,7 @@ import ir._ import Mappers._ import firrtl.Utils.{BoolType, kind, ceilLog2, one} import firrtl.passes.MemPortUtils +import firrtl.transforms.DontTouchAnnotation import annotations._ import scala.collection.mutable import mutable.{LinkedHashSet, LinkedHashMap} @@ -26,202 +27,229 @@ trait FAME1Channel { def name: String def direction: Direction def ports: Seq[Port] - def tpe: Type = FAMEChannelAnalysis.getHostDecoupledChannelType(name, ports) - def portName: String - def asPort: Port = Port(NoInfo, portName, direction, tpe) - def isReady: Expression = WSubField(WRef(asPort), "ready", BoolType) - def isValid: Expression = WSubField(WRef(asPort), "valid", BoolType) - def isFiring: Expression = Reduce.and(Seq(isReady, isValid)) - def replacePortRef(wr: WRef): WSubField = { - if (ports.size > 1) { - WSubField(WSubField(WRef(asPort), "bits"), FAMEChannelAnalysis.removeCommonPrefix(wr.name, name)._1) - } else { - WSubField(WRef(asPort), "bits") - } - } + def isValid: Expression + def asHostModelPort: Option[Port] = None + def replacePortRef(wr: WRef): Expression } -case class FAME1InputChannel(val name: String, val ports: Seq[Port]) extends FAME1Channel { +trait InputChannel { + this: FAME1Channel => val direction = Input val portName = s"${name}_sink" - def genTokenLogic(finishing: WRef): Seq[Statement] = { - Seq(Connect(NoInfo, isReady, finishing)) + def setReady(readyCond: Expression): Statement +} + +trait HasModelPort { + this: FAME1Channel => + override def isValid = WSubField(WRef(asHostModelPort.get), "valid", BoolType) + def isReady = WSubField(WRef(asHostModelPort.get), "ready", BoolType) + def isFiring: Expression = And(isReady, isValid) + def setReady(advanceCycle: Expression): Statement = Connect(NoInfo, isReady, advanceCycle) + + override def asHostModelPort: Option[Port] = { + val tpe = FAMEChannelAnalysis.getHostDecoupledChannelType(name, ports) + direction match { + case Input => Some(Port(NoInfo, s"${name}_sink", Input, tpe)) + case Output => Some(Port(NoInfo, s"${name}_source", Output, tpe)) + } + } + + def replacePortRef(wr: WRef): Expression = { + val payload = WSubField(WRef(asHostModelPort.get), "bits") + if (ports.size == 1) payload else WSubField(payload, FAMEChannelAnalysis.removeCommonPrefix(wr.name, name)._1) } } -case class FAME1OutputChannel(val name: String, val ports: Seq[Port], val firedReg: DefRegister) extends FAME1Channel { +trait FAME1DataChannel extends FAME1Channel with HasModelPort { + def clockDomainEnable: Expression + def firedReg: DefRegister + def isFired = WRef(firedReg) + def isFiredOrFiring = Or(isFired, isFiring) + def updateFiredReg(finishing: WRef): Statement = { + Connect(NoInfo, isFired, Mux(finishing, Negate(clockDomainEnable), isFiredOrFiring, BoolType)) + } +} + +case class FAME1ClockChannel(name: String, ports: Seq[Port]) extends FAME1Channel with InputChannel with HasModelPort + +case class VirtualClockChannel(targetClock: Port) extends FAME1Channel with InputChannel { + val name = "VirtualClockChannel" + val ports = Seq(targetClock) + val isValid: Expression = UIntLiteral(1) + def setReady(advanceCycle: Expression): Statement = EmptyStmt + def replacePortRef(wr: WRef): Expression = UIntLiteral(1) +} + +case class FAME1InputChannel( + name: String, + clockDomainEnable: Expression, + ports: Seq[Port], + firedReg: DefRegister) extends FAME1DataChannel with InputChannel { + override def setReady(advanceCycle: Expression): Statement = { + Connect(NoInfo, isReady, And(advanceCycle, Negate(isFired))) + } +} + +case class FAME1OutputChannel( + name: String, + clockDomainEnable: Expression, + ports: Seq[Port], + firedReg: DefRegister) extends FAME1DataChannel { val direction = Output val portName = s"${name}_source" - val isFired = WRef(firedReg) - val isFiredOrFiring = Reduce.or(Seq(isFired, isFiring)) - def genTokenLogic(finishing: WRef, ccDeps: Iterable[FAME1InputChannel]): Seq[Statement] = { - val regUpdate = Connect( - NoInfo, - isFired, - Mux(finishing, - UIntLiteral(0, IntWidth(1)), - isFiredOrFiring, - BoolType)) - val setValid = Connect( - NoInfo, - isValid, - Reduce.and(ccDeps.map(_.isValid) ++ Seq(Negate(isFired)))) - Seq(regUpdate, setValid) - } -} - -object ChannelCCDependencyGraph { - def apply(m: Module): LinkedHashMap[FAME1OutputChannel, LinkedHashSet[FAME1InputChannel]] = { - new LinkedHashMap[FAME1OutputChannel, LinkedHashSet[FAME1InputChannel]] - } -} - -object PatientMemTransformer { - def apply(mem: DefMemory, finishing: Expression, memClock: WRef, ns: Namespace): Block = { - val shim = DefWire(NoInfo, mem.name, MemPortUtils.memType(mem)) - val newMem = mem.copy(name = ns.newName(mem.name)) - val defaultConnect = Connect(NoInfo, WRef(shim), WRef(newMem.name, shim.tpe, MemKind)) - val syncReadPorts = (newMem.readers ++ newMem.readwriters).filter(rp => mem.readLatency > 0) - val preserveReads = syncReadPorts.flatMap { - case rpName => - val addrWidth = IntWidth(ceilLog2(mem.depth) max 1) - val dummyReset = DefWire(NoInfo, ns.newName(s"${mem.name}_${rpName}_dummyReset"), BoolType) - val tieOff = Connect(NoInfo, WRef(dummyReset), UIntLiteral(0)) - val addrReg = new DefRegister(NoInfo, ns.newName(s"${mem.name}_${rpName}"), - UIntType(addrWidth), memClock, WRef(dummyReset), UIntLiteral(0, addrWidth)) - val updateReg = Connect(NoInfo, WRef(addrReg), WSubField(WSubField(WRef(shim), rpName), "addr")) - val useReg = Connect(NoInfo, MemPortUtils.memPortField(newMem, rpName, "addr"), WRef(addrReg)) - Seq(dummyReset, tieOff, addrReg, Conditionally(NoInfo, finishing, updateReg, useReg)) - } - val gateWrites = (newMem.writers ++ newMem.readwriters).map { - case wpName => - Conditionally( - NoInfo, - Negate(finishing), - Connect(NoInfo, MemPortUtils.memPortField(newMem, wpName, "en"), UIntLiteral(0, IntWidth(1))), - EmptyStmt) - } - new Block(Seq(shim, newMem, defaultConnect) ++ preserveReads ++ gateWrites) - } -} - -object PatientSSMTransformer { - def apply(m: Module, analysis: FAMEChannelAnalysis)(implicit triggerName: String): Module = { - val ns = Namespace(m) - val clocks = m.ports.filter(_.tpe == ClockType) - // TODO: turn this back on - // assert(clocks.length == 1) - val finishing = new Port(NoInfo, ns.newName(triggerName), Input, BoolType) - val hostClock = clocks.find(_.name == "clock").getOrElse(clocks.head) // TODO: naming convention for host clock - def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { - case conn @ Connect(info, lhs, _) if (kind(lhs) == RegKind) => - Conditionally(info, WRef(finishing), conn, EmptyStmt) - case s: Stop => s.copy(en = DoPrim(PrimOps.And, Seq(WRef(finishing), s.en), Seq.empty, BoolType)) - case p: Print => p.copy(en = DoPrim(PrimOps.And, Seq(WRef(finishing), p.en), Seq.empty, BoolType)) - case mem: DefMemory => PatientMemTransformer(mem, WRef(finishing), WRef(hostClock), ns) - case wi: WDefInstance if analysis.syncNativeModules.contains(analysis.moduleTarget(wi)) => - new Block(Seq(wi, Connect(wi.info, WSubField(WRef(wi), triggerName), WRef(finishing)))) - case s => s - } - Module(m.info, m.name, m.ports :+ finishing, m.body.map(onStmt)) + def setValid(finishing: WRef, ccDeps: Iterable[FAME1InputChannel]): Statement = { + Connect(NoInfo, isValid, And.reduce(ccDeps.map(_.isValid).toSeq :+ Negate(isFired))) } } +// Multi-clock timestep: +// When finishing is high, dequeue token from clock channel +// - Use to initialize isFired for all channels (with negation) +// - Finishing is gated with clock channel valid object FAMEModuleTransformer { - def apply(m: Module, analysis: FAMEChannelAnalysis)(implicit triggerName: String): Module = { - // Step 0: Special signals & bookkeeping + def apply(m: Module, analysis: FAMEChannelAnalysis): Module = { + // Step 0: Bookkeeping for port structure conventions implicit val ns = Namespace(m) - val clocks = m.ports.filter(_.tpe == ClockType) - // TODO: turn this back to == 1 - assert(clocks.length >= 1) - val hostClock = clocks.find(_.name == "clock").getOrElse(clocks.head) // TODO: naming convention for host clock - val hostReset = HostReset.makePort(ns) - def createHostReg(name: String = "host", width: Width = IntWidth(1)): DefRegister = { - new DefRegister(NoInfo, ns.newName(name), UIntType(width), WRef(hostClock), WRef(hostReset), UIntLiteral(0, width)) - } - val finishing = DefWire(NoInfo, ns.newName(triggerName), BoolType) - /* - * The conjuction of the following expressions enables the target-clock - * Failing to keep the target clock-gated during reset can lead to spurious - * target-clock edges leading to non-deterministically initialized target state - * in registers that are not reset - */ - val clockGatePredicates = Seq(WRef(finishing), DoPrim(PrimOps.Not, Seq(WRef(hostReset)), Seq.empty, BoolType)) - val buf = InstanceInfo(DefineAbstractClockGate.blackbox).connect("I", WRef(hostClock)) - .connect("CE", DoPrim(PrimOps.And, clockGatePredicates, Seq.empty, BoolType)) - val targetClock = SignalInfo(buf.decl, buf.assigns, WSubField(buf.ref, "O", ClockType, SourceFlow)) - - // Step 1: Build channels val mTarget = ModuleTarget(analysis.circuit.main, m.name) - val inChannels = (analysis.modelInputChannelPortMap(mTarget)).map({ - case(cName, ports) => new FAME1InputChannel(cName, ports) - }) - val inChannelMap = new LinkedHashMap[String, FAME1InputChannel] ++ - (inChannels.flatMap(c => c.ports.map(p => (p.name, c)))) + val clocks: Seq[Port] = m.ports.filter(_.tpe == ClockType) + val portsByName = m.ports.map(p => p.name -> p).toMap + assert(clocks.length >= 1) - val outChannels = analysis.modelOutputChannelPortMap(mTarget).map({ - case(cName, ports) => - val firedReg = createHostReg(name = ns.newName(s"${cName}_fired")) - new FAME1OutputChannel(cName, ports, firedReg) - }) - val outChannelMap = new LinkedHashMap[String, FAME1OutputChannel] ++ - (outChannels.flatMap(c => c.ports.map(p => (p.name, c)))) - val decls = Seq(finishing) ++ outChannels.map(_.firedReg) + // Multi-clock management step 1: Add host clock + reset ports, finishing wire + // TODO: Should finishing be a WrappedComponent? + // TODO: Avoid static naming convention. + val hostReset = Port(NoInfo, WrapTop.hostResetName, Input, BoolType) + val hostClock = Port(NoInfo, WrapTop.hostClockName, Input, ClockType) + val finishing = DefWire(NoInfo, "targetCycleFinishing", BoolType) + assert(ns.tryName(hostReset.name) && ns.tryName(hostClock.name) && ns.tryName(finishing.name)) + def hostFlagReg(suggestName: String, resetVal: UIntLiteral = UIntLiteral(0)): DefRegister = { + DefRegister(NoInfo, ns.newName(suggestName), BoolType, WRef(hostClock), WRef(hostReset), resetVal) + } + + // Multi-clock management step 2: Build clock flags and clock channel + def isClockChannel(info: (String, (Option[Port], Seq[Port]))) = info match { + case (_, (clk, ports)) => clk.isEmpty && ports.forall(_.tpe == ClockType) + } - // Step 2: Find combinational dependencies - val ccChecker = new firrtl.transforms.CheckCombLoops + val clockChannel = analysis.modelInputChannelPortMap(mTarget).find(isClockChannel) match { + case Some((name, (None, ports))) => FAME1ClockChannel(name, ports) + case Some(_) => ??? // Clock channel cannot have an associated clock domain + case None => VirtualClockChannel(clocks.head) // Virtual clock channel for single-clock models + } + + /* + * NB: Failing to keep the target clock-gated during FPGA initialization + * can lead to spurious updates or metastability in target state elements. + * Keeping all target-clocks gated through the latter stages of FPGA + * initialization and reset ensures all target state elements are + * initialized with a deterministic set of initial values. + */ + val nReset = DoPrim(PrimOps.Not, Seq(WRef(hostReset)), Seq.empty, BoolType) + + // Multi-clock management step 4: Generate clock buffers for all target clocks + val targetClockBufs: Seq[SignalInfo] = clockChannel.ports.map { en => + val enableReg = hostFlagReg(s"${en.name}_enabled", resetVal = UIntLiteral(1)) + val buf = WDefInstance(ns.newName(s"${en.name}_buffer"), DefineAbstractClockGate.blackbox.name) + val clockFlag = DoPrim(PrimOps.AsUInt, Seq(clockChannel.replacePortRef(WRef(en))), Nil, BoolType) + val connects = Block(Seq( + Connect(NoInfo, WRef(enableReg), Mux(WRef(finishing), clockFlag, WRef(enableReg), BoolType)), + Connect(NoInfo, WSubField(WRef(buf), "I"), WRef(hostClock)), + Connect(NoInfo, WSubField(WRef(buf), "CE"), And(And(WRef(enableReg), WRef(finishing)), nReset)))) + SignalInfo(Block(Seq(enableReg, buf)), connects, WSubField(WRef(buf), "O", ClockType, SourceFlow)) + } + + // Multi-clock management step 5: Generate target clock substitution map + def asWE(p: Port) = WrappedExpression.we(WRef(p)) + val replaceClocksMap = (clockChannel.ports.map(p => asWE(p)) zip targetClockBufs.map(_.ref)).toMap + + // LI-BDN transformation step 1: Build channels + // TODO: get rid of the analysis calls; we just need connectivity & annotations val portDeps = analysis.connectivity(m.name) + + def genMetadata(info: (String, (Option[Port], Seq[Port]))) = info match { + case (cName, (Some(clock), ports)) => + // must be driven by one clock input port + // TODO: this should not include muxes in connectivity! + val srcClockPorts = portDeps.getEdges(clock.name).map(portsByName(_)) + assert(srcClockPorts.size == 1) + val clockRef = WRef(srcClockPorts.head) + val clockFlag = DoPrim(PrimOps.AsUInt, Seq(clockChannel.replacePortRef(clockRef)), Nil, BoolType) + val firedReg = hostFlagReg(suggestName = ns.newName(s"${cName}_fired")) + (cName, clockFlag, ports, firedReg) + case (cName, (None, ports)) => clockChannel match { + case vc: VirtualClockChannel => + val firedReg = hostFlagReg(suggestName = ns.newName(s"${cName}_fired")) + (cName, UIntLiteral(1), ports, firedReg) + case _ => + throw new RuntimeException(s"Channel ${cName} has no associated clock.") + } + } + + // LinkedHashMap.from is 2.13-only :( + def stableMap[K, V](contents: Iterable[(K, V)]) = new LinkedHashMap[K, V] ++= contents + + // Have to filter out the clock channel from the input channels + val inChannelInfo = analysis.modelInputChannelPortMap(mTarget).filterNot(isClockChannel(_)).toSeq + val inChannelMetadata = inChannelInfo.map(genMetadata(_)) + val inChannels = inChannelMetadata.map((FAME1InputChannel.apply _).tupled) + val inChannelMap = stableMap(inChannels.flatMap(c => c.ports.map(p => p.name -> c))) + + val outChannelInfo = analysis.modelOutputChannelPortMap(mTarget).toSeq + val outChannelMetadata = outChannelInfo.map(genMetadata(_)) + val outChannels = outChannelMetadata.map((FAME1OutputChannel.apply _).tupled) + val outChannelMap = stableMap(outChannels.flatMap(c => c.ports.map(p => p.name -> c))) + + // LI-BDN transformation step 2: find combinational dependencies among channels val ccDeps = new LinkedHashMap[FAME1OutputChannel, LinkedHashSet[FAME1InputChannel]] portDeps.getEdgeMap.collect({ case (o, iSet) if outChannelMap.contains(o) => // Only add input channels, since output might depend on output RHS ref ccDeps.getOrElseUpdate(outChannelMap(o), new LinkedHashSet[FAME1InputChannel]) ++= iSet.flatMap(inChannelMap.get(_)) }) - // Step 3: transform ports - val transformedPorts = clocks ++ Seq(hostReset) ++ inChannels.map(_.asPort) ++ outChannels.map(_.asPort) + // LI-BDN transformation step 3: transform ports (includes new clock ports) + val transformedPorts = hostClock +: hostReset +: (clockChannel +: inChannels ++: outChannels).flatMap(_.asHostModelPort) - // Step 4: Replace refs and gate state updates + // LI-BDN transformation step 4: replace port and clock references and gate state updates + val clockChannelPortNames = clockChannel.ports.map(_.name).toSet def onExpr(expr: Expression): Expression = expr.map(onExpr) match { - case iWR @ WRef(name, tpe, PortKind, SourceFlow) if tpe != ClockType => + case iWR @ WRef(name, tpe, PortKind, SourceFlow) if tpe != ClockType => // Generally SourceFlow references to ports will be input channels, but RTL may use // an assignment to an output port as something akin to a wire, so check output ports too. inChannelMap.getOrElse(name, outChannelMap(name)).replacePortRef(iWR) case oWR @ WRef(name, tpe, PortKind, SinkFlow) if tpe != ClockType => outChannelMap(name).replacePortRef(oWR) - case wr: WRef if wr.name == hostClock.name => - // Replace host clock references with target clock references - targetClock.ref - case e => e + case cWR @ WRef(name, ClockType, PortKind, SourceFlow) if clockChannelPortNames(name) => + replaceClocksMap(WrappedExpression.we(cWR)) + case e => e map onExpr } - // This is vestigial from when we supported two means of implementing clock - // gating; This will be removed when multiclock is implemented - val targetStateTrigger = one - - def onStmt(stmt: Statement): Statement = stmt.map(onStmt).map(onExpr) match { - case conn @ Connect(info, lhs, _) if (kind(lhs) == RegKind) => - Conditionally(info, targetStateTrigger, conn, EmptyStmt) - case mem: DefMemory => PatientMemTransformer(mem, targetStateTrigger, WRef(hostClock), ns) - case wi: WDefInstance if analysis.syncNativeModules.contains(analysis.moduleTarget(wi)) => - new Block(Seq(wi, Connect(wi.info, WSubField(WRef(wi), triggerName), targetStateTrigger))) - case s: Stop => s.copy(en = DoPrim(PrimOps.And, Seq(targetStateTrigger, s.en), Seq.empty, BoolType)) - case p: Print => p.copy(en = DoPrim(PrimOps.And, Seq(targetStateTrigger, p.en), Seq.empty, BoolType)) - case s => s + def onStmt(stmt: Statement): Statement = stmt match { + case Connect(info, WRef(name, ClockType, PortKind, flow), rhs) => + // Don't substitute gated clock for LHS expressions + Connect(info, WRef(name, ClockType, WireKind, flow), onExpr(rhs)) + case s => s map onStmt map onExpr } - val transformedStmts = Seq(m.body.map(onStmt)) + val updatedBody = onStmt(m.body) - // Step 5: Add firing rules for output channels, trigger end of cycle - val ruleStmts = new mutable.ArrayBuffer[Statement] - ruleStmts ++= outChannels.flatMap(o => o.genTokenLogic(WRef(finishing), ccDeps(o))) - ruleStmts ++= inChannels.flatMap(i => i.genTokenLogic(WRef(finishing))) - ruleStmts += Connect(NoInfo, WRef(finishing), - Reduce.and(outChannels.map(_.isFiredOrFiring) ++ inChannels.map(_.isValid))) + // LI-BDN transformation step 5: add firing rules for output channels, trigger end of cycle + // This is modified for multi-clock, as each channel fires only when associated clock is enabled + val allFiredOrFiring = And.reduce(outChannels.map(_.isFiredOrFiring) ++ inChannels.map(_.isValid)) + + val channelStateRules = (inChannels ++ outChannels).map(c => c.updateFiredReg(WRef(finishing))) + val inputRules = inChannels.map(i => i.setReady(WRef(finishing))) + val outputRules = outChannels.map(o => o.setValid(WRef(finishing), ccDeps(o))) + val topRules = Seq(clockChannel.setReady(allFiredOrFiring), + Connect(NoInfo, WRef(finishing), And(allFiredOrFiring, clockChannel.isValid))) + + // Keep output clock ports around as wires just for convenience to keep connects legal + val clockOutputsAsWires = m.ports.collect { case Port(i, n, Output, ClockType) => DefWire(i, n, ClockType) } // Statements have to be conservatively ordered to satisfy declaration order - val allStmts = targetClock.decl +: (decls ++ transformedStmts ++ ruleStmts) :+ targetClock.assigns - Module(m.info, m.name, transformedPorts, new Block(allStmts)) + val decls = finishing +: clockOutputsAsWires ++: targetClockBufs.map(_.decl) ++: (inChannels ++ outChannels).map(_.firedReg) + val assigns = targetClockBufs.map(_.assigns) ++ channelStateRules ++ inputRules ++ outputRules ++ topRules + Module(m.info, m.name, transformedPorts, Block(decls ++: updatedBody +: assigns)) } } @@ -231,8 +259,10 @@ class FAMETransform extends Transform { def updateNonChannelConnects(analysis: FAMEChannelAnalysis)(stmt: Statement): Statement = stmt.map(updateNonChannelConnects(analysis)) match { case wi: WDefInstance if (analysis.transformedModules.contains(analysis.moduleTarget(wi))) => - val resetConn = Connect(NoInfo, WSubField(WRef(wi), "hostReset"), WRef(analysis.hostReset.ref, BoolType)) - Block(Seq(wi, resetConn)) + val clockConn = Connect(NoInfo, WSubField(WRef(wi), WrapTop.hostClockName), WRef(analysis.hostClock.ref, ClockType)) + val resetConn = Connect(NoInfo, WSubField(WRef(wi), WrapTop.hostResetName), WRef(analysis.hostReset.ref, BoolType)) + Block(Seq(wi, clockConn, resetConn)) + case Connect(_, lhs, rhs) if (lhs.tpe == ClockType) => EmptyStmt // drop ancillary clock connects case Connect(_, WRef(name, _, _, _), _) if (analysis.staleTopPorts.contains(analysis.topTarget.ref(name))) => EmptyStmt case Connect(_, _, WRef(name, _, _, _)) if (analysis.staleTopPorts.contains(analysis.topTarget.ref(name))) => EmptyStmt case s => s @@ -254,9 +284,9 @@ class FAMETransform extends Transform { analysis.sourcePorts(c).map(rt => (rt, rt.copy(ref = s"${c}_source").field("bits").field(FAMEChannelAnalysis.removeCommonPrefix(rt.ref, c)._1))) }) - def renamePorts(suffix: String, lookup: ModuleTarget => Map[String, Seq[Port]]) + def renamePorts(suffix: String, lookup: ModuleTarget => Map[String, (Option[Port], Seq[Port])]) (mT: ModuleTarget): Seq[(ReferenceTarget, ReferenceTarget)] = { - lookup(mT).toSeq.flatMap({ case (cName, pList) => + lookup(mT).toSeq.flatMap({ case (cName, (clockOption, pList)) => pList.map({ port => val decoupledTarget = mT.ref(s"${cName}${suffix}").field("bits") if (pList.size == 1) @@ -264,6 +294,7 @@ class FAMETransform extends Transform { else (mT.ref(port.name), decoupledTarget.field(FAMEChannelAnalysis.removeCommonPrefix(port.name, cName)._1)) }) + // TODO: rename clock to nothing, since it is deleted }) } def renameModelInputs: ModuleTarget => Seq[(ReferenceTarget, ReferenceTarget)] = renamePorts("_sink", analysis.modelInputChannelPortMap) @@ -292,12 +323,20 @@ class FAMETransform extends Transform { val analysis = new FAMEChannelAnalysis(state, FAME1Transform) // TODO: pick a value that does not collide implicit val triggerName = "finishing" + + val toTransform = analysis.transformedModules val transformedModules = c.modules.map { case m: Module if (m.name == c.main) => transformTop(m, analysis) - case m: Module if (analysis.transformedModules.contains(ModuleTarget(c.main,m.name))) => FAMEModuleTransformer(m, analysis) - case m: Module if (analysis.syncNativeModules.contains(ModuleTarget(c.main, m.name))) => PatientSSMTransformer(m, analysis) - case m => m + case m: Module if (toTransform.contains(ModuleTarget(c.main, m.name))) => FAMEModuleTransformer(m, analysis) + case m => m // TODO (Albert): revisit this; currently, not transforming nested modules } - state.copy(circuit = c.copy(modules = transformedModules), renames = Some(hostDecouplingRenames(analysis))) + + val filteredAnnos = state.annotations.filter { + case DontTouchAnnotation(rt) if toTransform.contains(rt.moduleTarget) => false + case _ => true + } + + val newCircuit = c.copy(modules = transformedModules) + CircuitState(newCircuit, outputForm, filteredAnnos, Some(hostDecouplingRenames(analysis))) } } diff --git a/sim/midas/src/main/scala/midas/passes/fame/FAMEUtils.scala b/sim/midas/src/main/scala/midas/passes/fame/FAMEUtils.scala index 531df676..a57bd1c4 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/FAMEUtils.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/FAMEUtils.scala @@ -6,6 +6,7 @@ import firrtl._ import ir._ import Utils._ import Mappers._ +import traversals.Foreachers._ import graph.DiGraph import analyses.InstanceGraph import transforms.CheckCombLoops @@ -50,11 +51,6 @@ private[fame] object FAMEChannelAnalysis { def getHostDecoupledChannelType(name: String, ports: Seq[Port]): Type = Decouple(getHostDecoupledChannelPayloadType(name, ports)) } -private [fame] object HostReset { - def makePort(ns: Namespace): Port = - new Port(NoInfo, ns.newName("hostReset"), Input, Utils.BoolType) -} - private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: FAMETransformType) { // TODO: only transform submodules of model modules // TODO: add renames! @@ -101,6 +97,7 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F transformedModules += mt case fca: FAMEChannelConnectionAnnotation => channels += fca.globalName + fca.clock.foreach({ rt => channelsByPort(rt) = fca.globalName }) fca.sinks.toSeq.flatten.foreach({ rt => channelsByPort(rt) = fca.globalName }) fca.sources.toSeq.flatten.foreach({ rt => channelsByPort(rt) = fca.globalName }) }) @@ -109,11 +106,17 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F val topConnects = new LinkedHashMap[ReferenceTarget, ReferenceTarget] val inputChannels = new LinkedHashMap[ModuleTarget, mutable.Set[String]] with MultiMap[ModuleTarget, String] val outputChannels = new LinkedHashMap[ModuleTarget, mutable.Set[String]] with MultiMap[ModuleTarget, String] - moduleNodes(topTarget).asInstanceOf[Module].body.map(getTopConnects) - def getTopConnects(stmt: Statement): Statement = stmt.map(getTopConnects) match { + getTopConnects(moduleNodes(topTarget).asInstanceOf[Module].body) + + def getTopConnects(stmt: Statement): Unit = stmt match { case WDefInstance(_, iname, mname, _) => moduleOfInstance(iname) = mname - EmptyStmt + case Connect(_, WRef(tpname, ClockType, _, _), WSubField(WRef(iname, _, _, _), pname, ClockType, _)) => + // Clock connect, don't make any channels + // The clock in a FAMEChannelConnectionAnnotation is the clock from model to bridge + val tpRef = topTarget.ref(tpname) + val child = topTarget.instOf(iname, moduleOfInstance(iname)) + topConnects(tpRef) = child.ref(pname) case Connect(_, WRef(tpname, _, _, _), WSubField(WRef(iname, _, _, _), pname, _, _)) => val tpRef = topTarget.ref(tpname) channelsByPort.get(tpRef).foreach({ cname => @@ -121,7 +124,6 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F topConnects(tpRef) = child.ref(pname) outputChannels.addBinding(child.ofModuleTarget, cname) }) - EmptyStmt case Connect(_, WSubField(WRef(iname, _, _, _), pname, _, _), WRef(tpname, _, _, _)) => val tpRef = topTarget.ref(tpname) channelsByPort.get(tpRef).foreach({ cname => @@ -129,20 +131,27 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F topConnects(tpRef) = child.ref(pname) inputChannels.addBinding(child.ofModuleTarget, cname) }) - EmptyStmt - case s => EmptyStmt + case s => s.foreach(getTopConnects) } val transformedSinks = new LinkedHashSet[String] val transformedSources = new LinkedHashSet[String] val sinkModel = new LinkedHashMap[String, InstanceTarget] val sourceModel = new LinkedHashMap[String, InstanceTarget] + + // clock ports don't go from one model to the other -> only one map needed + val modelClockPort = new LinkedHashMap[String, Option[ReferenceTarget]] val sinkPorts = new LinkedHashMap[String, Seq[ReferenceTarget]] val sourcePorts = new LinkedHashMap[String, Seq[ReferenceTarget]] val staleTopPorts = new LinkedHashSet[ReferenceTarget] state.annotations.collect({ case fca: FAMEChannelConnectionAnnotation => channels += fca.globalName + + // Clock port always gets recorded and marked for deletion in FAME transform + modelClockPort(fca.globalName) = fca.clock + staleTopPorts ++= fca.clock + val sinks = fca.sinks.toSeq.flatten sinkPorts(fca.globalName) = sinks sinks.headOption.filter(rt => transformedModules.contains(ModuleTarget(rt.circuit, topConnects(rt).encapsulatingModule))).foreach({ rt => @@ -151,6 +160,7 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F transformedSinks += fca.globalName staleTopPorts ++= sinks }) + val sources = fca.sources.toSeq.flatten sourcePorts(fca.globalName) = sources sources.headOption.filter(rt => transformedModules.contains(ModuleTarget(rt.circuit, topConnects(rt).encapsulatingModule))).foreach({ rt => @@ -161,42 +171,48 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F }) }) + val hostClock = state.annotations.collect({ case FAMEHostClock(rt) => rt }).head val hostReset = state.annotations.collect({ case FAMEHostReset(rt) => rt }).head - def inputPortsByChannel(mTarget: ModuleTarget): Map[String, Seq[Port]] = { + private def irPortFromGlobalTarget(mt: ModuleTarget)(rt: ReferenceTarget): Option[Port] = { + val modelPort = topConnects(rt).pathlessTarget + Some(modelPort).filter(_.module == mt.module).map(portNodes(_)) + } + + def portsByInputChannel(mTarget: ModuleTarget): Map[String, (Option[Port], Seq[Port])] = { val iChannels = inputChannels.get(mTarget).toSet.flatten iChannels.map({ - cname => (cname, sinkPorts(cname).map(topConnects(_).pathlessTarget).map(portNodes(_))) + cname => (cname, (modelClockPort(cname).flatMap(irPortFromGlobalTarget(mTarget)), sinkPorts(cname).map(rt => irPortFromGlobalTarget(mTarget)(rt).get))) }).toMap } - def outputPortsByChannel(mTarget: ModuleTarget): Map[String, Seq[Port]] = { + def portsByOutputChannel(mTarget: ModuleTarget): Map[String, (Option[Port], Seq[Port])] = { val oChannels = outputChannels.get(mTarget).toSet.flatten oChannels.map({ - cname => (cname, sourcePorts(cname).map(topConnects(_).pathlessTarget).map(portNodes(_))) + cname => (cname, (modelClockPort(cname).flatMap(irPortFromGlobalTarget(mTarget)), sourcePorts(cname).map(rt => irPortFromGlobalTarget(mTarget)(rt).get))) }).toMap } lazy val modelPorts = { val mPorts = new LinkedHashMap[ModuleTarget, mutable.Set[FAMEChannelPortsAnnotation]] with MultiMap[ModuleTarget, FAMEChannelPortsAnnotation] state.annotations.collect({ - case fcp@FAMEChannelPortsAnnotation(_, port :: ps) => mPorts.addBinding(port.moduleTarget, fcp) + case fcp @ FAMEChannelPortsAnnotation(_, _, port :: ps) => mPorts.addBinding(port.moduleTarget, fcp) }) mPorts } // Looks up all FAMEChannelPortAnnotations bound to a model module, to generate a Map - // from channel name to port list - private def genModelChannelPortMap(direction: Option[Direction])(mTarget: ModuleTarget): Map[String, Seq[Port]] = { + // from channel name to clock option and port list + private def genModelChannelPortMap(direction: Option[Direction])(mTarget: ModuleTarget): Map[String, (Option[Port], Seq[Port])] = { modelPorts(mTarget).collect({ - case FAMEChannelPortsAnnotation(name, ports) if direction == None || portNodes(ports.head).direction == direction.get => - (name, ports.map(portNodes(_))) + case FAMEChannelPortsAnnotation(name, clock, ports) if direction == None || portNodes(ports.head).direction == direction.get => + (name, (clock.map(portNodes(_)), ports.map(portNodes(_)))) }).toMap } - def modelInputChannelPortMap: ModuleTarget => Map[String, Seq[Port]] = genModelChannelPortMap(Some(Input)) - def modelOutputChannelPortMap: ModuleTarget => Map[String, Seq[Port]] = genModelChannelPortMap(Some(Output)) - def modelChannelPortMap: ModuleTarget => Map[String, Seq[Port]] = genModelChannelPortMap(None) + def modelInputChannelPortMap: ModuleTarget => Map[String, (Option[Port], Seq[Port])] = genModelChannelPortMap(Some(Input)) + def modelOutputChannelPortMap: ModuleTarget => Map[String, (Option[Port], Seq[Port])] = genModelChannelPortMap(Some(Output)) + def modelChannelPortMap: ModuleTarget => Map[String, (Option[Port], Seq[Port])] = genModelChannelPortMap(None) def getSinkHostDecoupledChannelType(cName: String): Type = { FAMEChannelAnalysis.getHostDecoupledChannelType(cName, sinkPorts(cName).map(portNodes(_))) @@ -210,28 +226,34 @@ private[fame] class FAMEChannelAnalysis(val state: CircuitState, val fameType: F // - Used to produce port annotations in InferModelPorts // - Reran to look up port names on model instances class ModulePortDeduper(val mTarget: ModuleTarget) { - val visitedLeafPort = new LinkedHashSet[Port]() - val visitedChannel = new LinkedHashMap[Seq[Port], String]() val channelDedups = new LinkedHashMap[String, String] - def channelSharesPorts(ps: Seq[Port]): Boolean = ps.map(visitedLeafPort).reduce(_ || _) - def channelIsDuplicate(ps: Seq[Port]): Boolean = visitedChannel.contains(ps) - def dedupPortLists(pList: Map[String, Seq[Port]]): Map[String, Seq[Port]] = pList.flatMap({ - case (cName, Nil) => throw new RuntimeException(s"Channel ${cName} is empty (has no associate ports)") - case (_, ports) if channelSharesPorts(ports) && !channelIsDuplicate(ports) => + private val visitedLeafPort = new LinkedHashSet[Port]() + private val visitedChannel = new LinkedHashMap[(Option[Port], Seq[Port]), String]() + + private def channelIsDuplicate(ps: (Option[Port], Seq[Port])): Boolean = visitedChannel.contains(ps) + private def channelSharesPorts(ps: (Option[Port], Seq[Port])): Boolean = ps match { + case (clk, ports) => ports.exists(visitedLeafPort(_)) // clock can be shared + } + + private def dedupPortLists(pList: Map[String, (Option[Port], Seq[Port])]): Map[String, (Option[Port], Seq[Port])] = pList.flatMap({ + case (cName, (_, Nil)) => throw new RuntimeException(s"Channel ${cName} is empty (has no associated ports)") + case (_, clockAndPorts) if channelSharesPorts(clockAndPorts) && !channelIsDuplicate(clockAndPorts) => throw new RuntimeException("Channel definition has partially overlapping ports with existing channel definition") - case (cName, ports) if channelIsDuplicate(ports) => - channelDedups(cName) = visitedChannel(ports) + case (cName, clockAndPorts) if channelIsDuplicate(clockAndPorts) => + channelDedups(cName) = visitedChannel(clockAndPorts) None - case (cName, ports) => - visitedChannel(ports) = cName + case (cName, (clock, ports)) => + visitedChannel((clock, ports)) = cName + visitedLeafPort ++= clock visitedLeafPort ++= ports channelDedups(cName) = cName - Some(cName, ports) + Some(cName, (clock, ports)) }).toMap - val inputPortMap = dedupPortLists(inputPortsByChannel(mTarget)) - val outputPortMap = dedupPortLists(outputPortsByChannel(mTarget)) + private val inputPortMap = dedupPortLists(portsByInputChannel(mTarget)) + private val outputPortMap = dedupPortLists(portsByOutputChannel(mTarget)) + val completePortMap = inputPortMap ++ outputPortMap } diff --git a/sim/midas/src/main/scala/midas/passes/fame/FindDefaultClocks.scala b/sim/midas/src/main/scala/midas/passes/fame/FindDefaultClocks.scala new file mode 100644 index 00000000..310d1cb6 --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/fame/FindDefaultClocks.scala @@ -0,0 +1,75 @@ +// See LICENSE for license details. + +package midas.passes.fame + +import firrtl._ +import firrtl.traversals.Foreachers._ +import firrtl.ir._ +import firrtl.annotations._ + +import collection.mutable + +/** In general, multi-clock Golden Gate simulations contain exactly one "hub" model that coordinates the clock domains + * and has a clock channel. Before this pass runs, loopback channels (those that run from one model to another) have no + * associated clock. The FAME transform expects that all channels that connect to the hub model have an associated + * clock, so this pass finds the top-level clock connection that drives the single clock of each "satellite" (non-hub) + * model and associates the clock output of the hub model driving this clock with all channels connecting the hub and + * the satellite. Channels between satellite models are not changed. + */ +object FindDefaultClocks extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + case class ModelInstance(name: String) + case class ClockConnection(source: ReferenceTarget, sink: ReferenceTarget) + type ClockConnMap = mutable.Map[ModelInstance, ClockConnection] + + def recordDefaultClocks(topTarget: ModuleTarget, defaultClocks: ClockConnMap)(stmt: Statement): Unit = stmt match { + case Connect(_, WSubField(WRef(lInst, _, InstanceKind, _), lPort, ClockType, _), WSubField(WRef(rInst, _, InstanceKind, _), rPort, ClockType, _)) => + // LHS must be "satellite" model receiving clock from "hub" model on RHS + defaultClocks(ModelInstance(lInst)) = ClockConnection(topTarget.ref(rInst).field(rPort), topTarget.ref(lInst).field(lPort)) + case s => s.foreach(recordDefaultClocks(topTarget, defaultClocks)) + } + + def addDefaultClocks(hubModel: ModelInstance, defaultClocks: ClockConnMap)(anno: Annotation): Annotation = anno match { + case fcca @ FAMEChannelConnectionAnnotation(name, info, None, Some(sources), Some(sinks)) => + // unclocked loopback channel + val sourceModelInst = ModelInstance(sources.head.ref) + val sinkModelInst = ModelInstance(sinks.head.ref) + if (sourceModelInst == hubModel) { + fcca.copy(clock = Some(defaultClocks(sinkModelInst).source)) + } else if (sinkModelInst == hubModel) { + fcca.copy(clock = Some(defaultClocks(sourceModelInst).source)) + } else { + fcca + } + case a => a + } + + def getStmts(stmt: Statement): Seq[Statement] = stmt match { + case Block(stmts) => stmts.flatMap(getStmts(_)) + case s => Seq(s) + } + + def execute(state: CircuitState): CircuitState = { + val topModule = state.circuit.modules.find(_.name == state.circuit.main).get.asInstanceOf[Module] + val cTarget = CircuitTarget(state.circuit.main) + val topTarget = cTarget.module(topModule.name) + val defaultClocks = new mutable.LinkedHashMap[ModelInstance, ClockConnection] + + // Find an arbitrary clock channel sink + val refClock = state.annotations.collectFirst { + case FAMEChannelConnectionAnnotation(_, TargetClockChannel(_), None, _, Some(sinks)) => sinks.head + } + + // Get the wrapper top port reference it points to + val refClockPort = refClock.get.ref + + val hubModel = getStmts(topModule.body).collectFirst { + case Connect(_, WSubField(WRef(lInst, _, _, _), _, _, _), WRef(`refClockPort`, _, PortKind, _)) => ModelInstance(lInst) + } + + topModule.body.foreach(recordDefaultClocks(topTarget, defaultClocks)) + state.copy(annotations = state.annotations.map(addDefaultClocks(hubModel.get, defaultClocks))) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/InferModelPorts.scala b/sim/midas/src/main/scala/midas/passes/fame/InferModelPorts.scala index f5d353a4..b9cd52cf 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/InferModelPorts.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/InferModelPorts.scala @@ -17,16 +17,23 @@ class InferModelPorts extends Transform { def inputForm = LowForm def outputForm = LowForm + def portAnnos(mt: ModuleTarget, cName: String, clk: Option[Port], ports: Seq[Port]): Seq[Annotation] = { + val clkRT = clk.map(p => mt.ref(p.name)) + val portsRT = ports.map(p => mt.ref(p.name)) + val fcpa = FAMEChannelPortsAnnotation(cName, clkRT, portsRT) + // Label all the channel ports with don't touch so as to prevent + // annotation renaming from breaking downstream + fcpa +: (clkRT ++: portsRT).map(DontTouchAnnotation(_)) + } + override def execute(state: CircuitState): CircuitState = { val analysis = new FAMEChannelAnalysis(state, FAME1Transform) val cTarget = CircuitTarget(state.circuit.main) - val modelChannelPortsAnnos = analysis.modulePortDedupers.flatMap(deduper => - deduper.completePortMap.flatMap({ case (cName, ports) => Seq( - FAMEChannelPortsAnnotation(cName, ports.map(p => deduper.mTarget.ref(p.name)))) ++ - // Label all the channel ports with don't touch so as to prevent - // annotation renaming from breaking downstream - ports.map(p => DontTouchAnnotation(deduper.mTarget.ref(p.name))) - })) + val modelChannelPortsAnnos = analysis.modulePortDedupers.flatMap { + case deduper => deduper.completePortMap.flatMap { + case (cName, (clk, ports)) => portAnnos(deduper.mTarget, cName, clk, ports) + } + } state.copy(annotations = state.annotations ++ modelChannelPortsAnnos) } } diff --git a/sim/midas/src/main/scala/midas/passes/fame/LabelSRAMModels.scala b/sim/midas/src/main/scala/midas/passes/fame/LabelSRAMModels.scala index 240002d7..ec7b4b2e 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/LabelSRAMModels.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/LabelSRAMModels.scala @@ -7,7 +7,7 @@ import Mappers._ import ir._ import annotations._ import collection.mutable.ArrayBuffer -import midas.targetutils.FirrtlMemModelAnnotation +import midas.targetutils.{FirrtlMemModelAnnotation, FirrtlFAMEModelAnnotation} class LabelSRAMModels extends Transform { def inputForm = HighForm @@ -41,7 +41,7 @@ class LabelSRAMModels extends Transform { val wrapper = mem2Module(mem).copy(name = moduleNS.newName(mem.name)) val wrapperTarget = ModuleTarget(circ.main, wrapper.name) memModules += wrapper - memModelAnnotations += FAMEModelAnnotation(mt.instOf(mem.name, wrapper.name)) + memModelAnnotations += FirrtlFAMEModelAnnotation(mt.instOf(mem.name, wrapper.name)) memModelAnnotations ++= mem.readers.map(rp => ModelReadPort(wrapperTarget.ref(rp))) memModelAnnotations ++= mem.writers.map(rp => ModelWritePort(wrapperTarget.ref(rp))) memModelAnnotations ++= mem.readwriters.map(rp => ModelReadWritePort(wrapperTarget.ref(rp))) diff --git a/sim/midas/src/main/scala/midas/passes/fame/MultiThreadFAME5Models.scala b/sim/midas/src/main/scala/midas/passes/fame/MultiThreadFAME5Models.scala new file mode 100644 index 00000000..847c733c --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/fame/MultiThreadFAME5Models.scala @@ -0,0 +1,213 @@ +// See LICENSE for license details. + +package midas.passes.fame + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ +import firrtl.annotations.ModuleTarget +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.Utils.BoolType + +import midas.targetutils.FirrtlEnableModelMultiThreadingAnnotation + +import collection.mutable + +import midas.passes._ + +// PREREQUISITE: Top-level simulator is already fully FAME-Transformed +// ASSUMPTION: Each channel is bulk-connected to EXACTLY ONE top-level port +// ASSUMPTION: The direction of a channel port on a model mirrors its data flow direction + +trait ReadyValidSignal { + val ref: Expression + def ready: WSubField = WSubField(ref, "ready", BoolType, UNKNOWNGENDER) + def valid: WSubField = WSubField(ref, "valid", BoolType, UNKNOWNGENDER) + def bits: WSubField = WSubField(ref, "bits", UnknownType, UNKNOWNGENDER) +} + +case class ReadyValidSink(ref: Expression) extends ReadyValidSignal + +case class ReadyValidSource(ref: Expression) extends ReadyValidSignal + +object FAME5Info { + def info = FileInfo(StringLit("@ [Added during FAME5Transform]")) +} + +object Counter { + def apply(nVals: Integer, hostClock: Expression, hostReset: Expression)(implicit ns: Namespace): SignalInfo = { + val maxLit = UIntLiteral(BigInt(nVals - 1)) + val decl = DefRegister(FAME5Info.info, ns.newName("threadIdx"), maxLit.tpe, hostClock, hostReset, UIntLiteral(0)) + val ref = WRef(decl) + val inc = DoPrim(PrimOps.Add, Seq(ref, UIntLiteral(1)), Nil, UnknownType) + val wrap = DoPrim(PrimOps.Eq, Seq(ref, maxLit), Nil, BoolType) + val assign = Connect(FAME5Info.info, ref, Mux(wrap, UIntLiteral(0), inc, UnknownType)) + SignalInfo(decl, assign, ref) + } +} + +class StaticArbiter(counter: SignalInfo) { + private def muxIdx(select: Expression, signals: Seq[Expression]): Expression = { + signals.zipWithIndex.tail.foldLeft(signals.head) { case (fval, (tval, idx)) => + Mux(DoPrim(PrimOps.Eq, Seq(select, UIntLiteral(idx)), Nil, BoolType), tval, fval, UnknownType) + } + } + + private def counterMask(counter: SignalInfo, idx: Integer, signal: WSubField): DoPrim = { + val eq = DoPrim(PrimOps.Eq, Seq(counter.ref, UIntLiteral(BigInt(idx))), Nil, BoolType) + DoPrim(PrimOps.And, Seq(signal, eq), Nil, BoolType) + } + + def mux(sink: ReadyValidSink, sources: Seq[ReadyValidSource]): Statement = { + val valid = muxIdx(counter.ref, sources.map(_.valid)) + val validConn = Connect(FAME5Info.info, sink.valid, valid) + val bits = muxIdx(counter.ref, sources.map(_.bits)) + val bitsConn = Connect(FAME5Info.info, sink.bits, bits) + val readyConns = sources.zipWithIndex.map { + case (source, idx) => Connect(FAME5Info.info, source.ready, counterMask(counter, idx, sink.ready)) + } + Block(validConn +: bitsConn +: readyConns) + } + + def demux(sinks: Seq[ReadyValidSink], source: ReadyValidSource): Statement = { + val ready = muxIdx(counter.ref, sinks.map(_.ready)) + val readyConn = Connect(FAME5Info.info, source.ready, ready) + val bitsConns = sinks.map(sink => Connect(FAME5Info.info, sink.bits, source.bits)) + val validConns = sinks.zipWithIndex.map { + case (sink, idx) => Connect(FAME5Info.info, sink.valid, counterMask(counter, idx, source.valid)) + } + Block(readyConn +: bitsConns ++: validConns) + } +} + +object MultiThreadFAME5Models extends Transform { + def inputForm = HighForm + def outputForm = HighForm + + type TopoMap = Map[OfModule, Map[String, mutable.Map[Instance, WRef]]] + + private def analyzeAndPruneTopo(fame5InstMap: Map[Instance, OfModule], topo: TopoMap)(stmt: Statement): Statement = { + def processChannelConn(inst: Instance, modelPort: String, topConn: WRef) = { + if (fame5InstMap.contains(inst)) { + topo(fame5InstMap(inst))(modelPort)(inst) = topConn + EmptyStmt + } else { + stmt + } + } + + stmt match { + case Connect(_, WSubField(WRef(iName, _, InstanceKind, _), ipName, BundleType(_), _), wr: WRef) => + // Could infer that this is an input channel, but that would be an extra assumption + processChannelConn(Instance(iName), ipName, wr) + case Connect(_, wr: WRef, WSubField(WRef(iName, _, InstanceKind, _), ipName, BundleType(_), _)) => + // Could infer that this is an output channel, but that would be an extra assumption + processChannelConn(Instance(iName), ipName, wr) + case Connect(_, WSubField(WRef(iName, _, InstanceKind, _), _, _, _), _) if fame5InstMap.contains(Instance(iName)) => + EmptyStmt // prune non-channel connections + case WDefInstance(_, iName, mName, _) if fame5InstMap.contains(Instance(iName)) => + EmptyStmt + case s => + s.map(analyzeAndPruneTopo(fame5InstMap, topo)) + } + } + + private def findFAME5(modInsts: collection.Map[OfModule, mutable.LinkedHashSet[Instance]])(stmt: Statement): Unit = { + stmt match { + case WDefInstance(_, iName, mName, _) => modInsts.get(OfModule(mName)).foreach(iSet => iSet += Instance(iName)) + case s => s.foreach(findFAME5(modInsts)) + } + } + + override def execute(state: CircuitState): CircuitState = { + val moduleDefs = state.circuit.modules.collect({ case m: Module => OfModule(m.name) -> m}).toMap + + val top = moduleDefs(OfModule(state.circuit.main)) + implicit val ns = Namespace(top) + + val hostClock = WRef(top.ports.find(_.name == WrapTop.hostClockName).get) + val hostReset = WRef(top.ports.find(_.name == WrapTop.hostResetName).get) + + // Populate keys from annotations, values from traversing statements + val fame5RawInstances = new mutable.LinkedHashMap[OfModule, mutable.LinkedHashSet[Instance]] + state.annotations.foreach { + case FirrtlEnableModelMultiThreadingAnnotation(ModuleTarget(_, m)) => + fame5RawInstances(OfModule(m)) = new mutable.LinkedHashSet[Instance] + case _ => + } + + top.body.foreach(findFAME5(fame5RawInstances)) + + // filter models with one instance to avoid needless multithreading + val fame5InstancesByModule = fame5RawInstances.filter { case (k, v) => v.size > 1 } + val fame5ModulesByInstance = fame5InstancesByModule.flatMap({ case (k, v) => v.map(vv => vv -> k) }).toMap + + // Maps from an (OfModule, PortName) pair + // It's actually nested (rather than indexed by tuple) for convenience + // We don't track the direction in this map, since we can just find it later + val fame5Topo: TopoMap = fame5InstancesByModule.map({ + case (m, iSet) => m -> moduleDefs(m).ports.collect({ + case Port(_, name, _, BundleType(_)) => name -> new mutable.HashMap[Instance, WRef] + }).toMap + }).toMap + + val prunedTopoTopBody = top.body.map(analyzeAndPruneTopo(fame5ModulesByInstance, fame5Topo)) + + // For now, just one FAME5 model + assert(fame5InstancesByModule.keySet.size <= 1) + + val nThreads = fame5InstancesByModule.headOption.map(_._2.size).getOrElse(1) + + val circuitNS = Namespace(state.circuit) + val threadedModuleNames = state.circuit.modules.collect({ + // Don't replace blackbox instances! TODO: Check for illegal blackboxes. + case m: Module => m.name -> circuitNS.newName(s"${m.name}_threaded") + }).toMap + + val threadedInstances = fame5InstancesByModule.map({ + case (m, insts) => m -> WDefInstance(FAME5Info.info, ns.newName(s"${m.value}_threaded"), threadedModuleNames(m.value), UnknownType) + }).toMap + + val threadCounters = fame5InstancesByModule.map { case (m, insts) => m -> Counter(insts.size, hostClock, hostReset) } + + val multiThreadedConns: Seq[Statement] = fame5Topo.toSeq.flatMap { + case (mod, connsByPort) => + val arbiter = new StaticArbiter(threadCounters(mod)) + connsByPort.toSeq.map { + case (port, connsByInstance) => + val canonicalOrderedConns = fame5InstancesByModule(mod).map(inst => connsByInstance(inst)).toSeq + if (moduleDefs(mod).ports.exists(p => p.name == port && p.direction == Input)) { + val sink = ReadyValidSink(WSubField(WRef(threadedInstances(mod)), port)) + val sources = canonicalOrderedConns.map(s => ReadyValidSource(s)) + arbiter.mux(sink, sources) + } else { + val sinks = canonicalOrderedConns.map(s => ReadyValidSink(s)) + val source = ReadyValidSource(WSubField(WRef(threadedInstances(mod)), port)) + arbiter.demux(sinks, source) + } + } + } + + val insts = threadedInstances.toSeq.map { case (k, v) => v } // keep ordering + val counters = threadCounters.toSeq.map { case (k, v) => v } // keep ordering + val clockConns = insts.map(i => Connect(FAME5Info.info, WSubField(WRef(i), WrapTop.hostClockName), WRef(WrapTop.hostClockName))) + val resetConns = insts.map(i => Connect(FAME5Info.info, WSubField(WRef(i), WrapTop.hostResetName), WRef(WrapTop.hostResetName))) + + val prologue = insts ++: clockConns ++: resetConns ++: counters.flatMap(c => Seq(c.decl, c.assigns)) + val multiThreadedTopBody = Block(prologue ++: prunedTopoTopBody +: multiThreadedConns) + + val transformedModules = state.circuit.modules.flatMap { + case m: Module if (m.name == state.circuit.main) => + Seq(m.copy(body = multiThreadedTopBody)) + case m: Module => + val threaded = MuxingMultiThreader(threadedModuleNames)(m, nThreads) // all threaded by same amount, many get pruned + Seq(m, threaded) + case m => Seq(m) + } + + // TODO: Renames! + + state.copy(circuit = state.circuit.copy(modules = transformedModules)) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/MultiThreader.scala b/sim/midas/src/main/scala/midas/passes/fame/MultiThreader.scala new file mode 100644 index 00000000..78955a3b --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/fame/MultiThreader.scala @@ -0,0 +1,211 @@ +// See LICENSE for license details. + +package midas.passes.fame + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ +import firrtl.Utils.{BoolType, kind, zero, one} + +import midas.passes.SignalInfo + +import collection.mutable +import collection.immutable.HashMap + +/* TODO: + * + * This got really messy because of the need to emulate gated + * clocks. This involves recovering the enable from the gated clock + * and using it to select either the output of the logic or the + * shadowed previous value of the currently active register slot. + * + * Also TODO: maybe use fewer than 200 characters per line. + */ + +// Utility to help "float" instance declarations to the top of a block for convenience +object SeparateInstanceDecls { + private def onStmt(insts: mutable.ArrayBuffer[WDefInstance])(stmt: Statement): Statement = stmt match { + case wi: WDefInstance => + insts.append(wi) + EmptyStmt + case s => s.map(onStmt(insts)) + } + + def apply(stmt: Statement): (Seq[WDefInstance], Statement) = { + val insts = new mutable.ArrayBuffer[WDefInstance] + val otherStmts = onStmt(insts)(stmt) + (insts.toSeq, otherStmts) + } +} + +object AddHostClockAndReset { + val hostClock = Port(FAME5Info.info, WrapTop.hostClockName, Input, ClockType) + val hostReset = Port(FAME5Info.info, WrapTop.hostResetName, Input, BoolType) + + def apply(m: Module): Module = { + m.copy(ports = m.ports ++ Seq(hostClock, hostReset).filterNot(p => m.ports.map(_.name).contains(p.name))) + } + + def apply(wi: WDefInstance): Statement = { + // Adds connections to instance hostClock/hostReset ports + val hcConn = Connect(FAME5Info.info, WSubField(WRef(wi), WrapTop.hostClockName), WRef(WrapTop.hostClockName)) + val hrConn = Connect(FAME5Info.info, WSubField(WRef(wi), WrapTop.hostResetName), WRef(WrapTop.hostResetName)) + Block(Seq(wi, hcConn, hrConn)) + } +} + +object Toggle { + def apply(r: DefRegister): Statement = { + Connect(FAME5Info.info, WRef(r), DoPrim(PrimOps.Not, Seq(WRef(r)), Nil, r.tpe)) + } +} + +// Invalidates WIR _.tpe (types of memories change) +object MultiThreader { + + type FreshNames = Map[String, Seq[String]] + + def renameRegs(freshNames: FreshNames, n: BigInt, ns: Namespace, stmt: Statement): FreshNames = stmt match { + case Block(stmts) => + stmts.foldLeft(freshNames) { case (rm, s) => renameRegs(rm, n, ns, s) } + case Conditionally(_, _, cons, alt) => + renameRegs(renameRegs(freshNames, n, ns, cons), n, ns, alt) + case reg: DefRegister => + // One extra register to shadow value for clock-gated registers + val regNames = (0 to n.intValue()).map(i => ns.newName(s"${reg.name}_slot_${i}")) + freshNames.updated(reg.name, regNames) + case s => freshNames + } + + def replaceRegRefsLHS(freshNames: FreshNames)(expr: Expression): Expression = expr match { + case wr @ WRef(name, _, RegKind, _) => wr.copy(name = freshNames(name).head) + case e => e.map(replaceRegRefsLHS(freshNames)) + } + + def replaceRegRefsRHS(freshNames: FreshNames)(expr: Expression): Expression = expr match { + // 2nd-to-last slot feeds logic; last slot holds a shadow value + case wr @ WRef(name, _, RegKind, _) => wr.copy(name = freshNames(name).init.last) + case e => e.map(replaceRegRefsRHS(freshNames)) + } + + def replaceRegRefs(freshNames: FreshNames)(expr: Expression): Expression = expr match { + case wr @ WRef(name, _, RegKind, SinkFlow) => wr.copy(name = freshNames(name).head) + // 2nd-to-last slot feeds logic; last slot holds a shadow value + case wr @ WRef(name, _, RegKind, SourceFlow) => wr.copy(name = freshNames(name).init.last) + case e => e.map(replaceRegRefs(freshNames)) + } + + def transformDepth(depth: BigInt, n: BigInt): BigInt = { + require(n.bitCount == 1) // pow2 threads for now + depth * n + } + + def transformAddr(counter: DefRegister, expr: Expression): Expression = { + DoPrim(PrimOps.Cat, Seq(expr, WRef(counter)), Nil, UnknownType) + } + + def updateReg(reg: DefRegister, slotName: String): DefRegister = { + // Keep self-resets as self-resets + val newInit = reg.init match { + case wr: WRef if (wr.name == reg.name) => wr.copy(name = slotName) + case e => e + } + // All slot registers are free-running + reg.copy(name = slotName, init = newInit, clock = WRef(WrapTop.hostClockName)) + } + + def multiThread(freshNames: FreshNames, edgeStatus: collection.Map[WrappedExpression, SignalInfo], n: BigInt, counter: DefRegister)(stmt: Statement): Statement = { + stmt match { + case mem: DefMemory => + assert(mem.readLatency == 0 && mem.writeLatency == 1, "Memories must be transformed with VerilogMemDelays before multithreading") + mem.copy(depth = transformDepth(mem.depth, n), readLatency = mem.readLatency * n.intValue()) + case reg: DefRegister => + val newRegs: Seq[DefRegister] = freshNames(reg.name).map(alias => updateReg(reg, alias)) + // Muxing happens between first two stages + val useNew = edgeStatus(WrappedExpression(reg.clock)).ref + val gatedUpdate = Connect(FAME5Info.info, WRef(newRegs.tail.head), Mux(useNew, WRef(newRegs.head), WRef(newRegs.last), UnknownType)) + // Other stages are straight connections + val directPairs = newRegs.tail zip newRegs.tail.tail + val directConns = directPairs.map { case (a, b) => Connect(FAME5Info.info, WRef(b), WRef(a)) } + Block(newRegs ++: gatedUpdate +: directConns) + case Connect(info, lhs @ WSubField(p: WSubField, "addr", _, _), rhs) if kind(lhs) == MemKind => + Connect(info, replaceRegRefsLHS(freshNames)(lhs), transformAddr(counter, replaceRegRefsRHS(freshNames)(rhs))) + case Connect(info, lhs, rhs) => + // TODO: LHS vs RHS is kind of a hack + // We need a new method to swap register refs on the LHS because VerilogMemDelays puts in un-gendered register refs + Connect(info, replaceRegRefsLHS(freshNames)(lhs), replaceRegRefsRHS(freshNames)(rhs)) + case s => s.map(multiThread(freshNames, edgeStatus, n, counter)).map(replaceRegRefsRHS(freshNames)) + } + } + + def findClocks(clocks: mutable.Set[WrappedExpression])(stmt: Statement): Unit = { + def findClocksExpr(expr: Expression): Unit = { + if (expr.tpe == ClockType && Utils.gender(expr) == MALE) { + clocks += WrappedExpression(expr) + } + expr.foreach(findClocksExpr) + } + + stmt.foreach(findClocksExpr) + stmt.foreach(findClocks(clocks)) + } + + def apply(threadedModuleNames: Map[String, String])(module: Module, n: BigInt): Module = { + // TODO: this is ugly and uses copied code instead of bumping FIRRTL + // Simplify all memories first + val loweredMod = (new MemDelayAndReadwriteTransformer(module)).transformed.asInstanceOf[Module] + + val ns = Namespace(loweredMod) + val hostClock = WRef(WrapTop.hostClockName) + val hostReset = WRef(WrapTop.hostResetName) + + val clocks = new mutable.LinkedHashSet[WrappedExpression] + loweredMod.body.foreach(findClocks(clocks)) + + val edgeStatus = new mutable.LinkedHashMap[WrappedExpression, SignalInfo] + clocks.foreach { + case we => we.e1 match { + case WRef(WrapTop.hostClockName, _, _, _) => + // Optimization -- don't generate this gate recovery stuff for host clock + edgeStatus(we) = SignalInfo(EmptyStmt, EmptyStmt, UIntLiteral(1)) + case e => + val edgeCount = DefRegister(FAME5Info.info, ns.newName("edgeCount"), BoolType, e, hostReset, UIntLiteral(0)) + val updateCount = DefRegister(FAME5Info.info, ns.newName("updateCount"), BoolType, hostClock, hostReset, UIntLiteral(0)) + val neq = DoPrim(PrimOps.Neq, Seq(WRef(edgeCount), WRef(updateCount)), Nil, BoolType) + val trackUpdates = Conditionally(FAME5Info.info, neq, Toggle(updateCount), EmptyStmt) + edgeStatus(we) = SignalInfo(Block(Seq(edgeCount, updateCount)), Block(Seq(Toggle(edgeCount), trackUpdates)), neq) + } + } + + val tidxMax = UIntLiteral(n-1) + val tidxType = UIntType(tidxMax.width) + val tidxRef = WRef(ns.newName("threadIdx"), tidxType, RegKind) + val tidxDecl = DefRegister(FAME5Info.info, tidxRef.name, tidxType, hostClock, zero, tidxRef) + val tidxUpdate = Mux( + DoPrim(PrimOps.Eq, Seq(tidxRef, tidxMax), Nil, BoolType), + UIntLiteral(0), + DoPrim(PrimOps.Add, Seq(tidxRef, one), Nil, BoolType), + tidxType) + val tidxConn = Connect(FAME5Info.info, tidxRef, tidxUpdate) + + val freshNames = renameRegs(HashMap.empty, n, ns, loweredMod.body) + val threaded = multiThread(freshNames, edgeStatus, n, tidxDecl)(loweredMod.body) + + // Uses only threaded instances + val (iDecls, body) = SeparateInstanceDecls(threaded) + + val threadedChildren = iDecls.map { + case i if (threadedModuleNames.contains(i.module)) => + AddHostClockAndReset(i.copy(module = threadedModuleNames(i.module))) + case i => i + } + + val clockGaters = edgeStatus.toSeq.map { case (k, v) => v } + val threadedBody = Block(threadedChildren ++ clockGaters.map(_.decl) ++ Seq(tidxDecl, tidxConn, body) ++ clockGaters.map(_.assigns)) + + val hostPorts = Seq(Port(FAME5Info.info, hostClock.name, Input, ClockType), Port(FAME5Info.info, hostReset.name, Input, BoolType)) + val newPorts = module.ports ++ hostPorts.filterNot(p => module.ports.map(_.name).contains(p.name)) + Module(FAME5Info.info ++ module.info, threadedModuleNames(module.name), newPorts, threadedBody) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/MuxingMultithreader.scala b/sim/midas/src/main/scala/midas/passes/fame/MuxingMultithreader.scala new file mode 100644 index 00000000..e5d9db31 --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/fame/MuxingMultithreader.scala @@ -0,0 +1,106 @@ +// See LICENSE for license details. + +package midas.passes.fame + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.passes.MemPortUtils._ +import firrtl.traversals.Foreachers._ +import firrtl.Utils.{BoolType, kind} + +import collection.mutable.ArrayBuffer + +// Invalidates WIR _.tpe (types of memories change) +object MuxingMultiThreader { + + type FreshNames = Map[String, Seq[String]] + + val rPortName = "read" + val wPortName = "write" + + def rField(mem: DefMemory, field: String): Expression = memPortField(mem, rPortName, field) + def wField(mem: DefMemory, field: String): Expression = memPortField(mem, wPortName, field) + + def onExprRHS(expr: Expression): Expression = expr match { + case WRef(name, tpe, RegKind, _) => + WSubField(WSubField(WRef(name, tpe, MemKind), rPortName), "data") + case e => e.map(onExprRHS) + } + + def regWriteAsMemWrite(info: Info, name: String, tpe: Type, rhs: Expression): Statement = { + val infos = FAME5Info.info ++ info + val wPort = WSubField(WRef(name), wPortName) + val dataConnect = Connect(infos, WSubField(wPort, "data"), rhs) + val enConnect = Connect(infos, WSubField(wPort, "en"), UIntLiteral(1)) + Block(Seq(dataConnect, enConnect)) + } + + def onStmt(newResets: ArrayBuffer[Statement], nThreads: BigInt, tIdx: Expression)(stmt: Statement): Statement = stmt match { + case DefRegister(info, name, tpe, clock, reset, init) => + val infos = FAME5Info.info ++ info + val mem = DefMemory(infos, name, tpe, nThreads, 1, 0, Seq(rPortName), Seq(wPortName), Nil) + val rClockConn = Connect(infos, rField(mem, "clk"), clock) + val rEnConn = Connect(infos, rField(mem, "en"), UIntLiteral(1)) + val rAddrConn = Connect(infos, rField(mem, "addr"), tIdx) + val wClockConn = Connect(infos, wField(mem, "clk"), clock) + val wEnDefault = Connect(infos, wField(mem, "en"), UIntLiteral(0)) + val wMaskConn = Connect(infos, wField(mem, "mask"), UIntLiteral(1)) + val wAddrConn = Connect(infos, wField(mem, "addr"), tIdx) + if (WrappedExpression(init) != WrappedExpression(WRef(name))) { + val doReset = regWriteAsMemWrite(info, name, tpe, onExprRHS(init)) + newResets += Conditionally(infos, reset, doReset, EmptyStmt) + } + Block(Seq(mem, rClockConn, rEnConn, rAddrConn, wClockConn, wEnDefault, wMaskConn, wAddrConn)) + case Connect(info, lhs @ WSubField(p: WSubField, "addr", _, _), rhs) if kind(lhs) == MemKind => + Connect(FAME5Info.info ++ info, lhs, DoPrim(PrimOps.Cat, Seq(onExprRHS(rhs), tIdx), Nil, UnknownType)) + case Connect(info, WRef(name, tpe, RegKind, _), rhs) => + regWriteAsMemWrite(info, name, tpe, onExprRHS(rhs)) + case Connect(_, lhs, _) if (kind(lhs) == RegKind) => + throw CustomTransformException(new IllegalArgumentException(s"Cannot handle complex register assignment to ${lhs}")) + case mem: DefMemory => + require(mem.readLatency == 0, "Memories must be transformed with VerilogMemDelays before multithreading") + require(mem.readLatency == 0, "Memories must have one-cycle write latency") + require(nThreads.bitCount == 1, "Models may only be threaded by pow2 threads for now") + mem.copy(depth = mem.depth * nThreads) + case s => s.map(onStmt(newResets, nThreads, tIdx)).map(onExprRHS) + } + + def apply(threadedModuleNames: Map[String, String])(module: Module, n: BigInt): Module = { + // TODO: this is ugly and uses copied code instead of bumping FIRRTL + // Simplify all memories first + val loweredMod = (new MemDelayAndReadwriteTransformer(module)).transformed.asInstanceOf[Module] + val ns = Namespace(loweredMod) + + val hostClock = WRef(WrapTop.hostClockName) + val hostReset = WRef(WrapTop.hostResetName) + + val tIdxMax = UIntLiteral(n-1) + val tIdxType = UIntType(tIdxMax.width) + val tIdxRef = WRef(ns.newName("threadIdx"), tIdxType, RegKind) + val tIdxDecl = DefRegister(FAME5Info.info, tIdxRef.name, tIdxType, hostClock, UIntLiteral(0), tIdxRef) + val tIdxUpdate = Mux( + DoPrim(PrimOps.Eq, Seq(tIdxRef, tIdxMax), Nil, BoolType), + UIntLiteral(0), + DoPrim(PrimOps.Add, Seq(tIdxRef, UIntLiteral(1)), Nil, BoolType), + tIdxType) + val tIdxConn = Connect(FAME5Info.info, tIdxRef, tIdxUpdate) + + // Resets transformed to conditional stores to threading RAMs + val newResets = new ArrayBuffer[Statement] + val threaded = onStmt(newResets, n, tIdxRef)(loweredMod.body) + + // Uses only threaded instances + val (iDecls, threadedImpl) = SeparateInstanceDecls(threaded) + + // TODO: earlier in the compiler, every module should get hostClock/hostReset ports hooked up + val threadedChildren = iDecls.map { + case i if (threadedModuleNames.contains(i.module)) => + AddHostClockAndReset(i.copy(module = threadedModuleNames(i.module))) + case i => i + } + + val threadedBody = Block(threadedChildren ++: tIdxDecl +: tIdxConn +: threadedImpl +: newResets.toSeq) + AddHostClockAndReset(Module(FAME5Info.info ++ module.info, threadedModuleNames(module.name), module.ports, threadedBody)) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/PatientSSMTransformers.scala b/sim/midas/src/main/scala/midas/passes/fame/PatientSSMTransformers.scala new file mode 100644 index 00000000..edb874db --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/fame/PatientSSMTransformers.scala @@ -0,0 +1,67 @@ +// See LICENSE for license details. + +package midas.passes.fame + +import java.io.{PrintWriter, File} + +import firrtl._ +import ir._ +import Mappers._ +import firrtl.Utils.{BoolType, kind, ceilLog2, one} +import firrtl.passes.MemPortUtils +import annotations._ +import scala.collection.mutable +import mutable.{LinkedHashSet, LinkedHashMap} + +import midas.passes._ + +object PatientMemTransformer { + def apply(mem: DefMemory, finishing: Expression, memClock: WRef, ns: Namespace): Block = { + val shim = DefWire(NoInfo, mem.name, MemPortUtils.memType(mem)) + val newMem = mem.copy(name = ns.newName(mem.name)) + val defaultConnect = Connect(NoInfo, WRef(shim), WRef(newMem.name, shim.tpe, MemKind)) + val syncReadPorts = (newMem.readers ++ newMem.readwriters).filter(rp => mem.readLatency > 0) + val preserveReads = syncReadPorts.flatMap { + case rpName => + val addrWidth = IntWidth(ceilLog2(mem.depth) max 1) + val dummyReset = DefWire(NoInfo, ns.newName(s"${mem.name}_${rpName}_dummyReset"), BoolType) + val tieOff = Connect(NoInfo, WRef(dummyReset), UIntLiteral(0)) + val addrReg = new DefRegister(NoInfo, ns.newName(s"${mem.name}_${rpName}"), + UIntType(addrWidth), memClock, WRef(dummyReset), UIntLiteral(0, addrWidth)) + val updateReg = Connect(NoInfo, WRef(addrReg), WSubField(WSubField(WRef(shim), rpName), "addr")) + val useReg = Connect(NoInfo, MemPortUtils.memPortField(newMem, rpName, "addr"), WRef(addrReg)) + Seq(dummyReset, tieOff, addrReg, Conditionally(NoInfo, finishing, updateReg, useReg)) + } + val gateWrites = (newMem.writers ++ newMem.readwriters).map { + case wpName => + Conditionally( + NoInfo, + Negate(finishing), + Connect(NoInfo, MemPortUtils.memPortField(newMem, wpName, "en"), UIntLiteral(0, IntWidth(1))), + EmptyStmt) + } + new Block(Seq(shim, newMem, defaultConnect) ++ preserveReads ++ gateWrites) + } +} + +object PatientSSMTransformer { + def apply(m: Module, analysis: FAMEChannelAnalysis)(implicit triggerName: String): Module = { + val ns = Namespace(m) + val clocks = m.ports.filter(_.tpe == ClockType) + // TODO: turn this back on + // assert(clocks.length == 1) + val finishing = new Port(NoInfo, ns.newName(triggerName), Input, BoolType) + val hostClock = clocks.find(_.name == "clock").getOrElse(clocks.head) // TODO: naming convention for host clock + def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { + case conn @ Connect(info, lhs, _) if (kind(lhs) == RegKind) => + Conditionally(info, WRef(finishing), conn, EmptyStmt) + case s: Stop => s.copy(en = DoPrim(PrimOps.And, Seq(WRef(finishing), s.en), Seq.empty, BoolType)) + case p: Print => p.copy(en = DoPrim(PrimOps.And, Seq(WRef(finishing), p.en), Seq.empty, BoolType)) + case mem: DefMemory => PatientMemTransformer(mem, WRef(finishing), WRef(hostClock), ns) + case wi: WDefInstance if analysis.syncNativeModules.contains(analysis.moduleTarget(wi)) => + new Block(Seq(wi, Connect(wi.info, WSubField(WRef(wi), triggerName), WRef(finishing)))) + case s => s + } + Module(m.info, m.name, m.ports :+ finishing, m.body.map(onStmt)) + } +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/RTLUtils.scala b/sim/midas/src/main/scala/midas/passes/fame/RTLUtils.scala index e8658fb7..27668e14 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/RTLUtils.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/RTLUtils.scala @@ -48,11 +48,31 @@ object Negate { def apply(arg: Expression): Expression = DoPrim(PrimOps.Not, Seq(arg), Seq.empty, arg.tpe) } -object Reduce { - private def _reduce(op: PrimOp, args: Iterable[Expression]): Expression = { - args.tail.foldLeft(args.head){ (l, r) => DoPrim(op, Seq(l, r), Seq.empty, UIntType(IntWidth(1))) } +sealed trait BinaryBooleanOp { + def op: PrimOp + def apply(l: Expression, r: Expression): DoPrim = DoPrim(op, Seq(l, r), Nil, BoolType) + def reduce(args: Iterable[Expression]): Expression = { + require(args.nonEmpty) + args.tail.foldLeft(args.head){ (l, r) => apply(l, r) } } - def and(args: Iterable[Expression]): Expression = _reduce(PrimOps.And, args) - def or(args: Iterable[Expression]): Expression = _reduce(PrimOps.Or, args) } +object And extends BinaryBooleanOp { + val op = PrimOps.And +} + +object Or extends BinaryBooleanOp { + val op = PrimOps.Or +} + +object Neq extends BinaryBooleanOp { + val op = PrimOps.Neq +} + +/** Generates a DefRegister with no reset, relying instead on FPGA programming + * to preset the register to 0 + */ +object RegZeroPreset { + def apply(info: Info, name: String, tpe: Type, clock: Expression): DefRegister = + DefRegister(info, name, tpe, clock, zero, WRef(name)) +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/TrivialChannelExcision.scala b/sim/midas/src/main/scala/midas/passes/fame/TrivialChannelExcision.scala index 699a18ed..e6b068e6 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/TrivialChannelExcision.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/TrivialChannelExcision.scala @@ -34,9 +34,9 @@ class TrivialChannelExcision extends Transform { val fame1Anno = FAMETransformAnnotation(FAME1Transform, ModuleTarget(topName, topChildren.head.module)) val fameChannelAnnos = topModule.ports.collect({ case ip @ Port(_, name, Input, tpe) if !specialSignals.contains(name) => - FAMEChannelConnectionAnnotation(name, WireChannel, None, Some(Seq(ReferenceTarget(topName, topName, Nil, name, Nil)))) + FAMEChannelConnectionAnnotation.implicitlyClockedSink(name, WireChannel, Seq(ReferenceTarget(topName, topName, Nil, name, Nil))) case op @ Port(_, name, Output, tpe) => - FAMEChannelConnectionAnnotation(name, WireChannel, Some(Seq(ReferenceTarget(topName, topName, Nil, name, Nil))), None) + FAMEChannelConnectionAnnotation.implicitlyClockedSource(name, WireChannel, Seq(ReferenceTarget(topName, topName, Nil, name, Nil))) }) state.copy(annotations = state.annotations ++ Seq(fame1Anno) ++ fameChannelAnnos) } diff --git a/sim/midas/src/main/scala/midas/passes/fame/VerilogMemDelays.scala b/sim/midas/src/main/scala/midas/passes/fame/VerilogMemDelays.scala new file mode 100644 index 00000000..758fecde --- /dev/null +++ b/sim/midas/src/main/scala/midas/passes/fame/VerilogMemDelays.scala @@ -0,0 +1,168 @@ +// See LICENSE for license details. + +package midas.passes.fame + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import firrtl.WrappedExpression._ +import firrtl.traversals.Foreachers._ +import firrtl.passes.MemPortUtils._ + +import collection.mutable + +object MemDelayAndReadwriteTransformer { + // Representation of a group of signals and associated valid signals + case class WithValid(valid: Expression, payload: Seq[Expression]) + + // Grouped statements that are split into declarations and connects to ease ordering + case class SplitStatements(decls: Seq[Statement], conns: Seq[Connect]) + + // Utilities for generating hardware + def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType) + def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType) + def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r) + def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe)) + + // Utilities for working with WithValid groups + def connect(l: WithValid, r: WithValid): Seq[Connect] = { + val paired = (l.valid +: l.payload) zip (r.valid +: r.payload) + paired.map { case (le, re) => connect(le, re) } + } + + def condConnect(l: WithValid, r: WithValid): Seq[Connect] = { + connect(l.valid, r.valid) +: (l.payload zip r.payload).map { case (le, re) => condConnect(r.valid)(le, re) } + } + + // Internal representation of a pipeline stage with an associated valid signal + private case class PipeStageWithValid(idx: Int, ref: WithValid, stmts: SplitStatements = SplitStatements(Nil, Nil)) + + // Utilities for creating legal names for registers + private val metaChars = raw"[\[\]\.]".r + private def flatName(e: Expression) = metaChars.replaceAllIn(e.serialize, "_") + + // Pipeline a group of signals with an associated valid signal. Gate registers when possible. + def pipelineWithValid(ns: Namespace)( + clock: Expression, + depth: Int, + src: WithValid, + nameTemplate: Option[WithValid] = None): (WithValid, Seq[Statement], Seq[Connect]) = { + + def asReg(e: Expression) = DefRegister(NoInfo, e.serialize, e.tpe, clock, zero, e) + val template = nameTemplate.getOrElse(src) + + val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { case prev => + def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind) + val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef)) + val regs = (ref.valid +: ref.payload).map(asReg) + PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref))) + } + (stages.last.ref, stages.flatMap(_.stmts.decls), stages.flatMap(_.stmts.conns)) + } +} + +/** + * This class performs the primary work of the transform: splitting readwrite ports into separate + * read and write ports while simultaneously compiling memory latencies to combinational-read + * memories with delay pipelines. It is represented as a class that takes a module as a constructor + * argument, as it encapsulates the mutable state required to analyze and transform one module. + * + * @note The final transformed module is found in the (sole public) field [[transformed]] + */ +class MemDelayAndReadwriteTransformer(m: DefModule) { + import MemDelayAndReadwriteTransformer._ + + private val ns = Namespace(m) + private val netlist = new collection.mutable.HashMap[WrappedExpression, Expression] + private val exprReplacements = new collection.mutable.HashMap[WrappedExpression, Expression] + private val newConns = new mutable.ArrayBuffer[Connect] + + private def findMemConns(s: Statement): Unit = s match { + case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr + case _ => s.foreach(findMemConns) + } + + private def swapMemRefs(e: Expression): Expression = e map swapMemRefs match { + case sf: WSubField => exprReplacements.getOrElse(we(sf), sf) + case ex => ex + } + + private def transform(s: Statement): Statement = s.map(transform) match { + case mem: DefMemory => + // Per-memory bookkeeping + val portNS = Namespace(mem.readers ++ mem.writers) + val rMap = mem.readwriters.map(rw => (rw -> portNS.newName(s"${rw}_r"))).toMap + val wMap = mem.readwriters.map(rw => (rw -> portNS.newName(s"${rw}_w"))).toMap + val newReaders = mem.readers ++ mem.readwriters.map(rMap(_)) + val newWriters = mem.writers ++ mem.readwriters.map(wMap(_)) + val newMem = DefMemory(mem.info, mem.name, mem.dataType, mem.depth, 1, 0, newReaders, newWriters, Nil) + val rCmdDelay = if (mem.readUnderWrite == "old") 0 else mem.readLatency + val rRespDelay = if (mem.readUnderWrite == "old") mem.readLatency else 0 + val wCmdDelay = mem.writeLatency - 1 + + val readStmts = (mem.readers ++ mem.readwriters).map { case r => + def oldDriver(f: String) = netlist(we(memPortField(mem, r, f))) + def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f) + val clk = oldDriver("clk") + + // Pack sources of read command inputs into WithValid object -> different for readwriter + val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en") + val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr"))) + val cmdSink = WithValid(newField("en"), Seq(newField("addr"))) + val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Pipeline read response using *last* command pipe stage enable as the valid signal + val resp = WithValid(cmdPiped.valid, Seq(newField("data"))) + val respPipeNameTemplate = Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names + val (respPiped, respDecls, respConns) = pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate) + + // Make sure references to the read data get appropriately substituted + val oldRDataName = if (rMap.contains(r)) "rdata" else "data" + exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns) + } + + val writeStmts = (mem.writers ++ mem.readwriters).map { case w => + def oldDriver(f: String) = netlist(we(memPortField(mem, w, f))) + def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f) + val clk = oldDriver("clk") + + // Pack sources of write command inputs into WithValid object -> different for readwriter + val cmdSrc = if (wMap.contains(w)) { + val en = AND(oldDriver("en"), oldDriver("wmode")) + WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata"))) + } else { + WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data"))) + } + + // Pipeline write command, connect to memory + val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data"))) + val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls, cmdConns ++ cmdPortConns) + } + + newConns ++= (readStmts ++ writeStmts).flatMap(_.conns) + Block(newMem +: (readStmts ++ writeStmts).flatMap(_.decls)) + case sx: Connect if kind(sx.loc) == MemKind => EmptyStmt // Filter old mem connections + case sx => sx.map(swapMemRefs) + } + + val transformed = m match { + case mod: Module => + findMemConns(mod.body) + mod.copy(body = Block(transform(mod.body) +: newConns.toSeq)) + case mod => mod + } +} + +object VerilogMemDelays extends passes.Pass { + def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) +} diff --git a/sim/midas/src/main/scala/midas/passes/fame/WrapTop.scala b/sim/midas/src/main/scala/midas/passes/fame/WrapTop.scala index f652a31f..3a7a0998 100644 --- a/sim/midas/src/main/scala/midas/passes/fame/WrapTop.scala +++ b/sim/midas/src/main/scala/midas/passes/fame/WrapTop.scala @@ -7,41 +7,55 @@ import ir._ import annotations._ // Wrap the top module of a circuit with another module -class WrapTop extends Transform { +object WrapTop extends Transform { def inputForm = HighForm def outputForm = HighForm + // TODO: Make these names flexible + // Previously, it looked like they were flexible in the code, but they weren't + // This refactors them to reflect that fact + val topWrapperName = "FAMETop" + val hostClockName = "hostClock" + val hostResetName = "hostReset" + + def checkNames(c: Circuit, top: DefModule) = { + val ns = Namespace(top.ports.map(_.name)) + val portsOK = ns.tryName(hostClockName) && ns.tryName(hostResetName) + portsOK && ns.tryName(top.name) && Namespace(c).tryName(topWrapperName) + } + override def execute(state: CircuitState): CircuitState = { val topName = state.circuit.main val topModule = state.circuit.modules.find(_.name == topName).get - val circuitNS = Namespace(state.circuit) - val topWrapperName = circuitNS.newName("FAMETop") - val topWrapperNS = Namespace(topModule.ports.map(_.name)) - val topInstance = WDefInstance(topWrapperNS.newName(topName), topName) + assert(checkNames(state.circuit, topModule)) + + val topInstance = WDefInstance(topName, topName) val portConnections = topModule.ports.map({ case ip @ Port(_, name, Input, _) => Connect(NoInfo, WSubField(WRef(topInstance), name), WRef(ip)) case op @ Port(_, name, Output, _) => Connect(NoInfo, WRef(op), WSubField(WRef(topInstance), name)) }) - val clocks = topModule.ports.filter(_.tpe == ClockType) - val hostClock = clocks.find(_.name == "clock").getOrElse(clocks.head) - val hostReset = HostReset.makePort(topWrapperNS) + + val hostClock = Port(NoInfo, hostClockName, Input, ClockType) + val hostReset = Port(NoInfo, hostResetName, Input, Utils.BoolType) + + val oldCircuitTarget = CircuitTarget(topName) val topWrapperTarget = ModuleTarget(topWrapperName, topWrapperName) - val topWrapper = Module(NoInfo, topWrapperName, topModule.ports :+ hostReset, Block(topInstance +: portConnections)) + val topWrapper = Module(NoInfo, topWrapperName, hostClock +: hostReset +: topModule.ports, Block(topInstance +: portConnections)) val specialPortAnnotations = Seq(FAMEHostClock(topWrapperTarget.ref(hostClock.name)), FAMEHostReset(topWrapperTarget.ref(hostReset.name))) + val renames = RenameMap() val newCircuit = Circuit(state.circuit.info, topWrapper +: state.circuit.modules, topWrapperName) // Make channel annotations point at top-level ports + val fccaRenames = RenameMap() + fccaRenames.record(oldCircuitTarget.module(topName), oldCircuitTarget.module(topWrapperName)) + val updatedAnnotations = state.annotations.map({ - case fca: FAMEChannelConnectionAnnotation => - fca.copy(sinks = fca.sinks.map(_.map(_.copy(module = topWrapperName))), sources = fca.sources.map(_.map(_.copy(module = topWrapperName)))) - case a => a - }).map({ // Also update targets in info fields - case fca @ FAMEChannelConnectionAnnotation(_,info@DecoupledForwardChannel(_,_,_,_),_,_) => - fca.copy(channelInfo = info.copy( - readySink = info.readySink. map(_.copy(module = topWrapperName)), - validSource = info.validSource.map(_.copy(module = topWrapperName)), - readySource = info.readySource.map(_.copy(module = topWrapperName)), - validSink = info.validSink. map(_.copy(module = topWrapperName)))) + case fcca: FAMEChannelConnectionAnnotation => + val renamedInfo = fcca.channelInfo match { + case fwd: DecoupledForwardChannel => fwd.update(fccaRenames) + case info => info + } + fcca.copy(channelInfo = renamedInfo).update(fccaRenames).head // always returns 1 case a => a }) diff --git a/sim/midas/src/main/scala/midas/passes/package.scala b/sim/midas/src/main/scala/midas/passes/package.scala index f6e09394..cff6b8ad 100644 --- a/sim/midas/src/main/scala/midas/passes/package.scala +++ b/sim/midas/src/main/scala/midas/passes/package.scala @@ -15,22 +15,12 @@ import scala.language.implicitConversions package object passes { - /** * A utility for keeping statements defining and connecting signals to a piece of hardware * together with a reference to the component. This is useful for passes that insert hardware, * since the "collateral" of that object can be kept in one place. */ - trait WrappedComponent { - val decl: Statement - val assigns: Statement - val ref: Expression - } - - /** - * Holds the definition of a signal along with the statements that assign to it and its reference. - */ - case class SignalInfo(decl: Statement, assigns: Statement, ref: Expression) extends WrappedComponent + case class SignalInfo(decl: Statement, assigns: Statement, ref: Expression) /** * A utility for creating a wire that "echoes" the value of an existing expression. @@ -44,29 +34,6 @@ package object passes { } } - object InstanceInfo { - def apply(m: DefModule)(implicit ns: Namespace): InstanceInfo = { - val inst = fame.Instantiate(m, ns.newName(m.name)) - InstanceInfo(inst, Block(Nil), WRef(inst)) - } - } - - /** - * Holds the declaration of an instance, along with the set of statements that create connections - * to its ports, along with a reference to the instance. - */ - case class InstanceInfo(decl: WDefInstance, assigns: Block, ref: WRef) extends WrappedComponent { - def addAssign(s: Statement): InstanceInfo = { - copy(assigns = Block(assigns.stmts :+ s)) - } - def connect(pName: String, rhs: Expression): InstanceInfo = { - addAssign(Connect(NoInfo, WSubField(ref, pName), rhs)) - } - def connect(lhs: Expression, pName: String): InstanceInfo = { - addAssign(Connect(NoInfo, lhs, WSubField(ref, pName))) - } - } - /** * This pass ensures that the AbstractClockGate blackbox is defined in a circuit, so that it can * later be instantiated. The blackbox clock gate has the following signature: diff --git a/sim/midas/src/main/scala/midas/widgets/Assert.scala b/sim/midas/src/main/scala/midas/widgets/Assert.scala index 79c4acc9..8b938075 100644 --- a/sim/midas/src/main/scala/midas/widgets/Assert.scala +++ b/sim/midas/src/main/scala/midas/widgets/Assert.scala @@ -15,7 +15,8 @@ class AssertBundle(val numAsserts: Int) extends Bundle { val asserts = Output(UInt(numAsserts.W)) } -class AssertBridgeModule(numAsserts: Int)(implicit p: Parameters) extends BridgeModule[HostPortIO[UInt]]()(p) { +class AssertBridgeModule(assertMessages: Seq[String])(implicit p: Parameters) extends BridgeModule[HostPortIO[UInt]]()(p) { + val numAsserts = assertMessages.size val io = IO(new WidgetIO()) val hPort = IO(HostPort(Input(UInt(numAsserts.W)))) val resume = WireInit(false.B) @@ -43,4 +44,12 @@ class AssertBridgeModule(numAsserts: Int)(implicit p: Parameters) extends Bridge genROReg(cycles >> 32, "cycle_high") Pulsify(genWORegInit(resume, "resume", false.B), pulseLength = 1) genCRFile() + + override def genHeader(base: BigInt, sb: StringBuilder) { + import CppGenerationUtils._ + val headerWidgetName = getWName.toUpperCase + super.genHeader(base, sb) + sb.append(genConstStatic(s"${headerWidgetName}_assert_count", UInt32(assertMessages.size))) + sb.append(genArray(s"${headerWidgetName}_assert_messages", assertMessages.map(CStrLit))) + } } diff --git a/sim/midas/src/main/scala/midas/widgets/AutoCounterBridge.scala b/sim/midas/src/main/scala/midas/widgets/AutoCounterBridge.scala index 18a0996c..740ff412 100644 --- a/sim/midas/src/main/scala/midas/widgets/AutoCounterBridge.scala +++ b/sim/midas/src/main/scala/midas/widgets/AutoCounterBridge.scala @@ -3,7 +3,6 @@ package widgets import chisel3._ import chisel3.util._ -import chisel3.experimental.{DataMirror, Direction} import chisel3.util.experimental.BoringUtils import freechips.rocketchip.config.{Parameters, Field} import freechips.rocketchip.diplomacy.AddressSet @@ -13,28 +12,29 @@ trait AutoCounterConsts { val counterWidth = 64 } -class AutoCounterBundle(val numCounters: Int) extends Bundle with AutoCounterConsts { - val counters = Input(Vec(numCounters, UInt(counterWidth.W))) +class AutoCounterBundle(eventNames: Seq[String], triggerName: String) extends Record { + val triggerEnable = Input(Bool()) + val events = eventNames.map(_ -> Input(Bool())) + val elements = collection.immutable.ListMap(((triggerName, triggerEnable) +: + events):_*) + override def cloneType = new AutoCounterBundle(eventNames, triggerName).asInstanceOf[this.type] } -case class AutoCounterBridgeConstArgs(numcounters: Int, autoCounterPortsMap: scala.collection.mutable.Map[Int, String], hastracerwidget: Boolean = false) - -class AutoCounterToHostToken(val numCounters: Int) extends Bundle { - val data_out = Vec(numCounters, UInt(64.W)) - val cycle = UInt(64.W) +class AutoCounterToHostToken(val numCounters: Int) extends Bundle with AutoCounterConsts { + val data_out = Vec(numCounters, UInt(counterWidth.W)) + val cycle = UInt(counterWidth.W) } -class AutoCounterBridgeModule(constructorArg: AutoCounterBridgeConstArgs)(implicit p: Parameters) extends BridgeModule[HostPortIO[AutoCounterBundle]]()(p) { - - val numCounters = constructorArg.numcounters - val labels = constructorArg.autoCounterPortsMap - val hastracerwidget = constructorArg.hastracerwidget - val trigger = WireDefault(true.B) +class AutoCounterBridgeModule(events: Seq[(String, String)], triggerName: String)(implicit p: Parameters) + extends BridgeModule[HostPortIO[AutoCounterBundle]]()(p) with AutoCounterConsts { + val numCounters = events.size + val (portNames, labels) = events.unzip val io = IO(new WidgetIO()) - val hPort = IO(HostPort(new AutoCounterBundle(numCounters))) - val cycles = RegInit(0.U(64.W)) - val acc_cycles = RegInit(0.U(64.W)) + val hPort = IO(HostPort(new AutoCounterBundle(portNames, triggerName))) + val trigger = hPort.hBits.triggerEnable + val cycles = RegInit(0.U(counterWidth.W)) + val acc_cycles = RegInit(0.U(counterWidth.W)) val hostCyclesWidthOffset = 64 - p(CtrlNastiKey).dataBits val hostCyclesLowWidth = if (hostCyclesWidthOffset > 0) p(CtrlNastiKey).dataBits else 64 @@ -50,83 +50,64 @@ class AutoCounterBridgeModule(constructorArg: AutoCounterBridgeConstArgs)(implic val readrate_low = RegInit(0.U(hostReadrateLowWidth.W)) val readrate_high = RegInit(0.U(hostReadrateHighWidth.W)) - val readrate = Wire(UInt(64.W)) - val readrate_dly = RegInit(0.U(64.W)) - readrate := Cat(readrate_high, readrate_low) + val readrate = Cat(readrate_high, readrate_low) + val initDone = RegInit(false.B) - val acc_counters = RegInit(VecInit(Seq.fill(numCounters)(0.U(64.W)))) - - hastracerwidget match { - case true => BoringUtils.addSink(trigger, s"trace_trigger") - case _ => trigger := true.B - } - - val tFireHelper = DecoupledHelper(hPort.toHost.hValid, hPort.fromHost.hReady) + val tFireHelper = DecoupledHelper(hPort.toHost.hValid, hPort.fromHost.hReady, initDone) val targetFire = tFireHelper.fire // We only sink tokens, so tie off the return channel hPort.fromHost.hValid := true.B when (targetFire) { cycles := cycles + 1.U - readrate_dly := readrate } - val periodcycles = RegInit(1.U(64.W)) - when (targetFire & (readrate =/= readrate_dly)) { - periodcycles := readrate - 2.U - } .elsewhen (targetFire & (periodcycles === 0.U) & (readrate > 0.U)) { - periodcycles := readrate - 1.U - } .elsewhen (targetFire & (cycles > 0.U) & (readrate > 0.U)) { - periodcycles := periodcycles - 1.U + val counters = hPort.hBits.events.unzip._2.map({ increment => + val count = RegInit(0.U(counterWidth.W)) + when (targetFire && increment) { + count := count + 1.U + } + count + }).toSeq + + val periodcycles = RegInit(0.U(64.W)) + val isSampleCycle = periodcycles === readrate + when (targetFire && isSampleCycle) { + periodcycles := 0.U + } .elsewhen (targetFire) { + periodcycles := periodcycles + 1.U } - val btht_queue = Module(new Queue(new AutoCounterToHostToken(numCounters), 10)) + val btht_queue = Module(new Queue(new AutoCounterToHostToken(numCounters), 2)) - btht_queue.io.enq.valid := (periodcycles === 0.U) & (cycles > 0.U) & (cycles >= readrate-1.U) & targetFire & trigger - btht_queue.io.enq.bits.data_out := hPort.hBits.counters + btht_queue.io.enq.valid := isSampleCycle & targetFire & trigger + btht_queue.io.enq.bits.data_out := VecInit(counters) btht_queue.io.enq.bits.cycle := cycles - hPort.toHost.hReady := btht_queue.io.enq.ready & hPort.fromHost.hReady + hPort.toHost.hReady := targetFire + val (lowCountAddrs, highCountAddrs) = (for ((counter, label) <- btht_queue.io.deq.bits.data_out.zip(labels)) yield { + val lowAddr = attach(counter(hostCounterLowWidth-1, 0), s"autocounter_low_${label}", ReadOnly) + val highAddr = attach(counter >> hostCounterLowWidth, s"autocounter_high_${label}", ReadOnly) + (lowAddr, highAddr) + }).unzip //communication with the driver - val readdone = RegInit(false.B) - val readdone_dly = RegInit(false.B) - val readdone_negedge = Wire(Bool()) - val readdone_posedge = Wire(Bool()) - val med = RegInit(false.B) - readdone_dly := readdone - readdone_posedge := readdone & ~readdone_dly - readdone_negedge := readdone_dly & ~readdone - - - when (btht_queue.io.deq.fire()) { - for (i <- 0 to numCounters-1) { acc_counters(i) := btht_queue.io.deq.bits.data_out(i) } - acc_cycles := btht_queue.io.deq.bits.cycle - med := false.B - } .elsewhen (readdone_posedge) { - med := true.B - } .elsewhen (readdone_negedge) { - med := false.B - } - - btht_queue.io.deq.ready := med - - - labels.keys.foreach { - case (index) => { - attach(acc_counters(index)(hostCounterLowWidth-1, 0), s"autocounter_low_${labels(index)}", ReadOnly) - attach(acc_counters(index) >> hostCounterLowWidth, s"autocounter_high_${labels(index)}", ReadOnly) - } - } - - - attach(acc_cycles(hostCyclesLowWidth-1, 0), "cycles_low", ReadOnly) - attach(acc_cycles >> hostCyclesLowWidth, "cycles_high", ReadOnly) + attach(btht_queue.io.deq.bits.cycle(hostCyclesLowWidth-1, 0), "cycles_low", ReadOnly) + attach(btht_queue.io.deq.bits.cycle >> hostCyclesLowWidth, "cycles_high", ReadOnly) attach(readrate_low, "readrate_low", WriteOnly) attach(readrate_high, "readrate_high", WriteOnly) + attach(initDone, "init_done", WriteOnly) attach(btht_queue.io.deq.valid, "countersready", ReadOnly) - attach(readdone, "readdone", WriteOnly) - + Pulsify(genWORegInit(btht_queue.io.deq.ready, "readdone", false.B), 1) + override def genHeader(base: BigInt, sb: StringBuilder) { + headerComment(sb) + // Exclude counter addresses as their names can vary across AutoCounter instances, but + // we only generate a single struct typedef + val headerWidgetName = wName.getOrElse(name).toUpperCase + crRegistry.genHeader(headerWidgetName, base, sb, lowCountAddrs ++ highCountAddrs) + crRegistry.genArrayHeader(headerWidgetName, base, sb) + emitClockDomainInfo(headerWidgetName, sb) + } genCRFile() } diff --git a/sim/midas/src/main/scala/midas/widgets/Bridge.scala b/sim/midas/src/main/scala/midas/widgets/Bridge.scala index a79d02b1..6780225f 100644 --- a/sim/midas/src/main/scala/midas/widgets/Bridge.scala +++ b/sim/midas/src/main/scala/midas/widgets/Bridge.scala @@ -7,7 +7,7 @@ import midas.core.{SimWrapperChannels, SimUtils} import midas.core.SimUtils.{RVChTuple} import midas.passes.fame.{FAMEChannelConnectionAnnotation,DecoupledForwardChannel, PipeChannel, DecoupledReverseChannel, WireChannel, JsonProtocol, HasSerializationHints} -import freechips.rocketchip.config.Parameters +import freechips.rocketchip.config.{Parameters, Field} import chisel3._ import chisel3.util._ @@ -23,11 +23,23 @@ import scala.reflect.runtime.{universe => ru} * */ +// Set in FPGA Top before the BridgeModule is generated +case object TargetClockInfo extends Field[Option[RationalClock]] + abstract class TokenizedRecord extends Record with HasChannels abstract class BridgeModule[HostPortType <: TokenizedRecord] (implicit p: Parameters) extends Widget()(p) { def hPort: HostPortType + def clockDomainInfo: RationalClock = p(TargetClockInfo).get + def emitClockDomainInfo(headerWidgetName: String, sb: StringBuilder): Unit = { + import CppGenerationUtils._ + val RationalClock(domainName, mul, div) = clockDomainInfo + sb.append(genStatic(s"${headerWidgetName}_clock_domain_name", CStrLit(domainName))) + sb.append(genConstStatic(s"${headerWidgetName}_clock_multiplier", UInt32(mul))) + sb.append(genConstStatic(s"${headerWidgetName}_clock_divisor", UInt32(div))) + } + } trait Bridge[HPType <: TokenizedRecord, WidgetType <: BridgeModule[HPType]] { @@ -85,9 +97,17 @@ trait HasChannels { * Implementation follows * */ + private[midas] def getClock(): Clock = { + val allTargetClocks = SimUtils.findClocks(targetPortProto) + require(allTargetClocks.nonEmpty, + s"Target-side bridge interface of ${targetPortProto.getClass} has no clock field.") + require(allTargetClocks.size == 1, + s"Target-side bridge interface of ${targetPortProto.getClass} has ${allTargetClocks.size} clocks but must define only one.") + allTargetClocks.head + } - private[midas] def inputChannelNames(): Seq[String] = inputWireChannels.map(_._2) private[midas] def outputChannelNames(): Seq[String] = outputWireChannels.map(_._2) + private[midas] def inputChannelNames(): Seq[String] = inputWireChannels.map(_._2) private def getRVChannelNames(channels: Seq[RVChTuple]): Seq[String] = channels.flatMap({ channel => @@ -113,9 +133,9 @@ trait HasChannels { for ((field, chName) <- channels) { annotate(new ChiselAnnotation { def toFirrtl = if (bridgeSunk) { - FAMEChannelConnectionAnnotation.source(chName, PipeChannel(latency), Seq(field.toNamed.toTarget)) + FAMEChannelConnectionAnnotation.source(chName, PipeChannel(latency), Some(getClock.toNamed.toTarget), Seq(field.toNamed.toTarget)) } else { - FAMEChannelConnectionAnnotation.sink (chName, PipeChannel(latency), Seq(field.toNamed.toTarget)) + FAMEChannelConnectionAnnotation.sink(chName, PipeChannel(latency), Some(getClock.toNamed.toTarget), Seq(field.toNamed.toTarget)) } }) } @@ -127,6 +147,7 @@ trait HasChannels { // Generate the forward channel annotation val (fwdChName, revChName) = SimUtils.rvChannelNamePair(chName) annotate(new ChiselAnnotation { def toFirrtl = { + val clockTarget = Some(getClock.toNamed.toTarget) val validTarget = field.valid.toNamed.toTarget val readyTarget = field.ready.toNamed.toTarget val leafTargets = Seq(validTarget) ++ lowerAggregateIntoLeafTargets(field.bits) @@ -135,6 +156,7 @@ trait HasChannels { FAMEChannelConnectionAnnotation.source( fwdChName, DecoupledForwardChannel.source(validTarget, readyTarget), + clockTarget, leafTargets ) } else { @@ -142,17 +164,19 @@ trait HasChannels { FAMEChannelConnectionAnnotation.sink( fwdChName, DecoupledForwardChannel.sink(validTarget, readyTarget), + clockTarget, leafTargets ) } }}) annotate(new ChiselAnnotation { def toFirrtl = { + val clockTarget = Some(getClock.toNamed.toTarget) val readyTarget = Seq(field.ready.toNamed.toTarget) if (bridgeSunk) { - FAMEChannelConnectionAnnotation.sink(revChName, DecoupledReverseChannel, readyTarget) + FAMEChannelConnectionAnnotation.sink(revChName, DecoupledReverseChannel, clockTarget, readyTarget) } else { - FAMEChannelConnectionAnnotation.source(revChName, DecoupledReverseChannel, readyTarget) + FAMEChannelConnectionAnnotation.source(revChName, DecoupledReverseChannel, clockTarget, readyTarget) } }}) } diff --git a/sim/midas/src/main/scala/midas/widgets/BridgeAnnotations.scala b/sim/midas/src/main/scala/midas/widgets/BridgeAnnotations.scala index c4d334cf..8598e7ba 100644 --- a/sim/midas/src/main/scala/midas/widgets/BridgeAnnotations.scala +++ b/sim/midas/src/main/scala/midas/widgets/BridgeAnnotations.scala @@ -99,6 +99,9 @@ case class InMemoryBridgeAnnotation( * @param channelMapping A mapping from the channel names initially emitted by the Chisel Module, to uniquified global ones * to find associated FCCAs for this bridge * + * @param clockInfo Contains information about the domain in which the bridge is instantiated. + * This will always be nonEmpty for bridges instantiated in the input FIRRTL + * * @param widget An optional lambda to elaborate the host-land BridgeModule. See InMemoryBridgeAnnotation * * @param widgetClass The BridgeModule's full class name. See SerializableBridgeAnnotation @@ -110,6 +113,7 @@ case class InMemoryBridgeAnnotation( private[midas] case class BridgeIOAnnotation( val target: ReferenceTarget, channelMapping: Map[String, String], + clockInfo: Option[RationalClock] = None, widget: Option[(Parameters) => BridgeModule[_ <: TokenizedRecord]] = None, widgetClass: Option[String] = None, widgetConstructorKey: Option[_ <: AnyRef] = None) extends SingleTargetAnnotation[ReferenceTarget] { @@ -119,17 +123,20 @@ private[midas] case class BridgeIOAnnotation( // Elaborates the BridgeModule using the lambda if it exists // Otherwise, uses reflection to find the constructor for the class given by // widgetClass, passing it the widgetConstructorKey - def elaborateWidget(implicit p: Parameters): BridgeModule[_ <: TokenizedRecord] = widget match { - case Some(elaborator) => elaborator(p) - case None => - println(s"Instantiating bridge ${target.ref} of type ${widgetClass.get}") - val constructor = Class.forName(widgetClass.get).getConstructors()(0) - (widgetConstructorKey match { - case Some(key) => - println(s" With constructor arguments: $key") - constructor.newInstance(key, p) - case None => constructor.newInstance(p) - }).asInstanceOf[BridgeModule[_ <: TokenizedRecord]] + def elaborateWidget(implicit p: Parameters): BridgeModule[_ <: TokenizedRecord] = { + val px = p alterPartial { case TargetClockInfo => clockInfo } + widget match { + case Some(elaborator) => elaborator(px) + case None => + println(s"Instantiating bridge ${target.ref} of type ${widgetClass.get}") + val constructor = Class.forName(widgetClass.get).getConstructors()(0) + (widgetConstructorKey match { + case Some(key) => + println(s" With constructor arguments: $key") + constructor.newInstance(key, px) + case None => constructor.newInstance(px) + }).asInstanceOf[BridgeModule[_ <: TokenizedRecord]] + } } } @@ -139,6 +146,6 @@ private[midas] object BridgeIOAnnotation { def apply(target: ReferenceTarget, widget: (Parameters) => BridgeModule[_ <: TokenizedRecord], channelNames: Seq[String]): BridgeIOAnnotation = - BridgeIOAnnotation(target, channelNames.map(p => p -> p).toMap, Some(widget)) + BridgeIOAnnotation(target, channelNames.map(p => p -> p).toMap, widget = Some(widget)) } diff --git a/sim/midas/src/main/scala/midas/widgets/ChannelizedHostPort.scala b/sim/midas/src/main/scala/midas/widgets/ChannelizedHostPort.scala index 800fbdaf..70cd678a 100644 --- a/sim/midas/src/main/scala/midas/widgets/ChannelizedHostPort.scala +++ b/sim/midas/src/main/scala/midas/widgets/ChannelizedHostPort.scala @@ -21,7 +21,7 @@ abstract class ChannelizedHostPortIO(protected val targetPortProto: Data) extend lazy val fieldToChannelMap = Map((_inputWireChannels ++ _outputWireChannels):_*) private def getLeafDirs(token: Data): Seq[Direction] = token match { - case c: Clock => Seq() + case c: Clock => throw new Exception("Tokens cannot contain clock fields") case b: Record => b.elements.flatMap({ case (_, e) => getLeafDirs(e)}).toSeq case v: Vec[_] => v.flatMap(getLeafDirs) case b: Bits => Seq(directionOf(b)) diff --git a/sim/midas/src/main/scala/midas/widgets/ClockBridge.scala b/sim/midas/src/main/scala/midas/widgets/ClockBridge.scala new file mode 100644 index 00000000..c584c9e5 --- /dev/null +++ b/sim/midas/src/main/scala/midas/widgets/ClockBridge.scala @@ -0,0 +1,207 @@ +// See LICENSE for license details. + +package midas.widgets + +import midas.core.{SimWrapperChannels, SimUtils} +import midas.core.SimUtils.{RVChTuple} +import midas.passes.fame.{FAMEChannelConnectionAnnotation, TargetClockChannel} + +import freechips.rocketchip.config.Parameters +import freechips.rocketchip.util.DensePrefixSum + +import chisel3._ +import chisel3.util._ +import chisel3.experimental.{BaseModule, Direction, ChiselAnnotation, annotate} +import firrtl.annotations.{ModuleTarget, ReferenceTarget} + +/** + * Defines a generated clock as a rational multiple of some reference clock. The generated + * clock has a frequency (multiplier / divisor) times that of reference. + * + * @param name An identifier for the associated clock domain + * + * @param multiplier See class comment. + * + * @param divisor See class comment. + */ +case class RationalClock(name: String, multiplier: Int, divisor: Int) + +sealed trait ClockBridgeConsts { + val clockChannelName = "clocks" + val refClockDomain = "baseClock" +} + +/** + * A custom bridge annotation for the Clock Bridge. Unique so that we can + * trivially match against it in bridge extraction. + * + * @param target The target-side module for the CB + * + * @param clocks The associated clock information for each output clock (including the base). + */ + +case class ClockBridgeAnnotation(val target: ModuleTarget, clocks: Seq[RationalClock]) + extends BridgeAnnotation with ClockBridgeConsts { + val channelNames = Seq(clockChannelName) + def duplicate(n: ModuleTarget) = this.copy(target) + def toIOAnnotation(port: String): BridgeIOAnnotation = { + val channelMapping = channelNames.map(oldName => oldName -> s"${port}_$oldName") + BridgeIOAnnotation( + target.copy(module = target.circuit).ref(port), + channelMapping.toMap, + widget = Some((p: Parameters) => new ClockBridgeModule(clocks)(p)) + ) + } +} + +/** + * The default target-side clock bridge. Generates a "base clock" and a vector of + * additional clocks related to that base clock. Simulation times are + * generally expressed in terms of this base clock. + * + * @param additionalClocks Rational clock information for each additional + * clock beyond the base + */ +class RationalClockBridge(additionalClocks: RationalClock*) extends BlackBox with ClockBridgeConsts { + outer => + // Always generate the base (element 0 in our output vec) + val baseClock = RationalClock(refClockDomain, 1, 1) + val allClocks = baseClock +: additionalClocks + val io = IO(new Bundle { + val clocks = Output(Vec(allClocks.size, Clock())) + }) + + // Generate the bridge annotation + annotate(new ChiselAnnotation { def toFirrtl = ClockBridgeAnnotation( outer.toTarget, allClocks) }) + annotate(new ChiselAnnotation { def toFirrtl = + FAMEChannelConnectionAnnotation( + clockChannelName, + channelInfo = TargetClockChannel(allClocks), + clock = None, // Clock channels do not have a reference clock + sinks = Some(io.clocks.map(_.toTarget)), + sources = None + ) + }) +} + +/** + * The host-land clock bridge interface. This consists of a single channel, + * carrying clock tokens. A clock token is a Vec[Bool], one element per clock, When a bit is set, + * that clock domain will fire in the simulator time step that consumes this clock token. + * + * NB: The target-time elapsed between tokens is not necessarily constant. + * + * @param numClocks The total number of clocks in the channel (inclusive of the base clock) + * + */ +class ClockTokenVector(numClocks: Int) extends TokenizedRecord with ClockBridgeConsts { + def targetPortProto(): Vec[Bool] = Vec(numClocks, Bool()) + val clocks = new DecoupledIO(targetPortProto) + + def outputWireChannels = Seq(clocks -> clockChannelName) + def inputWireChannels = Seq() + def outputRVChannels = Seq() + def inputRVChannels = Seq() + + def connectChannels2Port(bridgeAnno: BridgeIOAnnotation, simIo: SimWrapperChannels): Unit = { + val local2globalName = bridgeAnno.channelMapping.toMap + for (localName <- outputChannelNames) { + simIo.clockElement._2 <> elements(localName) + } + } + + val elements = collection.immutable.ListMap(clockChannelName -> clocks) + override def cloneType(): this.type = new ClockTokenVector(numClocks).asInstanceOf[this.type] + def generateAnnotations(): Unit = {} +} + +/** + * The host-side implementation. Based on provided a clock information, generates a clock token stream + * which will be sunk by the FAME-1 hub model. This token stream does not + * depend on the runtime-behavior of the target, allowing this bridge run + * ahead of the rest of the simulation. + * + * Target and host time measurements provided by simif_t are facilitated with MMIO to this bridge + * + * @param clockInfo Clock frequency information for each target clock + * + */ +class ClockBridgeModule(clockInfo: Seq[RationalClock])(implicit p: Parameters) + extends BridgeModule[ClockTokenVector] { + val io = IO(new WidgetIO()) + val hPort = IO(new ClockTokenVector(clockInfo.size)) + val phaseRelationships = clockInfo map { cInfo => (cInfo.multiplier, cInfo.divisor) } + val clockTokenGen = Module(new RationalClockTokenGenerator(phaseRelationships)) + hPort.clocks <> clockTokenGen.io + + val hCycleName = "hCycle" + val hCycle = genWideRORegInit(0.U(64.W), hCycleName) + hCycle := hCycle + 1.U + + // Count the number of clock tokens for which the fastest clock is scheduled to fire + // --> Use to calculate FMR + val tCycleFastest = genWideRORegInit(0.U(64.W), "tCycle") + val fastestClockIdx = (phaseRelationships).map({ case (n, d) => n.toDouble / d }) + .zipWithIndex + .sortBy(_._1) + .last._2 + + when (hPort.clocks.fire && hPort.clocks.bits(fastestClockIdx)) { + tCycleFastest := tCycleFastest + 1.U + } + genCRFile() +} + +/** + * Finds a virtual fast-clock whose period is the GCD of the periods of all requested + * clocks, and returns the period of each requested clock as an integer multiple of that + * high-frequency virtual clock. + */ +object FindScaledPeriodGCD { + def apply(phaseRelationships: Seq[(Int, Int)]): Seq[BigInt] = { + val periodDivisors = phaseRelationships.unzip._1 + val productOfDivisors = periodDivisors.foldLeft(BigInt(1))(_ * _) + val scaledMultipliers = phaseRelationships.map({ case (divisor, multiplier) => multiplier * productOfDivisors / divisor }) + val gcdOfScaledPeriods = scaledMultipliers.reduce((a, b) => a.gcd(b)) + val reducedPeriods = scaledMultipliers.map(_ / gcdOfScaledPeriods) + reducedPeriods + } +} + +/** + * Generates an infinite clock token stream based on rational relationship of each clock. + * To improve simulator FMR, this module always produces non-zero clock tokens + * + * @param phaseRelationships multiplier, divisor pairs for each clock + */ +class RationalClockTokenGenerator(phaseRelationships: Seq[(Int, Int)]) extends Module { + val numClocks = phaseRelationships.size + val io = IO(new DecoupledIO(Vec(numClocks, Bool()))) + // The clock token stream is known a priori! + io.valid := true.B + + // Determine the number of virtual-clock cycles for each target clock. + val clockPeriodicity = FindScaledPeriodGCD(phaseRelationships) + val counterWidth = clockPeriodicity.map(p => log2Ceil(p + 1)).reduce((a, b) => math.max(a, b)) + + // This is an arbitrarily selected number; feel free to increase it. If we + // need more time resolution we can trivially pipeline this thing. + val maxCounterWidth = 16 + require(counterWidth <= maxCounterWidth, "Ensure this circuit doesn't blow up") + + // For each target clock, count the number of virtual cycles until the next expected clock edge + val timeToNextEdge = RegInit(VecInit(Seq.fill(numClocks)(0.U(counterWidth.W)))) + // Find the smallest number of virtual-clock cycles that must must advance + // before one real clock would fire. + val minStepsToEdge = DensePrefixSum(timeToNextEdge)({ case (a, b) => Mux(a < b, a, b) }).last + + // Advance the virtual clock (minStepsToEdge) cycles, and determine which + // target clocks have an edge at that time to populate the clock token + io.bits := VecInit(for ((reg, period) <- timeToNextEdge.zip(clockPeriodicity)) yield { + val clockFiring = reg === minStepsToEdge + when (io.ready) { + reg := Mux(clockFiring, period.U, reg - minStepsToEdge) + } + clockFiring + }) +} diff --git a/sim/midas/src/main/scala/midas/widgets/CppGeneration.scala b/sim/midas/src/main/scala/midas/widgets/CppGeneration.scala index 7af3112c..2ae8694a 100644 --- a/sim/midas/src/main/scala/midas/widgets/CppGeneration.scala +++ b/sim/midas/src/main/scala/midas/widgets/CppGeneration.scala @@ -3,12 +3,12 @@ package midas package widgets -trait CPPLiteral { +sealed trait CPPLiteral { def typeString: String def toC: String } -trait IntLikeLiteral extends CPPLiteral { +sealed trait IntLikeLiteral extends CPPLiteral { def bitWidth: Int def literalSuffix: String def value: BigInt @@ -31,7 +31,7 @@ case class UInt64(value: BigInt) extends IntLikeLiteral { case class CStrLit(val value: String) extends CPPLiteral { def typeString = "const char* const" - def toC = "\"%s\"".format(value) + def toC = "R\"ESC(%s)ESC\"".format(value) } object CppGenerationUtils { @@ -64,6 +64,7 @@ object CppGenerationUtils { def genComment(str: String): String = "// %s\n".format(str) + implicit def toStrLit(str: String): CStrLit = CStrLit(str) } diff --git a/sim/midas/src/main/scala/midas/widgets/HostPort.scala b/sim/midas/src/main/scala/midas/widgets/HostPort.scala index 6e41fe94..78871831 100644 --- a/sim/midas/src/main/scala/midas/widgets/HostPort.scala +++ b/sim/midas/src/main/scala/midas/widgets/HostPort.scala @@ -23,7 +23,6 @@ import scala.collection.mutable * (It is also possible to use this for very simple models where the one or * more outputs depend combinationally on a _single_ input token (the toHost * field)) - * */ // We're using a Record here because reflection in Bundle prematurely initializes our lazy vals @@ -64,7 +63,7 @@ class HostPortIO[+T <: Data](protected val targetPortProto: T) extends Tokenized field := tokenChannel.bits toHostChannels += tokenChannel } - + for ((field, localName) <- outputWireChannels) { val tokenChannel = simIo.wireInputPortMap(local2globalName(localName)) tokenChannel.bits := field @@ -105,6 +104,9 @@ class HostPortIO[+T <: Data](protected val targetPortProto: T) extends Tokenized // Enqueue into the toHost channels only once all toHost channels can accept the token val fromHostHelper = DecoupledHelper((fromHost.hValid +: fromHostChannels.map(_.ready)):_*) fromHostChannels.foreach(ch => ch.valid := fromHostHelper.fire(ch.ready)) + + // Tie off the target clock; these should be unused in the BridgeModule + SimUtils.findClocks(hBits).map(_ := false.B.asClock) } def generateAnnotations(): Unit = { diff --git a/sim/midas/src/main/scala/midas/widgets/Lib.scala b/sim/midas/src/main/scala/midas/widgets/Lib.scala index 2f45f2c0..a829c888 100644 --- a/sim/midas/src/main/scala/midas/widgets/Lib.scala +++ b/sim/midas/src/main/scala/midas/widgets/Lib.scala @@ -225,16 +225,18 @@ class MCRFileMap() { case (e: RegisterEntry, addr) => mcrIO.bindReg(e, addr) } - def genHeader(prefix: String, base: BigInt, sb: StringBuilder): Unit = { + def genHeader(prefix: String, base: BigInt, sb: StringBuilder, addrsToExclude: Seq[Int] = Nil): Unit = { // get widget name with no widget number (prefix includes it) val prefix_no_num = prefix.split("_")(0) + val filteredRegs = name2addr.toList.filterNot({ case (_, idx) => addrsToExclude.contains(idx) }) + // emit generic struct for this widget type. guarded so it only gets // defined once sb append s"#ifndef ${prefix_no_num}_struct_guard\n" sb append s"#define ${prefix_no_num}_struct_guard\n" sb append s"typedef struct ${prefix_no_num}_struct {\n" - name2addr.toList foreach { case (regName, idx) => + filteredRegs foreach { case (regName, idx) => sb append s" unsigned long ${regName};\n" } sb append s"} ${prefix_no_num}_struct;\n" @@ -248,7 +250,7 @@ class MCRFileMap() { sb append s"#define ${prefix}_substruct_create \\\n" // assume the widget destructor will free this sb append s"${prefix_no_num}_struct * ${prefix}_substruct = (${prefix_no_num}_struct *) malloc(sizeof(${prefix_no_num}_struct)); \\\n" - name2addr.toList foreach { case (regName, idx) => + filteredRegs foreach { case (regName, idx) => val address = base + idx sb append s"${prefix}_substruct->${regName} = ${address}; \\\n" } diff --git a/sim/midas/src/main/scala/midas/widgets/PeekPokeIO.scala b/sim/midas/src/main/scala/midas/widgets/PeekPokeIO.scala index f02cecef..a5228f70 100644 --- a/sim/midas/src/main/scala/midas/widgets/PeekPokeIO.scala +++ b/sim/midas/src/main/scala/midas/widgets/PeekPokeIO.scala @@ -51,9 +51,6 @@ class PeekPokeBridgeModule(key: PeekPokeKey)(implicit p: Parameters) extends Bri val tCycle = genWideRORegInit(0.U(64.W), tCycleName) val tCycleAdvancing = WireInit(false.B) - val hCycleName = "hCycle" - val hCycle = genWideRORegInit(0.U(64.W), hCycleName) - // needs back pressure from reset queues io.idle := cycleHorizon === 0.U @@ -137,8 +134,6 @@ class PeekPokeBridgeModule(key: PeekPokeKey)(implicit p: Parameters) extends Bri tCycleAdvancing := true.B } - hCycle := hCycle + 1.U - when (io.step.fire) { cycleHorizon := io.step.bits } @@ -201,8 +196,10 @@ object PeekPokeTokenizedIO { } class PeekPokeTargetIO(targetIO: Seq[(String, Data)], withReset: Boolean) extends Record { + val clock = Input(Clock()) val reset = if (withReset) Some(Output(Bool())) else None override val elements = ListMap(( + Seq("clock" -> clock) ++ reset.map("reset" -> _).toSeq ++ targetIO.map({ case (name, field) => name -> Flipped(chiselTypeOf(field)) }) ):_*) @@ -219,10 +216,11 @@ class PeekPokeBridge(targetIO: Seq[(String, Data)], reset: Option[Bool]) extends object PeekPokeBridge { @chiselName - def apply(reset: Bool, ioList: (String, Data)*): PeekPokeBridge = { + def apply(clock: Clock, reset: Bool, ioList: (String, Data)*): PeekPokeBridge = { val peekPokeBridge = Module(new PeekPokeBridge(ioList, Some(reset))) ioList.foreach({ case (name, field) => field <> peekPokeBridge.io.elements(name) }) reset := peekPokeBridge.io.reset.get + peekPokeBridge.io.clock := clock peekPokeBridge } } diff --git a/sim/midas/src/main/scala/midas/widgets/Print.scala b/sim/midas/src/main/scala/midas/widgets/Print.scala index 83269921..163872d1 100644 --- a/sim/midas/src/main/scala/midas/widgets/Print.scala +++ b/sim/midas/src/main/scala/midas/widgets/Print.scala @@ -17,11 +17,13 @@ class PrintRecord(portType: firrtl.ir.BundleType, val formatString: String) exte def regenLeafType(tpe: firrtl.ir.Type): Data = tpe match { case firrtl.ir.UIntType(width: firrtl.ir.IntWidth) => UInt(width.width.toInt.W) case firrtl.ir.SIntType(width: firrtl.ir.IntWidth) => SInt(width.width.toInt.W) + case firrtl.ir.SIntType(width: firrtl.ir.IntWidth) => SInt(width.width.toInt.W) case badType => throw new RuntimeException(s"Unexpected type in PrintBundle: ${badType}") } val args: Seq[(String, Data)] = portType.fields.collect({ - case firrtl.ir.Field(name, _, tpe) if name != "enable" => (name -> Output(regenLeafType(tpe))) + case firrtl.ir.Field(name, _, tpe) if name != "enable" && name != "clock" => + (name -> Output(regenLeafType(tpe))) }) val enable = Output(Bool()) @@ -184,6 +186,7 @@ class PrintBridgeModule(printPorts: Seq[(firrtl.ir.Port, String)])(implicit p: P sb.append(genArray(s"${headerWidgetName}_format_strings", formatStrings)) sb.append(genArray(s"${headerWidgetName}_argument_counts", argumentCounts)) sb.append(genArray(s"${headerWidgetName}_argument_widths", argumentWidths)) + emitClockDomainInfo(headerWidgetName, sb) } genCRFile() } diff --git a/sim/midas/src/test/scala/midas/BridgeTopWiringSpec.scala b/sim/midas/src/test/scala/midas/BridgeTopWiringSpec.scala new file mode 100644 index 00000000..a47f3d9b --- /dev/null +++ b/sim/midas/src/test/scala/midas/BridgeTopWiringSpec.scala @@ -0,0 +1,70 @@ +// See LICENSE for license details. + +package goldengate.tests + +import midas.passes._ + +import firrtl._ +import firrtl.ir._ +import firrtl.annotations._ +// Switch to FIRRTL in 3.2 +import midas.firrtl.testutils._ + +class BridgeTopWiringSpec extends MiddleTransformSpec with FirrtlRunners { + + def transform = new BridgeTopWiring("t_") + + "The signal x in module A" should s"should be wired to Top with the correct clocks" in { + val input = + """circuit Top : + | module Top : + | input clock1 : Clock + | input clock2 : Clock + | inst a1 of A + | a1.clock <= clock1 + | inst a2 of A + | a2.clock <= clock2 + | module A : + | input clock : Clock + | wire x : UInt<1> + | x <= UInt(1) + """.stripMargin + val aIT = ModuleTarget("Top", "A") + val topMT = ModuleTarget("Top", "Top") + val annos = Seq(BridgeTopWiringAnnotation(aIT.ref("x"), aIT.ref("clock"))) + val checkAnnos = Seq( + BridgeTopWiringOutputAnnotation(aIT.ref("x"), + aIT.addHierarchy("Top", "a1").ref("x"), + topMT.ref("t_a1_x"), + topMT.ref("t_clock1")), + BridgeTopWiringOutputAnnotation(aIT.ref("x"), + aIT.addHierarchy("Top", "a2").ref("x"), + topMT.ref("t_a2_x"), + topMT.ref("t_clock2"))) + val check = + """circuit Top : + | module Top : + | input clock1: Clock + | input clock2: Clock + | output t_a1_x: UInt<1> + | output t_a2_x: UInt<1> + | output t_clock1: Clock + | output t_clock2: Clock + | inst a1 of A + | inst a2 of A + | a1.clock <= clock1 + | a2.clock <= clock2 + | t_a1_x <= a1.t_x + | t_a2_x <= a2.t_x + | t_clock1 <= clock1 + | t_clock2 <= clock2 + | module A : + | input clock : Clock + | output t_x : UInt<1> + | wire x : UInt<1> + | x <= UInt(1) + | t_x <= x + """.stripMargin + executeWithAnnos(input, check, annos, checkAnnos) + } +} diff --git a/sim/midas/src/test/scala/midas/FindClockSourcesSpec.scala b/sim/midas/src/test/scala/midas/FindClockSourcesSpec.scala new file mode 100644 index 00000000..81f94a18 --- /dev/null +++ b/sim/midas/src/test/scala/midas/FindClockSourcesSpec.scala @@ -0,0 +1,111 @@ + +package goldengate.tests + +import midas.passes._ + +import firrtl._ +import firrtl.ir._ +import firrtl.annotations._ +// Switch to FIRRTL in 3.2 +import midas.firrtl.testutils._ + + +class FindClockSourceSpec extends LowTransformSpec { + def transform = FindClockSources + + def executeAnnosOnly(input: String, annotations: Seq[Annotation], checkAnnotations: Seq[Annotation]): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + + annotations.foreach { anno => + logger.debug(anno.serialize) + } + + finalState.annotations.toSeq.foreach { anno => + logger.debug(anno.serialize) + } + + val csAnnos = finalState.annotations.collect { case a: ClockSourceAnnotation => a } + checkAnnotations.foreach { check => csAnnos should contain (check) } + finalState + } + + "FindClockSources" should s"find sources for submodule clocks" in { + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input clock2 : Clock + | inst a1 of A + | a1.clock <= clock + | module A : + | input clock : Clock + """.stripMargin + val aClockRT = ModuleTarget("Top", "A").addHierarchy("Top", "a1").ref("clock") + val topClockRT = ModuleTarget("Top", "Top").ref("clock") + val annos = Seq(FindClockSourceAnnotation(aClockRT)) + val checkAnnos = Seq(ClockSourceAnnotation(aClockRT, Some(topClockRT))) + executeAnnosOnly(input, annos, checkAnnos) + } + + it should s"find sources that pass through other modules" in { + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input clock2 : Clock + | inst a1 of A + | inst p of PassThru + | p.iclock <= clock + | a1.clock <= p.oclock + | module A : + | input clock : Clock + | module PassThru : + | input iclock : Clock + | output oclock : Clock + | oclock <= iclock + """.stripMargin + val aClockRT = ModuleTarget("Top", "A").addHierarchy("Top", "a1").ref("clock") + val topClockRT = ModuleTarget("Top", "Top").ref("clock") + val annos = Seq(FindClockSourceAnnotation(aClockRT)) + val checkAnnos = Seq(ClockSourceAnnotation(aClockRT, Some(topClockRT))) + executeAnnosOnly(input, annos, checkAnnos) + } + + it should s"find sources for clocks at the root of the module hiearchy" in { + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input clock2 : Clock + | wire c : Clock + | c <= clock2 + """.stripMargin + val topClockRT = ModuleTarget("Top", "Top").ref("clock2") + val wireClockRT = ModuleTarget("Top", "Top").ref("c") + val annos = Seq(FindClockSourceAnnotation(wireClockRT)) + val checkAnnos = Seq(ClockSourceAnnotation(wireClockRT, Some(topClockRT))) + executeAnnosOnly(input, annos, checkAnnos) + } + + it should s"find sources for intermediate nodes in a chain of clock connections" in { + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input clock2 : Clock + | wire c : Clock + | wire d : Clock + | c <= clock2 + | d <= c + """.stripMargin + val topClockRT = ModuleTarget("Top", "Top").ref("clock2") + val cClockRT = ModuleTarget("Top", "Top").ref("c") + val dClockRT = ModuleTarget("Top", "Top").ref("d") + val annos = Seq(FindClockSourceAnnotation(cClockRT), + FindClockSourceAnnotation(dClockRT)) + val checkAnnos = Seq(ClockSourceAnnotation(cClockRT, Some(topClockRT)), + ClockSourceAnnotation(dClockRT, Some(topClockRT))) + executeAnnosOnly(input, annos, checkAnnos) + } + +} diff --git a/sim/midas/src/test/scala/midas/testutils/FirrtlSpec.scala b/sim/midas/src/test/scala/midas/testutils/FirrtlSpec.scala new file mode 100644 index 00000000..9a2497eb --- /dev/null +++ b/sim/midas/src/test/scala/midas/testutils/FirrtlSpec.scala @@ -0,0 +1,405 @@ +// See LICENSE for license details. + +package midas.firrtl.testutils + +import java.io._ +import java.security.Permission + +import com.typesafe.scalalogging.LazyLogging + +import scala.sys.process._ +import org.scalatest._ +import org.scalatest.prop._ + +import firrtl._ +import firrtl.ir._ +import firrtl.Parser.{IgnoreInfo, UseInfo} +import firrtl.analyses.{GetNamespace, InstanceGraph, ModuleNamespaceAnnotation} +import firrtl.annotations._ +import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation, RenameModules} +import firrtl.util.BackendCompilationUtilities +import scala.collection.mutable + +class CheckLowForm extends SeqTransform { + def inputForm = LowForm + def outputForm = LowForm + def transforms = Seq( + passes.CheckHighForm + ) +} + +trait FirrtlRunners extends BackendCompilationUtilities { + + val cppHarnessResourceName: String = "/firrtl/testTop.cpp" + /** Extra transforms to run by default */ + val extraCheckTransforms = Seq(new CheckLowForm) + + private class RenameTop(newTopPrefix: String) extends Transform { + def inputForm: LowForm.type = LowForm + def outputForm: LowForm.type = LowForm + + def execute(state: CircuitState): CircuitState = { + val namespace = state.annotations.collectFirst { + case m: ModuleNamespaceAnnotation => m + }.get.namespace + + val newTopName = namespace.newName(newTopPrefix) + val modulesx = state.circuit.modules.map { + case mod: Module if mod.name == state.circuit.main => mod.mapString(_ => newTopName) + case other => other + } + + state.copy(circuit = state.circuit.copy(main = newTopName, modules = modulesx)) + } + } + + /** Check equivalence of Firrtl transforms using yosys + * + * @param input string containing Firrtl source + * @param customTransforms Firrtl transforms to test for equivalence + * @param customAnnotations Optional Firrtl annotations + * @param resets tell yosys which signals to set for SAT, format is (timestep, signal, value) + */ + def firrtlEquivalenceTest(input: String, + customTransforms: Seq[Transform] = Seq.empty, + customAnnotations: AnnotationSeq = Seq.empty, + resets: Seq[(Int, String, Int)] = Seq.empty): Unit = { + val circuit = Parser.parse(input.split("\n").toIterator) + val compiler = new MinimumVerilogCompiler + val prefix = circuit.main + val testDir = createTestDirectory(prefix + "_equivalence_test") + val firrtlWriter = new PrintWriter(s"${testDir.getAbsolutePath}/$prefix.fir") + firrtlWriter.write(input) + firrtlWriter.close() + + val customVerilog = compiler.compileAndEmit(CircuitState(circuit, HighForm, customAnnotations), + new GetNamespace +: new RenameTop(s"${prefix}_custom") +: customTransforms) + val namespaceAnnotation = customVerilog.annotations.collectFirst { case m: ModuleNamespaceAnnotation => m }.get + val customTop = customVerilog.circuit.main + val customFile = new PrintWriter(s"${testDir.getAbsolutePath}/$customTop.v") + customFile.write(customVerilog.getEmittedCircuit.value) + customFile.close() + + val referenceVerilog = compiler.compileAndEmit(CircuitState(circuit, HighForm, Seq(namespaceAnnotation)), + Seq(new RenameModules, new RenameTop(s"${prefix}_reference"))) + val referenceTop = referenceVerilog.circuit.main + val referenceFile = new PrintWriter(s"${testDir.getAbsolutePath}/$referenceTop.v") + referenceFile.write(referenceVerilog.getEmittedCircuit.value) + referenceFile.close() + + assert(yosysExpectSuccess(customTop, referenceTop, testDir, resets)) + } + + /** Compiles input Firrtl to Verilog */ + def compileToVerilog(input: String, annotations: AnnotationSeq = Seq.empty): String = { + val circuit = Parser.parse(input.split("\n").toIterator) + val compiler = new VerilogCompiler + val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms) + res.getEmittedCircuit.value + } + /** Compile a Firrtl file + * + * @param prefix is the name of the Firrtl file without path or file extension + * @param srcDir directory where all Resources for this test are located + * @param annotations Optional Firrtl annotations + */ + def compileFirrtlTest( + prefix: String, + srcDir: String, + customTransforms: Seq[Transform] = Seq.empty, + annotations: AnnotationSeq = Seq.empty): File = { + val testDir = createTestDirectory(prefix) + copyResourceToFile(s"${srcDir}/${prefix}.fir", new File(testDir, s"${prefix}.fir")) + + val optionsManager = new ExecutionOptionsManager(prefix) with HasFirrtlOptions { + commonOptions = CommonOptions(topName = prefix, targetDirName = testDir.getPath) + firrtlOptions = FirrtlExecutionOptions( + infoModeName = "ignore", + customTransforms = customTransforms ++ extraCheckTransforms, + annotations = annotations.toList) + } + firrtl.Driver.execute(optionsManager) + + testDir + } + /** Execute a Firrtl Test + * + * @param prefix is the name of the Firrtl file without path or file extension + * @param srcDir directory where all Resources for this test are located + * @param verilogPrefixes names of option Verilog resources without path or file extension + * @param annotations Optional Firrtl annotations + */ + def runFirrtlTest( + prefix: String, + srcDir: String, + verilogPrefixes: Seq[String] = Seq.empty, + customTransforms: Seq[Transform] = Seq.empty, + annotations: AnnotationSeq = Seq.empty) = { + val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations) + val harness = new File(testDir, s"top.cpp") + copyResourceToFile(cppHarnessResourceName, harness) + + // Note file copying side effect + val verilogFiles = verilogPrefixes map { vprefix => + val file = new File(testDir, s"$vprefix.v") + copyResourceToFile(s"$srcDir/$vprefix.v", file) + file + } + + verilogToCpp(prefix, testDir, verilogFiles, harness).! + cppToExe(prefix, testDir).! + assert(executeExpectingSuccess(prefix, testDir)) + } +} + +trait FirrtlMatchers extends Matchers { + def dontTouch(path: String): Annotation = { + val parts = path.split('.') + require(parts.size >= 2, "Must specify both module and component!") + val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top"))) + DontTouchAnnotation(name) + } + def dontDedup(mod: String): Annotation = { + require(mod.split('.').size == 1, "Can only specify a Module, not a component or instance") + NoDedupAnnotation(ModuleName(mod, CircuitName("Top"))) + } + // Replace all whitespace with a single space and remove leading and + // trailing whitespace + // Note this is intended for single-line strings, no newlines + def normalized(s: String): String = { + require(!s.contains("\n")) + s.replaceAll("\\s+", " ").trim + } + /** Helper to make circuits that are the same appear the same */ + def canonicalize(circuit: Circuit): Circuit = { + import firrtl.Mappers._ + def onModule(mod: DefModule) = mod.map(firrtl.Utils.squashEmpty) + circuit.map(onModule) + } + def parse(str: String) = Parser.parse(str.split("\n").toIterator, UseInfo) + /** Helper for executing tests + * compiler will be run on input then emitted result will each be split into + * lines and normalized. + */ + def executeTest( + input: String, + expected: Seq[String], + compiler: Compiler, + annotations: Seq[Annotation] = Seq.empty) = { + val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val lines = finalState.getEmittedCircuit.value split "\n" map normalized + for (e <- expected) { + lines should contain (e) + } + } +} + +object FirrtlCheckers extends FirrtlMatchers { + import matchers._ + implicit class TestingFunctionsOnCircuitState(val state: CircuitState) extends AnyVal { + def search(pf: PartialFunction[Any, Boolean]): Boolean = state.circuit.search(pf) + } + implicit class TestingFunctionsOnCircuit(val circuit: Circuit) extends AnyVal { + def search(pf: PartialFunction[Any, Boolean]): Boolean = { + val f = pf.lift + def rec(node: Any): Boolean = { + f(node) match { + // If the partial function is defined on this node, return its result + case Some(res) => res + // Otherwise keep digging + case None => + require(node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], + "Error! Unexpected FirrtlNode that does not implement Product!") + val iter = node match { + case p: Product => p.productIterator + case i: Iterable[Any] => i.iterator + case _ => Iterator.empty + } + iter.foldLeft(false) { + case (res, elt) => if (res) res else rec(elt) + } + } + } + rec(circuit) + } + } + + /** Checks that the emitted circuit has the expected line, both will be normalized */ + def containLine(expectedLine: String) = containLines(expectedLine) + + /** Checks that the emitted circuit has the expected lines in order, all lines will be normalized */ + def containLines(expectedLines: String*) = new CircuitStateStringsMatcher(expectedLines) + + class CircuitStateStringsMatcher(expectedLines: Seq[String]) extends Matcher[CircuitState] { + override def apply(state: CircuitState): MatchResult = { + val emitted = state.getEmittedCircuit.value + MatchResult( + emitted.split("\n").map(normalized).containsSlice(expectedLines.map(normalized)), + emitted + "\n did not contain \"" + expectedLines + "\"", + s"${state.circuit.main} contained $expectedLines" + ) + } + } + + def containTree(pf: PartialFunction[Any, Boolean]) = new CircuitStatePFMatcher(pf) + + class CircuitStatePFMatcher(pf: PartialFunction[Any, Boolean]) extends Matcher[CircuitState] { + override def apply(state: CircuitState): MatchResult = { + MatchResult( + state.search(pf), + state.circuit.serialize + s"\n did not contain $pf", + s"${state.circuit.main} contained $pf" + ) + } + } +} + +abstract class FirrtlPropSpec extends PropSpec with PropertyChecks with FirrtlRunners with LazyLogging + +abstract class FirrtlFlatSpec extends FlatSpec with FirrtlRunners with FirrtlMatchers with LazyLogging + +// Who tests the testers? +class TestFirrtlFlatSpec extends FirrtlFlatSpec { + import FirrtlCheckers._ + + val c = parse(""" + |circuit Test: + | module Test : + | input in : UInt<8> + | output out : UInt<8> + | out <= in + |""".stripMargin) + val state = CircuitState(c, ChirrtlForm) + val compiled = (new LowFirrtlCompiler).compileAndEmit(state, List.empty) + + // While useful, ScalaTest helpers should be used over search + behavior of "Search" + + it should "be supported on Circuit" in { + assert(c search { + case Connect(_, Reference("out",_), Reference("in",_)) => true + }) + } + it should "be supported on CircuitStates" in { + assert(state search { + case Connect(_, Reference("out",_), Reference("in",_)) => true + }) + } + it should "be supported on the results of compilers" in { + assert(compiled search { + case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true + }) + } + + // Use these!!! + behavior of "ScalaTest helpers" + + they should "work for lines of emitted text" in { + compiled should containLine (s"input in : UInt<8>") + compiled should containLine (s"output out : UInt<8>") + compiled should containLine (s"out <= in") + } + + they should "work for partial functions matching on subtrees" in { + val UInt8 = UIntType(IntWidth(8)) // BigInt unapply is weird + compiled should containTree { case Port(_, "in", Input, UInt8) => true } + compiled should containTree { case Port(_, "out", Output, UInt8) => true } + compiled should containTree { case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true } + } +} + +/** Super class for execution driven Firrtl tests */ +abstract class ExecutionTest(name: String, dir: String, vFiles: Seq[String] = Seq.empty) extends FirrtlPropSpec { + property(s"$name should execute correctly") { + runFirrtlTest(name, dir, vFiles) + } +} +/** Super class for compilation driven Firrtl tests */ +abstract class CompilationTest(name: String, dir: String) extends FirrtlPropSpec { + property(s"$name should compile correctly") { + compileFirrtlTest(name, dir) + } +} + +trait Utils { + + /** Run some Scala thunk and return STDOUT and STDERR as strings. + * @param thunk some Scala code + * @return a tuple containing STDOUT, STDERR, and what the thunk returns + */ + def grabStdOutErr[T](thunk: => T): (String, String, T) = { + val stdout, stderr = new ByteArrayOutputStream() + val ret = scala.Console.withOut(stdout) { scala.Console.withErr(stderr) { thunk } } + (stdout.toString, stderr.toString, ret) + } + + /** Encodes a System.exit exit code + * @param status the exit code + */ + private case class ExitException(status: Int) extends SecurityException(s"Found a sys.exit with code $status") + + /** A security manager that converts calls to System.exit into [[ExitException]]s by explicitly disabling the ability of + * a thread to actually exit. For more information, see: + * - https://docs.oracle.com/javase/tutorial/essential/environment/security.html + */ + private class ExceptOnExit extends SecurityManager { + override def checkPermission(perm: Permission): Unit = {} + override def checkPermission(perm: Permission, context: Object): Unit = {} + override def checkExit(status: Int): Unit = { + super.checkExit(status) + throw ExitException(status) + } + } + + /** Encodes a file that some code tries to write to + * @param the file name + */ + private case class WriteException(file: String) extends SecurityException(s"Tried to write to file $file") + + /** A security manager that converts writes to any file into [[WriteException]]s. + */ + private class ExceptOnWrite extends SecurityManager { + override def checkPermission(perm: Permission): Unit = {} + override def checkPermission(perm: Permission, context: Object): Unit = {} + override def checkWrite(file: String): Unit = { + super.checkWrite(file) + throw WriteException(file) + } + } + + /** Run some Scala code (a thunk) in an environment where all System.exit are caught and returned. This avoids a + * situation where a test results in something actually exiting and killing the entire test. This is necessary if you + * want to test a command line program, e.g., the `main` method of [[firrtl.options.Stage Stage]]. + * + * NOTE: THIS WILL NOT WORK IN SITUATIONS WHERE THE THUNK IS CATCHING ALL [[Exception]]s OR [[Throwable]]s, E.G., + * SCOPT. IF THIS IS HAPPENING THIS WILL NOT WORK. REPEAT THIS WILL NOT WORK. + * @param thunk some Scala code + * @return either the output of the thunk (`Right[T]`) or an exit code (`Left[Int]`) + */ + def catchStatus[T](thunk: => T): Either[Int, T] = { + try { + System.setSecurityManager(new ExceptOnExit()) + Right(thunk) + } catch { + case ExitException(a) => Left(a) + } finally { + System.setSecurityManager(null) + } + } + + /** Run some Scala code (a thunk) in an environment where file writes are caught and the file that a program tries to + * write to is returned. This is useful if you want to test that some thunk either tries to write to a specific file + * or doesn't try to write at all. + */ + def catchWrites[T](thunk: => T): Either[String, T] = { + try { + System.setSecurityManager(new ExceptOnWrite()) + Right(thunk) + } catch { + case WriteException(a) => Left(a) + } finally { + System.setSecurityManager(null) + } + } + +} diff --git a/sim/midas/src/test/scala/midas/testutils/PassTests.scala b/sim/midas/src/test/scala/midas/testutils/PassTests.scala new file mode 100644 index 00000000..0e7b1306 --- /dev/null +++ b/sim/midas/src/test/scala/midas/testutils/PassTests.scala @@ -0,0 +1,105 @@ +// See LICENSE for license details. + +package midas.firrtl.testutils + +import java.io.{StringWriter,Writer} +import org.scalatest.{FlatSpec, Matchers} +import org.scalatest.junit.JUnitRunner +import firrtl.ir.Circuit +import firrtl.Parser.UseInfo +import firrtl.passes.{Pass, PassExceptions, RemoveEmpty} +import firrtl._ +import firrtl.annotations._ +import logger._ + +// An example methodology for testing Firrtl Passes +// Spec class should extend this class +abstract class SimpleTransformSpec extends FlatSpec with FirrtlMatchers with Compiler with LazyLogging { + // Utility function + def squash(c: Circuit): Circuit = RemoveEmpty.run(c) + + // Executes the test. Call in tests. + // annotations cannot have default value because scalatest trait Suite has a default value + def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be (expected) + finalState + } + + def executeWithAnnos(input: String, check: String, annotations: Seq[Annotation], + checkAnnotations: Seq[Annotation]): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be (expected) + + annotations.foreach { anno => + logger.debug(anno.serialize) + } + + finalState.annotations.toSeq.foreach { anno => + logger.debug(anno.serialize) + } + checkAnnotations.foreach { check => + (finalState.annotations.toSeq) should contain (check) + } + finalState + } + // Executes the test, should throw an error + // No default to be consistent with execute + def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { + intercept[PassExceptions] { + compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty) + } + } +} + +class CustomResolveAndCheck(form: CircuitForm) extends SeqTransform { + def inputForm = form + def outputForm = form + def transforms: Seq[Transform] = Seq[Transform](new ResolveAndCheck) +} + +trait LowTransformSpec extends SimpleTransformSpec { + def emitter = new LowFirrtlEmitter + def transform: Transform + def transforms: Seq[Transform] = Seq( + new ChirrtlToHighFirrtl(), + new IRToWorkingIR(), + new ResolveAndCheck(), + new HighFirrtlToMiddleFirrtl(), + new MiddleFirrtlToLowFirrtl(), + new CustomResolveAndCheck(LowForm), + transform + ) +} + +trait MiddleTransformSpec extends SimpleTransformSpec { + def emitter = new MiddleFirrtlEmitter + def transform: Transform + def transforms: Seq[Transform] = Seq( + new ChirrtlToHighFirrtl(), + new IRToWorkingIR(), + new ResolveAndCheck(), + new HighFirrtlToMiddleFirrtl(), + new CustomResolveAndCheck(MidForm), + transform + ) +} + +trait HighTransformSpec extends SimpleTransformSpec { + def emitter = new HighFirrtlEmitter + def transform: Transform + def transforms = Seq( + new ChirrtlToHighFirrtl(), + new IRToWorkingIR(), + new CustomResolveAndCheck(HighForm), + transform + ) +} diff --git a/sim/midas/targetutils/src/main/scala/midas/annotations.scala b/sim/midas/targetutils/src/main/scala/midas/annotations.scala index 614cc60d..899bff94 100644 --- a/sim/midas/targetutils/src/main/scala/midas/annotations.scala +++ b/sim/midas/targetutils/src/main/scala/midas/annotations.scala @@ -3,11 +3,10 @@ package midas.targetutils import chisel3._ -import chisel3.experimental.{BaseModule, ChiselAnnotation} +import chisel3.experimental.{BaseModule, ChiselAnnotation, annotate} import firrtl.{RenameMap} -import firrtl.annotations.{NoTargetAnnotation, SingleTargetAnnotation, ComponentName} // Deprecated -import firrtl.annotations.{ReferenceTarget, ModuleTarget, AnnotationException} +import firrtl.annotations._ // This is currently consumed by a transformation that runs after MIDAS's core // transformations In FireSim, targeting an F1 host, these are consumed by the @@ -33,9 +32,11 @@ private[midas] class ReferenceTargetRenamer(renames: RenameMap) { // TODO: determine order for multiple renames, or just check of == 1 rename? def exactRename(rt: ReferenceTarget): ReferenceTarget = { val renameMatches = renames.get(rt).getOrElse(Seq(rt)).collect({ case rt: ReferenceTarget => rt }) - assert(renameMatches.length == 1) + assert(renameMatches.length == 1, + s"${rt} should be renamed exactly once. Suggested renames: ${renameMatches}") renameMatches.head } + def apply(rt: ReferenceTarget): Seq[ReferenceTarget] = { renames.get(rt).getOrElse(Seq(rt)).collect({ case rt: ReferenceTarget => rt }) } @@ -92,7 +93,51 @@ object SynthesizePrintf { // TODO: Accept a printable -> need to somehow get the format string from } -// This labels a target Mem so that it is extracted and replaced with a separate model + +/** + * A mixed-in ancestor trait for all FAME annotations, useful for type-casing. + */ +trait FAMEAnnotation { + this: Annotation => +} + +/** + * This labels an instance so that it is extracted as a separate FAME model. + */ +case class FAMEModelAnnotation(target: BaseModule) extends chisel3.experimental.ChiselAnnotation { + def toFirrtl: FirrtlFAMEModelAnnotation = { + val parent = ModuleTarget(target.toNamed.circuit.name, target.parentModName) + FirrtlFAMEModelAnnotation(parent.instOf(target.instanceName, target.name)) + } +} + +case class FirrtlFAMEModelAnnotation( + target: InstanceTarget) extends SingleTargetAnnotation[InstanceTarget] with FAMEAnnotation { + def targets = Seq(target) + def duplicate(n: InstanceTarget) = this.copy(n) +} + +/** + * This specifies that the module should be automatically multi-threaded (Chisel annotator). + */ +case class EnableModelMultiThreadingAnnotation(target: BaseModule) extends chisel3.experimental.ChiselAnnotation { + def toFirrtl: FirrtlEnableModelMultiThreadingAnnotation = { + FirrtlEnableModelMultiThreadingAnnotation(target.toNamed.toTarget) + } +} + +/** + * This specifies that the module should be automatically multi-threaded (FIRRTL annotation). + */ +case class FirrtlEnableModelMultiThreadingAnnotation( + target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] with FAMEAnnotation { + def targets = Seq(target) + def duplicate(n: ModuleTarget) = this.copy(n) +} + +/** + * This labels a target Mem so that it is extracted and replaced with a separate model. + */ case class MemModelAnnotation[T <: chisel3.Data](target: chisel3.MemBase[T]) extends chisel3.experimental.ChiselAnnotation { def toFirrtl = FirrtlMemModelAnnotation(target.toNamed.toTarget) @@ -116,38 +161,177 @@ object ExcludeInstanceAsserts { } -//AutoCounter annotations - -case class AutoCounterCoverAnnotation(target: ReferenceTarget, label: String, message: String) extends - SingleTargetAnnotation[ReferenceTarget] { - def duplicate(n: ReferenceTarget) = this.copy(target = n) -} - -case class AutoCounterFirrtlAnnotation(target: ReferenceTarget, label: String, message: String) extends - SingleTargetAnnotation[ReferenceTarget] { - def duplicate(n: ReferenceTarget) = this.copy(target = n) +/** + * AutoCounter annotations. Do not emit the FIRRTL annotations unless you are + * writing a target transformation, use the Chisel-side [[PerfCounter]] object + * instead. + * + */ +case class AutoCounterFirrtlAnnotation( + target: ReferenceTarget, + clock: ReferenceTarget, + reset: ReferenceTarget, + label: String, + message: String, + coverGenerated: Boolean = false) + extends firrtl.annotations.Annotation { + def update(renames: RenameMap): Seq[firrtl.annotations.Annotation] = { + val renamer = new ReferenceTargetRenamer(renames) + val renamedTarget = renamer.exactRename(target) + val renamedClock = renamer.exactRename(clock) + val renamedReset = renamer.exactRename(reset) + Seq(this.copy(target = renamedTarget, clock = renamedClock, reset = renamedReset)) + } + // The AutoCounter tranform will reject this annotation if it's not enclosed + def shouldBeIncluded(modList: Seq[String]): Boolean = !coverGenerated || modList.contains(target.module) + def enclosingModule(): String = target.module + def enclosingModuleTarget(): ModuleTarget = ModuleTarget(target.circuit, enclosingModule) } case class AutoCounterCoverModuleFirrtlAnnotation(target: ModuleTarget) extends - SingleTargetAnnotation[ModuleTarget] { + SingleTargetAnnotation[ModuleTarget] with FAMEAnnotation { def duplicate(n: ModuleTarget) = this.copy(target = n) } -import chisel3.experimental.ChiselAnnotation case class AutoCounterCoverModuleAnnotation(target: String) extends ChiselAnnotation { //TODO: fix the CircuitName arguemnt of ModuleTarget after chisel implements Target //It currently doesn't matter since the transform throws away the circuit name def toFirrtl = AutoCounterCoverModuleFirrtlAnnotation(ModuleTarget("",target)) } -case class AutoCounterAnnotation(target: chisel3.Data, label: String, message: String) extends ChiselAnnotation { - def toFirrtl = AutoCounterFirrtlAnnotation(target.toNamed.toTarget, label, message) +object PerfCounter { + /** + * Annotates a Bool representing a target event (ex. L1 D$ miss) that + * should be tracked by AutoCounter + * + * @param target The event + * + * @param clock The clock to which this event is sychronized. + * + * @param reset If the event is asserted while under the provide reset, it + * is not counted. TODO: This should be made optional. + * + * @param label A verilog-friendly identifier for the event signal + * + * @param message A description of the event. + * + */ + def apply(target: chisel3.Bool, + clock: chisel3.Clock, + reset: Reset, + label: String, + message: String): Unit = { + dontTouch(reset) + dontTouch(target) + dontTouch(clock) + annotate(new ChiselAnnotation { + def toFirrtl = AutoCounterFirrtlAnnotation(target.toTarget, clock.toTarget, + reset.toTarget, label, message) + }) + } + + /** + * A simplified variation of the full apply method above that uses the + * implicit clock and reset. + */ + def apply(target: chisel3.Bool, label: String, message: String): Unit = + apply(target, Module.clock, Module.reset, label, message) } -object PerfCounter { - def apply(target: chisel3.Data, label: String, message: String): Unit = { - chisel3.experimental.annotate(AutoCounterAnnotation(target, label, message)) +// Need serialization utils to be upstreamed to FIRRTL before i can use these. +//sealed trait TriggerSourceType +//case object Credit extends TriggerSourceType +//case object Debit extends TriggerSourceType + +case class TriggerSourceAnnotation( + target: ReferenceTarget, + clock: ReferenceTarget, + reset: Option[ReferenceTarget], + sourceType: Boolean) extends Annotation with FAMEAnnotation{ + def update(renames: RenameMap): Seq[firrtl.annotations.Annotation] = { + val renamer = new ReferenceTargetRenamer(renames) + val renamedTarget = renamer.exactRename(target) + val renamedClock = renamer.exactRename(clock) + val renamedReset = reset map renamer.exactRename + Seq(this.copy(target = renamedTarget, clock = renamedClock, reset = renamedReset)) + } + def enclosingModuleTarget(): ModuleTarget = ModuleTarget(target.circuit, target.module) + def enclosingModule(): String = target.module +} + + +case class TriggerSinkAnnotation( + target: ReferenceTarget, + clock: ReferenceTarget) extends Annotation with FAMEAnnotation { + def update(renames: RenameMap): Seq[firrtl.annotations.Annotation] = { + val renamer = new ReferenceTargetRenamer(renames) + val renamedTarget = renamer.exactRename(target) + val renamedClock = renamer.exactRename(clock) + Seq(this.copy(target = renamedTarget, clock = renamedClock)) + } + def enclosingModuleTarget(): ModuleTarget = ModuleTarget(target.circuit, target.module) +} + +object TriggerSource { + private def annotateTrigger(tpe: Boolean)(target: Bool, reset: Option[Bool]): Unit = { + // Hack: Create dummy nodes until chisel-side instance annotations have been improved + val clock = WireDefault(Module.clock) + dontTouch(target) + dontTouch(clock) + reset.map(dontTouch.apply) + annotate(new ChiselAnnotation { + def toFirrtl = TriggerSourceAnnotation(target.toNamed.toTarget, clock.toNamed.toTarget, reset.map(_.toTarget), tpe) + }) + } + def annotateCredit = annotateTrigger(true) _ + def annotateDebit = annotateTrigger(false) _ + + /** + * Methods to annotate a Boolean as a trigger credit or debit. Credits and + * debits issued while the module's implicit reset is asserted are not + * counted. + */ + def credit(credit: Bool): Unit = annotateCredit(credit, Some(Module.reset.toBool)) + def debit(debit: Bool): Unit = annotateDebit(debit, Some(Module.reset.toBool)) + def apply(creditSig: Bool, debitSig: Bool): Unit = { + credit(creditSig) + debit(debitSig) + } + + /** + * Variations of the above methods that count credits and debits provided + * while the implicit reset is asserted. + */ + def creditEvenUnderReset(credit: Bool): Unit = annotateCredit(credit, None) + def debitEvenUnderReset(debit: Bool): Unit = annotateDebit(debit, None) + def evenUnderReset(creditSig: Bool, debitSig: Bool): Unit = { + creditEvenUnderReset(creditSig) + debitEvenUnderReset(debitSig) } } - +object TriggerSink { + /** + * Marks a bool as receiving the global trigger signal. + * + * @param target A Bool node that will be driven with the trigger + * + * @param noSourceDefault The value that the trigger signal should take on + * if no trigger soruces are found in the target. This is a temporary parameter required + * while this apply method generates a wire. Otherwise this can be punted to the target's RTL. + */ + def apply(target: Bool, noSourceDefault: =>Bool = true.B): Unit = { + // Hack: Create dummy nodes until chisel-side instance annotations have been improved + val targetWire = WireDefault(noSourceDefault) + val clock = Module.clock + target := targetWire + dontTouch(targetWire) + // Both the provided node and the generated one need to be dontTouched to stop + // constProp from optimizing the down stream logic(?) + dontTouch(target) + dontTouch(clock) + annotate(new ChiselAnnotation { + def toFirrtl = TriggerSinkAnnotation(targetWire.toTarget, clock.toTarget) + }) + } +} diff --git a/sim/src/main/cc/fasedtests/fasedtests_top.cc b/sim/src/main/cc/fasedtests/fasedtests_top.cc index 03575b2f..365eb566 100644 --- a/sim/src/main/cc/fasedtests/fasedtests_top.cc +++ b/sim/src/main/cc/fasedtests/fasedtests_top.cc @@ -49,32 +49,6 @@ uint64_t host_mem_offset = -0x80000000LL; host_mem_offset += (1ULL << FASEDMEMORYTIMINGMODEL_0_target_addr_bits); #endif -// There can only be one instance of assert and print widgets as their IO is -// uniquely generated by a FIRRTL transform -#ifdef ASSERTIONBRIDGEMODULE_struct_guard - #ifdef ASSERTIONBRIDGEMODULE_0_PRESENT - ASSERTIONBRIDGEMODULE_0_substruct_create; - add_bridge_driver(new synthesized_assertions_t(this, ASSERTIONBRIDGEMODULE_0_substruct)); - #endif -#endif - -#ifdef PRINTBRIDGEMODULE_struct_guard - #ifdef PRINTBRIDGEMODULE_0_PRESENT - PRINTBRIDGEMODULE_0_substruct_create; - print_bridge = new synthesized_prints_t(this, - args, - PRINTBRIDGEMODULE_0_substruct, - PRINTBRIDGEMODULE_0_print_count, - PRINTBRIDGEMODULE_0_token_bytes, - PRINTBRIDGEMODULE_0_idle_cycles_mask, - PRINTBRIDGEMODULE_0_print_offsets, - PRINTBRIDGEMODULE_0_format_strings, - PRINTBRIDGEMODULE_0_argument_counts, - PRINTBRIDGEMODULE_0_argument_widths, - PRINTBRIDGEMODULE_0_DMA_ADDR); - add_bridge_driver(print_bridge); - #endif -#endif // Add functions you'd like to periodically invoke on a paused simulator here. if (profile_interval != -1) { register_task([this](){ return this->profile_models();}, 0); diff --git a/sim/src/main/cc/firesim/firesim_top.cc b/sim/src/main/cc/firesim/firesim_top.cc index 681f6fd7..d980bed9 100644 --- a/sim/src/main/cc/firesim/firesim_top.cc +++ b/sim/src/main/cc/firesim/firesim_top.cc @@ -309,36 +309,52 @@ uint64_t host_mem_offset = -0x80000000LL; #ifdef TRACERVBRIDGEMODULE_struct_guard #ifdef TRACERVBRIDGEMODULE_0_PRESENT - TRACERVBRIDGEMODULE_0_substruct_create; - add_bridge_driver(new tracerv_t(this, args, TRACERVBRIDGEMODULE_0_substruct, 0, TRACERVBRIDGEMODULE_0_DMA_ADDR)); + INSTANTIATE_TRACERV(add_bridge_driver, 0) #endif #ifdef TRACERVBRIDGEMODULE_1_PRESENT - TRACERVBRIDGEMODULE_1_substruct_create; - add_bridge_driver(new tracerv_t(this, args, TRACERVBRIDGEMODULE_1_substruct, 1, TRACERVBRIDGEMODULE_1_DMA_ADDR)); + INSTANTIATE_TRACERV(add_bridge_driver, 1) #endif #ifdef TRACERVBRIDGEMODULE_2_PRESENT - TRACERVBRIDGEMODULE_2_substruct_create; - add_bridge_driver(new tracerv_t(this, args, TRACERVBRIDGEMODULE_2_substruct, 2, TRACERVBRIDGEMODULE_2_DMA_ADDR)); + INSTANTIATE_TRACERV(add_bridge_driver, 2) #endif #ifdef TRACERVBRIDGEMODULE_3_PRESENT - TRACERVBRIDGEMODULE_3_substruct_create; - add_bridge_driver(new tracerv_t(this, args, TRACERVBRIDGEMODULE_3_substruct, 3, TRACERVBRIDGEMODULE_3_DMA_ADDR)); + INSTANTIATE_TRACERV(add_bridge_driver, 3) #endif #ifdef TRACERVBRIDGEMODULE_4_PRESENT - TRACERVBRIDGEMODULE_4_substruct_create; - add_bridge_driver(new tracerv_t(this, args, TRACERVBRIDGEMODULE_4_substruct, 4, TRACERVBRIDGEMODULE_4_DMA_ADDR)); + INSTANTIATE_TRACERV(add_bridge_driver, 4) #endif #ifdef TRACERVBRIDGEMODULE_5_PRESENT - TRACERVBRIDGEMODULE_5_substruct_create; - add_bridge_driver(new tracerv_t(this, args, TRACERVBRIDGEMODULE_5_substruct, 5, TRACERVBRIDGEMODULE_5_DMA_ADDR)); + INSTANTIATE_TRACERV(add_bridge_driver, 5) #endif #ifdef TRACERVBRIDGEMODULE_6_PRESENT - TRACERVBRIDGEMODULE_6_substruct_create; - add_bridge_driver(new tracerv_t(this, args, TRACERVBRIDGEMODULE_6_substruct, 6, TRACERVBRIDGEMODULE_6_DMA_ADDR)); + INSTANTIATE_TRACERV(add_bridge_driver, 6) #endif #ifdef TRACERVBRIDGEMODULE_7_PRESENT - TRACERVBRIDGEMODULE_7_substruct_create; - add_bridge_driver(new tracerv_t(this, args, TRACERVBRIDGEMODULE_7_substruct, 7, TRACERVBRIDGEMODULE_7_DMA_ADDR)); + INSTANTIATE_TRACERV(add_bridge_driver, 7) + #endif + #ifdef TRACERVBRIDGEMODULE_8_PRESENT + INSTANTIATE_TRACERV(add_bridge_driver, 8) + #endif + #ifdef TRACERVBRIDGEMODULE_9_PRESENT + INSTANTIATE_TRACERV(add_bridge_driver, 9) + #endif + #ifdef TRACERVBRIDGEMODULE_10_PRESENT + INSTANTIATE_TRACERV(add_bridge_driver, 10) + #endif + #ifdef TRACERVBRIDGEMODULE_11_PRESENT + INSTANTIATE_TRACERV(add_bridge_driver, 11) + #endif + #ifdef TRACERVBRIDGEMODULE_12_PRESENT + INSTANTIATE_TRACERV(add_bridge_driver, 12) + #endif + #ifdef TRACERVBRIDGEMODULE_13_PRESENT + INSTANTIATE_TRACERV(add_bridge_driver, 13) + #endif + #ifdef TRACERVBRIDGEMODULE_14_PRESENT + INSTANTIATE_TRACERV(add_bridge_driver, 14) + #endif + #ifdef TRACERVBRIDGEMODULE_15_PRESENT + INSTANTIATE_TRACERV(add_bridge_driver, 15) #endif #endif @@ -385,117 +401,118 @@ uint64_t host_mem_offset = -0x80000000LL; #endif #endif -#ifdef AUTOCOUNTERBRIDGEMODULE_struct_guard - #ifdef AUTOCOUNTERBRIDGEMODULE_0_PRESENT - AUTOCOUNTERBRIDGEMODULE_0_substruct_create; - add_bridge_driver(new autocounter_t( - this, args, AUTOCOUNTERBRIDGEMODULE_0_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_0_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_0_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_0_R_names, - AUTOCOUNTERBRIDGEMODULE_0_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_0_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_0_W_names), 0)); - #endif - #ifdef AUTOCOUNTERBRIDGEMODULE_1_PRESENT - AUTOCOUNTERBRIDGEMODULE_1_substruct_create; - add_bridge_driver(new autocounter_t( - this, args, AUTOCOUNTERBRIDGEMODULE_1_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_1_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_1_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_1_R_names, - AUTOCOUNTERBRIDGEMODULE_1_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_1_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_1_W_names), 1)); - #endif - #ifdef AUTOCOUNTERBRIDGEMODULE_2_PRESENT - AUTOCOUNTERBRIDGEMODULE_2_substruct_create; - add_bridge_driver(new autocounter_t( - this, args, AUTOCOUNTERBRIDGEMODULE_2_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_2_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_2_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_2_R_names, - AUTOCOUNTERBRIDGEMODULE_2_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_2_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_2_W_names), 2)); - #endif - #ifdef AUTCOUNTERBRIDGEMODULE_3_PRESENT - AUTOCOUNTERBRIDGEMODULE_3_substruct_create; - add_bridge_driver(new autocounter_t( - this, args, AUTOCOUNTERBRIDGEMODULE_3_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_3_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_3_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_3_R_names, - AUTOCOUNTERBRIDGEMODULE_3_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_3_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_3_W_names), 3)); - #endif - #ifdef AUTOCOUNTERBRIDGEMODULE_4_PRESENT - AUTOCOUNTERBRIDGEMODULE_4_substruct_create; - add_bridge_driver(new autocounter_t( - this, args, AUTOCOUNTERBRIDGEMODULE_4_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_4_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_4_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_4_R_names, - AUTOCOUNTERBRIDGEMODULE_4_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_4_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_4_W_names), 4)); - #endif - #ifdef AUTOCOUNTERBRIDGEMODULE_5_PRESENT - AUTOCOUNTERBRIDGEMODULE_5_substruct_create; - add_bridge_driver(new autocounter_t( - this, args, AUTOCOUNTERBRIDGEMODULE_5_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_5_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_5_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_5_R_names, - AUTOCOUNTERBRIDGEMODULE_5_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_5_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_5_W_names), 5)); - #endif - #ifdef AUTOCOUNTERBRIDGEMODULE_6_PRESENT - AUTOCOUNTERBRIDGEMODULE_6_substruct_create; - add_bridge_driver(new autocounter_t( - this, args, AUTOCOUNTERBRIDGEMODULE_6_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_6_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_6_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_6_R_names, - AUTOCOUNTERBRIDGEMODULE_6_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_6_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_6_W_names), 6)); - #endif - #ifdef AUTOCOUNTERBRIDGEMODULE_7_PRESENT - AUTOCOUNTERBRIDGEMODULE_7_substruct_create; - add_bridge_driver(new autocounter_t( - this, args, AUTOCOUNTERBRIDGEMODULE_7_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_7_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_7_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_7_R_names, - AUTOCOUNTERBRIDGEMODULE_7_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_7_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_7_W_names), 7)); - #endif +#ifdef AUTOCOUNTERBRIDGEMODULE_0_PRESENT + INSTANTIATE_AUTOCOUNTER(add_bridge_driver, 0) +#endif +#ifdef AUTOCOUNTERBRIDGEMODULE_1_PRESENT + INSTANTIATE_AUTOCOUNTER(add_bridge_driver, 1) +#endif +#ifdef AUTOCOUNTERBRIDGEMODULE_2_PRESENT + INSTANTIATE_AUTOCOUNTER(add_bridge_driver, 2) +#endif +#ifdef AUTOCOUNTERBRIDGEMODULE_3_PRESENT + INSTANTIATE_AUTOCOUNTER(add_bridge_driver, 3) +#endif +#ifdef AUTOCOUNTERBRIDGEMODULE_4_PRESENT + INSTANTIATE_AUTOCOUNTER(add_bridge_driver, 4) +#endif +#ifdef AUTOCOUNTERBRIDGEMODULE_5_PRESENT + INSTANTIATE_AUTOCOUNTER(add_bridge_driver, 5) +#endif +#ifdef AUTOCOUNTERBRIDGEMODULE_6_PRESENT + INSTANTIATE_AUTOCOUNTER(add_bridge_driver, 6) +#endif +#ifdef AUTOCOUNTERBRIDGEMODULE_7_PRESENT + INSTANTIATE_AUTOCOUNTER(add_bridge_driver, 7) #endif -// There can only be one instance of assert and print widgets as their IO is -// uniquely generated by a FIRRTL transform #ifdef ASSERTBRIDGEMODULE_0_PRESENT ASSERTBRIDGEMODULE_0_substruct_create - add_bridge_driver(new synthesized_assertions_t(this, ASSERTBRIDGEMODULE_0_substruct)); + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_0_substruct, + ASSERTBRIDGEMODULE_0_assert_count, + ASSERTBRIDGEMODULE_0_assert_messages)); +#endif +#ifdef ASSERTBRIDGEMODULE_1_PRESENT + ASSERTBRIDGEMODULE_1_substruct_create + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_1_substruct, + ASSERTBRIDGEMODULE_1_assert_count, + ASSERTBRIDGEMODULE_1_assert_messages)); +#endif +#ifdef ASSERTBRIDGEMODULE_2_PRESENT + ASSERTBRIDGEMODULE_2_substruct_create + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_2_substruct, + ASSERTBRIDGEMODULE_2_assert_count, + ASSERTBRIDGEMODULE_2_assert_messages)); +#endif +#ifdef ASSERTBRIDGEMODULE_3_PRESENT + ASSERTBRIDGEMODULE_3_substruct_create + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_3_substruct, + ASSERTBRIDGEMODULE_3_assert_count, + ASSERTBRIDGEMODULE_3_assert_messages)); +#endif +#ifdef ASSERTBRIDGEMODULE_3_PRESENT + ASSERTBRIDGEMODULE_3_substruct_create + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_3_substruct, + ASSERTBRIDGEMODULE_3_assert_count, + ASSERTBRIDGEMODULE_3_assert_messages)); +#endif +#ifdef ASSERTBRIDGEMODULE_4_PRESENT + ASSERTBRIDGEMODULE_4_substruct_create + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_4_substruct, + ASSERTBRIDGEMODULE_4_assert_count, + ASSERTBRIDGEMODULE_4_assert_messages)); +#endif +#ifdef ASSERTBRIDGEMODULE_5_PRESENT + ASSERTBRIDGEMODULE_5_substruct_create + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_5_substruct, + ASSERTBRIDGEMODULE_5_assert_count, + ASSERTBRIDGEMODULE_5_assert_messages)); +#endif +#ifdef ASSERTBRIDGEMODULE_6_PRESENT + ASSERTBRIDGEMODULE_6_substruct_create + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_6_substruct, + ASSERTBRIDGEMODULE_6_assert_count, + ASSERTBRIDGEMODULE_6_assert_messages)); +#endif +#ifdef ASSERTBRIDGEMODULE_7_PRESENT + ASSERTBRIDGEMODULE_7_substruct_create + add_bridge_driver(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_7_substruct, + ASSERTBRIDGEMODULE_7_assert_count, + ASSERTBRIDGEMODULE_7_assert_messages)); #endif #ifdef PRINTBRIDGEMODULE_0_PRESENT - PRINTBRIDGEMODULE_0_substruct_create; - add_bridge_driver(new synthesized_prints_t(this, - args, - PRINTBRIDGEMODULE_0_substruct, - PRINTBRIDGEMODULE_0_print_count, - PRINTBRIDGEMODULE_0_token_bytes, - PRINTBRIDGEMODULE_0_idle_cycles_mask, - PRINTBRIDGEMODULE_0_print_offsets, - PRINTBRIDGEMODULE_0_format_strings, - PRINTBRIDGEMODULE_0_argument_counts, - PRINTBRIDGEMODULE_0_argument_widths, - PRINTBRIDGEMODULE_0_DMA_ADDR)); + INSTANTIATE_PRINTF(add_bridge_driver,0) +#endif +#ifdef PRINTBRIDGEMODULE_1_PRESENT + INSTANTIATE_PRINTF(add_bridge_driver,1) +#endif +#ifdef PRINTBRIDGEMODULE_2_PRESENT + INSTANTIATE_PRINTF(add_bridge_driver,2) +#endif +#ifdef PRINTBRIDGEMODULE_3_PRESENT + INSTANTIATE_PRINTF(add_bridge_driver,3) +#endif +#ifdef PRINTBRIDGEMODULE_4_PRESENT + INSTANTIATE_PRINTF(add_bridge_driver,4) +#endif +#ifdef PRINTBRIDGEMODULE_5_PRESENT + INSTANTIATE_PRINTF(add_bridge_driver,5) +#endif +#ifdef PRINTBRIDGEMODULE_6_PRESENT + INSTANTIATE_PRINTF(add_bridge_driver,6) +#endif +#ifdef PRINTBRIDGEMODULE_7_PRESENT + INSTANTIATE_PRINTF(add_bridge_driver,7) #endif // Add functions you'd like to periodically invoke on a paused simulator here. if (profile_interval != -1) { @@ -576,6 +593,7 @@ void firesim_top_t::run() { fprintf(stderr, "time elapsed: %.1f s, simulation speed = %.2f KHz\n", sim_time, sim_speed); } double fmr = ((double) hcycles / end_cycle); + // This returns the FMR of the fastest target clock fprintf(stderr, "FPGA-Cycles-to-Model-Cycles Ratio (FMR): %.2f\n", fmr); expect(!exitcode, NULL); diff --git a/sim/src/main/cc/midasexamples/AssertModule.h b/sim/src/main/cc/midasexamples/AssertModule.h index 1cfd9413..1f3eb744 100644 --- a/sim/src/main/cc/midasexamples/AssertModule.h +++ b/sim/src/main/cc/midasexamples/AssertModule.h @@ -9,7 +9,10 @@ public: synthesized_assertions_t * assert_endpoint; AssertModule_t(int argc, char** argv) { ASSERTBRIDGEMODULE_0_substruct_create; - assert_endpoint = new synthesized_assertions_t(this, ASSERTBRIDGEMODULE_0_substruct); + assert_endpoint = new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_0_substruct, + ASSERTBRIDGEMODULE_0_assert_count, + ASSERTBRIDGEMODULE_0_assert_messages); }; void run() { int assertions_thrown = 0; diff --git a/sim/src/main/cc/midasexamples/AutoCounterCoverModule.h b/sim/src/main/cc/midasexamples/AutoCounterCoverModule.h index c4d213dc..8753d67d 100644 --- a/sim/src/main/cc/midasexamples/AutoCounterCoverModule.h +++ b/sim/src/main/cc/midasexamples/AutoCounterCoverModule.h @@ -11,7 +11,9 @@ class AutoCounterCoverModule_t: public autocounter_module_t, virtual simif_t public: AutoCounterCoverModule_t(int argc, char** argv): autocounter_module_t(argc, argv) {}; virtual void run() { - autocounter_endpoint->init(); + for (auto &autocounter_endpoint: autocounter_endpoints) { + autocounter_endpoint->init(); + } poke(reset, 1); poke(io_a, 0); step(1); diff --git a/sim/src/main/cc/midasexamples/AutoCounterModule.h b/sim/src/main/cc/midasexamples/AutoCounterModule.h index e7f9d31b..e2108040 100644 --- a/sim/src/main/cc/midasexamples/AutoCounterModule.h +++ b/sim/src/main/cc/midasexamples/AutoCounterModule.h @@ -8,26 +8,24 @@ class autocounter_module_t: virtual simif_t { public: - std::unique_ptr autocounter_endpoint; + std::vector autocounter_endpoints; autocounter_module_t(int argc, char** argv) { - AUTOCOUNTERBRIDGEMODULE_0_substruct_create; std::vector args(argv + 1, argv + argc); - autocounter_endpoint = std::unique_ptr(new autocounter_t(this, - args, - AUTOCOUNTERBRIDGEMODULE_0_substruct, - AddressMap(AUTOCOUNTERBRIDGEMODULE_0_R_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_0_R_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_0_R_names, - AUTOCOUNTERBRIDGEMODULE_0_W_num_registers, - (const unsigned int*) AUTOCOUNTERBRIDGEMODULE_0_W_addrs, - (const char* const*) AUTOCOUNTERBRIDGEMODULE_0_W_names), 0)); + INSTANTIATE_AUTOCOUNTER(autocounter_endpoints.push_back, 0) +#ifdef AUTOCOUNTERBRIDGEMODULE_1_PRESENT + INSTANTIATE_AUTOCOUNTER(autocounter_endpoints.push_back, 1) +#endif }; void run_and_collect(int cycles) { step(cycles, false); while (!done()) { - autocounter_endpoint->tick(); + for (auto &autocounter_endpoint: autocounter_endpoints) { + autocounter_endpoint->tick(); + } + } + for (auto &autocounter_endpoint: autocounter_endpoints) { + autocounter_endpoint->finish(); } - autocounter_endpoint->finish(); }; }; @@ -37,7 +35,9 @@ class AutoCounterModule_t: public autocounter_module_t, virtual simif_t public: AutoCounterModule_t(int argc, char** argv): autocounter_module_t(argc, argv) {}; virtual void run() { - autocounter_endpoint->init(); + for (auto &autocounter_endpoint: autocounter_endpoints) { + autocounter_endpoint->init(); + } poke(reset, 1); poke(io_a, 0); step(1); diff --git a/sim/src/main/cc/midasexamples/Driver.cc b/sim/src/main/cc/midasexamples/Driver.cc index 614bb9d7..aa10799a 100644 --- a/sim/src/main/cc/midasexamples/Driver.cc +++ b/sim/src/main/cc/midasexamples/Driver.cc @@ -33,14 +33,28 @@ #include "PrintfModule.h" #elif defined DESIGNNAME_NarrowPrintfModule #include "NarrowPrintfModule.h" +#elif defined DESIGNNAME_MulticlockPrintfModule +#include "MulticlockPrintfModule.h" #elif defined DESIGNNAME_AutoCounterModule #include "AutoCounterModule.h" #elif defined DESIGNNAME_AutoCounterCoverModule #include "AutoCounterCoverModule.h" +#elif defined DESIGNNAME_AutoCounterPrintfModule +#include "PrintfModule.h" +#elif defined DESIGNNAME_MulticlockAutoCounterModule +#include "MulticlockAutoCounterModule.h" #elif defined DESIGNNAME_Accumulator #include "Accumulator.h" #elif defined DESIGNNAME_VerilogAccumulator #include "VerilogAccumulator.h" +#elif defined DESIGNNAME_TrivialMulticlock +#include "TrivialMulticlock.h" +#elif defined DESIGNNAME_MulticlockAssertModule +#include "MulticlockAssertModule.h" +#elif defined DESIGNNAME_TriggerWiringModule +#include "TriggerWiringModule.h" +#elif defined DESIGNNAME_TwoAdders +#include "TwoAdders.h" #endif class dut_emul_t: diff --git a/sim/src/main/cc/midasexamples/MulticlockAssertModule.h b/sim/src/main/cc/midasexamples/MulticlockAssertModule.h new file mode 100644 index 00000000..1951bca5 --- /dev/null +++ b/sim/src/main/cc/midasexamples/MulticlockAssertModule.h @@ -0,0 +1,45 @@ +//See LICENSE for license details. + +#include "simif.h" +#include "bridges/synthesized_assertions.h" + +class MulticlockAssertModule_t: virtual simif_t +{ +public: + std::vector assert_endpoints; + synthesized_assertions_t * full_rate_assert_ep; + synthesized_assertions_t * half_rate_assert_ep; + MulticlockAssertModule_t(int argc, char** argv) { + ASSERTBRIDGEMODULE_0_substruct_create; + ASSERTBRIDGEMODULE_1_substruct_create; + full_rate_assert_ep = new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_0_substruct, + ASSERTBRIDGEMODULE_0_assert_count, + ASSERTBRIDGEMODULE_0_assert_messages); + half_rate_assert_ep = new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_1_substruct, + ASSERTBRIDGEMODULE_1_assert_count, + ASSERTBRIDGEMODULE_1_assert_messages); + assert_endpoints.push_back(full_rate_assert_ep); + assert_endpoints.push_back(half_rate_assert_ep); + }; + void run() { + int assertions_thrown = 0; + poke(reset, 0); + poke(fullrate_pulseLength, 2); + poke(fullrate_cycle, 186); + poke(halfrate_pulseLength, 2); + poke(halfrate_cycle, 129); + step(256, false); + while (!done()) { + for (auto ep: assert_endpoints) { + ep->tick(); + if (ep->terminate()) { + ep->resume(); + assertions_thrown++; + } + } + } + expect(assertions_thrown == 3, "EXPECT: Two assertions thrown"); + }; +}; diff --git a/sim/src/main/cc/midasexamples/MulticlockAutoCounterModule.h b/sim/src/main/cc/midasexamples/MulticlockAutoCounterModule.h new file mode 100644 index 00000000..b7a59e98 --- /dev/null +++ b/sim/src/main/cc/midasexamples/MulticlockAutoCounterModule.h @@ -0,0 +1,21 @@ +//See LICENSE for license details. + +#include "AutoCounterModule.h" + +#ifdef DESIGNNAME_MulticlockAutoCounterModule +class MulticlockAutoCounterModule_t: public autocounter_module_t, virtual simif_t +{ +public: + MulticlockAutoCounterModule_t(int argc, char** argv): autocounter_module_t(argc, argv) {}; + virtual void run() { + for (auto &autocounter_endpoint: autocounter_endpoints) { + autocounter_endpoint->init(); + } + poke(reset, 1); + step(1); + poke(reset, 0); + run_and_collect(3000); + }; +}; +#endif //DESIGNNAME_MulticlockAutoCounterModule + diff --git a/sim/src/main/cc/midasexamples/MulticlockPrintfModule.h b/sim/src/main/cc/midasexamples/MulticlockPrintfModule.h new file mode 100644 index 00000000..d22ca0e6 --- /dev/null +++ b/sim/src/main/cc/midasexamples/MulticlockPrintfModule.h @@ -0,0 +1,20 @@ +//See LICENSE for license details. + +#include "PrintfModule.h" + +#ifdef DESIGNNAME_MulticlockPrintfModule +class MulticlockPrintfModule_t: public print_module_t, virtual simif_t +{ +public: + MulticlockPrintfModule_t(int argc, char** argv): print_module_t(argc, argv) {}; + virtual void run() { + for (auto &print_endpoint: print_endpoints) { + print_endpoint->init(); + } + step(1); + poke(reset, 0); + run_and_collect_prints(256); + }; +}; +#endif //DESIGNNAME_MulticlockPrintfModule + diff --git a/sim/src/main/cc/midasexamples/NarrowPrintfModule.h b/sim/src/main/cc/midasexamples/NarrowPrintfModule.h index 3b03b39b..d5e55b29 100644 --- a/sim/src/main/cc/midasexamples/NarrowPrintfModule.h +++ b/sim/src/main/cc/midasexamples/NarrowPrintfModule.h @@ -6,7 +6,9 @@ class NarrowPrintfModule_t: public print_module_t, virtual simif_t public: NarrowPrintfModule_t(int argc, char** argv): print_module_t(argc, argv) {}; virtual void run() { - print_endpoint->init(); + for (auto &print_endpoint: print_endpoints) { + print_endpoint->init(); + } poke(reset, 1); poke(io_enable, 0); step(1); diff --git a/sim/src/main/cc/midasexamples/PrintfModule.h b/sim/src/main/cc/midasexamples/PrintfModule.h index 94bbe8b8..11ad7e00 100644 --- a/sim/src/main/cc/midasexamples/PrintfModule.h +++ b/sim/src/main/cc/midasexamples/PrintfModule.h @@ -7,30 +7,28 @@ class print_module_t: virtual simif_t { - public: - std::unique_ptr print_endpoint; - print_module_t(int argc, char** argv) { - PRINTBRIDGEMODULE_0_substruct_create; - std::vector args(argv + 1, argv + argc); - print_endpoint = std::unique_ptr(new synthesized_prints_t(this, - args, - PRINTBRIDGEMODULE_0_substruct, - PRINTBRIDGEMODULE_0_print_count, - PRINTBRIDGEMODULE_0_token_bytes, - PRINTBRIDGEMODULE_0_idle_cycles_mask, - PRINTBRIDGEMODULE_0_print_offsets, - PRINTBRIDGEMODULE_0_format_strings, - PRINTBRIDGEMODULE_0_argument_counts, - PRINTBRIDGEMODULE_0_argument_widths, - PRINTBRIDGEMODULE_0_DMA_ADDR)); - }; - void run_and_collect_prints(int cycles) { - step(cycles, false); - while (!done()) { - print_endpoint->tick(); - } - print_endpoint->finish(); - }; + public: + std::vector print_endpoints; + print_module_t(int argc, char** argv) { + std::vector args(argv + 1, argv + argc); +#ifdef PRINTBRIDGEMODULE_0_PRESENT + INSTANTIATE_PRINTF(print_endpoints.push_back,0) +#endif +#ifdef PRINTBRIDGEMODULE_1_PRESENT + INSTANTIATE_PRINTF(print_endpoints.push_back,1) +#endif + }; + void run_and_collect_prints(int cycles) { + step(cycles, false); + while (!done()) { + for (auto &print_endpoint: print_endpoints) { + print_endpoint->tick(); + } + } + for (auto &print_endpoint: print_endpoints) { + print_endpoint->finish(); + } + }; }; #ifdef DESIGNNAME_PrintfModule @@ -39,7 +37,9 @@ class PrintfModule_t: public print_module_t, virtual simif_t public: PrintfModule_t(int argc, char** argv): print_module_t(argc, argv) {}; virtual void run() { - print_endpoint->init(); + for (auto &print_endpoint: print_endpoints) { + print_endpoint->init(); + } poke(reset, 1); poke(io_a, 0); poke(io_b, 0); @@ -52,3 +52,23 @@ public: }; }; #endif //DESIGNNAME_PrintfModule + +#ifdef DESIGNNAME_AutoCounterPrintfModule +class AutoCounterPrintfModule_t: public print_module_t, virtual simif_t +{ +public: + AutoCounterPrintfModule_t(int argc, char** argv): print_module_t(argc, argv) {}; + virtual void run() { + for (auto &print_endpoint: print_endpoints) { + print_endpoint->init(); + } + poke(reset, 1); + poke(io_a, 0); + step(1); + poke(reset, 0); + step(1); + poke(io_a, 1); + run_and_collect_prints(3000); + }; +}; +#endif // DESIGNNAME_AutoCounterPrintf diff --git a/sim/src/main/cc/midasexamples/TriggerWiringModule.h b/sim/src/main/cc/midasexamples/TriggerWiringModule.h new file mode 100644 index 00000000..ba85feaf --- /dev/null +++ b/sim/src/main/cc/midasexamples/TriggerWiringModule.h @@ -0,0 +1,49 @@ +//See LICENSE for license details. + +#include "simif.h" +#include "bridges/synthesized_assertions.h" + +class TriggerWiringModule_t: virtual simif_t +{ +public: + std::vector assert_endpoints; + TriggerWiringModule_t(int argc, char** argv) { + ASSERTBRIDGEMODULE_0_substruct_create; + ASSERTBRIDGEMODULE_1_substruct_create; + assert_endpoints.push_back(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_0_substruct, + ASSERTBRIDGEMODULE_0_assert_count, + ASSERTBRIDGEMODULE_0_assert_messages)); + assert_endpoints.push_back(new synthesized_assertions_t(this, + ASSERTBRIDGEMODULE_1_substruct, + ASSERTBRIDGEMODULE_1_assert_count, + ASSERTBRIDGEMODULE_1_assert_messages)); + }; + bool simulation_complete() { + bool is_complete = false; + for (auto &e: assert_endpoints) { + is_complete |= e->terminate(); + } + return is_complete; + } + int exit_code(){ + for (auto &e: assert_endpoints) { + if (e->exit_code()) + return e->exit_code(); + } + return 0; + } + void run() { + int assertions_thrown = 0; + poke(reset, 1); + step(1); + poke(reset, 0); + step(10000, false); + while (!done() && !simulation_complete()) { + for (auto ep: assert_endpoints) { + ep->tick(); + } + } + expect(!exit_code(), "No assertions should be thrown"); + } +}; diff --git a/sim/src/main/cc/midasexamples/TrivialMulticlock.h b/sim/src/main/cc/midasexamples/TrivialMulticlock.h new file mode 100644 index 00000000..4dd973ef --- /dev/null +++ b/sim/src/main/cc/midasexamples/TrivialMulticlock.h @@ -0,0 +1,58 @@ +//See LICENSE for license details. + +#include "simif.h" + +class MulticlockChecker { + public: + simif_t * sim; + uint32_t field_address; + int numerator, denominator; + int cycle = -1; + uint32_t expected_value; + uint32_t fast_domain_reg, slow_domain_reg, fast_domain_reg_out; + + MulticlockChecker(simif_t * sim, uint32_t field_address, int numerator, int denominator): + sim(sim), field_address(field_address), numerator(numerator), denominator(denominator) {}; + void expect_and_update(uint64_t poked_value){ + if (cycle > 1 ) sim->expect(field_address, fast_domain_reg_out); + if (cycle < 1) { + fast_domain_reg_out = slow_domain_reg; + slow_domain_reg = fast_domain_reg; + } else { + fast_domain_reg_out = slow_domain_reg; + if (((cycle * numerator) / denominator) > (((cycle - 1) * numerator)/ denominator)) { + // TODO: Handle the case where numerator * cycle is not a multiple of the division + //if (((cycle * numerator) % denominator) != 0) { + // fast_domain_reg_out = slow_domain_reg; + // slow_domain_reg = poked_value; + //} else { + slow_domain_reg = fast_domain_reg; + //} + } + } + fast_domain_reg = poked_value; + cycle++; + }; +}; + +class TrivialMulticlock_t: virtual simif_t +{ +public: + TrivialMulticlock_t(int argc, char** argv) {} + void run() { + uint64_t limit = 256; + std::vector checkers; + checkers.push_back(new MulticlockChecker(this, halfOut, 1, 2)); + checkers.push_back(new MulticlockChecker(this, thirdOut, 1, 3)); + // Resolve bug in PeekPoke Bridge + //checkers.push_back(new MulticlockChecker(this, threeSeventhsOut, 3, 7)); + + uint32_t current = rand_next(limit); + for(int i = 1; i < 1024; i++){ + for(auto checker: checkers) checker->expect_and_update(i); + poke(in, i); + current = rand_next(limit); + step(1); + } + } +}; diff --git a/sim/src/main/cc/midasexamples/TwoAdders.h b/sim/src/main/cc/midasexamples/TwoAdders.h new file mode 100644 index 00000000..7ca074c7 --- /dev/null +++ b/sim/src/main/cc/midasexamples/TwoAdders.h @@ -0,0 +1,34 @@ +//See LICENSE for license details. + +#include "simif.h" +#include "stdio.h" + +#define NTESTS 6 + +class TwoAdders_t: virtual simif_t +{ +public: + TwoAdders_t(int argc, char** argv) {} + uint32_t i0[NTESTS] = { 4, 8, 13, 26, 19, 0 }; + uint32_t i1[NTESTS] = { 31, 11, 99, 27, 43, 0 }; + uint32_t i2[NTESTS] = { 28, 7, 30, 2, 88, 0 }; + uint32_t i3[NTESTS] = { 67, 29, 50, 80, 59, 0 }; + void run() { + int i; + target_reset(); + poke(io_i0, i0[0]); + poke(io_i1, i1[0]); + poke(io_i2, i2[0]); + poke(io_i3, i3[0]); + step(1); + for (i = 1; i < NTESTS; i++) { + poke(io_i0, i0[i]); + poke(io_i1, i1[i]); + poke(io_i2, i2[i]); + poke(io_i3, i3[i]); + step(1); // has latency of 1 cycle + expect(io_o0, i0[i-1] + i1[i-1]); + expect(io_o1, i2[i-1] + i3[i-1]); + } + } +}; diff --git a/sim/src/main/makefrag/firesim/Makefrag b/sim/src/main/makefrag/firesim/Makefrag index e579e344..ed6f0a37 100644 --- a/sim/src/main/makefrag/firesim/Makefrag +++ b/sim/src/main/makefrag/firesim/Makefrag @@ -105,9 +105,9 @@ NET_MACADDR ?= $(shell printf '00:00:00:00:00:%02x' $$(($(NET_SLOT)+2))) nic_args = +shmemportname0=$(NET_SHMEMPORTNAME) +macaddr0=$(NET_MACADDR) \ +niclog0=niclog$(NET_SLOT) +linklatency0=$(NET_LINK_LATENCY) \ +netbw0=$(NET_BW) +netburst0=8 $(NET_LOOPBACK) -tracer_args = +tracefile0=TRACEFILE0 +tracer_args = +tracefile=TRACEFILE blkdev_args = +blkdev-in-mem0=128 +blkdev-log0=blkdev-log$(NET_SLOT) -autocounter_args = +autocounter-readrate0=1000 +autocounter-filename0=AUTOCOUNTERFILE0 +autocounter_args = +autocounter-readrate=1000 +autocounter-filename=AUTOCOUNTERFILE # Neglecting this +arg will make the simulator use the same step size as on the # FPGA. This will make ML simulation more closely match results seen on the # FPGA at the expense of dramatically increased target runtime diff --git a/sim/src/main/scala/fasedtests/AXI4Fuzzer.scala b/sim/src/main/scala/fasedtests/AXI4Fuzzer.scala index ebd150dc..18404f6a 100644 --- a/sim/src/main/scala/fasedtests/AXI4Fuzzer.scala +++ b/sim/src/main/scala/fasedtests/AXI4Fuzzer.scala @@ -11,7 +11,7 @@ import freechips.rocketchip.config.Parameters import junctions.{NastiKey, NastiParameters} import midas.models.{FASEDBridge, AXI4EdgeSummary, CompleteConfig} -import midas.widgets.{PeekPokeBridge} +import midas.widgets.{PeekPokeBridge, RationalClockBridge} object AXI4Printf { def apply(axi4: AXI4Bundle): Unit = { @@ -90,18 +90,18 @@ class AXI4FuzzerDUT(implicit p: Parameters) extends LazyModule with HasFuzzTarge } class AXI4Fuzzer(implicit val p: Parameters) extends RawModule { - val clock = IO(Input(Clock())) val reset = WireInit(false.B) - + val clockBridge = Module(new RationalClockBridge()) + val clock = clockBridge.io.clocks(0) withClockAndReset(clock, reset) { val fuzzer = Module((LazyModule(new AXI4FuzzerDUT)).module) val nastiKey = NastiParameters(fuzzer.axi4.r.bits.data.getWidth, fuzzer.axi4.ar.bits.addr.getWidth, fuzzer.axi4.ar.bits.id.getWidth) - val fasedInstance = FASEDBridge(fuzzer.axi4, reset, + val fasedInstance = FASEDBridge(clock, fuzzer.axi4, reset, CompleteConfig(p(firesim.configs.MemModelKey), nastiKey, Some(AXI4EdgeSummary(fuzzer.axi4Edge)))) - val peekPokeBridge = PeekPokeBridge(reset, + val peekPokeBridge = PeekPokeBridge(clock, reset, ("done", fuzzer.done), ("error", fuzzer.error)) } diff --git a/sim/src/main/scala/midasexamples/AssertModule.scala b/sim/src/main/scala/midasexamples/AssertModule.scala index 1ba9c74b..697a35be 100644 --- a/sim/src/main/scala/midasexamples/AssertModule.scala +++ b/sim/src/main/scala/midasexamples/AssertModule.scala @@ -3,6 +3,7 @@ package firesim.midasexamples import chisel3._ +import midas.widgets.{RationalClockBridge, PeekPokeBridge, RationalClock} class ChildModule extends Module { val io = IO(new Bundle { @@ -10,7 +11,6 @@ class ChildModule extends Module { }) assert(!io.pred, "Pred asserted") } - class AssertModuleDUT extends Module { val io = IO(new Bundle { val cycleToFail = Input(UInt(16.W)) @@ -31,3 +31,71 @@ class AssertModuleDUT extends Module { } class AssertModule extends PeekPokeMidasExampleHarness(() => new AssertModuleDUT) + +class RegisteredAssertModule extends Module { + val io = IO(new Bundle { + val pred = Input(Bool()) + }) + assert(!RegNext(io.pred), "Pred asserted") +} + +class DualClockModule extends Module { + val io = IO(new Bundle { + val clockB = Input(Clock()) + val a = Input(Bool()) + val b = Input(Bool()) + val c = Input(Bool()) + val d = Input(Bool()) + }) + + withClock(io.clockB) { + assert(!RegNext(io.a), "io.a asserted") + val modB = Module(new RegisteredAssertModule) + modB.io.pred := io.b + } + + assert(!RegNext(io.c), "io.c asserted") + val modA = Module(new RegisteredAssertModule) + modA.io.pred := io.d +} + +class StimulusGenerator extends MultiIOModule { + val input = IO(new Bundle { + val cycle = Input(UInt(16.W)) + val pulseLength = Input(UInt(4.W)) + }) + val pred = IO(Output(Bool())) + // Here i'm relying on zero-intialization of state instead of reset + val cycleCount = Reg(UInt(16.W)) + cycleCount := cycleCount + 1.U + + val pulseLengthRemaining = Reg(UInt(4.W)) + when(input.cycle === cycleCount && input.pulseLength =/= 0.U) { + pulseLengthRemaining := input.pulseLength - 1.U + }.elsewhen(pulseLengthRemaining =/= 0.U) { + pulseLengthRemaining := pulseLengthRemaining - 1.U + } + + pred := pulseLengthRemaining =/= 0.U || input.cycle === cycleCount +} + + +class MulticlockAssertModule extends RawModule { + val clockBridge = Module(new RationalClockBridge(RationalClock("HalfRate", 1, 2))) + val List(refClock, div2Clock) = clockBridge.io.clocks.toList + val reset = WireInit(false.B) + withClockAndReset(refClock, reset) { + val fullRateMod = Module(new RegisteredAssertModule) + val fullRatePulseGen = Module(new StimulusGenerator) + fullRateMod.io.pred := fullRatePulseGen.pred + + val halfRateMod = Module(new RegisteredAssertModule) + halfRateMod.clock := div2Clock + val halfRatePulseGen = Module(new StimulusGenerator) + halfRateMod.io.pred := halfRatePulseGen.pred + + val peekPokeBridge = PeekPokeBridge(refClock, reset, + ("fullrate", fullRatePulseGen.input), + ("halfrate", halfRatePulseGen.input)) + } +} diff --git a/sim/src/main/scala/midasexamples/AutoCounterModule.scala b/sim/src/main/scala/midasexamples/AutoCounterModule.scala index 6d62aa08..cccb751b 100644 --- a/sim/src/main/scala/midasexamples/AutoCounterModule.scala +++ b/sim/src/main/scala/midasexamples/AutoCounterModule.scala @@ -9,7 +9,10 @@ import chisel3.core.MultiIOModule import midas.targetutils.{PerfCounter, AutoCounterCoverModuleAnnotation} import freechips.rocketchip.util.property._ -class AutoCounterModuleDUT extends Module { +class AutoCounterModuleDUT( + printfPrefix: String = "AUTOCOUNTER_PRINT ", + instPath: String = "AutoCounterModule_AutoCounterModuleDUT", + clockDivision: Int = 1) extends Module { val io = IO(new Bundle { val a = Input(Bool()) }) @@ -29,18 +32,19 @@ class AutoCounterModuleDUT extends Module { //--------VALIDATION--------------- + val samplePeriod = 1000 / clockDivision val enabled_printcount = freechips.rocketchip.util.WideCounter(64, io.a) val enabled4_printcount = freechips.rocketchip.util.WideCounter(64, enabled4) val oddlfsr_printcount = freechips.rocketchip.util.WideCounter(64, childInst.io.oddlfsr) val cycle_print = Reg(UInt(64.W)) cycle_print := cycle_print + 1.U - when ((cycle_print >= 1000.U) & (cycle_print % 1000.U === 0.U)) { - printf("AUTOCOUNTER_PRINT Cycle %d\n", cycle_print) - printf("AUTOCOUNTER_PRINT ============================\n") - printf("AUTOCOUNTER_PRINT PerfCounter ENABLED_AutoCounterModule_AutoCounterModuleDUT: %d\n", enabled_printcount) - printf("AUTOCOUNTER_PRINT PerfCounter ENABLED_DIV_4_AutoCounterModule_AutoCounterModuleDUT: %d\n", enabled4_printcount) - printf("AUTOCOUNTER_PRINT PerfCounter ODD_LFSR_AutoCounterModule_AutoCounterModuleDUT_childInst: %d\n", oddlfsr_printcount) - printf("AUTOCOUNTER_PRINT \n") + when ((cycle_print >= (samplePeriod - 1).U) & (cycle_print % samplePeriod.U === (samplePeriod - 1).U)) { + printf(s"${printfPrefix}Cycle %d\n", cycle_print) + printf(s"${printfPrefix}============================\n") + printf(s"${printfPrefix}PerfCounter ENABLED_${instPath}: %d\n", enabled_printcount) + printf(s"${printfPrefix}PerfCounter ENABLED_DIV_4_${instPath}: %d\n", enabled4_printcount) + printf(s"${printfPrefix}PerfCounter ODD_LFSR_${instPath}_childInst: %d\n", oddlfsr_printcount) + printf(s"${printfPrefix}\n") } } @@ -80,9 +84,10 @@ class AutoCounterCoverModuleDUT extends Module { when (cycle8) { cycle8_printcount := cycle8_printcount + 1.U } + val samplePeriod = 1000 val cycle_print = Reg(UInt(64.W)) cycle_print := cycle_print + 1.U - when ((cycle_print >= 1000.U) & (cycle_print % 1000.U === 0.U)) { + when ((cycle_print >= (samplePeriod - 1).U) & (cycle_print % 1000.U === (samplePeriod - 1).U)) { printf("AUTOCOUNTER_PRINT Cycle %d\n", cycle_print) printf("AUTOCOUNTER_PRINT ============================\n") printf("AUTOCOUNTER_PRINT PerfCounter CYCLES_DIV_8_AutoCounterCoverModule_AutoCounterCoverModuleDUT: %d\n", cycle8_printcount) @@ -93,3 +98,23 @@ class AutoCounterCoverModuleDUT extends Module { class AutoCounterCoverModule extends PeekPokeMidasExampleHarness(() => new AutoCounterCoverModuleDUT) +class AutoCounterPrintfDUT extends Module { + val io = IO(new Bundle { + val a = Input(Bool()) + }) + + val childInst = Module(new AutoCounterModuleChild) + childInst.io.c := io.a + + //--------VALIDATION--------------- + + val oddlfsr_printcount = freechips.rocketchip.util.WideCounter(64, childInst.io.oddlfsr) + val cycle_print = Reg(UInt(39.W)) + cycle_print := cycle_print + 1.U + when (childInst.io.oddlfsr) { + printf("SYNTHESIZED_PRINT CYCLE: %d [AutoCounter] ODD_LFSR: %d\n", cycle_print, oddlfsr_printcount) + } +} + +class AutoCounterPrintfModule extends PeekPokeMidasExampleHarness(() => new AutoCounterPrintfDUT) + diff --git a/sim/src/main/scala/midasexamples/Config.scala b/sim/src/main/scala/midasexamples/Config.scala index 524326f6..f2591413 100644 --- a/sim/src/main/scala/midasexamples/Config.scala +++ b/sim/src/main/scala/midasexamples/Config.scala @@ -13,11 +13,11 @@ import firesim.configs.WithDefaultMemModel class NoConfig extends Config(Parameters.empty) // This is incomplete and must be mixed into a complete platform config class DefaultF1Config extends Config(new Config((site, here, up) => { - case DesiredHostFrequency => 75 - case SynthAsserts => true - case midas.GenerateMultiCycleRamModels => true - case SynthPrints => true - case TargetTransforms => ((p: Parameters) => Seq(new midas.passes.AutoCounterTransform()(p))) +: up(TargetTransforms, site) + case DesiredHostFrequency => 75 + case SynthAsserts => true + case GenerateMultiCycleRamModels => true + case SynthPrints => true + case EnableAutoCounter => true }) ++ new Config(new firesim.configs.WithEC2F1Artefacts ++ new WithDefaultMemModel ++ new midas.F1Config)) class PointerChaserConfig extends Config((site, here, up) => { @@ -28,3 +28,8 @@ class PointerChaserConfig extends Config((site, here, up) => { case NastiKey => NastiParameters(dataBits = 64, addrBits = 32, idBits = 3) case Seed => System.currentTimeMillis }) + +class AutoCounterPrintf extends Config((site, here, up) => { + case AutoCounterUsePrintfImpl => true +}) + diff --git a/sim/src/main/scala/midasexamples/GCD.scala b/sim/src/main/scala/midasexamples/GCD.scala index 9e659500..f23fcc0e 100644 --- a/sim/src/main/scala/midasexamples/GCD.scala +++ b/sim/src/main/scala/midasexamples/GCD.scala @@ -4,18 +4,21 @@ package firesim.midasexamples import chisel3._ import chisel3.util.unless -import chisel3.experimental.{withClock} +import chisel3.experimental.{withClock, annotate} import midas.widgets.PeekPokeBridge +import midas.targetutils.FAMEModelAnnotation -class GCDDUT extends Module { - val io = IO(new Bundle { - val a = Input(UInt(16.W)) - val b = Input(UInt(16.W)) - val e = Input(Bool()) - val z = Output(UInt(16.W)) - val v = Output(Bool()) - }) +class GCDIO extends Bundle { + val a = Input(UInt(16.W)) + val b = Input(UInt(16.W)) + val e = Input(Bool()) + val z = Output(UInt(16.W)) + val v = Output(Bool()) +} + +class GCDInner extends Module { + val io = IO(new GCDIO) val x = Reg(UInt()) val y = Reg(UInt()) when (x > y) { x := x - y } @@ -23,9 +26,53 @@ class GCDDUT extends Module { when (io.e) { x := io.a; y := io.b } io.z := x io.v := y === 0.U - - assert(!io.e || io.a =/= 0.U && io.b =/= 0.U, "Inputs to GCD cannot be 0") + // TODO: this assertion fails spuriously with deduped, extracted models + // assert(!io.e || io.a =/= 0.U && io.b =/= 0.U, "Inputs to GCD cannot be 0") printf("X: %d, Y:%d\n", x, y) } +class GCDDUT extends Module { + val io = IO(new GCDIO) + val inner1 = Module(new GCDInner) + annotate(FAMEModelAnnotation(inner1)) + val inner2 = Module(new GCDInner) + annotate(FAMEModelAnnotation(inner2)) + + val select = RegInit(false.B) + select := !select + + val done1 = RegInit(false.B) + val result1 = Reg(UInt()) + + when (io.v) { + done1 := false.B + } .elsewhen (inner1.io.v) { + done1 := true.B + result1 := inner1.io.z + } + + val done2 = RegInit(false.B) + val result2 = Reg(UInt()) + + when (io.v) { + done2 := false.B + } .elsewhen (inner2.io.v) { + done2 := true.B + result2 := inner2.io.z + } + + inner1.io.a := io.a + inner1.io.b := io.b + inner1.io.e := io.e + + inner2.io.a := io.b + inner2.io.b := io.a + inner2.io.e := io.e + + io.z := Mux(select, result1, result2) + io.v := done1 && done2 + + assert(!done1 || !done2 || (result1 === result2), "Outputs do not match!") +} + class GCD extends PeekPokeMidasExampleHarness(() => new GCDDUT) diff --git a/sim/src/main/scala/midasexamples/MulticlockAutoCounterModule.scala b/sim/src/main/scala/midasexamples/MulticlockAutoCounterModule.scala new file mode 100644 index 00000000..9696b51b --- /dev/null +++ b/sim/src/main/scala/midasexamples/MulticlockAutoCounterModule.scala @@ -0,0 +1,34 @@ +//See LICENSE for license details. + +package firesim.midasexamples + +import chisel3._ +import freechips.rocketchip.util.ResetCatchAndSync +import midas.widgets.{RationalClockBridge, PeekPokeBridge, RationalClock} + + +// Instantiates two of the autocounter duts from the single-clock test +// Use two separate prefixes so that we can partition the output +// from verilator/vcs and compare against the two files produced by +// the bridges +class MulticlockAutoCounterModule extends RawModule { + val clockBridge = Module(new RationalClockBridge(RationalClock("ThirdRate", 1, 3))) + val List(refClock, div2Clock) = clockBridge.io.clocks.toList + val reset = WireInit(false.B) + val resetHalfRate = ResetCatchAndSync(div2Clock, reset.toBool) + // Used to let printfs that emit the correct validation output + val instPath = "MulticlockAutoCounterModule_AutoCounterModuleDUT" + withClockAndReset(refClock, reset) { + val lfsr = chisel3.util.LFSR16() + val fullRateMod = Module(new AutoCounterModuleDUT(instPath = instPath)) + fullRateMod.io.a := lfsr(0) + val peekPokeBridge = PeekPokeBridge(refClock, reset) + } + withClockAndReset(div2Clock, resetHalfRate) { + val lfsr = chisel3.util.LFSR16() + val fullRateMod = Module(new AutoCounterModuleDUT("AUTOCOUNTER_PRINT_THIRDRATE ", + instPath = instPath + "_1", + clockDivision = 3)) + fullRateMod.io.a := lfsr(0) + } +} diff --git a/sim/src/main/scala/midasexamples/MulticlockPrintfModule.scala b/sim/src/main/scala/midasexamples/MulticlockPrintfModule.scala new file mode 100644 index 00000000..ee7d2111 --- /dev/null +++ b/sim/src/main/scala/midasexamples/MulticlockPrintfModule.scala @@ -0,0 +1,32 @@ +//See LICENSE for license details. + +package firesim.midasexamples + +import chisel3._ +import freechips.rocketchip.util.ResetCatchAndSync +import midas.widgets.{RationalClockBridge, PeekPokeBridge, RationalClock} + + +// Instantiates two of the printf duts from the single-clock test +// Use two separate prefixes so that we can partition the output +// from verilator/vcs and compare against the two files produced by +// the bridges +class MulticlockPrintfModule extends RawModule { + val clockBridge = Module(new RationalClockBridge(RationalClock("HalfRate",1 , 2))) + val List(refClock, div2Clock) = clockBridge.io.clocks.toList + val reset = WireInit(false.B) + val resetHalfRate = ResetCatchAndSync(div2Clock, reset.toBool) + withClockAndReset(refClock, reset) { + val lfsr = chisel3.util.LFSR16() + val fullRateMod = Module(new PrintfModuleDUT) + fullRateMod.io.a := lfsr(0) + fullRateMod.io.b := ~lfsr(0) + val peekPokeBridge = PeekPokeBridge(refClock, reset) + } + withClockAndReset(div2Clock, resetHalfRate) { + val lfsr = chisel3.util.LFSR16() + val fullRateMod = Module(new PrintfModuleDUT("SYNTHESIZED_PRINT_HALFRATE ")) + fullRateMod.io.a := lfsr(0) + fullRateMod.io.b := ~lfsr(0) + } +} diff --git a/sim/src/main/scala/midasexamples/PointerChaser.scala b/sim/src/main/scala/midasexamples/PointerChaser.scala index 1a711746..fed873a6 100644 --- a/sim/src/main/scala/midasexamples/PointerChaser.scala +++ b/sim/src/main/scala/midasexamples/PointerChaser.scala @@ -113,8 +113,8 @@ class PointerChaser(implicit val p: Parameters) extends RawModule { val fasedInstance = Module(new FASEDBridge(CompleteConfig(LatencyPipeConfig(BaseParams(16,16)), p(NastiKey)))) fasedInstance.io.axi4 <> pointerChaser.io.nasti fasedInstance.io.reset := reset - val peekPokeBridge = PeekPokeBridge(reset, - ("io_startAddr", pointerChaser.io.startAddr), - ("io_result", pointerChaser.io.result)) + val peekPokeBridge = PeekPokeBridge(clock, reset, + ("io_startAddr", pointerChaser.io.startAddr), + ("io_result", pointerChaser.io.result)) } } diff --git a/sim/src/main/scala/midasexamples/PrintfModule.scala b/sim/src/main/scala/midasexamples/PrintfModule.scala index 0c206482..4bc7902d 100644 --- a/sim/src/main/scala/midasexamples/PrintfModule.scala +++ b/sim/src/main/scala/midasexamples/PrintfModule.scala @@ -7,37 +7,36 @@ import chisel3.util.LFSR16 import midas.targetutils.SynthesizePrintf -class PrintfModuleDUT extends Module { +class PrintfModuleDUT(printfPrefix: String = "SYNTHESIZED_PRINT ") extends Module { val io = IO(new Bundle { val a = Input(Bool()) val b = Input(Bool()) }) val cycle = RegInit(0.U(16.W)) - - when(io.a) { cycle := cycle + 1.U } + cycle := cycle + 1.U // Printf format strings must be prefixed with "SYNTHESIZED_PRINT CYCLE: %d" // so they can be pulled out of RTL simulators log and sorted within a cycle // As the printf order will be different betwen RTL simulator and synthesized stream - printf(SynthesizePrintf("SYNTHESIZED_PRINT CYCLE: %d\n", cycle)) + printf(SynthesizePrintf(s"${printfPrefix}CYCLE: %d\n", cycle)) val wideArgument = VecInit(Seq.fill(33)(WireInit(cycle))).asUInt - printf(SynthesizePrintf("SYNTHESIZED_PRINT CYCLE: %d wideArgument: %x\n", cycle, wideArgument)) // argument width > DMA width + printf(SynthesizePrintf(s"${printfPrefix}CYCLE: %d wideArgument: %x\n", cycle, wideArgument)) // argument width > DMA width - val childInst = Module(new PrintfModuleChild) + val childInst = Module(new PrintfModuleChild(printfPrefix)) childInst.c := io.a childInst.cycle := cycle - printf(SynthesizePrintf("thi$!sn+taS/\neName", "SYNTHESIZED_PRINT CYCLE: %d constantArgument: %x\n", cycle, 1.U(8.W))) + printf(SynthesizePrintf("thi$!sn+taS/\neName", s"${printfPrefix}CYCLE: %d constantArgument: %x\n", cycle, 1.U(8.W))) } -class PrintfModuleChild extends MultiIOModule { +class PrintfModuleChild(printfPrefix: String) extends MultiIOModule { val c = IO(Input(Bool())) val cycle = IO(Input(UInt(16.W))) val lfsr = chisel3.util.LFSR16(c) - printf(SynthesizePrintf("SYNTHESIZED_PRINT CYCLE: %d LFSR: %x\n", cycle, lfsr)) + printf(SynthesizePrintf(s"${printfPrefix}CYCLE: %d LFSR: %x\n", cycle, lfsr)) //when (lsfr(0)) { // printf(SynthesizePrintf(p"SYNTHESIZED_PRINT CYCLE: ${cycle} LFSR is odd")) diff --git a/sim/src/main/scala/midasexamples/TriggerWiringModule.scala b/sim/src/main/scala/midasexamples/TriggerWiringModule.scala new file mode 100644 index 00000000..1a677f8d --- /dev/null +++ b/sim/src/main/scala/midasexamples/TriggerWiringModule.scala @@ -0,0 +1,119 @@ +//See LICENSE for license details. + +package firesim.midasexamples + +import midas.widgets.{RationalClockBridge, PeekPokeBridge, RationalClock} +import midas.targetutils.{TriggerSource, TriggerSink} +import freechips.rocketchip.util.{DensePrefixSum, ResetCatchAndSync} +import chisel3._ +import chisel3.util._ +import chisel3.experimental.chiselName + +import scala.collection.mutable + +class TriggerSinkModule extends MultiIOModule { + val reference = IO(Input(Bool())) + val generated = WireDefault(true.B) + TriggerSink(generated) + assert(reference === generated) +} + +class TriggerSourceModule extends MultiIOModule { + val referenceCredit = IO(Output(Bool())) + private val lfsr = LFSR16() + val credit = lfsr(0) + TriggerSource.credit(credit) + referenceCredit := ~reset.toBool && credit + + val referenceDebit = IO(Output(Bool())) + val debit = ShiftRegister(lfsr(0), 5) + TriggerSource.debit(debit) + referenceDebit := ~reset.toBool && debit +} + +class ReferenceSourceCounters(numCredits: Int, numDebits: Int) extends MultiIOModule { + def counterType = UInt(16.W) + val inputCredits = IO(Input(Vec(numCredits, Bool()))) + val inputDebits = IO(Input(Vec(numCredits, Bool()))) + val totalCredit = IO(Output(counterType)) + val totalDebit = IO(Output(counterType)) + + def doAccounting(values: Seq[Bool]): UInt = { + val total = Reg(counterType) + val update = total + PopCount(values) + total := update + update + } + totalCredit := doAccounting(inputCredits) + totalDebit := doAccounting(inputDebits) + + @chiselName + def synchAndDiff(count: UInt): UInt = { + val sync = RegNext(count) + val syncLast = RegNext(sync) + sync - syncLast + } + def syncAndDiffCredits(): UInt = synchAndDiff(totalCredit) + def syncAndDiffDebits(): UInt = synchAndDiff(totalDebit) +} + +object ReferenceSourceCounters { + def apply(credits: Seq[Bool], debits: Seq[Bool]): ReferenceSourceCounters = { + val m = Module(new ReferenceSourceCounters(credits.size, debits.size)) + m.inputCredits := VecInit(credits) + m.inputDebits := VecInit(debits) + m + } +} + +// This test target implements in Chisel what the Trigger Transformation should +// implement in FIRRTL. The test fails if the firrtl-generated trigger-enables, +// as seen by all nodes with a trigger sink, fail to match their references. +class TriggerWiringModule extends RawModule { + val clockBridge = Module(new RationalClockBridge(RationalClock("HalfRate", 1, 2))) + val refClock :: div2Clock :: _ = clockBridge.io.clocks.toList + val refSourceCounts = new mutable.ArrayBuffer[ReferenceSourceCounters]() + val refSinks = new mutable.ArrayBuffer[Bool]() + val reset = WireInit(false.B) + val resetHalfRate = ResetCatchAndSync(div2Clock, reset.toBool) + withClockAndReset(refClock, reset) { + val peekPokeBridge = PeekPokeBridge(refClock, reset) + val src = Module(new TriggerSourceModule) + val sink = Module(new TriggerSinkModule) + + // Reference Hardware + refSourceCounts += ReferenceSourceCounters(Seq(src.referenceCredit), Seq(src.referenceDebit)) + refSinks += { + val syncReg = Reg(Bool()) + sink.reference := syncReg + syncReg + } + } + + withClockAndReset(div2Clock, resetHalfRate) { + val src = Module(new TriggerSourceModule) + val sink = Module(new TriggerSinkModule) + // Reference Hardware + refSourceCounts += ReferenceSourceCounters(Seq(src.referenceCredit), Seq(src.referenceDebit)) + refSinks += { + val syncReg = Reg(Bool()) + sink.reference := syncReg + syncReg + } + } + + @chiselName + class ReferenceImpl { + val refTotalCredit = Reg(UInt(32.W)) + val refTotalDebit = Reg(UInt(32.W)) + val refCreditNext = refTotalCredit + DensePrefixSum(refSourceCounts.map(_.syncAndDiffCredits))(_ + _).last + val refDebitNext = refTotalDebit + DensePrefixSum(refSourceCounts.map(_.syncAndDiffDebits))(_ + _).last + refTotalCredit := refCreditNext + refTotalDebit := refDebitNext + val refTriggerEnable = refCreditNext =/= refDebitNext + refSinks foreach { _ := refTriggerEnable } + } + + // Reference Trigger Enable + withClock(refClock) { new ReferenceImpl } +} diff --git a/sim/src/main/scala/midasexamples/TrivialMulticlock.scala b/sim/src/main/scala/midasexamples/TrivialMulticlock.scala new file mode 100644 index 00000000..4b22be7c --- /dev/null +++ b/sim/src/main/scala/midasexamples/TrivialMulticlock.scala @@ -0,0 +1,53 @@ +//See LICENSE for license details. + +package firesim.midasexamples + +import chisel3._ + +import midas.widgets.{RationalClockBridge, PeekPokeBridge, RationalClock} + +class RegisterModule extends MultiIOModule { + def dataType = UInt(32.W) + val in = IO(Input(dataType)) + val out = IO(Output(dataType)) + val slowClock = IO(Input(Clock())) + // Register the input and output in the fast domain so that the PeekPoke + // bridge is synchronous with the design. The clock crossing is contained in this module + val regIn = RegNext(in) + val regOut = Reg(in.cloneType) + out := regOut + withClock(slowClock) { + regOut := RegNext(regIn) + } +} + +class TrivialMulticlock extends RawModule { + // TODO: Resolve bug in PeekPoke bridge for 3/7 case + //val clockBridge = Module(new RationalClockBridge(1000, (1,2), (1,3), (3,7))) + //val List(fullRate, halfRate, thirdRate, threeSeventhsRate) = clockBridge.io.clocks.toList + val clockBridge = Module(new RationalClockBridge(RationalClock("HalfRate", 1, 2), + RationalClock("ThirdRate", 1, 3))) + val List(fullRate, halfRate, thirdRate) = clockBridge.io.clocks.toList + val reset = WireInit(false.B) + + withClockAndReset(fullRate, reset) { + val halfRateInst = Module(new RegisterModule) + halfRateInst.slowClock := halfRate + val thirdRateInst = Module(new RegisterModule) + thirdRateInst.slowClock := thirdRate + thirdRateInst.in := halfRateInst.in + val threeSeventhsRateInst = Module(new RegisterModule) + // TODO: See above + //threeSeventhsRateInst.slowClock := threeSeventhsRate + threeSeventhsRateInst.slowClock := fullRate + threeSeventhsRateInst.in := halfRateInst.in + + // TODO: Remove reset + val peekPokeBridge = PeekPokeBridge(fullRate, + reset, + ("in", halfRateInst.in), + ("halfOut",halfRateInst.out), + ("thirdOut", thirdRateInst.out), + ("threeSeventhsOut", threeSeventhsRateInst.out)) + } +} diff --git a/sim/src/main/scala/midasexamples/TwoAdders.scala b/sim/src/main/scala/midasexamples/TwoAdders.scala new file mode 100644 index 00000000..d54ae347 --- /dev/null +++ b/sim/src/main/scala/midasexamples/TwoAdders.scala @@ -0,0 +1,60 @@ +//See LICENSE for license details. + +package firesim.midasexamples + +import chisel3._ +import chisel3.util.unless +import chisel3.experimental.{withClock, annotate} + +import midas.widgets.PeekPokeBridge +import midas.targetutils._ + + +class AdderIO extends Bundle { + val x = Input(UInt(16.W)) + val y = Input(UInt(16.W)) + val z = Output(UInt(16.W)) +} + +class PipeAdder extends Module { + val io = IO(new AdderIO) + val mem = Mem(2, UInt(16.W)) + mem.write(1.U, io.x + io.y) + val memout = mem.read(1.U) + io.z := memout +} + +class DoublePipeAdder extends Module { + val io = IO(new AdderIO) + val logic = Module(new PipeAdder) + logic.io.x := io.x + logic.io.y := io.y + io.z := RegNext(logic.io.z) +} + +class TwoAddersDUT extends Module { + val io = IO(new Bundle { + val i0 = Input(UInt(16.W)) + val i1 = Input(UInt(16.W)) + val i2 = Input(UInt(16.W)) + val i3 = Input(UInt(16.W)) + val o0 = Output(UInt(16.W)) + val o1 = Output(UInt(16.W)) + }) + + val a0 = Module(new DoublePipeAdder) + annotate(FAMEModelAnnotation(a0)) + annotate(EnableModelMultiThreadingAnnotation(a0)) + a0.io.x := io.i0 + a0.io.y := io.i1 + io.o0 := a0.io.z + + val a1 = Module(new DoublePipeAdder) + annotate(FAMEModelAnnotation(a1)) + annotate(EnableModelMultiThreadingAnnotation(a1)) + a1.io.x := io.i2 + a1.io.y := io.i3 + io.o1 := a1.io.z +} + +class TwoAdders extends PeekPokeMidasExampleHarness(() => new TwoAddersDUT) diff --git a/sim/src/main/scala/midasexamples/Util.scala b/sim/src/main/scala/midasexamples/Util.scala index a39db999..15baea9a 100644 --- a/sim/src/main/scala/midasexamples/Util.scala +++ b/sim/src/main/scala/midasexamples/Util.scala @@ -5,17 +5,17 @@ package firesim.midasexamples import chisel3._ import chisel3.experimental.{withClock} -import midas.widgets.PeekPokeBridge +import midas.widgets.{RationalClockBridge, PeekPokeBridge} // A simple MIDAS harness that generates a legacy // module DUT (it has a single io: Data member) and connects all of // its IO to a PeekPokeBridge class PeekPokeMidasExampleHarness(dutGen: () => Module) extends RawModule { - val clock = IO(Input(Clock())) + val clock = Module(new RationalClockBridge()).io.clocks.head val reset = WireInit(false.B) withClockAndReset(clock, reset) { val dut = Module(dutGen()) - val peekPokeBridge = PeekPokeBridge(reset, ("io", dut.io)) + val peekPokeBridge = PeekPokeBridge(clock, reset, ("io", dut.io)) } } diff --git a/sim/src/test/scala/midasexamples/TutorialSuite.scala b/sim/src/test/scala/midasexamples/TutorialSuite.scala index c7e82ddf..29fea9ec 100644 --- a/sim/src/test/scala/midasexamples/TutorialSuite.scala +++ b/sim/src/test/scala/midasexamples/TutorialSuite.scala @@ -10,6 +10,7 @@ import firesim.util.GeneratorArgs abstract class TutorialSuite( val targetName: String, // See GeneratorUtils targetConfigs: String = "NoConfig", + platformConfigs: String = "HostDebugFeatures_DefaultF1Config", tracelen: Int = 8, simulationArgs: Seq[String] = Seq() ) extends firesim.TestSuiteCommon with firesim.util.HasFireSimGeneratorUtilities { @@ -25,7 +26,7 @@ abstract class TutorialSuite( targetConfigProject = "firesim.midasexamples", targetConfigs = targetConfigs, platformConfigProject = "firesim.midasexamples", - platformConfigs = "HostDebugFeatures_DefaultF1Config") + platformConfigs = platformConfigs) val args = Seq(s"+tracelen=$tracelen") ++ simulationArgs val commonMakeArgs = Seq(s"TARGET_PROJECT=midasexamples", @@ -66,52 +67,34 @@ abstract class TutorialSuite( } } - // Checks that the synthesized print log in ${genDir}/${synthPrintLog} matches the - // printfs from the RTL simulator - def diffSynthesizedPrints(synthPrintLog: String) { - behavior of "synthesized print log" - it should "match the logs produced by the verilated design" in { - def printLines(filename: File): Seq[String] = { - val lines = Source.fromFile(filename).getLines.toList - lines.filter(_.startsWith("SYNTHESIZED_PRINT")).sorted + // Checks that a bridge generated log in ${genDir}/${synthLog} matches output + // generated directly by the RTL simulator (usually with printfs) + def diffSynthesizedLog(synthLog: String, + stdoutPrefix: String = "SYNTHESIZED_PRINT ", + synthPrefix: String = "SYNTHESIZED_PRINT ", + synthLinesToDrop: Int = 0) { + behavior of s"${synthLog}" + it should "match the prints generated by the verilated design" in { + def printLines(filename: File, prefix: String, linesToDrop: Int = 0): Seq[String] = { + // Drop the first line from all files as it is either a header in the synthesized file, + // or some unrelated output from verlator + val lines = Source.fromFile(filename).getLines.toList.drop(1) + lines.filter(_.startsWith(prefix)) + .dropRight(linesToDrop) + .map(_.stripPrefix(prefix).replaceAll(" +", " ")) + .sorted } - val verilatedOutput = printLines(new File(outDir, s"/${targetName}.${backendSimulator}.out")) - val synthPrintOutput = printLines(new File(genDir, s"/${synthPrintLog}")) - assert(verilatedOutput.size == synthPrintOutput.size && verilatedOutput.nonEmpty) + val verilatedOutput = printLines(new File(outDir, s"/${targetName}.${backendSimulator}.out"), stdoutPrefix) + val synthPrintOutput = printLines(new File(genDir, s"/${synthLog}"), synthPrefix, synthLinesToDrop) + assert(verilatedOutput.size == synthPrintOutput.size && verilatedOutput.nonEmpty, + s"\nSynthesized output had length ${synthPrintOutput.size}. Expected ${verilatedOutput.size}") for ( (vPrint, sPrint) <- verilatedOutput.zip(synthPrintOutput) ) { assert(vPrint == sPrint) } } } - // Checks that the synthesized print log in ${genDir}/${synthPrintLog} matches the - // printfs from the RTL simulator - def diffAutoCounterOutput(autocounterOutputLog: String, referenceFile: String) { - behavior of "AutoCounter output log" - it should "match the logs commited based on the design intent" in { - def printLines(filename: File): Seq[String] = { - val lines = Source.fromFile(filename).getLines.toList - lines.sorted - } - - def printVerilatorLines(filename: File): Seq[String] = { - val lines = Source.fromFile(filename).getLines.toList - val stripedlines = lines.filter(_.startsWith("AUTOCOUNTER_PRINT")).map(line => line.stripPrefix("AUTOCOUNTER_PRINT").trim.replaceAll(" +", " ")) - stripedlines.sorted - } - - //val referenceOutput = printLines(new File(outDir, s"/${referenceFile}")) - val referenceOutput = printVerilatorLines(new File(outDir, s"/${targetName}.${backendSimulator}.out")) - val autocounterOutput = printLines(new File(genDir, s"/${autocounterOutputLog}")) - assert(referenceOutput.size == autocounterOutput.size && referenceOutput.nonEmpty) - for ( (rPrint, acPrint) <- referenceOutput.zip(autocounterOutput) ) { - assert(rPrint == acPrint) - } - } - } - - clean mkdirs elaborate runTest(backendSimulator) @@ -133,20 +116,52 @@ class RiscF1Test extends TutorialSuite("Risc") class RiscSRAMF1Test extends TutorialSuite("RiscSRAM") class AssertModuleF1Test extends TutorialSuite("AssertModule") class AutoCounterModuleF1Test extends TutorialSuite("AutoCounterModule", - simulationArgs = Seq("+autocounter-readrate0=1000", "+autocounter-filename0=AUTOCOUNTERFILE0")) { - diffAutoCounterOutput("AUTOCOUNTERFILE0", "AutoCounterModule.autocounter.out") + simulationArgs = Seq("+autocounter-readrate=1000", "+autocounter-filename=AUTOCOUNTERFILE")) { + diffSynthesizedLog("AUTOCOUNTERFILE0", stdoutPrefix = "AUTOCOUNTER_PRINT ", synthPrefix = "") } class AutoCounterCoverModuleF1Test extends TutorialSuite("AutoCounterCoverModule", - simulationArgs = Seq("+autocounter-readrate0=1000", "+autocounter-filename0=AUTOCOUNTERFILE0")) { - diffAutoCounterOutput("AUTOCOUNTERFILE0", "AutoCounterCoverModule.autocounter.out") + simulationArgs = Seq("+autocounter-readrate=1000", "+autocounter-filename=AUTOCOUNTERFILE")) { + diffSynthesizedLog("AUTOCOUNTERFILE0", stdoutPrefix = "AUTOCOUNTER_PRINT ", synthPrefix = "") + +} +class AutoCounterPrintfF1Test extends TutorialSuite("AutoCounterPrintfModule", + simulationArgs = Seq("+print-file=synthprinttest.out"), + platformConfigs = "AutoCounterPrintf_HostDebugFeatures_DefaultF1Config") { + diffSynthesizedLog("synthprinttest.out0", stdoutPrefix = "SYNTHESIZED_PRINT CYCLE", synthPrefix = "CYCLE") } class PrintfModuleF1Test extends TutorialSuite("PrintfModule", simulationArgs = Seq("+print-no-cycle-prefix", "+print-file=synthprinttest.out")) { - diffSynthesizedPrints("synthprinttest.out") + diffSynthesizedLog("synthprinttest.out0") } class NarrowPrintfModuleF1Test extends TutorialSuite("NarrowPrintfModule", simulationArgs = Seq("+print-no-cycle-prefix", "+print-file=synthprinttest.out")) { - diffSynthesizedPrints("synthprinttest.out") + diffSynthesizedLog("synthprinttest.out0") } -// MIDAS 2.0 compiler tests + class WireInterconnectF1Test extends TutorialSuite("WireInterconnect") +class TrivialMulticlockF1Test extends TutorialSuite("TrivialMulticlock") { + runTest("verilator", true) + runTest("vcs", true) +} + +class TriggerWiringModuleF1Test extends TutorialSuite("TriggerWiringModule") + +class MulticlockAssertF1Test extends TutorialSuite("MulticlockAssertModule") + +class MulticlockPrintF1Test extends TutorialSuite("MulticlockPrintfModule", + simulationArgs = Seq("+print-file=synthprinttest.out", + "+print-no-cycle-prefix")) { + diffSynthesizedLog("synthprinttest.out0") + diffSynthesizedLog("synthprinttest.out1", + stdoutPrefix = "SYNTHESIZED_PRINT_HALFRATE ", + synthPrefix = "SYNTHESIZED_PRINT_HALFRATE ", + synthLinesToDrop = 4) // Correspondes to a single cycle of extra output +} + +class MulticlockAutoCounterF1Test extends TutorialSuite("MulticlockAutoCounterModule", + simulationArgs = Seq("+autocounter-readrate=1000", "+autocounter-filename=AUTOCOUNTERFILE")) { + diffSynthesizedLog("AUTOCOUNTERFILE0", "AUTOCOUNTER_PRINT ", "") + diffSynthesizedLog("AUTOCOUNTERFILE1", "AUTOCOUNTER_PRINT_THIRDRATE ", "") +} +// Basic test for deduplicated extracted models +class TwoAddersF1Test extends TutorialSuite("TwoAdders") diff --git a/target-design/chipyard b/target-design/chipyard index 8c6b66d9..fbc47af6 160000 --- a/target-design/chipyard +++ b/target-design/chipyard @@ -1 +1 @@ -Subproject commit 8c6b66d9b2ca8b12a1507cdcd63cd91091fed687 +Subproject commit fbc47af67cb8df379347b26e30c3b7ade75306b7