From c4eac86ce53b5b046da948f7437534fe9a6843d5 Mon Sep 17 00:00:00 2001 From: Sylvain Benner Date: Tue, 2 Apr 2024 09:27:49 -0400 Subject: [PATCH] [backend-comparison] Add all choice to --benches and --backends (#1567) + Make some tweaks in logs --- backend-comparison/src/burnbenchapp/base.rs | 41 ++++++++++++++------- backend-comparison/src/persistence/base.rs | 5 ++- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/backend-comparison/src/burnbenchapp/base.rs b/backend-comparison/src/burnbenchapp/base.rs index ea567e97b..8ebac3f9a 100644 --- a/backend-comparison/src/burnbenchapp/base.rs +++ b/backend-comparison/src/burnbenchapp/base.rs @@ -64,6 +64,8 @@ struct RunArgs { #[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)] pub(crate) enum BackendValues { + #[strum(to_string = "all")] + All, #[strum(to_string = "candle-cpu")] CandleCpu, #[strum(to_string = "candle-cuda")] @@ -90,6 +92,8 @@ pub(crate) enum BackendValues { #[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)] pub(crate) enum BenchmarkValues { + #[strum(to_string = "all")] + All, #[strum(to_string = "binary")] Binary, #[strum(to_string = "custom-gelu")] @@ -142,20 +146,26 @@ fn command_run(run_args: RunArgs) { if run_args.share { tokens = get_tokens(); } - let total_combinations = run_args.backends.len() * run_args.benches.len(); - println!( - "Executing benchmark and backend combinations in total: {}", - total_combinations - ); + // collect benchmarks and benches to execute + let mut backends = run_args.backends.clone(); + if backends.contains(&BackendValues::All) { + backends = BackendValues::iter() + .filter(|b| b != &BackendValues::All) + .collect(); + } + let mut benches = run_args.benches.clone(); + if benches.contains(&BenchmarkValues::All) { + benches = BenchmarkValues::iter() + .filter(|b| b != &BenchmarkValues::All) + .collect(); + } + + let total_combinations = backends.len() * benches.len(); let mut app = App::new(); app.init(); - println!("Running benchmarks...\n"); + println!("Running {} benchmark(s)...\n", total_combinations); let access_token = tokens.map(|t| t.access_token); - app.run( - &run_args.benches, - &run_args.backends, - access_token.as_deref(), - ); + app.run(&benches, &backends, access_token.as_deref()); app.cleanup(); } @@ -177,6 +187,8 @@ pub(crate) fn run_backend_comparison_benchmarks( backends: &[BackendValues], token: Option<&str>, ) { + let total_count = backends.len() * benches.len(); + let mut current_index = 0; // Prefix and postfix for titles let filler = ["="; 10].join(""); @@ -195,9 +207,10 @@ pub(crate) fn run_backend_comparison_benchmarks( for backend in backends.iter() { let bench_str = bench.to_string(); let backend_str = backend.to_string(); + current_index += 1; println!( - "{}Benchmarking {} on {}{}", - filler, bench_str, backend_str, filler + "{} ({}/{}) Benchmarking {} on {} {}", + filler, current_index, total_count, bench_str, backend_str, filler ); let url = format!("{}benchmarks", super::USER_BENCHMARK_SERVER_URL); let mut args = vec![ @@ -244,7 +257,7 @@ pub(crate) fn run_backend_comparison_benchmarks( }; } println!( - "{}Benchmark Results{}\n\n{}", + "{} Benchmark Results {}\n\n{}", filler, filler, benchmark_results ); fs::remove_file(benchmark_results_file).ok(); diff --git a/backend-comparison/src/persistence/base.rs b/backend-comparison/src/persistence/base.rs index 7ca683ff0..cb0c9632f 100644 --- a/backend-comparison/src/persistence/base.rs +++ b/backend-comparison/src/persistence/base.rs @@ -255,7 +255,8 @@ impl Display for BenchmarkCollection { let mut max_feature_len = "Feature".len(); for record in self.records.iter() { max_name_len = max_name_len.max(record.results.name.len()); - max_backend_len = max_backend_len.max(record.backend.len()); + // + 2 because if the added backticks + max_backend_len = max_backend_len.max(record.backend.len() + 2); max_device_len = max_device_len.max(record.device.len()); max_feature_len = max_feature_len.max(record.feature.len()); } @@ -276,7 +277,7 @@ impl Display for BenchmarkCollection { "| {: