Make compatible with thumbv6m-none-eabi + add raspberry pi pico example (#2096)

* Made compatible with thumbv6m-none-eabi

* Added example of no_std on rp2040

* Added documentation on usage in no_std

* Rename rp2040 example and add README.md
This commit is contained in:
Bjorn Beishline 2024-08-23 04:39:39 -07:00 committed by GitHub
parent 48a64d3b8a
commit 17de832c6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 6173 additions and 155 deletions

267
Cargo.lock generated
View File

@ -44,7 +44,7 @@ dependencies = [
"getrandom", "getrandom",
"once_cell", "once_cell",
"version_check", "version_check",
"zerocopy 0.7.35", "zerocopy",
] ]
[[package]] [[package]]
@ -532,6 +532,7 @@ dependencies = [
"hashbrown 0.14.5", "hashbrown 0.14.5",
"log", "log",
"num-traits", "num-traits",
"portable-atomic-util",
"rand", "rand",
"regex", "regex",
"rmp-serde", "rmp-serde",
@ -671,9 +672,10 @@ dependencies = [
"derive-new", "derive-new",
"libm", "libm",
"matrixmultiply", "matrixmultiply",
"ndarray", "ndarray 0.16.1",
"num-traits", "num-traits",
"openblas-src", "openblas-src",
"portable-atomic-util",
"rand", "rand",
"spin", "spin",
] ]
@ -789,9 +791,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.7.0" version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca2be1d5c43812bae364ee3f30b3afcb7877cf59f4aeb94c66f313a41d2fac9" checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50"
[[package]] [[package]]
name = "bytesize" name = "bytesize"
@ -841,7 +843,7 @@ dependencies = [
"rand", "rand",
"rand_distr", "rand_distr",
"rayon", "rayon",
"safetensors 0.4.3", "safetensors 0.4.4",
"thiserror", "thiserror",
"yoke", "yoke",
"zip 1.1.4", "zip 1.1.4",
@ -894,9 +896,9 @@ dependencies = [
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.7" version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" checksum = "e9e8aabfac534be767c909e0690571677d49f41bd8465ae876fe043d52ba5292"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
@ -924,6 +926,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]] [[package]]
name = "chrono" name = "chrono"
version = "0.4.38" version = "0.4.38"
@ -1211,9 +1219,9 @@ dependencies = [
[[package]] [[package]]
name = "core-foundation-sys" name = "core-foundation-sys"
version = "0.8.6" version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]] [[package]]
name = "core-graphics" name = "core-graphics"
@ -1241,9 +1249,9 @@ dependencies = [
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.12" version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad"
dependencies = [ dependencies = [
"libc", "libc",
] ]
@ -1380,7 +1388,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl" name = "cubecl"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e" source = "git+https://github.com/tracel-ai/cubecl?rev=034f667da6e92a81b7da9f303e8507db944cc2a4#034f667da6e92a81b7da9f303e8507db944cc2a4"
dependencies = [ dependencies = [
"cubecl-core", "cubecl-core",
"cubecl-cuda", "cubecl-cuda",
@ -1391,7 +1399,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-common" name = "cubecl-common"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e" source = "git+https://github.com/tracel-ai/cubecl?rev=034f667da6e92a81b7da9f303e8507db944cc2a4#034f667da6e92a81b7da9f303e8507db944cc2a4"
dependencies = [ dependencies = [
"derive-new", "derive-new",
"getrandom", "getrandom",
@ -1406,7 +1414,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-core" name = "cubecl-core"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e" source = "git+https://github.com/tracel-ai/cubecl?rev=034f667da6e92a81b7da9f303e8507db944cc2a4#034f667da6e92a81b7da9f303e8507db944cc2a4"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-macros", "cubecl-macros",
@ -1421,7 +1429,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-cuda" name = "cubecl-cuda"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e" source = "git+https://github.com/tracel-ai/cubecl?rev=034f667da6e92a81b7da9f303e8507db944cc2a4#034f667da6e92a81b7da9f303e8507db944cc2a4"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common",
@ -1436,7 +1444,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-linalg" name = "cubecl-linalg"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e" source = "git+https://github.com/tracel-ai/cubecl?rev=034f667da6e92a81b7da9f303e8507db944cc2a4#034f667da6e92a81b7da9f303e8507db944cc2a4"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-core", "cubecl-core",
@ -1447,7 +1455,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-macros" name = "cubecl-macros"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e" source = "git+https://github.com/tracel-ai/cubecl?rev=034f667da6e92a81b7da9f303e8507db944cc2a4#034f667da6e92a81b7da9f303e8507db944cc2a4"
dependencies = [ dependencies = [
"derive-new", "derive-new",
"proc-macro2", "proc-macro2",
@ -1458,9 +1466,10 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-runtime" name = "cubecl-runtime"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e" source = "git+https://github.com/tracel-ai/cubecl?rev=034f667da6e92a81b7da9f303e8507db944cc2a4#034f667da6e92a81b7da9f303e8507db944cc2a4"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"cfg_aliases 0.2.1",
"cubecl-common", "cubecl-common",
"derive-new", "derive-new",
"dirs 5.0.1", "dirs 5.0.1",
@ -1477,7 +1486,7 @@ dependencies = [
[[package]] [[package]]
name = "cubecl-wgpu" name = "cubecl-wgpu"
version = "0.1.1" version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=bee7886b5c3016c425d244136f77442655097f3e#bee7886b5c3016c425d244136f77442655097f3e" source = "git+https://github.com/tracel-ai/cubecl?rev=034f667da6e92a81b7da9f303e8507db944cc2a4#034f667da6e92a81b7da9f303e8507db944cc2a4"
dependencies = [ dependencies = [
"async-channel", "async-channel",
"bytemuck", "bytemuck",
@ -1990,14 +1999,14 @@ dependencies = [
[[package]] [[package]]
name = "filetime" name = "filetime"
version = "0.2.23" version = "0.2.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" checksum = "bf401df4a4e3872c4fe8151134cf483738e74b67fc934d6532c882b3d24a4550"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
"redox_syscall 0.4.1", "libredox",
"windows-sys 0.52.0", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@ -2567,7 +2576,7 @@ dependencies = [
"futures-sink", "futures-sink",
"futures-util", "futures-util",
"http 0.2.12", "http 0.2.12",
"indexmap 2.2.6", "indexmap 2.4.0",
"slab", "slab",
"tokio", "tokio",
"tokio-util", "tokio-util",
@ -2586,7 +2595,7 @@ dependencies = [
"futures-core", "futures-core",
"futures-sink", "futures-sink",
"http 1.1.0", "http 1.1.0",
"indexmap 2.2.6", "indexmap 2.4.0",
"slab", "slab",
"tokio", "tokio",
"tokio-util", "tokio-util",
@ -2905,9 +2914,9 @@ dependencies = [
[[package]] [[package]]
name = "hyper-util" name = "hyper-util"
version = "0.1.6" version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
@ -3050,9 +3059,9 @@ dependencies = [
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.2.6" version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c"
dependencies = [ dependencies = [
"equivalent", "equivalent",
"hashbrown 0.14.5", "hashbrown 0.14.5",
@ -3174,9 +3183,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0"
[[package]] [[package]]
name = "js-sys" name = "js-sys"
version = "0.3.69" version = "0.3.70"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a"
dependencies = [ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
@ -3251,6 +3260,7 @@ checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [ dependencies = [
"bitflags 2.6.0", "bitflags 2.6.0",
"libc", "libc",
"redox_syscall 0.5.3",
] ]
[[package]] [[package]]
@ -3490,9 +3500,9 @@ dependencies = [
[[package]] [[package]]
name = "mio" name = "mio"
version = "1.0.1" version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec"
dependencies = [ dependencies = [
"hermit-abi 0.3.9", "hermit-abi 0.3.9",
"libc", "libc",
@ -3574,17 +3584,17 @@ dependencies = [
[[package]] [[package]]
name = "naga" name = "naga"
version = "22.0.0" version = "22.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09eeccb9b50f4f7839b214aa3e08be467159506a986c18e0702170ccf720a453" checksum = "8bd5a652b6faf21496f2cfd88fc49989c8db0825d1f6746b1a71a6ede24a63ad"
dependencies = [ dependencies = [
"arrayvec", "arrayvec",
"bit-set", "bit-set",
"bitflags 2.6.0", "bitflags 2.6.0",
"cfg_aliases", "cfg_aliases 0.1.1",
"codespan-reporting", "codespan-reporting",
"hexf-parse", "hexf-parse",
"indexmap 2.2.6", "indexmap 2.4.0",
"log", "log",
"rustc-hash", "rustc-hash",
"spirv", "spirv",
@ -3623,6 +3633,19 @@ name = "ndarray"
version = "0.15.6" version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"rawpointer",
]
[[package]]
name = "ndarray"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [ dependencies = [
"cblas-sys", "cblas-sys",
"libc", "libc",
@ -3630,6 +3653,8 @@ dependencies = [
"num-complex", "num-complex",
"num-integer", "num-integer",
"num-traits", "num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer", "rawpointer",
"rayon", "rayon",
] ]
@ -3958,9 +3983,9 @@ dependencies = [
[[package]] [[package]]
name = "object" name = "object"
version = "0.36.2" version = "0.36.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
@ -4436,7 +4461,7 @@ dependencies = [
"comfy-table", "comfy-table",
"either", "either",
"hashbrown 0.14.5", "hashbrown 0.14.5",
"indexmap 2.2.6", "indexmap 2.4.0",
"num-traits", "num-traits",
"once_cell", "once_cell",
"polars-arrow", "polars-arrow",
@ -4575,7 +4600,7 @@ dependencies = [
"either", "either",
"hashbrown 0.14.5", "hashbrown 0.14.5",
"hex", "hex",
"indexmap 2.2.6", "indexmap 2.4.0",
"memchr", "memchr",
"num-traits", "num-traits",
"polars-arrow", "polars-arrow",
@ -4725,7 +4750,7 @@ dependencies = [
"ahash", "ahash",
"bytemuck", "bytemuck",
"hashbrown 0.14.5", "hashbrown 0.14.5",
"indexmap 2.2.6", "indexmap 2.4.0",
"num-traits", "num-traits",
"once_cell", "once_cell",
"polars-error", "polars-error",
@ -4766,11 +4791,11 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.18" version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
dependencies = [ dependencies = [
"zerocopy 0.6.6", "zerocopy",
] ]
[[package]] [[package]]
@ -4884,7 +4909,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a16027030d4ec33e423385f73bb559821827e9ec18c50e7874e4d6de5a4e96f" checksum = "1a16027030d4ec33e423385f73bb559821827e9ec18c50e7874e4d6de5a4e96f"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"indexmap 2.2.6", "indexmap 2.4.0",
"log", "log",
"protobuf", "protobuf",
"protobuf-support", "protobuf-support",
@ -5094,9 +5119,9 @@ dependencies = [
[[package]] [[package]]
name = "ravif" name = "ravif"
version = "0.11.9" version = "0.11.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5797d09f9bd33604689e87e8380df4951d4912f01b63f71205e2abd4ae25e6b6" checksum = "a8f0bfd976333248de2078d350bfdf182ff96e168a24d23d2436cef320dd4bdd"
dependencies = [ dependencies = [
"avif-serialize", "avif-serialize",
"imgref", "imgref",
@ -5202,15 +5227,6 @@ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
] ]
[[package]]
name = "redox_syscall"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa"
dependencies = [
"bitflags 1.3.2",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.3" version = "0.5.3"
@ -5349,7 +5365,7 @@ dependencies = [
"once_cell", "once_cell",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"rustls-pemfile 2.1.2", "rustls-pemfile 2.1.3",
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
@ -5367,9 +5383,9 @@ dependencies = [
[[package]] [[package]]
name = "rgb" name = "rgb"
version = "0.8.45" version = "0.8.48"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ade4539f42266ded9e755c605bdddf546242b2c961b03b06a7375260788a0523" checksum = "0f86ae463694029097b846d8f99fd5536740602ae00022c0c50c5600720b2f71"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
] ]
@ -5520,7 +5536,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba"
dependencies = [ dependencies = [
"openssl-probe", "openssl-probe",
"rustls-pemfile 2.1.2", "rustls-pemfile 2.1.3",
"rustls-pki-types", "rustls-pki-types",
"schannel", "schannel",
"security-framework", "security-framework",
@ -5537,9 +5553,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls-pemfile" name = "rustls-pemfile"
version = "2.1.2" version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"rustls-pki-types", "rustls-pki-types",
@ -5547,9 +5563,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls-pki-types" name = "rustls-pki-types"
version = "1.7.0" version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
@ -5586,9 +5602,9 @@ dependencies = [
[[package]] [[package]]
name = "safetensors" name = "safetensors"
version = "0.4.3" version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ced76b22c7fba1162f11a5a75d9d8405264b467a07ae0c9c29be119b9297db9" checksum = "7725d4d98fa515472f43a6e2bbf956c48e06b89bb50593a040e5945160214450"
dependencies = [ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
@ -5615,9 +5631,9 @@ dependencies = [
[[package]] [[package]]
name = "scc" name = "scc"
version = "2.1.6" version = "2.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05ccfb12511cdb770157ace92d7dda771e498445b78f9886e8cdbc5140a4eced" checksum = "79da19444d9da7a9a82b80ecf059eceba6d3129d84a8610fd25ff2364f255466"
dependencies = [ dependencies = [
"sdd", "sdd",
] ]
@ -5648,9 +5664,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]] [[package]]
name = "sdd" name = "sdd"
version = "2.1.0" version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "177258b64c0faaa9ffd3c65cd3262c2bc7e2588dbbd9c1641d0346145c1bbda8" checksum = "0495e4577c672de8254beb68d01a9b62d0e8a13c099edecdbedccce3223cd29f"
[[package]] [[package]]
name = "security-framework" name = "security-framework"
@ -6219,7 +6235,7 @@ dependencies = [
"half", "half",
"lazy_static", "lazy_static",
"libc", "libc",
"ndarray", "ndarray 0.15.6",
"rand", "rand",
"safetensors 0.3.3", "safetensors 0.3.3",
"thiserror", "thiserror",
@ -6426,7 +6442,7 @@ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
"libc", "libc",
"mio 1.0.1", "mio 1.0.2",
"pin-project-lite", "pin-project-lite",
"socket2", "socket2",
"tokio-macros", "tokio-macros",
@ -6505,7 +6521,7 @@ version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1"
dependencies = [ dependencies = [
"indexmap 2.2.6", "indexmap 2.4.0",
"toml_datetime", "toml_datetime",
"winnow 0.5.40", "winnow 0.5.40",
] ]
@ -6516,7 +6532,7 @@ version = "0.22.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d"
dependencies = [ dependencies = [
"indexmap 2.2.6", "indexmap 2.4.0",
"serde", "serde",
"serde_spanned", "serde_spanned",
"toml_datetime", "toml_datetime",
@ -6555,15 +6571,15 @@ dependencies = [
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.2" version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]] [[package]]
name = "tower-service" name = "tower-service"
version = "0.3.2" version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]] [[package]]
name = "tracing" name = "tracing"
@ -6728,9 +6744,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]] [[package]]
name = "ureq" name = "ureq"
version = "2.10.0" version = "2.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea" checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"flate2", "flate2",
@ -6835,19 +6851,20 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]] [[package]]
name = "wasm-bindgen" name = "wasm-bindgen"
version = "0.2.92" version = "0.2.93"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"once_cell",
"wasm-bindgen-macro", "wasm-bindgen-macro",
] ]
[[package]] [[package]]
name = "wasm-bindgen-backend" name = "wasm-bindgen-backend"
version = "0.2.92" version = "0.2.93"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b"
dependencies = [ dependencies = [
"bumpalo", "bumpalo",
"log", "log",
@ -6860,9 +6877,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-futures" name = "wasm-bindgen-futures"
version = "0.4.42" version = "0.4.43"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys", "js-sys",
@ -6872,9 +6889,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-macro" name = "wasm-bindgen-macro"
version = "0.2.92" version = "0.2.93"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf"
dependencies = [ dependencies = [
"quote", "quote",
"wasm-bindgen-macro-support", "wasm-bindgen-macro-support",
@ -6882,9 +6899,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-macro-support" name = "wasm-bindgen-macro-support"
version = "0.2.92" version = "0.2.93"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -6895,9 +6912,9 @@ dependencies = [
[[package]] [[package]]
name = "wasm-bindgen-shared" name = "wasm-bindgen-shared"
version = "0.2.92" version = "0.2.93"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484"
[[package]] [[package]]
name = "wasm-logger" name = "wasm-logger"
@ -6927,9 +6944,9 @@ dependencies = [
[[package]] [[package]]
name = "web-sys" name = "web-sys"
version = "0.3.69" version = "0.3.70"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0"
dependencies = [ dependencies = [
"js-sys", "js-sys",
"wasm-bindgen", "wasm-bindgen",
@ -6967,7 +6984,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1d1c4ba43f80542cf63a0a6ed3134629ae73e8ab51e4b765a67f3aa062eb433" checksum = "e1d1c4ba43f80542cf63a0a6ed3134629ae73e8ab51e4b765a67f3aa062eb433"
dependencies = [ dependencies = [
"arrayvec", "arrayvec",
"cfg_aliases", "cfg_aliases 0.1.1",
"document-features", "document-features",
"js-sys", "js-sys",
"log", "log",
@ -6987,16 +7004,16 @@ dependencies = [
[[package]] [[package]]
name = "wgpu-core" name = "wgpu-core"
version = "22.0.0" version = "22.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0f191908a21968991463fcf3b42cb6c9648c0fb7fa301b8fc733bc21a9ed9bd" checksum = "0348c840d1051b8e86c3bcd31206080c5e71e5933dabd79be1ce732b0b2f089a"
dependencies = [ dependencies = [
"arrayvec", "arrayvec",
"bit-vec", "bit-vec",
"bitflags 2.6.0", "bitflags 2.6.0",
"cfg_aliases", "cfg_aliases 0.1.1",
"document-features", "document-features",
"indexmap 2.2.6", "indexmap 2.4.0",
"log", "log",
"naga", "naga",
"once_cell", "once_cell",
@ -7022,7 +7039,7 @@ dependencies = [
"bit-set", "bit-set",
"bitflags 2.6.0", "bitflags 2.6.0",
"block", "block",
"cfg_aliases", "cfg_aliases 0.1.1",
"core-graphics-types", "core-graphics-types",
"d3d12", "d3d12",
"glow", "glow",
@ -7102,11 +7119,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]] [[package]]
name = "winapi-util" name = "winapi-util"
version = "0.1.8" version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@ -7368,9 +7385,9 @@ dependencies = [
[[package]] [[package]]
name = "xml-rs" name = "xml-rs"
version = "0.8.20" version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "791978798f0597cfc70478424c2b4fdc2b7a8024aaff78497ef00f24ef674193" checksum = "539a77ee7c0de333dcc6da69b177380a0b81e0dacfa4f7344c465a36871ee601"
[[package]] [[package]]
name = "xtask" name = "xtask"
@ -7422,34 +7439,14 @@ dependencies = [
"synstructure", "synstructure",
] ]
[[package]]
name = "zerocopy"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6"
dependencies = [
"byteorder",
"zerocopy-derive 0.6.6",
]
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"
version = "0.7.35" version = "0.7.35"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
dependencies = [ dependencies = [
"zerocopy-derive 0.7.35", "byteorder",
] "zerocopy-derive",
[[package]]
name = "zerocopy-derive"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.74",
] ]
[[package]] [[package]]
@ -7534,7 +7531,7 @@ dependencies = [
"crc32fast", "crc32fast",
"crossbeam-utils", "crossbeam-utils",
"displaydoc", "displaydoc",
"indexmap 2.2.6", "indexmap 2.4.0",
"num_enum", "num_enum",
"thiserror", "thiserror",
] ]
@ -7555,7 +7552,7 @@ dependencies = [
"displaydoc", "displaydoc",
"flate2", "flate2",
"hmac", "hmac",
"indexmap 2.2.6", "indexmap 2.4.0",
"lzma-rs", "lzma-rs",
"memchr", "memchr",
"pbkdf2 0.12.2", "pbkdf2 0.12.2",
@ -7597,7 +7594,7 @@ version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
dependencies = [ dependencies = [
"zstd-safe 7.2.0", "zstd-safe 7.2.1",
] ]
[[package]] [[package]]
@ -7612,18 +7609,18 @@ dependencies = [
[[package]] [[package]]
name = "zstd-safe" name = "zstd-safe"
version = "7.2.0" version = "7.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa556e971e7b568dc775c136fc9de8c779b1c2fc3a63defaafadffdbd3181afa" checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
dependencies = [ dependencies = [
"zstd-sys", "zstd-sys",
] ]
[[package]] [[package]]
name = "zstd-sys" name = "zstd-sys"
version = "2.0.12+zstd.1.5.6" version = "2.0.13+zstd.1.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa"
dependencies = [ dependencies = [
"cc", "cc",
"pkg-config", "pkg-config",

View File

@ -16,6 +16,8 @@ members = [
exclude = [ exclude = [
"examples/notebook", "examples/notebook",
"examples/raspberry-pi-pico", # will cause dependency building issues otherwise
# "crates/burn-cuda", # comment this line to work on burn-cuda
] ]
[workspace.package] [workspace.package]
@ -72,7 +74,11 @@ serde_bytes = { version = "0.11.15", default-features = false, features = [
] } # alloc for no_std ] } # alloc for no_std
serde_rusqlite = "0.35.0" serde_rusqlite = "0.35.0"
serial_test = "3.1.1" serial_test = "3.1.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } spin = { version = "0.9.8", features = [
"mutex",
"spin_mutex",
"portable-atomic",
] }
strum = "0.26.3" strum = "0.26.3"
strum_macros = "0.26.4" strum_macros = "0.26.4"
syn = { version = "2.0.74", features = ["full", "extra-traits"] } syn = { version = "2.0.74", features = ["full", "extra-traits"] }
@ -118,7 +124,7 @@ half = { version = "2.4.1", features = [
"num-traits", "num-traits",
"serde", "serde",
], default-features = false } ], default-features = false }
ndarray = { version = "0.15.6", default-features = false } ndarray = { version = "0.16.0", default-features = false }
matrixmultiply = { version = "0.3.9", default-features = false } matrixmultiply = { version = "0.3.9", default-features = false }
openblas-src = "0.10.9" openblas-src = "0.10.9"
blas-src = { version = "0.10.0", default-features = false } blas-src = { version = "0.10.0", default-features = false }
@ -142,9 +148,11 @@ nvml-wrapper = "0.10.0"
sysinfo = "0.30.13" sysinfo = "0.30.13"
systemstat = "0.2.3" systemstat = "0.2.3"
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
### For the main burn branch. ### ### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" } cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "034f667da6e92a81b7da9f303e8507db944cc2a4" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" } cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "034f667da6e92a81b7da9f303e8507db944cc2a4" }
### For local development. ### ### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" } # cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" } # cubecl-common = { path = "../cubecl/crates/cubecl-common" }

View File

@ -336,3 +336,240 @@ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
## RP
**Source**:
- https://github.com/embassy-rs/embassy/blob/main/examples/rp/Cargo.toml
- https://github.com/embassy-rs/embassy/blob/main/examples/rp/build.rs
- https://github.com/embassy-rs/embassy/blob/main/examples/rp/memory.x
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright (c) Embassy project contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
MIT license
Copyright (c) Embassy project contributors
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

View File

@ -6,6 +6,7 @@ authors = [
"Dilshod Tadjibaev", "Dilshod Tadjibaev",
"Guillaume Lagrange", "Guillaume Lagrange",
"Sylvain Benner", "Sylvain Benner",
"Bjorn Beishline"
] ]
language = "en" language = "en"
multilingual = false multilingual = false

View File

@ -31,4 +31,4 @@
- [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md) - [Custom WGPU Kernel](./advanced/backend-extension/custom-wgpu-kernel.md)
- [Custom Optimizer]() - [Custom Optimizer]()
- [WebAssembly]() - [WebAssembly]()
- [No-Std]() - [No-Std](./advanced/no-std.md)

View File

@ -0,0 +1,96 @@
# No Standard Library
In this section, you will learn how to run an onnx inference model on an embedded system, with no standard library support on a Raspberry Pi Pico. This should be universally applicable to other platforms. All the code can be found under the
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/raspberry-pi-pico).
## Step-by-Step Guide
Let's walk through the process of running an embedded ONNX model:
### Setup
Follow the [embassy guide](https://embassy.dev/book/#_getting_started) for your specific environment. Once setup, you should have something similar to the following.
```
./inference
├── Cargo.lock
├── Cargo.toml
├── build.rs
├── memory.x
└── src
└── main.rs
```
Some other dependencies have to be added
```toml
[dependencies]
embedded-alloc = "0.5.1" # Only if there is no default allocator for your chip
burn = { version = "0.14", default-features = false, features = ["ndarray"] } # Backend must be ndarray
[build-dependencies]
burn-import = { version = "0.14" } # Used to auto generate the rust code to import the model
```
### Import the Model
Follow the directions to [import models](./import/README.md).
Use the following ModelGen config
```rs
ModelGen::new()
.input(my_model)
.out_dir("model/")
.record_type(RecordType::Bincode)
.embed_states(true)
.run_from_script();
```
### Global Allocator
First define a global allocator (if you are on a no_std system without alloc).
```rs
use embedded_alloc::Heap;
#[global_allocator]
static HEAP: Heap = Heap::empty();
#[embassy_executor::main]
async fn main(_spawner: Spawner) {
{
use core::mem::MaybeUninit;
const HEAP_SIZE: usize = 100 * 1024; // This is dependent on the model size in memory.
static mut HEAP_MEM: [MaybeUninit<u8>; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE];
unsafe { HEAP.init(HEAP_MEM.as_ptr() as usize, HEAP_SIZE) }
}
}
```
### Define Backend
We are using ndarray, so we just need to define the NdArray backend as usual
```rs
use burn::{backend::NdArray, tensor::Tensor};
type Backend = NdArray<f32>;
type BackendDeice = <Backend as burn::tensor::backend::Backend>::Device;
```
Then inside the `main` function add
```rs
use your_model::Model;
// Get a default device for the backend
let device = BackendDeice::default();
// Create a new model and load the state
let model: Model<Backend> = Model::default();
```
### Running the Model
To run the model, just call it as you would normally
```rs
// Define the tensor
let input = Tensor::<Backend, 2>::from_floats([[input]], &device);
// Run the model on the input
let output = model.forward(input);
```
## Conclusion
Running a model in a no_std environment is pretty much identical to a normal environment. All that is needed is a global allocator.

View File

@ -32,7 +32,7 @@ tokio = { workspace = true, optional = true }
# Parallel # Parallel
rayon = { workspace = true, optional = true } rayon = { workspace = true, optional = true }
cubecl-common = { workspace = true } cubecl-common = { workspace = true, default-features = false }
[dev-dependencies] [dev-dependencies]
dashmap = { workspace = true } dashmap = { workspace = true }

View File

@ -122,9 +122,6 @@ derive-new = { workspace = true }
log = { workspace = true, optional = true } log = { workspace = true, optional = true }
rand = { workspace = true, features = ["std_rng"] } # Default enables std rand = { workspace = true, features = ["std_rng"] } # Default enables std
# Using in place of use std::sync::Mutex when std is disabled
spin = { workspace = true, features = ["mutex", "spin_mutex"] }
# The same implementation of HashMap in std but with no_std support (only alloc crate is needed) # The same implementation of HashMap in std but with no_std support (only alloc crate is needed)
hashbrown = { workspace = true, features = ["serde"] } # no_std compatible hashbrown = { workspace = true, features = ["serde"] } # no_std compatible
@ -139,6 +136,10 @@ serde_json = { workspace = true, features = ["alloc"] } #Default enables std
thiserror = { workspace = true, optional = true } thiserror = { workspace = true, optional = true }
regex = { workspace = true, optional = true } regex = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
spin = { workspace = true } # Using in place of use std::sync::Mutex when std is disabled
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
[dev-dependencies] [dev-dependencies]
tempfile = { workspace = true } tempfile = { workspace = true }

View File

@ -5,9 +5,14 @@ use crate::module::{
}; };
use alloc::string::ToString; use alloc::string::ToString;
use alloc::sync::Arc;
use alloc::vec::Vec; use alloc::vec::Vec;
#[cfg(target_has_atomic = "ptr")]
use alloc::sync::Arc;
#[cfg(not(target_has_atomic = "ptr"))]
use portable_atomic_util::Arc;
use burn_common::stub::Mutex; use burn_common::stub::Mutex;
use burn_tensor::{ use burn_tensor::{
backend::{AutodiffBackend, Backend}, backend::{AutodiffBackend, Backend},

View File

@ -56,6 +56,9 @@ openblas-src = { workspace = true, optional = true }
rand = { workspace = true } rand = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex; spin = { workspace = true } # using in place of use std::sync::Mutex;
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic-util = { workspace = true }
[dev-dependencies] [dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.14.0", default-features = false, features = [ burn-autodiff = { path = "../burn-autodiff", version = "0.14.0", default-features = false, features = [
"export_tests", "export_tests",

View File

@ -612,7 +612,7 @@ fn arg<E: NdArrayElement, const D: usize>(
idx as i64 idx as i64
}); });
let output = output.into_shape(Dim(reshape.as_slice())).unwrap(); let output = output.to_shape(Dim(reshape.as_slice())).unwrap();
NdArrayTensor { NdArrayTensor {
array: output.into_shared(), array: output.into_shared(),

View File

@ -209,7 +209,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
}); });
let output = output let output = output
.into_shape([batch_size, out_channels, out_height, out_width]) .to_shape([batch_size, out_channels, out_height, out_width])
.unwrap() .unwrap()
.into_dyn() .into_dyn()
.into_shared(); .into_shared();
@ -437,7 +437,7 @@ pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
}); });
let output = output let output = output
.into_shape([batch_size, out_channels, out_depth, out_height, out_width]) .to_shape([batch_size, out_channels, out_depth, out_height, out_width])
.unwrap() .unwrap()
.into_dyn() .into_dyn()
.into_shared(); .into_shared();

View File

@ -70,10 +70,10 @@ macro_rules! reshape {
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim); let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() { let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
true => $array true => $array
.into_shape(dim) .to_shape(dim)
.expect("Safe to change shape without relayout") .expect("Safe to change shape without relayout")
.into_shared(), .into_shared(),
false => $array.reshape(dim), false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
}; };
let array = array.into_dyn(); let array = array.into_dyn();

View File

@ -6,12 +6,14 @@ The continuous integration (CI) should build with additional targets:
* `wasm32-unknown-unknown` - WebAssembly * `wasm32-unknown-unknown` - WebAssembly
* `thumbv7m-none-eabi` - ARM Cortex-M3 * `thumbv7m-none-eabi` - ARM Cortex-M3
* `thumbv6m-none-eabi` - ARM Cortex-M0+
Shell commands to build and test the package: Shell commands to build and test the package:
```sh ```sh
# install the new targets if not installed previously # install the new targets if not installed previously
rustup target add thumbv6m-none-eabi
rustup target add thumbv7m-none-eabi rustup target add thumbv7m-none-eabi
rustup target add wasm32-unknown-unknown rustup target add wasm32-unknown-unknown
@ -19,6 +21,7 @@ rustup target add wasm32-unknown-unknown
cargo build # regular build cargo build # regular build
cargo build --target thumbv7m-none-eabi cargo build --target thumbv7m-none-eabi
cargo build --target wasm32-unknown-unknown cargo build --target wasm32-unknown-unknown
RUSTFLAGS="--cfg portable_atomic_unsafe_assume_single_core" cargo build --target thumbv6m-none-eabi
# test # test
cargo test cargo test

View File

@ -0,0 +1,10 @@
[target.'cfg(all(target_arch = "arm", target_os = "none"))']
rustflags = ["--cfg", "portable_atomic_critical_section"]
runner = "probe-rs run --chip RP2040"
# runner = "elf2uf2-rs -d -s"
[build]
target = "thumbv6m-none-eabi" # Cortex-M0 and Cortex-M0+
[env]
DEFMT_LOG = "debug"

5336
examples/raspberry-pi-pico/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,37 @@
[package]
authors = ["Bjorn Beishline (@bjorntheprogrammer)"]
edition = "2021"
name = "raspberry-pi-pico"
license = "MIT OR Apache-2.0"
version = "0.1.0"
[dependencies]
embassy-embedded-hal = { version = "0.2.0", git = "https://github.com/embassy-rs/embassy.git", rev = "d3ff0b184861fecf7d11e14bb90a39711e10176d", features = ["defmt"] }
embassy-executor = { version = "0.6.0", git = "https://github.com/embassy-rs/embassy.git", rev = "d3ff0b184861fecf7d11e14bb90a39711e10176d", features = ["task-arena-size-98304", "arch-cortex-m", "executor-thread", "executor-interrupt", "defmt", "integrated-timers"] }
embassy-time = { version = "0.3.2", git = "https://github.com/embassy-rs/embassy.git", rev = "d3ff0b184861fecf7d11e14bb90a39711e10176d", features = ["defmt", "defmt-timestamp-uptime"] }
embassy-rp = { version = "0.2.0", git = "https://github.com/embassy-rs/embassy.git", rev = "d3ff0b184861fecf7d11e14bb90a39711e10176d", features = ["defmt", "unstable-pac", "time-driver", "critical-section-impl"] }
defmt = "0.3"
defmt-rtt = "0.4"
fixed = "1.23.1"
fixed-macro = "1.2"
#cortex-m = { version = "0.7.6", features = ["critical-section-single-core"] }
cortex-m = { version = "0.7.6", features = ["inline-asm"] }
cortex-m-rt = "0.7.0"
critical-section = "1.1"
panic-probe = { version = "0.3", features = ["print-defmt"] }
portable-atomic = { version = "1.5", features = ["critical-section"] }
embedded-alloc = "0.5.1"
burn = { path = "../../crates/burn", default-features = false, features = ["ndarray"] }
[build-dependencies]
burn-import = { path = "../../crates/burn-import" }
[profile.release]
debug = 2
[profile.dev]
lto = true
opt-level = "z"

View File

@ -0,0 +1,41 @@
# Running Onnx Inference on the Raspberry Pi Pico
This example shows how to run an inference on a no_std, no atomic pointer, and no heap environment.
## Setup
1. Install raspberry pi pico target `rustup target add thumbv6m-none-eabi`
2. Install [`probe-rs`](https://probe.rs/docs/getting-started/installation/). This is optional, install `elf2uf2-rs` to use the usb boot with `cargo install elf2uf2-rs`.
3. Have a [compatible probe](https://probe.rs/docs/getting-started/probe-setup/) to flash to the raspberry pi pico. This is optional, alternatively, modify `.cargo/config.toml` and uncomment the runner to use `elf2uf2-rs`.
If you are using `elfuf2-rs` logging will not go to your serial port, add logging by using `embassy-usb`.
## Running
Run as usual with `cargo run`
## Project Structure
The project is structured as follows
```
raspberry-pi-pico
├── Cargo.lock
├── Cargo.toml
├── README.md
├── build.rs
├── memory.x
├── src
│ ├── bin
│ │ └── main.rs
│ ├── lib.rs
│ └── model
│ ├── mod.rs
│ └── sine.onnx
└── tensorflow
├── requirements.txt
└── train.py
```
Everything is standard with any other cargo project except for the `memory.x`, the `model` directory, and the `tensorflow` directory.
The `memory.x` file contains the memory layout of the chip.
The `tensorflow` directory contains a python script which generates the onnx model using tensorflow, using the requirements from `requirements.txt`.
The onnx model will be outputted to `src/model/sine.onnx`. The `build.rs` script will generate a rust file which takes in the `sine.onnx` file and generates an import, which gets exposed in `mod.rs`.

View File

@ -0,0 +1,50 @@
//! This build script copies the `memory.x` file from the crate root into
//! a directory where the linker can always find it at build time.
//! For many projects this is optional, as the linker always searches the
//! project root directory -- wherever `Cargo.toml` is. However, if you
//! are using a workspace or have a more complicated build setup, this
//! build script becomes required. Additionally, by requesting that
//! Cargo re-run the build script whenever `memory.x` is changed,
//! updating `memory.x` ensures a rebuild of the application with the
//! new memory settings.
use burn_import::burn::graph::RecordType;
use burn_import::onnx::ModelGen;
use std::env;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
fn main() {
// Put `memory.x` in our output directory and ensure it's
// on the linker search path.
let out = &PathBuf::from(env::var_os("OUT_DIR").unwrap());
File::create(out.join("memory.x"))
.unwrap()
.write_all(include_bytes!("memory.x"))
.unwrap();
println!("cargo:rustc-link-search={}", out.display());
// By default, Cargo will re-run a build script whenever
// any file in the project changes. By specifying `memory.x`
// here, we ensure the build script is only re-run when
// `memory.x` is changed.
println!("cargo:rerun-if-changed=memory.x");
println!("cargo:rustc-link-arg-bins=--nmagic");
println!("cargo:rustc-link-arg-bins=-Tlink.x");
println!("cargo:rustc-link-arg-bins=-Tlink-rp.x");
println!("cargo:rustc-link-arg-bins=-Tdefmt.x");
generate_model();
}
fn generate_model() {
// Generate the model code from the ONNX file.
ModelGen::new()
.input("src/model/sine.onnx")
.out_dir("model/")
.record_type(RecordType::Bincode)
.embed_states(true)
.run_from_script();
}

View File

@ -0,0 +1,17 @@
MEMORY {
BOOT2 : ORIGIN = 0x10000000, LENGTH = 0x100
FLASH : ORIGIN = 0x10000100, LENGTH = 2048K - 0x100
/* Pick one of the two options for RAM layout */
/* OPTION A: Use all RAM banks as one big block */
/* Reasonable, unless you are doing something */
/* really particular with DMA or other concurrent */
/* access that would benefit from striping */
RAM : ORIGIN = 0x20000000, LENGTH = 264K
/* OPTION B: Keep the unstriped sections separate */
/* RAM: ORIGIN = 0x20000000, LENGTH = 256K */
/* SCRATCH_A: ORIGIN = 0x20040000, LENGTH = 4K */
/* SCRATCH_B: ORIGIN = 0x20041000, LENGTH = 4K */
}

View File

@ -0,0 +1,58 @@
#![no_std]
#![no_main]
use burn::{backend::NdArray, tensor::Tensor};
use defmt::*;
use embassy_executor::Spawner;
use onnx_inference_rp2040::sine::Model;
use {defmt_rtt as _, panic_probe as _};
use embassy_rp as _;
use embedded_alloc::Heap;
type Backend = NdArray<f32>;
type BackendDeice = <Backend as burn::tensor::backend::Backend>::Device;
#[global_allocator]
static HEAP: Heap = Heap::empty();
#[embassy_executor::main]
async fn main(_spawner: Spawner) {
{
use core::mem::MaybeUninit;
const HEAP_SIZE: usize = 100 * 1024;
static mut HEAP_MEM: [MaybeUninit<u8>; HEAP_SIZE] = [MaybeUninit::uninit(); HEAP_SIZE];
unsafe { HEAP.init(HEAP_MEM.as_ptr() as usize, HEAP_SIZE) }
}
// Get a default device for the backend
let device = BackendDeice::default();
// Create a new model and load the state
let model: Model<Backend> = Model::default();
// Define input
let mut input = 0.0;
loop {
if input > 2.0 { input = 0.0 }
input += 0.05;
// Run the model
let output = run_model(&model, &device, input);
// Output the values
match output.into_primitive().tensor().array.as_slice() {
Some(slice) => info!("input: {} - output: {}", input, slice),
None => defmt::panic!("Failed to get value")
};
}
}
fn run_model<'a>(model: &Model<NdArray>, device: &BackendDeice, input: f32) -> Tensor<Backend, 2> {
// Define the tensor
let input = Tensor::<Backend, 2>::from_floats([[input]], &device);
// Run the model on the input
let output = model.forward(input);
output
}

View File

@ -0,0 +1,4 @@
#![no_std]
pub mod model;
pub use model::sine;

View File

@ -0,0 +1,3 @@
pub mod sine {
include!(concat!(env!("OUT_DIR"), "/model/sine.rs"));
}

Binary file not shown.

View File

@ -0,0 +1,4 @@
tensorflow==2.15.1
tf2onnx==1.16.1
onnx==1.16.2
numpy==2.0.1

View File

@ -0,0 +1,82 @@
# Originally copied and modified from:
# https://github.com/tensorflow/tensorflow/blob/e0b19f6ef223af40e2e6d1d21b8464c1b2ebee8f/tensorflow/lite/micro/examples/hello_world/train/train_hello_world_model.ipynb
# under the following license: Apache License 2.0
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tf2onnx
import onnx
import math
from pathlib import Path
def main():
# Define paths to model files
MODELS_DIR = '../src/model/'
os.makedirs(MODELS_DIR, exist_ok=True)
MODEL_ONNX = MODELS_DIR + 'sine.onnx'
np.random.seed(1)
# Number of sample datapoints
SAMPLES = 1000
# Generate a uniformly distributed set of random numbers in the range from
# 0 to 2π, which covers a complete sine wave oscillation
x_values = np.random.uniform(
low=0, high=2*math.pi, size=SAMPLES).astype(np.float32)
# Shuffle the values to guarantee they're not in order
np.random.shuffle(x_values)
# Calculate the corresponding sine values
y_values = np.sin(x_values).astype(np.float32)
# Add a small random number to each y value to mimic real world data
y_values += 0.1 * np.random.randn(*y_values.shape)
# We'll use 60% of our data for training and 20% for testing. The remaining
# 20% will be used for validation. Calculate the indices of each section.
TRAIN_SPLIT = int(0.6 * SAMPLES)
TEST_SPLIT = int(0.2 * SAMPLES + TRAIN_SPLIT)
# Use np.split to chop our data into three parts.
# The second argument to np.split is an array of indices where the data
# will be split. We provide two indices, so the data will be divided into
# three chunks.
x_train, x_test, x_validate = np.split(x_values, [TRAIN_SPLIT, TEST_SPLIT])
y_train, y_test, y_validate = np.split(y_values, [TRAIN_SPLIT, TEST_SPLIT])
# Double check that our splits add up correctly
assert (x_train.size + x_validate.size + x_test.size) == SAMPLES
model = tf.keras.Sequential()
# First layer takes a scalar input and feeds it through 16 "neurons". The
# neurons decide whether to activate based on the 'relu' activation
# function.
model.add(keras.layers.Dense(16, activation='relu', input_shape=(1,)))
# The new second layer may help the network learn more complex
# representations
model.add(keras.layers.Dense(16, activation='relu'))
# Final layer is a single neuron, since we want to output a single value
model.add(keras.layers.Dense(1))
# Compile the model using a standard optimizer and loss function for
# regression
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
history = model.fit(x_train, y_train, epochs=500, batch_size=64,
validation_data=(x_validate, y_validate))
# Use from_function for tf functions
onnx_model, _ = tf2onnx.convert.from_keras(model, opset=16)
onnx.save(onnx_model, MODEL_ONNX)
print("Onnx model generated at", Path(MODEL_ONNX).absolute())
if __name__ == '__main__':
main()

View File

@ -22,6 +22,7 @@ use crate::{endgroup, group};
// Targets constants // Targets constants
const WASM32_TARGET: &str = "wasm32-unknown-unknown"; const WASM32_TARGET: &str = "wasm32-unknown-unknown";
const ARM_TARGET: &str = "thumbv7m-none-eabi"; const ARM_TARGET: &str = "thumbv7m-none-eabi";
const ARM_NO_ATOMIC_PTR_TARGET: &str = "thumbv6m-none-eabi";
#[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)] #[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)]
pub(crate) enum CheckType { pub(crate) enum CheckType {
@ -81,12 +82,12 @@ impl CheckType {
} }
/// Run cargo build command /// Run cargo build command
fn cargo_build(params: Params) { fn cargo_build(params: Params, envs: Option<HashMap<&str, String>>) {
// Run cargo build // Run cargo build
run_cargo( run_cargo(
"build", "build",
params + "--color=always", params + "--color=always",
HashMap::new(), envs.unwrap_or_default(),
"Failed to run cargo build", "Failed to run cargo build",
); );
} }
@ -155,7 +156,10 @@ fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]
group!("Checks: {} (no-std)", crate_name); group!("Checks: {} (no-std)", crate_name);
// Run cargo build --no-default-features // Run cargo build --no-default-features
cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); cargo_build(
Params::from(["-p", crate_name, "--no-default-features"]) + extra_args,
None,
);
// Run cargo test --no-default-features // Run cargo test --no-default-features
cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args); cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
@ -169,6 +173,7 @@ fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]
"--target", "--target",
WASM32_TARGET, WASM32_TARGET,
]) + extra_args, ]) + extra_args,
None,
); );
// Run cargo build --no-default-features --target thumbv7m-none-eabi // Run cargo build --no-default-features --target thumbv7m-none-eabi
@ -180,6 +185,22 @@ fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]
"--target", "--target",
ARM_TARGET, ARM_TARGET,
]) + extra_args, ]) + extra_args,
None,
);
// Run cargo build --no-default-features --target thumbv6m-none-eabi
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
ARM_NO_ATOMIC_PTR_TARGET,
]) + extra_args,
Some(HashMap::from([(
"RUSTFLAGS",
"--cfg portable_atomic_unsafe_assume_single_core".to_string(),
)])),
); );
endgroup!(); endgroup!();
@ -228,6 +249,9 @@ fn no_std_checks() {
// Install ARM target // Install ARM target
rustup_add_target(ARM_TARGET); rustup_add_target(ARM_TARGET);
// Install ARM no atomic ptr target
rustup_add_target(ARM_NO_ATOMIC_PTR_TARGET);
// Run checks for the following crates // Run checks for the following crates
build_and_test_no_std("burn", []); build_and_test_no_std("burn", []);
build_and_test_no_std("burn-core", []); build_and_test_no_std("burn-core", []);
@ -265,7 +289,7 @@ fn burn_dataset_features_std() {
group!("Checks: burn-dataset (all-features)"); group!("Checks: burn-dataset (all-features)");
// Run cargo build --all-features // Run cargo build --all-features
cargo_build(["-p", "burn-dataset", "--all-features"].into()); cargo_build(["-p", "burn-dataset", "--all-features"].into(), None);
// Run cargo test --all-features // Run cargo test --all-features
cargo_test(["-p", "burn-dataset", "--all-features"].into()); cargo_test(["-p", "burn-dataset", "--all-features"].into());
@ -334,7 +358,7 @@ fn std_checks() {
} }
group!("Checks: {}", member.name); group!("Checks: {}", member.name);
cargo_build(Params::from(["-p", &member.name])); cargo_build(Params::from(["-p", &member.name]), None);
cargo_test(Params::from(["-p", &member.name])); cargo_test(Params::from(["-p", &member.name]));
endgroup!(); endgroup!();
} }
@ -373,6 +397,7 @@ fn check_typos() {
// Run typos command as child process // Run typos command as child process
let typos = Command::new("typos") let typos = Command::new("typos")
.args(["--exclude", "**/*.onnx"])
.stdout(Stdio::inherit()) // Send stdout directly to terminal .stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()) // Send stderr directly to terminal .stderr(Stdio::inherit()) // Send stderr directly to terminal
.spawn() .spawn()