[backend-comparison] Add GitHub authentication to burnbench CLI (#1285)

* [backend-comparison] Add auth command to burnbench CLI

* [backend-comparison] Add --share argument to Burnbench CLI

* Cargo clippy fixes

* Fix typos

* Add comment to explain the FIVE_SECONDS constant

* Use num_args to force at least one arg value and make args required

In the run command, makes the --benches and --backends required
The manual check is no longer necessary

* Use and_then instead of match

* Simplify token verification

* Use map_or instead of match
This commit is contained in:
Sylvain Benner 2024-02-13 11:16:53 -05:00 committed by GitHub
parent 62809cdb30
commit 00b6c7d136
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 661 additions and 84 deletions

376
Cargo.lock generated
View File

@ -61,6 +61,12 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
@ -124,6 +130,25 @@ version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
[[package]]
name = "arboard"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aafb29b107435aa276664c1db8954ac27a6e105cdad3c88287a199eb0e313c08"
dependencies = [
"clipboard-win",
"core-graphics",
"image",
"log",
"objc",
"objc-foundation",
"objc_id",
"parking_lot 0.12.1",
"thiserror",
"winapi",
"x11rb",
]
[[package]]
name = "arrayvec"
version = "0.7.4"
@ -150,6 +175,17 @@ dependencies = [
"syn 2.0.48",
]
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi",
]
[[package]]
name = "autocfg"
version = "1.1.0"
@ -160,15 +196,19 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
name = "backend-comparison"
version = "0.13.0"
dependencies = [
"arboard",
"burn",
"burn-common",
"clap",
"clap 4.4.18",
"crossterm",
"derive-new",
"dirs 5.0.1",
"github-device-flow",
"rand",
"ratatui",
"reqwest",
"serde_json",
"serial_test",
"strum",
"strum_macros",
]
@ -340,12 +380,9 @@ dependencies = [
"dashmap",
"derive-new",
"getrandom",
"indicatif",
"rand",
"reqwest",
"serde",
"spin",
"tokio",
"uuid",
"web-time",
]
@ -402,7 +439,6 @@ dependencies = [
name = "burn-dataset"
version = "0.13.0"
dependencies = [
"burn-common",
"csv",
"derive-new",
"dirs 5.0.1",
@ -412,10 +448,12 @@ dependencies = [
"globwalk",
"hound",
"image",
"indicatif",
"r2d2",
"r2d2_sqlite",
"rand",
"rayon",
"reqwest",
"rmp-serde",
"rstest",
"rusqlite",
@ -427,6 +465,7 @@ dependencies = [
"strum_macros",
"tempfile",
"thiserror",
"tokio",
]
[[package]]
@ -590,9 +629,9 @@ dependencies = [
[[package]]
name = "bytemuck"
version = "1.14.3"
version = "1.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f"
checksum = "ed2490600f404f2b94c167e31d3ed1d5f3c225a0f3b80230053b3e0b7b962bd9"
dependencies = [
"bytemuck_derive",
]
@ -726,6 +765,20 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.52.0",
]
[[package]]
name = "cipher"
version = "0.4.4"
@ -738,31 +791,61 @@ dependencies = [
[[package]]
name = "clap"
version = "4.5.0"
version = "3.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80c21025abd42669a92efc996ef13cfb2c5c627858421ea58d5c3b331a6c134f"
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
dependencies = [
"atty",
"bitflags 1.3.2",
"clap_derive 3.2.25",
"clap_lex 0.2.4",
"indexmap 1.9.3",
"once_cell",
"strsim 0.10.0",
"termcolor",
"textwrap",
]
[[package]]
name = "clap"
version = "4.4.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c"
dependencies = [
"clap_builder",
"clap_derive",
"clap_derive 4.4.7",
]
[[package]]
name = "clap_builder"
version = "4.5.0"
version = "4.4.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "458bf1f341769dfcf849846f65dffdf9146daa56bcd2a47cb4e1de9915567c99"
checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim 0.11.0",
"clap_lex 0.6.0",
"strsim 0.10.0",
]
[[package]]
name = "clap_derive"
version = "4.5.0"
version = "3.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47"
checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008"
dependencies = [
"heck",
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "clap_derive"
version = "4.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442"
dependencies = [
"heck",
"proc-macro2",
@ -772,9 +855,29 @@ dependencies = [
[[package]]
name = "clap_lex"
version = "0.7.0"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce"
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
dependencies = [
"os_str_bytes",
]
[[package]]
name = "clap_lex"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1"
[[package]]
name = "clipboard-win"
version = "4.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7191c27c2357d9b7ef96baac1773290d4ca63b24205b82a3fd8a0637afcf0362"
dependencies = [
"error-code",
"str-buf",
"winapi",
]
[[package]]
name = "cmake"
@ -858,6 +961,19 @@ version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
[[package]]
name = "core-graphics"
version = "0.22.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"core-graphics-types",
"foreign-types 0.3.2",
"libc",
]
[[package]]
name = "core-graphics-types"
version = "0.1.3"
@ -1351,6 +1467,16 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "error-code"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64f18991e7bf11e7ffee451b5318b5c1a73c52d0d0ada6e5a3017c8c1ced6a21"
dependencies = [
"libc",
"str-buf",
]
[[package]]
name = "esaxx-rs"
version = "0.1.10"
@ -1766,6 +1892,16 @@ dependencies = [
"version_check",
]
[[package]]
name = "gethostname"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb65d4ba3173c56a500b555b532f72c42e8d1fe64962b518897f8959fae2c177"
dependencies = [
"libc",
"winapi",
]
[[package]]
name = "getrandom"
version = "0.2.12"
@ -1795,6 +1931,20 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "github-device-flow"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98852ab71f5613dac02a0d1b41f3ffaf993b69449904dd13a10575612a56074d"
dependencies = [
"chrono",
"clap 3.2.25",
"reqwest",
"serde",
"serde_derive",
"serde_json",
]
[[package]]
name = "gix-features"
version = "0.36.1"
@ -1874,7 +2024,7 @@ dependencies = [
"bstr",
"log",
"regex-automata",
"regex-syntax",
"regex-syntax 0.8.2",
]
[[package]]
@ -2061,6 +2211,15 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "hermit-abi"
version = "0.3.4"
@ -2214,6 +2373,29 @@ dependencies = [
"tokio-native-tls",
]
[[package]]
name = "iana-time-zone"
version = "0.1.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]]
name = "ident_case"
version = "1.0.1"
@ -2306,9 +2488,9 @@ dependencies = [
[[package]]
name = "indicatif"
version = "0.17.8"
version = "0.17.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3"
checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25"
dependencies = [
"console",
"instant",
@ -2572,6 +2754,15 @@ dependencies = [
"stable_deref_trait",
]
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]]
name = "metal"
version = "0.27.0"
@ -2752,6 +2943,18 @@ dependencies = [
"cmake",
]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if",
"libc",
"memoffset",
]
[[package]]
name = "nom"
version = "7.1.3"
@ -2823,7 +3026,7 @@ version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi",
"hermit-abi 0.3.4",
"libc",
]
@ -2875,6 +3078,17 @@ dependencies = [
"objc_exception",
]
[[package]]
name = "objc-foundation"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1add1b659e36c9607c7aab864a76c7a4c2760cd0cd2e120f3fb8b952c7e22bf9"
dependencies = [
"block",
"objc",
"objc_id",
]
[[package]]
name = "objc_exception"
version = "0.1.2"
@ -2884,6 +3098,15 @@ dependencies = [
"cc",
]
[[package]]
name = "objc_id"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c92d4ddb4bd7b50d730c215ff871754d0da6b2178849f8a2a2ab69712d0c073b"
dependencies = [
"objc",
]
[[package]]
name = "object"
version = "0.32.2"
@ -3017,6 +3240,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "os_str_bytes"
version = "6.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
[[package]]
name = "overload"
version = "0.1.1"
@ -3197,6 +3426,30 @@ dependencies = [
"yansi",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn 1.0.109",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro2"
version = "1.0.78"
@ -3499,7 +3752,7 @@ dependencies = [
"aho-corasick",
"memchr",
"regex-automata",
"regex-syntax",
"regex-syntax 0.8.2",
]
[[package]]
@ -3510,9 +3763,15 @@ checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax",
"regex-syntax 0.8.2",
]
[[package]]
name = "regex-syntax"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
[[package]]
name = "regex-syntax"
version = "0.8.2"
@ -4119,6 +4378,12 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "str-buf"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e08d8363704e6c71fc928674353e6b7c23dcea9d82d7012c8faf2a3a025f8d0"
[[package]]
name = "strsim"
version = "0.9.3"
@ -4131,12 +4396,6 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "strsim"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01"
[[package]]
name = "strum"
version = "0.25.0"
@ -4350,19 +4609,25 @@ dependencies = [
]
[[package]]
name = "thiserror"
version = "1.0.57"
name = "textwrap"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d"
[[package]]
name = "thiserror"
version = "1.0.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.57"
version = "1.0.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471"
dependencies = [
"proc-macro2",
"quote",
@ -4449,16 +4714,16 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokenizers"
version = "0.15.2"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dd47962b0ba36e7fd33518fbf1754d136fd1474000162bbf2a8b5fcb2d3654d"
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
dependencies = [
"aho-corasick",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"itertools 0.12.1",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
@ -4469,7 +4734,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
@ -5057,6 +5322,15 @@ dependencies = [
"winapi",
]
[[package]]
name = "winapi-wsapoll"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44c17110f57155602a80dca10be03852116403c9ff3cd25b079d666f2aa3df6e"
dependencies = [
"winapi",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
@ -5236,6 +5510,28 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "x11rb"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1641b26d4dec61337c35a1b1aaf9e3cba8f46f0b43636c609ab0291a648040a"
dependencies = [
"gethostname",
"nix",
"winapi",
"winapi-wsapoll",
"x11rb-protocol",
]
[[package]]
name = "x11rb-protocol"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82d6c3f9a0fb6701fab8f6cea9b0c0bd5d6876f1f89f7fada07e558077c344bc"
dependencies = [
"nix",
]
[[package]]
name = "xattr"
version = "1.3.1"
@ -5258,7 +5554,7 @@ name = "xtask"
version = "0.4.0"
dependencies = [
"anyhow",
"clap",
"clap 4.4.18",
"derive_more",
"env_logger",
"log",

View File

@ -106,6 +106,10 @@ text_placeholder = "0.5.0"
pollster = "0.3"
wgpu = "0.18.0"
# Burnbench
arboard = "3.3.0"
github-device-flow = "0.2.0"
bincode = { version = "2.0.0-rc.3", features = [
"alloc",
"serde",

View File

@ -26,20 +26,23 @@ wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
[dependencies]
arboard = { workspace = true }
burn = { path = "../burn", default-features = false }
burn-common = { path = "../burn-common", version = "0.13.0" }
clap = { workspace = true }
crossterm = { workspace = true, optional = true }
derive-new = { workspace = true }
dirs = { workspace = true }
github-device-flow = { workspace = true }
rand = { workspace = true }
ratatui = { workspace = true, optional = true }
reqwest = {workspace = true, features = ["blocking", "json"]}
serde_json = { workspace = true }
strum = { workspace = true }
strum_macros = { workspace = true }
[dev-dependencies]
serial_test = { workspace = true }
[[bench]]
name = "unary"

View File

@ -16,6 +16,10 @@ The end of options argument `--` is used to pass arguments to the `burnbench`
application. For instance `cargo run --bin burnbench -- list` passes the `list`
argument to `burnbench` effectively calling `burnbench list`.
### Commands
#### List benches and backends
To list all the available benches and backends use the `list` command:
```sh
@ -43,7 +47,9 @@ Available Benchmarks:
- unary
```
To execute a given benchmark against a specific backend we use the `run` command
#### Run benchmarks
To run a given benchmark against a specific backend we use the `run` command
with the arguments `--benches` and `--backends` respectively. In the following
example we execute the `unary` benchmark against the `wgpu-fusion` backend:
@ -72,9 +78,35 @@ Executing the following benchmark and backend combinations (Total: 4):
Running benchmarks...
```
#### Authentication and benchmarks sharing
Burnbench can upload benchmark results to our servers so that users can share
their results with the community and we can use this information to drive the
development of Burn.
Sharing results is opt-in and it is enabled with the `--share` arguments passed
to the `run` command:
```sh
> cargo run --bin burnbench -- run --share --benches unary --backends wgpu-fusion
```
To be able to upload results you must be authenticated. We only support GitHub
authentication. To authenticate run the `auth` command, then follow the URL
to enter your device code and authorize the Burnbench application:
```sh
> cargo run --bin burnbench -- run auth
```
If everything is fine you should get a confirmation in the terminal that your
token has been saved to the burn cache directory.
You can now use the `--share` argument to upload and share your benchmarks!
### Terminal UI
This is a work in progress.
This is a work in progress and is not usable for now.
## Execute benchmarks with cargo

View File

@ -1,5 +1,5 @@
use backend_comparison::burnbenchapp;
fn main() {
burnbenchapp::run()
burnbenchapp::execute();
}

View File

@ -0,0 +1,175 @@
use reqwest;
use std::io::Write;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use std::{
fs::{self, File},
path::{Path, PathBuf},
};
pub(crate) static CLIENT_ID: &str = "Iv1.84002254a02791f3";
static GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version";
static GITHUB_API_VERSION: &str = "2022-11-28";
/// Return the file path for the auth cache on disk
pub(crate) fn get_auth_cache_file_path() -> PathBuf {
let home_dir = dirs::home_dir().expect("an home directory should exist");
let path_dir = home_dir.join(".cache").join("burn").join("burnbench");
#[cfg(test)]
let path_dir = path_dir.join("test");
let path = Path::new(&path_dir);
path.join("token.txt")
}
/// Returns true if the token is still valid
pub(crate) fn verify_token(token: &str) -> bool {
let client = reqwest::blocking::Client::new();
let response = client
.get("https://api.github.com/user")
.header(reqwest::header::USER_AGENT, "burnbench")
.header(reqwest::header::ACCEPT, "application/vnd.github+json")
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", token))
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.send();
response.map_or(false, |resp| resp.status().is_success())
}
/// Save token in Burn cache directory and adjust file permissions
pub(crate) fn save_token(token: &str) {
let path = get_auth_cache_file_path();
fs::create_dir_all(path.parent().expect("path should have a parent directory"))
.expect("directory should be created");
let mut file = File::create(&path).expect("file should be created");
write!(file, "{}", token).expect("token should be written to file");
// On unix systems we lower the permissions on the cache file to be readable
// just by the current user
#[cfg(unix)]
fs::set_permissions(&path, fs::Permissions::from_mode(0o600))
.expect("permissions should be set to 600");
println!("✅ Token saved at location: {}", path.to_str().unwrap());
}
/// Return the token saved in the cache file
#[inline]
pub(crate) fn get_token_from_cache() -> Option<String> {
let path = get_auth_cache_file_path();
fs::read_to_string(path)
.ok()
.and_then(|contents| contents.lines().next().map(str::to_string))
}
#[cfg(test)]
use serial_test::serial;
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn cleanup_test_environment() {
let path = get_auth_cache_file_path();
if path.exists() {
fs::remove_file(&path).expect("should be able to delete the token file");
}
let parent_dir = path
.parent()
.expect("token file should have a parent directory");
if parent_dir.exists() {
fs::remove_dir_all(parent_dir).expect("should be able to delete the cache directory");
}
}
#[test]
#[serial]
fn test_save_token_when_file_does_not_exist() {
cleanup_test_environment();
let token = "unique_test_token";
// Ensure the file does not exist
let path = get_auth_cache_file_path();
if path.exists() {
fs::remove_file(&path).unwrap();
}
save_token(token);
assert_eq!(fs::read_to_string(path).unwrap(), token);
cleanup_test_environment();
}
#[test]
#[serial]
fn test_overwrite_saved_token_when_file_already_exists() {
cleanup_test_environment();
let initial_token = "initial_test_token";
let new_token = "new_test_token";
// Save initial token
save_token(initial_token);
// Save new token that should overwrite the initial one
save_token(new_token);
let path = get_auth_cache_file_path();
assert_eq!(fs::read_to_string(path).unwrap(), new_token);
cleanup_test_environment();
}
#[test]
#[serial]
fn test_get_saved_token_from_cache_when_it_exists() {
cleanup_test_environment();
let token = "existing_test_token";
// Save the token first
save_token(token);
// Now retrieve it
let retrieved_token = get_token_from_cache().unwrap();
assert_eq!(retrieved_token, token);
cleanup_test_environment();
}
#[test]
#[serial]
fn test_return_only_first_line_of_cache_as_token() {
cleanup_test_environment();
let path = get_auth_cache_file_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).expect("directory tree should be created");
}
// Create a file with multiple lines
let mut file = File::create(&path).expect("test file should be created");
write!(file, "first_line_token\nsecond_line\nthird_line")
.expect("test file should contain several lines");
// Test that only the first line is returned as the token
let token = get_token_from_cache().expect("token should be present");
assert_eq!(
token, "first_line_token",
"The token should match only the first line of the file"
);
cleanup_test_environment();
}
#[test]
#[serial]
fn test_return_none_when_cache_file_does_not_exist() {
cleanup_test_environment();
let path = get_auth_cache_file_path();
// Ensure the file does not exist
if path.exists() {
fs::remove_file(&path).unwrap();
}
assert!(get_token_from_cache().is_none());
cleanup_test_environment();
}
#[test]
#[serial]
fn test_return_none_when_cache_file_exists_but_is_empty() {
cleanup_test_environment();
// Create an empty file
let path = get_auth_cache_file_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).expect("directory tree should be created");
}
File::create(&path).expect("empty file should be created");
assert!(
get_token_from_cache().is_none(),
"Expected None for empty cache file, got Some"
);
cleanup_test_environment();
}
}

View File

@ -1,9 +1,21 @@
use arboard::Clipboard;
use clap::{Parser, Subcommand, ValueEnum};
use std::process::{Command, Stdio};
use github_device_flow::{self, DeviceFlow};
use std::{
process::{Command, Stdio},
thread, time,
};
use strum::IntoEnumIterator;
use strum_macros::{Display, EnumIter};
use super::App;
use crate::burnbenchapp::auth::{get_token_from_cache, verify_token};
use super::{
auth::{save_token, CLIENT_ID},
App,
};
const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);
/// Base trait to define an application
pub(crate) trait Application {
@ -24,6 +36,8 @@ struct Args {
#[derive(Subcommand, Debug)]
enum Commands {
/// Authenticate using GitHub
Auth,
/// List all available benchmarks and backends
List,
/// Runs benchmarks
@ -32,12 +46,16 @@ enum Commands {
#[derive(Parser, Debug)]
struct RunArgs {
/// Comma-separated list of backends to include
#[clap(short = 'B', long = "backends", value_name = "BACKEND,BACKEND,...", num_args(0..))]
/// Share the benchmark results by uploading them to Burn servers
#[clap(short = 's', long = "share")]
share: bool,
/// Space separated list of backends to include
#[clap(short = 'B', long = "backends", value_name = "BACKEND BACKEND ...", num_args(1..), required = true)]
backends: Vec<BackendValues>,
/// Comma-separated list of benches to run
#[clap(short = 'b', long = "benches", value_name = "BACKEND,BACKEND,...", num_args(0..))]
/// Space separated list of benches to run
#[clap(short = 'b', long = "benches", value_name = "BENCH BENCH ...", num_args(1..), required = true)]
benches: Vec<BenchmarkValues>,
}
@ -81,45 +99,92 @@ pub(crate) enum BenchmarkValues {
Unary,
}
pub fn run() {
pub fn execute() {
let args = Args::parse();
match args.command {
Commands::List => {
println!("Available Backends:");
for backend in BackendValues::iter() {
println!("- {}", backend);
}
Commands::Auth => command_auth(),
Commands::List => command_list(),
Commands::Run(run_args) => command_run(run_args),
}
}
println!("\nAvailable Benchmarks:");
for bench in BenchmarkValues::iter() {
println!("- {}", bench);
}
/// Create an access token from GitHub Burnbench application and store it
/// to be used with the user benchmark backend.
fn command_auth() {
let mut flow = match DeviceFlow::start(CLIENT_ID, None) {
Ok(flow) => flow,
Err(e) => {
eprintln!("Error authenticating: {}", e);
return;
}
Commands::Run(run_args) => {
if run_args.backends.is_empty() || run_args.benches.is_empty() {
println!("No backends or benchmarks specified. Please select at least one backend and one benchmark.");
return;
}
let total_combinations = run_args.backends.len() * run_args.benches.len();
println!(
"Executing the following benchmark and backend combinations (Total: {}):",
total_combinations
);
for backend in &run_args.backends {
for bench in &run_args.benches {
println!("- Benchmark: {}, Backend: {}", bench, backend);
}
}
let mut app = App::new();
app.init();
println!("Running benchmarks...");
app.run(&run_args.benches, &run_args.backends);
app.cleanup();
println!("Cleanup completed. Benchmark run(s) finished.");
};
println!("🌐 Please visit for following URL in your browser (CTRL+click if your terminal supports it):");
println!("\n {}\n", flow.verification_uri.clone().unwrap());
let user_code = flow.user_code.clone().unwrap();
println!("👉 And enter code: {}", &user_code);
if let Ok(mut clipboard) = Clipboard::new() {
if clipboard.set_text(user_code).is_ok() {
println!("📋 Code has been successfully copied to clipboard.")
};
};
// Wait for the minimum allowed interval to poll for authentication update
// see: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#step-3-app-polls-github-to-check-if-the-user-authorized-the-device
thread::sleep(FIVE_SECONDS);
match flow.poll(20) {
Ok(creds) => {
save_token(&creds.token);
}
Err(e) => eprint!("Authentication error: {}", e),
};
}
fn command_list() {
println!("Available Backends:");
for backend in BackendValues::iter() {
println!("- {}", backend);
}
println!("\nAvailable Benchmarks:");
for bench in BenchmarkValues::iter() {
println!("- {}", bench);
}
}
fn command_run(run_args: RunArgs) {
if run_args.share {
// Verify if a token is saved
let token = get_token_from_cache();
if token.is_none() {
eprintln!("You need to be authenticated to be able to share benchmark results.");
eprintln!("Run the command 'burnbench auth' to authenticate.");
return;
}
// TODO refresh the token when it is expired
// Check for the validity of the saved token
if !verify_token(&token.unwrap()) {
eprintln!("Your access token is no longer valid.");
eprintln!("Run the command 'burnbench auth' again to get a new token.");
return;
}
}
let total_combinations = run_args.backends.len() * run_args.benches.len();
println!(
"Executing the following benchmark and backend combinations (Total: {}):",
total_combinations
);
for backend in &run_args.backends {
for bench in &run_args.benches {
println!("- Benchmark: {}, Backend: {}", bench, backend);
}
}
let mut app = App::new();
app.init();
println!("Running benchmarks...");
app.run(&run_args.benches, &run_args.backends);
app.cleanup();
println!("Cleanup completed. Benchmark run(s) finished.");
if run_args.share {
println!("Sharing results...");
// TODO Post the results once backend can verify the GitHub access token
}
}

View File

@ -1,4 +1,6 @@
mod auth;
mod base;
pub use base::*;
#[cfg(feature = "tui")]