First draft CUDA runtime (#1685)

Initial cuda runtime crate with a WIP compiler.
This commit is contained in:
Nathaniel Simard 2024-04-30 09:46:29 -04:00 committed by GitHub
parent ab501431b1
commit 587b8f80b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 2910 additions and 108 deletions

200
Cargo.lock generated
View File

@ -132,19 +132,18 @@ checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
[[package]]
name = "arboard"
version = "3.3.2"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2041f1943049c7978768d84e6d0fd95de98b76d6c4727b09e78ec253d29fa58"
checksum = "9fb4009533e8ff8f1450a5bcbc30f4242a1d34442221f72314bea1f5dc9c7f89"
dependencies = [
"clipboard-win",
"core-graphics",
"image",
"image 0.25.1",
"log",
"objc",
"objc-foundation",
"objc_id",
"parking_lot 0.12.1",
"thiserror",
"objc2",
"objc2-app-kit",
"objc2-foundation",
"parking_lot 0.12.2",
"windows-sys 0.48.0",
"x11rb",
]
@ -339,6 +338,15 @@ dependencies = [
"generic-array",
]
[[package]]
name = "block2"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43ff7d91d3c1d568065b06c899777d1e48dcf76103a672a0adbc238a7f247f1e"
dependencies = [
"objc2",
]
[[package]]
name = "bstr"
version = "1.9.1"
@ -453,6 +461,22 @@ dependencies = [
"thiserror",
]
[[package]]
name = "burn-cuda"
version = "0.14.0"
dependencies = [
"burn-common",
"burn-compute",
"burn-fusion",
"burn-jit",
"burn-tensor",
"bytemuck",
"cudarc",
"derive-new",
"half",
"log",
]
[[package]]
name = "burn-dataset"
version = "0.14.0"
@ -466,7 +490,7 @@ dependencies = [
"gix-tempfile",
"globwalk",
"hound",
"image",
"image 0.24.9",
"r2d2",
"r2d2_sqlite",
"rand",
@ -549,6 +573,7 @@ dependencies = [
"burn-tensor-testgen",
"bytemuck",
"derive-new",
"half",
"hashbrown 0.14.5",
"log",
"num-traits",
@ -1116,7 +1141,7 @@ dependencies = [
"crossterm_winapi",
"libc",
"mio",
"parking_lot 0.12.1",
"parking_lot 0.12.2",
"signal-hook",
"signal-hook-mio",
"winapi",
@ -1287,7 +1312,7 @@ dependencies = [
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core 0.9.9",
"parking_lot_core 0.9.10",
]
[[package]]
@ -1559,9 +1584,9 @@ dependencies = [
[[package]]
name = "fastrand"
version = "2.0.2"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984"
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
[[package]]
name = "fdeflate"
@ -1586,9 +1611,9 @@ dependencies = [
[[package]]
name = "flate2"
version = "1.0.29"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4556222738635b7a3417ae6130d8f52201e45a0c4d1a907f0826383adb5f85e7"
checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae"
dependencies = [
"crc32fast",
"miniz_oxide",
@ -1719,7 +1744,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f"
dependencies = [
"futures-core",
"lock_api",
"parking_lot 0.12.1",
"parking_lot 0.12.2",
]
[[package]]
@ -1996,7 +2021,7 @@ dependencies = [
"gix-fs",
"libc",
"once_cell",
"parking_lot 0.12.1",
"parking_lot 0.12.2",
"signal-hook",
"signal-hook-registry",
"tempfile",
@ -2439,6 +2464,19 @@ dependencies = [
"tiff",
]
[[package]]
name = "image"
version = "0.25.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11"
dependencies = [
"bytemuck",
"byteorder",
"num-traits",
"png",
"tiff",
]
[[package]]
name = "image-classification-web"
version = "0.14.0"
@ -2669,9 +2707,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "lock_api"
version = "0.4.11"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45"
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
dependencies = [
"autocfg",
"scopeguard",
@ -3055,14 +3093,58 @@ dependencies = [
]
[[package]]
name = "objc-foundation"
version = "0.1.1"
name = "objc-sys"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1add1b659e36c9607c7aab864a76c7a4c2760cd0cd2e120f3fb8b952c7e22bf9"
checksum = "da284c198fb9b7b0603f8635185e85fbd5b64ee154b1ed406d489077de2d6d60"
[[package]]
name = "objc2"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4b25e1034d0e636cd84707ccdaa9f81243d399196b8a773946dcffec0401659"
dependencies = [
"block",
"objc",
"objc_id",
"objc-sys",
"objc2-encode",
]
[[package]]
name = "objc2-app-kit"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb79768a710a9a1798848179edb186d1af7e8a8679f369e4b8d201dd2a034047"
dependencies = [
"block2",
"objc2",
"objc2-core-data",
"objc2-foundation",
]
[[package]]
name = "objc2-core-data"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e092bc42eaf30a08844e6a076938c60751225ec81431ab89f5d1ccd9f958d6c"
dependencies = [
"block2",
"objc2",
"objc2-foundation",
]
[[package]]
name = "objc2-encode"
version = "4.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88658da63e4cc2c8adb1262902cd6af51094df0488b760d6fd27194269c0950a"
[[package]]
name = "objc2-foundation"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfaefe14254871ea16c7d88968c0ff14ba554712a20d76421eec52f0a7fb8904"
dependencies = [
"block2",
"objc2",
]
[[package]]
@ -3074,15 +3156,6 @@ dependencies = [
"cc",
]
[[package]]
name = "objc_id"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c92d4ddb4bd7b50d730c215ff871754d0da6b2178849f8a2a2ab69712d0c073b"
dependencies = [
"objc",
]
[[package]]
name = "object"
version = "0.32.2"
@ -3252,12 +3325,12 @@ dependencies = [
[[package]]
name = "parking_lot"
version = "0.12.1"
version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb"
dependencies = [
"lock_api",
"parking_lot_core 0.9.9",
"parking_lot_core 0.9.10",
]
[[package]]
@ -3276,15 +3349,15 @@ dependencies = [
[[package]]
name = "parking_lot_core"
version = "0.9.9"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e"
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
dependencies = [
"cfg-if",
"libc",
"redox_syscall 0.4.1",
"redox_syscall 0.5.1",
"smallvec",
"windows-targets 0.48.5",
"windows-targets 0.52.5",
]
[[package]]
@ -3542,7 +3615,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93"
dependencies = [
"log",
"parking_lot 0.12.1",
"parking_lot 0.12.2",
"scheduled-thread-pool",
]
@ -3699,6 +3772,15 @@ dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "redox_syscall"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e"
dependencies = [
"bitflags 2.5.0",
]
[[package]]
name = "redox_users"
version = "0.4.5"
@ -3972,9 +4054,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
version = "1.4.1"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247"
checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54"
[[package]]
name = "rustls-webpki"
@ -4062,7 +4144,7 @@ version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19"
dependencies = [
"parking_lot 0.12.1",
"parking_lot 0.12.2",
]
[[package]]
@ -4185,7 +4267,7 @@ dependencies = [
"futures",
"log",
"once_cell",
"parking_lot 0.12.1",
"parking_lot 0.12.2",
"scc",
"serial_test_derive",
]
@ -4294,9 +4376,9 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "socket2"
version = "0.5.6"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871"
checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c"
dependencies = [
"libc",
"windows-sys 0.52.0",
@ -4906,9 +4988,9 @@ checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202"
[[package]]
name = "unicode-width"
version = "0.1.11"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6"
[[package]]
name = "unicode-xid"
@ -4930,11 +5012,11 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.9.6"
version = "2.9.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35"
checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd"
dependencies = [
"base64 0.21.7",
"base64 0.22.0",
"flate2",
"log",
"native-tls",
@ -5158,7 +5240,7 @@ dependencies = [
"js-sys",
"log",
"naga",
"parking_lot 0.12.1",
"parking_lot 0.12.2",
"profiling",
"raw-window-handle",
"smallvec",
@ -5186,7 +5268,7 @@ dependencies = [
"log",
"naga",
"once_cell",
"parking_lot 0.12.1",
"parking_lot 0.12.2",
"profiling",
"raw-window-handle",
"rustc-hash",
@ -5228,7 +5310,7 @@ dependencies = [
"ndk-sys",
"objc",
"once_cell",
"parking_lot 0.12.1",
"parking_lot 0.12.2",
"profiling",
"range-alloc",
"raw-window-handle",
@ -5289,11 +5371,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.6"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596"
checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b"
dependencies = [
"winapi",
"windows-sys 0.52.0",
]
[[package]]

View File

@ -13,7 +13,7 @@ version.workspace = true
# we depend on wgpu and autotune by default because we use the burn-wgpu crate to get system information
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle", "burn/cuda"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
candle-accelerate = ["burn/candle", "burn/accelerate"]
ndarray = ["burn/ndarray"]

View File

@ -51,7 +51,7 @@ pub struct Handle<Server: ComputeServer> {
}
/// Binding of a [tensor handle](Handle) to execute a kernel.
#[derive(new)]
#[derive(new, Debug)]
pub struct Binding<Server: ComputeServer> {
/// Memory binding.
pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Binding,

View File

@ -72,7 +72,6 @@ autodiff = ["burn-autodiff"]
fusion = ["burn-wgpu?/fusion"]
## Backend features
cuda = ["burn-candle?/cuda"]
metal = ["burn-candle?/metal"]
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
openblas = ["burn-ndarray?/blas-openblas"]
@ -84,6 +83,7 @@ template = ["burn-wgpu?/template"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
wgpu = ["burn-wgpu"]
# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.

View File

@ -0,0 +1,40 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "CUDA backend for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "gpu", "cuda"]
license.workspace = true
name = "burn-cuda"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cuda"
version.workspace = true
[features]
default = ["fusion", "burn-jit/default"]
fusion = ["burn-fusion", "burn-jit/fusion"]
autotune = ["burn-jit/autotune"]
doc = ["burn-jit/doc"]
std = ["burn-jit/std"]
[dependencies]
burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false }
burn-compute = { path = "../burn-compute", version = "0.14.0" }
burn-tensor = { path = "../burn-tensor", version = "0.14.0" }
burn-common = { path = "../burn-common", version = "0.14.0" }
burn-fusion = { path = "../burn-fusion", version = "0.14.0", optional = true }
half = { workspace = true }
bytemuck = { workspace = true }
cudarc = "0.10.0"
log = { workspace = true }
derive-new = { workspace = true }
[dev-dependencies]
burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false, features = [
"export_tests",
] }
[package.metadata.docs.rs]
features = ["doc"]

View File

@ -0,0 +1,5 @@
# Burn-Cuda
This backend is still a work in progress and not ready to be used.
See #1525

View File

@ -0,0 +1,442 @@
use super::Instruction;
use burn_jit::gpu::{self};
#[allow(clippy::too_many_arguments)]
#[derive(new, Clone, Debug, Default)]
pub struct CudaCompiler {
shape: bool,
stride: bool,
num_inputs: usize,
num_outputs: usize,
shared_memories: Vec<super::SharedMemory>,
local_arrays: Vec<super::LocalArray>,
id: bool,
rank: bool,
invocation_index: bool,
global_invocation_id: (bool, bool, bool),
}
impl burn_jit::Compiler for CudaCompiler {
type Representation = super::ComputeShader;
fn compile(shader: burn_jit::gpu::ComputeShader) -> Self::Representation {
let compiler = Self::default();
compiler.compile_shader(shader)
}
fn elem_size(elem: burn_jit::gpu::Elem) -> usize {
Self::compile_elem(elem).size()
}
fn max_shared_memory_size() -> usize {
// TODO: Find out this value.
usize::MAX
}
}
impl CudaCompiler {
fn compile_shader(mut self, mut value: gpu::ComputeShader) -> super::ComputeShader {
self.num_inputs = value.inputs.len();
self.num_outputs = value.outputs.len();
let instructions = self.compile_scope(&mut value.body);
let body = super::Body {
instructions,
stride: true,
shape: true,
shared_memories: self.shared_memories,
local_arrays: self.local_arrays,
rank: self.rank,
id: self.id,
invocation_index: self.invocation_index,
global_invocation_id: self.global_invocation_id,
};
super::ComputeShader {
inputs: value
.inputs
.into_iter()
.map(Self::compile_binding)
.collect(),
outputs: value
.outputs
.into_iter()
.map(Self::compile_binding)
.collect(),
named: value
.named
.into_iter()
.map(|(name, binding)| (name, Self::compile_binding(binding)))
.collect(),
workgroup_size: value.workgroup_size,
body,
}
}
fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec<Instruction> {
let mut instructions = Vec::new();
let mut processing = value.process();
for operation in &mut processing.operations {
if let gpu::Operation::Operator(gpu::Operator::Index(operands)) = operation {
// Replace all Index operators for global arrays with CheckedIndexAssign procedures
match operands.lhs {
gpu::Variable::GlobalInputArray(_, _)
| gpu::Variable::GlobalOutputArray(_, _) => {
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndex(
gpu::CheckedIndex {
lhs: operands.lhs,
rhs: operands.rhs,
out: operands.out,
},
));
}
// Cannot perform bound check on non-global arrays, do nothing.
_ => (),
}
}
if let gpu::Operation::Operator(gpu::Operator::IndexAssign(operands)) = operation {
// Replace all IndexAssign operators of global arrays with CheckedIndexAssign procedures
match operands.out {
gpu::Variable::GlobalInputArray(_, _)
| gpu::Variable::GlobalOutputArray(_, _) => {
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndexAssign(
gpu::CheckedIndexAssign {
lhs: operands.lhs,
rhs: operands.rhs,
out: operands.out,
},
));
}
// Cannot perform bound check on non-global arrays, do nothing.
_ => (),
}
}
}
for var in processing.variables {
instructions.push(Instruction::DeclareVariable {
var: self.compile_variable(var),
});
}
processing
.operations
.into_iter()
.for_each(|op| self.compile_operation(&mut instructions, op, value));
instructions
}
fn compile_operation(
&mut self,
instructions: &mut Vec<Instruction>,
operation: gpu::Operation,
scope: &mut gpu::Scope,
) {
match operation {
gpu::Operation::Operator(op) => instructions.push(self.compile_instruction(op)),
gpu::Operation::Procedure(proc) => self.compile_procedure(instructions, proc, scope),
gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)),
gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
gpu::Operation::Synchronization(val) => match val {
gpu::Synchronization::WorkgroupBarrier => {
instructions.push(Instruction::SyncThreads)
}
},
}
}
fn compile_metadata(&mut self, metadata: gpu::Metadata) -> Instruction {
match metadata {
gpu::Metadata::Stride { dim, var, out } => {
self.stride = true;
let position = match var {
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
};
Instruction::Stride {
dim: self.compile_variable(dim),
position,
out: self.compile_variable(out),
}
}
gpu::Metadata::Shape { dim, var, out } => {
self.shape = true;
let position = match var {
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
_ => panic!("Only Input and Output have a shape, got {:?}", var),
};
Instruction::Shape {
dim: self.compile_variable(dim),
position,
out: self.compile_variable(out),
}
}
gpu::Metadata::ArrayLength { var, out } => super::Instruction::ArrayLength {
input: self.compile_variable(var),
out: self.compile_variable(out),
num_inputs: self.num_inputs,
num_outputs: self.num_outputs,
},
}
}
fn compile_branch(&mut self, instructions: &mut Vec<Instruction>, branch: gpu::Branch) {
match branch {
gpu::Branch::If(mut op) => instructions.push(Instruction::If {
cond: self.compile_variable(op.cond),
instructions: self.compile_scope(&mut op.scope),
}),
gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
cond: self.compile_variable(op.cond),
instructions_if: self.compile_scope(&mut op.scope_if),
instructions_else: self.compile_scope(&mut op.scope_else),
}),
gpu::Branch::Return => instructions.push(Instruction::Return),
gpu::Branch::Break => instructions.push(Instruction::Break),
gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
i: self.compile_variable(range_loop.i),
start: self.compile_variable(range_loop.start),
end: self.compile_variable(range_loop.end),
instructions: self.compile_scope(&mut range_loop.scope),
}),
gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
instructions: self.compile_scope(&mut op.scope),
}),
};
}
fn compile_procedure(
&mut self,
instructions: &mut Vec<Instruction>,
proc: gpu::Procedure,
scope: &mut gpu::Scope,
) {
let mut compile = |scope: &mut gpu::Scope| {
instructions.extend(self.compile_scope(scope));
};
match proc {
gpu::Procedure::ReadGlobalWithLayout(proc) => {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::ReadGlobal(proc) => {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::WriteGlobal(proc) => {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::ConditionalAssign(proc) => {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::CheckedIndex(proc) => {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::CheckedIndexAssign(proc) => {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::IndexOffsetGlobalWithLayout(proc) => {
proc.expand(scope);
compile(scope);
}
}
}
fn compile_instruction(&mut self, value: gpu::Operator) -> Instruction {
match value {
gpu::Operator::Add(op) => Instruction::Add(self.compile_binary(op)),
gpu::Operator::Mul(op) => Instruction::Mul(self.compile_binary(op)),
gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)),
gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)),
gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)),
gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)),
gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)),
gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)),
gpu::Operator::UncheckedIndexAssign(op) => {
Instruction::IndexAssign(self.compile_binary(op))
}
gpu::Operator::Modulo(op) => Instruction::Modulo(self.compile_binary(op)),
gpu::Operator::Equal(op) => Instruction::Equal(self.compile_binary(op)),
gpu::Operator::Lower(op) => Instruction::Lower(self.compile_binary(op)),
gpu::Operator::Greater(op) => Instruction::Greater(self.compile_binary(op)),
gpu::Operator::LowerEqual(op) => Instruction::LowerEqual(self.compile_binary(op)),
gpu::Operator::GreaterEqual(op) => Instruction::GreaterEqual(self.compile_binary(op)),
gpu::Operator::Abs(op) => Instruction::Abs(self.compile_unary(op)),
gpu::Operator::Exp(op) => Instruction::Exp(self.compile_unary(op)),
gpu::Operator::Log(op) => Instruction::Log(self.compile_unary(op)),
gpu::Operator::Log1p(op) => Instruction::Log1p(self.compile_unary(op)),
gpu::Operator::Cos(op) => Instruction::Cos(self.compile_unary(op)),
gpu::Operator::Sin(op) => Instruction::Sin(self.compile_unary(op)),
gpu::Operator::Tanh(op) => Instruction::Tanh(self.compile_unary(op)),
gpu::Operator::Powf(op) => Instruction::Powf(self.compile_binary(op)),
gpu::Operator::Sqrt(op) => Instruction::Sqrt(self.compile_unary(op)),
gpu::Operator::Erf(op) => Instruction::Erf(self.compile_unary(op)),
gpu::Operator::And(op) => Instruction::And(self.compile_binary(op)),
gpu::Operator::Or(op) => Instruction::Or(self.compile_binary(op)),
gpu::Operator::Not(op) => Instruction::Not(self.compile_unary(op)),
gpu::Operator::Max(op) => Instruction::Max(self.compile_binary(op)),
gpu::Operator::Min(op) => Instruction::Min(self.compile_binary(op)),
gpu::Operator::NotEqual(op) => Instruction::NotEqual(self.compile_binary(op)),
gpu::Operator::BitwiseAnd(op) => Instruction::BitwiseAnd(self.compile_binary(op)),
gpu::Operator::BitwiseXor(op) => Instruction::BitwiseXor(self.compile_binary(op)),
gpu::Operator::ShiftLeft(op) => Instruction::ShiftLeft(self.compile_binary(op)),
gpu::Operator::ShiftRight(op) => Instruction::ShiftRight(self.compile_binary(op)),
gpu::Operator::Clamp(op) => Instruction::Clamp {
input: self.compile_variable(op.input),
min_value: self.compile_variable(op.min_value),
max_value: self.compile_variable(op.max_value),
out: self.compile_variable(op.out),
},
gpu::Operator::Recip(op) => Instruction::Div(super::BinaryInstruction {
lhs: super::Variable::ConstantScalar(
1.0,
Self::compile_elem(op.input.item().elem()),
),
rhs: self.compile_variable(op.input),
out: self.compile_variable(op.out),
}),
gpu::Operator::Floor(op) => Instruction::Floor(self.compile_unary(op)),
gpu::Operator::Ceil(op) => Instruction::Ceil(self.compile_unary(op)),
gpu::Operator::Remainder(_op) => todo!(),
}
}
fn compile_binary(&mut self, value: gpu::BinaryOperator) -> super::BinaryInstruction {
super::BinaryInstruction {
lhs: self.compile_variable(value.lhs),
rhs: self.compile_variable(value.rhs),
out: self.compile_variable(value.out),
}
}
fn compile_unary(&mut self, value: gpu::UnaryOperator) -> super::UnaryInstruction {
super::UnaryInstruction {
input: self.compile_variable(value.input),
out: self.compile_variable(value.out),
}
}
fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable {
match value {
gpu::Variable::GlobalInputArray(index, item) => {
super::Variable::GlobalInputArray(index, Self::compile_item(item))
}
gpu::Variable::GlobalScalar(index, elem) => {
super::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
}
gpu::Variable::Local(index, item, scope_depth) => super::Variable::Local {
index,
item: Self::compile_item(item),
scope_depth,
},
gpu::Variable::LocalScalar(index, elem, scope_depth) => super::Variable::LocalScalar {
index,
elem: Self::compile_elem(elem),
scope_depth,
},
gpu::Variable::GlobalOutputArray(index, item) => {
super::Variable::GlobalOutputArray(index, Self::compile_item(item))
}
gpu::Variable::ConstantScalar(index, elem) => {
super::Variable::ConstantScalar(index, Self::compile_elem(elem))
}
gpu::Variable::SharedMemory(index, item, size) => {
let item = Self::compile_item(item);
if !self.shared_memories.iter().any(|s| s.index == index) {
self.shared_memories
.push(super::SharedMemory::new(index, item, size));
}
super::Variable::SharedMemory(index, item, size)
}
gpu::Variable::Id => {
self.id = true;
super::Variable::Id
}
gpu::Variable::Rank => {
self.rank = true;
super::Variable::Rank
}
gpu::Variable::LocalInvocationIndex => {
self.invocation_index = true;
super::Variable::LocalInvocationIndex
}
gpu::Variable::LocalInvocationIdX => super::Variable::LocalInvocationIdX,
gpu::Variable::LocalInvocationIdY => super::Variable::LocalInvocationIdY,
gpu::Variable::LocalInvocationIdZ => super::Variable::LocalInvocationIdZ,
gpu::Variable::WorkgroupIdX => super::Variable::WorkgroupIdX,
gpu::Variable::WorkgroupIdY => super::Variable::WorkgroupIdY,
gpu::Variable::WorkgroupIdZ => super::Variable::WorkgroupIdZ,
gpu::Variable::GlobalInvocationIdX => {
self.global_invocation_id.0 = true;
super::Variable::GlobalInvocationIdX
}
gpu::Variable::GlobalInvocationIdY => {
self.global_invocation_id.1 = true;
super::Variable::GlobalInvocationIdY
}
gpu::Variable::GlobalInvocationIdZ => {
self.global_invocation_id.2 = true;
super::Variable::GlobalInvocationIdZ
}
gpu::Variable::WorkgroupSizeX => super::Variable::WorkgroupSizeX,
gpu::Variable::WorkgroupSizeY => super::Variable::WorkgroupSizeY,
gpu::Variable::WorkgroupSizeZ => super::Variable::WorkgroupSizeZ,
gpu::Variable::NumWorkgroupsX => super::Variable::NumWorkgroupsX,
gpu::Variable::NumWorkgroupsY => super::Variable::NumWorkgroupsY,
gpu::Variable::NumWorkgroupsZ => super::Variable::NumWorkgroupsZ,
gpu::Variable::LocalArray(id, item, depth, size) => {
let item = Self::compile_item(item);
if !self
.local_arrays
.iter()
.any(|s| s.index == id && s.depth == depth)
{
self.local_arrays
.push(super::LocalArray::new(id, item, depth, size));
}
super::Variable::LocalArray(id, item, depth, size)
}
}
}
fn compile_binding(binding: gpu::Binding) -> super::Binding {
super::Binding {
item: Self::compile_item(binding.item),
size: binding.size,
}
}
fn compile_item(item: gpu::Item) -> super::Item {
match item {
gpu::Item::Vec4(elem) => super::Item::Vec4(Self::compile_elem(elem)),
gpu::Item::Vec3(elem) => super::Item::Vec3(Self::compile_elem(elem)),
gpu::Item::Vec2(elem) => super::Item::Vec2(Self::compile_elem(elem)),
gpu::Item::Scalar(elem) => super::Item::Scalar(Self::compile_elem(elem)),
}
}
fn compile_elem(value: gpu::Elem) -> super::Elem {
match value {
gpu::Elem::Float(kind) => match kind {
gpu::FloatKind::F16 => super::Elem::F16,
gpu::FloatKind::BF16 => super::Elem::BF16,
gpu::FloatKind::F32 => super::Elem::F32,
gpu::FloatKind::F64 => panic!("f64 isn't supported yet"),
},
gpu::Elem::Int(kind) => match kind {
gpu::IntKind::I32 => super::Elem::I32,
gpu::IntKind::I64 => panic!("i64 isn't supported yet"),
},
gpu::Elem::UInt => super::Elem::U32,
gpu::Elem::Bool => super::Elem::Bool,
}
}
}

View File

@ -0,0 +1,483 @@
use super::{Component, Elem, InstructionSettings, Item, Variable};
use std::fmt::Display;
pub trait Binary {
fn format(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
) -> std::fmt::Result {
let item = out.item();
let settings = Self::settings(*item.elem());
match item {
Item::Vec4(elem) => {
if settings.native_vec4 && lhs.item() == rhs.item() {
Self::format_native_vec4(f, lhs, rhs, out, elem)
} else {
Self::unroll_vec4(f, lhs, rhs, out, elem)
}
}
Item::Vec3(elem) => {
if settings.native_vec3 && lhs.item() == rhs.item() {
Self::format_native_vec3(f, lhs, rhs, out, elem)
} else {
Self::unroll_vec3(f, lhs, rhs, out, elem)
}
}
Item::Vec2(elem) => {
if settings.native_vec2 && lhs.item() == rhs.item() {
Self::format_native_vec2(f, lhs, rhs, out, elem)
} else {
Self::unroll_vec2(f, lhs, rhs, out, elem)
}
}
Item::Scalar(elem) => Self::format_scalar(f, *lhs, *rhs, *out, elem),
}
}
fn settings(_elem: Elem) -> InstructionSettings {
InstructionSettings::default()
}
fn format_scalar<Lhs, Rhs, Out>(
f: &mut std::fmt::Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
out: Out,
elem: Elem,
) -> std::fmt::Result
where
Lhs: Component,
Rhs: Component,
Out: Component;
fn format_native_vec4(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
Self::format_scalar(f, *lhs, *rhs, *out, elem)
}
fn format_native_vec3(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
Self::format_scalar(f, *lhs, *rhs, *out, elem)
}
fn format_native_vec2(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
Self::format_scalar(f, *lhs, *rhs, *out, elem)
}
fn unroll_vec2(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let out0 = out.index(0);
let out1 = out.index(1);
Self::format_scalar(f, lhs0, rhs0, out0, elem)?;
Self::format_scalar(f, lhs1, rhs1, out1, elem)?;
Ok(())
}
fn unroll_vec3(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let lhs2 = lhs.index(2);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let rhs2 = rhs.index(2);
let out0 = out.index(0);
let out1 = out.index(1);
let out2 = out.index(2);
Self::format_scalar(f, lhs0, rhs0, out0, elem)?;
Self::format_scalar(f, lhs1, rhs1, out1, elem)?;
Self::format_scalar(f, lhs2, rhs2, out2, elem)?;
Ok(())
}
fn unroll_vec4(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let lhs2 = lhs.index(2);
let lhs3 = lhs.index(3);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let rhs2 = rhs.index(2);
let rhs3 = rhs.index(3);
let out0 = out.index(0);
let out1 = out.index(1);
let out2 = out.index(2);
let out3 = out.index(3);
Self::format_scalar(f, lhs0, rhs0, out0, elem)?;
Self::format_scalar(f, lhs1, rhs1, out1, elem)?;
Self::format_scalar(f, lhs2, rhs2, out2, elem)?;
Self::format_scalar(f, lhs3, rhs3, out3, elem)?;
Ok(())
}
}
macro_rules! operator {
($name:ident, $op:expr) => {
operator!(
$name,
$op,
InstructionSettings {
native_vec4: false,
native_vec3: false,
native_vec2: false,
}
);
};
($name:ident, $op:expr, $vectorization:expr) => {
pub struct $name;
impl Binary for $name {
fn format_scalar<Lhs: Display, Rhs: Display, Out: Display>(
f: &mut std::fmt::Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
out: Out,
_elem: Elem,
) -> std::fmt::Result {
f.write_fmt(format_args!("{out} = {lhs} {} {rhs};\n", $op))
}
#[allow(unused_variables)]
fn settings(elem: Elem) -> InstructionSettings {
$vectorization
}
}
};
}
macro_rules! function {
($name:ident, $op:expr) => {
function!(
$name,
$op,
InstructionSettings {
native_vec4: false,
native_vec3: false,
native_vec2: true,
}
);
};
($name:ident, $op:expr, $vectorization:expr) => {
pub struct $name;
impl Binary for $name {
fn format_scalar<Lhs: Display, Rhs: Display, Out: Display>(
f: &mut std::fmt::Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
out: Out,
_elem: Elem,
) -> std::fmt::Result {
f.write_fmt(format_args!("{out} = {}({lhs}, {rhs});\n", $op))
}
#[allow(unused_variables)]
fn settings(elem: Elem) -> InstructionSettings {
$vectorization
}
}
};
}
operator!(Add, "+");
operator!(Sub, "-");
operator!(Div, "/");
operator!(Mul, "*");
operator!(Modulo, "%");
operator!(Equal, "==");
operator!(NotEqual, "!=");
operator!(Lower, "<");
operator!(LowerEqual, "<=");
operator!(Greater, ">");
operator!(GreaterEqual, ">=");
operator!(ShiftLeft, "<<");
operator!(ShiftRight, ">>");
operator!(BitwiseAnd, "&");
operator!(BitwiseXor, "^");
operator!(Or, "||");
operator!(And, "&&");
function!(Powf, "powf");
function!(Max, "max");
function!(Min, "min");
pub struct IndexAssign;
pub struct Index;
impl Binary for IndexAssign {
fn format_scalar<Lhs, Rhs, Out>(
f: &mut std::fmt::Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
out: Out,
elem: Elem,
) -> std::fmt::Result
where
Lhs: Component,
Rhs: Component,
Out: Component,
{
let elem_rhs = rhs.elem();
// Cast only when necessary.
if elem != elem_rhs {
if let Elem::Bool = elem_rhs {
match rhs.item() {
Item::Vec4(_) => {
f.write_fmt(format_args!("{out}[{lhs}] = make_uint4({elem}({rhs}.x), {elem}({rhs}.y), {elem}({rhs}.z), {elem}({rhs}.w));\n"))
},
Item::Vec3(_) => todo!(),
Item::Vec2(_) => todo!(),
Item::Scalar(_) => todo!(),
}
} else {
f.write_fmt(format_args!("{out}[{lhs}] = {elem}({rhs});\n"))
}
} else {
f.write_fmt(format_args!("{out}[{lhs}] = {rhs};\n"))
}
}
fn unroll_vec2(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
Self::format_scalar(f, lhs0, rhs0, *out, elem)?;
Self::format_scalar(f, lhs1, rhs1, *out, elem)?;
Ok(())
}
fn unroll_vec3(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let lhs2 = lhs.index(2);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let rhs2 = rhs.index(2);
Self::format_scalar(f, lhs0, rhs0, *out, elem)?;
Self::format_scalar(f, lhs1, rhs1, *out, elem)?;
Self::format_scalar(f, lhs2, rhs2, *out, elem)?;
Ok(())
}
fn unroll_vec4(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let lhs2 = lhs.index(2);
let lhs3 = lhs.index(3);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let rhs2 = rhs.index(2);
let rhs3 = rhs.index(3);
Self::format_scalar(f, lhs0, rhs0, *out, elem)?;
Self::format_scalar(f, lhs1, rhs1, *out, elem)?;
Self::format_scalar(f, lhs2, rhs2, *out, elem)?;
Self::format_scalar(f, lhs3, rhs3, *out, elem)?;
Ok(())
}
fn format(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
) -> std::fmt::Result {
if let Variable::Local {
index: _,
item: _,
scope_depth: _,
} = out
{
return IndexAssignVector::format(f, lhs, rhs, out);
};
let elem = out.elem();
match lhs.item() {
Item::Vec4(_) => Self::unroll_vec4(f, lhs, rhs, out, elem),
Item::Vec3(_) => Self::unroll_vec3(f, lhs, rhs, out, elem),
Item::Vec2(_) => Self::unroll_vec2(f, lhs, rhs, out, elem),
Item::Scalar(_) => Self::format_scalar(f, *lhs, *rhs, *out, elem),
}
}
}
impl Binary for Index {
fn format(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
) -> std::fmt::Result {
if let Variable::Local {
index: _,
item: _,
scope_depth: _,
} = lhs
{
return IndexVector::format(f, lhs, rhs, out);
}
Self::format_scalar(f, *lhs, *rhs, *out, out.elem())
}
fn format_scalar<Lhs, Rhs, Out>(
f: &mut std::fmt::Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
out: Out,
_elem: Elem,
) -> std::fmt::Result
where
Lhs: Component,
Rhs: Component,
Out: Component,
{
f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n"))
}
}
/// The goal is to support indexing of vectorized types.
///
/// # Examples
///
/// ```c
/// float4 rhs;
/// float item = var[0]; // We want that.
/// float item = var.x; // So we compile to that.
/// ```
struct IndexVector;
/// The goal is to support indexing of vectorized types.
///
/// # Examples
///
/// ```c
/// float4 var;
///
/// var[0] = 1.0; // We want that.
/// var.x = 1.0; // So we compile to that.
/// ```
struct IndexAssignVector;
impl IndexVector {
fn format(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
) -> std::fmt::Result {
let index = match rhs {
Variable::ConstantScalar(value, _elem) => *value as usize,
_ => {
let elem = out.elem();
return f.write_fmt(format_args!("{out} = *(({elem}*)&{lhs} + {rhs});\n"));
}
};
let out = out.index(index);
let lhs = lhs.index(index);
f.write_fmt(format_args!("{out} = {lhs};\n"))
}
}
impl IndexAssignVector {
fn format(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
) -> std::fmt::Result {
let index = match lhs {
Variable::ConstantScalar(value, _) => *value as usize,
_ => {
let elem = out.elem();
return f.write_fmt(format_args!("*(({elem}*)&{out} + {lhs}) = {rhs};\n"));
}
};
let out = out.index(index);
let rhs = rhs.index(index);
f.write_fmt(format_args!("{out} = {rhs};\n"))
}
}

View File

@ -0,0 +1,81 @@
use super::Instruction;
use std::fmt::Display;
/// A body is composed of a list of [instructions](Instruction).
#[derive(Debug, Clone)]
pub struct Body {
pub instructions: Vec<Instruction>,
pub shared_memories: Vec<super::SharedMemory>,
pub local_arrays: Vec<super::LocalArray>,
pub stride: bool,
pub shape: bool,
pub id: bool,
pub rank: bool,
pub invocation_index: bool,
pub global_invocation_id: (bool, bool, bool),
}
impl Display for Body {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.id
|| self.global_invocation_id.0
|| self.global_invocation_id.1
|| self.global_invocation_id.2
{
f.write_str(
"
int3 globalInvocationId = make_int3(
blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
blockIdx.z * blockDim.z + threadIdx.z
);
",
)?;
}
if self.id {
f.write_str(
"
uint id = globalInvocationId.y * (blockDim.x * gridDim.x) + globalInvocationId.x;
",
)?;
}
if self.invocation_index {
f.write_str(
"
int invocationIndex = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.x * blockDim.y);
",
)?;
}
if self.rank || self.stride || self.shape {
f.write_str("uint rank = info[0];\n")?;
}
if self.stride || self.shape {
f.write_str("uint rank_2 = rank * 2;\n")?;
}
for shared in self.shared_memories.iter() {
f.write_fmt(format_args!(
"__shared__ {} shared_memory_{}[{}];\n",
shared.item, shared.index, shared.size
))?;
}
// Local arrays
for array in self.local_arrays.iter() {
f.write_fmt(format_args!(
"{} l_arr_{}_{}[{}];\n\n",
array.item, array.index, array.depth, array.size
))?;
}
for ops in self.instructions.iter() {
f.write_fmt(format_args!("{ops}"))?;
}
Ok(())
}
}

View File

@ -0,0 +1,309 @@
use burn_jit::gpu;
use half::{bf16, f16};
use std::fmt::Display;
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum Elem {
F32,
F16,
BF16,
I32,
U32,
Bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum Item {
Vec4(Elem),
Vec3(Elem),
Vec2(Elem),
Scalar(Elem),
}
impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Elem::F16 => f.write_str("f16"),
Elem::F32 => f.write_str("float"),
Elem::BF16 => f.write_str("bf16"),
Elem::I32 => f.write_str("int"),
Elem::U32 => f.write_str("uint"),
Elem::Bool => f.write_str("bool"),
}
}
}
impl Display for Item {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Item::Vec4(elem) => match elem {
Elem::F32 => f.write_str("float4"),
Elem::I32 => f.write_str("int4"),
Elem::U32 => f.write_str("uint4"),
Elem::Bool => f.write_str("bool4"),
Elem::BF16 => f.write_str("bf164"),
Elem::F16 => f.write_str("f164"),
},
Item::Vec3(elem) => match elem {
Elem::F32 => f.write_str("float3"),
Elem::I32 => f.write_str("int3"),
Elem::U32 => f.write_str("uint3"),
Elem::Bool => f.write_str("bool3"),
Elem::BF16 => f.write_str("bf163"),
Elem::F16 => f.write_str("f163"),
},
Item::Vec2(elem) => match elem {
Elem::F32 => f.write_str("float2"),
Elem::I32 => f.write_str("int2"),
Elem::U32 => f.write_str("uint2"),
Elem::Bool => f.write_str("bool2"),
Elem::BF16 => f.write_str("bf162"),
Elem::F16 => f.write_str("f162"),
},
Item::Scalar(elem) => f.write_fmt(format_args!("{elem}")),
}
}
}
pub trait Component: Display {
fn item(&self) -> Item;
fn elem(&self) -> Elem {
*self.item().elem()
}
}
impl Component for IndexedVariable {
fn item(&self) -> Item {
self.var.item()
}
}
impl Component for Variable {
fn item(&self) -> Item {
match self {
Variable::GlobalInputArray(_, e) => *e,
Variable::GlobalOutputArray(_, e) => *e,
Variable::SharedMemory(_, e, _) => *e,
Variable::Local {
index: _,
item,
scope_depth: _,
} => *item,
Variable::ConstantScalar(_, e) => Item::Scalar(*e),
Variable::GlobalScalar(_, e, _) => Item::Scalar(*e),
Variable::Id => Item::Scalar(Elem::U32),
Variable::LocalInvocationIndex => Item::Scalar(Elem::U32),
Variable::LocalInvocationIdX => Item::Scalar(Elem::U32),
Variable::LocalInvocationIdY => Item::Scalar(Elem::U32),
Variable::LocalInvocationIdZ => Item::Scalar(Elem::U32),
Variable::Rank => Item::Scalar(Elem::U32),
Variable::LocalScalar {
index: _,
elem,
scope_depth: _,
} => Item::Scalar(*elem),
Variable::WorkgroupIdX => Item::Scalar(Elem::U32),
Variable::WorkgroupIdY => Item::Scalar(Elem::U32),
Variable::WorkgroupIdZ => Item::Scalar(Elem::U32),
Variable::GlobalInvocationIdX => Item::Scalar(Elem::U32),
Variable::GlobalInvocationIdY => Item::Scalar(Elem::U32),
Variable::GlobalInvocationIdZ => Item::Scalar(Elem::U32),
Variable::WorkgroupSizeX => Item::Scalar(Elem::U32),
Variable::WorkgroupSizeY => Item::Scalar(Elem::U32),
Variable::WorkgroupSizeZ => Item::Scalar(Elem::U32),
Variable::NumWorkgroupsX => Item::Scalar(Elem::U32),
Variable::NumWorkgroupsY => Item::Scalar(Elem::U32),
Variable::NumWorkgroupsZ => Item::Scalar(Elem::U32),
Variable::LocalArray(_, e, _, _) => *e,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum Variable {
GlobalInputArray(u16, Item),
GlobalOutputArray(u16, Item),
GlobalScalar(u16, Elem, gpu::Elem),
ConstantScalar(f64, Elem),
Local {
index: u16,
item: Item,
scope_depth: u8,
},
LocalScalar {
index: u16,
elem: Elem,
scope_depth: u8,
},
SharedMemory(u16, Item, u32),
LocalArray(u16, Item, u8, u32),
Id,
LocalInvocationIndex,
LocalInvocationIdX,
LocalInvocationIdY,
LocalInvocationIdZ,
Rank,
WorkgroupIdX,
WorkgroupIdY,
WorkgroupIdZ,
GlobalInvocationIdX,
GlobalInvocationIdY,
GlobalInvocationIdZ,
WorkgroupSizeX,
WorkgroupSizeY,
WorkgroupSizeZ,
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
}
impl Display for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
Variable::LocalScalar {
index,
elem: _,
scope_depth,
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
Variable::Local {
index,
item: _,
scope_depth,
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")),
Variable::GlobalScalar(number, _, elem) => {
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
}
Variable::ConstantScalar(number, elem) => f.write_fmt(format_args!("{elem}({number})")),
Variable::SharedMemory(number, _, _) => {
f.write_fmt(format_args!("shared_memory_{number}"))
}
Variable::Id => f.write_str("id"),
Variable::LocalInvocationIndex => f.write_str("invocationIndex"),
Variable::LocalInvocationIdX => f.write_str("threadIdx.x"),
Variable::LocalInvocationIdY => f.write_str("threadIdx.y"),
Variable::LocalInvocationIdZ => f.write_str("threadIdx.z"),
Variable::Rank => f.write_str("rank"),
Variable::WorkgroupIdX => f.write_str("blockIdx.x"),
Variable::WorkgroupIdY => f.write_str("blockIdx.y"),
Variable::WorkgroupIdZ => f.write_str("blockIdx.z"),
Variable::WorkgroupSizeX => f.write_str("blockDim.x"),
Variable::WorkgroupSizeY => f.write_str("blockDim.y"),
Variable::WorkgroupSizeZ => f.write_str("blockDim.z"),
Variable::NumWorkgroupsX => f.write_str("gridDim.x"),
Variable::NumWorkgroupsY => f.write_str("gridDim.y"),
Variable::NumWorkgroupsZ => f.write_str("gridDim.z"),
Variable::GlobalInvocationIdX => f.write_str("globalInvocationId.x"),
Variable::GlobalInvocationIdY => f.write_str("globalInvocationId.y"),
Variable::GlobalInvocationIdZ => f.write_str("globalInvocationId.z"),
Variable::LocalArray(id, _item, depth, _size) => {
f.write_fmt(format_args!("l_arr_{}_{}", id, depth))
}
}
}
}
impl Variable {
pub fn is_always_scalar(&self) -> bool {
match self {
Variable::GlobalScalar(_, _, _) => true,
Variable::ConstantScalar(_, _) => true,
Variable::LocalScalar {
index: _,
elem: _,
scope_depth: _,
} => true,
Variable::Id => true,
Variable::LocalInvocationIndex => true,
Variable::LocalInvocationIdX => true,
Variable::LocalInvocationIdY => true,
Variable::LocalInvocationIdZ => true,
Variable::Rank => true,
Variable::GlobalInputArray(_, _) => false,
Variable::GlobalOutputArray(_, _) => false,
Variable::SharedMemory(_, _, _) => false,
Variable::Local {
index: _,
item: _,
scope_depth: _,
} => false,
Variable::WorkgroupIdX => true,
Variable::WorkgroupIdY => true,
Variable::WorkgroupIdZ => true,
Variable::GlobalInvocationIdX => true,
Variable::GlobalInvocationIdY => true,
Variable::GlobalInvocationIdZ => true,
Variable::WorkgroupSizeX => true,
Variable::WorkgroupSizeY => true,
Variable::WorkgroupSizeZ => true,
Variable::NumWorkgroupsX => true,
Variable::NumWorkgroupsY => true,
Variable::NumWorkgroupsZ => true,
Variable::LocalArray(_, _, _, _) => false,
}
}
pub fn index(&self, index: usize) -> IndexedVariable {
IndexedVariable { var: *self, index }
}
}
#[derive(Debug, Clone)]
pub struct IndexedVariable {
var: Variable,
index: usize,
}
impl Display for IndexedVariable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let var = &self.var;
let item = self.var.item();
match item {
Item::Vec4(_) => match self.index {
0 => f.write_fmt(format_args!("{var}.x"))?,
1 => f.write_fmt(format_args!("{var}.y"))?,
2 => f.write_fmt(format_args!("{var}.z"))?,
3 => f.write_fmt(format_args!("{var}.w"))?,
_ => unreachable!(),
},
Item::Vec3(_) => match self.index {
0 => f.write_fmt(format_args!("{var}.x"))?,
1 => f.write_fmt(format_args!("{var}.y"))?,
2 => f.write_fmt(format_args!("{var}.z"))?,
_ => unreachable!(),
},
Item::Vec2(_) => match self.index {
0 => f.write_fmt(format_args!("{var}.x"))?,
1 => f.write_fmt(format_args!("{var}.y"))?,
_ => unreachable!(),
},
Item::Scalar(_) => f.write_fmt(format_args!("{var}"))?,
}
Ok(())
}
}
impl Item {
pub fn elem(&self) -> &Elem {
match self {
Item::Vec4(e) => e,
Item::Vec3(e) => e,
Item::Vec2(e) => e,
Item::Scalar(e) => e,
}
}
}
impl Elem {
pub fn size(&self) -> usize {
match self {
Self::F32 => core::mem::size_of::<f32>(),
Self::F16 => core::mem::size_of::<f16>(),
Self::BF16 => core::mem::size_of::<bf16>(),
Self::I32 => core::mem::size_of::<i32>(),
Self::U32 => core::mem::size_of::<u32>(),
Self::Bool => core::mem::size_of::<bool>(),
}
}
}

View File

@ -0,0 +1,233 @@
use super::{binary::*, unary::*, Component, Variable};
use std::fmt::Display;
#[derive(Debug, Clone)]
pub struct BinaryInstruction {
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}
#[derive(Debug, Clone)]
pub struct UnaryInstruction {
pub input: Variable,
pub out: Variable,
}
#[derive(Debug, Clone)]
pub enum Instruction {
ArrayLength {
input: Variable,
out: Variable,
num_inputs: usize,
num_outputs: usize,
},
DeclareVariable {
var: Variable,
},
Modulo(BinaryInstruction),
Add(BinaryInstruction),
Div(BinaryInstruction),
Mul(BinaryInstruction),
Sub(BinaryInstruction),
Index(BinaryInstruction),
IndexAssign(BinaryInstruction),
CheckedIndexAssign(BinaryInstruction),
Assign(UnaryInstruction),
RangeLoop {
i: Variable,
start: Variable,
end: Variable,
instructions: Vec<Self>,
},
Loop {
instructions: Vec<Self>,
},
If {
cond: Variable,
instructions: Vec<Self>,
},
IfElse {
cond: Variable,
instructions_if: Vec<Self>,
instructions_else: Vec<Self>,
},
Return,
Break,
Stride {
dim: Variable,
position: usize,
out: Variable,
},
Shape {
dim: Variable,
position: usize,
out: Variable,
},
Equal(BinaryInstruction),
NotEqual(BinaryInstruction),
Lower(BinaryInstruction),
Greater(BinaryInstruction),
LowerEqual(BinaryInstruction),
GreaterEqual(BinaryInstruction),
Erf(UnaryInstruction),
BitwiseAnd(BinaryInstruction),
BitwiseXor(BinaryInstruction),
ShiftLeft(BinaryInstruction),
ShiftRight(BinaryInstruction),
Abs(UnaryInstruction),
Exp(UnaryInstruction),
Log(UnaryInstruction),
Log1p(UnaryInstruction),
Cos(UnaryInstruction),
Sin(UnaryInstruction),
Tanh(UnaryInstruction),
Powf(BinaryInstruction),
Sqrt(UnaryInstruction),
Min(BinaryInstruction),
Max(BinaryInstruction),
Not(UnaryInstruction),
Or(BinaryInstruction),
And(BinaryInstruction),
Clamp {
input: Variable,
min_value: Variable,
max_value: Variable,
out: Variable,
},
SyncThreads,
Ceil(UnaryInstruction),
Floor(UnaryInstruction),
}
impl Display for Instruction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Instruction::Return => f.write_str("return;"),
Instruction::Break => f.write_str("break;"),
Instruction::DeclareVariable { var } => {
let item = var.item();
f.write_fmt(format_args!("{item} {var};\n"))
}
Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::CheckedIndexAssign(it) => {
IndexAssign::format(f, &it.lhs, &it.rhs, &it.out)
}
Instruction::Assign(it) => Assign::format(f, &it.input, &it.out),
Instruction::RangeLoop {
i,
start,
end,
instructions,
} => {
f.write_fmt(format_args!(
"
for (uint {i} = {start}; {i} < {end}; {i}++) {{
"
))?;
for instruction in instructions {
f.write_fmt(format_args!("{instruction}"))?;
}
f.write_str("}\n")
}
Instruction::Loop { instructions } => {
f.write_fmt(format_args!("while (true) {{\n"))?;
for i in instructions {
f.write_fmt(format_args!("{i}"))?;
}
f.write_str("}\n")
}
Instruction::If { cond, instructions } => {
f.write_fmt(format_args!("if ({cond}) {{\n"))?;
for i in instructions {
f.write_fmt(format_args!("{i}"))?;
}
f.write_str("}\n")
}
Instruction::IfElse {
cond,
instructions_if,
instructions_else,
} => {
f.write_fmt(format_args!("if ({cond}) {{\n"))?;
for i in instructions_if {
f.write_fmt(format_args!("{i}"))?;
}
f.write_str("} else {\n")?;
for i in instructions_else {
f.write_fmt(format_args!("{i}"))?;
}
f.write_str("}\n")
}
Instruction::Stride { dim, position, out } => f.write_fmt(format_args!(
"{out} = info[({position} * rank_2) + {dim} + 1];\n"
)),
Instruction::Shape { dim, position, out } => f.write_fmt(format_args!(
"{out} = info[({position} * rank_2) + rank + {dim} + 1];\n"
)),
Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Erf(it) => Erf::format(f, &it.input, &it.out),
Instruction::Abs(it) => Abs::format(f, &it.input, &it.out),
Instruction::Exp(it) => Exp::format(f, &it.input, &it.out),
Instruction::Log(it) => Log::format(f, &it.input, &it.out),
Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out),
Instruction::Cos(it) => Cos::format(f, &it.input, &it.out),
Instruction::Sin(it) => Sin::format(f, &it.input, &it.out),
Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out),
Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out),
Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Not(it) => Not::format(f, &it.input, &it.out),
Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Clamp {
input,
min_value,
max_value,
out,
} => f.write_fmt(format_args!(
"
{out} = min({input}, {max_value});
{out} = max({out}, {min_value});
"
)),
Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
Instruction::ArrayLength {
input,
out,
num_inputs,
num_outputs,
} => {
let offset = num_inputs + num_outputs;
let index = match input {
Variable::GlobalInputArray(index, _) => *index as usize,
Variable::GlobalOutputArray(index, _) => *index as usize + num_inputs,
_ => panic!("Can only know the len of a global array."),
} + 1;
f.write_fmt(format_args!(
"{out} = info[({offset} * 2 * info[0]) + {index}];\n"
))
}
}
}
}

View File

@ -0,0 +1,16 @@
pub mod binary;
pub mod unary;
mod base;
mod body;
mod element;
mod instruction;
mod settings;
mod shader;
pub use base::*;
pub use body::*;
pub use element::*;
pub use instruction::*;
pub use settings::*;
pub use shader::*;

View File

@ -0,0 +1,6 @@
#[derive(Debug, Default)]
pub struct InstructionSettings {
pub native_vec4: bool,
pub native_vec3: bool,
pub native_vec2: bool,
}

View File

@ -0,0 +1,155 @@
// use super::{Body, Extension, Item};
use super::{Body, Item};
use burn_jit::{gpu::WorkgroupSize, CompilerRepresentation};
use std::fmt::Display;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Location {
Storage,
#[allow(dead_code)]
Workgroup,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Visibility {
Read,
ReadWrite,
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Binding {
pub item: Item,
pub size: Option<usize>,
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SharedMemory {
pub index: u16,
pub item: Item,
pub size: u32,
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct LocalArray {
pub index: u16,
pub item: Item,
pub depth: u8,
pub size: u32,
}
impl LocalArray {
pub fn new(index: u16, item: Item, depth: u8, size: u32) -> Self {
Self {
index,
item,
depth,
size,
}
}
}
impl SharedMemory {
pub fn new(index: u16, item: Item, size: u32) -> Self {
Self { index, item, size }
}
}
#[derive(Debug, Clone)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,
pub outputs: Vec<Binding>,
pub named: Vec<(String, Binding)>,
pub workgroup_size: WorkgroupSize,
pub body: Body,
}
impl CompilerRepresentation for ComputeShader {
fn shared_memory_size(&self) -> usize {
let mut current = 0usize;
for var in self.body.shared_memories.iter() {
let factor = match var.item {
Item::Vec4(_) => 4,
Item::Vec3(_) => 3,
Item::Vec2(_) => 2,
Item::Scalar(_) => 1,
};
let elem_size_bytes = var.item.elem().size();
current += (var.size as usize) * factor * elem_size_bytes;
}
current
}
}
impl Display for ComputeShader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"
typedef unsigned int uint;
extern \"C\" struct bool4 {{
bool x;
bool y;
bool z;
bool w;
}};
extern \"C\" __global__ void kernel(
",
))?;
let num_bindings = self.inputs.len() + self.outputs.len() + self.named.len();
let mut binding_index = 0;
for (index, binding) in self.inputs.iter().enumerate() {
binding_index += 1;
f.write_fmt(format_args!("{} input_{}[]", binding.item, index))?;
if binding_index < num_bindings {
f.write_str(",")?;
}
}
for (index, binding) in self.outputs.iter().enumerate() {
binding_index += 1;
f.write_fmt(format_args!("{} output_{}[]", binding.item, index))?;
if binding_index < num_bindings {
f.write_str(",")?;
}
}
for (name, binding) in self.named.iter() {
binding_index += 1;
f.write_fmt(format_args!("{} {}[]", binding.item, name))?;
if binding_index < num_bindings {
f.write_str(",")?;
}
}
f.write_str("\n) {\n")?;
f.write_fmt(format_args!("{}", self.body))?;
f.write_str("\n}")?;
Ok(())
}
}
impl ComputeShader {}
impl Display for Location {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Location::Storage => f.write_str("storage"),
Location::Workgroup => f.write_str("workgroup"),
}
}
}
impl Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Visibility::Read => f.write_str("read"),
Visibility::ReadWrite => f.write_str("read_write"),
}
}
}

View File

@ -0,0 +1,210 @@
use super::{Component, Elem, InstructionSettings, Item, Variable};
use std::fmt::Display;
pub trait Unary {
fn format(
f: &mut std::fmt::Formatter<'_>,
input: &Variable,
out: &Variable,
) -> std::fmt::Result {
let item = out.item();
let settings = Self::settings(*item.elem());
match item {
Item::Vec4(elem) => {
if settings.native_vec4 {
Self::format_native_vec4(f, input, out, elem)
} else {
Self::unroll_vec4(f, input, out, elem)
}
}
Item::Vec3(elem) => {
if settings.native_vec3 {
Self::format_native_vec3(f, input, out, elem)
} else {
Self::unroll_vec3(f, input, out, elem)
}
}
Item::Vec2(elem) => {
if settings.native_vec2 {
Self::format_native_vec2(f, input, out, elem)
} else {
Self::unroll_vec2(f, input, out, elem)
}
}
Item::Scalar(elem) => Self::format_scalar(f, *input, *out, elem),
}
}
fn settings(_elem: Elem) -> InstructionSettings {
InstructionSettings::default()
}
fn format_scalar<Input, Out>(
f: &mut std::fmt::Formatter<'_>,
input: Input,
out: Out,
elem: Elem,
) -> std::fmt::Result
where
Input: Component,
Out: Component;
fn format_native_vec4(
f: &mut std::fmt::Formatter<'_>,
input: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
Self::format_scalar(f, *input, *out, elem)
}
fn format_native_vec3(
f: &mut std::fmt::Formatter<'_>,
input: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
Self::format_scalar(f, *input, *out, elem)
}
fn format_native_vec2(
f: &mut std::fmt::Formatter<'_>,
input: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
Self::format_scalar(f, *input, *out, elem)
}
fn unroll_vec2(
f: &mut std::fmt::Formatter<'_>,
input: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let input0 = input.index(0);
let input1 = input.index(1);
let out0 = out.index(0);
let out1 = out.index(1);
Self::format_scalar(f, input0, out0, elem)?;
Self::format_scalar(f, input1, out1, elem)?;
Ok(())
}
fn unroll_vec3(
f: &mut std::fmt::Formatter<'_>,
input: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let input0 = input.index(0);
let input1 = input.index(1);
let input2 = input.index(2);
let out0 = out.index(0);
let out1 = out.index(1);
let out2 = out.index(2);
Self::format_scalar(f, input0, out0, elem)?;
Self::format_scalar(f, input1, out1, elem)?;
Self::format_scalar(f, input2, out2, elem)?;
Ok(())
}
fn unroll_vec4(
f: &mut std::fmt::Formatter<'_>,
input: &Variable,
out: &Variable,
elem: Elem,
) -> std::fmt::Result {
let input0 = input.index(0);
let input1 = input.index(1);
let input2 = input.index(2);
let input3 = input.index(3);
let out0 = out.index(0);
let out1 = out.index(1);
let out2 = out.index(2);
let out3 = out.index(3);
Self::format_scalar(f, input0, out0, elem)?;
Self::format_scalar(f, input1, out1, elem)?;
Self::format_scalar(f, input2, out2, elem)?;
Self::format_scalar(f, input3, out3, elem)?;
Ok(())
}
}
macro_rules! function {
($name:ident, $func:expr) => {
pub struct $name;
impl Unary for $name {
fn format_scalar<Input: Display, Out: Display>(
f: &mut std::fmt::Formatter<'_>,
input: Input,
out: Out,
_elem: Elem,
) -> std::fmt::Result {
f.write_fmt(format_args!("{out} = {}({input});\n", $func))
}
}
};
}
function!(Abs, "abs");
function!(Log, "log");
function!(Log1p, "log1p");
function!(Cos, "cos");
function!(Sin, "sin");
function!(Tanh, "tanh");
function!(Sqrt, "sqrt");
function!(Exp, "exp");
function!(Erf, "erff");
function!(Ceil, "ceil");
function!(Floor, "floor");
pub struct Not;
impl Unary for Not {
fn format_scalar<Input, Out>(
f: &mut std::fmt::Formatter<'_>,
input: Input,
out: Out,
_elem: Elem,
) -> std::fmt::Result
where
Input: Component,
Out: Component,
{
f.write_fmt(format_args!("{out} = !{input};\n"))
}
}
pub struct Assign;
impl Unary for Assign {
fn format_scalar<Input, Out>(
f: &mut std::fmt::Formatter<'_>,
input: Input,
out: Out,
elem: Elem,
) -> std::fmt::Result
where
Input: Component,
Out: Component,
{
// Cast only when necessary.
if elem != input.elem() {
f.write_fmt(format_args!("{out} = {elem}({input});\n"))
} else {
f.write_fmt(format_args!("{out} = {input};\n"))
}
}
}

View File

@ -0,0 +1,5 @@
mod server;
mod storage;
pub use server::*;
pub use storage::*;

View File

@ -0,0 +1,226 @@
use super::storage::Binding;
use super::storage::CudaStorage;
use burn_compute::{
memory_management::MemoryManagement,
server::{self, ComputeServer},
};
use burn_jit::compute::{JitAutotuneKey, Kernel, WorkGroup};
use burn_jit::gpu::WorkgroupSize;
use cudarc::driver::sys::CUctx_st;
use cudarc::driver::sys::CUfunc_st;
use std::collections::HashMap;
use std::ffi::CStr;
use std::ffi::CString;
#[derive(Debug)]
pub struct CudaServer<MM: MemoryManagement<CudaStorage>> {
state: CudaServerState<MM>,
}
pub(crate) enum CudaServerState<MM: MemoryManagement<CudaStorage>> {
Uninitialized {
device_index: usize,
init: Box<dyn Fn(usize) -> CudaContext<MM>>,
},
Initialized {
ctx: CudaContext<MM>,
},
}
impl<MM: MemoryManagement<CudaStorage>> core::fmt::Debug for CudaServerState<MM> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Context")
}
}
#[derive(Debug)]
pub(crate) struct CudaContext<MM: MemoryManagement<CudaStorage>> {
context: *mut CUctx_st,
stream: cudarc::driver::sys::CUstream,
memory_management: MM,
module_names: HashMap<String, CompiledKernel>,
}
#[derive(Debug)]
struct CompiledKernel {
workgroup_size: WorkgroupSize,
shared_mem_bytes: usize,
func: *mut CUfunc_st,
}
unsafe impl<MM: MemoryManagement<CudaStorage>> Send for CudaServer<MM> {}
impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
type Kernel = Kernel;
type Storage = CudaStorage;
type MemoryManagement = MM;
type AutotuneKey = JitAutotuneKey;
fn read(&mut self, binding: server::Binding<Self>) -> burn_tensor::Reader<Vec<u8>> {
let ctx = self.get_context();
let resource = ctx.memory_management.get(binding.memory);
// TODO: Check if it is possible to make this faster
let mut data = vec![0; resource.size() as usize];
unsafe {
cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap();
};
ctx.sync();
burn_tensor::Reader::Concrete(data)
}
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
let ctx = self.get_context();
let handle = ctx.memory_management.reserve(data.len());
let handle = server::Handle::new(handle);
let binding = handle.clone().binding().memory;
let resource = ctx.memory_management.get(binding);
unsafe {
cudarc::driver::result::memcpy_htod_async(resource.ptr, data, ctx.stream).unwrap();
}
handle
}
fn empty(&mut self, size: usize) -> server::Handle<Self> {
let ctx = self.get_context();
let handle = ctx.memory_management.reserve(size);
server::Handle::new(handle)
}
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
let ctx = self.get_context();
let kernel_id = kernel.id();
let settings = kernel.launch_settings();
if !ctx.module_names.contains_key(&kernel_id) {
ctx.compile_kernel(&kernel_id, kernel);
}
let bindings = bindings
.into_iter()
.map(|binding| ctx.memory_management.get(binding.memory).as_binding())
.collect();
ctx.execute_task(kernel_id, settings.workgroup, bindings);
// TODO: fix this
// self.memory_management.storage().perform_deallocations();
}
fn sync(&mut self) {
let ctx = self.get_context();
ctx.sync();
}
}
impl<MM: MemoryManagement<CudaStorage>> CudaContext<MM> {
pub fn new(
memory_management: MM,
stream: cudarc::driver::sys::CUstream,
context: *mut CUctx_st,
) -> Self {
Self {
context,
memory_management,
module_names: HashMap::new(),
stream,
}
}
fn sync(&mut self) {
unsafe {
cudarc::driver::result::stream::synchronize(self.stream).unwrap();
};
}
fn compile_kernel(&mut self, kernel_id: &str, kernel: Kernel) {
let kernel_compiled = kernel.compile();
let shared_mem_bytes = kernel_compiled.shared_mem_bytes;
let workgroup_size = kernel_compiled.workgroup_size;
let ptx = unsafe {
let program = cudarc::nvrtc::result::create_program(kernel_compiled.source).unwrap();
if cudarc::nvrtc::result::compile_program::<Vec<_>>(program, &[]).is_err() {
let log_raw = cudarc::nvrtc::result::get_program_log(program).unwrap();
let log_ptr = log_raw.as_ptr();
let log = CStr::from_ptr(log_ptr).to_str().unwrap();
let mut message = "[Compilation Error] ".to_string();
for line in log.split('\n') {
if !line.is_empty() {
message += format!("\n {line}").as_str();
}
}
let source = kernel.compile().source;
panic!("{message}\n[Source] \n{source}");
};
cudarc::nvrtc::result::get_ptx(program).unwrap()
};
let func_name = CString::new("kernel".to_string()).unwrap();
let func = unsafe {
let module =
cudarc::driver::result::module::load_data(ptx.as_ptr() as *const _).unwrap();
cudarc::driver::result::module::get_function(module, func_name).unwrap()
};
self.module_names.insert(
kernel_id.to_string(),
CompiledKernel {
workgroup_size,
shared_mem_bytes,
func,
},
);
}
fn execute_task(
&mut self,
kernel_id: String,
workgroup: WorkGroup,
mut bindings: Vec<Binding>,
) {
let kernel = self.module_names.get(&kernel_id).unwrap();
let workgroup_size = kernel.workgroup_size;
unsafe {
cudarc::driver::result::launch_kernel(
kernel.func,
(workgroup.x, workgroup.y, workgroup.z),
(workgroup_size.x, workgroup_size.y, workgroup_size.z),
kernel.shared_mem_bytes as u32,
self.stream,
&mut bindings,
)
.unwrap();
};
}
}
impl<MM: MemoryManagement<CudaStorage>> CudaServer<MM> {
/// Create a new cuda server.
pub(crate) fn new(index: usize, init: Box<dyn Fn(usize) -> CudaContext<MM>>) -> Self {
Self {
state: CudaServerState::Uninitialized {
device_index: index,
init,
},
}
}
fn get_context(&mut self) -> &mut CudaContext<MM> {
if let CudaServerState::Uninitialized { device_index, init } = &self.state {
let ctx = init(*device_index);
self.state = CudaServerState::Initialized { ctx };
}
if let CudaServerState::Initialized { ctx } = &mut self.state {
unsafe {
cudarc::driver::result::ctx::set_current(ctx.context).unwrap();
};
ctx
} else {
panic!("Context should be initialized");
}
}
}

View File

@ -0,0 +1,118 @@
use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use cudarc::driver::sys::CUstream;
use std::collections::HashMap;
/// Buffer storage for cuda.
pub struct CudaStorage {
memory: HashMap<StorageId, cudarc::driver::sys::CUdeviceptr>,
deallocations: Vec<StorageId>,
stream: cudarc::driver::sys::CUstream,
}
unsafe impl Send for CudaStorage {}
impl core::fmt::Debug for CudaStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(format!("CudaStorage {{ device: {:?} }}", self.stream).as_str())
}
}
/// Keeps actual wgpu buffer references in a hashmap with ids as key.
impl CudaStorage {
/// Create a new storage on the given [device](wgpu::Device).
pub fn new(stream: CUstream) -> Self {
Self {
memory: HashMap::new(),
deallocations: Vec::new(),
stream,
}
}
/// Actually deallocates buffers tagged to be deallocated.
pub fn perform_deallocations(&mut self) {
for id in self.deallocations.drain(..) {
if let Some(ptr) = self.memory.remove(&id) {
unsafe {
cudarc::driver::result::free_async(ptr, self.stream).unwrap();
}
}
}
}
}
/// The memory resource that can be allocated for wgpu.
#[derive(new, Debug)]
pub struct CudaResource {
/// The wgpu buffer.
pub ptr: u64,
pub binding: *mut std::ffi::c_void,
/// How the resource is used.
pub kind: CudaResourceKind,
}
unsafe impl Send for CudaResource {}
pub type Binding = *mut std::ffi::c_void;
impl CudaResource {
/// Return the binding view of the buffer.
pub fn as_binding(&self) -> Binding {
self.binding
}
/// Return the buffer size.
pub fn size(&self) -> u64 {
match self.kind {
CudaResourceKind::Full { size } => size as u64,
CudaResourceKind::Slice { size, offset: _ } => size as u64,
}
}
/// Return the buffer offset.
pub fn offset(&self) -> u64 {
match self.kind {
CudaResourceKind::Full { size: _ } => 0,
CudaResourceKind::Slice { size: _, offset } => offset as u64,
}
}
}
/// How the resource is used, either as a slice or fully.
#[derive(Debug)]
pub enum CudaResourceKind {
/// Represents an entire buffer.
Full { size: usize },
/// A slice over a buffer.
Slice { size: usize, offset: usize },
}
impl ComputeStorage for CudaStorage {
type Resource = CudaResource;
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let ptr = self.memory.get(&handle.id).unwrap();
match handle.utilization {
StorageUtilization::Full(size) => CudaResource::new(
*ptr,
ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void,
CudaResourceKind::Full { size },
),
StorageUtilization::Slice { offset, size } => CudaResource::new(
*ptr,
ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void,
CudaResourceKind::Slice { size, offset },
),
}
}
fn alloc(&mut self, size: usize) -> StorageHandle {
let id = StorageId::new();
let ptr = unsafe { cudarc::driver::result::malloc_async(self.stream, size).unwrap() };
self.memory.insert(id.clone(), ptr);
StorageHandle::new(id, StorageUtilization::Full(size))
}
fn dealloc(&mut self, id: StorageId) {
self.deallocations.push(id);
}
}

View File

@ -0,0 +1,12 @@
use burn_tensor::backend::{DeviceId, DeviceOps};
#[derive(new, Clone, Debug, PartialEq, Eq, Default, Hash)]
pub struct CudaDevice {
pub index: usize,
}
impl DeviceOps for CudaDevice {
fn id(&self) -> DeviceId {
DeviceId::new(0, self.index as u32)
}
}

View File

@ -0,0 +1,42 @@
use burn_jit::JitElement;
use crate::compiler;
/// The base element trait for the wgpu backend.
pub trait CudaElement: JitElement {
fn cuda_elem() -> compiler::Elem;
}
/// The float element type for the wgpu backend.
pub trait FloatElement: CudaElement + burn_jit::FloatElement {}
/// The int element type for the wgpu backend.
pub trait IntElement: CudaElement + burn_jit::IntElement {}
impl CudaElement for u32 {
fn cuda_elem() -> compiler::Elem {
compiler::Elem::U32
}
}
impl CudaElement for i32 {
fn cuda_elem() -> compiler::Elem {
compiler::Elem::I32
}
}
impl CudaElement for f32 {
fn cuda_elem() -> compiler::Elem {
compiler::Elem::F32
}
}
impl CudaElement for half::bf16 {
fn cuda_elem() -> compiler::Elem {
compiler::Elem::BF16
}
}
impl FloatElement for f32 {}
impl FloatElement for half::bf16 {}
impl IntElement for i32 {}

View File

@ -0,0 +1,29 @@
#[macro_use]
extern crate derive_new;
extern crate alloc;
mod compute;
mod device;
mod element;
mod runtime;
pub mod compiler;
pub use device::*;
use burn_jit::JitBackend;
use runtime::CudaRuntime;
#[cfg(not(feature = "fusion"))]
pub type Cuda<F = f32, I = i32> = JitBackend<CudaRuntime, F, I>;
#[cfg(feature = "fusion")]
pub type Cuda<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<CudaRuntime, F, I>>;
#[cfg(test)]
mod tests {
use super::*;
pub type TestRuntime = crate::CudaRuntime;
burn_jit::testgen_all!();
}

View File

@ -0,0 +1,81 @@
use burn_common::stub::RwLock;
use burn_compute::{
channel::MutexComputeChannel,
client::ComputeClient,
memory_management::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy},
tune::Tuner,
ComputeRuntime,
};
use burn_jit::Runtime;
use std::sync::Arc;
use crate::{
compiler::CudaCompiler,
compute::{CudaContext, CudaServer, CudaStorage},
device::CudaDevice,
};
#[derive(Debug)]
pub struct CudaRuntime;
// static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
ComputeRuntime::new();
type Server = CudaServer<SimpleMemoryManagement<CudaStorage>>;
impl Runtime for CudaRuntime {
type Compiler = CudaCompiler;
type Server = CudaServer<SimpleMemoryManagement<CudaStorage>>;
// type Channel = MutexComputeChannel<CudaServer<SimpleMemoryManagement<CudaStorage>>>;
type Channel = MutexComputeChannel<CudaServer<SimpleMemoryManagement<CudaStorage>>>;
type Device = CudaDevice;
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
fn init(index: usize) -> CudaContext<SimpleMemoryManagement<CudaStorage>> {
cudarc::driver::result::init().unwrap();
let device_ptr = cudarc::driver::result::device::get(index as i32).unwrap();
let ctx = unsafe {
let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap();
cudarc::driver::result::ctx::set_current(ctx).unwrap();
ctx
};
let stream = cudarc::driver::result::stream::create(
cudarc::driver::result::stream::StreamKind::NonBlocking,
)
.unwrap();
let storage = CudaStorage::new(stream);
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(1),
SliceStrategy::Never,
);
CudaContext::new(memory_management, stream, ctx)
}
RUNTIME.client(device, move || {
let server = CudaServer::new(device.index, Box::new(init));
let tuner_device_id = tuner_device_id();
ComputeClient::new(
MutexComputeChannel::new(server),
Arc::new(RwLock::new(Tuner::new(&tuner_device_id))),
)
})
}
fn name() -> &'static str {
"cuda"
}
fn require_array_lengths() -> bool {
true
}
}
fn tuner_device_id() -> String {
"cuda".into()
}

View File

@ -37,6 +37,7 @@ log = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true }
spin = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
# Template
serde = { workspace = true }

View File

@ -1,11 +1,17 @@
use super::dialect::gpu;
use std::fmt::Display;
/// Trait for compiled code representation
pub trait CompilerRepresentation: Display {
/// Computes and returns the shared memory size
fn shared_memory_size(&self) -> usize;
}
/// Compiles the [gpu representation](gpu::ComputeShader) into its own representation that can be
/// formatted into tokens.
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
/// The representation for the compiled code.
type Representation: Display;
type Representation: CompilerRepresentation;
/// Compiles the [gpu shader](gpu::ComputeShader) into the compiler's representation.
fn compile(shader: gpu::ComputeShader) -> Self::Representation;

View File

@ -217,12 +217,24 @@ macro_rules! gpu {
gpu!(binary $lhs, $rhs, $out)
));
};
// out = unchecked(lhs[rhs])
($scope:expr, $out:ident = unchecked($lhs:ident[$rhs:expr])) => {
$scope.register($crate::codegen::dialect::gpu::Operator::UncheckedIndex(
gpu!(binary $lhs, $rhs, $out)
));
};
// out[lhs] = rhs
($scope:expr, $out:ident[$lhs:ident] = $rhs:expr) => {
$scope.register($crate::codegen::dialect::gpu::Operator::IndexAssign(
gpu!(binary $lhs, $rhs, $out)
));
};
// unchecked(out[lhs]) = rhs
($scope:expr, unchecked($out:ident[$lhs:ident]) = $rhs:expr) => {
$scope.register($crate::codegen::dialect::gpu::Operator::UncheckedIndexAssign(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = |input|
($scope:expr, $out:ident = |$input:ident|) => {
gpu!($scope, $out = abs($input))

View File

@ -50,7 +50,9 @@ pub enum Operator {
Assign(UnaryOperator),
Modulo(BinaryOperator),
Index(BinaryOperator),
UncheckedIndex(BinaryOperator),
IndexAssign(BinaryOperator),
UncheckedIndexAssign(BinaryOperator),
And(BinaryOperator),
Or(BinaryOperator),
Not(UnaryOperator),

View File

@ -1,5 +1,6 @@
use super::{
ConditionalAssign, IndexOffsetGlobalWithLayout, ReadGlobal, ReadGlobalWithLayout, WriteGlobal,
CheckedIndex, CheckedIndexAssign, ConditionalAssign, IndexOffsetGlobalWithLayout, ReadGlobal,
ReadGlobalWithLayout, WriteGlobal,
};
use crate::codegen::dialect::gpu::Vectorization;
use serde::{Deserialize, Serialize};
@ -13,6 +14,8 @@ pub enum Procedure {
IndexOffsetGlobalWithLayout(IndexOffsetGlobalWithLayout),
ReadGlobal(ReadGlobal),
WriteGlobal(WriteGlobal),
CheckedIndex(CheckedIndex),
CheckedIndexAssign(CheckedIndexAssign),
ConditionalAssign(ConditionalAssign),
}
@ -22,14 +25,18 @@ impl Procedure {
Procedure::ReadGlobalWithLayout(op) => {
Procedure::ReadGlobalWithLayout(op.vectorize(vectorization))
}
Procedure::ReadGlobal(op) => Procedure::ReadGlobal(op.vectorize(vectorization)),
Procedure::WriteGlobal(op) => Procedure::WriteGlobal(op.vectorize(vectorization)),
Procedure::ConditionalAssign(proc) => {
Procedure::ConditionalAssign(proc.vectorize(vectorization))
}
Procedure::IndexOffsetGlobalWithLayout(op) => {
Procedure::IndexOffsetGlobalWithLayout(op.vectorize(vectorization))
}
Procedure::ReadGlobal(op) => Procedure::ReadGlobal(op.vectorize(vectorization)),
Procedure::WriteGlobal(op) => Procedure::WriteGlobal(op.vectorize(vectorization)),
Procedure::CheckedIndex(proc) => Procedure::CheckedIndex(proc.vectorize(vectorization)),
Procedure::CheckedIndexAssign(proc) => {
Procedure::CheckedIndexAssign(proc.vectorize(vectorization))
}
Procedure::ConditionalAssign(proc) => {
Procedure::ConditionalAssign(proc.vectorize(vectorization))
}
}
}
}

View File

@ -0,0 +1,74 @@
use crate::codegen::dialect::gpu::{macros::gpu, Item, Scope, Variable, Vectorization};
use serde::{Deserialize, Serialize};
/// Perform a check bound on the index (lhs) of value (rhs)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct CheckedIndex {
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}
impl CheckedIndex {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let lhs = self.lhs;
let rhs = self.rhs;
let out = self.out;
let array_len = scope.create_local(Item::Scalar(crate::gpu::Elem::UInt));
let inside_bound = scope.create_local(Item::Scalar(crate::gpu::Elem::Bool));
gpu!(scope, array_len = len(lhs));
gpu!(scope, inside_bound = rhs < array_len);
gpu!(scope, if(inside_bound).then(|scope| {
gpu!(scope, out = unchecked(lhs[rhs]));
}).else(|scope| {
gpu!(scope, out = cast(0));
}));
}
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
lhs: self.lhs.vectorize(vectorization),
rhs: self.rhs.vectorize(vectorization),
out: self.out.vectorize(vectorization),
}
}
}
/// Perform a check bound on the index (lhs) of output before assigning the value (rhs)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct CheckedIndexAssign {
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}
impl CheckedIndexAssign {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let lhs = self.lhs;
let rhs = self.rhs;
let out = self.out;
let array_len = scope.create_local(Item::Scalar(crate::gpu::Elem::UInt));
let inside_bound = scope.create_local(Item::Scalar(crate::gpu::Elem::Bool));
gpu!(scope, array_len = len(out));
gpu!(scope, inside_bound = lhs < array_len);
gpu!(scope, if(inside_bound).then(|scope| {
gpu!(scope, unchecked(out[lhs]) = rhs);
}));
}
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
lhs: self.lhs.vectorize(vectorization),
rhs: self.rhs.vectorize(vectorization),
out: self.out.vectorize(vectorization),
}
}
}

View File

@ -1,9 +1,11 @@
mod assign;
mod base;
mod index;
mod read;
mod write;
pub use assign::*;
pub use base::*;
pub use index::*;
pub use read::*;
pub use write::*;

View File

@ -20,6 +20,8 @@ pub enum Visibility {
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum FloatKind {
F16,
BF16,
F32,
F64,
}
@ -68,7 +70,8 @@ pub enum Item {
}
impl Item {
pub(crate) fn elem(&self) -> Elem {
/// Fetch the elem of the item.
pub fn elem(&self) -> Elem {
match self {
Self::Vec4(elem) => *elem,
Self::Vec3(elem) => *elem,

View File

@ -63,7 +63,8 @@ impl Variable {
Variable::NumWorkgroupsZ => None,
}
}
pub(crate) fn item(&self) -> Item {
/// Fetch the item of the variable.
pub fn item(&self) -> Item {
match self {
Variable::GlobalInputArray(_, item) => *item,
Variable::GlobalOutputArray(_, item) => *item,

View File

@ -40,6 +40,7 @@ impl Operator {
Operator::Min(op) => Operator::Min(op.vectorize(vectorization)),
Operator::Add(op) => Operator::Add(op.vectorize(vectorization)),
Operator::Index(op) => Operator::Index(op.vectorize(vectorization)),
Operator::UncheckedIndex(op) => Operator::UncheckedIndex(op.vectorize(vectorization)),
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)),
Operator::Div(op) => Operator::Div(op.vectorize(vectorization)),
@ -74,6 +75,9 @@ impl Operator {
}
Operator::Modulo(op) => Operator::Modulo(op.vectorize(vectorization)),
Operator::IndexAssign(op) => Operator::IndexAssign(op.vectorize(vectorization)),
Operator::UncheckedIndexAssign(op) => {
Operator::UncheckedIndexAssign(op.vectorize(vectorization))
}
Operator::And(op) => Operator::And(op.vectorize(vectorization)),
Operator::Or(op) => Operator::Or(op.vectorize(vectorization)),
Operator::Not(op) => Operator::Not(op.vectorize(vectorization)),

View File

@ -285,6 +285,19 @@ fn execute_settings<'a, R: Runtime, E1: JitElement, E2: JitElement, E3: JitEleme
handles.push(output.handle.clone().binding());
}
// [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len]
if R::require_array_lengths() {
for input in inputs.iter() {
let len = calculate_num_elems_dyn_rank(input.shape);
info.push(len as u32);
}
for output in outputs.iter() {
let len = calculate_num_elems_dyn_rank(output.shape);
info.push(len as u32);
}
}
let info = client.create(bytemuck::cast_slice(&info));
// Finally we finish with the named bindings.

View File

@ -2,7 +2,9 @@ use std::marker::PhantomData;
#[cfg(feature = "template")]
use crate::template::TemplateKernel;
use crate::{gpu::WorkgroupSize, kernel::GpuComputeShaderPhase, Compiler};
use crate::{
codegen::CompilerRepresentation, gpu::WorkgroupSize, kernel::GpuComputeShaderPhase, Compiler,
};
use alloc::sync::Arc;
/// Kernel for JIT backends
@ -53,6 +55,8 @@ pub struct CompiledKernel {
pub source: String,
/// Size of a workgroup for the compiled kernel
pub workgroup_size: WorkgroupSize,
/// The number of bytes used by the share memory
pub shared_mem_bytes: usize,
}
/// Information needed to launch the kernel
@ -86,13 +90,14 @@ impl<C: Compiler, K: GpuComputeShaderPhase> JitKernel for FullCompilationPhase<C
fn compile(&self) -> CompiledKernel {
let gpu_ir = self.kernel.compile();
let workgroup_size = gpu_ir.workgroup_size;
let lower_level_ir = C::compile(gpu_ir);
let shared_mem_bytes = lower_level_ir.shared_memory_size();
let source = lower_level_ir.to_string();
CompiledKernel {
source,
workgroup_size,
shared_mem_bytes,
}
}

View File

@ -92,5 +92,26 @@ impl JitElement for f32 {
}
}
impl JitElement for half::bf16 {
fn type_name() -> &'static str {
"bf16"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Float(gpu::FloatKind::BF16)
}
fn maximum_value() -> Self {
half::bf16::MAX
}
fn minimum_value() -> Self {
half::bf16::MIN
}
}
impl FloatElement for f32 {}
impl FloatElement for half::bf16 {}
impl IntElement for i32 {}

View File

@ -1,3 +1,4 @@
use crate::codegen::calculate_num_elems_dyn_rank;
use crate::codegen::Compilation;
use crate::codegen::CompilationInfo;
use crate::codegen::CompilationSettings;
@ -165,14 +166,14 @@ impl<R: Runtime> FusionKernel<R> {
let mut output_register = Vec::with_capacity(outputs_description_updated.len());
// We register the info and handles for the inputs.
for (handle, tensor) in handles_input.iter().zip(inputs_description_updated) {
for (handle, tensor) in handles_input.iter().zip(inputs_description_updated.iter()) {
register_info_tensor(&mut info, tensor, handle);
bindings.push(handle.handle.clone().binding());
}
// We register the info and handles for the outputs.
for (tensor, output_info) in outputs_description_updated
.into_iter()
.iter()
.zip(fusion_kernel.runtime_info.iter())
{
match output_info {
@ -204,6 +205,19 @@ impl<R: Runtime> FusionKernel<R> {
};
}
// [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len]
if R::require_array_lengths() {
for input in inputs_description_updated.iter() {
let len = calculate_num_elems_dyn_rank(&input.shape);
info.push(len as u32);
}
for output in outputs_description_updated.iter() {
let len = calculate_num_elems_dyn_rank(&output.shape);
info.push(len as u32);
}
}
// Create the info buffer.
bindings.push(client.create(bytemuck::cast_slice(&info)).binding());

View File

@ -219,6 +219,11 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::UncheckedIndex(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Sub(op) => mark_binary(
op,
&mut local_tensor_ids_input,
@ -343,6 +348,11 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::UncheckedIndexAssign(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::BitwiseAnd(op) => mark_binary(
op,
&mut local_tensor_ids_input,
@ -380,6 +390,12 @@ impl TraceBuilder {
gpu::Procedure::WriteGlobal(_) => {
// Nothing to do here.
}
gpu::Procedure::CheckedIndex(_) => {
// Nothing to do here.
}
gpu::Procedure::CheckedIndexAssign(_) => {
// Nothing to do here.
}
gpu::Procedure::ConditionalAssign(proc) => {
mark(&proc.cond, &mut local_tensor_ids_input);
mark(&proc.lhs, &mut local_tensor_ids_input);

View File

@ -104,6 +104,7 @@ pub enum MatmulStrategy {
Autotune,
}
#[cfg(feature = "autotune")]
#[cfg(not(feature = "autotune"))]
impl Default for MatmulStrategy {
fn default() -> Self {

View File

@ -19,7 +19,7 @@ pub(crate) mod codegen;
pub(crate) mod tune;
mod element;
pub use codegen::compiler::Compiler;
pub use codegen::compiler::{Compiler, CompilerRepresentation};
pub use codegen::dialect::gpu;
pub use element::{FloatElement, IntElement, JitElement};

View File

@ -28,4 +28,9 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug {
/// The runtime name.
fn name() -> &'static str;
/// Return true if global input array lengths should be added to kernel info.
fn require_array_lengths() -> bool {
false
}
}

View File

@ -43,9 +43,11 @@ where
fn compile(&self) -> CompiledKernel {
let source_template = self.kernel_source.source();
let source = source_template.complete();
CompiledKernel {
source,
workgroup_size: self.workgroup_size,
shared_mem_bytes: 0,
}
}

View File

@ -74,13 +74,9 @@ mod tests {
fn args(&self) -> Self::Args {
let device = Default::default();
(
TestTensor::random([32, 32], Distribution::Default, &device)
.into_data()
.convert(),
TestTensor::ones([32, 32], &device).into_data().convert(),
// Avoid div by zero.
TestTensor::random([32, 32], Distribution::Uniform(1., 3.), &device)
.into_data()
.convert(),
TestTensor::ones([32, 32], &device).into_data().convert(),
)
}

View File

@ -1,7 +1,7 @@
use super::Instruction;
use std::fmt::Display;
/// A body is composed of a list of [operations](Operation).
/// A body is composed of a list of [instructions](Instruction).
///
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
/// X and Y, but with Z=1.

View File

@ -100,6 +100,8 @@ impl WgslCompiler {
fn compile_elem(value: gpu::Elem) -> wgsl::Elem {
match value {
gpu::Elem::Float(f) => match f {
gpu::FloatKind::F16 => panic!("f16 is not yet supported"),
gpu::FloatKind::BF16 => panic!("f64 is not a valid WgpuElement"),
gpu::FloatKind::F32 => wgsl::Elem::F32,
gpu::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"),
},
@ -317,6 +319,14 @@ impl WgslCompiler {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::CheckedIndex(proc) => {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::CheckedIndexAssign(proc) => {
proc.expand(scope);
compile(scope);
}
gpu::Procedure::IndexOffsetGlobalWithLayout(proc) => {
proc.expand(scope);
compile(scope);
@ -381,6 +391,11 @@ impl WgslCompiler {
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::UncheckedIndex(op) => wgsl::Instruction::Index {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::Modulo(op) => wgsl::Instruction::Modulo {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
@ -499,6 +514,11 @@ impl WgslCompiler {
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::UncheckedIndexAssign(op) => wgsl::Instruction::IndexAssign {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
gpu::Operator::And(op) => wgsl::Instruction::And {
lhs: self.compile_variable(op.lhs),
rhs: self.compile_variable(op.rhs),
@ -593,6 +613,14 @@ fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec<wgsl::Extensio
wgsl::Instruction::Tanh { input, out: _ } => {
register_extension(wgsl::Extension::SafeTanh(input.item()))
}
wgsl::Instruction::If {
cond: _,
instructions,
} => {
for extension in register_extensions(instructions) {
register_extension(extension);
}
}
_ => {}
}
}

View File

@ -1,5 +1,5 @@
use super::{Body, Extension, Item};
use burn_jit::gpu::WorkgroupSize;
use burn_jit::{gpu::WorkgroupSize, CompilerRepresentation};
use std::fmt::Display;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -218,3 +218,10 @@ impl Display for Visibility {
}
}
}
impl CompilerRepresentation for ComputeShader {
fn shared_memory_size(&self) -> usize {
// not used in wgsl compiler
0
}
}

View File

@ -41,7 +41,7 @@ autodiff = ["burn-core/autodiff"]
fusion = ["burn-core/fusion"]
## Backend features
cuda = ["burn-core/cuda"]
candle-cuda = ["burn-core/candle-cuda"]
metal = ["burn-core/metal"]
accelerate = ["burn-core/accelerate"]
openblas = ["burn-core/openblas"]

View File

@ -15,7 +15,7 @@ use crate::utils::cargo::{run_cargo, run_cargo_with_path};
use crate::utils::process::{handle_child_process, run_command};
use crate::utils::rustup::{rustup_add_component, rustup_add_target};
use crate::utils::time::format_duration;
use crate::utils::workspace::{get_workspaces, WorkspaceMemberType};
use crate::utils::workspace::{get_workspace_members, WorkspaceMemberType};
use crate::utils::Params;
use crate::{endgroup, group};
@ -310,9 +310,13 @@ fn std_checks() {
// Check clippy lints
cargo_clippy();
// Produce documentation for each workspace
group!("Docs: workspaces");
cargo_doc(["--workspace", "--no-deps"].into());
// Produce documentation for each workspace member
group!("Docs: crates");
let mut params = Params::from(["--workspace", "--no-deps"]);
// Exclude burn-cuda on all platforms
params.params.push("--exclude".to_string());
params.params.push("burn-cuda".to_string());
cargo_doc(params);
endgroup!();
// Setup code coverage
@ -320,20 +324,23 @@ fn std_checks() {
setup_coverage();
}
// Build & test each workspace
let workspaces = get_workspaces(WorkspaceMemberType::Crate);
for workspace in workspaces {
if disable_wgpu && workspace.name == "burn-wgpu" {
// Build & test each member in workspace
let members = get_workspace_members(WorkspaceMemberType::Crate);
for member in members {
if disable_wgpu && member.name == "burn-wgpu" {
continue;
}
if member.name == "burn-cuda" {
// burn-cuda requires CUDA Toolkit which is not currently setup on our CI runners
continue;
}
if member.name == "burn-tch" {
continue;
}
if workspace.name == "burn-tch" {
continue;
}
group!("Checks: {}", workspace.name);
cargo_build(Params::from(["-p", &workspace.name]));
cargo_test(Params::from(["-p", &workspace.name]));
group!("Checks: {}", member.name);
cargo_build(Params::from(["-p", &member.name]));
cargo_test(Params::from(["-p", &member.name]));
endgroup!();
}
@ -381,18 +388,18 @@ fn check_typos() {
}
fn check_examples() {
let workspaces = get_workspaces(WorkspaceMemberType::Example);
for workspace in workspaces {
if workspace.name == "notebook" {
let members = get_workspace_members(WorkspaceMemberType::Example);
for member in members {
if member.name == "notebook" {
continue;
}
group!("Checks: Example - {}", workspace.name);
group!("Checks: Example - {}", member.name);
run_cargo_with_path(
"check",
["--examples"].into(),
HashMap::new(),
Some(workspace.path),
Some(member.path),
"Failed to check example",
);
endgroup!();

View File

@ -6,7 +6,7 @@ pub(crate) mod time;
pub(crate) mod workspace;
pub(crate) struct Params {
params: Vec<String>,
pub params: Vec<String>,
}
impl<const N: usize> From<[&str; N]> for Params {

View File

@ -25,8 +25,8 @@ impl WorkspaceMember {
}
}
/// Get project workspaces
pub(crate) fn get_workspaces(w_type: WorkspaceMemberType) -> Vec<WorkspaceMember> {
/// Get workspace crates
pub(crate) fn get_workspace_members(w_type: WorkspaceMemberType) -> Vec<WorkspaceMember> {
// Run `cargo metadata` command to get project metadata
let output = Command::new("cargo")
.arg("metadata")