From 587b8f80b3a0f77ce0d75a377c0adf5943537de6 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 30 Apr 2024 09:46:29 -0400 Subject: [PATCH] First draft CUDA runtime (#1685) Initial cuda runtime crate with a WIP compiler. --- Cargo.lock | 200 +++++--- backend-comparison/Cargo.toml | 2 +- crates/burn-compute/src/server.rs | 2 +- crates/burn-core/Cargo.toml | 2 +- crates/burn-cuda/Cargo.toml | 40 ++ crates/burn-cuda/README.md | 5 + crates/burn-cuda/src/compiler/base.rs | 442 ++++++++++++++++ crates/burn-cuda/src/compiler/binary.rs | 483 ++++++++++++++++++ crates/burn-cuda/src/compiler/body.rs | 81 +++ crates/burn-cuda/src/compiler/element.rs | 309 +++++++++++ crates/burn-cuda/src/compiler/instruction.rs | 233 +++++++++ crates/burn-cuda/src/compiler/mod.rs | 16 + crates/burn-cuda/src/compiler/settings.rs | 6 + crates/burn-cuda/src/compiler/shader.rs | 155 ++++++ crates/burn-cuda/src/compiler/unary.rs | 210 ++++++++ crates/burn-cuda/src/compute/mod.rs | 5 + crates/burn-cuda/src/compute/server.rs | 226 ++++++++ crates/burn-cuda/src/compute/storage.rs | 118 +++++ crates/burn-cuda/src/device.rs | 12 + crates/burn-cuda/src/element.rs | 42 ++ crates/burn-cuda/src/lib.rs | 29 ++ crates/burn-cuda/src/runtime.rs | 81 +++ crates/burn-jit/Cargo.toml | 1 + crates/burn-jit/src/codegen/compiler.rs | 8 +- .../src/codegen/dialect/gpu/macros.rs | 12 + .../src/codegen/dialect/gpu/operation.rs | 2 + .../src/codegen/dialect/gpu/procedure/base.rs | 19 +- .../codegen/dialect/gpu/procedure/index.rs | 74 +++ .../src/codegen/dialect/gpu/procedure/mod.rs | 2 + .../src/codegen/dialect/gpu/shader.rs | 5 +- .../src/codegen/dialect/gpu/variable.rs | 3 +- .../src/codegen/dialect/gpu/vectorization.rs | 4 + crates/burn-jit/src/codegen/kernel.rs | 13 + crates/burn-jit/src/compute/kernel.rs | 9 +- crates/burn-jit/src/element.rs | 21 + crates/burn-jit/src/fusion/kernel.rs | 18 +- crates/burn-jit/src/fusion/tracing/builder.rs | 16 + crates/burn-jit/src/kernel/matmul/base.rs | 1 + crates/burn-jit/src/lib.rs | 2 +- crates/burn-jit/src/runtime.rs | 5 + crates/burn-jit/src/template/base.rs | 2 + .../burn-tensor/src/tests/clone_invariance.rs | 8 +- crates/burn-wgpu/src/compiler/wgsl/body.rs | 2 +- .../burn-wgpu/src/compiler/wgsl/compiler.rs | 28 + crates/burn-wgpu/src/compiler/wgsl/shader.rs | 9 +- crates/burn/Cargo.toml | 2 +- xtask/src/runchecks.rs | 47 +- xtask/src/utils/mod.rs | 2 +- xtask/src/utils/workspace.rs | 4 +- 49 files changed, 2910 insertions(+), 108 deletions(-) create mode 100644 crates/burn-cuda/Cargo.toml create mode 100644 crates/burn-cuda/README.md create mode 100644 crates/burn-cuda/src/compiler/base.rs create mode 100644 crates/burn-cuda/src/compiler/binary.rs create mode 100644 crates/burn-cuda/src/compiler/body.rs create mode 100644 crates/burn-cuda/src/compiler/element.rs create mode 100644 crates/burn-cuda/src/compiler/instruction.rs create mode 100644 crates/burn-cuda/src/compiler/mod.rs create mode 100644 crates/burn-cuda/src/compiler/settings.rs create mode 100644 crates/burn-cuda/src/compiler/shader.rs create mode 100644 crates/burn-cuda/src/compiler/unary.rs create mode 100644 crates/burn-cuda/src/compute/mod.rs create mode 100644 crates/burn-cuda/src/compute/server.rs create mode 100644 crates/burn-cuda/src/compute/storage.rs create mode 100644 crates/burn-cuda/src/device.rs create mode 100644 crates/burn-cuda/src/element.rs create mode 100644 crates/burn-cuda/src/lib.rs create mode 100644 crates/burn-cuda/src/runtime.rs create mode 100644 crates/burn-jit/src/codegen/dialect/gpu/procedure/index.rs diff --git a/Cargo.lock b/Cargo.lock index 0d3dd4aba..e69bfdc1c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,19 +132,18 @@ checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" [[package]] name = "arboard" -version = "3.3.2" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2041f1943049c7978768d84e6d0fd95de98b76d6c4727b09e78ec253d29fa58" +checksum = "9fb4009533e8ff8f1450a5bcbc30f4242a1d34442221f72314bea1f5dc9c7f89" dependencies = [ "clipboard-win", "core-graphics", - "image", + "image 0.25.1", "log", - "objc", - "objc-foundation", - "objc_id", - "parking_lot 0.12.1", - "thiserror", + "objc2", + "objc2-app-kit", + "objc2-foundation", + "parking_lot 0.12.2", "windows-sys 0.48.0", "x11rb", ] @@ -339,6 +338,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block2" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43ff7d91d3c1d568065b06c899777d1e48dcf76103a672a0adbc238a7f247f1e" +dependencies = [ + "objc2", +] + [[package]] name = "bstr" version = "1.9.1" @@ -453,6 +461,22 @@ dependencies = [ "thiserror", ] +[[package]] +name = "burn-cuda" +version = "0.14.0" +dependencies = [ + "burn-common", + "burn-compute", + "burn-fusion", + "burn-jit", + "burn-tensor", + "bytemuck", + "cudarc", + "derive-new", + "half", + "log", +] + [[package]] name = "burn-dataset" version = "0.14.0" @@ -466,7 +490,7 @@ dependencies = [ "gix-tempfile", "globwalk", "hound", - "image", + "image 0.24.9", "r2d2", "r2d2_sqlite", "rand", @@ -549,6 +573,7 @@ dependencies = [ "burn-tensor-testgen", "bytemuck", "derive-new", + "half", "hashbrown 0.14.5", "log", "num-traits", @@ -1116,7 +1141,7 @@ dependencies = [ "crossterm_winapi", "libc", "mio", - "parking_lot 0.12.1", + "parking_lot 0.12.2", "signal-hook", "signal-hook-mio", "winapi", @@ -1287,7 +1312,7 @@ dependencies = [ "hashbrown 0.14.5", "lock_api", "once_cell", - "parking_lot_core 0.9.9", + "parking_lot_core 0.9.10", ] [[package]] @@ -1559,9 +1584,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" [[package]] name = "fdeflate" @@ -1586,9 +1611,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.29" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4556222738635b7a3417ae6130d8f52201e45a0c4d1a907f0826383adb5f85e7" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" dependencies = [ "crc32fast", "miniz_oxide", @@ -1719,7 +1744,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.12.1", + "parking_lot 0.12.2", ] [[package]] @@ -1996,7 +2021,7 @@ dependencies = [ "gix-fs", "libc", "once_cell", - "parking_lot 0.12.1", + "parking_lot 0.12.2", "signal-hook", "signal-hook-registry", "tempfile", @@ -2439,6 +2464,19 @@ dependencies = [ "tiff", ] +[[package]] +name = "image" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +dependencies = [ + "bytemuck", + "byteorder", + "num-traits", + "png", + "tiff", +] + [[package]] name = "image-classification-web" version = "0.14.0" @@ -2669,9 +2707,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" dependencies = [ "autocfg", "scopeguard", @@ -3055,14 +3093,58 @@ dependencies = [ ] [[package]] -name = "objc-foundation" -version = "0.1.1" +name = "objc-sys" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1add1b659e36c9607c7aab864a76c7a4c2760cd0cd2e120f3fb8b952c7e22bf9" +checksum = "da284c198fb9b7b0603f8635185e85fbd5b64ee154b1ed406d489077de2d6d60" + +[[package]] +name = "objc2" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4b25e1034d0e636cd84707ccdaa9f81243d399196b8a773946dcffec0401659" dependencies = [ - "block", - "objc", - "objc_id", + "objc-sys", + "objc2-encode", +] + +[[package]] +name = "objc2-app-kit" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb79768a710a9a1798848179edb186d1af7e8a8679f369e4b8d201dd2a034047" +dependencies = [ + "block2", + "objc2", + "objc2-core-data", + "objc2-foundation", +] + +[[package]] +name = "objc2-core-data" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e092bc42eaf30a08844e6a076938c60751225ec81431ab89f5d1ccd9f958d6c" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-encode" +version = "4.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88658da63e4cc2c8adb1262902cd6af51094df0488b760d6fd27194269c0950a" + +[[package]] +name = "objc2-foundation" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfaefe14254871ea16c7d88968c0ff14ba554712a20d76421eec52f0a7fb8904" +dependencies = [ + "block2", + "objc2", ] [[package]] @@ -3074,15 +3156,6 @@ dependencies = [ "cc", ] -[[package]] -name = "objc_id" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c92d4ddb4bd7b50d730c215ff871754d0da6b2178849f8a2a2ab69712d0c073b" -dependencies = [ - "objc", -] - [[package]] name = "object" version = "0.32.2" @@ -3252,12 +3325,12 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.12.1" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" dependencies = [ "lock_api", - "parking_lot_core 0.9.9", + "parking_lot_core 0.9.10", ] [[package]] @@ -3276,15 +3349,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.9" +version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.4.1", + "redox_syscall 0.5.1", "smallvec", - "windows-targets 0.48.5", + "windows-targets 0.52.5", ] [[package]] @@ -3542,7 +3615,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" dependencies = [ "log", - "parking_lot 0.12.1", + "parking_lot 0.12.2", "scheduled-thread-pool", ] @@ -3699,6 +3772,15 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "redox_syscall" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +dependencies = [ + "bitflags 2.5.0", +] + [[package]] name = "redox_users" version = "0.4.5" @@ -3972,9 +4054,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" +checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" [[package]] name = "rustls-webpki" @@ -4062,7 +4144,7 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" dependencies = [ - "parking_lot 0.12.1", + "parking_lot 0.12.2", ] [[package]] @@ -4185,7 +4267,7 @@ dependencies = [ "futures", "log", "once_cell", - "parking_lot 0.12.1", + "parking_lot 0.12.2", "scc", "serial_test_derive", ] @@ -4294,9 +4376,9 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" dependencies = [ "libc", "windows-sys 0.52.0", @@ -4906,9 +4988,9 @@ checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unicode-width" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" [[package]] name = "unicode-xid" @@ -4930,11 +5012,11 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.9.6" +version = "2.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35" +checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" dependencies = [ - "base64 0.21.7", + "base64 0.22.0", "flate2", "log", "native-tls", @@ -5158,7 +5240,7 @@ dependencies = [ "js-sys", "log", "naga", - "parking_lot 0.12.1", + "parking_lot 0.12.2", "profiling", "raw-window-handle", "smallvec", @@ -5186,7 +5268,7 @@ dependencies = [ "log", "naga", "once_cell", - "parking_lot 0.12.1", + "parking_lot 0.12.2", "profiling", "raw-window-handle", "rustc-hash", @@ -5228,7 +5310,7 @@ dependencies = [ "ndk-sys", "objc", "once_cell", - "parking_lot 0.12.1", + "parking_lot 0.12.2", "profiling", "range-alloc", "raw-window-handle", @@ -5289,11 +5371,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" dependencies = [ - "winapi", + "windows-sys 0.52.0", ] [[package]] diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 783942564..269592c34 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -13,7 +13,7 @@ version.workspace = true # we depend on wgpu and autotune by default because we use the burn-wgpu crate to get system information default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] candle-cpu = ["burn/candle"] -candle-cuda = ["burn/candle", "burn/cuda"] +candle-cuda = ["burn/candle-cuda"] candle-metal = ["burn/candle", "burn/metal"] candle-accelerate = ["burn/candle", "burn/accelerate"] ndarray = ["burn/ndarray"] diff --git a/crates/burn-compute/src/server.rs b/crates/burn-compute/src/server.rs index ef09a9304..3ce2c738a 100644 --- a/crates/burn-compute/src/server.rs +++ b/crates/burn-compute/src/server.rs @@ -51,7 +51,7 @@ pub struct Handle { } /// Binding of a [tensor handle](Handle) to execute a kernel. -#[derive(new)] +#[derive(new, Debug)] pub struct Binding { /// Memory binding. pub memory: >::Binding, diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 6a0fd2f82..7dcbf4601 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -72,7 +72,6 @@ autodiff = ["burn-autodiff"] fusion = ["burn-wgpu?/fusion"] ## Backend features -cuda = ["burn-candle?/cuda"] metal = ["burn-candle?/metal"] accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] openblas = ["burn-ndarray?/blas-openblas"] @@ -84,6 +83,7 @@ template = ["burn-wgpu?/template"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] candle = ["burn-candle"] +candle-cuda = ["candle", "burn-candle/cuda"] wgpu = ["burn-wgpu"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml new file mode 100644 index 000000000..e76e2a0ec --- /dev/null +++ b/crates/burn-cuda/Cargo.toml @@ -0,0 +1,40 @@ +[package] +authors = ["nathanielsimard "] +categories = ["science"] +description = "CUDA backend for the Burn framework" +edition.workspace = true +keywords = ["deep-learning", "machine-learning", "gpu", "cuda"] +license.workspace = true +name = "burn-cuda" +readme.workspace = true +repository = "https://github.com/tracel-ai/burn/tree/main/burn-cuda" +version.workspace = true + +[features] +default = ["fusion", "burn-jit/default"] +fusion = ["burn-fusion", "burn-jit/fusion"] +autotune = ["burn-jit/autotune"] +doc = ["burn-jit/doc"] +std = ["burn-jit/std"] + +[dependencies] +burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false } +burn-compute = { path = "../burn-compute", version = "0.14.0" } +burn-tensor = { path = "../burn-tensor", version = "0.14.0" } +burn-common = { path = "../burn-common", version = "0.14.0" } +burn-fusion = { path = "../burn-fusion", version = "0.14.0", optional = true } +half = { workspace = true } + +bytemuck = { workspace = true } +cudarc = "0.10.0" + +log = { workspace = true } +derive-new = { workspace = true } + +[dev-dependencies] +burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false, features = [ + "export_tests", +] } + +[package.metadata.docs.rs] +features = ["doc"] \ No newline at end of file diff --git a/crates/burn-cuda/README.md b/crates/burn-cuda/README.md new file mode 100644 index 000000000..1133d344a --- /dev/null +++ b/crates/burn-cuda/README.md @@ -0,0 +1,5 @@ +# Burn-Cuda + +This backend is still a work in progress and not ready to be used. + +See #1525 diff --git a/crates/burn-cuda/src/compiler/base.rs b/crates/burn-cuda/src/compiler/base.rs new file mode 100644 index 000000000..0cd17ff91 --- /dev/null +++ b/crates/burn-cuda/src/compiler/base.rs @@ -0,0 +1,442 @@ +use super::Instruction; +use burn_jit::gpu::{self}; + +#[allow(clippy::too_many_arguments)] +#[derive(new, Clone, Debug, Default)] +pub struct CudaCompiler { + shape: bool, + stride: bool, + num_inputs: usize, + num_outputs: usize, + shared_memories: Vec, + local_arrays: Vec, + id: bool, + rank: bool, + invocation_index: bool, + global_invocation_id: (bool, bool, bool), +} + +impl burn_jit::Compiler for CudaCompiler { + type Representation = super::ComputeShader; + + fn compile(shader: burn_jit::gpu::ComputeShader) -> Self::Representation { + let compiler = Self::default(); + compiler.compile_shader(shader) + } + + fn elem_size(elem: burn_jit::gpu::Elem) -> usize { + Self::compile_elem(elem).size() + } + + fn max_shared_memory_size() -> usize { + // TODO: Find out this value. + usize::MAX + } +} + +impl CudaCompiler { + fn compile_shader(mut self, mut value: gpu::ComputeShader) -> super::ComputeShader { + self.num_inputs = value.inputs.len(); + self.num_outputs = value.outputs.len(); + + let instructions = self.compile_scope(&mut value.body); + let body = super::Body { + instructions, + stride: true, + shape: true, + shared_memories: self.shared_memories, + local_arrays: self.local_arrays, + rank: self.rank, + id: self.id, + invocation_index: self.invocation_index, + global_invocation_id: self.global_invocation_id, + }; + + super::ComputeShader { + inputs: value + .inputs + .into_iter() + .map(Self::compile_binding) + .collect(), + outputs: value + .outputs + .into_iter() + .map(Self::compile_binding) + .collect(), + named: value + .named + .into_iter() + .map(|(name, binding)| (name, Self::compile_binding(binding))) + .collect(), + workgroup_size: value.workgroup_size, + body, + } + } + + fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec { + let mut instructions = Vec::new(); + let mut processing = value.process(); + + for operation in &mut processing.operations { + if let gpu::Operation::Operator(gpu::Operator::Index(operands)) = operation { + // Replace all Index operators for global arrays with CheckedIndexAssign procedures + match operands.lhs { + gpu::Variable::GlobalInputArray(_, _) + | gpu::Variable::GlobalOutputArray(_, _) => { + *operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndex( + gpu::CheckedIndex { + lhs: operands.lhs, + rhs: operands.rhs, + out: operands.out, + }, + )); + } + // Cannot perform bound check on non-global arrays, do nothing. + _ => (), + } + } + if let gpu::Operation::Operator(gpu::Operator::IndexAssign(operands)) = operation { + // Replace all IndexAssign operators of global arrays with CheckedIndexAssign procedures + match operands.out { + gpu::Variable::GlobalInputArray(_, _) + | gpu::Variable::GlobalOutputArray(_, _) => { + *operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndexAssign( + gpu::CheckedIndexAssign { + lhs: operands.lhs, + rhs: operands.rhs, + out: operands.out, + }, + )); + } + // Cannot perform bound check on non-global arrays, do nothing. + _ => (), + } + } + } + + for var in processing.variables { + instructions.push(Instruction::DeclareVariable { + var: self.compile_variable(var), + }); + } + + processing + .operations + .into_iter() + .for_each(|op| self.compile_operation(&mut instructions, op, value)); + + instructions + } + + fn compile_operation( + &mut self, + instructions: &mut Vec, + operation: gpu::Operation, + scope: &mut gpu::Scope, + ) { + match operation { + gpu::Operation::Operator(op) => instructions.push(self.compile_instruction(op)), + gpu::Operation::Procedure(proc) => self.compile_procedure(instructions, proc, scope), + gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)), + gpu::Operation::Branch(val) => self.compile_branch(instructions, val), + gpu::Operation::Synchronization(val) => match val { + gpu::Synchronization::WorkgroupBarrier => { + instructions.push(Instruction::SyncThreads) + } + }, + } + } + + fn compile_metadata(&mut self, metadata: gpu::Metadata) -> Instruction { + match metadata { + gpu::Metadata::Stride { dim, var, out } => { + self.stride = true; + let position = match var { + gpu::Variable::GlobalInputArray(idx, _) => idx as usize, + gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize, + _ => panic!("Only Input and Output have a stride, got: {:?}", var), + }; + Instruction::Stride { + dim: self.compile_variable(dim), + position, + out: self.compile_variable(out), + } + } + gpu::Metadata::Shape { dim, var, out } => { + self.shape = true; + let position = match var { + gpu::Variable::GlobalInputArray(idx, _) => idx as usize, + gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize, + _ => panic!("Only Input and Output have a shape, got {:?}", var), + }; + Instruction::Shape { + dim: self.compile_variable(dim), + position, + out: self.compile_variable(out), + } + } + gpu::Metadata::ArrayLength { var, out } => super::Instruction::ArrayLength { + input: self.compile_variable(var), + out: self.compile_variable(out), + num_inputs: self.num_inputs, + num_outputs: self.num_outputs, + }, + } + } + + fn compile_branch(&mut self, instructions: &mut Vec, branch: gpu::Branch) { + match branch { + gpu::Branch::If(mut op) => instructions.push(Instruction::If { + cond: self.compile_variable(op.cond), + instructions: self.compile_scope(&mut op.scope), + }), + gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse { + cond: self.compile_variable(op.cond), + instructions_if: self.compile_scope(&mut op.scope_if), + instructions_else: self.compile_scope(&mut op.scope_else), + }), + gpu::Branch::Return => instructions.push(Instruction::Return), + gpu::Branch::Break => instructions.push(Instruction::Break), + gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop { + i: self.compile_variable(range_loop.i), + start: self.compile_variable(range_loop.start), + end: self.compile_variable(range_loop.end), + instructions: self.compile_scope(&mut range_loop.scope), + }), + gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop { + instructions: self.compile_scope(&mut op.scope), + }), + }; + } + fn compile_procedure( + &mut self, + instructions: &mut Vec, + proc: gpu::Procedure, + scope: &mut gpu::Scope, + ) { + let mut compile = |scope: &mut gpu::Scope| { + instructions.extend(self.compile_scope(scope)); + }; + + match proc { + gpu::Procedure::ReadGlobalWithLayout(proc) => { + proc.expand(scope); + compile(scope); + } + gpu::Procedure::ReadGlobal(proc) => { + proc.expand(scope); + compile(scope); + } + gpu::Procedure::WriteGlobal(proc) => { + proc.expand(scope); + compile(scope); + } + gpu::Procedure::ConditionalAssign(proc) => { + proc.expand(scope); + compile(scope); + } + gpu::Procedure::CheckedIndex(proc) => { + proc.expand(scope); + compile(scope); + } + gpu::Procedure::CheckedIndexAssign(proc) => { + proc.expand(scope); + compile(scope); + } + gpu::Procedure::IndexOffsetGlobalWithLayout(proc) => { + proc.expand(scope); + compile(scope); + } + } + } + + fn compile_instruction(&mut self, value: gpu::Operator) -> Instruction { + match value { + gpu::Operator::Add(op) => Instruction::Add(self.compile_binary(op)), + gpu::Operator::Mul(op) => Instruction::Mul(self.compile_binary(op)), + gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)), + gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)), + gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)), + gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)), + gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)), + gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)), + gpu::Operator::UncheckedIndexAssign(op) => { + Instruction::IndexAssign(self.compile_binary(op)) + } + gpu::Operator::Modulo(op) => Instruction::Modulo(self.compile_binary(op)), + gpu::Operator::Equal(op) => Instruction::Equal(self.compile_binary(op)), + gpu::Operator::Lower(op) => Instruction::Lower(self.compile_binary(op)), + gpu::Operator::Greater(op) => Instruction::Greater(self.compile_binary(op)), + gpu::Operator::LowerEqual(op) => Instruction::LowerEqual(self.compile_binary(op)), + gpu::Operator::GreaterEqual(op) => Instruction::GreaterEqual(self.compile_binary(op)), + gpu::Operator::Abs(op) => Instruction::Abs(self.compile_unary(op)), + gpu::Operator::Exp(op) => Instruction::Exp(self.compile_unary(op)), + gpu::Operator::Log(op) => Instruction::Log(self.compile_unary(op)), + gpu::Operator::Log1p(op) => Instruction::Log1p(self.compile_unary(op)), + gpu::Operator::Cos(op) => Instruction::Cos(self.compile_unary(op)), + gpu::Operator::Sin(op) => Instruction::Sin(self.compile_unary(op)), + gpu::Operator::Tanh(op) => Instruction::Tanh(self.compile_unary(op)), + gpu::Operator::Powf(op) => Instruction::Powf(self.compile_binary(op)), + gpu::Operator::Sqrt(op) => Instruction::Sqrt(self.compile_unary(op)), + gpu::Operator::Erf(op) => Instruction::Erf(self.compile_unary(op)), + gpu::Operator::And(op) => Instruction::And(self.compile_binary(op)), + gpu::Operator::Or(op) => Instruction::Or(self.compile_binary(op)), + gpu::Operator::Not(op) => Instruction::Not(self.compile_unary(op)), + gpu::Operator::Max(op) => Instruction::Max(self.compile_binary(op)), + gpu::Operator::Min(op) => Instruction::Min(self.compile_binary(op)), + gpu::Operator::NotEqual(op) => Instruction::NotEqual(self.compile_binary(op)), + gpu::Operator::BitwiseAnd(op) => Instruction::BitwiseAnd(self.compile_binary(op)), + gpu::Operator::BitwiseXor(op) => Instruction::BitwiseXor(self.compile_binary(op)), + gpu::Operator::ShiftLeft(op) => Instruction::ShiftLeft(self.compile_binary(op)), + gpu::Operator::ShiftRight(op) => Instruction::ShiftRight(self.compile_binary(op)), + gpu::Operator::Clamp(op) => Instruction::Clamp { + input: self.compile_variable(op.input), + min_value: self.compile_variable(op.min_value), + max_value: self.compile_variable(op.max_value), + out: self.compile_variable(op.out), + }, + gpu::Operator::Recip(op) => Instruction::Div(super::BinaryInstruction { + lhs: super::Variable::ConstantScalar( + 1.0, + Self::compile_elem(op.input.item().elem()), + ), + rhs: self.compile_variable(op.input), + out: self.compile_variable(op.out), + }), + gpu::Operator::Floor(op) => Instruction::Floor(self.compile_unary(op)), + gpu::Operator::Ceil(op) => Instruction::Ceil(self.compile_unary(op)), + gpu::Operator::Remainder(_op) => todo!(), + } + } + + fn compile_binary(&mut self, value: gpu::BinaryOperator) -> super::BinaryInstruction { + super::BinaryInstruction { + lhs: self.compile_variable(value.lhs), + rhs: self.compile_variable(value.rhs), + out: self.compile_variable(value.out), + } + } + + fn compile_unary(&mut self, value: gpu::UnaryOperator) -> super::UnaryInstruction { + super::UnaryInstruction { + input: self.compile_variable(value.input), + out: self.compile_variable(value.out), + } + } + + fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable { + match value { + gpu::Variable::GlobalInputArray(index, item) => { + super::Variable::GlobalInputArray(index, Self::compile_item(item)) + } + gpu::Variable::GlobalScalar(index, elem) => { + super::Variable::GlobalScalar(index, Self::compile_elem(elem), elem) + } + gpu::Variable::Local(index, item, scope_depth) => super::Variable::Local { + index, + item: Self::compile_item(item), + scope_depth, + }, + gpu::Variable::LocalScalar(index, elem, scope_depth) => super::Variable::LocalScalar { + index, + elem: Self::compile_elem(elem), + scope_depth, + }, + gpu::Variable::GlobalOutputArray(index, item) => { + super::Variable::GlobalOutputArray(index, Self::compile_item(item)) + } + gpu::Variable::ConstantScalar(index, elem) => { + super::Variable::ConstantScalar(index, Self::compile_elem(elem)) + } + gpu::Variable::SharedMemory(index, item, size) => { + let item = Self::compile_item(item); + if !self.shared_memories.iter().any(|s| s.index == index) { + self.shared_memories + .push(super::SharedMemory::new(index, item, size)); + } + super::Variable::SharedMemory(index, item, size) + } + gpu::Variable::Id => { + self.id = true; + super::Variable::Id + } + gpu::Variable::Rank => { + self.rank = true; + super::Variable::Rank + } + gpu::Variable::LocalInvocationIndex => { + self.invocation_index = true; + super::Variable::LocalInvocationIndex + } + gpu::Variable::LocalInvocationIdX => super::Variable::LocalInvocationIdX, + gpu::Variable::LocalInvocationIdY => super::Variable::LocalInvocationIdY, + gpu::Variable::LocalInvocationIdZ => super::Variable::LocalInvocationIdZ, + gpu::Variable::WorkgroupIdX => super::Variable::WorkgroupIdX, + gpu::Variable::WorkgroupIdY => super::Variable::WorkgroupIdY, + gpu::Variable::WorkgroupIdZ => super::Variable::WorkgroupIdZ, + gpu::Variable::GlobalInvocationIdX => { + self.global_invocation_id.0 = true; + super::Variable::GlobalInvocationIdX + } + gpu::Variable::GlobalInvocationIdY => { + self.global_invocation_id.1 = true; + super::Variable::GlobalInvocationIdY + } + gpu::Variable::GlobalInvocationIdZ => { + self.global_invocation_id.2 = true; + super::Variable::GlobalInvocationIdZ + } + gpu::Variable::WorkgroupSizeX => super::Variable::WorkgroupSizeX, + gpu::Variable::WorkgroupSizeY => super::Variable::WorkgroupSizeY, + gpu::Variable::WorkgroupSizeZ => super::Variable::WorkgroupSizeZ, + gpu::Variable::NumWorkgroupsX => super::Variable::NumWorkgroupsX, + gpu::Variable::NumWorkgroupsY => super::Variable::NumWorkgroupsY, + gpu::Variable::NumWorkgroupsZ => super::Variable::NumWorkgroupsZ, + gpu::Variable::LocalArray(id, item, depth, size) => { + let item = Self::compile_item(item); + if !self + .local_arrays + .iter() + .any(|s| s.index == id && s.depth == depth) + { + self.local_arrays + .push(super::LocalArray::new(id, item, depth, size)); + } + super::Variable::LocalArray(id, item, depth, size) + } + } + } + + fn compile_binding(binding: gpu::Binding) -> super::Binding { + super::Binding { + item: Self::compile_item(binding.item), + size: binding.size, + } + } + + fn compile_item(item: gpu::Item) -> super::Item { + match item { + gpu::Item::Vec4(elem) => super::Item::Vec4(Self::compile_elem(elem)), + gpu::Item::Vec3(elem) => super::Item::Vec3(Self::compile_elem(elem)), + gpu::Item::Vec2(elem) => super::Item::Vec2(Self::compile_elem(elem)), + gpu::Item::Scalar(elem) => super::Item::Scalar(Self::compile_elem(elem)), + } + } + + fn compile_elem(value: gpu::Elem) -> super::Elem { + match value { + gpu::Elem::Float(kind) => match kind { + gpu::FloatKind::F16 => super::Elem::F16, + gpu::FloatKind::BF16 => super::Elem::BF16, + gpu::FloatKind::F32 => super::Elem::F32, + gpu::FloatKind::F64 => panic!("f64 isn't supported yet"), + }, + gpu::Elem::Int(kind) => match kind { + gpu::IntKind::I32 => super::Elem::I32, + gpu::IntKind::I64 => panic!("i64 isn't supported yet"), + }, + gpu::Elem::UInt => super::Elem::U32, + gpu::Elem::Bool => super::Elem::Bool, + } + } +} diff --git a/crates/burn-cuda/src/compiler/binary.rs b/crates/burn-cuda/src/compiler/binary.rs new file mode 100644 index 000000000..f2fae1c71 --- /dev/null +++ b/crates/burn-cuda/src/compiler/binary.rs @@ -0,0 +1,483 @@ +use super::{Component, Elem, InstructionSettings, Item, Variable}; +use std::fmt::Display; + +pub trait Binary { + fn format( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> std::fmt::Result { + let item = out.item(); + let settings = Self::settings(*item.elem()); + + match item { + Item::Vec4(elem) => { + if settings.native_vec4 && lhs.item() == rhs.item() { + Self::format_native_vec4(f, lhs, rhs, out, elem) + } else { + Self::unroll_vec4(f, lhs, rhs, out, elem) + } + } + Item::Vec3(elem) => { + if settings.native_vec3 && lhs.item() == rhs.item() { + Self::format_native_vec3(f, lhs, rhs, out, elem) + } else { + Self::unroll_vec3(f, lhs, rhs, out, elem) + } + } + Item::Vec2(elem) => { + if settings.native_vec2 && lhs.item() == rhs.item() { + Self::format_native_vec2(f, lhs, rhs, out, elem) + } else { + Self::unroll_vec2(f, lhs, rhs, out, elem) + } + } + Item::Scalar(elem) => Self::format_scalar(f, *lhs, *rhs, *out, elem), + } + } + + fn settings(_elem: Elem) -> InstructionSettings { + InstructionSettings::default() + } + + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + out: Out, + elem: Elem, + ) -> std::fmt::Result + where + Lhs: Component, + Rhs: Component, + Out: Component; + + fn format_native_vec4( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + Self::format_scalar(f, *lhs, *rhs, *out, elem) + } + + fn format_native_vec3( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + Self::format_scalar(f, *lhs, *rhs, *out, elem) + } + + fn format_native_vec2( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + Self::format_scalar(f, *lhs, *rhs, *out, elem) + } + + fn unroll_vec2( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + + let out0 = out.index(0); + let out1 = out.index(1); + + Self::format_scalar(f, lhs0, rhs0, out0, elem)?; + Self::format_scalar(f, lhs1, rhs1, out1, elem)?; + + Ok(()) + } + + fn unroll_vec3( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + let lhs2 = lhs.index(2); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + let rhs2 = rhs.index(2); + + let out0 = out.index(0); + let out1 = out.index(1); + let out2 = out.index(2); + + Self::format_scalar(f, lhs0, rhs0, out0, elem)?; + Self::format_scalar(f, lhs1, rhs1, out1, elem)?; + Self::format_scalar(f, lhs2, rhs2, out2, elem)?; + + Ok(()) + } + + fn unroll_vec4( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + let lhs2 = lhs.index(2); + let lhs3 = lhs.index(3); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + let rhs2 = rhs.index(2); + let rhs3 = rhs.index(3); + + let out0 = out.index(0); + let out1 = out.index(1); + let out2 = out.index(2); + let out3 = out.index(3); + + Self::format_scalar(f, lhs0, rhs0, out0, elem)?; + Self::format_scalar(f, lhs1, rhs1, out1, elem)?; + Self::format_scalar(f, lhs2, rhs2, out2, elem)?; + Self::format_scalar(f, lhs3, rhs3, out3, elem)?; + + Ok(()) + } +} + +macro_rules! operator { + ($name:ident, $op:expr) => { + operator!( + $name, + $op, + InstructionSettings { + native_vec4: false, + native_vec3: false, + native_vec2: false, + } + ); + }; + ($name:ident, $op:expr, $vectorization:expr) => { + pub struct $name; + + impl Binary for $name { + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + out: Out, + _elem: Elem, + ) -> std::fmt::Result { + f.write_fmt(format_args!("{out} = {lhs} {} {rhs};\n", $op)) + } + + #[allow(unused_variables)] + fn settings(elem: Elem) -> InstructionSettings { + $vectorization + } + } + }; +} + +macro_rules! function { + ($name:ident, $op:expr) => { + function!( + $name, + $op, + InstructionSettings { + native_vec4: false, + native_vec3: false, + native_vec2: true, + } + ); + }; + ($name:ident, $op:expr, $vectorization:expr) => { + pub struct $name; + + impl Binary for $name { + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + out: Out, + _elem: Elem, + ) -> std::fmt::Result { + f.write_fmt(format_args!("{out} = {}({lhs}, {rhs});\n", $op)) + } + + #[allow(unused_variables)] + fn settings(elem: Elem) -> InstructionSettings { + $vectorization + } + } + }; +} + +operator!(Add, "+"); +operator!(Sub, "-"); +operator!(Div, "/"); +operator!(Mul, "*"); +operator!(Modulo, "%"); +operator!(Equal, "=="); +operator!(NotEqual, "!="); +operator!(Lower, "<"); +operator!(LowerEqual, "<="); +operator!(Greater, ">"); +operator!(GreaterEqual, ">="); +operator!(ShiftLeft, "<<"); +operator!(ShiftRight, ">>"); +operator!(BitwiseAnd, "&"); +operator!(BitwiseXor, "^"); +operator!(Or, "||"); +operator!(And, "&&"); + +function!(Powf, "powf"); +function!(Max, "max"); +function!(Min, "min"); + +pub struct IndexAssign; +pub struct Index; + +impl Binary for IndexAssign { + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + out: Out, + elem: Elem, + ) -> std::fmt::Result + where + Lhs: Component, + Rhs: Component, + Out: Component, + { + let elem_rhs = rhs.elem(); + // Cast only when necessary. + if elem != elem_rhs { + if let Elem::Bool = elem_rhs { + match rhs.item() { + Item::Vec4(_) => { + f.write_fmt(format_args!("{out}[{lhs}] = make_uint4({elem}({rhs}.x), {elem}({rhs}.y), {elem}({rhs}.z), {elem}({rhs}.w));\n")) + }, + Item::Vec3(_) => todo!(), + Item::Vec2(_) => todo!(), + Item::Scalar(_) => todo!(), + } + } else { + f.write_fmt(format_args!("{out}[{lhs}] = {elem}({rhs});\n")) + } + } else { + f.write_fmt(format_args!("{out}[{lhs}] = {rhs};\n")) + } + } + + fn unroll_vec2( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + + Self::format_scalar(f, lhs0, rhs0, *out, elem)?; + Self::format_scalar(f, lhs1, rhs1, *out, elem)?; + + Ok(()) + } + + fn unroll_vec3( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + let lhs2 = lhs.index(2); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + let rhs2 = rhs.index(2); + + Self::format_scalar(f, lhs0, rhs0, *out, elem)?; + Self::format_scalar(f, lhs1, rhs1, *out, elem)?; + Self::format_scalar(f, lhs2, rhs2, *out, elem)?; + + Ok(()) + } + + fn unroll_vec4( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let lhs0 = lhs.index(0); + let lhs1 = lhs.index(1); + let lhs2 = lhs.index(2); + let lhs3 = lhs.index(3); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + let rhs2 = rhs.index(2); + let rhs3 = rhs.index(3); + + Self::format_scalar(f, lhs0, rhs0, *out, elem)?; + Self::format_scalar(f, lhs1, rhs1, *out, elem)?; + Self::format_scalar(f, lhs2, rhs2, *out, elem)?; + Self::format_scalar(f, lhs3, rhs3, *out, elem)?; + + Ok(()) + } + + fn format( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> std::fmt::Result { + if let Variable::Local { + index: _, + item: _, + scope_depth: _, + } = out + { + return IndexAssignVector::format(f, lhs, rhs, out); + }; + + let elem = out.elem(); + + match lhs.item() { + Item::Vec4(_) => Self::unroll_vec4(f, lhs, rhs, out, elem), + Item::Vec3(_) => Self::unroll_vec3(f, lhs, rhs, out, elem), + Item::Vec2(_) => Self::unroll_vec2(f, lhs, rhs, out, elem), + Item::Scalar(_) => Self::format_scalar(f, *lhs, *rhs, *out, elem), + } + } +} + +impl Binary for Index { + fn format( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> std::fmt::Result { + if let Variable::Local { + index: _, + item: _, + scope_depth: _, + } = lhs + { + return IndexVector::format(f, lhs, rhs, out); + } + + Self::format_scalar(f, *lhs, *rhs, *out, out.elem()) + } + + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + out: Out, + _elem: Elem, + ) -> std::fmt::Result + where + Lhs: Component, + Rhs: Component, + Out: Component, + { + f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n")) + } +} + +/// The goal is to support indexing of vectorized types. +/// +/// # Examples +/// +/// ```c +/// float4 rhs; +/// float item = var[0]; // We want that. +/// float item = var.x; // So we compile to that. +/// ``` +struct IndexVector; + +/// The goal is to support indexing of vectorized types. +/// +/// # Examples +/// +/// ```c +/// float4 var; +/// +/// var[0] = 1.0; // We want that. +/// var.x = 1.0; // So we compile to that. +/// ``` +struct IndexAssignVector; + +impl IndexVector { + fn format( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> std::fmt::Result { + let index = match rhs { + Variable::ConstantScalar(value, _elem) => *value as usize, + _ => { + let elem = out.elem(); + return f.write_fmt(format_args!("{out} = *(({elem}*)&{lhs} + {rhs});\n")); + } + }; + + let out = out.index(index); + let lhs = lhs.index(index); + + f.write_fmt(format_args!("{out} = {lhs};\n")) + } +} + +impl IndexAssignVector { + fn format( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> std::fmt::Result { + let index = match lhs { + Variable::ConstantScalar(value, _) => *value as usize, + _ => { + let elem = out.elem(); + return f.write_fmt(format_args!("*(({elem}*)&{out} + {lhs}) = {rhs};\n")); + } + }; + + let out = out.index(index); + let rhs = rhs.index(index); + + f.write_fmt(format_args!("{out} = {rhs};\n")) + } +} diff --git a/crates/burn-cuda/src/compiler/body.rs b/crates/burn-cuda/src/compiler/body.rs new file mode 100644 index 000000000..09d33049f --- /dev/null +++ b/crates/burn-cuda/src/compiler/body.rs @@ -0,0 +1,81 @@ +use super::Instruction; +use std::fmt::Display; + +/// A body is composed of a list of [instructions](Instruction). +#[derive(Debug, Clone)] +pub struct Body { + pub instructions: Vec, + pub shared_memories: Vec, + pub local_arrays: Vec, + pub stride: bool, + pub shape: bool, + pub id: bool, + pub rank: bool, + pub invocation_index: bool, + pub global_invocation_id: (bool, bool, bool), +} + +impl Display for Body { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.id + || self.global_invocation_id.0 + || self.global_invocation_id.1 + || self.global_invocation_id.2 + { + f.write_str( + " + int3 globalInvocationId = make_int3( + blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z + ); +", + )?; + } + + if self.id { + f.write_str( + " + uint id = globalInvocationId.y * (blockDim.x * gridDim.x) + globalInvocationId.x; +", + )?; + } + + if self.invocation_index { + f.write_str( + " + int invocationIndex = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.x * blockDim.y); + ", + )?; + } + + if self.rank || self.stride || self.shape { + f.write_str("uint rank = info[0];\n")?; + } + + if self.stride || self.shape { + f.write_str("uint rank_2 = rank * 2;\n")?; + } + + for shared in self.shared_memories.iter() { + f.write_fmt(format_args!( + "__shared__ {} shared_memory_{}[{}];\n", + shared.item, shared.index, shared.size + ))?; + } + + // Local arrays + for array in self.local_arrays.iter() { + f.write_fmt(format_args!( + "{} l_arr_{}_{}[{}];\n\n", + array.item, array.index, array.depth, array.size + ))?; + } + + for ops in self.instructions.iter() { + f.write_fmt(format_args!("{ops}"))?; + } + + Ok(()) + } +} diff --git a/crates/burn-cuda/src/compiler/element.rs b/crates/burn-cuda/src/compiler/element.rs new file mode 100644 index 000000000..1d324d17a --- /dev/null +++ b/crates/burn-cuda/src/compiler/element.rs @@ -0,0 +1,309 @@ +use burn_jit::gpu; +use half::{bf16, f16}; +use std::fmt::Display; + +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum Elem { + F32, + F16, + BF16, + I32, + U32, + Bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum Item { + Vec4(Elem), + Vec3(Elem), + Vec2(Elem), + Scalar(Elem), +} + +impl Display for Elem { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Elem::F16 => f.write_str("f16"), + Elem::F32 => f.write_str("float"), + Elem::BF16 => f.write_str("bf16"), + Elem::I32 => f.write_str("int"), + Elem::U32 => f.write_str("uint"), + Elem::Bool => f.write_str("bool"), + } + } +} + +impl Display for Item { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Item::Vec4(elem) => match elem { + Elem::F32 => f.write_str("float4"), + Elem::I32 => f.write_str("int4"), + Elem::U32 => f.write_str("uint4"), + Elem::Bool => f.write_str("bool4"), + Elem::BF16 => f.write_str("bf164"), + Elem::F16 => f.write_str("f164"), + }, + Item::Vec3(elem) => match elem { + Elem::F32 => f.write_str("float3"), + Elem::I32 => f.write_str("int3"), + Elem::U32 => f.write_str("uint3"), + Elem::Bool => f.write_str("bool3"), + Elem::BF16 => f.write_str("bf163"), + Elem::F16 => f.write_str("f163"), + }, + Item::Vec2(elem) => match elem { + Elem::F32 => f.write_str("float2"), + Elem::I32 => f.write_str("int2"), + Elem::U32 => f.write_str("uint2"), + Elem::Bool => f.write_str("bool2"), + Elem::BF16 => f.write_str("bf162"), + Elem::F16 => f.write_str("f162"), + }, + Item::Scalar(elem) => f.write_fmt(format_args!("{elem}")), + } + } +} + +pub trait Component: Display { + fn item(&self) -> Item; + fn elem(&self) -> Elem { + *self.item().elem() + } +} + +impl Component for IndexedVariable { + fn item(&self) -> Item { + self.var.item() + } +} +impl Component for Variable { + fn item(&self) -> Item { + match self { + Variable::GlobalInputArray(_, e) => *e, + Variable::GlobalOutputArray(_, e) => *e, + Variable::SharedMemory(_, e, _) => *e, + Variable::Local { + index: _, + item, + scope_depth: _, + } => *item, + Variable::ConstantScalar(_, e) => Item::Scalar(*e), + Variable::GlobalScalar(_, e, _) => Item::Scalar(*e), + Variable::Id => Item::Scalar(Elem::U32), + Variable::LocalInvocationIndex => Item::Scalar(Elem::U32), + Variable::LocalInvocationIdX => Item::Scalar(Elem::U32), + Variable::LocalInvocationIdY => Item::Scalar(Elem::U32), + Variable::LocalInvocationIdZ => Item::Scalar(Elem::U32), + Variable::Rank => Item::Scalar(Elem::U32), + Variable::LocalScalar { + index: _, + elem, + scope_depth: _, + } => Item::Scalar(*elem), + Variable::WorkgroupIdX => Item::Scalar(Elem::U32), + Variable::WorkgroupIdY => Item::Scalar(Elem::U32), + Variable::WorkgroupIdZ => Item::Scalar(Elem::U32), + Variable::GlobalInvocationIdX => Item::Scalar(Elem::U32), + Variable::GlobalInvocationIdY => Item::Scalar(Elem::U32), + Variable::GlobalInvocationIdZ => Item::Scalar(Elem::U32), + Variable::WorkgroupSizeX => Item::Scalar(Elem::U32), + Variable::WorkgroupSizeY => Item::Scalar(Elem::U32), + Variable::WorkgroupSizeZ => Item::Scalar(Elem::U32), + Variable::NumWorkgroupsX => Item::Scalar(Elem::U32), + Variable::NumWorkgroupsY => Item::Scalar(Elem::U32), + Variable::NumWorkgroupsZ => Item::Scalar(Elem::U32), + Variable::LocalArray(_, e, _, _) => *e, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub enum Variable { + GlobalInputArray(u16, Item), + GlobalOutputArray(u16, Item), + GlobalScalar(u16, Elem, gpu::Elem), + ConstantScalar(f64, Elem), + Local { + index: u16, + item: Item, + scope_depth: u8, + }, + LocalScalar { + index: u16, + elem: Elem, + scope_depth: u8, + }, + SharedMemory(u16, Item, u32), + LocalArray(u16, Item, u8, u32), + Id, + LocalInvocationIndex, + LocalInvocationIdX, + LocalInvocationIdY, + LocalInvocationIdZ, + Rank, + WorkgroupIdX, + WorkgroupIdY, + WorkgroupIdZ, + GlobalInvocationIdX, + GlobalInvocationIdY, + GlobalInvocationIdZ, + WorkgroupSizeX, + WorkgroupSizeY, + WorkgroupSizeZ, + NumWorkgroupsX, + NumWorkgroupsY, + NumWorkgroupsZ, +} + +impl Display for Variable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")), + Variable::LocalScalar { + index, + elem: _, + scope_depth, + } => f.write_fmt(format_args!("s_{scope_depth}_{index}")), + Variable::Local { + index, + item: _, + scope_depth, + } => f.write_fmt(format_args!("l_{scope_depth}_{index}")), + Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")), + Variable::GlobalScalar(number, _, elem) => { + f.write_fmt(format_args!("scalars_{elem}[{number}]")) + } + Variable::ConstantScalar(number, elem) => f.write_fmt(format_args!("{elem}({number})")), + Variable::SharedMemory(number, _, _) => { + f.write_fmt(format_args!("shared_memory_{number}")) + } + Variable::Id => f.write_str("id"), + Variable::LocalInvocationIndex => f.write_str("invocationIndex"), + Variable::LocalInvocationIdX => f.write_str("threadIdx.x"), + Variable::LocalInvocationIdY => f.write_str("threadIdx.y"), + Variable::LocalInvocationIdZ => f.write_str("threadIdx.z"), + Variable::Rank => f.write_str("rank"), + Variable::WorkgroupIdX => f.write_str("blockIdx.x"), + Variable::WorkgroupIdY => f.write_str("blockIdx.y"), + Variable::WorkgroupIdZ => f.write_str("blockIdx.z"), + Variable::WorkgroupSizeX => f.write_str("blockDim.x"), + Variable::WorkgroupSizeY => f.write_str("blockDim.y"), + Variable::WorkgroupSizeZ => f.write_str("blockDim.z"), + Variable::NumWorkgroupsX => f.write_str("gridDim.x"), + Variable::NumWorkgroupsY => f.write_str("gridDim.y"), + Variable::NumWorkgroupsZ => f.write_str("gridDim.z"), + Variable::GlobalInvocationIdX => f.write_str("globalInvocationId.x"), + Variable::GlobalInvocationIdY => f.write_str("globalInvocationId.y"), + Variable::GlobalInvocationIdZ => f.write_str("globalInvocationId.z"), + Variable::LocalArray(id, _item, depth, _size) => { + f.write_fmt(format_args!("l_arr_{}_{}", id, depth)) + } + } + } +} + +impl Variable { + pub fn is_always_scalar(&self) -> bool { + match self { + Variable::GlobalScalar(_, _, _) => true, + Variable::ConstantScalar(_, _) => true, + Variable::LocalScalar { + index: _, + elem: _, + scope_depth: _, + } => true, + Variable::Id => true, + Variable::LocalInvocationIndex => true, + Variable::LocalInvocationIdX => true, + Variable::LocalInvocationIdY => true, + Variable::LocalInvocationIdZ => true, + Variable::Rank => true, + Variable::GlobalInputArray(_, _) => false, + Variable::GlobalOutputArray(_, _) => false, + Variable::SharedMemory(_, _, _) => false, + Variable::Local { + index: _, + item: _, + scope_depth: _, + } => false, + Variable::WorkgroupIdX => true, + Variable::WorkgroupIdY => true, + Variable::WorkgroupIdZ => true, + Variable::GlobalInvocationIdX => true, + Variable::GlobalInvocationIdY => true, + Variable::GlobalInvocationIdZ => true, + Variable::WorkgroupSizeX => true, + Variable::WorkgroupSizeY => true, + Variable::WorkgroupSizeZ => true, + Variable::NumWorkgroupsX => true, + Variable::NumWorkgroupsY => true, + Variable::NumWorkgroupsZ => true, + Variable::LocalArray(_, _, _, _) => false, + } + } + + pub fn index(&self, index: usize) -> IndexedVariable { + IndexedVariable { var: *self, index } + } +} + +#[derive(Debug, Clone)] +pub struct IndexedVariable { + var: Variable, + index: usize, +} + +impl Display for IndexedVariable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let var = &self.var; + let item = self.var.item(); + + match item { + Item::Vec4(_) => match self.index { + 0 => f.write_fmt(format_args!("{var}.x"))?, + 1 => f.write_fmt(format_args!("{var}.y"))?, + 2 => f.write_fmt(format_args!("{var}.z"))?, + 3 => f.write_fmt(format_args!("{var}.w"))?, + _ => unreachable!(), + }, + Item::Vec3(_) => match self.index { + 0 => f.write_fmt(format_args!("{var}.x"))?, + 1 => f.write_fmt(format_args!("{var}.y"))?, + 2 => f.write_fmt(format_args!("{var}.z"))?, + _ => unreachable!(), + }, + Item::Vec2(_) => match self.index { + 0 => f.write_fmt(format_args!("{var}.x"))?, + 1 => f.write_fmt(format_args!("{var}.y"))?, + _ => unreachable!(), + }, + Item::Scalar(_) => f.write_fmt(format_args!("{var}"))?, + } + + Ok(()) + } +} +impl Item { + pub fn elem(&self) -> &Elem { + match self { + Item::Vec4(e) => e, + Item::Vec3(e) => e, + Item::Vec2(e) => e, + Item::Scalar(e) => e, + } + } +} + +impl Elem { + pub fn size(&self) -> usize { + match self { + Self::F32 => core::mem::size_of::(), + Self::F16 => core::mem::size_of::(), + Self::BF16 => core::mem::size_of::(), + Self::I32 => core::mem::size_of::(), + Self::U32 => core::mem::size_of::(), + Self::Bool => core::mem::size_of::(), + } + } +} diff --git a/crates/burn-cuda/src/compiler/instruction.rs b/crates/burn-cuda/src/compiler/instruction.rs new file mode 100644 index 000000000..3538f3dcf --- /dev/null +++ b/crates/burn-cuda/src/compiler/instruction.rs @@ -0,0 +1,233 @@ +use super::{binary::*, unary::*, Component, Variable}; +use std::fmt::Display; + +#[derive(Debug, Clone)] +pub struct BinaryInstruction { + pub lhs: Variable, + pub rhs: Variable, + pub out: Variable, +} + +#[derive(Debug, Clone)] +pub struct UnaryInstruction { + pub input: Variable, + pub out: Variable, +} + +#[derive(Debug, Clone)] +pub enum Instruction { + ArrayLength { + input: Variable, + out: Variable, + num_inputs: usize, + num_outputs: usize, + }, + DeclareVariable { + var: Variable, + }, + Modulo(BinaryInstruction), + Add(BinaryInstruction), + Div(BinaryInstruction), + Mul(BinaryInstruction), + Sub(BinaryInstruction), + Index(BinaryInstruction), + IndexAssign(BinaryInstruction), + CheckedIndexAssign(BinaryInstruction), + Assign(UnaryInstruction), + RangeLoop { + i: Variable, + start: Variable, + end: Variable, + instructions: Vec, + }, + Loop { + instructions: Vec, + }, + If { + cond: Variable, + instructions: Vec, + }, + IfElse { + cond: Variable, + instructions_if: Vec, + instructions_else: Vec, + }, + Return, + Break, + Stride { + dim: Variable, + position: usize, + out: Variable, + }, + Shape { + dim: Variable, + position: usize, + out: Variable, + }, + Equal(BinaryInstruction), + NotEqual(BinaryInstruction), + Lower(BinaryInstruction), + Greater(BinaryInstruction), + LowerEqual(BinaryInstruction), + GreaterEqual(BinaryInstruction), + Erf(UnaryInstruction), + BitwiseAnd(BinaryInstruction), + BitwiseXor(BinaryInstruction), + ShiftLeft(BinaryInstruction), + ShiftRight(BinaryInstruction), + Abs(UnaryInstruction), + Exp(UnaryInstruction), + Log(UnaryInstruction), + Log1p(UnaryInstruction), + Cos(UnaryInstruction), + Sin(UnaryInstruction), + Tanh(UnaryInstruction), + Powf(BinaryInstruction), + Sqrt(UnaryInstruction), + Min(BinaryInstruction), + Max(BinaryInstruction), + Not(UnaryInstruction), + Or(BinaryInstruction), + And(BinaryInstruction), + Clamp { + input: Variable, + min_value: Variable, + max_value: Variable, + out: Variable, + }, + SyncThreads, + Ceil(UnaryInstruction), + Floor(UnaryInstruction), +} + +impl Display for Instruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Instruction::Return => f.write_str("return;"), + Instruction::Break => f.write_str("break;"), + Instruction::DeclareVariable { var } => { + let item = var.item(); + f.write_fmt(format_args!("{item} {var};\n")) + } + Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out), + Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::CheckedIndexAssign(it) => { + IndexAssign::format(f, &it.lhs, &it.rhs, &it.out) + } + Instruction::Assign(it) => Assign::format(f, &it.input, &it.out), + Instruction::RangeLoop { + i, + start, + end, + instructions, + } => { + f.write_fmt(format_args!( + " +for (uint {i} = {start}; {i} < {end}; {i}++) {{ +" + ))?; + for instruction in instructions { + f.write_fmt(format_args!("{instruction}"))?; + } + + f.write_str("}\n") + } + + Instruction::Loop { instructions } => { + f.write_fmt(format_args!("while (true) {{\n"))?; + for i in instructions { + f.write_fmt(format_args!("{i}"))?; + } + f.write_str("}\n") + } + Instruction::If { cond, instructions } => { + f.write_fmt(format_args!("if ({cond}) {{\n"))?; + for i in instructions { + f.write_fmt(format_args!("{i}"))?; + } + f.write_str("}\n") + } + Instruction::IfElse { + cond, + instructions_if, + instructions_else, + } => { + f.write_fmt(format_args!("if ({cond}) {{\n"))?; + for i in instructions_if { + f.write_fmt(format_args!("{i}"))?; + } + f.write_str("} else {\n")?; + for i in instructions_else { + f.write_fmt(format_args!("{i}"))?; + } + f.write_str("}\n") + } + Instruction::Stride { dim, position, out } => f.write_fmt(format_args!( + "{out} = info[({position} * rank_2) + {dim} + 1];\n" + )), + Instruction::Shape { dim, position, out } => f.write_fmt(format_args!( + "{out} = info[({position} * rank_2) + rank + {dim} + 1];\n" + )), + Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Erf(it) => Erf::format(f, &it.input, &it.out), + Instruction::Abs(it) => Abs::format(f, &it.input, &it.out), + Instruction::Exp(it) => Exp::format(f, &it.input, &it.out), + Instruction::Log(it) => Log::format(f, &it.input, &it.out), + Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out), + Instruction::Cos(it) => Cos::format(f, &it.input, &it.out), + Instruction::Sin(it) => Sin::format(f, &it.input, &it.out), + Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out), + Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), + Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Not(it) => Not::format(f, &it.input, &it.out), + Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Clamp { + input, + min_value, + max_value, + out, + } => f.write_fmt(format_args!( + " +{out} = min({input}, {max_value}); +{out} = max({out}, {min_value}); + " + )), + Instruction::SyncThreads => f.write_str("__syncthreads();\n"), + Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out), + Instruction::Floor(it) => Floor::format(f, &it.input, &it.out), + Instruction::ArrayLength { + input, + out, + num_inputs, + num_outputs, + } => { + let offset = num_inputs + num_outputs; + let index = match input { + Variable::GlobalInputArray(index, _) => *index as usize, + Variable::GlobalOutputArray(index, _) => *index as usize + num_inputs, + _ => panic!("Can only know the len of a global array."), + } + 1; + f.write_fmt(format_args!( + "{out} = info[({offset} * 2 * info[0]) + {index}];\n" + )) + } + } + } +} diff --git a/crates/burn-cuda/src/compiler/mod.rs b/crates/burn-cuda/src/compiler/mod.rs new file mode 100644 index 000000000..5f83e66c3 --- /dev/null +++ b/crates/burn-cuda/src/compiler/mod.rs @@ -0,0 +1,16 @@ +pub mod binary; +pub mod unary; + +mod base; +mod body; +mod element; +mod instruction; +mod settings; +mod shader; + +pub use base::*; +pub use body::*; +pub use element::*; +pub use instruction::*; +pub use settings::*; +pub use shader::*; diff --git a/crates/burn-cuda/src/compiler/settings.rs b/crates/burn-cuda/src/compiler/settings.rs new file mode 100644 index 000000000..09e35427f --- /dev/null +++ b/crates/burn-cuda/src/compiler/settings.rs @@ -0,0 +1,6 @@ +#[derive(Debug, Default)] +pub struct InstructionSettings { + pub native_vec4: bool, + pub native_vec3: bool, + pub native_vec2: bool, +} diff --git a/crates/burn-cuda/src/compiler/shader.rs b/crates/burn-cuda/src/compiler/shader.rs new file mode 100644 index 000000000..3d9bd6d83 --- /dev/null +++ b/crates/burn-cuda/src/compiler/shader.rs @@ -0,0 +1,155 @@ +// use super::{Body, Extension, Item}; +use super::{Body, Item}; +use burn_jit::{gpu::WorkgroupSize, CompilerRepresentation}; +use std::fmt::Display; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Location { + Storage, + #[allow(dead_code)] + Workgroup, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Visibility { + Read, + ReadWrite, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Binding { + pub item: Item, + pub size: Option, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct SharedMemory { + pub index: u16, + pub item: Item, + pub size: u32, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct LocalArray { + pub index: u16, + pub item: Item, + pub depth: u8, + pub size: u32, +} + +impl LocalArray { + pub fn new(index: u16, item: Item, depth: u8, size: u32) -> Self { + Self { + index, + item, + depth, + size, + } + } +} + +impl SharedMemory { + pub fn new(index: u16, item: Item, size: u32) -> Self { + Self { index, item, size } + } +} + +#[derive(Debug, Clone)] +pub struct ComputeShader { + pub inputs: Vec, + pub outputs: Vec, + pub named: Vec<(String, Binding)>, + pub workgroup_size: WorkgroupSize, + pub body: Body, +} + +impl CompilerRepresentation for ComputeShader { + fn shared_memory_size(&self) -> usize { + let mut current = 0usize; + + for var in self.body.shared_memories.iter() { + let factor = match var.item { + Item::Vec4(_) => 4, + Item::Vec3(_) => 3, + Item::Vec2(_) => 2, + Item::Scalar(_) => 1, + }; + + let elem_size_bytes = var.item.elem().size(); + current += (var.size as usize) * factor * elem_size_bytes; + } + + current + } +} + +impl Display for ComputeShader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + " +typedef unsigned int uint; + +extern \"C\" struct bool4 {{ + bool x; + bool y; + bool z; + bool w; +}}; + +extern \"C\" __global__ void kernel( +", + ))?; + + let num_bindings = self.inputs.len() + self.outputs.len() + self.named.len(); + let mut binding_index = 0; + for (index, binding) in self.inputs.iter().enumerate() { + binding_index += 1; + f.write_fmt(format_args!("{} input_{}[]", binding.item, index))?; + if binding_index < num_bindings { + f.write_str(",")?; + } + } + for (index, binding) in self.outputs.iter().enumerate() { + binding_index += 1; + f.write_fmt(format_args!("{} output_{}[]", binding.item, index))?; + if binding_index < num_bindings { + f.write_str(",")?; + } + } + for (name, binding) in self.named.iter() { + binding_index += 1; + f.write_fmt(format_args!("{} {}[]", binding.item, name))?; + + if binding_index < num_bindings { + f.write_str(",")?; + } + } + + f.write_str("\n) {\n")?; + + f.write_fmt(format_args!("{}", self.body))?; + f.write_str("\n}")?; + + Ok(()) + } +} + +impl ComputeShader {} + +impl Display for Location { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Location::Storage => f.write_str("storage"), + Location::Workgroup => f.write_str("workgroup"), + } + } +} + +impl Display for Visibility { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Visibility::Read => f.write_str("read"), + Visibility::ReadWrite => f.write_str("read_write"), + } + } +} diff --git a/crates/burn-cuda/src/compiler/unary.rs b/crates/burn-cuda/src/compiler/unary.rs new file mode 100644 index 000000000..1731b1b6b --- /dev/null +++ b/crates/burn-cuda/src/compiler/unary.rs @@ -0,0 +1,210 @@ +use super::{Component, Elem, InstructionSettings, Item, Variable}; +use std::fmt::Display; + +pub trait Unary { + fn format( + f: &mut std::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + ) -> std::fmt::Result { + let item = out.item(); + let settings = Self::settings(*item.elem()); + + match item { + Item::Vec4(elem) => { + if settings.native_vec4 { + Self::format_native_vec4(f, input, out, elem) + } else { + Self::unroll_vec4(f, input, out, elem) + } + } + Item::Vec3(elem) => { + if settings.native_vec3 { + Self::format_native_vec3(f, input, out, elem) + } else { + Self::unroll_vec3(f, input, out, elem) + } + } + Item::Vec2(elem) => { + if settings.native_vec2 { + Self::format_native_vec2(f, input, out, elem) + } else { + Self::unroll_vec2(f, input, out, elem) + } + } + Item::Scalar(elem) => Self::format_scalar(f, *input, *out, elem), + } + } + + fn settings(_elem: Elem) -> InstructionSettings { + InstructionSettings::default() + } + + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + input: Input, + out: Out, + elem: Elem, + ) -> std::fmt::Result + where + Input: Component, + Out: Component; + + fn format_native_vec4( + f: &mut std::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + Self::format_scalar(f, *input, *out, elem) + } + + fn format_native_vec3( + f: &mut std::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + Self::format_scalar(f, *input, *out, elem) + } + + fn format_native_vec2( + f: &mut std::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + Self::format_scalar(f, *input, *out, elem) + } + + fn unroll_vec2( + f: &mut std::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let input0 = input.index(0); + let input1 = input.index(1); + + let out0 = out.index(0); + let out1 = out.index(1); + + Self::format_scalar(f, input0, out0, elem)?; + Self::format_scalar(f, input1, out1, elem)?; + + Ok(()) + } + + fn unroll_vec3( + f: &mut std::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let input0 = input.index(0); + let input1 = input.index(1); + let input2 = input.index(2); + + let out0 = out.index(0); + let out1 = out.index(1); + let out2 = out.index(2); + + Self::format_scalar(f, input0, out0, elem)?; + Self::format_scalar(f, input1, out1, elem)?; + Self::format_scalar(f, input2, out2, elem)?; + + Ok(()) + } + + fn unroll_vec4( + f: &mut std::fmt::Formatter<'_>, + input: &Variable, + out: &Variable, + elem: Elem, + ) -> std::fmt::Result { + let input0 = input.index(0); + let input1 = input.index(1); + let input2 = input.index(2); + let input3 = input.index(3); + + let out0 = out.index(0); + let out1 = out.index(1); + let out2 = out.index(2); + let out3 = out.index(3); + + Self::format_scalar(f, input0, out0, elem)?; + Self::format_scalar(f, input1, out1, elem)?; + Self::format_scalar(f, input2, out2, elem)?; + Self::format_scalar(f, input3, out3, elem)?; + + Ok(()) + } +} + +macro_rules! function { + ($name:ident, $func:expr) => { + pub struct $name; + + impl Unary for $name { + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + input: Input, + out: Out, + _elem: Elem, + ) -> std::fmt::Result { + f.write_fmt(format_args!("{out} = {}({input});\n", $func)) + } + } + }; +} + +function!(Abs, "abs"); +function!(Log, "log"); +function!(Log1p, "log1p"); +function!(Cos, "cos"); +function!(Sin, "sin"); +function!(Tanh, "tanh"); +function!(Sqrt, "sqrt"); +function!(Exp, "exp"); +function!(Erf, "erff"); +function!(Ceil, "ceil"); +function!(Floor, "floor"); + +pub struct Not; + +impl Unary for Not { + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + input: Input, + out: Out, + _elem: Elem, + ) -> std::fmt::Result + where + Input: Component, + Out: Component, + { + f.write_fmt(format_args!("{out} = !{input};\n")) + } +} + +pub struct Assign; + +impl Unary for Assign { + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + input: Input, + out: Out, + elem: Elem, + ) -> std::fmt::Result + where + Input: Component, + Out: Component, + { + // Cast only when necessary. + if elem != input.elem() { + f.write_fmt(format_args!("{out} = {elem}({input});\n")) + } else { + f.write_fmt(format_args!("{out} = {input};\n")) + } + } +} diff --git a/crates/burn-cuda/src/compute/mod.rs b/crates/burn-cuda/src/compute/mod.rs new file mode 100644 index 000000000..4139c3868 --- /dev/null +++ b/crates/burn-cuda/src/compute/mod.rs @@ -0,0 +1,5 @@ +mod server; +mod storage; + +pub use server::*; +pub use storage::*; diff --git a/crates/burn-cuda/src/compute/server.rs b/crates/burn-cuda/src/compute/server.rs new file mode 100644 index 000000000..1be228f7b --- /dev/null +++ b/crates/burn-cuda/src/compute/server.rs @@ -0,0 +1,226 @@ +use super::storage::Binding; +use super::storage::CudaStorage; +use burn_compute::{ + memory_management::MemoryManagement, + server::{self, ComputeServer}, +}; +use burn_jit::compute::{JitAutotuneKey, Kernel, WorkGroup}; +use burn_jit::gpu::WorkgroupSize; +use cudarc::driver::sys::CUctx_st; +use cudarc::driver::sys::CUfunc_st; +use std::collections::HashMap; +use std::ffi::CStr; +use std::ffi::CString; + +#[derive(Debug)] +pub struct CudaServer> { + state: CudaServerState, +} + +pub(crate) enum CudaServerState> { + Uninitialized { + device_index: usize, + init: Box CudaContext>, + }, + Initialized { + ctx: CudaContext, + }, +} + +impl> core::fmt::Debug for CudaServerState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Context") + } +} + +#[derive(Debug)] +pub(crate) struct CudaContext> { + context: *mut CUctx_st, + stream: cudarc::driver::sys::CUstream, + memory_management: MM, + module_names: HashMap, +} + +#[derive(Debug)] +struct CompiledKernel { + workgroup_size: WorkgroupSize, + shared_mem_bytes: usize, + func: *mut CUfunc_st, +} + +unsafe impl> Send for CudaServer {} + +impl> ComputeServer for CudaServer { + type Kernel = Kernel; + type Storage = CudaStorage; + type MemoryManagement = MM; + type AutotuneKey = JitAutotuneKey; + + fn read(&mut self, binding: server::Binding) -> burn_tensor::Reader> { + let ctx = self.get_context(); + let resource = ctx.memory_management.get(binding.memory); + // TODO: Check if it is possible to make this faster + let mut data = vec![0; resource.size() as usize]; + unsafe { + cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap(); + }; + + ctx.sync(); + + burn_tensor::Reader::Concrete(data) + } + + fn create(&mut self, data: &[u8]) -> server::Handle { + let ctx = self.get_context(); + let handle = ctx.memory_management.reserve(data.len()); + let handle = server::Handle::new(handle); + let binding = handle.clone().binding().memory; + let resource = ctx.memory_management.get(binding); + + unsafe { + cudarc::driver::result::memcpy_htod_async(resource.ptr, data, ctx.stream).unwrap(); + } + + handle + } + + fn empty(&mut self, size: usize) -> server::Handle { + let ctx = self.get_context(); + let handle = ctx.memory_management.reserve(size); + server::Handle::new(handle) + } + + fn execute(&mut self, kernel: Self::Kernel, bindings: Vec>) { + let ctx = self.get_context(); + let kernel_id = kernel.id(); + let settings = kernel.launch_settings(); + + if !ctx.module_names.contains_key(&kernel_id) { + ctx.compile_kernel(&kernel_id, kernel); + } + + let bindings = bindings + .into_iter() + .map(|binding| ctx.memory_management.get(binding.memory).as_binding()) + .collect(); + + ctx.execute_task(kernel_id, settings.workgroup, bindings); + // TODO: fix this + // self.memory_management.storage().perform_deallocations(); + } + + fn sync(&mut self) { + let ctx = self.get_context(); + ctx.sync(); + } +} + +impl> CudaContext { + pub fn new( + memory_management: MM, + stream: cudarc::driver::sys::CUstream, + context: *mut CUctx_st, + ) -> Self { + Self { + context, + memory_management, + module_names: HashMap::new(), + stream, + } + } + + fn sync(&mut self) { + unsafe { + cudarc::driver::result::stream::synchronize(self.stream).unwrap(); + }; + } + + fn compile_kernel(&mut self, kernel_id: &str, kernel: Kernel) { + let kernel_compiled = kernel.compile(); + let shared_mem_bytes = kernel_compiled.shared_mem_bytes; + let workgroup_size = kernel_compiled.workgroup_size; + + let ptx = unsafe { + let program = cudarc::nvrtc::result::create_program(kernel_compiled.source).unwrap(); + if cudarc::nvrtc::result::compile_program::>(program, &[]).is_err() { + let log_raw = cudarc::nvrtc::result::get_program_log(program).unwrap(); + let log_ptr = log_raw.as_ptr(); + let log = CStr::from_ptr(log_ptr).to_str().unwrap(); + let mut message = "[Compilation Error] ".to_string(); + for line in log.split('\n') { + if !line.is_empty() { + message += format!("\n {line}").as_str(); + } + } + let source = kernel.compile().source; + panic!("{message}\n[Source] \n{source}"); + }; + cudarc::nvrtc::result::get_ptx(program).unwrap() + }; + + let func_name = CString::new("kernel".to_string()).unwrap(); + let func = unsafe { + let module = + cudarc::driver::result::module::load_data(ptx.as_ptr() as *const _).unwrap(); + cudarc::driver::result::module::get_function(module, func_name).unwrap() + }; + + self.module_names.insert( + kernel_id.to_string(), + CompiledKernel { + workgroup_size, + shared_mem_bytes, + func, + }, + ); + } + + fn execute_task( + &mut self, + kernel_id: String, + workgroup: WorkGroup, + mut bindings: Vec, + ) { + let kernel = self.module_names.get(&kernel_id).unwrap(); + let workgroup_size = kernel.workgroup_size; + + unsafe { + cudarc::driver::result::launch_kernel( + kernel.func, + (workgroup.x, workgroup.y, workgroup.z), + (workgroup_size.x, workgroup_size.y, workgroup_size.z), + kernel.shared_mem_bytes as u32, + self.stream, + &mut bindings, + ) + .unwrap(); + }; + } +} + +impl> CudaServer { + /// Create a new cuda server. + pub(crate) fn new(index: usize, init: Box CudaContext>) -> Self { + Self { + state: CudaServerState::Uninitialized { + device_index: index, + init, + }, + } + } + + fn get_context(&mut self) -> &mut CudaContext { + if let CudaServerState::Uninitialized { device_index, init } = &self.state { + let ctx = init(*device_index); + self.state = CudaServerState::Initialized { ctx }; + } + if let CudaServerState::Initialized { ctx } = &mut self.state { + unsafe { + cudarc::driver::result::ctx::set_current(ctx.context).unwrap(); + }; + ctx + } else { + panic!("Context should be initialized"); + } + } +} diff --git a/crates/burn-cuda/src/compute/storage.rs b/crates/burn-cuda/src/compute/storage.rs new file mode 100644 index 000000000..59395632c --- /dev/null +++ b/crates/burn-cuda/src/compute/storage.rs @@ -0,0 +1,118 @@ +use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; +use cudarc::driver::sys::CUstream; +use std::collections::HashMap; + +/// Buffer storage for cuda. +pub struct CudaStorage { + memory: HashMap, + deallocations: Vec, + stream: cudarc::driver::sys::CUstream, +} + +unsafe impl Send for CudaStorage {} + +impl core::fmt::Debug for CudaStorage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("CudaStorage {{ device: {:?} }}", self.stream).as_str()) + } +} + +/// Keeps actual wgpu buffer references in a hashmap with ids as key. +impl CudaStorage { + /// Create a new storage on the given [device](wgpu::Device). + pub fn new(stream: CUstream) -> Self { + Self { + memory: HashMap::new(), + deallocations: Vec::new(), + stream, + } + } + + /// Actually deallocates buffers tagged to be deallocated. + pub fn perform_deallocations(&mut self) { + for id in self.deallocations.drain(..) { + if let Some(ptr) = self.memory.remove(&id) { + unsafe { + cudarc::driver::result::free_async(ptr, self.stream).unwrap(); + } + } + } + } +} + +/// The memory resource that can be allocated for wgpu. +#[derive(new, Debug)] +pub struct CudaResource { + /// The wgpu buffer. + pub ptr: u64, + pub binding: *mut std::ffi::c_void, + /// How the resource is used. + pub kind: CudaResourceKind, +} + +unsafe impl Send for CudaResource {} + +pub type Binding = *mut std::ffi::c_void; + +impl CudaResource { + /// Return the binding view of the buffer. + pub fn as_binding(&self) -> Binding { + self.binding + } + + /// Return the buffer size. + pub fn size(&self) -> u64 { + match self.kind { + CudaResourceKind::Full { size } => size as u64, + CudaResourceKind::Slice { size, offset: _ } => size as u64, + } + } + + /// Return the buffer offset. + pub fn offset(&self) -> u64 { + match self.kind { + CudaResourceKind::Full { size: _ } => 0, + CudaResourceKind::Slice { size: _, offset } => offset as u64, + } + } +} + +/// How the resource is used, either as a slice or fully. +#[derive(Debug)] +pub enum CudaResourceKind { + /// Represents an entire buffer. + Full { size: usize }, + /// A slice over a buffer. + Slice { size: usize, offset: usize }, +} + +impl ComputeStorage for CudaStorage { + type Resource = CudaResource; + + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let ptr = self.memory.get(&handle.id).unwrap(); + match handle.utilization { + StorageUtilization::Full(size) => CudaResource::new( + *ptr, + ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void, + CudaResourceKind::Full { size }, + ), + StorageUtilization::Slice { offset, size } => CudaResource::new( + *ptr, + ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void, + CudaResourceKind::Slice { size, offset }, + ), + } + } + + fn alloc(&mut self, size: usize) -> StorageHandle { + let id = StorageId::new(); + let ptr = unsafe { cudarc::driver::result::malloc_async(self.stream, size).unwrap() }; + self.memory.insert(id.clone(), ptr); + StorageHandle::new(id, StorageUtilization::Full(size)) + } + + fn dealloc(&mut self, id: StorageId) { + self.deallocations.push(id); + } +} diff --git a/crates/burn-cuda/src/device.rs b/crates/burn-cuda/src/device.rs new file mode 100644 index 000000000..04a601443 --- /dev/null +++ b/crates/burn-cuda/src/device.rs @@ -0,0 +1,12 @@ +use burn_tensor::backend::{DeviceId, DeviceOps}; + +#[derive(new, Clone, Debug, PartialEq, Eq, Default, Hash)] +pub struct CudaDevice { + pub index: usize, +} + +impl DeviceOps for CudaDevice { + fn id(&self) -> DeviceId { + DeviceId::new(0, self.index as u32) + } +} diff --git a/crates/burn-cuda/src/element.rs b/crates/burn-cuda/src/element.rs new file mode 100644 index 000000000..e125a5a2b --- /dev/null +++ b/crates/burn-cuda/src/element.rs @@ -0,0 +1,42 @@ +use burn_jit::JitElement; + +use crate::compiler; + +/// The base element trait for the wgpu backend. +pub trait CudaElement: JitElement { + fn cuda_elem() -> compiler::Elem; +} + +/// The float element type for the wgpu backend. +pub trait FloatElement: CudaElement + burn_jit::FloatElement {} + +/// The int element type for the wgpu backend. +pub trait IntElement: CudaElement + burn_jit::IntElement {} + +impl CudaElement for u32 { + fn cuda_elem() -> compiler::Elem { + compiler::Elem::U32 + } +} + +impl CudaElement for i32 { + fn cuda_elem() -> compiler::Elem { + compiler::Elem::I32 + } +} + +impl CudaElement for f32 { + fn cuda_elem() -> compiler::Elem { + compiler::Elem::F32 + } +} + +impl CudaElement for half::bf16 { + fn cuda_elem() -> compiler::Elem { + compiler::Elem::BF16 + } +} + +impl FloatElement for f32 {} +impl FloatElement for half::bf16 {} +impl IntElement for i32 {} diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs new file mode 100644 index 000000000..2d065b3e6 --- /dev/null +++ b/crates/burn-cuda/src/lib.rs @@ -0,0 +1,29 @@ +#[macro_use] +extern crate derive_new; +extern crate alloc; + +mod compute; +mod device; +mod element; +mod runtime; + +pub mod compiler; +pub use device::*; + +use burn_jit::JitBackend; +use runtime::CudaRuntime; + +#[cfg(not(feature = "fusion"))] +pub type Cuda = JitBackend; + +#[cfg(feature = "fusion")] +pub type Cuda = burn_fusion::Fusion>; + +#[cfg(test)] +mod tests { + use super::*; + + pub type TestRuntime = crate::CudaRuntime; + + burn_jit::testgen_all!(); +} diff --git a/crates/burn-cuda/src/runtime.rs b/crates/burn-cuda/src/runtime.rs new file mode 100644 index 000000000..6b745528b --- /dev/null +++ b/crates/burn-cuda/src/runtime.rs @@ -0,0 +1,81 @@ +use burn_common::stub::RwLock; +use burn_compute::{ + channel::MutexComputeChannel, + client::ComputeClient, + memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy}, + tune::Tuner, + ComputeRuntime, +}; +use burn_jit::Runtime; +use std::sync::Arc; + +use crate::{ + compiler::CudaCompiler, + compute::{CudaContext, CudaServer, CudaStorage}, + device::CudaDevice, +}; + +#[derive(Debug)] +pub struct CudaRuntime; + +// static RUNTIME: ComputeRuntime> = +static RUNTIME: ComputeRuntime> = + ComputeRuntime::new(); + +type Server = CudaServer>; + +impl Runtime for CudaRuntime { + type Compiler = CudaCompiler; + type Server = CudaServer>; + + // type Channel = MutexComputeChannel>>; + type Channel = MutexComputeChannel>>; + type Device = CudaDevice; + + fn client(device: &Self::Device) -> ComputeClient { + fn init(index: usize) -> CudaContext> { + cudarc::driver::result::init().unwrap(); + let device_ptr = cudarc::driver::result::device::get(index as i32).unwrap(); + + let ctx = unsafe { + let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap(); + cudarc::driver::result::ctx::set_current(ctx).unwrap(); + ctx + }; + + let stream = cudarc::driver::result::stream::create( + cudarc::driver::result::stream::StreamKind::NonBlocking, + ) + .unwrap(); + let storage = CudaStorage::new(stream); + let memory_management = SimpleMemoryManagement::new( + storage, + DeallocStrategy::new_period_tick(1), + SliceStrategy::Never, + ); + CudaContext::new(memory_management, stream, ctx) + } + + RUNTIME.client(device, move || { + let server = CudaServer::new(device.index, Box::new(init)); + + let tuner_device_id = tuner_device_id(); + ComputeClient::new( + MutexComputeChannel::new(server), + Arc::new(RwLock::new(Tuner::new(&tuner_device_id))), + ) + }) + } + + fn name() -> &'static str { + "cuda" + } + + fn require_array_lengths() -> bool { + true + } +} + +fn tuner_device_id() -> String { + "cuda".into() +} diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index a40e60489..d0f727e5d 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -37,6 +37,7 @@ log = { workspace = true } num-traits = { workspace = true } rand = { workspace = true } spin = { workspace = true } +half = { workspace = true, features = ["bytemuck"] } # Template serde = { workspace = true } diff --git a/crates/burn-jit/src/codegen/compiler.rs b/crates/burn-jit/src/codegen/compiler.rs index 4b3a59d56..3941d5eab 100644 --- a/crates/burn-jit/src/codegen/compiler.rs +++ b/crates/burn-jit/src/codegen/compiler.rs @@ -1,11 +1,17 @@ use super::dialect::gpu; use std::fmt::Display; +/// Trait for compiled code representation +pub trait CompilerRepresentation: Display { + /// Computes and returns the shared memory size + fn shared_memory_size(&self) -> usize; +} + /// Compiles the [gpu representation](gpu::ComputeShader) into its own representation that can be /// formatted into tokens. pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { /// The representation for the compiled code. - type Representation: Display; + type Representation: CompilerRepresentation; /// Compiles the [gpu shader](gpu::ComputeShader) into the compiler's representation. fn compile(shader: gpu::ComputeShader) -> Self::Representation; diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index 836d23d40..ecb095f35 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -217,12 +217,24 @@ macro_rules! gpu { gpu!(binary $lhs, $rhs, $out) )); }; + // out = unchecked(lhs[rhs]) + ($scope:expr, $out:ident = unchecked($lhs:ident[$rhs:expr])) => { + $scope.register($crate::codegen::dialect::gpu::Operator::UncheckedIndex( + gpu!(binary $lhs, $rhs, $out) + )); + }; // out[lhs] = rhs ($scope:expr, $out:ident[$lhs:ident] = $rhs:expr) => { $scope.register($crate::codegen::dialect::gpu::Operator::IndexAssign( gpu!(binary $lhs, $rhs, $out) )); }; + // unchecked(out[lhs]) = rhs + ($scope:expr, unchecked($out:ident[$lhs:ident]) = $rhs:expr) => { + $scope.register($crate::codegen::dialect::gpu::Operator::UncheckedIndexAssign( + gpu!(binary $lhs, $rhs, $out) + )); + }; // out = |input| ($scope:expr, $out:ident = |$input:ident|) => { gpu!($scope, $out = abs($input)) diff --git a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs index f44a8ae30..fa41fc70b 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/operation.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/operation.rs @@ -50,7 +50,9 @@ pub enum Operator { Assign(UnaryOperator), Modulo(BinaryOperator), Index(BinaryOperator), + UncheckedIndex(BinaryOperator), IndexAssign(BinaryOperator), + UncheckedIndexAssign(BinaryOperator), And(BinaryOperator), Or(BinaryOperator), Not(UnaryOperator), diff --git a/crates/burn-jit/src/codegen/dialect/gpu/procedure/base.rs b/crates/burn-jit/src/codegen/dialect/gpu/procedure/base.rs index 9055c95eb..86a5eccaf 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/procedure/base.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/procedure/base.rs @@ -1,5 +1,6 @@ use super::{ - ConditionalAssign, IndexOffsetGlobalWithLayout, ReadGlobal, ReadGlobalWithLayout, WriteGlobal, + CheckedIndex, CheckedIndexAssign, ConditionalAssign, IndexOffsetGlobalWithLayout, ReadGlobal, + ReadGlobalWithLayout, WriteGlobal, }; use crate::codegen::dialect::gpu::Vectorization; use serde::{Deserialize, Serialize}; @@ -13,6 +14,8 @@ pub enum Procedure { IndexOffsetGlobalWithLayout(IndexOffsetGlobalWithLayout), ReadGlobal(ReadGlobal), WriteGlobal(WriteGlobal), + CheckedIndex(CheckedIndex), + CheckedIndexAssign(CheckedIndexAssign), ConditionalAssign(ConditionalAssign), } @@ -22,14 +25,18 @@ impl Procedure { Procedure::ReadGlobalWithLayout(op) => { Procedure::ReadGlobalWithLayout(op.vectorize(vectorization)) } - Procedure::ReadGlobal(op) => Procedure::ReadGlobal(op.vectorize(vectorization)), - Procedure::WriteGlobal(op) => Procedure::WriteGlobal(op.vectorize(vectorization)), - Procedure::ConditionalAssign(proc) => { - Procedure::ConditionalAssign(proc.vectorize(vectorization)) - } Procedure::IndexOffsetGlobalWithLayout(op) => { Procedure::IndexOffsetGlobalWithLayout(op.vectorize(vectorization)) } + Procedure::ReadGlobal(op) => Procedure::ReadGlobal(op.vectorize(vectorization)), + Procedure::WriteGlobal(op) => Procedure::WriteGlobal(op.vectorize(vectorization)), + Procedure::CheckedIndex(proc) => Procedure::CheckedIndex(proc.vectorize(vectorization)), + Procedure::CheckedIndexAssign(proc) => { + Procedure::CheckedIndexAssign(proc.vectorize(vectorization)) + } + Procedure::ConditionalAssign(proc) => { + Procedure::ConditionalAssign(proc.vectorize(vectorization)) + } } } } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/procedure/index.rs b/crates/burn-jit/src/codegen/dialect/gpu/procedure/index.rs new file mode 100644 index 000000000..95b0797e2 --- /dev/null +++ b/crates/burn-jit/src/codegen/dialect/gpu/procedure/index.rs @@ -0,0 +1,74 @@ +use crate::codegen::dialect::gpu::{macros::gpu, Item, Scope, Variable, Vectorization}; +use serde::{Deserialize, Serialize}; + +/// Perform a check bound on the index (lhs) of value (rhs) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[allow(missing_docs)] +pub struct CheckedIndex { + pub lhs: Variable, + pub rhs: Variable, + pub out: Variable, +} + +impl CheckedIndex { + #[allow(missing_docs)] + pub fn expand(self, scope: &mut Scope) { + let lhs = self.lhs; + let rhs = self.rhs; + let out = self.out; + let array_len = scope.create_local(Item::Scalar(crate::gpu::Elem::UInt)); + let inside_bound = scope.create_local(Item::Scalar(crate::gpu::Elem::Bool)); + + gpu!(scope, array_len = len(lhs)); + gpu!(scope, inside_bound = rhs < array_len); + + gpu!(scope, if(inside_bound).then(|scope| { + gpu!(scope, out = unchecked(lhs[rhs])); + }).else(|scope| { + gpu!(scope, out = cast(0)); + })); + } + + pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { + Self { + lhs: self.lhs.vectorize(vectorization), + rhs: self.rhs.vectorize(vectorization), + out: self.out.vectorize(vectorization), + } + } +} + +/// Perform a check bound on the index (lhs) of output before assigning the value (rhs) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[allow(missing_docs)] +pub struct CheckedIndexAssign { + pub lhs: Variable, + pub rhs: Variable, + pub out: Variable, +} + +impl CheckedIndexAssign { + #[allow(missing_docs)] + pub fn expand(self, scope: &mut Scope) { + let lhs = self.lhs; + let rhs = self.rhs; + let out = self.out; + let array_len = scope.create_local(Item::Scalar(crate::gpu::Elem::UInt)); + let inside_bound = scope.create_local(Item::Scalar(crate::gpu::Elem::Bool)); + + gpu!(scope, array_len = len(out)); + gpu!(scope, inside_bound = lhs < array_len); + + gpu!(scope, if(inside_bound).then(|scope| { + gpu!(scope, unchecked(out[lhs]) = rhs); + })); + } + + pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { + Self { + lhs: self.lhs.vectorize(vectorization), + rhs: self.rhs.vectorize(vectorization), + out: self.out.vectorize(vectorization), + } + } +} diff --git a/crates/burn-jit/src/codegen/dialect/gpu/procedure/mod.rs b/crates/burn-jit/src/codegen/dialect/gpu/procedure/mod.rs index 7994faff1..a537fc04d 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/procedure/mod.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/procedure/mod.rs @@ -1,9 +1,11 @@ mod assign; mod base; +mod index; mod read; mod write; pub use assign::*; pub use base::*; +pub use index::*; pub use read::*; pub use write::*; diff --git a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs index a4651f029..fbd0deb19 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs @@ -20,6 +20,8 @@ pub enum Visibility { #[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)] #[allow(missing_docs)] pub enum FloatKind { + F16, + BF16, F32, F64, } @@ -68,7 +70,8 @@ pub enum Item { } impl Item { - pub(crate) fn elem(&self) -> Elem { + /// Fetch the elem of the item. + pub fn elem(&self) -> Elem { match self { Self::Vec4(elem) => *elem, Self::Vec3(elem) => *elem, diff --git a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs index cd52ea413..4a6a336d6 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/variable.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/variable.rs @@ -63,7 +63,8 @@ impl Variable { Variable::NumWorkgroupsZ => None, } } - pub(crate) fn item(&self) -> Item { + /// Fetch the item of the variable. + pub fn item(&self) -> Item { match self { Variable::GlobalInputArray(_, item) => *item, Variable::GlobalOutputArray(_, item) => *item, diff --git a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs index b81fd248e..1bec46323 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/vectorization.rs @@ -40,6 +40,7 @@ impl Operator { Operator::Min(op) => Operator::Min(op.vectorize(vectorization)), Operator::Add(op) => Operator::Add(op.vectorize(vectorization)), Operator::Index(op) => Operator::Index(op.vectorize(vectorization)), + Operator::UncheckedIndex(op) => Operator::UncheckedIndex(op.vectorize(vectorization)), Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)), Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)), Operator::Div(op) => Operator::Div(op.vectorize(vectorization)), @@ -74,6 +75,9 @@ impl Operator { } Operator::Modulo(op) => Operator::Modulo(op.vectorize(vectorization)), Operator::IndexAssign(op) => Operator::IndexAssign(op.vectorize(vectorization)), + Operator::UncheckedIndexAssign(op) => { + Operator::UncheckedIndexAssign(op.vectorize(vectorization)) + } Operator::And(op) => Operator::And(op.vectorize(vectorization)), Operator::Or(op) => Operator::Or(op.vectorize(vectorization)), Operator::Not(op) => Operator::Not(op.vectorize(vectorization)), diff --git a/crates/burn-jit/src/codegen/kernel.rs b/crates/burn-jit/src/codegen/kernel.rs index c74acfa87..523471965 100644 --- a/crates/burn-jit/src/codegen/kernel.rs +++ b/crates/burn-jit/src/codegen/kernel.rs @@ -285,6 +285,19 @@ fn execute_settings<'a, R: Runtime, E1: JitElement, E2: JitElement, E3: JitEleme handles.push(output.handle.clone().binding()); } + // [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len] + if R::require_array_lengths() { + for input in inputs.iter() { + let len = calculate_num_elems_dyn_rank(input.shape); + info.push(len as u32); + } + + for output in outputs.iter() { + let len = calculate_num_elems_dyn_rank(output.shape); + info.push(len as u32); + } + } + let info = client.create(bytemuck::cast_slice(&info)); // Finally we finish with the named bindings. diff --git a/crates/burn-jit/src/compute/kernel.rs b/crates/burn-jit/src/compute/kernel.rs index 4b9fa7d4c..4d9d0e604 100644 --- a/crates/burn-jit/src/compute/kernel.rs +++ b/crates/burn-jit/src/compute/kernel.rs @@ -2,7 +2,9 @@ use std::marker::PhantomData; #[cfg(feature = "template")] use crate::template::TemplateKernel; -use crate::{gpu::WorkgroupSize, kernel::GpuComputeShaderPhase, Compiler}; +use crate::{ + codegen::CompilerRepresentation, gpu::WorkgroupSize, kernel::GpuComputeShaderPhase, Compiler, +}; use alloc::sync::Arc; /// Kernel for JIT backends @@ -53,6 +55,8 @@ pub struct CompiledKernel { pub source: String, /// Size of a workgroup for the compiled kernel pub workgroup_size: WorkgroupSize, + /// The number of bytes used by the share memory + pub shared_mem_bytes: usize, } /// Information needed to launch the kernel @@ -86,13 +90,14 @@ impl JitKernel for FullCompilationPhase CompiledKernel { let gpu_ir = self.kernel.compile(); let workgroup_size = gpu_ir.workgroup_size; - let lower_level_ir = C::compile(gpu_ir); + let shared_mem_bytes = lower_level_ir.shared_memory_size(); let source = lower_level_ir.to_string(); CompiledKernel { source, workgroup_size, + shared_mem_bytes, } } diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index 5fd78e484..c5aa160aa 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -92,5 +92,26 @@ impl JitElement for f32 { } } +impl JitElement for half::bf16 { + fn type_name() -> &'static str { + "bf16" + } + fn as_bytes(slice: &[Self]) -> &[u8] { + bytemuck::cast_slice(slice) + } + fn from_bytes(bytes: &[u8]) -> &[Self] { + bytemuck::cast_slice(bytes) + } + fn gpu_elem() -> gpu::Elem { + gpu::Elem::Float(gpu::FloatKind::BF16) + } + fn maximum_value() -> Self { + half::bf16::MAX + } + fn minimum_value() -> Self { + half::bf16::MIN + } +} impl FloatElement for f32 {} +impl FloatElement for half::bf16 {} impl IntElement for i32 {} diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index 0304bfeb4..c15a41720 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -1,3 +1,4 @@ +use crate::codegen::calculate_num_elems_dyn_rank; use crate::codegen::Compilation; use crate::codegen::CompilationInfo; use crate::codegen::CompilationSettings; @@ -165,14 +166,14 @@ impl FusionKernel { let mut output_register = Vec::with_capacity(outputs_description_updated.len()); // We register the info and handles for the inputs. - for (handle, tensor) in handles_input.iter().zip(inputs_description_updated) { + for (handle, tensor) in handles_input.iter().zip(inputs_description_updated.iter()) { register_info_tensor(&mut info, tensor, handle); bindings.push(handle.handle.clone().binding()); } // We register the info and handles for the outputs. for (tensor, output_info) in outputs_description_updated - .into_iter() + .iter() .zip(fusion_kernel.runtime_info.iter()) { match output_info { @@ -204,6 +205,19 @@ impl FusionKernel { }; } + // [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len] + if R::require_array_lengths() { + for input in inputs_description_updated.iter() { + let len = calculate_num_elems_dyn_rank(&input.shape); + info.push(len as u32); + } + + for output in outputs_description_updated.iter() { + let len = calculate_num_elems_dyn_rank(&output.shape); + info.push(len as u32); + } + } + // Create the info buffer. bindings.push(client.create(bytemuck::cast_slice(&info)).binding()); diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 0377e19e7..1cec52dda 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -219,6 +219,11 @@ impl TraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), + gpu::Operator::UncheckedIndex(op) => mark_binary( + op, + &mut local_tensor_ids_input, + &mut local_tensor_ids_output, + ), gpu::Operator::Sub(op) => mark_binary( op, &mut local_tensor_ids_input, @@ -343,6 +348,11 @@ impl TraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), + gpu::Operator::UncheckedIndexAssign(op) => mark_binary( + op, + &mut local_tensor_ids_input, + &mut local_tensor_ids_output, + ), gpu::Operator::BitwiseAnd(op) => mark_binary( op, &mut local_tensor_ids_input, @@ -380,6 +390,12 @@ impl TraceBuilder { gpu::Procedure::WriteGlobal(_) => { // Nothing to do here. } + gpu::Procedure::CheckedIndex(_) => { + // Nothing to do here. + } + gpu::Procedure::CheckedIndexAssign(_) => { + // Nothing to do here. + } gpu::Procedure::ConditionalAssign(proc) => { mark(&proc.cond, &mut local_tensor_ids_input); mark(&proc.lhs, &mut local_tensor_ids_input); diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index ba1f3571d..60bb659da 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -104,6 +104,7 @@ pub enum MatmulStrategy { Autotune, } +#[cfg(feature = "autotune")] #[cfg(not(feature = "autotune"))] impl Default for MatmulStrategy { fn default() -> Self { diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index e178eabae..fa08bb292 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -19,7 +19,7 @@ pub(crate) mod codegen; pub(crate) mod tune; mod element; -pub use codegen::compiler::Compiler; +pub use codegen::compiler::{Compiler, CompilerRepresentation}; pub use codegen::dialect::gpu; pub use element::{FloatElement, IntElement, JitElement}; diff --git a/crates/burn-jit/src/runtime.rs b/crates/burn-jit/src/runtime.rs index 91a194be9..77681482e 100644 --- a/crates/burn-jit/src/runtime.rs +++ b/crates/burn-jit/src/runtime.rs @@ -28,4 +28,9 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { /// The runtime name. fn name() -> &'static str; + + /// Return true if global input array lengths should be added to kernel info. + fn require_array_lengths() -> bool { + false + } } diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 09f9728f5..0b467b9ca 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -43,9 +43,11 @@ where fn compile(&self) -> CompiledKernel { let source_template = self.kernel_source.source(); let source = source_template.complete(); + CompiledKernel { source, workgroup_size: self.workgroup_size, + shared_mem_bytes: 0, } } diff --git a/crates/burn-tensor/src/tests/clone_invariance.rs b/crates/burn-tensor/src/tests/clone_invariance.rs index 40b7599e9..1bcd885c0 100644 --- a/crates/burn-tensor/src/tests/clone_invariance.rs +++ b/crates/burn-tensor/src/tests/clone_invariance.rs @@ -74,13 +74,9 @@ mod tests { fn args(&self) -> Self::Args { let device = Default::default(); ( - TestTensor::random([32, 32], Distribution::Default, &device) - .into_data() - .convert(), + TestTensor::ones([32, 32], &device).into_data().convert(), // Avoid div by zero. - TestTensor::random([32, 32], Distribution::Uniform(1., 3.), &device) - .into_data() - .convert(), + TestTensor::ones([32, 32], &device).into_data().convert(), ) } diff --git a/crates/burn-wgpu/src/compiler/wgsl/body.rs b/crates/burn-wgpu/src/compiler/wgsl/body.rs index a48affa52..dfa11638c 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/body.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/body.rs @@ -1,7 +1,7 @@ use super::Instruction; use std::fmt::Display; -/// A body is composed of a list of [operations](Operation). +/// A body is composed of a list of [instructions](Instruction). /// /// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size /// X and Y, but with Z=1. diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index 87d2a31b9..9a267a85b 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -100,6 +100,8 @@ impl WgslCompiler { fn compile_elem(value: gpu::Elem) -> wgsl::Elem { match value { gpu::Elem::Float(f) => match f { + gpu::FloatKind::F16 => panic!("f16 is not yet supported"), + gpu::FloatKind::BF16 => panic!("f64 is not a valid WgpuElement"), gpu::FloatKind::F32 => wgsl::Elem::F32, gpu::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"), }, @@ -317,6 +319,14 @@ impl WgslCompiler { proc.expand(scope); compile(scope); } + gpu::Procedure::CheckedIndex(proc) => { + proc.expand(scope); + compile(scope); + } + gpu::Procedure::CheckedIndexAssign(proc) => { + proc.expand(scope); + compile(scope); + } gpu::Procedure::IndexOffsetGlobalWithLayout(proc) => { proc.expand(scope); compile(scope); @@ -381,6 +391,11 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(op.out), }, + gpu::Operator::UncheckedIndex(op) => wgsl::Instruction::Index { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(op.out), + }, gpu::Operator::Modulo(op) => wgsl::Instruction::Modulo { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), @@ -499,6 +514,11 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(op.out), }, + gpu::Operator::UncheckedIndexAssign(op) => wgsl::Instruction::IndexAssign { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(op.out), + }, gpu::Operator::And(op) => wgsl::Instruction::And { lhs: self.compile_variable(op.lhs), rhs: self.compile_variable(op.rhs), @@ -593,6 +613,14 @@ fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec { register_extension(wgsl::Extension::SafeTanh(input.item())) } + wgsl::Instruction::If { + cond: _, + instructions, + } => { + for extension in register_extensions(instructions) { + register_extension(extension); + } + } _ => {} } } diff --git a/crates/burn-wgpu/src/compiler/wgsl/shader.rs b/crates/burn-wgpu/src/compiler/wgsl/shader.rs index 8229ac451..87de8b925 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/shader.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/shader.rs @@ -1,5 +1,5 @@ use super::{Body, Extension, Item}; -use burn_jit::gpu::WorkgroupSize; +use burn_jit::{gpu::WorkgroupSize, CompilerRepresentation}; use std::fmt::Display; #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -218,3 +218,10 @@ impl Display for Visibility { } } } + +impl CompilerRepresentation for ComputeShader { + fn shared_memory_size(&self) -> usize { + // not used in wgsl compiler + 0 + } +} diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index a9cd196b5..ea62a1058 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -41,7 +41,7 @@ autodiff = ["burn-core/autodiff"] fusion = ["burn-core/fusion"] ## Backend features -cuda = ["burn-core/cuda"] +candle-cuda = ["burn-core/candle-cuda"] metal = ["burn-core/metal"] accelerate = ["burn-core/accelerate"] openblas = ["burn-core/openblas"] diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index b9d7fe4c5..850b07c81 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -15,7 +15,7 @@ use crate::utils::cargo::{run_cargo, run_cargo_with_path}; use crate::utils::process::{handle_child_process, run_command}; use crate::utils::rustup::{rustup_add_component, rustup_add_target}; use crate::utils::time::format_duration; -use crate::utils::workspace::{get_workspaces, WorkspaceMemberType}; +use crate::utils::workspace::{get_workspace_members, WorkspaceMemberType}; use crate::utils::Params; use crate::{endgroup, group}; @@ -310,9 +310,13 @@ fn std_checks() { // Check clippy lints cargo_clippy(); - // Produce documentation for each workspace - group!("Docs: workspaces"); - cargo_doc(["--workspace", "--no-deps"].into()); + // Produce documentation for each workspace member + group!("Docs: crates"); + let mut params = Params::from(["--workspace", "--no-deps"]); + // Exclude burn-cuda on all platforms + params.params.push("--exclude".to_string()); + params.params.push("burn-cuda".to_string()); + cargo_doc(params); endgroup!(); // Setup code coverage @@ -320,20 +324,23 @@ fn std_checks() { setup_coverage(); } - // Build & test each workspace - let workspaces = get_workspaces(WorkspaceMemberType::Crate); - for workspace in workspaces { - if disable_wgpu && workspace.name == "burn-wgpu" { + // Build & test each member in workspace + let members = get_workspace_members(WorkspaceMemberType::Crate); + for member in members { + if disable_wgpu && member.name == "burn-wgpu" { + continue; + } + if member.name == "burn-cuda" { + // burn-cuda requires CUDA Toolkit which is not currently setup on our CI runners + continue; + } + if member.name == "burn-tch" { continue; } - if workspace.name == "burn-tch" { - continue; - } - - group!("Checks: {}", workspace.name); - cargo_build(Params::from(["-p", &workspace.name])); - cargo_test(Params::from(["-p", &workspace.name])); + group!("Checks: {}", member.name); + cargo_build(Params::from(["-p", &member.name])); + cargo_test(Params::from(["-p", &member.name])); endgroup!(); } @@ -381,18 +388,18 @@ fn check_typos() { } fn check_examples() { - let workspaces = get_workspaces(WorkspaceMemberType::Example); - for workspace in workspaces { - if workspace.name == "notebook" { + let members = get_workspace_members(WorkspaceMemberType::Example); + for member in members { + if member.name == "notebook" { continue; } - group!("Checks: Example - {}", workspace.name); + group!("Checks: Example - {}", member.name); run_cargo_with_path( "check", ["--examples"].into(), HashMap::new(), - Some(workspace.path), + Some(member.path), "Failed to check example", ); endgroup!(); diff --git a/xtask/src/utils/mod.rs b/xtask/src/utils/mod.rs index 55345eb5a..7e176cc74 100644 --- a/xtask/src/utils/mod.rs +++ b/xtask/src/utils/mod.rs @@ -6,7 +6,7 @@ pub(crate) mod time; pub(crate) mod workspace; pub(crate) struct Params { - params: Vec, + pub params: Vec, } impl From<[&str; N]> for Params { diff --git a/xtask/src/utils/workspace.rs b/xtask/src/utils/workspace.rs index 4d9a2e19b..027d26dd9 100644 --- a/xtask/src/utils/workspace.rs +++ b/xtask/src/utils/workspace.rs @@ -25,8 +25,8 @@ impl WorkspaceMember { } } -/// Get project workspaces -pub(crate) fn get_workspaces(w_type: WorkspaceMemberType) -> Vec { +/// Get workspace crates +pub(crate) fn get_workspace_members(w_type: WorkspaceMemberType) -> Vec { // Run `cargo metadata` command to get project metadata let output = Command::new("cargo") .arg("metadata")