mirror of https://github.com/tracel-ai/burn.git
Migration/cubecl (#2041)
This commit is contained in:
parent
0d5025edbb
commit
19cd67a9e2
|
@ -255,10 +255,10 @@ dependencies = [
|
|||
"arboard",
|
||||
"burn",
|
||||
"burn-common",
|
||||
"burn-cuda",
|
||||
"burn-wgpu",
|
||||
"clap 4.5.9",
|
||||
"colored",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"dirs 5.0.1",
|
||||
"github-device-flow",
|
||||
|
@ -469,41 +469,17 @@ dependencies = [
|
|||
name = "burn-common"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"dashmap",
|
||||
"data-encoding",
|
||||
"derive-new",
|
||||
"getrandom",
|
||||
"indicatif",
|
||||
"pollster",
|
||||
"rand",
|
||||
"rayon",
|
||||
"reqwest 0.12.5",
|
||||
"serde",
|
||||
"spin",
|
||||
"tokio",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-compute"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"burn-common",
|
||||
"derive-new",
|
||||
"dirs 5.0.1",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"md5",
|
||||
"pollster",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serial_test",
|
||||
"spin",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-core"
|
||||
version = "0.14.0"
|
||||
|
@ -534,49 +510,6 @@ dependencies = [
|
|||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-cube"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"burn-compute",
|
||||
"burn-cube-macros",
|
||||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"trybuild",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-cube-macros"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.71",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-cuda"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-cube",
|
||||
"burn-fusion",
|
||||
"burn-jit",
|
||||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"cudarc",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-dataset"
|
||||
version = "0.14.0"
|
||||
|
@ -653,7 +586,7 @@ dependencies = [
|
|||
"thiserror",
|
||||
"tracing-core",
|
||||
"tracing-subscriber",
|
||||
"zip 2.1.3",
|
||||
"zip 2.1.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -662,13 +595,12 @@ version = "0.14.0"
|
|||
dependencies = [
|
||||
"burn-autodiff",
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-cube",
|
||||
"burn-fusion",
|
||||
"burn-ndarray",
|
||||
"burn-tensor",
|
||||
"burn-tensor-testgen",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"half",
|
||||
"hashbrown 0.14.5",
|
||||
|
@ -727,6 +659,7 @@ dependencies = [
|
|||
"burn-common",
|
||||
"burn-tensor-testgen",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"half",
|
||||
"hashbrown 0.14.5",
|
||||
|
@ -768,19 +701,10 @@ dependencies = [
|
|||
name = "burn-wgpu"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-cube",
|
||||
"burn-fusion",
|
||||
"burn-jit",
|
||||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"derive-new",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"pollster",
|
||||
"wgpu",
|
||||
"cubecl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -850,9 +774,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "candle-core"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "311d8dbe293aa3b5c34f6a57727fafd67d17a74fa8b65276501237c233b34ffd"
|
||||
checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7"
|
||||
dependencies = [
|
||||
"accelerate-src",
|
||||
"byteorder",
|
||||
|
@ -877,18 +801,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "candle-kernels"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d3b4b048ca298fb8be90b0f4d0fe68bdca9de956ab52bb6e381463d955f2b661"
|
||||
checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488"
|
||||
dependencies = [
|
||||
"bindgen_cuda",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d31136c9541c82b7de0937c9a58210ada38e17d70810e0eacc0a99d849d848d"
|
||||
checksum = "f889aacd02fd999620a0435133d7cf3b58c81ef9dd5e47c38939b7a72345ea86"
|
||||
dependencies = [
|
||||
"metal 0.27.0",
|
||||
"once_cell",
|
||||
|
@ -922,9 +846,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.1.5"
|
||||
version = "1.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "324c74f2155653c90b04f25b2a47a8a631360cb908f92a772695f430c7e31052"
|
||||
checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f"
|
||||
dependencies = [
|
||||
"jobserver",
|
||||
"libc",
|
||||
|
@ -1362,11 +1286,124 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
|
||||
dependencies = [
|
||||
"cubecl-core",
|
||||
"cubecl-cuda",
|
||||
"cubecl-linalg",
|
||||
"cubecl-wgpu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-common"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"getrandom",
|
||||
"pollster",
|
||||
"rand",
|
||||
"serde",
|
||||
"spin",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-core"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-macros",
|
||||
"cubecl-runtime",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
"num-traits",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-cuda"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
"cubecl-core",
|
||||
"cubecl-runtime",
|
||||
"cudarc",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-linalg"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-core",
|
||||
"cubecl-runtime",
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-macros"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.71",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-runtime"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"cubecl-common",
|
||||
"derive-new",
|
||||
"dirs 5.0.1",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"md5",
|
||||
"pollster",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"spin",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cubecl-wgpu"
|
||||
version = "0.1.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
"cubecl-core",
|
||||
"cubecl-runtime",
|
||||
"derive-new",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"pollster",
|
||||
"wgpu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.11.8"
|
||||
version = "0.11.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a56028291ec3b0f6711e2e1b2d597484d359833dcb68331ce89e538012f835c4"
|
||||
checksum = "e395cd01168d63af826749573071f3c5069b338ae473cab355d22db0b2bb5a0d"
|
||||
dependencies = [
|
||||
"half",
|
||||
"libloading 0.8.4",
|
||||
|
@ -1421,6 +1458,7 @@ version = "0.14.0"
|
|||
dependencies = [
|
||||
"burn",
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
"derive-new",
|
||||
"log",
|
||||
"serde",
|
||||
|
@ -2020,15 +2058,6 @@ dependencies = [
|
|||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gelu"
|
||||
version = "0.14.0"
|
||||
dependencies = [
|
||||
"burn-cube",
|
||||
"burn-cuda",
|
||||
"burn-wgpu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gemm"
|
||||
version = "0.17.1"
|
||||
|
@ -2846,6 +2875,7 @@ dependencies = [
|
|||
"burn-import",
|
||||
"burn-wgpu",
|
||||
"console_error_panic_hook",
|
||||
"cubecl",
|
||||
"js-sys",
|
||||
"log",
|
||||
"serde",
|
||||
|
@ -4909,9 +4939,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
|||
|
||||
[[package]]
|
||||
name = "sdd"
|
||||
version = "1.6.0"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8eb0dde0ccd15e337a3cf738a9a38115c6d8e74795d074e73973dad3d229a897"
|
||||
checksum = "85f05a494052771fc5bd0619742363b5e24e5ad72ab3111ec2e27925b8edc5f3"
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
|
@ -5706,7 +5736,7 @@ dependencies = [
|
|||
"serde",
|
||||
"serde_spanned",
|
||||
"toml_datetime",
|
||||
"winnow 0.6.13",
|
||||
"winnow 0.6.14",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5826,20 +5856,6 @@ version = "0.2.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "trybuild"
|
||||
version = "1.0.97"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b1e5645f2ee8025c2f1d75e1138f2dd034d74e6ba54620f3c569ba2a2a1ea06"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"termcolor",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
|
@ -6478,9 +6494,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "winnow"
|
||||
version = "0.6.13"
|
||||
version = "0.6.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1"
|
||||
checksum = "374ec40a2d767a3c1b4972d9475ecd557356637be906f2cb3f7fe17a6eb5e22f"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
@ -6699,9 +6715,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "2.1.3"
|
||||
version = "2.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "775a2b471036342aa69bc5a602bc889cb0a06cda00477d0c69566757d5553d39"
|
||||
checksum = "e29ab4097989787b2029a5981c41b7bfb427b5a601e23f455daacb4d0360a9e9"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"arbitrary",
|
||||
|
|
|
@ -16,7 +16,7 @@ members = [
|
|||
|
||||
exclude = [
|
||||
"examples/notebook",
|
||||
# "crates/burn-cuda" # comment this line to work on burn-cuda
|
||||
"crates/burn-cuda", # comment this line to work on burn-cuda
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
@ -27,7 +27,7 @@ license = "MIT OR Apache-2.0"
|
|||
|
||||
[workspace.dependencies]
|
||||
bytemuck = "1.16.1"
|
||||
candle-core = { version = "0.5.1" }
|
||||
candle-core = { version = "0.6.0" }
|
||||
clap = { version = "4.5.9", features = ["derive"] }
|
||||
colored = "2.1.0"
|
||||
console_error_panic_hook = "0.1.7"
|
||||
|
@ -140,6 +140,9 @@ nvml-wrapper = "0.10.0"
|
|||
sysinfo = "0.30.13"
|
||||
systemstat = "0.2.3"
|
||||
|
||||
cubecl = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl", default-features = false }
|
||||
cubecl-common = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl", default-features = false }
|
||||
|
||||
[profile.dev]
|
||||
debug = 0 # Speed up compilation time and not necessary.
|
||||
opt-level = 2
|
||||
|
|
|
@ -24,14 +24,14 @@ tch-cpu = ["burn/tch"]
|
|||
tch-gpu = ["burn/tch"]
|
||||
wgpu = ["burn/wgpu", "burn/autotune"]
|
||||
wgpu-fusion = ["wgpu", "burn/fusion"]
|
||||
cuda-jit = ["burn-cuda"]
|
||||
# cuda-jit = ["burn-cuda"]
|
||||
|
||||
[dependencies]
|
||||
arboard = { workspace = true }
|
||||
burn = { path = "../crates/burn", default-features = false }
|
||||
burn-common = { path = "../crates/burn-common", version = "0.14.0" }
|
||||
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.14.0" }
|
||||
burn-cuda = { path = "../crates/burn-cuda", version = "0.14.0", optional = true }
|
||||
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.14.0", optional = true }
|
||||
# burn-cuda = { path = "../crates/burn-cuda", version = "0.14.0", optional = true }
|
||||
clap = { workspace = true }
|
||||
colored = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
|
@ -49,6 +49,7 @@ strum_macros = { workspace = true }
|
|||
sysinfo = { workspace = true, features = ["serde"] }
|
||||
wgpu = { workspace = true }
|
||||
wsl = { workspace = true }
|
||||
cubecl = { workspace = true, features = ["wgpu"] }
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = { workspace = true }
|
||||
|
|
|
@ -29,7 +29,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
|
|||
}
|
||||
|
||||
fn execute(&self, (lhs, rhs): Self::Args) {
|
||||
lhs.clone().transpose().matmul(rhs.clone());
|
||||
lhs.clone().matmul(rhs.clone());
|
||||
}
|
||||
|
||||
fn prepare(&self) -> Self::Args {
|
||||
|
@ -52,11 +52,11 @@ fn bench<B: Backend>(
|
|||
token: Option<&str>,
|
||||
) {
|
||||
const D: usize = 3;
|
||||
let batch_size = 32;
|
||||
let m = 256;
|
||||
let k = 1024;
|
||||
let n = 256;
|
||||
let shape_lhs = [batch_size, k, m].into();
|
||||
let batch_size = 8;
|
||||
let m = 2048;
|
||||
let k = 2048;
|
||||
let n = 2048;
|
||||
let shape_lhs = [batch_size, m, k].into();
|
||||
let shape_rhs = [batch_size, k, n].into();
|
||||
|
||||
let benchmark = MatmulBenchmark::<B, D>::new(shape_lhs, shape_rhs, device.clone());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::serde::{Deserialize, Serialize};
|
||||
use burn_wgpu::GraphicsApi;
|
||||
use cubecl::wgpu::GraphicsApi;
|
||||
use std::collections::HashSet;
|
||||
use sysinfo;
|
||||
use wgpu;
|
||||
|
@ -51,7 +51,7 @@ impl BenchmarkSystemInfo {
|
|||
fn enumerate_gpus() -> Vec<String> {
|
||||
let instance = wgpu::Instance::default();
|
||||
let adapters: Vec<wgpu::Adapter> = instance
|
||||
.enumerate_adapters(burn_wgpu::AutoGraphicsApi::backend().into())
|
||||
.enumerate_adapters(cubecl::wgpu::AutoGraphicsApi::backend().into())
|
||||
.into_iter()
|
||||
.filter(|adapter| {
|
||||
let info = adapter.get_info();
|
||||
|
|
|
@ -11,8 +11,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-common"
|
|||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["rand/std", "data-encoding/std", "dep:pollster"]
|
||||
default = ["std", "cubecl-common/default"]
|
||||
std = ["cubecl-common/std"]
|
||||
doc = ["default"]
|
||||
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
|
||||
rayon = ["dep:rayon"]
|
||||
|
@ -23,13 +23,7 @@ web-time = { version = "1.1.0" }
|
|||
|
||||
|
||||
[dependencies]
|
||||
# ** Please make sure all dependencies support no_std when std is disabled **
|
||||
rand = { workspace = true }
|
||||
spin = { workspace = true } # using in place of use std::sync::Mutex;
|
||||
derive-new = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
data-encoding = { workspace = true }
|
||||
pollster = { workspace = true, optional = true }
|
||||
|
||||
# Network downloader
|
||||
indicatif = { workspace = true, optional = true }
|
||||
|
@ -38,6 +32,7 @@ tokio = { workspace = true, optional = true }
|
|||
|
||||
# Parallel
|
||||
rayon = { workspace = true, optional = true }
|
||||
cubecl-common = { workspace = true, default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
dashmap = { workspace = true }
|
||||
|
|
|
@ -1,307 +0,0 @@
|
|||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt::Display;
|
||||
use core::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
|
||||
use std::time::Instant;
|
||||
#[cfg(all(target_family = "wasm", feature = "std"))]
|
||||
use web_time::Instant;
|
||||
|
||||
/// Results of a benchmark run.
|
||||
#[derive(new, Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchmarkDurations {
|
||||
/// All durations of the run, in the order they were benchmarked
|
||||
pub durations: Vec<Duration>,
|
||||
}
|
||||
|
||||
impl BenchmarkDurations {
|
||||
/// Returns a tuple of durations: (min, max, median)
|
||||
fn min_max_median_durations(&self) -> (Duration, Duration, Duration) {
|
||||
let mut sorted = self.durations.clone();
|
||||
sorted.sort();
|
||||
let min = *sorted.first().unwrap();
|
||||
let max = *sorted.last().unwrap();
|
||||
let median = *sorted.get(sorted.len() / 2).unwrap();
|
||||
(min, max, median)
|
||||
}
|
||||
|
||||
/// Returns the median duration among all durations
|
||||
pub(crate) fn mean_duration(&self) -> Duration {
|
||||
self.durations.iter().sum::<Duration>() / self.durations.len() as u32
|
||||
}
|
||||
|
||||
/// Returns the variance durations for the durations
|
||||
pub(crate) fn variance_duration(&self, mean: Duration) -> Duration {
|
||||
let var = self
|
||||
.durations
|
||||
.iter()
|
||||
.map(|duration| {
|
||||
let tmp = duration.as_secs_f64() - mean.as_secs_f64();
|
||||
Duration::from_secs_f64(tmp * tmp)
|
||||
})
|
||||
.sum::<Duration>()
|
||||
/ self.durations.len() as u32;
|
||||
var
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for BenchmarkDurations {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
let computed = BenchmarkComputations::new(self);
|
||||
let BenchmarkComputations {
|
||||
mean,
|
||||
median,
|
||||
variance,
|
||||
min,
|
||||
max,
|
||||
} = computed;
|
||||
let num_sample = self.durations.len();
|
||||
|
||||
f.write_str(
|
||||
format!(
|
||||
"
|
||||
―――――――― Result ―――――――――
|
||||
Samples {num_sample}
|
||||
Mean {mean:.3?}
|
||||
Variance {variance:.3?}
|
||||
Median {median:.3?}
|
||||
Min {min:.3?}
|
||||
Max {max:.3?}
|
||||
―――――――――――――――――――――――――"
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Computed values from benchmark durations.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub struct BenchmarkComputations {
|
||||
/// Mean of all the durations.
|
||||
pub mean: Duration,
|
||||
/// Median of all the durations.
|
||||
pub median: Duration,
|
||||
/// Variance of all the durations.
|
||||
pub variance: Duration,
|
||||
/// Minimum duration amongst all durations.
|
||||
pub min: Duration,
|
||||
/// Maximum duration amongst all durations.
|
||||
pub max: Duration,
|
||||
}
|
||||
|
||||
impl BenchmarkComputations {
|
||||
/// Compute duration values and return a BenchmarkComputations struct
|
||||
pub fn new(durations: &BenchmarkDurations) -> Self {
|
||||
let mean = durations.mean_duration();
|
||||
let (min, max, median) = durations.min_max_median_durations();
|
||||
Self {
|
||||
mean,
|
||||
median,
|
||||
min,
|
||||
max,
|
||||
variance: durations.variance_duration(mean),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmark trait.
|
||||
pub trait Benchmark {
|
||||
/// Benchmark arguments.
|
||||
type Args: Clone;
|
||||
|
||||
/// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
|
||||
/// count as included in the duration.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This should not include warmup, the benchmark will be run at least one time without
|
||||
/// measuring the execution time.
|
||||
fn prepare(&self) -> Self::Args;
|
||||
/// Execute the benchmark and returns the time it took to complete.
|
||||
fn execute(&self, args: Self::Args);
|
||||
/// Number of samples per run required to have a statistical significance.
|
||||
fn num_samples(&self) -> usize {
|
||||
10
|
||||
}
|
||||
/// Name of the benchmark, should be short and it should match the name
|
||||
/// defined in the crate Cargo.toml
|
||||
fn name(&self) -> String;
|
||||
/// The options passed to the benchmark.
|
||||
fn options(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
/// Shapes dimensions
|
||||
fn shapes(&self) -> Vec<Vec<usize>> {
|
||||
vec![]
|
||||
}
|
||||
/// Wait for computed to be over
|
||||
fn sync(&self);
|
||||
/// Run the benchmark a number of times.
|
||||
fn run(&self) -> BenchmarkDurations {
|
||||
#[cfg(not(feature = "std"))]
|
||||
panic!("Attempting to run benchmark in a no-std environment");
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
// Warmup
|
||||
let args = self.prepare();
|
||||
|
||||
self.execute(args.clone());
|
||||
self.sync();
|
||||
|
||||
let mut durations = Vec::with_capacity(self.num_samples());
|
||||
|
||||
for _ in 0..self.num_samples() {
|
||||
// Prepare
|
||||
self.sync();
|
||||
|
||||
// Execute the benchmark
|
||||
let start = Instant::now();
|
||||
self.execute(args.clone());
|
||||
self.sync();
|
||||
let end = Instant::now();
|
||||
|
||||
// Register the duration
|
||||
durations.push(end - start);
|
||||
}
|
||||
|
||||
BenchmarkDurations { durations }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of a benchmark run, with metadata
|
||||
#[derive(Default, Clone)]
|
||||
pub struct BenchmarkResult {
|
||||
/// Individual raw results of the run
|
||||
pub raw: BenchmarkDurations,
|
||||
/// Computed values for the run
|
||||
pub computed: BenchmarkComputations,
|
||||
/// Git commit hash of the commit in which the run occurred
|
||||
pub git_hash: String,
|
||||
/// Name of the benchmark
|
||||
pub name: String,
|
||||
/// Options passed to the benchmark
|
||||
pub options: Option<String>,
|
||||
/// Shape dimensions
|
||||
pub shapes: Vec<Vec<usize>>,
|
||||
/// Time just before the run
|
||||
pub timestamp: u128,
|
||||
}
|
||||
|
||||
impl Display for BenchmarkResult {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
format!(
|
||||
"
|
||||
Timestamp: {}
|
||||
Git Hash: {}
|
||||
Benchmarking - {}{}
|
||||
",
|
||||
self.timestamp, self.git_hash, self.name, self.raw
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
/// Runs the given benchmark on the device and prints result and information.
|
||||
pub fn run_benchmark<BM>(benchmark: BM) -> BenchmarkResult
|
||||
where
|
||||
BM: Benchmark,
|
||||
{
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis();
|
||||
let output = std::process::Command::new("git")
|
||||
.args(["rev-parse", "HEAD"])
|
||||
.output()
|
||||
.unwrap();
|
||||
let git_hash = String::from_utf8(output.stdout).unwrap().trim().to_string();
|
||||
let durations = benchmark.run();
|
||||
BenchmarkResult {
|
||||
raw: durations.clone(),
|
||||
computed: BenchmarkComputations::new(&durations),
|
||||
git_hash,
|
||||
name: benchmark.name(),
|
||||
options: benchmark.options(),
|
||||
shapes: benchmark.shapes(),
|
||||
timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn test_min_max_median_durations_even_number_of_samples() {
|
||||
let durations = BenchmarkDurations {
|
||||
durations: vec![
|
||||
Duration::new(10, 0),
|
||||
Duration::new(20, 0),
|
||||
Duration::new(30, 0),
|
||||
Duration::new(40, 0),
|
||||
Duration::new(50, 0),
|
||||
],
|
||||
};
|
||||
let (min, max, median) = durations.min_max_median_durations();
|
||||
assert_eq!(min, Duration::from_secs(10));
|
||||
assert_eq!(max, Duration::from_secs(50));
|
||||
assert_eq!(median, Duration::from_secs(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_max_median_durations_odd_number_of_samples() {
|
||||
let durations = BenchmarkDurations {
|
||||
durations: vec![
|
||||
Duration::new(18, 5),
|
||||
Duration::new(20, 0),
|
||||
Duration::new(30, 0),
|
||||
Duration::new(40, 0),
|
||||
],
|
||||
};
|
||||
let (min, max, median) = durations.min_max_median_durations();
|
||||
assert_eq!(min, Duration::from_nanos(18000000005_u64));
|
||||
assert_eq!(max, Duration::from_secs(40));
|
||||
assert_eq!(median, Duration::from_secs(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mean_duration() {
|
||||
let durations = BenchmarkDurations {
|
||||
durations: vec![
|
||||
Duration::new(10, 0),
|
||||
Duration::new(20, 0),
|
||||
Duration::new(30, 0),
|
||||
Duration::new(40, 0),
|
||||
],
|
||||
};
|
||||
let mean = durations.mean_duration();
|
||||
assert_eq!(mean, Duration::from_secs(25));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_variance_duration() {
|
||||
let durations = BenchmarkDurations {
|
||||
durations: vec![
|
||||
Duration::new(10, 0),
|
||||
Duration::new(20, 0),
|
||||
Duration::new(30, 0),
|
||||
Duration::new(40, 0),
|
||||
Duration::new(50, 0),
|
||||
],
|
||||
};
|
||||
let mean = durations.mean_duration();
|
||||
let variance = durations.variance_duration(mean);
|
||||
assert_eq!(variance, Duration::from_secs(200));
|
||||
}
|
||||
}
|
|
@ -5,28 +5,10 @@
|
|||
//!
|
||||
//! This library contains common types used by other Burn crates that must be shared.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// Id module contains types for unique identifiers.
|
||||
pub mod id;
|
||||
|
||||
/// Rand module contains types for random number generation for non-std environments and for
|
||||
/// std environments.
|
||||
pub mod rand;
|
||||
|
||||
/// Stub module contains types for stubs for non-std environments and for std environments.
|
||||
pub mod stub;
|
||||
|
||||
/// Module for benchmarking any executable part
|
||||
pub mod benchmark;
|
||||
|
||||
/// Useful when you need to read async data without having to decorate each function with async
|
||||
/// notation.
|
||||
pub mod reader;
|
||||
|
||||
/// Synchronization type module, used both by ComputeServer and Backends.
|
||||
pub mod sync_type;
|
||||
pub use cubecl_common::*;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
pub use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
use rand::distributions::Standard;
|
||||
use rand::prelude::Distribution;
|
||||
|
||||
/// Returns a seeded random number generator using entropy.
|
||||
#[cfg(feature = "std")]
|
||||
#[inline(always)]
|
||||
pub fn get_seeded_rng() -> StdRng {
|
||||
StdRng::from_entropy()
|
||||
}
|
||||
|
||||
/// Returns a seeded random number generator using a pre-generated seed.
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[inline(always)]
|
||||
pub fn get_seeded_rng() -> StdRng {
|
||||
const CONST_SEED: u64 = 42;
|
||||
StdRng::seed_from_u64(CONST_SEED)
|
||||
}
|
||||
|
||||
/// Generates random data from a thread-local RNG.
|
||||
#[cfg(feature = "std")]
|
||||
#[inline]
|
||||
pub fn gen_random<T>() -> T
|
||||
where
|
||||
Standard: Distribution<T>,
|
||||
{
|
||||
rand::thread_rng().gen()
|
||||
}
|
||||
|
||||
/// Generates random data from a mutex-protected RNG.
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[inline]
|
||||
pub fn gen_random<T>() -> T
|
||||
where
|
||||
Standard: Distribution<T>,
|
||||
{
|
||||
use crate::stub::Mutex;
|
||||
static RNG: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
let mut rng = RNG.lock().unwrap();
|
||||
if rng.is_none() {
|
||||
*rng = Some(get_seeded_rng());
|
||||
}
|
||||
rng.as_mut().unwrap().gen()
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
use alloc::{boxed::Box, sync::Arc, task::Wake, vec::Vec};
|
||||
use core::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
|
||||
/// A future that is used to read resources from a compute server.
|
||||
pub type Reader = Pin<Box<dyn Future<Output = Vec<u8>> + Send>>;
|
||||
|
||||
/// Create a reader from a concrete value.
|
||||
pub fn reader_from_concrete(val: Vec<u8>) -> Reader {
|
||||
Box::pin(async move { val })
|
||||
}
|
||||
|
||||
struct DummyWaker;
|
||||
|
||||
impl Wake for DummyWaker {
|
||||
fn wake(self: Arc<Self>) {}
|
||||
fn wake_by_ref(self: &Arc<Self>) {}
|
||||
}
|
||||
|
||||
/// Read a future synchronously.
|
||||
///
|
||||
/// On WASM futures cannot block, so this only succeeds if the future returns immediately.
|
||||
/// If you want to handle this error, please use
|
||||
/// try_read_sync instead.
|
||||
pub fn read_sync<F: Future<Output = T>, T>(f: F) -> T {
|
||||
try_read_sync(f).expect("Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. If possible, try using an async variant of this function instead.")
|
||||
}
|
||||
|
||||
/// Read a future synchronously.
|
||||
///
|
||||
/// On WASM futures cannot block, so this only succeeds if the future returns immediately.
|
||||
/// otherwise this returns None.
|
||||
pub fn try_read_sync<F: Future<Output = T>, T>(f: F) -> Option<T> {
|
||||
// Create a dummy context.
|
||||
let waker = Waker::from(Arc::new(DummyWaker));
|
||||
let mut context = Context::from_waker(&waker);
|
||||
|
||||
// Pin & poll the future. Some backends don't do async readbacks, and instead immediately get
|
||||
// the data. This let's us detect when a future is synchronous and doesn't require any waiting.
|
||||
let mut pinned = core::pin::pin!(f);
|
||||
|
||||
match pinned.as_mut().poll(&mut context) {
|
||||
Poll::Ready(output) => Some(output),
|
||||
// On platforms that support it, now just block on the future and drive it to completion.
|
||||
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
|
||||
Poll::Pending => Some(pollster::block_on(pinned)),
|
||||
// Otherwise, just bail and return None - this futures will have to be read back asynchronously.
|
||||
#[cfg(any(target_family = "wasm", not(feature = "std")))]
|
||||
Poll::Pending => None,
|
||||
}
|
||||
}
|
|
@ -1,154 +0,0 @@
|
|||
#[cfg(not(feature = "std"))]
|
||||
use spin::{
|
||||
Mutex as MutexImported, MutexGuard, Once as OnceImported, RwLock as RwLockImported,
|
||||
RwLockReadGuard, RwLockWriteGuard,
|
||||
};
|
||||
#[cfg(feature = "std")]
|
||||
use std::sync::{
|
||||
Mutex as MutexImported, MutexGuard, OnceLock as OnceImported, RwLock as RwLockImported,
|
||||
RwLockReadGuard, RwLockWriteGuard,
|
||||
};
|
||||
|
||||
/// A mutual exclusion primitive useful for protecting shared data
|
||||
///
|
||||
/// This mutex will block threads waiting for the lock to become available. The
|
||||
/// mutex can also be statically initialized or created via a [Mutex::new]
|
||||
///
|
||||
/// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap
|
||||
#[derive(Debug)]
|
||||
pub struct Mutex<T> {
|
||||
inner: MutexImported<T>,
|
||||
}
|
||||
|
||||
impl<T> Mutex<T> {
|
||||
/// Creates a new mutex in an unlocked state ready for use.
|
||||
#[inline(always)]
|
||||
pub const fn new(value: T) -> Self {
|
||||
Self {
|
||||
inner: MutexImported::new(value),
|
||||
}
|
||||
}
|
||||
|
||||
/// Locks the mutex blocking the current thread until it is able to do so.
|
||||
#[inline(always)]
|
||||
pub fn lock(&self) -> Result<MutexGuard<T>, alloc::string::String> {
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
Ok(self.inner.lock())
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
self.inner.lock().map_err(|err| err.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A reader-writer lock which is exclusively locked for writing or shared for reading.
|
||||
/// This reader-writer lock will block threads waiting for the lock to become available.
|
||||
/// The lock can also be statically initialized or created via a [RwLock::new]
|
||||
/// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap
|
||||
#[derive(Debug)]
|
||||
pub struct RwLock<T> {
|
||||
inner: RwLockImported<T>,
|
||||
}
|
||||
|
||||
impl<T> RwLock<T> {
|
||||
/// Creates a new reader-writer lock in an unlocked state ready for use.
|
||||
#[inline(always)]
|
||||
pub const fn new(value: T) -> Self {
|
||||
Self {
|
||||
inner: RwLockImported::new(value),
|
||||
}
|
||||
}
|
||||
|
||||
/// Locks this rwlock with shared read access, blocking the current thread
|
||||
/// until it can be acquired.
|
||||
#[inline(always)]
|
||||
pub fn read(&self) -> Result<RwLockReadGuard<T>, alloc::string::String> {
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
Ok(self.inner.read())
|
||||
}
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
self.inner.read().map_err(|err| err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Locks this rwlock with exclusive write access, blocking the current thread
|
||||
/// until it can be acquired.
|
||||
#[inline(always)]
|
||||
pub fn write(&self) -> Result<RwLockWriteGuard<T>, alloc::string::String> {
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
Ok(self.inner.write())
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
self.inner.write().map_err(|err| err.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A unique identifier for a running thread.
|
||||
///
|
||||
/// This module is a stub when no std is available to swap with std::thread::ThreadId.
|
||||
#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
|
||||
pub struct ThreadId(core::num::NonZeroU64);
|
||||
|
||||
/// A cell that provides lazy one-time initialization that implements [Sync] and [Send].
|
||||
///
|
||||
/// This module is a stub when no std is available to swap with [std::sync::OnceLock].
|
||||
pub struct SyncOnceCell<T>(OnceImported<T>);
|
||||
|
||||
impl<T: core::fmt::Debug> Default for SyncOnceCell<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::fmt::Debug> SyncOnceCell<T> {
|
||||
/// Create a new once.
|
||||
#[inline(always)]
|
||||
pub fn new() -> Self {
|
||||
Self(OnceImported::new())
|
||||
}
|
||||
|
||||
/// Initialize the cell with a value.
|
||||
#[inline(always)]
|
||||
pub fn initialized(value: T) -> Self {
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
let cell = OnceImported::initialized(value);
|
||||
Self(cell)
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
let cell = OnceImported::new();
|
||||
cell.set(value).unwrap();
|
||||
|
||||
Self(cell)
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the contents of the cell, initializing it with `f` if the cell
|
||||
/// was empty.
|
||||
#[inline(always)]
|
||||
pub fn get_or_init<F>(&self, f: F) -> &T
|
||||
where
|
||||
F: FnOnce() -> T,
|
||||
{
|
||||
#[cfg(not(feature = "std"))]
|
||||
{
|
||||
self.0.call_once(f)
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
{
|
||||
self.0.get_or_init(f)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
/// What kind of synchronization to use.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SyncType {
|
||||
/// Submit all outstanding tasks to the task queue if any.
|
||||
Flush,
|
||||
/// Submit all tasks to the task queue and wait for all of them to complete.
|
||||
Wait,
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
[package]
|
||||
authors = ["louisfd <louisfd94@gmail.com>", "Nathaniel Simard"]
|
||||
categories = ["science"]
|
||||
description = "Compute crate that helps creating high performance async backends."
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "data"]
|
||||
license.workspace = true
|
||||
name = "burn-compute"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-compute"
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = [
|
||||
"std",
|
||||
"channel-mutex",
|
||||
"channel-mpsc",
|
||||
"channel-cell",
|
||||
"storage-bytes",
|
||||
"autotune-persistent-cache",
|
||||
]
|
||||
std = ["burn-common/std"]
|
||||
channel-mutex = []
|
||||
channel-cell = []
|
||||
channel-mpsc = ["dep:async-channel", "dep:pollster"] # Assume std
|
||||
storage-bytes = []
|
||||
autotune-persistent-cache = ["dirs", "md5", "serde", "serde_json"] # Assume std
|
||||
|
||||
[dependencies]
|
||||
burn-common = { path = "../burn-common", version = "0.14.0", default-features = false }
|
||||
derive-new = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
log = { workspace = true }
|
||||
hashbrown = { workspace = true }
|
||||
dirs = { workspace = true, optional = true }
|
||||
serde = { workspace = true, optional = true }
|
||||
serde_json = { workspace = true, features = ["std"], optional = true }
|
||||
md5 = { workspace = true, optional = true }
|
||||
pollster = { workspace = true, optional = true }
|
||||
async-channel = { workspace = true, optional = true }
|
||||
|
||||
[target.'cfg(target_family = "wasm")'.dependencies]
|
||||
web-time = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
serial_test = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
name = "dynamic"
|
||||
harness = false
|
|
@ -1 +0,0 @@
|
|||
../../LICENSE-APACHE
|
|
@ -1 +0,0 @@
|
|||
../../LICENSE-MIT
|
|
@ -1,7 +0,0 @@
|
|||
# Burn Compute
|
||||
|
||||
This crate helps creating high performance async backends.
|
||||
|
||||
- [x] Asynchronous kernel executions
|
||||
- [x] Memory allocation management
|
||||
- [x] Autotuning
|
|
@ -1,29 +0,0 @@
|
|||
use std::collections::LinkedList;
|
||||
|
||||
use burn_compute::{
|
||||
memory_management::{
|
||||
dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
|
||||
MemoryManagement,
|
||||
},
|
||||
storage::BytesStorage,
|
||||
};
|
||||
|
||||
const MB: usize = 1024 * 1024;
|
||||
|
||||
fn main() {
|
||||
let start = std::time::Instant::now();
|
||||
let storage = BytesStorage::default();
|
||||
let mut mm = DynamicMemoryManagement::new(
|
||||
storage,
|
||||
DynamicMemoryManagementOptions::preset(2048 * MB, 32),
|
||||
);
|
||||
let mut handles = LinkedList::new();
|
||||
for _ in 0..100 * 2048 {
|
||||
if handles.len() >= 4000 {
|
||||
handles.pop_front();
|
||||
}
|
||||
let handle = mm.reserve(MB, || {});
|
||||
handles.push_back(handle);
|
||||
}
|
||||
println!("{:?}", start.elapsed());
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
use crate::{
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
storage::ComputeStorage,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
|
||||
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
|
||||
/// while ensuring thread-safety
|
||||
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send + Sync {
|
||||
/// Given a binding, returns owned resource as bytes
|
||||
fn read(&self, binding: Binding<Server>) -> Reader;
|
||||
|
||||
/// Given a resource handle, return the storage resource.
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the resource handle
|
||||
fn create(&self, data: &[u8]) -> Handle<Server>;
|
||||
|
||||
/// Reserves `size` bytes in the storage, and returns a handle over them
|
||||
fn empty(&self, size: usize) -> Handle<Server>;
|
||||
|
||||
/// Executes the `kernel` over the given `bindings`.
|
||||
fn execute(
|
||||
&self,
|
||||
kernel: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
bindings: Vec<Binding<Server>>,
|
||||
);
|
||||
|
||||
/// Perform some synchronization of commands on the server.
|
||||
fn sync(&self, sync_type: SyncType);
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
use super::ComputeChannel;
|
||||
use crate::server::{Binding, ComputeServer, Handle};
|
||||
use crate::storage::ComputeStorage;
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::sync_type::SyncType;
|
||||
|
||||
/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability.
|
||||
///
|
||||
/// # Important
|
||||
///
|
||||
/// Only use this channel if you don't use any threading in your application, otherwise it will
|
||||
/// panic or cause undefined behaviors.
|
||||
///
|
||||
/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer
|
||||
/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels.
|
||||
#[derive(Debug)]
|
||||
pub struct RefCellComputeChannel<Server> {
|
||||
server: Arc<core::cell::RefCell<Server>>,
|
||||
}
|
||||
|
||||
impl<S> Clone for RefCellComputeChannel<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
server: self.server.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server> RefCellComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
/// Create a new cell compute channel.
|
||||
pub fn new(server: Server) -> Self {
|
||||
Self {
|
||||
server: Arc::new(core::cell::RefCell::new(server)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer + Send,
|
||||
{
|
||||
fn read(&self, binding: Binding<Server>) -> Reader {
|
||||
self.server.borrow_mut().read(binding)
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
self.server.borrow_mut().get_resource(binding)
|
||||
}
|
||||
|
||||
fn create(&self, resource: &[u8]) -> Handle<Server> {
|
||||
self.server.borrow_mut().create(resource)
|
||||
}
|
||||
|
||||
fn empty(&self, size: usize) -> Handle<Server> {
|
||||
self.server.borrow_mut().empty(size)
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
kernel_description: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
bindings: Vec<Binding<Server>>,
|
||||
) {
|
||||
self.server
|
||||
.borrow_mut()
|
||||
.execute(kernel_description, count, bindings)
|
||||
}
|
||||
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
self.server.borrow_mut().sync(sync_type)
|
||||
}
|
||||
}
|
||||
|
||||
/// This is unsafe, since no concurrency is supported by the `RefCell` channel.
|
||||
/// However using this channel should only be done in single threaded environments such as `no-std`.
|
||||
unsafe impl<Server: ComputeServer> Send for RefCellComputeChannel<Server> {}
|
||||
unsafe impl<Server: ComputeServer> Sync for RefCellComputeChannel<Server> {}
|
|
@ -1,17 +0,0 @@
|
|||
mod base;
|
||||
pub use base::*;
|
||||
|
||||
#[cfg(feature = "channel-mutex")]
|
||||
mod mutex;
|
||||
#[cfg(feature = "channel-mutex")]
|
||||
pub use mutex::*;
|
||||
|
||||
#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))]
|
||||
mod mpsc;
|
||||
#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))]
|
||||
pub use mpsc::*;
|
||||
|
||||
#[cfg(feature = "channel-cell")]
|
||||
mod cell;
|
||||
#[cfg(feature = "channel-cell")]
|
||||
pub use cell::*;
|
|
@ -1,181 +0,0 @@
|
|||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
use std::{sync::Arc, thread};
|
||||
|
||||
use super::ComputeChannel;
|
||||
use crate::{
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
storage::ComputeStorage,
|
||||
};
|
||||
|
||||
/// Create a channel using a [multi-producer, single-consumer channel to communicate with
|
||||
/// the compute server spawn on its own thread.
|
||||
#[derive(Debug)]
|
||||
pub struct MpscComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
state: Arc<MpscComputeChannelState<Server>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MpscComputeChannelState<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
_handle: thread::JoinHandle<()>,
|
||||
sender: async_channel::Sender<Message<Server>>,
|
||||
}
|
||||
|
||||
type Callback<Response> = async_channel::Sender<Response>;
|
||||
|
||||
enum Message<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
Read(Binding<Server>, Callback<Vec<u8>>),
|
||||
GetResource(
|
||||
Binding<Server>,
|
||||
Callback<<Server::Storage as ComputeStorage>::Resource>,
|
||||
),
|
||||
Create(Vec<u8>, Callback<Handle<Server>>),
|
||||
Empty(usize, Callback<Handle<Server>>),
|
||||
ExecuteKernel(
|
||||
(Server::Kernel, Server::DispatchOptions),
|
||||
Vec<Binding<Server>>,
|
||||
),
|
||||
Sync(SyncType, Callback<()>),
|
||||
}
|
||||
|
||||
impl<Server> MpscComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer + 'static,
|
||||
{
|
||||
/// Create a new mpsc compute channel.
|
||||
pub fn new(mut server: Server) -> Self {
|
||||
let (sender, receiver) = async_channel::unbounded();
|
||||
|
||||
let _handle = thread::spawn(move || {
|
||||
// Run the whole procedure as one blocking future. This is much simpler than trying
|
||||
// to use some multithreaded executor.
|
||||
pollster::block_on(async {
|
||||
while let Ok(message) = receiver.recv().await {
|
||||
match message {
|
||||
Message::Read(binding, callback) => {
|
||||
let data = server.read(binding).await;
|
||||
callback.send(data).await.unwrap();
|
||||
}
|
||||
Message::GetResource(binding, callback) => {
|
||||
let data = server.get_resource(binding);
|
||||
callback.send(data).await.unwrap();
|
||||
}
|
||||
Message::Create(data, callback) => {
|
||||
let handle = server.create(&data);
|
||||
callback.send(handle).await.unwrap();
|
||||
}
|
||||
Message::Empty(size, callback) => {
|
||||
let handle = server.empty(size);
|
||||
callback.send(handle).await.unwrap();
|
||||
}
|
||||
Message::ExecuteKernel(kernel, bindings) => {
|
||||
server.execute(kernel.0, kernel.1, bindings);
|
||||
}
|
||||
Message::Sync(sync_type, callback) => {
|
||||
server.sync(sync_type);
|
||||
callback.send(()).await.unwrap();
|
||||
}
|
||||
};
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
let state = Arc::new(MpscComputeChannelState { sender, _handle });
|
||||
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
state: self.state.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer + 'static,
|
||||
{
|
||||
fn read(&self, binding: Binding<Server>) -> Reader {
|
||||
let sender = self.state.sender.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
sender.send(Message::Read(binding, callback)).await.unwrap();
|
||||
handle_response(response.recv().await)
|
||||
})
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send_blocking(Message::GetResource(binding, callback))
|
||||
.unwrap();
|
||||
|
||||
handle_response(response.recv_blocking())
|
||||
}
|
||||
|
||||
fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
|
||||
self.state
|
||||
.sender
|
||||
.send_blocking(Message::Create(data.to_vec(), callback))
|
||||
.unwrap();
|
||||
|
||||
handle_response(response.recv_blocking())
|
||||
}
|
||||
|
||||
fn empty(&self, size: usize) -> Handle<Server> {
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
self.state
|
||||
.sender
|
||||
.send_blocking(Message::Empty(size, callback))
|
||||
.unwrap();
|
||||
|
||||
handle_response(response.recv_blocking())
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
kernel: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
bindings: Vec<Binding<Server>>,
|
||||
) {
|
||||
self.state
|
||||
.sender
|
||||
.send_blocking(Message::ExecuteKernel((kernel, count), bindings))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
let (callback, response) = async_channel::unbounded();
|
||||
self.state
|
||||
.sender
|
||||
.send_blocking(Message::Sync(sync_type, callback))
|
||||
.unwrap();
|
||||
handle_response(response.recv_blocking())
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_response<Response, Err: core::fmt::Debug>(response: Result<Response, Err>) -> Response {
|
||||
match response {
|
||||
Ok(val) => val,
|
||||
Err(err) => panic!("Can't connect to the server correctly {err:?}"),
|
||||
}
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
use super::ComputeChannel;
|
||||
use crate::server::{Binding, ComputeServer, Handle};
|
||||
use crate::storage::ComputeStorage;
|
||||
use alloc::sync::Arc;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use burn_common::sync_type::SyncType;
|
||||
use spin::Mutex;
|
||||
|
||||
/// The MutexComputeChannel ensures thread-safety by locking the server
|
||||
/// on every operation
|
||||
#[derive(Debug)]
|
||||
pub struct MutexComputeChannel<Server> {
|
||||
server: Arc<Mutex<Server>>,
|
||||
}
|
||||
|
||||
impl<S> Clone for MutexComputeChannel<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
server: self.server.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<Server> MutexComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
/// Create a new mutex compute channel.
|
||||
pub fn new(server: Server) -> Self {
|
||||
Self {
|
||||
server: Arc::new(Mutex::new(server)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
{
|
||||
fn read(&self, handle: Binding<Server>) -> Reader {
|
||||
self.server.lock().read(handle)
|
||||
}
|
||||
|
||||
fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
self.server.lock().get_resource(binding)
|
||||
}
|
||||
|
||||
fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
self.server.lock().create(data)
|
||||
}
|
||||
|
||||
fn empty(&self, size: usize) -> Handle<Server> {
|
||||
self.server.lock().empty(size)
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&self,
|
||||
kernel: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
handles: Vec<Binding<Server>>,
|
||||
) {
|
||||
self.server.lock().execute(kernel, count, handles)
|
||||
}
|
||||
|
||||
fn sync(&self, sync_type: SyncType) {
|
||||
self.server.lock().sync(sync_type)
|
||||
}
|
||||
}
|
|
@ -1,119 +0,0 @@
|
|||
use crate::{
|
||||
channel::ComputeChannel,
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
storage::ComputeStorage,
|
||||
tune::{AutotuneOperationSet, Tuner},
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{boxed::Box, sync::Arc};
|
||||
use burn_common::stub::RwLock;
|
||||
use burn_common::sync_type::SyncType;
|
||||
|
||||
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
||||
/// It should be obtained for a specific device via the Compute struct.
|
||||
#[derive(Debug)]
|
||||
pub struct ComputeClient<Server: ComputeServer, Channel> {
|
||||
channel: Channel,
|
||||
tuner: Arc<RwLock<Tuner<Server::AutotuneKey>>>,
|
||||
features: Arc<Server::FeatureSet>,
|
||||
}
|
||||
|
||||
impl<S, C> Clone for ComputeClient<S, C>
|
||||
where
|
||||
S: ComputeServer,
|
||||
C: ComputeChannel<S>,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
channel: self.channel.clone(),
|
||||
tuner: self.tuner.clone(),
|
||||
features: self.features.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server, Channel> ComputeClient<Server, Channel>
|
||||
where
|
||||
Server: ComputeServer,
|
||||
Channel: ComputeChannel<Server>,
|
||||
{
|
||||
/// Create a new client.
|
||||
pub fn new(
|
||||
channel: Channel,
|
||||
tuner: Arc<RwLock<Tuner<Server::AutotuneKey>>>,
|
||||
features: Arc<Server::FeatureSet>,
|
||||
) -> Self {
|
||||
Self {
|
||||
channel,
|
||||
tuner,
|
||||
features,
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a binding, returns owned resource as bytes.
|
||||
pub async fn read_async(&self, binding: Binding<Server>) -> Vec<u8> {
|
||||
self.channel.read(binding).await
|
||||
}
|
||||
|
||||
/// Given a binding, returns owned resource as bytes.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Panics if the read operation fails.
|
||||
pub fn read(&self, binding: Binding<Server>) -> Vec<u8> {
|
||||
burn_common::reader::read_sync(self.channel.read(binding))
|
||||
}
|
||||
|
||||
/// Given a resource handle, returns the storage resource.
|
||||
pub fn get_resource(
|
||||
&self,
|
||||
binding: Binding<Server>,
|
||||
) -> <Server::Storage as ComputeStorage>::Resource {
|
||||
self.channel.get_resource(binding)
|
||||
}
|
||||
|
||||
/// Given a resource, stores it and returns the resource handle.
|
||||
pub fn create(&self, data: &[u8]) -> Handle<Server> {
|
||||
self.channel.create(data)
|
||||
}
|
||||
|
||||
/// Reserves `size` bytes in the storage, and returns a handle over them.
|
||||
pub fn empty(&self, size: usize) -> Handle<Server> {
|
||||
self.channel.empty(size)
|
||||
}
|
||||
|
||||
/// Executes the `kernel` over the given `bindings`.
|
||||
pub fn execute(
|
||||
&self,
|
||||
kernel: Server::Kernel,
|
||||
count: Server::DispatchOptions,
|
||||
bindings: Vec<Binding<Server>>,
|
||||
) {
|
||||
self.channel.execute(kernel, count, bindings)
|
||||
}
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
pub fn sync(&self, sync_type: SyncType) {
|
||||
self.channel.sync(sync_type)
|
||||
}
|
||||
|
||||
/// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks
|
||||
pub fn autotune_execute(
|
||||
&self,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet<Server::AutotuneKey>>,
|
||||
) {
|
||||
self.tuner
|
||||
.write()
|
||||
.unwrap()
|
||||
.execute_autotune(autotune_operation_set, self);
|
||||
}
|
||||
|
||||
/// Get the fastest kernel for the given autotune key if it exists.
|
||||
pub fn autotune_result(&self, key: &Server::AutotuneKey) -> Option<usize> {
|
||||
self.tuner.read().unwrap().autotune_fastest(key)
|
||||
}
|
||||
|
||||
/// Get the features supported by the compute server.
|
||||
pub fn features(&self) -> &Server::FeatureSet {
|
||||
self.features.as_ref()
|
||||
}
|
||||
}
|
|
@ -1,94 +0,0 @@
|
|||
use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
|
||||
use core::ops::DerefMut;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// The compute type has the responsibility to retrieve the correct compute client based on the
|
||||
/// given device.
|
||||
pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
|
||||
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
|
||||
}
|
||||
|
||||
impl<Device, Server, Channel> Default for ComputeRuntime<Device, Server, Channel>
|
||||
where
|
||||
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
|
||||
Server: ComputeServer,
|
||||
Channel: ComputeChannel<Server>,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
|
||||
where
|
||||
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
|
||||
Server: ComputeServer,
|
||||
Channel: ComputeChannel<Server>,
|
||||
{
|
||||
/// Create a new compute.
|
||||
pub const fn new() -> Self {
|
||||
Self {
|
||||
clients: spin::Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the compute client for the given device.
|
||||
///
|
||||
/// Provide the init function to create a new client if it isn't already initialized.
|
||||
pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
|
||||
where
|
||||
Init: Fn() -> ComputeClient<Server, Channel>,
|
||||
{
|
||||
let mut clients = self.clients.lock();
|
||||
|
||||
if clients.is_none() {
|
||||
Self::register_inner(device, init(), &mut clients);
|
||||
}
|
||||
|
||||
match clients.deref_mut() {
|
||||
Some(clients) => match clients.get(device) {
|
||||
Some(client) => client.clone(),
|
||||
None => {
|
||||
let client = init();
|
||||
clients.insert(device.clone(), client.clone());
|
||||
client
|
||||
}
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the compute client for the given device.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This function is mostly useful when the creation of the compute client can't be done
|
||||
/// synchronously and require special context.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If a client is already registered for the given device.
|
||||
pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
|
||||
let mut clients = self.clients.lock();
|
||||
|
||||
Self::register_inner(device, client, &mut clients);
|
||||
}
|
||||
|
||||
fn register_inner(
|
||||
device: &Device,
|
||||
client: ComputeClient<Server, Channel>,
|
||||
clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
|
||||
) {
|
||||
if clients.is_none() {
|
||||
*clients = Some(HashMap::new());
|
||||
}
|
||||
|
||||
if let Some(clients) = clients {
|
||||
if clients.contains_key(device) {
|
||||
panic!("Client already created for device {:?}", device);
|
||||
}
|
||||
|
||||
clients.insert(device.clone(), client);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,175 +0,0 @@
|
|||
use alloc::sync::Arc;
|
||||
|
||||
#[macro_export(local_inner_macros)]
|
||||
/// Create a new storage ID type.
|
||||
macro_rules! storage_id_type {
|
||||
($name:ident) => {
|
||||
/// Storage ID.
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
|
||||
pub struct $name {
|
||||
value: usize,
|
||||
}
|
||||
|
||||
impl $name {
|
||||
/// Create a new ID.
|
||||
pub fn new() -> Self {
|
||||
use core::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
static COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
let value = COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
if value == usize::MAX {
|
||||
core::panic!("Memory ID overflowed");
|
||||
}
|
||||
Self { value }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for $name {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Reference to a buffer handle.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct HandleRef<Id> {
|
||||
id: Arc<Id>,
|
||||
all: Arc<()>,
|
||||
}
|
||||
|
||||
/// Reference to buffer binding.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BindingRef<Id> {
|
||||
id: Id,
|
||||
_all: Arc<()>,
|
||||
}
|
||||
|
||||
impl<Id> BindingRef<Id>
|
||||
where
|
||||
Id: Clone + core::fmt::Debug,
|
||||
{
|
||||
/// The id associated to the buffer.
|
||||
pub(crate) fn id(&self) -> &Id {
|
||||
&self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> HandleRef<Id>
|
||||
where
|
||||
Id: Clone + core::fmt::Debug,
|
||||
{
|
||||
/// Create a new handle.
|
||||
pub(crate) fn new(id: Id) -> Self {
|
||||
Self {
|
||||
id: Arc::new(id),
|
||||
all: Arc::new(()),
|
||||
}
|
||||
}
|
||||
|
||||
/// The id associated to the handle.
|
||||
pub(crate) fn id(&self) -> &Id {
|
||||
&self.id
|
||||
}
|
||||
|
||||
/// Get the binding.
|
||||
pub(crate) fn binding(self) -> BindingRef<Id> {
|
||||
BindingRef {
|
||||
id: self.id.as_ref().clone(),
|
||||
_all: self.all,
|
||||
}
|
||||
}
|
||||
|
||||
/// If the handle can be mut.
|
||||
pub(crate) fn can_mut(&self) -> bool {
|
||||
// 1 memory management reference with 1 tensor reference.
|
||||
Arc::strong_count(&self.id) <= 2
|
||||
}
|
||||
|
||||
/// If the resource is free.
|
||||
pub(crate) fn is_free(&self) -> bool {
|
||||
Arc::strong_count(&self.all) <= 1
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export(local_inner_macros)]
|
||||
/// Create new memory ID types.
|
||||
macro_rules! memory_id_type {
|
||||
($id:ident, $handle:ident) => {
|
||||
/// Memory Handle.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct $handle {
|
||||
value: $crate::id::HandleRef<$id>,
|
||||
}
|
||||
|
||||
/// Memory ID.
|
||||
#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
|
||||
pub struct $id {
|
||||
pub(crate) value: usize,
|
||||
}
|
||||
|
||||
impl $handle {
|
||||
/// Create a new ID.
|
||||
pub(crate) fn new() -> Self {
|
||||
let value = Self::gen_id();
|
||||
Self {
|
||||
value: $crate::id::HandleRef::new($id { value }),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_id() -> usize {
|
||||
static COUNTER: core::sync::atomic::AtomicUsize =
|
||||
core::sync::atomic::AtomicUsize::new(0);
|
||||
|
||||
let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
|
||||
if value == usize::MAX {
|
||||
core::panic!("Memory ID overflowed");
|
||||
}
|
||||
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Deref for $handle {
|
||||
type Target = $crate::id::HandleRef<$id>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for $handle {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
($id:ident, $handle:ident, $binding:ident) => {
|
||||
memory_id_type!($id, $handle);
|
||||
|
||||
/// Binding of a memory handle.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct $binding {
|
||||
value: $crate::id::BindingRef<$id>,
|
||||
}
|
||||
|
||||
impl $handle {
|
||||
pub(crate) fn binding(self) -> $binding {
|
||||
$binding {
|
||||
value: self.value.binding(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Deref for $binding {
|
||||
type Target = $crate::id::BindingRef<$id>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
|
@ -1,29 +0,0 @@
|
|||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! Burn compute crate that helps creating high performance async backends.
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
mod id;
|
||||
|
||||
/// Compute channel module.
|
||||
pub mod channel;
|
||||
/// Compute client module.
|
||||
pub mod client;
|
||||
|
||||
/// Autotune module
|
||||
pub mod tune;
|
||||
|
||||
/// Memory management module.
|
||||
pub mod memory_management;
|
||||
/// Compute server module.
|
||||
pub mod server;
|
||||
/// Compute Storage module.
|
||||
pub mod storage;
|
||||
|
||||
mod compute;
|
||||
pub use compute::*;
|
|
@ -1,57 +0,0 @@
|
|||
use crate::storage::ComputeStorage;
|
||||
|
||||
/// The managed tensor buffer handle that points to some memory segment.
|
||||
/// It should not contain actual data.
|
||||
pub trait MemoryHandle<Binding>: Clone + Send + Sync + core::fmt::Debug {
|
||||
/// Checks if the underlying memory can be safely mutated.
|
||||
fn can_mut(&self) -> bool;
|
||||
/// Get the binding associated to the current handle.
|
||||
fn binding(self) -> Binding;
|
||||
}
|
||||
|
||||
/// Binding to a [memory handle](MemoryHandle).
|
||||
pub trait MemoryBinding: Clone + Send + Sync + core::fmt::Debug {}
|
||||
|
||||
/// The MemoryManagement trait encapsulates strategies for (de)allocating memory.
|
||||
/// It is bound to the ComputeStorage trait, which does the actual (de)allocations.
|
||||
///
|
||||
/// The MemoryManagement can only reserve memory space or get the resource located at a space.
|
||||
/// Modification of the resource data should be done directly on the resource.
|
||||
pub trait MemoryManagement<Storage: ComputeStorage>: Send + core::fmt::Debug {
|
||||
/// The associated type that must implement [MemoryHandle].
|
||||
type Handle: MemoryHandle<Self::Binding>;
|
||||
/// The associated type that must implement [MemoryBinding]
|
||||
type Binding: MemoryBinding;
|
||||
|
||||
/// Returns the resource from the storage at the specified handle
|
||||
fn get(&mut self, binding: Self::Binding) -> Storage::Resource;
|
||||
|
||||
/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
|
||||
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
|
||||
|
||||
/// Bypass the memory allocation algorithm to allocate data directly.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Can be useful for servers that want specific control over memory.
|
||||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
|
||||
|
||||
/// Bypass the memory allocation algorithm to deallocate data directly.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Can be useful for servers that want specific control over memory.
|
||||
fn dealloc(&mut self, binding: Self::Binding);
|
||||
|
||||
/// Fetch the storage used by the memory manager.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The storage should probably not be used for allocations since the handles won't be
|
||||
/// compatible with the ones provided by the current trait. Prefer using the
|
||||
/// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions.
|
||||
///
|
||||
/// This is useful if you need to time the deallocations based on async computation, or to
|
||||
/// change the mode of storage for different reasons.
|
||||
fn storage(&mut self) -> &mut Storage;
|
||||
}
|
|
@ -1,181 +0,0 @@
|
|||
use super::memory_pool::{
|
||||
MemoryExtensionStrategy, MemoryPool, MemoryPoolBinding, MemoryPoolHandle, RoundingStrategy,
|
||||
SmallMemoryPool,
|
||||
};
|
||||
use crate::storage::ComputeStorage;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::MemoryManagement;
|
||||
|
||||
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
||||
pub struct DynamicMemoryManagement<Storage> {
|
||||
min_chunk_alignment_offset: usize,
|
||||
small_memory_pool: SmallMemoryPool,
|
||||
pools: Vec<MemoryPool>,
|
||||
options: Vec<MemoryPoolOptions>,
|
||||
storage: Storage,
|
||||
}
|
||||
|
||||
/// Options to initialize a [dynamic memory management](DynamicMemoryManagement).
|
||||
#[derive(new, Debug)]
|
||||
pub struct DynamicMemoryManagementOptions {
|
||||
pools: Vec<MemoryPoolOptions>,
|
||||
min_chunk_alignment_offset: usize,
|
||||
}
|
||||
|
||||
/// Options to create a memory pool.
|
||||
#[derive(Debug)]
|
||||
pub struct MemoryPoolOptions {
|
||||
/// The amount of bytes used for each chunk in the memory pool.
|
||||
pub chunk_size: usize,
|
||||
/// The number of chunks allocated directly at creation.
|
||||
///
|
||||
/// Useful when you know in advance how much memory you'll need.
|
||||
pub chunk_num_prealloc: usize,
|
||||
/// The max size in bytes a slice can take in the pool.
|
||||
pub slice_max_size: usize,
|
||||
}
|
||||
|
||||
impl DynamicMemoryManagementOptions {
|
||||
/// Creates the options from device limits.
|
||||
pub fn preset(max_chunk_size: usize, min_chunk_alignment_offset: usize) -> Self {
|
||||
// Rounding down to a factor of 8.
|
||||
let max_chunk_size = (max_chunk_size / 8) * 8;
|
||||
|
||||
const MB: usize = 1024 * 1024;
|
||||
|
||||
let mut pools = Vec::new();
|
||||
|
||||
pools.push(MemoryPoolOptions {
|
||||
chunk_size: max_chunk_size,
|
||||
chunk_num_prealloc: 0,
|
||||
slice_max_size: max_chunk_size,
|
||||
});
|
||||
|
||||
let mut current = max_chunk_size;
|
||||
|
||||
while current >= 32 * MB {
|
||||
current /= 4;
|
||||
|
||||
pools.push(MemoryPoolOptions {
|
||||
chunk_size: current,
|
||||
chunk_num_prealloc: 0,
|
||||
// Creating max slices lower than the chunk size reduces fragmentation.
|
||||
slice_max_size: current / 2usize.pow(pools.len() as u32),
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
pools,
|
||||
min_chunk_alignment_offset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
||||
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
|
||||
pub fn new(mut storage: Storage, mut options: DynamicMemoryManagementOptions) -> Self {
|
||||
options
|
||||
.pools
|
||||
.sort_by(|pool1, pool2| usize::cmp(&pool1.slice_max_size, &pool2.slice_max_size));
|
||||
|
||||
let min_chunk_alignment_offset = options.min_chunk_alignment_offset;
|
||||
|
||||
let pools = options
|
||||
.pools
|
||||
.iter()
|
||||
.map(|option| {
|
||||
let mut pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::Never,
|
||||
RoundingStrategy::FixedAmount(option.chunk_size),
|
||||
min_chunk_alignment_offset,
|
||||
);
|
||||
|
||||
for _ in 0..option.chunk_num_prealloc {
|
||||
pool.alloc(&mut storage, option.chunk_size, || {});
|
||||
}
|
||||
|
||||
pool
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
min_chunk_alignment_offset,
|
||||
small_memory_pool: SmallMemoryPool::new(min_chunk_alignment_offset),
|
||||
pools,
|
||||
options: options.pools,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage> core::fmt::Debug for DynamicMemoryManagement<Storage> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
alloc::format!(
|
||||
"DynamicMemoryManagement {:?}",
|
||||
core::any::type_name::<Storage>(),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagement<Storage> {
|
||||
type Handle = MemoryPoolHandle;
|
||||
type Binding = MemoryPoolBinding;
|
||||
|
||||
fn get(&mut self, binding: Self::Binding) -> Storage::Resource {
|
||||
if let Some(handle) = self.small_memory_pool.get(&mut self.storage, &binding) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
for pool in &mut self.pools {
|
||||
if let Some(handle) = pool.get(&mut self.storage, &binding) {
|
||||
return handle;
|
||||
}
|
||||
}
|
||||
|
||||
panic!("No handle found in memory pools");
|
||||
}
|
||||
|
||||
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
||||
if size <= self.min_chunk_alignment_offset {
|
||||
return self
|
||||
.small_memory_pool
|
||||
.reserve(&mut self.storage, size, sync);
|
||||
}
|
||||
|
||||
for (index, option) in self.options.iter().enumerate() {
|
||||
if size <= option.slice_max_size {
|
||||
let pool = &mut self.pools[index];
|
||||
return pool.reserve(&mut self.storage, size, sync);
|
||||
}
|
||||
}
|
||||
|
||||
panic!("No memory pool big enough to reserve {size} bytes.");
|
||||
}
|
||||
|
||||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
||||
if size <= self.min_chunk_alignment_offset {
|
||||
return self.small_memory_pool.alloc(&mut self.storage, size, sync);
|
||||
}
|
||||
|
||||
for (index, option) in self.options.iter().enumerate() {
|
||||
if size <= option.slice_max_size {
|
||||
let pool = &mut self.pools[index];
|
||||
return pool.alloc(&mut self.storage, size, sync);
|
||||
}
|
||||
}
|
||||
|
||||
panic!("No memory pool big enough to alloc {size} bytes.");
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, _binding: Self::Binding) {
|
||||
// Can't dealloc slices.
|
||||
}
|
||||
|
||||
fn storage(&mut self) -> &mut Storage {
|
||||
&mut self.storage
|
||||
}
|
||||
}
|
|
@ -1,586 +0,0 @@
|
|||
use super::index::SearchIndex;
|
||||
use super::{
|
||||
ChunkHandle, ChunkId, MemoryChunk, MemoryPoolBinding, MemoryPoolHandle, MemorySlice,
|
||||
RingBuffer, SliceHandle, SliceId,
|
||||
};
|
||||
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
|
||||
use alloc::vec::Vec;
|
||||
use hashbrown::{HashMap, HashSet};
|
||||
|
||||
pub struct MemoryPool {
|
||||
chunks: HashMap<ChunkId, Chunk>,
|
||||
slices: HashMap<SliceId, Slice>,
|
||||
#[allow(unused)] // will be used when we rewrite memory extension
|
||||
memory_extension_strategy: MemoryExtensionStrategy,
|
||||
rounding: RoundingStrategy,
|
||||
chunk_index: SearchIndex<ChunkId>,
|
||||
ring: RingBuffer<Chunk, Slice>,
|
||||
recently_added_chunks: Vec<ChunkId>,
|
||||
recently_allocated_size: usize,
|
||||
buffer_alignment: usize,
|
||||
}
|
||||
|
||||
struct SliceUpdate {
|
||||
slice_id: SliceId,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct Chunk {
|
||||
pub storage: StorageHandle,
|
||||
pub handle: ChunkHandle,
|
||||
pub slices: MemoryPage,
|
||||
}
|
||||
|
||||
// TODO: consider using generic trait and decouple from Slice
|
||||
#[derive(new, Debug)]
|
||||
pub struct MemoryPage {
|
||||
pub slices: HashMap<usize, SliceId>,
|
||||
}
|
||||
|
||||
impl MemoryPage {
|
||||
/// merge slice at first_slice_address with the next slice (if there is one and if it's free)
|
||||
/// return a boolean representing if a merge happened
|
||||
fn merge_with_next_slice(
|
||||
&mut self,
|
||||
first_slice_address: usize,
|
||||
slices: &mut HashMap<SliceId, Slice>,
|
||||
) -> bool {
|
||||
let first_slice_id = self.find_slice(first_slice_address).expect(
|
||||
"merge_with_next_slice shouldn't be called with a nonexistent first_slice address",
|
||||
);
|
||||
|
||||
let next_slice_address =
|
||||
first_slice_address + slices.get(&first_slice_id).unwrap().effective_size();
|
||||
|
||||
if let Some(next_slice_id) = self.find_slice(next_slice_address) {
|
||||
let (next_slice_eff_size, next_slice_is_free) = {
|
||||
let next_slice = slices.get(&next_slice_id).unwrap();
|
||||
(next_slice.effective_size(), next_slice.is_free())
|
||||
};
|
||||
if next_slice_is_free {
|
||||
let first_slice = slices.get_mut(&first_slice_id).unwrap();
|
||||
let first_slice_eff_size = first_slice.effective_size();
|
||||
let first_slice_offset = first_slice.storage.offset();
|
||||
|
||||
let merged_size = first_slice_eff_size + next_slice_eff_size;
|
||||
first_slice.storage.utilization = StorageUtilization::Slice {
|
||||
size: merged_size,
|
||||
offset: first_slice_offset,
|
||||
};
|
||||
first_slice.padding = 0;
|
||||
|
||||
// Cleanup of the extra slice
|
||||
self.slices.remove(&next_slice_address);
|
||||
slices.remove(&next_slice_id);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn find_slice(&self, address: usize) -> Option<SliceId> {
|
||||
let slice_id = self.slices.get(&address);
|
||||
slice_id.copied()
|
||||
}
|
||||
|
||||
fn insert_slice(&mut self, address: usize, slice: SliceId) {
|
||||
self.slices.insert(address, slice);
|
||||
}
|
||||
|
||||
fn slices_sorted_by_address(&self) -> Vec<SliceId> {
|
||||
let mut entries: Vec<(usize, SliceId)> = self.slices.clone().into_iter().collect();
|
||||
entries.sort_by_key(|&(key, _)| key);
|
||||
let sorted_slices: Vec<SliceId> = entries.into_iter().map(|(_, values)| values).collect();
|
||||
sorted_slices
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct Slice {
|
||||
pub storage: StorageHandle,
|
||||
pub handle: SliceHandle,
|
||||
pub chunk: ChunkHandle,
|
||||
pub padding: usize,
|
||||
}
|
||||
|
||||
impl Slice {
|
||||
pub fn effective_size(&self) -> usize {
|
||||
self.storage.size() + self.padding
|
||||
}
|
||||
}
|
||||
|
||||
const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16;
|
||||
|
||||
pub enum RoundingStrategy {
|
||||
FixedAmount(usize),
|
||||
#[allow(unused)]
|
||||
None,
|
||||
}
|
||||
|
||||
impl RoundingStrategy {
|
||||
fn alloc_size(&self, size: usize) -> usize {
|
||||
match self {
|
||||
RoundingStrategy::FixedAmount(chunk_size) => {
|
||||
assert!(*chunk_size >= size);
|
||||
*chunk_size
|
||||
}
|
||||
RoundingStrategy::None => size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The strategy defines the frequency at which merging of free slices (defragmentation) occurs
|
||||
#[allow(unused)]
|
||||
#[derive(Debug)]
|
||||
pub enum MemoryExtensionStrategy {
|
||||
/// Once every n calls to reserve.
|
||||
PeriodTick {
|
||||
/// Number of calls to be executed before triggering the defragmentation.
|
||||
period: usize,
|
||||
/// Current state. Should start at zero.
|
||||
state: usize,
|
||||
},
|
||||
/// Never defragment.
|
||||
Never,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl MemoryExtensionStrategy {
|
||||
/// Create a new strategy with the given period.
|
||||
pub fn new_period_tick(period: usize) -> Self {
|
||||
MemoryExtensionStrategy::PeriodTick { period, state: 0 }
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn should_extend_max_memory(&mut self) -> bool {
|
||||
match self {
|
||||
MemoryExtensionStrategy::PeriodTick { period, state } => {
|
||||
*state = (*state + 1) % *period;
|
||||
*state == 0
|
||||
}
|
||||
MemoryExtensionStrategy::Never => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryPool {
|
||||
pub fn new(
|
||||
merging_strategy: MemoryExtensionStrategy,
|
||||
alloc_strategy: RoundingStrategy,
|
||||
buffer_alignment: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
memory_extension_strategy: merging_strategy,
|
||||
rounding: alloc_strategy,
|
||||
chunk_index: SearchIndex::new(),
|
||||
ring: RingBuffer::new(buffer_alignment),
|
||||
recently_added_chunks: Vec::new(),
|
||||
recently_allocated_size: 0,
|
||||
buffer_alignment,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the resource from the storage, for the specified handle.
|
||||
pub fn get<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
binding: &MemoryPoolBinding,
|
||||
) -> Option<Storage::Resource> {
|
||||
self.slices
|
||||
.get(binding.slice.id())
|
||||
.map(|s| &s.storage)
|
||||
.map(|h| storage.get(h))
|
||||
}
|
||||
|
||||
/// Reserves memory of specified size using the reserve algorithm, and return
|
||||
/// a handle to the reserved memory.
|
||||
///
|
||||
/// Also clean ups, merging free slices together if permitted by the merging strategy
|
||||
pub fn reserve<Storage: ComputeStorage, Sync: FnOnce()>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
let slice = self.get_free_slice(size);
|
||||
|
||||
match slice {
|
||||
Some(slice) => MemoryPoolHandle {
|
||||
slice: slice.clone(),
|
||||
},
|
||||
None => self.alloc(storage, size, sync),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alloc<Storage: ComputeStorage, Sync: FnOnce()>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
#[allow(unused)] sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
let alloc_size = self.rounding.alloc_size(size);
|
||||
self.alloc_slice(storage, alloc_size, size)
|
||||
}
|
||||
|
||||
fn alloc_slice<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
alloc_size: usize,
|
||||
slice_size: usize,
|
||||
) -> MemoryPoolHandle {
|
||||
let chunk_size = self.rounding.alloc_size(alloc_size);
|
||||
let handle_chunk = self.create_chunk(storage, chunk_size);
|
||||
let chunk_size = self.chunks.get(handle_chunk.id()).unwrap().storage.size();
|
||||
self.recently_added_chunks.push(*handle_chunk.id());
|
||||
self.recently_allocated_size += chunk_size;
|
||||
|
||||
let chunk_id = *handle_chunk.id();
|
||||
let (slice, extra_slice) =
|
||||
self.allocate_slices(handle_chunk.clone(), chunk_size, slice_size);
|
||||
|
||||
let handle_slice = slice.handle.clone();
|
||||
self.update_chunk_metadata(chunk_id, slice, extra_slice);
|
||||
|
||||
MemoryPoolHandle {
|
||||
slice: handle_slice,
|
||||
}
|
||||
}
|
||||
|
||||
fn allocate_slices(
|
||||
&self,
|
||||
handle_chunk: ChunkHandle,
|
||||
alloc_size: usize,
|
||||
slice_size: usize,
|
||||
) -> (Slice, Option<Slice>) {
|
||||
let slice = self.create_slice(0, slice_size, handle_chunk.clone());
|
||||
|
||||
let effective_size = slice.effective_size();
|
||||
|
||||
let extra_slice = if effective_size < alloc_size {
|
||||
Some(self.create_slice(effective_size, alloc_size - effective_size, handle_chunk))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(slice, extra_slice)
|
||||
}
|
||||
|
||||
fn update_chunk_metadata(
|
||||
&mut self,
|
||||
chunk_id: ChunkId,
|
||||
slice: Slice,
|
||||
extra_slice: Option<Slice>,
|
||||
) {
|
||||
let slice_id = *slice.handle.id();
|
||||
let slice_offset = slice.storage.offset();
|
||||
|
||||
self.slices.insert(slice_id, slice);
|
||||
self.chunks
|
||||
.get_mut(&chunk_id)
|
||||
.unwrap()
|
||||
.slices
|
||||
.slices
|
||||
.insert(slice_offset, slice_id);
|
||||
|
||||
if let Some(extra_slice) = extra_slice {
|
||||
let extra_slice_id = *extra_slice.handle.id();
|
||||
let extra_slice_offset = extra_slice.storage.offset();
|
||||
self.slices.insert(extra_slice_id, extra_slice);
|
||||
self.chunks
|
||||
.get_mut(&chunk_id)
|
||||
.unwrap()
|
||||
.slices
|
||||
.slices
|
||||
.insert(extra_slice_offset, extra_slice_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn display_memory_usage(&self) {
|
||||
let total_memory_usage: f64 = self
|
||||
.chunks
|
||||
.values()
|
||||
.map(|chunk| chunk.storage.size() as f64)
|
||||
.sum();
|
||||
let effective_memory_usage: f64 = self
|
||||
.slices
|
||||
.values()
|
||||
.filter(|slice| slice.handle.is_free())
|
||||
.map(|slice| slice.storage.size() as f64)
|
||||
.sum();
|
||||
let ratio = 100.0 * effective_memory_usage / total_memory_usage;
|
||||
log::info!("the memory usage is {ratio}");
|
||||
}
|
||||
|
||||
/// Finds a free slice that can contain the given size
|
||||
/// Returns the chunk's id and size.
|
||||
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
|
||||
if size < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
return None;
|
||||
}
|
||||
|
||||
let padding = calculate_padding(size, self.buffer_alignment);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let slice_id =
|
||||
self.ring
|
||||
.find_free_slice(effective_size, &mut self.chunks, &mut self.slices)?;
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let old_slice_size = slice.effective_size();
|
||||
|
||||
let offset = match slice.storage.utilization {
|
||||
StorageUtilization::Full(_) => 0,
|
||||
StorageUtilization::Slice { offset, size: _ } => offset,
|
||||
};
|
||||
slice.storage.utilization = StorageUtilization::Slice { offset, size };
|
||||
let new_padding = old_slice_size - size;
|
||||
slice.padding = new_padding;
|
||||
assert_eq!(
|
||||
slice.effective_size(),
|
||||
old_slice_size,
|
||||
"new and old slice should have the same size"
|
||||
);
|
||||
|
||||
Some(slice.handle.clone())
|
||||
}
|
||||
|
||||
/// Creates a slice of size `size` upon the given chunk with the given offset.
|
||||
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice {
|
||||
assert_eq!(
|
||||
offset % self.buffer_alignment,
|
||||
0,
|
||||
"slice with offset {offset} needs to be a multiple of {}",
|
||||
self.buffer_alignment
|
||||
);
|
||||
if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
|
||||
}
|
||||
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
|
||||
let handle = SliceHandle::new();
|
||||
|
||||
let storage = StorageHandle {
|
||||
id: chunk.storage.id.clone(),
|
||||
utilization: StorageUtilization::Slice { offset, size },
|
||||
};
|
||||
|
||||
let padding = calculate_padding(size, self.buffer_alignment);
|
||||
|
||||
Slice::new(storage, handle, chunk.handle.clone(), padding)
|
||||
}
|
||||
|
||||
/// Creates a chunk of given size by allocating on the storage.
|
||||
fn create_chunk<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
) -> ChunkHandle {
|
||||
let padding = calculate_padding(size, self.buffer_alignment);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let storage = storage.alloc(effective_size);
|
||||
let handle = ChunkHandle::new();
|
||||
let id = *handle.id();
|
||||
|
||||
self.ring.push_chunk(id);
|
||||
|
||||
self.chunks.insert(
|
||||
id,
|
||||
Chunk::new(storage, handle.clone(), MemoryPage::new(HashMap::new())),
|
||||
);
|
||||
self.chunk_index.insert(id, size);
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn extend_max_memory<Storage: ComputeStorage>(&mut self, storage: &mut Storage) {
|
||||
let mut slices = Vec::<SliceUpdate>::new();
|
||||
|
||||
let mut deallocations = HashSet::<ChunkId>::new();
|
||||
|
||||
let mut chunks_total_size: usize = 0;
|
||||
|
||||
for chunk_id in &self.recently_added_chunks {
|
||||
let chunk = self.chunks.get(chunk_id).unwrap();
|
||||
let chunk_id = *chunk.handle.id();
|
||||
let sorted_slice = chunk.slices.slices_sorted_by_address();
|
||||
for slice_id in sorted_slice {
|
||||
let slice = self.slices.get(&slice_id).unwrap();
|
||||
let size = slice.storage.size();
|
||||
|
||||
slices.push(SliceUpdate { slice_id, size });
|
||||
}
|
||||
chunks_total_size += chunk.storage.size();
|
||||
deallocations.insert(chunk_id);
|
||||
}
|
||||
|
||||
if !slices.is_empty() {
|
||||
self.move_to_new_chunk(chunks_total_size, storage, &mut slices, &mut deallocations);
|
||||
} else {
|
||||
self.deallocate(storage, &mut deallocations);
|
||||
}
|
||||
}
|
||||
|
||||
fn deallocate<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
deallocations: &mut HashSet<ChunkId>,
|
||||
) {
|
||||
for id in deallocations.drain() {
|
||||
let mut chunk = self.chunks.remove(&id).unwrap();
|
||||
self.ring.remove_chunk(id);
|
||||
|
||||
for (_address, slice_id) in chunk.slices.slices.drain() {
|
||||
let slice = self.slices.get(&slice_id).unwrap();
|
||||
let chunk_id = *slice.chunk.id();
|
||||
|
||||
assert_ne!(chunk_id, id, "Chunk id should be updated");
|
||||
}
|
||||
|
||||
self.chunk_index.remove(&id);
|
||||
storage.dealloc(chunk.storage.id);
|
||||
}
|
||||
}
|
||||
|
||||
fn move_to_new_chunk<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
alloc_size: usize,
|
||||
storage: &mut Storage,
|
||||
slices: &mut Vec<SliceUpdate>,
|
||||
deallocations: &mut HashSet<ChunkId>,
|
||||
) {
|
||||
let chunk = self.create_chunk(storage, alloc_size);
|
||||
let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone();
|
||||
let mut offset = 0;
|
||||
let mut slices_ids: Vec<(usize, SliceId)> = Vec::new();
|
||||
|
||||
for update in slices.drain(..) {
|
||||
let slice_id = update.slice_id;
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let old_storage = slice.storage.clone();
|
||||
|
||||
slice.chunk = chunk.clone();
|
||||
slice.storage = StorageHandle {
|
||||
id: storage_id.clone(),
|
||||
utilization: StorageUtilization::Slice {
|
||||
offset,
|
||||
size: update.size,
|
||||
},
|
||||
};
|
||||
storage.copy(&old_storage, &slice.storage);
|
||||
slices_ids.push((offset, slice_id));
|
||||
offset += slice.effective_size();
|
||||
}
|
||||
|
||||
let chunk = self.chunks.get_mut(chunk.id()).unwrap();
|
||||
let chunk_handle = chunk.handle.clone();
|
||||
for (address, slice_id) in slices_ids.drain(..) {
|
||||
chunk.slices.insert_slice(address, slice_id);
|
||||
}
|
||||
let chunk_size = chunk.storage.size();
|
||||
let last_slice_size = chunk_size - offset;
|
||||
assert_eq!(last_slice_size % self.buffer_alignment, 0);
|
||||
if last_slice_size != 0 {
|
||||
self.create_slice(offset, last_slice_size, chunk_handle);
|
||||
}
|
||||
|
||||
self.deallocate(storage, deallocations);
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_padding(size: usize, buffer_alignment: usize) -> usize {
|
||||
let remainder = size % buffer_alignment;
|
||||
if remainder != 0 {
|
||||
buffer_alignment - remainder
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl MemorySlice for Slice {
|
||||
fn is_free(&self) -> bool {
|
||||
self.handle.is_free()
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.effective_size()
|
||||
}
|
||||
|
||||
fn split(&mut self, offset_slice: usize, buffer_alignment: usize) -> Option<Self> {
|
||||
let size_new = self.effective_size() - offset_slice;
|
||||
let offset_new = self.storage.offset() + offset_slice;
|
||||
let old_size = self.effective_size();
|
||||
|
||||
let storage_new = StorageHandle {
|
||||
id: self.storage.id.clone(),
|
||||
utilization: StorageUtilization::Slice {
|
||||
offset: offset_new,
|
||||
size: size_new,
|
||||
},
|
||||
};
|
||||
|
||||
self.storage.utilization = StorageUtilization::Slice {
|
||||
offset: self.storage.offset(),
|
||||
size: offset_slice,
|
||||
};
|
||||
|
||||
if offset_new > 0 && size_new < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
panic!("tried to create slice of size {size_new} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
|
||||
}
|
||||
if offset_new % buffer_alignment != 0 {
|
||||
panic!("slice with offset {offset_new} needs to be a multiple of {buffer_alignment}");
|
||||
}
|
||||
let handle = SliceHandle::new();
|
||||
if size_new < buffer_alignment {
|
||||
self.padding = old_size - offset_slice;
|
||||
assert_eq!(self.effective_size(), old_size);
|
||||
return None;
|
||||
}
|
||||
|
||||
assert!(
|
||||
size_new >= buffer_alignment,
|
||||
"Size new > {buffer_alignment}"
|
||||
);
|
||||
self.padding = 0;
|
||||
let padding = calculate_padding(size_new - buffer_alignment, buffer_alignment);
|
||||
Some(Slice::new(storage_new, handle, self.chunk.clone(), padding))
|
||||
}
|
||||
|
||||
fn id(&self) -> SliceId {
|
||||
*self.handle.id()
|
||||
}
|
||||
|
||||
fn next_slice_position(&self) -> usize {
|
||||
self.storage.offset() + self.effective_size()
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryChunk<Slice> for Chunk {
|
||||
fn merge_next_slice(
|
||||
&mut self,
|
||||
from_slice_index: usize,
|
||||
slices: &mut HashMap<SliceId, Slice>,
|
||||
) -> bool {
|
||||
self.slices.merge_with_next_slice(from_slice_index, slices)
|
||||
}
|
||||
|
||||
fn slice(&self, index: usize) -> Option<SliceId> {
|
||||
self.slices.find_slice(index)
|
||||
}
|
||||
|
||||
fn insert_slice(
|
||||
&mut self,
|
||||
position: usize,
|
||||
slice: Slice,
|
||||
slices: &mut HashMap<SliceId, Slice>,
|
||||
) {
|
||||
self.slices.insert_slice(position, slice.id());
|
||||
slices.insert(slice.id(), slice);
|
||||
}
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
use crate::memory_id_type;
|
||||
use crate::memory_management::{MemoryBinding, MemoryHandle};
|
||||
|
||||
// The ChunkId allows to keep track of how many references there are to a specific chunk.
|
||||
memory_id_type!(ChunkId, ChunkHandle);
|
||||
// The SliceId allows to keep track of how many references there are to a specific slice.
|
||||
memory_id_type!(SliceId, SliceHandle, SliceBinding);
|
||||
|
||||
/// A tensor memory handle, referring to either a chunk or a slice.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryPoolHandle {
|
||||
pub slice: SliceHandle,
|
||||
}
|
||||
|
||||
/// Binding of the [dynamic handle](DynamicHandle).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryPoolBinding {
|
||||
pub slice: SliceBinding,
|
||||
}
|
||||
|
||||
impl MemoryBinding for MemoryPoolBinding {}
|
||||
|
||||
impl MemoryHandle<MemoryPoolBinding> for MemoryPoolHandle {
|
||||
fn can_mut(&self) -> bool {
|
||||
self.slice.can_mut()
|
||||
}
|
||||
|
||||
fn binding(self) -> MemoryPoolBinding {
|
||||
MemoryPoolBinding {
|
||||
slice: self.slice.binding(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,65 +0,0 @@
|
|||
use alloc::collections::BTreeMap;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use core::hash::Hash;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// Data Structure that helps to search items by size efficiently.
|
||||
pub struct SearchIndex<T> {
|
||||
items_per_size: BTreeMap<usize, Vec<T>>,
|
||||
sizes_per_item: HashMap<T, usize>,
|
||||
}
|
||||
|
||||
impl<T: PartialEq + Eq + Hash + Clone> SearchIndex<T> {
|
||||
/// Create a new item search index.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
items_per_size: BTreeMap::new(),
|
||||
sizes_per_item: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a new sized item into the search index.
|
||||
pub fn insert(&mut self, item: T, size: usize) {
|
||||
self.remove(&item);
|
||||
|
||||
if let Some(values) = self.items_per_size.get_mut(&size) {
|
||||
values.push(item.clone())
|
||||
} else {
|
||||
self.items_per_size.insert(size, vec![item.clone()]);
|
||||
}
|
||||
self.sizes_per_item.insert(item, size);
|
||||
}
|
||||
|
||||
/// Find the item by size range.
|
||||
#[allow(unused)]
|
||||
pub fn find_by_size(
|
||||
&self,
|
||||
range: core::ops::Range<usize>,
|
||||
) -> impl DoubleEndedIterator<Item = &T> {
|
||||
self.items_per_size.range(range).flat_map(|a| a.1)
|
||||
}
|
||||
|
||||
/// Remove an item from the index.
|
||||
pub fn remove(&mut self, item: &T) {
|
||||
let size = match self.sizes_per_item.remove(item) {
|
||||
Some(size) => size,
|
||||
None => return,
|
||||
};
|
||||
|
||||
if let Some(values) = self.items_per_size.get_mut(&size) {
|
||||
let mut removed_index = None;
|
||||
|
||||
for (i, v) in values.iter().enumerate() {
|
||||
if v == item {
|
||||
removed_index = Some(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(index) = removed_index {
|
||||
values.remove(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
pub(crate) mod index;
|
||||
mod ring;
|
||||
|
||||
mod base;
|
||||
mod handle;
|
||||
mod small;
|
||||
|
||||
pub use base::*;
|
||||
pub use handle::*;
|
||||
pub use ring::*;
|
||||
pub use small::*;
|
|
@ -1,454 +0,0 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::marker::PhantomData;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use super::{ChunkId, SliceId};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingBuffer<C: MemoryChunk<S>, S: MemorySlice> {
|
||||
queue: Vec<ChunkId>,
|
||||
chunk_positions: HashMap<ChunkId, usize>,
|
||||
cursor_slice: usize,
|
||||
cursor_chunk: usize,
|
||||
_s: PhantomData<S>,
|
||||
_c: PhantomData<C>,
|
||||
buffer_alignment: usize,
|
||||
}
|
||||
|
||||
pub trait MemoryChunk<S: MemorySlice> {
|
||||
fn merge_next_slice(&mut self, slice_position: usize, slices: &mut HashMap<SliceId, S>)
|
||||
-> bool;
|
||||
fn slice(&self, index: usize) -> Option<SliceId>;
|
||||
fn insert_slice(&mut self, position: usize, slice: S, slices: &mut HashMap<SliceId, S>);
|
||||
}
|
||||
|
||||
pub trait MemorySlice: Sized {
|
||||
fn is_free(&self) -> bool;
|
||||
fn size(&self) -> usize;
|
||||
fn split(&mut self, offset: usize, buffer_alignment: usize) -> Option<Self>;
|
||||
fn id(&self) -> SliceId;
|
||||
fn next_slice_position(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
||||
pub fn new(buffer_alignment: usize) -> Self {
|
||||
Self {
|
||||
queue: Vec::new(),
|
||||
chunk_positions: HashMap::new(),
|
||||
cursor_slice: 0,
|
||||
cursor_chunk: 0,
|
||||
_s: PhantomData,
|
||||
_c: PhantomData,
|
||||
buffer_alignment,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_chunk(&mut self, chunk_id: ChunkId) {
|
||||
self.queue.push(chunk_id);
|
||||
self.chunk_positions.insert(chunk_id, self.queue.len() - 1);
|
||||
}
|
||||
|
||||
pub fn remove_chunk(&mut self, chunk_id: ChunkId) {
|
||||
if let Some(position) = self.chunk_positions.remove(&chunk_id) {
|
||||
self.queue.remove(position);
|
||||
}
|
||||
|
||||
self.chunk_positions.clear();
|
||||
|
||||
for (pos, id) in self.queue.iter().enumerate() {
|
||||
self.chunk_positions.insert(*id, pos);
|
||||
}
|
||||
self.cursor_chunk = 0;
|
||||
self.cursor_slice = 0;
|
||||
}
|
||||
|
||||
pub fn find_free_slice(
|
||||
&mut self,
|
||||
size: usize,
|
||||
chunks: &mut HashMap<ChunkId, C>,
|
||||
slices: &mut HashMap<SliceId, S>,
|
||||
) -> Option<SliceId> {
|
||||
let max_second = self.cursor_chunk;
|
||||
let result = self.find_free_slice_in_all_chunks(size, chunks, slices, self.queue.len());
|
||||
|
||||
if result.is_some() {
|
||||
return result;
|
||||
}
|
||||
|
||||
self.cursor_chunk = 0;
|
||||
self.cursor_slice = 0;
|
||||
self.find_free_slice_in_all_chunks(size, chunks, slices, max_second)
|
||||
}
|
||||
|
||||
fn find_free_slice_in_chunk(
|
||||
&mut self,
|
||||
size: usize,
|
||||
chunk: &mut C,
|
||||
slices: &mut HashMap<SliceId, S>,
|
||||
mut slice_index: usize,
|
||||
) -> Option<(usize, SliceId)> {
|
||||
while let Some(slice_id) = chunk.slice(slice_index) {
|
||||
//mutable borrow scope
|
||||
{
|
||||
let slice = slices.get_mut(&slice_id).unwrap();
|
||||
|
||||
let is_big_enough = slice.size() >= size;
|
||||
let is_free = slice.is_free();
|
||||
|
||||
if is_big_enough && is_free {
|
||||
if slice.size() > size {
|
||||
if let Some(new_slice) = slice.split(size, self.buffer_alignment) {
|
||||
let new_slice_id = new_slice.id();
|
||||
chunk.insert_slice(slice.next_slice_position(), new_slice, slices);
|
||||
slices.get(&new_slice_id).unwrap();
|
||||
}
|
||||
}
|
||||
return Some((slice_index, slice_id));
|
||||
}
|
||||
}
|
||||
{
|
||||
let slice = slices.get_mut(&slice_id).unwrap();
|
||||
let is_free = slice.is_free();
|
||||
if is_free && chunk.merge_next_slice(slice_index, slices) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(slice) = slices.get(&slice_id) {
|
||||
slice_index = slice.next_slice_position();
|
||||
} else {
|
||||
panic!("current slice_id should still be valid after potential merge");
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn find_free_slice_in_all_chunks(
|
||||
&mut self,
|
||||
size: usize,
|
||||
chunks: &mut HashMap<ChunkId, C>,
|
||||
slices: &mut HashMap<SliceId, S>,
|
||||
max_cursor_position: usize,
|
||||
) -> Option<SliceId> {
|
||||
let start = self.cursor_chunk;
|
||||
let end = usize::min(self.queue.len(), max_cursor_position);
|
||||
let mut slice_index = self.cursor_slice;
|
||||
|
||||
for chunk_index in start..end {
|
||||
if chunk_index > start {
|
||||
slice_index = 0;
|
||||
}
|
||||
|
||||
if let Some(id) = self.queue.get(chunk_index) {
|
||||
let chunk = chunks.get_mut(id).unwrap();
|
||||
let result = self.find_free_slice_in_chunk(size, chunk, slices, slice_index);
|
||||
|
||||
if let Some((_cursor_slice, slice)) = result {
|
||||
let slice = slices.get(&slice).unwrap();
|
||||
self.cursor_slice = slice.next_slice_position();
|
||||
self.cursor_chunk = chunk_index;
|
||||
return Some(slice.id());
|
||||
}
|
||||
}
|
||||
self.cursor_chunk = chunk_index;
|
||||
self.cursor_slice = 0;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::stub::*;
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn simple_1() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 200, 1);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
|
||||
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
|
||||
let slice = ring.find_free_slice(50, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 0 });
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 50);
|
||||
assert_eq!(slices.len(), 3);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_2() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 200, 1);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
|
||||
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
|
||||
let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 0 });
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 150);
|
||||
assert_eq!(slices.len(), 2);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_chunks() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 200, 1);
|
||||
let slice_3 = new_slice(2, 200, 0);
|
||||
let slice_4 = new_slice(3, 200, 1);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
let chunk_2 = new_chunk(1, vec![2, 3]);
|
||||
|
||||
let mut slices = HashMap::from([
|
||||
(slice_1.id, slice_1),
|
||||
(slice_2.id, slice_2),
|
||||
(slice_3.id, slice_3),
|
||||
(slice_4.id, slice_4),
|
||||
]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
ring.push_chunk(ChunkId { value: 1 });
|
||||
|
||||
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = false;
|
||||
slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = false;
|
||||
|
||||
let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 2 });
|
||||
|
||||
let slice = ring.find_free_slice(100, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_free_slice_with_exact_fit() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 200, 1);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
|
||||
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
|
||||
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = false;
|
||||
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
|
||||
|
||||
let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 1 });
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 200);
|
||||
assert_eq!(slices.len(), 2);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_free_slice_with_merging() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 50, 1);
|
||||
let slice_3 = new_slice(2, 100, 2);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1, 2]);
|
||||
|
||||
let mut slices = HashMap::from([
|
||||
(slice_1.id, slice_1),
|
||||
(slice_2.id, slice_2),
|
||||
(slice_3.id, slice_3),
|
||||
]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
|
||||
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true;
|
||||
|
||||
let slice = ring.find_free_slice(250, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 0 });
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 250);
|
||||
assert_eq!(slices.len(), 1);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_free_slice_with_multiple_chunks_and_merging() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 50, 0);
|
||||
let slice_2 = new_slice(1, 50, 1);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
|
||||
let slice_3 = new_slice(2, 100, 0);
|
||||
let slice_4 = new_slice(3, 50, 1);
|
||||
let chunk_2 = new_chunk(1, vec![2, 3]);
|
||||
|
||||
let mut slices = HashMap::from([
|
||||
(slice_1.id, slice_1),
|
||||
(slice_2.id, slice_2),
|
||||
(slice_3.id, slice_3),
|
||||
(slice_4.id, slice_4),
|
||||
]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
ring.push_chunk(ChunkId { value: 1 });
|
||||
|
||||
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = true;
|
||||
|
||||
let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 150);
|
||||
assert_eq!(slices.len(), 2);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 1);
|
||||
}
|
||||
|
||||
fn new_slice(id: usize, size: usize, position: usize) -> TestSlice {
|
||||
TestSlice {
|
||||
id: SliceId { value: id },
|
||||
is_free: true,
|
||||
size,
|
||||
position,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_chunk(id: usize, slices: Vec<usize>) -> TestChunk {
|
||||
TestChunk {
|
||||
id: ChunkId { value: id },
|
||||
slices: slices.into_iter().map(|i| SliceId { value: i }).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod stub {
|
||||
use super::*;
|
||||
use burn_common::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TestChunk {
|
||||
pub id: ChunkId,
|
||||
pub slices: Vec<SliceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TestSlice {
|
||||
pub id: SliceId,
|
||||
pub is_free: bool,
|
||||
pub size: usize,
|
||||
pub position: usize,
|
||||
}
|
||||
|
||||
impl MemorySlice for TestSlice {
|
||||
fn is_free(&self) -> bool {
|
||||
self.is_free
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
fn split(&mut self, offset: usize, _buffer_alignment: usize) -> Option<Self> {
|
||||
let size_remained = self.size - offset;
|
||||
self.size = offset;
|
||||
|
||||
Some(Self {
|
||||
id: SliceId {
|
||||
value: rand::gen_random(),
|
||||
},
|
||||
is_free: true,
|
||||
size: size_remained,
|
||||
position: self.position + 1,
|
||||
})
|
||||
}
|
||||
|
||||
fn id(&self) -> SliceId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn next_slice_position(&self) -> usize {
|
||||
self.position + 1
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryChunk<TestSlice> for TestChunk {
|
||||
fn merge_next_slice(
|
||||
&mut self,
|
||||
from_slice_index: usize,
|
||||
slices: &mut HashMap<SliceId, TestSlice>,
|
||||
) -> bool {
|
||||
let slice_id_current = self.slices.get(from_slice_index).unwrap();
|
||||
let slice_id_next = self.slices.get(from_slice_index + 1);
|
||||
let slice_id_next = match slice_id_next {
|
||||
Some(val) => val,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
let slice_next = slices.get(slice_id_next).unwrap();
|
||||
let is_free = slice_next.is_free;
|
||||
let size = slice_next.size;
|
||||
|
||||
let slice_current = slices.get_mut(slice_id_current).unwrap();
|
||||
|
||||
if is_free {
|
||||
slice_current.size += size;
|
||||
slices.remove(slice_id_next);
|
||||
self.slices.remove(from_slice_index + 1);
|
||||
|
||||
for (index, temp_slice_id) in self.slices.iter_mut().enumerate() {
|
||||
let slice = slices.get_mut(temp_slice_id).unwrap();
|
||||
slice.position = index;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn slice(&self, index: usize) -> Option<SliceId> {
|
||||
self.slices.get(index).copied()
|
||||
}
|
||||
|
||||
fn insert_slice(
|
||||
&mut self,
|
||||
position: usize,
|
||||
slice: TestSlice,
|
||||
slices: &mut HashMap<SliceId, TestSlice>,
|
||||
) {
|
||||
self.slices.insert(position, slice.id());
|
||||
slices.insert(slice.id(), slice);
|
||||
for (index, temp_slice_id) in self.slices.iter_mut().enumerate() {
|
||||
let temp_slice = slices.get_mut(temp_slice_id).unwrap();
|
||||
temp_slice.position = index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,226 +0,0 @@
|
|||
use super::{ChunkHandle, ChunkId, MemoryPoolBinding, MemoryPoolHandle, SliceHandle, SliceId};
|
||||
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
|
||||
use alloc::vec::Vec;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// A memory pool that allocates fixed-size chunks (32 bytes each) and reuses them to minimize allocations.
|
||||
///
|
||||
/// - Only one slice is supported per chunk due to the limitations in WGPU where small allocations cannot be offset.
|
||||
/// - The pool uses a ring buffer to efficiently manage and reuse chunks.
|
||||
///
|
||||
/// Fields:
|
||||
/// - `chunks`: A hashmap storing the allocated chunks by their IDs.
|
||||
/// - `slices`: A hashmap storing the slices by their IDs.
|
||||
/// - `ring_buffer`: A vector used as a ring buffer to manage chunk reuse.
|
||||
/// - `index`: The current position in the ring buffer.
|
||||
pub struct SmallMemoryPool {
|
||||
chunks: HashMap<ChunkId, SmallChunk>,
|
||||
slices: HashMap<SliceId, SmallSlice>,
|
||||
ring_buffer: Vec<ChunkId>,
|
||||
index: usize,
|
||||
buffer_storage_alignment_offset: usize,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct SmallChunk {
|
||||
pub storage: StorageHandle,
|
||||
#[allow(dead_code)]
|
||||
pub handle: ChunkHandle,
|
||||
pub slice: Option<SliceId>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct SmallSlice {
|
||||
pub storage: StorageHandle,
|
||||
pub handle: SliceHandle,
|
||||
#[allow(dead_code)]
|
||||
pub chunk: ChunkHandle,
|
||||
pub padding: usize,
|
||||
}
|
||||
|
||||
impl SmallSlice {
|
||||
pub fn effective_size(&self) -> usize {
|
||||
self.storage.size() + self.padding
|
||||
}
|
||||
}
|
||||
|
||||
impl SmallMemoryPool {
|
||||
pub fn new(buffer_storage_alignment_offset: usize) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
ring_buffer: Vec::new(),
|
||||
index: 0,
|
||||
buffer_storage_alignment_offset,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the resource from the storage, for the specified handle.
|
||||
pub fn get<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
binding: &MemoryPoolBinding,
|
||||
) -> Option<Storage::Resource> {
|
||||
self.slices
|
||||
.get(binding.slice.id())
|
||||
.map(|s| &s.storage)
|
||||
.map(|h| storage.get(h))
|
||||
}
|
||||
|
||||
/// Reserves memory of specified size using the reserve algorithm, and return
|
||||
/// a handle to the reserved memory.
|
||||
///
|
||||
/// Also clean ups, merging free slices together if permitted by the merging strategy
|
||||
pub fn reserve<Storage: ComputeStorage, Sync: FnOnce()>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
assert!(size <= self.buffer_storage_alignment_offset);
|
||||
let slice = self.get_free_slice(size);
|
||||
|
||||
match slice {
|
||||
Some(slice) => MemoryPoolHandle {
|
||||
slice: slice.clone(),
|
||||
},
|
||||
None => self.alloc(storage, size, sync),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alloc<Storage: ComputeStorage, Sync: FnOnce()>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
_sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
assert!(size <= self.buffer_storage_alignment_offset);
|
||||
|
||||
self.alloc_slice(storage, size)
|
||||
}
|
||||
|
||||
fn alloc_slice<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
slice_size: usize,
|
||||
) -> MemoryPoolHandle {
|
||||
let handle_chunk = self.create_chunk(storage, self.buffer_storage_alignment_offset);
|
||||
let chunk_id = *handle_chunk.id();
|
||||
let slice = self.allocate_slice(handle_chunk.clone(), slice_size);
|
||||
|
||||
let handle_slice = slice.handle.clone();
|
||||
self.update_chunk_metadata(chunk_id, slice);
|
||||
|
||||
MemoryPoolHandle {
|
||||
slice: handle_slice,
|
||||
}
|
||||
}
|
||||
|
||||
fn allocate_slice(&self, handle_chunk: ChunkHandle, slice_size: usize) -> SmallSlice {
|
||||
let slice = self.create_slice(0, slice_size, handle_chunk.clone());
|
||||
|
||||
let effective_size = slice.effective_size();
|
||||
assert_eq!(effective_size, self.buffer_storage_alignment_offset);
|
||||
|
||||
slice
|
||||
}
|
||||
|
||||
fn update_chunk_metadata(&mut self, chunk_id: ChunkId, slice: SmallSlice) {
|
||||
let slice_id = *slice.handle.id();
|
||||
|
||||
self.slices.insert(slice_id, slice);
|
||||
self.chunks.get_mut(&chunk_id).unwrap().slice = Some(slice_id);
|
||||
}
|
||||
|
||||
fn find_free_slice(&mut self) -> Option<SliceId> {
|
||||
if self.ring_buffer.is_empty() {
|
||||
return None;
|
||||
}
|
||||
for _ in 0..self.ring_buffer.len() {
|
||||
let chunk_id = self.ring_buffer.get(self.index).unwrap();
|
||||
let chunk = self.chunks.get(chunk_id).unwrap();
|
||||
let slice = self.slices.get(&chunk.slice.unwrap()).unwrap();
|
||||
self.index = (self.index + 1) % self.ring_buffer.len();
|
||||
if slice.handle.is_free() {
|
||||
return Some(*slice.handle.id());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Finds a free slice that can contain the given size
|
||||
/// Returns the chunk's id and size.
|
||||
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
|
||||
let slice_id = self.find_free_slice()?;
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let old_slice_size = slice.effective_size();
|
||||
|
||||
let offset = match slice.storage.utilization {
|
||||
StorageUtilization::Full(_) => 0,
|
||||
StorageUtilization::Slice { offset, size: _ } => offset,
|
||||
};
|
||||
assert_eq!(offset, 0);
|
||||
slice.storage.utilization = StorageUtilization::Slice { offset, size };
|
||||
let new_padding = old_slice_size - size;
|
||||
slice.padding = new_padding;
|
||||
assert_eq!(
|
||||
slice.effective_size(),
|
||||
old_slice_size,
|
||||
"new and old slice should have the same size"
|
||||
);
|
||||
|
||||
Some(slice.handle.clone())
|
||||
}
|
||||
|
||||
/// Creates a slice of size `size` upon the given chunk with the given offset.
|
||||
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> SmallSlice {
|
||||
assert_eq!(offset, 0);
|
||||
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
|
||||
let handle = SliceHandle::new();
|
||||
|
||||
let storage = StorageHandle {
|
||||
id: chunk.storage.id.clone(),
|
||||
utilization: StorageUtilization::Slice { offset, size },
|
||||
};
|
||||
|
||||
let padding = calculate_padding(size, self.buffer_storage_alignment_offset);
|
||||
|
||||
SmallSlice::new(storage, handle, chunk.handle.clone(), padding)
|
||||
}
|
||||
|
||||
/// Creates a chunk of given size by allocating on the storage.
|
||||
fn create_chunk<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
) -> ChunkHandle {
|
||||
let padding = calculate_padding(size, self.buffer_storage_alignment_offset);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let storage = storage.alloc(effective_size);
|
||||
let handle = ChunkHandle::new();
|
||||
let id = *handle.id();
|
||||
|
||||
self.ring_buffer.push(id);
|
||||
|
||||
self.chunks
|
||||
.insert(id, SmallChunk::new(storage, handle.clone(), None));
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn deallocate<Storage: ComputeStorage>(&mut self, _storage: &mut Storage) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_padding(size: usize, buffer_storage_alignment_offset: usize) -> usize {
|
||||
let remainder = size % buffer_storage_alignment_offset;
|
||||
if remainder != 0 {
|
||||
buffer_storage_alignment_offset - remainder
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
pub(crate) mod memory_pool;
|
||||
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
|
||||
/// Dynamic memory management strategy.
|
||||
pub mod dynamic;
|
||||
/// Simple memory management strategy.
|
||||
pub mod simple;
|
|
@ -1,559 +0,0 @@
|
|||
use crate::{
|
||||
memory_id_type,
|
||||
storage::{ComputeStorage, StorageHandle, StorageUtilization},
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
|
||||
use std::time;
|
||||
#[cfg(all(target_family = "wasm", feature = "std"))]
|
||||
use web_time as time;
|
||||
|
||||
use super::{MemoryBinding, MemoryHandle, MemoryManagement};
|
||||
|
||||
// The ChunkId allows to keep track of how many references there are to a specific chunk.
|
||||
memory_id_type!(ChunkId, ChunkHandle, ChunkBinding);
|
||||
// The SliceId allows to keep track of how many references there are to a specific slice.
|
||||
memory_id_type!(SliceId, SliceHandle, SliceBinding);
|
||||
|
||||
/// A tensor memory handle, referring to either a chunk or a slice.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SimpleHandle {
|
||||
/// A whole chunk of memory.
|
||||
Chunk(ChunkHandle),
|
||||
/// A slice of a chunk of memory.
|
||||
Slice(SliceHandle),
|
||||
}
|
||||
|
||||
/// Binding of the [simple handle](SimpleHandle).
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SimpleBinding {
|
||||
/// Binding of the [chunk handle](ChunkHandle).
|
||||
Chunk(ChunkBinding),
|
||||
/// Binding of the [slice handle](SliceHandle)
|
||||
Slice(SliceBinding),
|
||||
}
|
||||
|
||||
/// The strategy defines the frequency at which deallocation of unused memory chunks should occur.
|
||||
#[derive(Debug)]
|
||||
pub enum DeallocStrategy {
|
||||
/// Once every n calls to reserve.
|
||||
PeriodTick {
|
||||
/// Number of calls to be executed before triggering the deallocation.
|
||||
period: usize,
|
||||
/// Current state. Should start at zero.
|
||||
state: usize,
|
||||
},
|
||||
#[cfg(feature = "std")]
|
||||
/// Once every period of time
|
||||
PeriodTime {
|
||||
/// Number of time before triggering the deallocation.
|
||||
period: time::Duration,
|
||||
/// Current state. Should start at now.
|
||||
state: time::Instant,
|
||||
},
|
||||
/// Never deallocate.
|
||||
Never,
|
||||
}
|
||||
|
||||
/// The strategy defines when to reuse chunk with slices.
|
||||
#[derive(Debug)]
|
||||
pub enum SliceStrategy {
|
||||
/// Never use slices.
|
||||
Never,
|
||||
/// Ratio needed before the chunk can be used as a slice. Between 0 and 1.
|
||||
Ratio(f32),
|
||||
/// When the reserved memory is at least {} bytes.
|
||||
MinimumSize(usize),
|
||||
/// When the reserved memory less than {} bytes.
|
||||
MaximumSize(usize),
|
||||
}
|
||||
|
||||
impl SliceStrategy {
|
||||
/// If the chunk can be used with a slice.
|
||||
pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool {
|
||||
if chunk_size < reserved_size {
|
||||
return false;
|
||||
}
|
||||
|
||||
match self {
|
||||
SliceStrategy::Never => false,
|
||||
SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio,
|
||||
SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes,
|
||||
SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DeallocStrategy {
|
||||
/// Create a new strategy with the given period.
|
||||
pub fn new_period_tick(period: usize) -> Self {
|
||||
DeallocStrategy::PeriodTick { period, state: 0 }
|
||||
}
|
||||
|
||||
fn should_dealloc(&mut self) -> bool {
|
||||
match self {
|
||||
DeallocStrategy::PeriodTick { period, state } => {
|
||||
*state = (*state + 1) % *period;
|
||||
*state == 0
|
||||
}
|
||||
#[cfg(feature = "std")]
|
||||
DeallocStrategy::PeriodTime { period, state } => {
|
||||
if &state.elapsed() > period {
|
||||
*state = time::Instant::now();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
DeallocStrategy::Never => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct Chunk {
|
||||
storage: StorageHandle,
|
||||
handle: ChunkHandle,
|
||||
slices: Vec<SliceId>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct Slice {
|
||||
storage: StorageHandle,
|
||||
handle: SliceHandle,
|
||||
// It is important to keep the chunk handle inside the slice, since it increases the ref count
|
||||
// on the chunk id and make the `is_free` method return false until the slice is freed.
|
||||
//
|
||||
// TL;DR we can't only store the chunk id.
|
||||
chunk: ChunkHandle,
|
||||
}
|
||||
|
||||
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
||||
pub struct SimpleMemoryManagement<Storage> {
|
||||
chunks: HashMap<ChunkId, Chunk>,
|
||||
slices: HashMap<SliceId, Slice>,
|
||||
dealloc_strategy: DeallocStrategy,
|
||||
slice_strategy: SliceStrategy,
|
||||
storage: Storage,
|
||||
}
|
||||
|
||||
impl<Storage> core::fmt::Debug for SimpleMemoryManagement<Storage> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
alloc::format!(
|
||||
"SimpleMemoryManagement {:?} - {:?}",
|
||||
self.dealloc_strategy,
|
||||
core::any::type_name::<Storage>(),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryBinding for SimpleBinding {}
|
||||
|
||||
impl MemoryHandle<SimpleBinding> for SimpleHandle {
|
||||
fn can_mut(&self) -> bool {
|
||||
match &self {
|
||||
SimpleHandle::Chunk(id) => id.can_mut(),
|
||||
SimpleHandle::Slice(id) => id.can_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
fn binding(self) -> SimpleBinding {
|
||||
match self {
|
||||
Self::Chunk(handle) => SimpleBinding::Chunk(handle.binding()),
|
||||
Self::Slice(handle) => SimpleBinding::Slice(handle.binding()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManagement<Storage> {
|
||||
type Handle = SimpleHandle;
|
||||
type Binding = SimpleBinding;
|
||||
|
||||
/// Returns the resource from the storage, for the specified handle.
|
||||
fn get(&mut self, binding: Self::Binding) -> Storage::Resource {
|
||||
let storage = match binding {
|
||||
SimpleBinding::Chunk(chunk) => {
|
||||
&self
|
||||
.chunks
|
||||
.get(chunk.id())
|
||||
.expect("Storage found for the given execution buffer handle")
|
||||
.storage
|
||||
}
|
||||
SimpleBinding::Slice(slice) => {
|
||||
&self
|
||||
.slices
|
||||
.get(slice.id())
|
||||
.expect("Storage found for the given execution buffer handle")
|
||||
.storage
|
||||
}
|
||||
};
|
||||
|
||||
self.storage.get(storage)
|
||||
}
|
||||
|
||||
/// Reserves memory of specified size using the reserve algorithm, and return
|
||||
/// a handle to the reserved memory.
|
||||
///
|
||||
/// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy.
|
||||
fn reserve<Sync: FnOnce()>(&mut self, size: usize, _sync: Sync) -> Self::Handle {
|
||||
self.cleanup_slices();
|
||||
|
||||
let handle = self.reserve_algorithm(size);
|
||||
|
||||
if self.dealloc_strategy.should_dealloc() {
|
||||
self.cleanup_chunks();
|
||||
}
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, _sync: Sync) -> Self::Handle {
|
||||
self.create_chunk(size)
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, binding: Self::Binding) {
|
||||
match binding {
|
||||
SimpleBinding::Chunk(chunk) => {
|
||||
if let Some(chunk) = self.chunks.remove(chunk.id()) {
|
||||
self.storage.dealloc(chunk.storage.id);
|
||||
}
|
||||
}
|
||||
SimpleBinding::Slice(_) => panic!("Can't dealloc slice manually"),
|
||||
}
|
||||
}
|
||||
|
||||
fn storage(&mut self) -> &mut Storage {
|
||||
&mut self.storage
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
|
||||
/// Creates a new instance using the given storage, deallocation strategy and slice strategy.
|
||||
pub fn new(
|
||||
storage: Storage,
|
||||
dealloc_strategy: DeallocStrategy,
|
||||
slice_strategy: SliceStrategy,
|
||||
) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
dealloc_strategy,
|
||||
slice_strategy,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
|
||||
fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle {
|
||||
// Looks for a large enough, existing but unused chunk of memory.
|
||||
let chunk = self.find_free_chunk(size);
|
||||
|
||||
match chunk {
|
||||
Some(chunk) => {
|
||||
if size == chunk.storage.size() {
|
||||
// If there is one of exactly the same size, it reuses it.
|
||||
SimpleHandle::Chunk(chunk.handle.clone())
|
||||
} else {
|
||||
// Otherwise creates a slice of the right size upon it, always starting at zero.
|
||||
self.create_slice(size, chunk.handle.clone())
|
||||
}
|
||||
}
|
||||
// If no chunk available, creates one of exactly the right size.
|
||||
None => self.create_chunk(size),
|
||||
}
|
||||
}
|
||||
|
||||
/// Finds the smallest of the free and large enough chunks to fit `size`
|
||||
/// Returns the chunk's id and size.
|
||||
fn find_free_chunk(&self, size: usize) -> Option<&Chunk> {
|
||||
let mut size_diff_current = usize::MAX;
|
||||
let mut current = None;
|
||||
|
||||
for chunk in self.chunks.values() {
|
||||
// If chunk is already used, we do not choose it
|
||||
if !chunk.handle.is_free() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let storage_size = chunk.storage.size();
|
||||
|
||||
// If we find a chunk of exactly the right size, we stop searching altogether
|
||||
if size == storage_size {
|
||||
current = Some(chunk);
|
||||
break;
|
||||
}
|
||||
|
||||
// Finds the smallest of the large enough chunks that can accept a slice
|
||||
// of the given size
|
||||
if self.slice_strategy.can_use_chunk(storage_size, size) {
|
||||
let size_diff = storage_size - size;
|
||||
|
||||
if size_diff < size_diff_current {
|
||||
current = Some(chunk);
|
||||
size_diff_current = size_diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
current
|
||||
}
|
||||
|
||||
/// Creates a slice of size `size` upon the given chunk.
|
||||
///
|
||||
/// For now slices must start at zero, therefore there can be only one per chunk
|
||||
fn create_slice(&mut self, size: usize, handle_chunk: ChunkHandle) -> SimpleHandle {
|
||||
let chunk = self.chunks.get_mut(handle_chunk.id()).unwrap();
|
||||
let handle_slice = SliceHandle::new();
|
||||
|
||||
let storage = StorageHandle {
|
||||
id: chunk.storage.id.clone(),
|
||||
utilization: StorageUtilization::Slice { offset: 0, size },
|
||||
};
|
||||
|
||||
if chunk.slices.is_empty() {
|
||||
self.slices.insert(
|
||||
*handle_slice.id(),
|
||||
Slice::new(storage, handle_slice.clone(), handle_chunk.clone()),
|
||||
);
|
||||
} else {
|
||||
panic!("Can't have more than 1 slice yet.");
|
||||
}
|
||||
|
||||
chunk.slices.push(*handle_slice.id());
|
||||
|
||||
SimpleHandle::Slice(handle_slice)
|
||||
}
|
||||
|
||||
/// Creates a chunk of given size by allocating on the storage.
|
||||
fn create_chunk(&mut self, size: usize) -> SimpleHandle {
|
||||
let storage = self.storage.alloc(size);
|
||||
let handle = ChunkHandle::new();
|
||||
|
||||
self.chunks.insert(
|
||||
*handle.id(),
|
||||
Chunk::new(storage, handle.clone(), Vec::new()),
|
||||
);
|
||||
|
||||
SimpleHandle::Chunk(handle)
|
||||
}
|
||||
|
||||
/// Deallocates free chunks and remove them from chunks map.
|
||||
fn cleanup_chunks(&mut self) {
|
||||
let mut ids_to_remove = Vec::new();
|
||||
|
||||
self.chunks.iter().for_each(|(chunk_id, chunk)| {
|
||||
if chunk.handle.is_free() {
|
||||
ids_to_remove.push(*chunk_id);
|
||||
}
|
||||
});
|
||||
|
||||
ids_to_remove
|
||||
.iter()
|
||||
.map(|chunk_id| self.chunks.remove(chunk_id).unwrap())
|
||||
.for_each(|chunk| {
|
||||
self.storage.dealloc(chunk.storage.id);
|
||||
});
|
||||
}
|
||||
|
||||
/// Removes free slices from slice map and corresponding chunks.
|
||||
fn cleanup_slices(&mut self) {
|
||||
let mut ids_to_remove = Vec::new();
|
||||
|
||||
self.slices.iter().for_each(|(slice_id, slice)| {
|
||||
if slice.handle.is_free() {
|
||||
ids_to_remove.push(*slice_id);
|
||||
}
|
||||
});
|
||||
|
||||
ids_to_remove
|
||||
.iter()
|
||||
.map(|slice_id| self.slices.remove(slice_id).unwrap())
|
||||
.for_each(|slice| {
|
||||
let chunk = self.chunks.get_mut(slice.chunk.id()).unwrap();
|
||||
chunk.slices.retain(|id| id != slice.handle.id());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
storage::BytesStorage,
|
||||
};
|
||||
|
||||
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
|
||||
fn reserve_no_sync(&mut self, size: usize) -> SimpleHandle {
|
||||
self.reserve(size, || {})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mut_with_single_tensor_reference() {
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
let chunk_size = 4;
|
||||
let simple_handle = memory_management.create_chunk(chunk_size);
|
||||
|
||||
let x = simple_handle.clone();
|
||||
core::mem::drop(simple_handle);
|
||||
|
||||
assert!(x.can_mut());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_tensor_references_remove_mutability() {
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
let chunk_size = 4;
|
||||
let simple_handle = memory_management.create_chunk(chunk_size);
|
||||
|
||||
let x = simple_handle.clone();
|
||||
|
||||
assert!(!simple_handle.can_mut());
|
||||
assert!(!x.can_mut())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
let chunk_size = 4;
|
||||
let _chunk_handle = memory_management.reserve_no_sync(chunk_size);
|
||||
let _new_handle = memory_management.reserve_no_sync(chunk_size);
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_empty_chunk_is_cleaned_upexists_it_disappears() {
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
let chunk_size = 4;
|
||||
let chunk_handle = memory_management.reserve_no_sync(chunk_size);
|
||||
drop(chunk_handle);
|
||||
memory_management.cleanup_chunks();
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn never_dealloc_strategy_never_deallocs() {
|
||||
let mut never_dealloc = DeallocStrategy::Never;
|
||||
for _ in 0..20 {
|
||||
assert!(!never_dealloc.should_dealloc())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn period_tick_dealloc_strategy_should_dealloc_after_period() {
|
||||
let period = 3;
|
||||
let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period);
|
||||
|
||||
for _ in 0..3 {
|
||||
for _ in 0..period - 1 {
|
||||
assert!(!period_tick_dealloc.should_dealloc());
|
||||
}
|
||||
assert!(period_tick_dealloc.should_dealloc());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_minimum_bytes() {
|
||||
let strategy = SliceStrategy::MinimumSize(100);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 101));
|
||||
assert!(!strategy.can_use_chunk(200, 99));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_maximum_bytes() {
|
||||
let strategy = SliceStrategy::MaximumSize(100);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 99));
|
||||
assert!(!strategy.can_use_chunk(200, 101));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_ratio() {
|
||||
let strategy = SliceStrategy::Ratio(0.9);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 180));
|
||||
assert!(!strategy.can_use_chunk(200, 179));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_mutability() {
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Ratio(0.5),
|
||||
);
|
||||
let handle = memory_management.reserve_no_sync(10);
|
||||
|
||||
let other_ref = handle.clone();
|
||||
|
||||
assert!(!handle.can_mut(), "Handle can't be mut when multiple ref.");
|
||||
drop(other_ref);
|
||||
assert!(handle.can_mut(), "Handle should be mut when only one ref.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_mutability() {
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
DeallocStrategy::Never,
|
||||
SliceStrategy::Ratio(0.5),
|
||||
);
|
||||
let chunk = memory_management.reserve_no_sync(10);
|
||||
|
||||
if let super::SimpleHandle::Slice(_) = chunk {
|
||||
panic!("Should be a chunk.")
|
||||
}
|
||||
|
||||
drop(chunk);
|
||||
|
||||
let slice = memory_management.reserve_no_sync(8);
|
||||
|
||||
if let super::SimpleHandle::Chunk(_) = &slice {
|
||||
panic!("Should be a slice.")
|
||||
}
|
||||
|
||||
if let super::SimpleHandle::Slice(slice) = slice {
|
||||
let other_ref = slice.clone();
|
||||
|
||||
assert!(
|
||||
!slice.can_mut(),
|
||||
"Slice can't be mut when multiple ref to the same handle."
|
||||
);
|
||||
drop(other_ref);
|
||||
assert!(
|
||||
slice.can_mut(),
|
||||
"Slice should be mut when only one ref to the same handle."
|
||||
);
|
||||
assert!(
|
||||
!slice.is_free(),
|
||||
"Slice can't be reallocated when one ref still exist."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
use crate::{
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
storage::ComputeStorage,
|
||||
tune::AutotuneKey,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::{reader::Reader, sync_type::SyncType};
|
||||
use core::fmt::Debug;
|
||||
|
||||
/// The compute server is responsible for handling resources and computations over resources.
|
||||
///
|
||||
/// Everything in the server is mutable, therefore it should be solely accessed through the
|
||||
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
|
||||
pub trait ComputeServer: Send + core::fmt::Debug
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
/// The kernel type defines the computation algorithms.
|
||||
type Kernel: Send;
|
||||
/// Options when dispatching the kernel, eg. the number of executions.
|
||||
type DispatchOptions: Send;
|
||||
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
|
||||
type Storage: ComputeStorage;
|
||||
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
||||
type MemoryManagement: MemoryManagement<Self::Storage>;
|
||||
/// The key used to cache operations used on specific inputs in autotune
|
||||
type AutotuneKey: AutotuneKey;
|
||||
/// Features supported by the compute server.
|
||||
type FeatureSet: Send + Sync;
|
||||
|
||||
/// Given a handle, returns the owned resource as bytes.
|
||||
fn read(&mut self, binding: Binding<Self>) -> Reader;
|
||||
|
||||
/// Given a resource handle, returns the storage resource.
|
||||
fn get_resource(
|
||||
&mut self,
|
||||
binding: Binding<Self>,
|
||||
) -> <Self::Storage as ComputeStorage>::Resource;
|
||||
|
||||
/// Given a resource as bytes, stores it and returns the memory handle.
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self>;
|
||||
|
||||
/// Reserves `size` bytes in the storage, and returns a handle over them.
|
||||
fn empty(&mut self, size: usize) -> Handle<Self>;
|
||||
|
||||
/// Executes the `kernel` over the given memory `handles`.
|
||||
///
|
||||
/// Kernels have mutable access to every resource they are given
|
||||
/// and are responsible of determining which should be read or written.
|
||||
fn execute(
|
||||
&mut self,
|
||||
kernel: Self::Kernel,
|
||||
count: Self::DispatchOptions,
|
||||
bindings: Vec<Binding<Self>>,
|
||||
);
|
||||
|
||||
/// Wait for the completion of every task in the server.
|
||||
fn sync(&mut self, command: SyncType);
|
||||
}
|
||||
|
||||
/// Server handle containing the [memory handle](MemoryManagement::Handle).
|
||||
#[derive(new, Debug)]
|
||||
pub struct Handle<Server: ComputeServer> {
|
||||
/// Memory handle.
|
||||
pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Handle,
|
||||
}
|
||||
|
||||
/// Binding of a [tensor handle](Handle) to execute a kernel.
|
||||
#[derive(new, Debug)]
|
||||
pub struct Binding<Server: ComputeServer> {
|
||||
/// Memory binding.
|
||||
pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Binding,
|
||||
}
|
||||
|
||||
impl<Server: ComputeServer> Handle<Server> {
|
||||
/// If the tensor handle can be reused inplace.
|
||||
pub fn can_mut(&self) -> bool {
|
||||
MemoryHandle::can_mut(&self.memory)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server: ComputeServer> Handle<Server> {
|
||||
/// Convert the [handle](Handle) into a [binding](Binding).
|
||||
pub fn binding(self) -> Binding<Server> {
|
||||
Binding {
|
||||
memory: MemoryHandle::binding(self.memory),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server: ComputeServer> Clone for Handle<Server> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
memory: self.memory.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Server: ComputeServer> Clone for Binding<Server> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
memory: self.memory.clone(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
use crate::storage_id_type;
|
||||
|
||||
// This ID is used to map a handle to its actual data.
|
||||
storage_id_type!(StorageId);
|
||||
|
||||
/// Defines if data uses a full memory chunk or a slice of it.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum StorageUtilization {
|
||||
/// Full memory chunk of specified size
|
||||
Full(usize),
|
||||
/// Slice of memory chunk with start index and size.
|
||||
Slice {
|
||||
/// The offset in bytes from the chunk start.
|
||||
offset: usize,
|
||||
/// The size of the slice in bytes.
|
||||
size: usize,
|
||||
},
|
||||
}
|
||||
|
||||
/// Contains the [storage id](StorageId) of a resource and the way it is used.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct StorageHandle {
|
||||
/// Storage id.
|
||||
pub id: StorageId,
|
||||
/// How the storage is used.
|
||||
pub utilization: StorageUtilization,
|
||||
}
|
||||
|
||||
impl StorageHandle {
|
||||
/// Returns the size the handle is pointing to in memory.
|
||||
pub fn size(&self) -> usize {
|
||||
match self.utilization {
|
||||
StorageUtilization::Full(size) => size,
|
||||
StorageUtilization::Slice { offset: _, size } => size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the size the handle is pointing to in memory.
|
||||
pub fn offset(&self) -> usize {
|
||||
match self.utilization {
|
||||
StorageUtilization::Full(..) => panic!("full size slice not supported anymore"),
|
||||
StorageUtilization::Slice { offset, .. } => offset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Storage types are responsible for allocating and deallocating memory.
|
||||
pub trait ComputeStorage: Send {
|
||||
/// The resource associated type determines the way data is implemented and how
|
||||
/// it can be accessed by kernels.
|
||||
type Resource: Send;
|
||||
|
||||
/// Returns the underlying resource for a specified storage handle
|
||||
fn get(&mut self, handle: &StorageHandle) -> Self::Resource;
|
||||
|
||||
/// Allocates `size` units of memory and returns a handle to it
|
||||
fn alloc(&mut self, size: usize) -> StorageHandle;
|
||||
|
||||
/// Deallocates the memory pointed by the given storage id.
|
||||
fn dealloc(&mut self, id: StorageId);
|
||||
|
||||
/// Copy
|
||||
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle);
|
||||
}
|
|
@ -1,150 +0,0 @@
|
|||
use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
|
||||
use alloc::alloc::{alloc, dealloc, Layout};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
|
||||
#[derive(Default)]
|
||||
pub struct BytesStorage {
|
||||
memory: HashMap<StorageId, AllocatedBytes>,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for BytesStorage {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str("BytesStorage")
|
||||
}
|
||||
}
|
||||
|
||||
/// Can send to other threads.
|
||||
unsafe impl Send for BytesStorage {}
|
||||
unsafe impl Send for BytesResource {}
|
||||
|
||||
/// This struct is a pointer to a memory chunk or slice.
|
||||
pub struct BytesResource {
|
||||
ptr: *mut u8,
|
||||
utilization: StorageUtilization,
|
||||
}
|
||||
|
||||
/// This struct refers to a specific (contiguous) layout of bytes.
|
||||
struct AllocatedBytes {
|
||||
ptr: *mut u8,
|
||||
layout: Layout,
|
||||
}
|
||||
|
||||
impl BytesResource {
|
||||
fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
|
||||
match self.utilization {
|
||||
StorageUtilization::Full(len) => (self.ptr, len),
|
||||
StorageUtilization::Slice { offset, size } => unsafe { (self.ptr.add(offset), size) },
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the resource as a mutable slice of bytes.
|
||||
pub fn write<'a>(&self) -> &'a mut [u8] {
|
||||
let (ptr, len) = self.get_exact_location_and_length();
|
||||
|
||||
unsafe { core::slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
|
||||
/// Returns the resource as an immutable slice of bytes.
|
||||
pub fn read<'a>(&self) -> &'a [u8] {
|
||||
let (ptr, len) = self.get_exact_location_and_length();
|
||||
|
||||
unsafe { core::slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ComputeStorage for BytesStorage {
|
||||
type Resource = BytesResource;
|
||||
|
||||
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
|
||||
let allocated_bytes = self.memory.get_mut(&handle.id).unwrap();
|
||||
|
||||
BytesResource {
|
||||
ptr: allocated_bytes.ptr,
|
||||
utilization: handle.utilization.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> StorageHandle {
|
||||
let id = StorageId::new();
|
||||
let handle = StorageHandle {
|
||||
id: id.clone(),
|
||||
utilization: StorageUtilization::Full(size),
|
||||
};
|
||||
|
||||
unsafe {
|
||||
let layout = Layout::array::<u8>(size).unwrap();
|
||||
let ptr = alloc(layout);
|
||||
let memory = AllocatedBytes { ptr, layout };
|
||||
|
||||
self.memory.insert(id, memory);
|
||||
}
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, id: StorageId) {
|
||||
if let Some(memory) = self.memory.remove(&id) {
|
||||
unsafe {
|
||||
dealloc(memory.ptr, memory.layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
|
||||
assert_eq!(from.size(), to.size());
|
||||
|
||||
let input = self.get(from);
|
||||
let output = self.get(to);
|
||||
|
||||
for i in 0..from.size() {
|
||||
let offset = i + from.offset();
|
||||
let ptr_out = output.ptr.wrapping_add(offset);
|
||||
|
||||
let offset = i + to.offset();
|
||||
let ptr_in = input.ptr.wrapping_add(offset);
|
||||
|
||||
unsafe { *ptr_in = *ptr_out }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_can_alloc_and_dealloc() {
|
||||
let mut storage = BytesStorage::default();
|
||||
let handle_1 = storage.alloc(64);
|
||||
|
||||
assert_eq!(handle_1.size(), 64);
|
||||
storage.dealloc(handle_1.id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slices() {
|
||||
let mut storage = BytesStorage::default();
|
||||
let handle_1 = storage.alloc(64);
|
||||
let handle_2 = StorageHandle::new(
|
||||
handle_1.id.clone(),
|
||||
StorageUtilization::Slice {
|
||||
offset: 24,
|
||||
size: 8,
|
||||
},
|
||||
);
|
||||
|
||||
storage
|
||||
.get(&handle_1)
|
||||
.write()
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, b)| {
|
||||
*b = i as u8;
|
||||
});
|
||||
|
||||
let bytes = storage.get(&handle_2).read().to_vec();
|
||||
storage.dealloc(handle_1.id);
|
||||
assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
|
||||
}
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
|
||||
#[cfg(feature = "storage-bytes")]
|
||||
mod bytes_cpu;
|
||||
#[cfg(feature = "storage-bytes")]
|
||||
pub use bytes_cpu::*;
|
|
@ -1,9 +0,0 @@
|
|||
mod operation;
|
||||
mod tune_benchmark;
|
||||
mod tune_cache;
|
||||
mod tuner;
|
||||
|
||||
pub use operation::*;
|
||||
pub use tune_benchmark::*;
|
||||
pub use tune_cache::*;
|
||||
pub use tuner::*;
|
|
@ -1,69 +0,0 @@
|
|||
use alloc::boxed::Box;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt::{Debug, Display};
|
||||
use core::hash::Hash;
|
||||
|
||||
/// Default checksum for an operation set
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
pub fn compute_checksum(autotunables: &[Box<dyn AutotuneOperation>]) -> String {
|
||||
let mut checksum = String::new();
|
||||
autotunables.iter().for_each(|op| {
|
||||
checksum += op.name();
|
||||
});
|
||||
format!("{:x}", md5::compute(checksum))
|
||||
}
|
||||
|
||||
/// Groups operations of the same type for autotune
|
||||
pub trait AutotuneOperationSet<K>: Send {
|
||||
/// The key used in the tune cache
|
||||
fn key(&self) -> K;
|
||||
|
||||
/// All candidate operations for autotuning this operation type
|
||||
/// Operations can run on toy tensors of relevant size
|
||||
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>>;
|
||||
|
||||
/// Returns the operation for the given index, matching the order
|
||||
/// returned by autotunables. Operation obtained here runs on original tensors
|
||||
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation>;
|
||||
|
||||
/// Compute a checksum that can invalidate outdated cached auto-tune results.
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
fn compute_checksum(&self) -> String {
|
||||
compute_checksum(&self.autotunables())
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains operation to run and inputs on which to run it
|
||||
pub trait AutotuneOperation {
|
||||
/// Runs the operation
|
||||
fn execute(self: Box<Self>);
|
||||
|
||||
/// The name of the operation.
|
||||
fn name(&self) -> &str {
|
||||
core::any::type_name::<Self>()
|
||||
}
|
||||
|
||||
/// Clones the operation and inputs
|
||||
fn clone(&self) -> Box<dyn AutotuneOperation>;
|
||||
}
|
||||
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
/// Trait alias with support for persistent caching
|
||||
pub trait AutotuneKey:
|
||||
Clone
|
||||
+ Debug
|
||||
+ PartialEq
|
||||
+ Eq
|
||||
+ Hash
|
||||
+ Display
|
||||
+ serde::Serialize
|
||||
+ serde::de::DeserializeOwned
|
||||
+ Send
|
||||
+ Sync
|
||||
{
|
||||
}
|
||||
#[cfg(not(feature = "autotune-persistent-cache"))]
|
||||
/// Trait alias
|
||||
pub trait AutotuneKey: Clone + Debug + PartialEq + Eq + Hash + Display {}
|
||||
impl AutotuneKey for String {}
|
|
@ -1,48 +0,0 @@
|
|||
use burn_common::benchmark::Benchmark;
|
||||
use burn_common::sync_type::SyncType;
|
||||
|
||||
use crate::channel::ComputeChannel;
|
||||
use crate::client::ComputeClient;
|
||||
use crate::server::ComputeServer;
|
||||
|
||||
use super::AutotuneOperation;
|
||||
use alloc::boxed::Box;
|
||||
use alloc::string::{String, ToString};
|
||||
|
||||
/// A benchmark that runs on server handles
|
||||
#[derive(new)]
|
||||
pub struct TuneBenchmark<S: ComputeServer, C> {
|
||||
operation: Box<dyn AutotuneOperation>,
|
||||
client: ComputeClient<S, C>,
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn AutotuneOperation> {
|
||||
fn clone(&self) -> Self {
|
||||
self.as_ref().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
|
||||
type Args = Box<dyn AutotuneOperation>;
|
||||
|
||||
fn prepare(&self) -> Self::Args {
|
||||
self.operation.clone()
|
||||
}
|
||||
|
||||
fn num_samples(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn execute(&self, operation: Self::Args) {
|
||||
AutotuneOperation::execute(operation);
|
||||
}
|
||||
|
||||
fn name(&self) -> String {
|
||||
"autotune".to_string()
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
// For benchmarks - we need to wait for all tasks to complete before returning.
|
||||
self.client.sync(SyncType::Wait);
|
||||
}
|
||||
}
|
|
@ -1,243 +0,0 @@
|
|||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
mod std_imports {
|
||||
pub use std::fs;
|
||||
pub use std::fs::File;
|
||||
pub use std::io;
|
||||
pub use std::path::Path;
|
||||
pub use std::path::PathBuf;
|
||||
}
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
use std_imports::*;
|
||||
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::AutotuneKey;
|
||||
use super::AutotuneOperation;
|
||||
use super::AutotuneOperationSet;
|
||||
use alloc::boxed::Box;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
/// Return the file path for the persistent cache on disk
|
||||
/// prefix should be the device id computed at the backend level
|
||||
pub fn get_persistent_cache_file_path(prefix: &str) -> PathBuf {
|
||||
let home_dir = dirs::home_dir().expect("An home directory should exist");
|
||||
let path_dir = home_dir.join(".cache").join("burn").join("autotune");
|
||||
let path = Path::new(&path_dir);
|
||||
path.join(format!("{}-autotune-cache.json", prefix))
|
||||
}
|
||||
|
||||
/// In-memory cache entry
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct InMemoryCacheEntry {
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
checksum_checked: bool,
|
||||
fastest_index: usize,
|
||||
}
|
||||
|
||||
/// Persistent cache entry
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(crate) struct PersistentCacheEntry {
|
||||
checksum: String,
|
||||
fastest_index: usize,
|
||||
}
|
||||
|
||||
/// Use to find and reuse the best kernel for some input
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TuneCache<K> {
|
||||
in_memory_cache: HashMap<K, InMemoryCacheEntry>,
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
persistent_cache: HashMap<K, PersistentCacheEntry>,
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
device_id: String,
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
name: String,
|
||||
}
|
||||
|
||||
/// Result of the cache try
|
||||
pub enum TuneCacheResult<K> {
|
||||
/// An operation is found and given
|
||||
Hit(Box<dyn AutotuneOperation>),
|
||||
/// No operation is found and the set is given back for ownership
|
||||
Miss(Box<dyn AutotuneOperationSet<K>>),
|
||||
}
|
||||
|
||||
impl<K: AutotuneKey> TuneCache<K> {
|
||||
pub(crate) fn new(
|
||||
#[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))] name: &str,
|
||||
#[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))]
|
||||
device_id: &str,
|
||||
) -> Self {
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
{
|
||||
let mut cache = TuneCache {
|
||||
in_memory_cache: HashMap::new(),
|
||||
persistent_cache: HashMap::new(),
|
||||
device_id: device_id.to_string(),
|
||||
name: name.to_string(),
|
||||
};
|
||||
if let Err(e) = cache.load() {
|
||||
log::warn!(
|
||||
"Unable to load autotune cache. Cache will be ignored ({}).",
|
||||
e
|
||||
);
|
||||
}
|
||||
cache
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "autotune-persistent-cache"))]
|
||||
{
|
||||
TuneCache {
|
||||
in_memory_cache: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn find_fastest(&self, key: &K) -> Option<usize> {
|
||||
let val = self.in_memory_cache.get(key)?;
|
||||
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
if val.checksum_checked {
|
||||
Some(val.fastest_index)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "autotune-persistent-cache"))]
|
||||
Some(val.fastest_index)
|
||||
}
|
||||
|
||||
pub(crate) fn try_cache(
|
||||
&mut self,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
|
||||
) -> TuneCacheResult<K> {
|
||||
let key = autotune_operation_set.key();
|
||||
let result = self.in_memory_cache.get_mut(&key);
|
||||
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
{
|
||||
if let Some(InMemoryCacheEntry {
|
||||
checksum_checked,
|
||||
fastest_index,
|
||||
}) = result
|
||||
{
|
||||
if !*checksum_checked {
|
||||
let checksum = autotune_operation_set.compute_checksum();
|
||||
let persistent_entry = self
|
||||
.persistent_cache
|
||||
.get(&key)
|
||||
.expect("Both caches should be in sync");
|
||||
if checksum != persistent_entry.checksum {
|
||||
return TuneCacheResult::Miss(autotune_operation_set);
|
||||
}
|
||||
*checksum_checked = true;
|
||||
}
|
||||
return TuneCacheResult::Hit(autotune_operation_set.fastest(*fastest_index));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "autotune-persistent-cache"))]
|
||||
{
|
||||
if let Some(InMemoryCacheEntry { fastest_index, .. }) = result {
|
||||
return TuneCacheResult::Hit(autotune_operation_set.fastest(*fastest_index));
|
||||
}
|
||||
}
|
||||
|
||||
TuneCacheResult::Miss(autotune_operation_set)
|
||||
}
|
||||
|
||||
pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
|
||||
self.in_memory_cache.insert(
|
||||
key,
|
||||
InMemoryCacheEntry {
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
checksum_checked: true,
|
||||
fastest_index,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
pub(crate) fn persistent_cache_insert(
|
||||
&mut self,
|
||||
key: K,
|
||||
checksum: String,
|
||||
fastest_index: usize,
|
||||
) {
|
||||
self.persistent_cache.insert(
|
||||
key,
|
||||
PersistentCacheEntry {
|
||||
checksum,
|
||||
fastest_index,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Load the persistent cache data from disk
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
pub(crate) fn load(&mut self) -> Result<(), io::Error> {
|
||||
let file_path = self.get_persistent_cache_file_path();
|
||||
// note: reading file from memory is faster than using
|
||||
// serde from_reader with a buffered reader
|
||||
// see issue:
|
||||
// https://github.com/serde-rs/json/issues/160
|
||||
match fs::read_to_string(file_path) {
|
||||
Ok(data) => {
|
||||
let data: Vec<(K, PersistentCacheEntry)> = serde_json::from_str(&data)?;
|
||||
for (key, value) in data.into_iter() {
|
||||
self.persistent_cache.insert(key, value);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}?;
|
||||
for (key, entry) in self.persistent_cache.iter() {
|
||||
self.in_memory_cache.insert(
|
||||
key.clone(),
|
||||
InMemoryCacheEntry {
|
||||
checksum_checked: false,
|
||||
fastest_index: entry.fastest_index,
|
||||
},
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Save the persistent cache on disk
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
pub(crate) fn save(&self) {
|
||||
let file_path = self.get_persistent_cache_file_path();
|
||||
if let Some(parent_dir) = file_path.parent() {
|
||||
if !parent_dir.exists() {
|
||||
fs::create_dir_all(parent_dir).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Should be able to create directory '{}' for autotune persistent cache file",
|
||||
parent_dir.to_str().unwrap())
|
||||
});
|
||||
}
|
||||
}
|
||||
let file = File::create(file_path.clone()).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Should be able to open autotune persistent cache file '{}'",
|
||||
file_path.to_str().unwrap()
|
||||
)
|
||||
});
|
||||
let data = self.persistent_cache.iter().collect::<Vec<_>>();
|
||||
serde_json::to_writer_pretty(file, &data)
|
||||
.expect("Should be able to write to autotune persistent cache");
|
||||
}
|
||||
|
||||
/// Return the file path for the persistent cache on disk
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
pub fn get_persistent_cache_file_path(&self) -> PathBuf {
|
||||
get_persistent_cache_file_path(&format!("{}-{}", self.name, self.device_id))
|
||||
}
|
||||
}
|
|
@ -1,124 +0,0 @@
|
|||
#[cfg(target_family = "wasm")]
|
||||
use web_time::Duration;
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
use core::time::Duration;
|
||||
|
||||
use alloc::boxed::Box;
|
||||
use alloc::string::ToString;
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::benchmark::{Benchmark, BenchmarkComputations, BenchmarkDurations};
|
||||
|
||||
use crate::channel::ComputeChannel;
|
||||
use crate::client::ComputeClient;
|
||||
use crate::server::ComputeServer;
|
||||
use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCache};
|
||||
|
||||
use super::AutotuneKey;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Executes autotune benchmarking and caching
|
||||
pub struct Tuner<K: AutotuneKey> {
|
||||
tune_cache: TuneCache<K>,
|
||||
}
|
||||
|
||||
#[allow(clippy::new_without_default)]
|
||||
impl<K: AutotuneKey> Tuner<K> {
|
||||
/// Returns a tuner with cache initialized from persistent cache
|
||||
pub fn new(name: &str, device_id: &str) -> Self {
|
||||
Self {
|
||||
tune_cache: TuneCache::new(name, device_id),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch the fastest autotune operation index for an autotune key.
|
||||
pub fn autotune_fastest(&self, key: &K) -> Option<usize> {
|
||||
self.tune_cache.find_fastest(key)
|
||||
}
|
||||
|
||||
/// Execute the fastest autotune operation if known, otherwise perform some benchmarks before.
|
||||
pub fn execute_autotune<S, C>(
|
||||
&mut self,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
|
||||
client: &ComputeClient<S, C>,
|
||||
) where
|
||||
S: ComputeServer,
|
||||
C: ComputeChannel<S>,
|
||||
{
|
||||
let operation = match self.tune_cache.try_cache(autotune_operation_set) {
|
||||
super::TuneCacheResult::Hit(ops) => ops,
|
||||
super::TuneCacheResult::Miss(set) => self.autotuning(set, client),
|
||||
};
|
||||
|
||||
AutotuneOperation::execute(operation);
|
||||
}
|
||||
|
||||
fn autotuning<S, C>(
|
||||
&mut self,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
|
||||
client: &ComputeClient<S, C>,
|
||||
) -> Box<dyn AutotuneOperation>
|
||||
where
|
||||
S: ComputeServer,
|
||||
C: ComputeChannel<S>,
|
||||
{
|
||||
let key = autotune_operation_set.key();
|
||||
let autotunables = autotune_operation_set.autotunables();
|
||||
let mut names = Vec::with_capacity(autotunables.len());
|
||||
|
||||
let results: Vec<BenchmarkDurations> = autotunables
|
||||
.into_iter()
|
||||
.map(|op| {
|
||||
names.push(op.name().to_string());
|
||||
self.run_benchmark(op, client)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Finds the fastest operation, stores it and returns it
|
||||
let fastest_index = self.find_fastest(results);
|
||||
let fastest_name = names.get(fastest_index).unwrap();
|
||||
log::info!("Fastest result {fastest_name}-{key}");
|
||||
|
||||
self.tune_cache.cache_insert(key.clone(), fastest_index);
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
{
|
||||
let checksum = autotune_operation_set.compute_checksum();
|
||||
self.tune_cache
|
||||
.persistent_cache_insert(key, checksum, fastest_index);
|
||||
self.tune_cache.save();
|
||||
}
|
||||
|
||||
match self.tune_cache.try_cache(autotune_operation_set) {
|
||||
super::TuneCacheResult::Hit(ops) => ops,
|
||||
super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_benchmark<S, C>(
|
||||
&mut self,
|
||||
operation: Box<dyn AutotuneOperation>,
|
||||
client: &ComputeClient<S, C>,
|
||||
) -> BenchmarkDurations
|
||||
where
|
||||
S: ComputeServer,
|
||||
C: ComputeChannel<S>,
|
||||
{
|
||||
TuneBenchmark::new(operation, client.clone()).run()
|
||||
}
|
||||
|
||||
fn find_fastest(&self, results: Vec<BenchmarkDurations>) -> usize {
|
||||
let mut smallest_duration = Duration::MAX;
|
||||
let mut fastest_tunable = None;
|
||||
|
||||
for (i, result) in results.into_iter().enumerate() {
|
||||
let computed = BenchmarkComputations::new(&result);
|
||||
|
||||
if computed.median < smallest_duration {
|
||||
smallest_duration = computed.median;
|
||||
fastest_tunable = Some(i);
|
||||
}
|
||||
}
|
||||
|
||||
fastest_tunable.expect("At least one kernel needed. ")
|
||||
}
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use super::DummyServer;
|
||||
use burn_common::stub::RwLock;
|
||||
use burn_compute::channel::MutexComputeChannel;
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_compute::memory_management::simple::{
|
||||
DeallocStrategy, SimpleMemoryManagement, SliceStrategy,
|
||||
};
|
||||
use burn_compute::storage::BytesStorage;
|
||||
use burn_compute::tune::Tuner;
|
||||
use burn_compute::ComputeRuntime;
|
||||
|
||||
/// The dummy device.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
pub struct DummyDevice;
|
||||
|
||||
pub type DummyChannel = MutexComputeChannel<DummyServer>;
|
||||
pub type DummyClient = ComputeClient<DummyServer, DummyChannel>;
|
||||
|
||||
static RUNTIME: ComputeRuntime<DummyDevice, DummyServer, DummyChannel> = ComputeRuntime::new();
|
||||
pub static TUNER_DEVICE_ID: &str = "tests/dummy-device";
|
||||
pub static TUNER_PREFIX: &str = "dummy-tests/dummy-device";
|
||||
|
||||
pub fn init_client() -> ComputeClient<DummyServer, MutexComputeChannel<DummyServer>> {
|
||||
let storage = BytesStorage::default();
|
||||
let memory_management =
|
||||
SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never);
|
||||
let server = DummyServer::new(memory_management);
|
||||
let channel = MutexComputeChannel::new(server);
|
||||
let tuner = Arc::new(RwLock::new(Tuner::new("dummy", TUNER_DEVICE_ID)));
|
||||
ComputeClient::new(channel, tuner, Arc::new(()))
|
||||
}
|
||||
|
||||
pub fn client(device: &DummyDevice) -> DummyClient {
|
||||
RUNTIME.client(device, init_client)
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
use burn_compute::storage::BytesResource;
|
||||
|
||||
/// The DummyKernel trait should be implemented for every supported operation
|
||||
pub trait DummyKernel: Sync + Send {
|
||||
fn compute(&self, resources: &mut [BytesResource]);
|
||||
}
|
||||
|
||||
/// Contains the algorithm for element-wise addition
|
||||
pub struct DummyElementwiseAddition;
|
||||
|
||||
impl DummyKernel for DummyElementwiseAddition {
|
||||
fn compute(&self, inputs: &mut [BytesResource]) {
|
||||
// Notice how the kernel is responsible for determining which inputs
|
||||
// are read-only and which are writable.
|
||||
let lhs = &inputs[0].read();
|
||||
let rhs = &inputs[1].read();
|
||||
let out = &mut inputs[2].write();
|
||||
|
||||
let size = lhs.len();
|
||||
|
||||
for i in 0..size {
|
||||
out[i] = lhs[i] + rhs[i];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
mod compute;
|
||||
mod kernel;
|
||||
mod server;
|
||||
mod tune;
|
||||
|
||||
pub use compute::*;
|
||||
pub use kernel::*;
|
||||
pub use server::*;
|
||||
pub use tune::*;
|
|
@ -1,74 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use burn_common::{reader::reader_from_concrete, sync_type::SyncType};
|
||||
use burn_compute::{
|
||||
memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement},
|
||||
server::{Binding, ComputeServer, Handle},
|
||||
storage::{BytesResource, BytesStorage},
|
||||
};
|
||||
use derive_new::new;
|
||||
|
||||
use super::DummyKernel;
|
||||
|
||||
/// The dummy server is used to test the burn-compute infrastructure.
|
||||
/// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks.
|
||||
#[derive(new, Debug)]
|
||||
pub struct DummyServer<MM = SimpleMemoryManagement<BytesStorage>> {
|
||||
memory_management: MM,
|
||||
}
|
||||
|
||||
impl<MM> ComputeServer for DummyServer<MM>
|
||||
where
|
||||
MM: MemoryManagement<BytesStorage>,
|
||||
{
|
||||
type DispatchOptions = ();
|
||||
type Kernel = Arc<dyn DummyKernel>;
|
||||
type Storage = BytesStorage;
|
||||
type MemoryManagement = MM;
|
||||
type AutotuneKey = String;
|
||||
type FeatureSet = ();
|
||||
|
||||
fn read(&mut self, binding: Binding<Self>) -> burn_common::reader::Reader {
|
||||
let bytes = self.memory_management.get(binding.memory);
|
||||
reader_from_concrete(bytes.read().to_vec())
|
||||
}
|
||||
|
||||
fn get_resource(&mut self, binding: Binding<Self>) -> BytesResource {
|
||||
self.memory_management.get(binding.memory)
|
||||
}
|
||||
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self> {
|
||||
let handle = self.memory_management.reserve(data.len(), || {});
|
||||
let resource = self.memory_management.get(handle.clone().binding());
|
||||
|
||||
let bytes = resource.write();
|
||||
|
||||
for (i, val) in data.iter().enumerate() {
|
||||
bytes[i] = *val;
|
||||
}
|
||||
|
||||
Handle::new(handle)
|
||||
}
|
||||
|
||||
fn empty(&mut self, size: usize) -> Handle<Self> {
|
||||
Handle::new(self.memory_management.reserve(size, || {}))
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
kernel: Self::Kernel,
|
||||
_count: Self::DispatchOptions,
|
||||
bindings: Vec<Binding<Self>>,
|
||||
) {
|
||||
let mut resources = bindings
|
||||
.into_iter()
|
||||
.map(|binding| self.memory_management.get(binding.memory))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
kernel.compute(&mut resources);
|
||||
}
|
||||
|
||||
fn sync(&mut self, _: SyncType) {
|
||||
// Nothing to do with dummy backend.
|
||||
}
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use burn_compute::{client::ComputeClient, server::Binding, tune::AutotuneOperation};
|
||||
use derive_new::new;
|
||||
|
||||
use crate::dummy::{DummyChannel, DummyKernel, DummyServer};
|
||||
|
||||
#[derive(new)]
|
||||
/// Extended kernel that accounts for additional parameters, i.e. needed
|
||||
/// information that does not count as an input/output.
|
||||
pub struct OneKernelAutotuneOperation {
|
||||
kernel: Arc<dyn DummyKernel>,
|
||||
client: ComputeClient<DummyServer, DummyChannel>,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
bindings: Vec<Binding<DummyServer>>,
|
||||
}
|
||||
|
||||
impl AutotuneOperation for OneKernelAutotuneOperation {
|
||||
/// Executes the operation on given bindings and server, with the additional parameters
|
||||
fn execute(self: Box<Self>) {
|
||||
self.client.execute(self.kernel.clone(), (), self.bindings);
|
||||
}
|
||||
|
||||
fn clone(&self) -> Box<dyn AutotuneOperation> {
|
||||
Box::new(Self {
|
||||
kernel: self.kernel.clone(),
|
||||
client: self.client.clone(),
|
||||
shapes: self.shapes.clone(),
|
||||
bindings: self.bindings.clone(),
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,106 +0,0 @@
|
|||
use std::{thread::sleep, time::Duration};
|
||||
|
||||
use burn_compute::storage::BytesResource;
|
||||
|
||||
use crate::dummy::DummyKernel;
|
||||
|
||||
const SLEEP_MS: u64 = 1;
|
||||
|
||||
pub struct DummyElementwiseAdditionSlowWrong;
|
||||
pub struct DummyElementwiseMultiplication;
|
||||
pub struct DummyElementwiseMultiplicationSlowWrong;
|
||||
pub struct CacheTestFastOn3;
|
||||
pub struct CacheTestSlowOn3;
|
||||
pub struct ParameteredKernel;
|
||||
|
||||
impl DummyKernel for DummyElementwiseAdditionSlowWrong {
|
||||
fn compute(&self, inputs: &mut [BytesResource]) {
|
||||
// Slow and wrong on purpose, for tests
|
||||
let lhs = &inputs[0].read();
|
||||
let out = &mut inputs[2].write();
|
||||
|
||||
let size = lhs.len();
|
||||
|
||||
for i in 0..size {
|
||||
sleep(Duration::from_millis(SLEEP_MS));
|
||||
out[i] = lhs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
impl DummyKernel for DummyElementwiseMultiplication {
|
||||
fn compute(&self, inputs: &mut [BytesResource]) {
|
||||
let lhs = &inputs[0].read();
|
||||
let rhs = &inputs[1].read();
|
||||
let out = &mut inputs[2].write();
|
||||
|
||||
let size = lhs.len();
|
||||
|
||||
for i in 0..size {
|
||||
out[i] = lhs[i] * rhs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
impl DummyKernel for DummyElementwiseMultiplicationSlowWrong {
|
||||
fn compute(&self, inputs: &mut [BytesResource]) {
|
||||
// Slow and wrong on purpose, for tests
|
||||
let lhs = &inputs[0].read();
|
||||
let out = &mut inputs[2].write();
|
||||
|
||||
let size = lhs.len();
|
||||
|
||||
for i in 0..size {
|
||||
sleep(Duration::from_millis(SLEEP_MS));
|
||||
out[i] = lhs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
impl DummyKernel for CacheTestFastOn3 {
|
||||
fn compute(&self, inputs: &mut [BytesResource]) {
|
||||
// This is an artificial kernel designed for testing cache only
|
||||
let lhs = &inputs[0].read();
|
||||
let out = &mut inputs[2].write();
|
||||
|
||||
let size = lhs.len();
|
||||
if size == 3 {
|
||||
out[..size].copy_from_slice(&lhs[..size]);
|
||||
} else {
|
||||
for i in 0..size {
|
||||
sleep(Duration::from_millis(SLEEP_MS));
|
||||
out[i] = lhs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DummyKernel for CacheTestSlowOn3 {
|
||||
fn compute(&self, inputs: &mut [BytesResource]) {
|
||||
// This is an artificial kernel designed for testing cache only
|
||||
let lhs = &inputs[0].read();
|
||||
let rhs = &inputs[1].read();
|
||||
let out = &mut inputs[2].write();
|
||||
|
||||
let size = lhs.len();
|
||||
if size == 3 {
|
||||
for i in 0..size {
|
||||
sleep(Duration::from_millis(SLEEP_MS));
|
||||
out[i] = rhs[i];
|
||||
}
|
||||
} else {
|
||||
out[..size].copy_from_slice(&rhs[..size]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DummyKernel for ParameteredKernel {
|
||||
fn compute(&self, inputs: &mut [BytesResource]) {
|
||||
// This is an artificial kernel designed for info buffer
|
||||
let lhs = &inputs[0].read();
|
||||
let rhs = &inputs[1].read();
|
||||
let out = &mut inputs[2].write();
|
||||
let info = &inputs[3].read();
|
||||
|
||||
for i in 0..lhs.len() {
|
||||
out[i] = lhs[i] + rhs[i] + info[0];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
mod autotune_operations;
|
||||
mod kernels;
|
||||
mod operation_sets;
|
||||
|
||||
pub use autotune_operations::*;
|
||||
pub use kernels::*;
|
||||
#[allow(unused)]
|
||||
pub use operation_sets::*;
|
|
@ -1,194 +0,0 @@
|
|||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "autotune-persistent-cache")]
|
||||
use burn_compute::tune::compute_checksum;
|
||||
use burn_compute::{
|
||||
server::Binding,
|
||||
tune::{AutotuneOperation, AutotuneOperationSet},
|
||||
};
|
||||
|
||||
use crate::dummy::{
|
||||
CacheTestFastOn3, CacheTestSlowOn3, DummyClient, DummyElementwiseAddition,
|
||||
DummyElementwiseMultiplication, DummyElementwiseMultiplicationSlowWrong, DummyServer,
|
||||
OneKernelAutotuneOperation,
|
||||
};
|
||||
|
||||
use super::DummyElementwiseAdditionSlowWrong;
|
||||
|
||||
pub struct AdditionAutotuneOperationSet {
|
||||
client: DummyClient,
|
||||
key: String,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
bindings: Vec<Binding<DummyServer>>,
|
||||
}
|
||||
|
||||
impl AdditionAutotuneOperationSet {
|
||||
#[allow(dead_code)]
|
||||
pub fn new(
|
||||
client: DummyClient,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
bindings: Vec<Binding<DummyServer>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
key: format!("{}-{}", "add", log_shape_input_key(&shapes)),
|
||||
shapes,
|
||||
bindings,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutotuneOperationSet<String> for AdditionAutotuneOperationSet {
|
||||
fn key(&self) -> String {
|
||||
self.key.clone()
|
||||
}
|
||||
|
||||
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
|
||||
vec![
|
||||
Box::new(OneKernelAutotuneOperation::new(
|
||||
Arc::new(DummyElementwiseAddition),
|
||||
self.client.clone(),
|
||||
self.shapes.clone(),
|
||||
self.bindings.clone(),
|
||||
)),
|
||||
Box::new(OneKernelAutotuneOperation::new(
|
||||
Arc::new(DummyElementwiseAdditionSlowWrong),
|
||||
self.client.clone(),
|
||||
self.shapes.clone(),
|
||||
self.bindings.clone(),
|
||||
)),
|
||||
]
|
||||
}
|
||||
|
||||
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
|
||||
self.autotunables()[fastest_index].clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MultiplicationAutotuneOperationSet {
|
||||
client: DummyClient,
|
||||
key: String,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
bindings: Vec<Binding<DummyServer>>,
|
||||
}
|
||||
|
||||
impl MultiplicationAutotuneOperationSet {
|
||||
#[allow(dead_code)]
|
||||
pub fn new(
|
||||
client: DummyClient,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
bindings: Vec<Binding<DummyServer>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
key: format!("{}-{}", "mul", log_shape_input_key(&shapes)),
|
||||
shapes,
|
||||
bindings,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AutotuneOperationSet<String> for MultiplicationAutotuneOperationSet {
|
||||
fn key(&self) -> String {
|
||||
self.key.clone()
|
||||
}
|
||||
|
||||
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
|
||||
vec![
|
||||
Box::new(OneKernelAutotuneOperation::new(
|
||||
Arc::new(DummyElementwiseMultiplicationSlowWrong),
|
||||
self.client.clone(),
|
||||
self.shapes.clone(),
|
||||
self.bindings.clone(),
|
||||
)),
|
||||
Box::new(OneKernelAutotuneOperation::new(
|
||||
Arc::new(DummyElementwiseMultiplication),
|
||||
self.client.clone(),
|
||||
self.shapes.clone(),
|
||||
self.bindings.clone(),
|
||||
)),
|
||||
]
|
||||
}
|
||||
|
||||
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
|
||||
self.autotunables()[fastest_index].clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CacheTestAutotuneOperationSet {
|
||||
client: DummyClient,
|
||||
key: String,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
bindings: Vec<Binding<DummyServer>>,
|
||||
pub generate_random_checksum: bool,
|
||||
}
|
||||
|
||||
impl CacheTestAutotuneOperationSet {
|
||||
#[allow(dead_code)]
|
||||
pub fn new(
|
||||
client: DummyClient,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
bindings: Vec<Binding<DummyServer>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)),
|
||||
shapes,
|
||||
bindings,
|
||||
generate_random_checksum: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutotuneOperationSet<String> for CacheTestAutotuneOperationSet {
|
||||
fn key(&self) -> String {
|
||||
self.key.clone()
|
||||
}
|
||||
|
||||
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
|
||||
vec![
|
||||
Box::new(OneKernelAutotuneOperation::new(
|
||||
Arc::new(CacheTestFastOn3),
|
||||
self.client.clone(),
|
||||
self.shapes.clone(),
|
||||
self.bindings.clone(),
|
||||
)),
|
||||
Box::new(OneKernelAutotuneOperation::new(
|
||||
Arc::new(CacheTestSlowOn3),
|
||||
self.client.clone(),
|
||||
self.shapes.clone(),
|
||||
self.bindings.clone(),
|
||||
)),
|
||||
]
|
||||
}
|
||||
|
||||
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
|
||||
self.autotunables()[fastest_index].clone()
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
fn compute_checksum(&self) -> String {
|
||||
if self.generate_random_checksum {
|
||||
let rand_string: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(16)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
rand_string
|
||||
} else {
|
||||
compute_checksum(&self.autotunables())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log_shape_input_key(shapes: &[Vec<usize>]) -> String {
|
||||
let mut hash = String::new();
|
||||
let lhs = &shapes[0];
|
||||
for size in lhs {
|
||||
let exp = f32::ceil(f32::log2(*size as f32)) as u32;
|
||||
hash.push_str(2_u32.pow(exp).to_string().as_str());
|
||||
hash.push(',');
|
||||
}
|
||||
hash
|
||||
}
|
|
@ -1,292 +0,0 @@
|
|||
mod dummy;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::dummy::{client, DummyDevice, DummyElementwiseAddition};
|
||||
use burn_compute::ComputeRuntime;
|
||||
|
||||
#[allow(unused)]
|
||||
use serial_test::serial;
|
||||
|
||||
#[test]
|
||||
fn created_resource_is_the_same_when_read() {
|
||||
let client = client(&DummyDevice);
|
||||
let resource = Vec::from([0, 1, 2]);
|
||||
let resource_description = client.create(&resource);
|
||||
|
||||
let obtained_resource = client.read(resource_description.binding());
|
||||
|
||||
assert_eq!(resource, obtained_resource)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allocates_memory() {
|
||||
let client = client(&DummyDevice);
|
||||
let size = 4;
|
||||
let resource_description = client.empty(size);
|
||||
let empty_resource = client.read(resource_description.binding());
|
||||
|
||||
assert_eq!(empty_resource.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute_elementwise_addition() {
|
||||
let client = client(&DummyDevice);
|
||||
let lhs = client.create(&[0, 1, 2]);
|
||||
let rhs = client.create(&[4, 4, 4]);
|
||||
let out = client.empty(3);
|
||||
|
||||
client.execute(
|
||||
Arc::new(DummyElementwiseAddition),
|
||||
(),
|
||||
vec![lhs.binding(), rhs.binding(), out.clone().binding()],
|
||||
);
|
||||
|
||||
let obtained_resource = client.read(out.binding());
|
||||
|
||||
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
#[cfg(feature = "std")]
|
||||
fn autotune_basic_addition_execution() {
|
||||
let client = client(&DummyDevice);
|
||||
|
||||
let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
|
||||
let lhs = client.create(&[0, 1, 2]);
|
||||
let rhs = client.create(&[4, 4, 4]);
|
||||
let out = client.empty(3);
|
||||
let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()];
|
||||
|
||||
let addition_autotune_kernel =
|
||||
dummy::AdditionAutotuneOperationSet::new(client.clone(), shapes, handles);
|
||||
client.autotune_execute(Box::new(addition_autotune_kernel));
|
||||
|
||||
let obtained_resource = client.read(out.binding());
|
||||
|
||||
// If slow kernel was selected it would output [0, 1, 2]
|
||||
assert_eq!(obtained_resource, Vec::from([4, 5, 6]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
#[cfg(feature = "std")]
|
||||
fn autotune_basic_multiplication_execution() {
|
||||
let client = client(&DummyDevice);
|
||||
|
||||
let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
|
||||
let lhs = client.create(&[0, 1, 2]);
|
||||
let rhs = client.create(&[4, 4, 4]);
|
||||
let out = client.empty(3);
|
||||
let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()];
|
||||
|
||||
let multiplication_autotune_kernel =
|
||||
dummy::MultiplicationAutotuneOperationSet::new(client.clone(), shapes, handles);
|
||||
client.autotune_execute(Box::new(multiplication_autotune_kernel));
|
||||
|
||||
let obtained_resource = client.read(out.binding());
|
||||
|
||||
// If slow kernel was selected it would output [0, 1, 2]
|
||||
assert_eq!(obtained_resource, Vec::from([0, 4, 8]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
#[cfg(feature = "std")]
|
||||
fn autotune_cache_same_key_return_a_cache_hit() {
|
||||
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
|
||||
let runtime = Runtime::new();
|
||||
|
||||
let client = runtime.client(&DummyDevice, dummy::init_client);
|
||||
|
||||
// note: the key name depends on the shapes of the operation set
|
||||
// see log_shape_input_key for more info.
|
||||
|
||||
// in this test both shapes [1,3] and [1,4] end up with the same key name
|
||||
// which is 'cache_test-1,4'
|
||||
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
|
||||
let lhs_1 = client.create(&[0, 1, 2]);
|
||||
let rhs_1 = client.create(&[4, 4, 4]);
|
||||
let out_1 = client.empty(3);
|
||||
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
|
||||
|
||||
let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]];
|
||||
let lhs_2 = client.create(&[0, 1, 2, 3]);
|
||||
let rhs_2 = client.create(&[5, 6, 7, 8]);
|
||||
let out_2 = client.empty(4);
|
||||
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
|
||||
|
||||
let cache_test_autotune_kernel_1 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
|
||||
let cache_test_autotune_kernel_2 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
|
||||
|
||||
let obtained_resource = client.read(out_2.binding());
|
||||
|
||||
// Cache should be hit, so CacheTestFastOn3 should be used, returning lhs
|
||||
assert_eq!(obtained_resource, Vec::from([0, 1, 2, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
#[cfg(feature = "std")]
|
||||
fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
|
||||
// delete the cache file
|
||||
let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX);
|
||||
let _ = std::fs::remove_file(file_path);
|
||||
|
||||
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
|
||||
let compute = Runtime::new();
|
||||
|
||||
let client = compute.client(&DummyDevice, dummy::init_client);
|
||||
|
||||
// in this test shapes [1,3] and [1,5] ends up with different key names
|
||||
// which are 'cache_test-1,4' and 'cache_test-1,8'
|
||||
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
|
||||
let lhs_1 = client.create(&[0, 1, 2]);
|
||||
let rhs_1 = client.create(&[4, 4, 4]);
|
||||
let out_1 = client.empty(3);
|
||||
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
|
||||
|
||||
let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]];
|
||||
let lhs_2 = client.create(&[0, 1, 2, 3, 4]);
|
||||
let rhs_2 = client.create(&[5, 6, 7, 8, 9]);
|
||||
let out_2 = client.empty(5);
|
||||
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
|
||||
|
||||
let cache_test_autotune_kernel_1 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
|
||||
let cache_test_autotune_kernel_2 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
|
||||
|
||||
// read the resource which should update the cache on disk
|
||||
let obtained_resource = client.read(out_2.binding());
|
||||
|
||||
// Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs
|
||||
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
#[cfg(feature = "std")]
|
||||
fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
|
||||
// delete the cache file
|
||||
|
||||
use burn_common::sync_type::SyncType;
|
||||
let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX);
|
||||
let parent_dir = file_path
|
||||
.parent()
|
||||
.expect("Cache file should have a parent directory");
|
||||
// Delete the cache file's parent directory
|
||||
let _ = std::fs::remove_dir_all(parent_dir);
|
||||
|
||||
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
|
||||
let runtime = Runtime::new();
|
||||
let client = runtime.client(&DummyDevice, dummy::init_client);
|
||||
|
||||
// in this test shapes [1,3] and [1,5] ends up with different key names
|
||||
// which are 'cache_test-1,4' and 'cache_test-1,8'
|
||||
let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
|
||||
let lhs = client.create(&[0, 1, 2]);
|
||||
let rhs = client.create(&[4, 4, 4]);
|
||||
let out = client.empty(3);
|
||||
let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()];
|
||||
|
||||
let cache_test_autotune_kernel =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes, handles);
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel));
|
||||
// ensure that the autotune operations are run and cached
|
||||
client.sync(SyncType::Wait);
|
||||
|
||||
assert!(
|
||||
parent_dir.exists(),
|
||||
"Parent directory of the cache file should exist"
|
||||
);
|
||||
assert!(file_path.exists(), "Cache file should exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
#[cfg(feature = "std")]
|
||||
fn autotune_cache_different_keys_return_a_cache_miss() {
|
||||
let client = client(&DummyDevice);
|
||||
|
||||
// in this test shapes [1,3] and [1,5] ends up with different key names
|
||||
// which are 'cache_test-1,4' and 'cache_test-1,8'
|
||||
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
|
||||
let lhs_1 = client.create(&[0, 1, 2]);
|
||||
let rhs_1 = client.create(&[4, 4, 4]);
|
||||
let out_1 = client.empty(3);
|
||||
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
|
||||
|
||||
let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]];
|
||||
let lhs_2 = client.create(&[0, 1, 2, 3, 4]);
|
||||
let rhs_2 = client.create(&[5, 6, 7, 8, 9]);
|
||||
let out_2 = client.empty(5);
|
||||
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
|
||||
|
||||
let cache_test_autotune_kernel_1 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
|
||||
let cache_test_autotune_kernel_2 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
|
||||
|
||||
let obtained_resource = client.read(out_2.binding());
|
||||
|
||||
// Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs
|
||||
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
#[cfg(feature = "std")]
|
||||
fn autotune_cache_different_checksums_return_a_cache_miss() {
|
||||
use burn_common::sync_type::SyncType;
|
||||
|
||||
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
|
||||
let runtime = Runtime::new();
|
||||
let client = runtime.client(&DummyDevice, dummy::init_client);
|
||||
|
||||
// in this test both shapes [1,3] and [1,4] end up with the same key name
|
||||
// which is 'cache_test-1,4'
|
||||
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
|
||||
let lhs_1 = client.create(&[0, 1, 2]);
|
||||
let rhs_1 = client.create(&[4, 4, 4]);
|
||||
let out_1 = client.empty(3);
|
||||
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
|
||||
let cache_test_autotune_kernel_1 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
|
||||
client.sync(SyncType::Wait);
|
||||
|
||||
// we use a second compute client in order to have freshly initialized autotune cache
|
||||
// and test invalidation of the cache when the checksum of the operation set is
|
||||
// different
|
||||
let runtime = Runtime::new();
|
||||
let client = runtime.client(&DummyDevice, dummy::init_client);
|
||||
|
||||
let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]];
|
||||
let lhs_2 = client.create(&[0, 1, 2, 3]);
|
||||
let rhs_2 = client.create(&[5, 6, 7, 8]);
|
||||
let out_2 = client.empty(4);
|
||||
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
|
||||
|
||||
let mut cache_test_autotune_kernel_2 =
|
||||
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
|
||||
cache_test_autotune_kernel_2.generate_random_checksum = true;
|
||||
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
|
||||
client.sync(SyncType::Wait);
|
||||
|
||||
let obtained_resource = client.read(out_2.binding());
|
||||
|
||||
// Cache should be missed because the checksum on 4 is generated randomly
|
||||
// and thus is always different,
|
||||
// so CacheTestSlowOn3 (but faster on 4) should be used, returning rhs
|
||||
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8]));
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
[package]
|
||||
authors = [
|
||||
"nathanielsimard <nathaniel.simard.42@gmail.com>",
|
||||
"louisfd <louisfd94@gmail.com",
|
||||
]
|
||||
categories = ["science"]
|
||||
description = "TODO"
|
||||
edition.workspace = true
|
||||
keywords = []
|
||||
license.workspace = true
|
||||
name = "burn-cube-macros"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube-macros"
|
||||
version.workspace = true
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
std = []
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = { workspace = true }
|
||||
quote = { workspace = true }
|
||||
syn = { workspace = true }
|
||||
derive-new = { workspace = true }
|
|
@ -1,280 +0,0 @@
|
|||
use syn::{Member, Pat, PathArguments, Stmt};
|
||||
|
||||
use crate::tracker::VariableTracker;
|
||||
|
||||
pub const KEYWORDS: [&str; 20] = [
|
||||
"ABSOLUTE_POS",
|
||||
"ABSOLUTE_POS_X",
|
||||
"ABSOLUTE_POS_Y",
|
||||
"ABSOLUTE_POS_Z",
|
||||
"UNIT_POS",
|
||||
"UNIT_POS_X",
|
||||
"UNIT_POS_Y",
|
||||
"UNIT_POS_Z",
|
||||
"CUBE_POS",
|
||||
"CUBE_POS_X",
|
||||
"CUBE_POS_Y",
|
||||
"CUBE_POS_Z",
|
||||
"CUBE_DIM",
|
||||
"CUBE_DIM_X",
|
||||
"CUBE_DIM_Y",
|
||||
"CUBE_DIM_Z",
|
||||
"CUBE_COUNT",
|
||||
"CUBE_COUNT_X",
|
||||
"CUBE_COUNT_Y",
|
||||
"CUBE_COUNT_Z",
|
||||
];
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
/// Reads the whole Cube code and accumulates information,
|
||||
/// to generate a VariableTracker that looked variable uses ahead
|
||||
pub(crate) struct VariableAnalyzer {
|
||||
variable_tracker: VariableTracker,
|
||||
}
|
||||
|
||||
impl VariableAnalyzer {
|
||||
pub fn create_tracker(func: &syn::ItemFn) -> VariableTracker {
|
||||
let analyzer = VariableAnalyzer::default();
|
||||
analyzer.analyze(func)
|
||||
}
|
||||
}
|
||||
|
||||
impl VariableAnalyzer {
|
||||
fn analyze(mut self, func: &syn::ItemFn) -> VariableTracker {
|
||||
// Build the vector of (Id, depth), using recursion
|
||||
self.signature_declarations(&func.sig);
|
||||
self.find_occurrences_in_stmts(&func.block.stmts, 0);
|
||||
|
||||
self.variable_tracker
|
||||
}
|
||||
|
||||
fn signature_declarations(&mut self, sig: &syn::Signature) {
|
||||
for input in &sig.inputs {
|
||||
match input {
|
||||
syn::FnArg::Typed(pat) => {
|
||||
let ident = &*pat.pat;
|
||||
let is_comptime = is_ty_comptime(&pat.ty);
|
||||
|
||||
match ident {
|
||||
syn::Pat::Ident(pat_ident) => {
|
||||
let id = &pat_ident.ident;
|
||||
self.variable_tracker
|
||||
.analyze_declare(id.to_string(), 0, is_comptime);
|
||||
}
|
||||
_ => todo!("Analysis: unsupported ident {ident:?}"),
|
||||
}
|
||||
}
|
||||
_ => todo!("Analysis: unsupported input {input:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_occurrences_in_stmts(&mut self, stmts: &Vec<Stmt>, depth: u8) {
|
||||
for stmt in stmts {
|
||||
match stmt {
|
||||
// Declaration
|
||||
syn::Stmt::Local(local) => {
|
||||
let mut is_comptime = false;
|
||||
let id = match &local.pat {
|
||||
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
|
||||
syn::Pat::Type(pat_type) => {
|
||||
is_comptime = is_ty_comptime(&pat_type.ty);
|
||||
match &*pat_type.pat {
|
||||
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
|
||||
_ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat),
|
||||
}
|
||||
}
|
||||
syn::Pat::Wild(_) => None,
|
||||
_ => todo!("Analysis: unsupported path {:?}", local.pat),
|
||||
};
|
||||
if let Some(id) = id {
|
||||
self.variable_tracker
|
||||
.analyze_declare(id.to_string(), depth, is_comptime);
|
||||
}
|
||||
if let Some(local_init) = &local.init {
|
||||
self.find_occurrences_in_expr(&local_init.expr, depth)
|
||||
}
|
||||
}
|
||||
syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth),
|
||||
_ => todo!("Analysis: unsupported stmt {stmt:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: u8) {
|
||||
match expr {
|
||||
syn::Expr::ForLoop(expr) => {
|
||||
self.find_occurrences_in_expr(&expr.expr, depth);
|
||||
|
||||
let depth = depth + 1;
|
||||
|
||||
if let syn::Pat::Ident(pat_ident) = &*expr.pat {
|
||||
let id = &pat_ident.ident;
|
||||
self.variable_tracker
|
||||
.analyze_declare(id.to_string(), depth, false);
|
||||
}
|
||||
|
||||
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
|
||||
}
|
||||
syn::Expr::While(expr) => {
|
||||
let depth = depth + 1;
|
||||
|
||||
self.find_occurrences_in_expr(&expr.cond, depth);
|
||||
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
|
||||
}
|
||||
syn::Expr::Loop(expr) => {
|
||||
let depth = depth + 1;
|
||||
|
||||
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
|
||||
}
|
||||
syn::Expr::If(expr) => {
|
||||
let depth = depth + 1;
|
||||
|
||||
self.find_occurrences_in_expr(&expr.cond, depth);
|
||||
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
|
||||
if let Some((_, expr)) = &expr.else_branch {
|
||||
match &**expr {
|
||||
syn::Expr::Block(expr_block) => {
|
||||
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
|
||||
}
|
||||
syn::Expr::If(expr) => {
|
||||
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Expr::Assign(expr) => {
|
||||
self.find_occurrences_in_expr(&expr.left, depth);
|
||||
self.find_occurrences_in_expr(&expr.right, depth);
|
||||
}
|
||||
syn::Expr::Index(expr) => {
|
||||
self.find_occurrences_in_expr(&expr.expr, depth);
|
||||
self.find_occurrences_in_expr(&expr.index, depth);
|
||||
}
|
||||
syn::Expr::Path(expr) => {
|
||||
if let Some(ident) = expr.path.get_ident() {
|
||||
if !KEYWORDS.contains(&ident.to_string().as_str()) {
|
||||
self.variable_tracker.analyze_reuse(ident, depth, None);
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Expr::Binary(expr) => {
|
||||
self.find_occurrences_in_expr(&expr.left, depth);
|
||||
self.find_occurrences_in_expr(&expr.right, depth);
|
||||
}
|
||||
syn::Expr::Lit(_) => {}
|
||||
syn::Expr::Call(expr) => {
|
||||
match &*expr.func {
|
||||
syn::Expr::Path(expr_path) => {
|
||||
if let Some(first_segment) = expr_path.path.segments.first() {
|
||||
// Check if the path segment has generic arguments
|
||||
if let PathArguments::AngleBracketed(arguments) =
|
||||
&first_segment.arguments
|
||||
{
|
||||
// Extract the generic arguments
|
||||
for arg in &arguments.args {
|
||||
match arg {
|
||||
syn::GenericArgument::Type(_)
|
||||
| syn::GenericArgument::Constraint(_) => {}
|
||||
_ => todo!("Analysis: Generic {:?} not supported", arg),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => todo!("Analysis: unsupported func expr {:?}", expr.func),
|
||||
}
|
||||
for arg in expr.args.iter() {
|
||||
self.find_occurrences_in_expr(arg, depth);
|
||||
}
|
||||
}
|
||||
syn::Expr::MethodCall(expr) => {
|
||||
self.find_occurrences_in_expr(&expr.receiver, depth);
|
||||
for arg in expr.args.iter() {
|
||||
self.find_occurrences_in_expr(arg, depth);
|
||||
}
|
||||
}
|
||||
syn::Expr::Break(_) => {}
|
||||
syn::Expr::Return(expr) => {
|
||||
if expr.expr.is_some() {
|
||||
// Unsupported: handled in codegen.
|
||||
}
|
||||
}
|
||||
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
syn::Expr::Array(_expr) => {
|
||||
// No analysis since only literals are supported
|
||||
}
|
||||
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
syn::Expr::Closure(expr) => {
|
||||
let depth = depth + 1;
|
||||
|
||||
for path in expr.inputs.iter() {
|
||||
let mut is_comptime = false;
|
||||
let ident = match path {
|
||||
Pat::Ident(pat_ident) => &pat_ident.ident,
|
||||
Pat::Type(pat_type) => {
|
||||
is_comptime = is_ty_comptime(&pat_type.ty);
|
||||
|
||||
if let Pat::Ident(pat_ident) = &*pat_type.pat {
|
||||
&pat_ident.ident
|
||||
} else {
|
||||
todo!("Analysis: {:?} not supported in closure inputs. ", path);
|
||||
}
|
||||
}
|
||||
_ => todo!("Analysis: {:?} not supported in closure inputs. ", path),
|
||||
};
|
||||
|
||||
self.variable_tracker
|
||||
.analyze_declare(ident.to_string(), depth, is_comptime);
|
||||
}
|
||||
|
||||
self.find_occurrences_in_expr(&expr.body, depth)
|
||||
}
|
||||
syn::Expr::Unary(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
syn::Expr::Field(expr) => {
|
||||
if let Member::Named(attribute_ident) = &expr.member {
|
||||
if let syn::Expr::Path(struct_expr) = &*expr.base {
|
||||
let struct_ident = struct_expr
|
||||
.path
|
||||
.get_ident()
|
||||
.expect("Analysis: field access only supported on ident struct.");
|
||||
|
||||
self.variable_tracker.analyze_reuse(
|
||||
struct_ident,
|
||||
depth,
|
||||
Some(attribute_ident.to_string()),
|
||||
);
|
||||
} else {
|
||||
todo!("Analysis: field access only supported on ident struct.");
|
||||
}
|
||||
} else {
|
||||
todo!("Analysis: unnamed attribute not supported.");
|
||||
}
|
||||
}
|
||||
syn::Expr::Struct(expr) => {
|
||||
for field in expr.fields.iter() {
|
||||
self.find_occurrences_in_expr(&field.expr, depth)
|
||||
}
|
||||
}
|
||||
syn::Expr::Range(_range) => {
|
||||
// Error is handled during codegen.
|
||||
}
|
||||
_ => {
|
||||
// Error is handled during codegen.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_ty_comptime(ty: &syn::Type) -> bool {
|
||||
if let syn::Type::Path(path) = ty {
|
||||
for segment in path.path.segments.iter() {
|
||||
if segment.ident == "Comptime" {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
pub(crate) mod signature;
|
|
@ -1,70 +0,0 @@
|
|||
use quote::ToTokens;
|
||||
|
||||
use crate::tracker::VariableTracker;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum ExpandMode {
|
||||
FuncImpl,
|
||||
MethodImpl,
|
||||
}
|
||||
|
||||
pub fn expand_sig(
|
||||
sig: &syn::Signature,
|
||||
visibility: &syn::Visibility,
|
||||
mut variable_tracker: Option<&mut VariableTracker>,
|
||||
mode: ExpandMode,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let mut inputs = quote::quote!();
|
||||
|
||||
for input in &sig.inputs {
|
||||
match input {
|
||||
syn::FnArg::Typed(pat) => {
|
||||
let ident = pat.pat.clone();
|
||||
|
||||
if let syn::Pat::Ident(ident) = ident.as_ref() {
|
||||
if let Some(vars) = &mut variable_tracker {
|
||||
vars.codegen_declare(ident.ident.to_string(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
let ty = no_ref(pat.ty.as_ref());
|
||||
inputs.extend(quote::quote! {
|
||||
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
|
||||
});
|
||||
}
|
||||
_ => todo!("Only Typed inputs are supported"),
|
||||
}
|
||||
}
|
||||
|
||||
let mut output = quote::quote!();
|
||||
|
||||
match &sig.output {
|
||||
syn::ReturnType::Default => output.extend(quote::quote! { ()}),
|
||||
syn::ReturnType::Type(_, ty) => {
|
||||
let ty = no_ref(ty.as_ref());
|
||||
output.extend(quote::quote! {
|
||||
<#ty as burn_cube::frontend::CubeType>::ExpandType
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let ident = &sig.ident;
|
||||
let ident = match mode {
|
||||
ExpandMode::FuncImpl => syn::Ident::new("__expand".to_string().as_str(), ident.span()),
|
||||
_ => syn::Ident::new(format!("__expand_{ident}").as_str(), ident.span()),
|
||||
};
|
||||
|
||||
let generics = sig.generics.clone().into_token_stream();
|
||||
|
||||
quote::quote! {
|
||||
/// Expanded Cube function
|
||||
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
|
||||
}
|
||||
}
|
||||
|
||||
pub fn no_ref(ty: &syn::Type) -> &syn::Type {
|
||||
match ty {
|
||||
syn::Type::Reference(val) => &val.elem,
|
||||
_ => ty,
|
||||
}
|
||||
}
|
|
@ -1,94 +0,0 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::ToTokens;
|
||||
|
||||
use super::{expr::codegen_expr, variable::codegen_local};
|
||||
use crate::tracker::VariableTracker;
|
||||
|
||||
/// Codegen for a statement (generally one line)
|
||||
/// Entry point of code generation
|
||||
pub fn codegen_statement(
|
||||
statement: &syn::Stmt,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
match statement {
|
||||
syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_tracker),
|
||||
syn::Stmt::Expr(expr, semi) => {
|
||||
let expr = codegen_expr(expr, loop_level, variable_tracker).tokens;
|
||||
|
||||
match semi {
|
||||
Some(_semi) => quote::quote!(
|
||||
#expr;
|
||||
),
|
||||
None => expr,
|
||||
}
|
||||
}
|
||||
_ => todo!("Codegen: statement {statement:?} not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for a code block (a list of statements)
|
||||
pub(crate) fn codegen_block(
|
||||
block: &syn::Block,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let mut statements = quote::quote!();
|
||||
|
||||
for statement in block.stmts.iter() {
|
||||
statements.extend(codegen_statement(statement, loop_level, variable_tracker));
|
||||
}
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
#statements
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Codegen {
|
||||
pub tokens: proc_macro2::TokenStream,
|
||||
pub is_comptime: bool,
|
||||
pub array_indexing: Option<ArrayIndexing>,
|
||||
}
|
||||
|
||||
pub(crate) struct ArrayIndexing {
|
||||
pub array: proc_macro2::TokenStream,
|
||||
pub index: proc_macro2::TokenStream,
|
||||
}
|
||||
|
||||
impl From<proc_macro2::TokenStream> for Codegen {
|
||||
fn from(tokens: proc_macro2::TokenStream) -> Self {
|
||||
Self {
|
||||
tokens,
|
||||
is_comptime: false,
|
||||
array_indexing: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Codegen {
|
||||
pub fn new<S: Into<proc_macro2::TokenStream>>(tokens: S, is_comptime: bool) -> Self {
|
||||
Self {
|
||||
tokens: tokens.into(),
|
||||
is_comptime,
|
||||
array_indexing: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn split(self) -> (proc_macro2::TokenStream, bool) {
|
||||
(self.tokens, self.is_comptime)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToTokens for Codegen {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
tokens.extend(self.tokens.clone());
|
||||
}
|
||||
fn into_token_stream(self) -> TokenStream
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.tokens
|
||||
}
|
||||
}
|
|
@ -1,193 +0,0 @@
|
|||
use proc_macro2::TokenStream;
|
||||
|
||||
use crate::{codegen_function::expr::codegen_expr, tracker::VariableTracker};
|
||||
|
||||
use super::{
|
||||
base::{codegen_block, Codegen},
|
||||
function::codegen_call,
|
||||
operation::codegen_binary,
|
||||
variable::{codegen_lit, codegen_path_var},
|
||||
};
|
||||
|
||||
/// Codegen of for loops
|
||||
/// Supports range:
|
||||
/// for i in range(start, end, unroll) {...}
|
||||
pub(crate) fn codegen_for_loop(
|
||||
for_loop: &syn::ExprForLoop,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let i = &for_loop.pat;
|
||||
|
||||
if let syn::Pat::Ident(pat_ident) = &*for_loop.pat {
|
||||
let id = &pat_ident.ident;
|
||||
variable_tracker.codegen_declare(id.to_string(), loop_level as u8 + 1);
|
||||
}
|
||||
|
||||
let invalid_for_loop = || {
|
||||
syn::Error::new_spanned(
|
||||
&for_loop.expr,
|
||||
"Invalid for loop: use [range](cubecl::prelude::range] instead.",
|
||||
)
|
||||
.into_compile_error()
|
||||
};
|
||||
|
||||
match for_loop.expr.as_ref() {
|
||||
syn::Expr::Call(call) => {
|
||||
let func_name = match call.func.as_ref() {
|
||||
syn::Expr::Path(path) => match path.path.get_ident() {
|
||||
Some(ident) => ident,
|
||||
None => return invalid_for_loop(),
|
||||
},
|
||||
_ => {
|
||||
return invalid_for_loop();
|
||||
}
|
||||
};
|
||||
|
||||
if &func_name.to_string() == "range" {
|
||||
let mut args = call.args.clone();
|
||||
|
||||
let unroll = codegen_expr(
|
||||
&args.pop().unwrap().into_value(),
|
||||
loop_level,
|
||||
variable_tracker,
|
||||
);
|
||||
let end = codegen_expr(
|
||||
&args.pop().unwrap().into_value(),
|
||||
loop_level,
|
||||
variable_tracker,
|
||||
);
|
||||
let start = codegen_expr(
|
||||
&args.pop().unwrap().into_value(),
|
||||
loop_level,
|
||||
variable_tracker,
|
||||
);
|
||||
|
||||
let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _start = #start;
|
||||
let _end = #end;
|
||||
let _unroll = #unroll;
|
||||
burn_cube::frontend::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
invalid_for_loop()
|
||||
}
|
||||
}
|
||||
_ => invalid_for_loop(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for condition of an if or a while
|
||||
pub(crate) fn codegen_cond(
|
||||
cond: &syn::Expr,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Codegen {
|
||||
match cond {
|
||||
syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_tracker),
|
||||
syn::Expr::Lit(expr) => Codegen::new(codegen_lit(expr), false),
|
||||
syn::Expr::Path(expr) => codegen_path_var(expr, loop_level, variable_tracker),
|
||||
syn::Expr::Call(expr) => codegen_call(expr, loop_level, variable_tracker),
|
||||
_ => todo!("{cond:?} cond not supported"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for break statement
|
||||
pub(crate) fn codegen_break() -> TokenStream {
|
||||
quote::quote! {
|
||||
burn_cube::frontend::branch::break_expand(context);
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for return statement
|
||||
pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream {
|
||||
if expr_return.expr.is_some() {
|
||||
return syn::Error::new_spanned(expr_return, "Only void return is supported.")
|
||||
.into_compile_error();
|
||||
}
|
||||
|
||||
quote::quote! {
|
||||
burn_cube::frontend::branch::return_expand(context);
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for if and if/else statements
|
||||
/// Supports:
|
||||
/// if cond {...}
|
||||
/// if cond {...} else {...}
|
||||
/// if Comptime::get(...) {...} [else {...}]
|
||||
/// if Comptime::get(...) {...} [else if Comptime::get(...) {...}]* [else {...}]
|
||||
pub(crate) fn codegen_if(
|
||||
expr_if: &syn::ExprIf,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let (cond, is_comptime) = codegen_cond(&expr_if.cond, loop_level, variable_tracker).split();
|
||||
let comptime_bool = if is_comptime {
|
||||
quote::quote! { Some(#cond) }
|
||||
} else {
|
||||
quote::quote! { None }
|
||||
};
|
||||
|
||||
let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_tracker);
|
||||
|
||||
if let Some((_, expr)) = &expr_if.else_branch {
|
||||
let else_block = match &**expr {
|
||||
syn::Expr::Block(expr_block) => {
|
||||
codegen_block(&expr_block.block, loop_level + 1, variable_tracker)
|
||||
}
|
||||
|
||||
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level + 1, variable_tracker),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
quote::quote! {
|
||||
{
|
||||
let _cond = #cond;
|
||||
burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
let _cond = #cond;
|
||||
burn_cube::frontend::branch::if_expand(context, #comptime_bool, _cond.into(), |context| #then_block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen of loop
|
||||
pub(crate) fn codegen_loop(
|
||||
loop_expr: &syn::ExprLoop,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let block = codegen_block(&loop_expr.body, loop_level + 1, variable_tracker);
|
||||
|
||||
quote::quote! {
|
||||
burn_cube::frontend::branch::loop_expand(context, |context| #block);
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for while loop
|
||||
pub(crate) fn codegen_while_loop(
|
||||
while_loop: &syn::ExprWhile,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let (cond, is_comptime) =
|
||||
codegen_cond(&while_loop.cond, loop_level + 1, variable_tracker).split();
|
||||
|
||||
if is_comptime {
|
||||
return syn::Error::new_spanned(while_loop.while_token, "Comptime not supported for while")
|
||||
.into_compile_error();
|
||||
}
|
||||
|
||||
let block = codegen_block(&while_loop.body, loop_level + 1, variable_tracker);
|
||||
|
||||
quote::quote! {
|
||||
burn_cube::frontend::branch::while_loop_expand(context, |context| #cond, |context| #block);
|
||||
}
|
||||
}
|
|
@ -1,99 +0,0 @@
|
|||
use crate::tracker::VariableTracker;
|
||||
use proc_macro2::TokenStream;
|
||||
|
||||
use super::{
|
||||
base::{codegen_block, Codegen},
|
||||
branch::{
|
||||
codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_return,
|
||||
codegen_while_loop,
|
||||
},
|
||||
function::{codegen_call, codegen_closure, codegen_expr_method_call},
|
||||
operation::{codegen_binary, codegen_unary},
|
||||
variable::{
|
||||
codegen_array_lit, codegen_assign, codegen_field, codegen_index, codegen_lit,
|
||||
codegen_path_var, codegen_struct,
|
||||
},
|
||||
};
|
||||
|
||||
/// Codegen for expressions
|
||||
pub(crate) fn codegen_expr(
|
||||
expr: &syn::Expr,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Codegen {
|
||||
match expr {
|
||||
syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker),
|
||||
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker),
|
||||
_ => {
|
||||
let mut array_indexing = None;
|
||||
let tokens = match expr {
|
||||
syn::Expr::Path(path) => {
|
||||
return codegen_path_var(path, loop_level, variable_tracker)
|
||||
}
|
||||
syn::Expr::Binary(op) => return codegen_binary(op, loop_level, variable_tracker),
|
||||
syn::Expr::Unary(op) => return codegen_unary(op, loop_level, variable_tracker),
|
||||
syn::Expr::Lit(lit) => codegen_lit(lit),
|
||||
syn::Expr::Closure(closure) => {
|
||||
codegen_closure(closure, loop_level, variable_tracker)
|
||||
}
|
||||
syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_tracker),
|
||||
syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_tracker),
|
||||
syn::Expr::ForLoop(for_loop) => {
|
||||
codegen_for_loop(for_loop, loop_level, variable_tracker)
|
||||
}
|
||||
syn::Expr::While(while_loop) => {
|
||||
codegen_while_loop(while_loop, loop_level, variable_tracker)
|
||||
}
|
||||
syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_tracker),
|
||||
syn::Expr::Break(_) => codegen_break(),
|
||||
syn::Expr::Return(return_expr) => codegen_return(return_expr),
|
||||
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_tracker),
|
||||
syn::Expr::MethodCall(call) => {
|
||||
codegen_expr_method_call(call, loop_level, variable_tracker)
|
||||
}
|
||||
syn::Expr::Index(index) => {
|
||||
let codegen = codegen_index(index, loop_level, variable_tracker);
|
||||
array_indexing = codegen.array_indexing;
|
||||
codegen.tokens
|
||||
}
|
||||
syn::Expr::Array(array) => codegen_array_lit(array),
|
||||
syn::Expr::Reference(reference) => {
|
||||
codegen_ref(reference, loop_level, variable_tracker)
|
||||
}
|
||||
syn::Expr::Field(field) => codegen_field(field, loop_level, variable_tracker),
|
||||
syn::Expr::Struct(struct_) => codegen_struct(struct_, loop_level, variable_tracker),
|
||||
syn::Expr::Range(range) => syn::Error::new_spanned(
|
||||
range,
|
||||
"Range is not supported, use [range](cubecl::prelude::range) instead.",
|
||||
)
|
||||
.to_compile_error(),
|
||||
_ => {
|
||||
syn::Error::new_spanned(expr, "Expression is not supported").to_compile_error()
|
||||
}
|
||||
};
|
||||
|
||||
let mut codegen = Codegen::new(tokens, false);
|
||||
codegen.array_indexing = array_indexing;
|
||||
codegen
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for an expression containing a block
|
||||
pub(crate) fn codegen_expr_block(
|
||||
block: &syn::ExprBlock,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
codegen_block(&block.block, loop_level, variable_tracker)
|
||||
}
|
||||
|
||||
pub(crate) fn codegen_ref(
|
||||
reference: &syn::ExprReference,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
// We ignore reference for the expansion.
|
||||
let inner = codegen_expr(&reference.expr, loop_level, variable_tracker);
|
||||
quote::quote! { #inner }
|
||||
}
|
|
@ -1,250 +0,0 @@
|
|||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::quote_spanned;
|
||||
use syn::{
|
||||
punctuated::Punctuated, spanned::Spanned, AngleBracketedGenericArguments, Expr, Ident,
|
||||
PathArguments, Token,
|
||||
};
|
||||
|
||||
use crate::{codegen_function::expr::codegen_expr, tracker::VariableTracker};
|
||||
|
||||
use super::base::Codegen;
|
||||
|
||||
/// Codegen for method call
|
||||
/// Supports [expr].method(args)
|
||||
pub(crate) fn codegen_expr_method_call(
|
||||
call: &syn::ExprMethodCall,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let receiver = codegen_expr(&call.receiver, loop_level, variable_tracker);
|
||||
let method_expand = syn::Ident::new(
|
||||
format!("{}_expand", call.method).as_str(),
|
||||
proc_macro2::Span::call_site(),
|
||||
);
|
||||
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
|
||||
|
||||
quote::quote!( {
|
||||
#expansion
|
||||
#receiver . #method_expand ( #variables )
|
||||
})
|
||||
}
|
||||
|
||||
/// Codegen for a closure
|
||||
pub(crate) fn codegen_closure(
|
||||
closure: &syn::ExprClosure,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let mut inputs = quote::quote! {};
|
||||
for input in closure.inputs.iter() {
|
||||
let (ident, ty) = match input {
|
||||
syn::Pat::Ident(ident) => (&ident.ident, None),
|
||||
syn::Pat::Type(pat_type) => (
|
||||
if let syn::Pat::Ident(ident) = &*pat_type.pat {
|
||||
&ident.ident
|
||||
} else {
|
||||
return syn::Error::new_spanned(pat_type, "Unsupported input")
|
||||
.into_compile_error();
|
||||
},
|
||||
Some(pat_type.ty.clone()),
|
||||
),
|
||||
_ => return syn::Error::new_spanned(input, "Unsupported input").into_compile_error(),
|
||||
};
|
||||
|
||||
if let Some(ty) = ty {
|
||||
inputs.extend(quote::quote! {
|
||||
#ident : #ty,
|
||||
});
|
||||
} else {
|
||||
inputs.extend(quote::quote! {
|
||||
#ident,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let body = codegen_expr(closure.body.as_ref(), loop_level, variable_tracker);
|
||||
|
||||
quote::quote! {
|
||||
|context: &mut CubeContext, #inputs| #body
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps
|
||||
/// [A[::<...>]?::]^* func[::<...>] (args)
|
||||
/// to
|
||||
/// [A[::<...>]?::]^* func_expand[::<...>] (context, args)
|
||||
///
|
||||
/// Also returns a bool that is true if it's comptime
|
||||
pub(crate) fn codegen_call(
|
||||
call: &syn::ExprCall,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Codegen {
|
||||
// We start with parsing the function path
|
||||
let path: Vec<(&Ident, Option<&AngleBracketedGenericArguments>)> = match call.func.as_ref() {
|
||||
syn::Expr::Path(expr_path) => {
|
||||
let mut path = Vec::new();
|
||||
for segment in expr_path.path.segments.iter() {
|
||||
let generics = if let PathArguments::AngleBracketed(arguments) = &segment.arguments
|
||||
{
|
||||
Some(arguments)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
path.push((&segment.ident, generics));
|
||||
}
|
||||
path
|
||||
}
|
||||
_ => {
|
||||
return Codegen::new(
|
||||
syn::Error::new_spanned(&call.func, "Unsupported").into_compile_error(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Path
|
||||
let mut path_tokens = TokenStream::new();
|
||||
let mut is_comptime = false;
|
||||
let mut is_plain_func = true;
|
||||
let mut comptime_func: Option<String> = None;
|
||||
|
||||
for (i, (ident, generics)) in path.iter().enumerate() {
|
||||
let name = ident.to_string();
|
||||
|
||||
if name == "Comptime" {
|
||||
is_comptime = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(first_char) = name.chars().next() {
|
||||
if first_char.is_uppercase() {
|
||||
is_plain_func = false;
|
||||
}
|
||||
}
|
||||
|
||||
if i == path.len() - 1 {
|
||||
if is_comptime {
|
||||
comptime_func = Some(ident.to_string());
|
||||
break;
|
||||
}
|
||||
|
||||
let func_name_expand = if is_plain_func {
|
||||
quote::quote! {
|
||||
#ident::__expand
|
||||
}
|
||||
} else {
|
||||
let ident = syn::Ident::new(
|
||||
format!("__expand_{ident}").as_str(),
|
||||
proc_macro2::Span::call_site(),
|
||||
);
|
||||
quote::quote! {
|
||||
#ident
|
||||
}
|
||||
};
|
||||
path_tokens.extend(quote_spanned! {func_name_expand.span() => #func_name_expand });
|
||||
if let Some(generics) = generics {
|
||||
path_tokens.extend(quote_spanned! {generics.span() => #generics });
|
||||
}
|
||||
} else if let Some(generics) = generics {
|
||||
path_tokens.extend(quote_spanned! {ident.span() => #ident });
|
||||
path_tokens.extend(quote_spanned! {generics.span() => #generics :: });
|
||||
} else {
|
||||
path_tokens.extend(quote_spanned! {ident.span() => #ident :: });
|
||||
}
|
||||
}
|
||||
|
||||
// Arguments
|
||||
if let Some(func_name) = comptime_func {
|
||||
let tokens = match func_name.as_str() {
|
||||
"get" | "new" => {
|
||||
let code = call.args.first().unwrap();
|
||||
quote::quote! {#code}
|
||||
}
|
||||
"map" => {
|
||||
let args = &call.args;
|
||||
|
||||
// Codegen
|
||||
quote::quote! {
|
||||
{
|
||||
Comptime::map_expand(#args)
|
||||
}
|
||||
}
|
||||
}
|
||||
"unwrap_or_else" => {
|
||||
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
|
||||
|
||||
// Codegen
|
||||
quote::quote! {{
|
||||
#expansion
|
||||
Comptime::unwrap_or_else_expand(#variables)
|
||||
}}
|
||||
}
|
||||
"is_some" => {
|
||||
let code = call.args.first().unwrap();
|
||||
quote::quote! { #code.is_some() }
|
||||
}
|
||||
"vectorization" => {
|
||||
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
|
||||
|
||||
// Codegen
|
||||
quote::quote! {{
|
||||
#expansion
|
||||
Comptime::vectorization_expand(#variables)
|
||||
}}
|
||||
}
|
||||
"vectorize" => {
|
||||
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
|
||||
|
||||
// Codegen
|
||||
quote::quote! {{
|
||||
#expansion
|
||||
Comptime::vectorize_expand(#variables)
|
||||
}}
|
||||
}
|
||||
"runtime" => {
|
||||
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
|
||||
|
||||
// Codegen
|
||||
quote::quote! {{
|
||||
#expansion
|
||||
Comptime::runtime_expand(#variables)
|
||||
}}
|
||||
}
|
||||
|
||||
_ => panic!("Codegen: Comptime function {:?} does not exist", func_name),
|
||||
};
|
||||
|
||||
Codegen::new(tokens, true)
|
||||
} else {
|
||||
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
|
||||
|
||||
// Codegen
|
||||
let tokens = quote::quote! {{
|
||||
#expansion
|
||||
#path_tokens (#variables)
|
||||
}};
|
||||
|
||||
Codegen::new(tokens, false)
|
||||
}
|
||||
}
|
||||
|
||||
fn codegen_args(
|
||||
args: &Punctuated<Expr, Token![,]>,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> (TokenStream, TokenStream) {
|
||||
let mut expansion = quote::quote! {};
|
||||
let mut variables = quote::quote! {};
|
||||
|
||||
variables.extend(quote::quote! { context, });
|
||||
|
||||
for (i, argument) in args.iter().enumerate() {
|
||||
let ident = Ident::new(format!("_var_{i}").as_str(), Span::call_site());
|
||||
let arg_token = codegen_expr(argument, loop_level, variable_tracker);
|
||||
expansion.extend(quote::quote! { let #ident = #arg_token; });
|
||||
variables.extend(quote::quote! { #ident, });
|
||||
}
|
||||
|
||||
(expansion, variables)
|
||||
}
|
|
@ -1,521 +0,0 @@
|
|||
use proc_macro2::{Span, TokenStream};
|
||||
use syn::{parse_quote, Generics, Ident};
|
||||
|
||||
#[derive(Default)]
|
||||
struct Codegen {
|
||||
// Basic attributes.
|
||||
name: String,
|
||||
generics: Generics,
|
||||
fn_inputs: TokenStream,
|
||||
fn_output: TokenStream,
|
||||
// States to generate code.
|
||||
state_comptimes: Vec<(syn::Type, Ident)>,
|
||||
state_args: Vec<TokenStream>,
|
||||
state_inputs: Vec<(Ident, syn::Type)>,
|
||||
state_outputs: Vec<(Ident, syn::Type)>,
|
||||
}
|
||||
|
||||
impl Codegen {
|
||||
fn from_sig(sig: &syn::Signature) -> Self {
|
||||
let mut codegen = Codegen::default();
|
||||
|
||||
let mut first_letter = sig.ident.to_string();
|
||||
let second_part = first_letter.split_off(1);
|
||||
|
||||
codegen.name = format!("{}{}", first_letter.to_uppercase(), second_part);
|
||||
codegen.generics = sig.generics.clone();
|
||||
|
||||
let mut inputs = quote::quote!();
|
||||
|
||||
for input in &sig.inputs {
|
||||
let mut is_output = false;
|
||||
let mut comptime = false;
|
||||
|
||||
match input {
|
||||
syn::FnArg::Typed(pat) => {
|
||||
let (ty, ident) = match pat.pat.as_ref() {
|
||||
syn::Pat::Ident(ident) => {
|
||||
if ident.mutability.is_some() {
|
||||
is_output = true;
|
||||
}
|
||||
|
||||
if let syn::Type::Reference(ty) = pat.ty.as_ref() {
|
||||
if ty.mutability.is_some() {
|
||||
is_output = true;
|
||||
}
|
||||
};
|
||||
|
||||
if let syn::Type::Path(pat) = pat.ty.as_ref() {
|
||||
if let Some(name) = pat.path.segments.first() {
|
||||
let name = name.ident.to_string();
|
||||
|
||||
if name == "Comptime" {
|
||||
comptime = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(pat.ty.clone(), ident.ident.clone())
|
||||
}
|
||||
_ => panic!("Nop"),
|
||||
};
|
||||
|
||||
if comptime {
|
||||
codegen.state_args.push(quote::quote! {
|
||||
self.#ident
|
||||
});
|
||||
} else {
|
||||
codegen.state_args.push(quote::quote! {
|
||||
#ident
|
||||
});
|
||||
}
|
||||
|
||||
if comptime {
|
||||
let ty = no_ref(&ty);
|
||||
inputs.extend(quote::quote! {
|
||||
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
|
||||
});
|
||||
} else {
|
||||
let ty = no_ref(&ty);
|
||||
inputs.extend(quote::quote! {
|
||||
#ident: RuntimeArg<'a, #ty, R>,
|
||||
});
|
||||
}
|
||||
|
||||
if is_output {
|
||||
codegen
|
||||
.state_outputs
|
||||
.push((ident.clone(), no_ref(&ty).clone()));
|
||||
} else if comptime {
|
||||
codegen
|
||||
.state_comptimes
|
||||
.push((first_generic_ty(&ty).clone(), ident.clone()));
|
||||
} else {
|
||||
codegen
|
||||
.state_inputs
|
||||
.push((ident.clone(), no_ref(&ty).clone()));
|
||||
}
|
||||
}
|
||||
_ => panic!("Only Typed inputs are supported"),
|
||||
};
|
||||
}
|
||||
|
||||
let mut output = quote::quote!();
|
||||
|
||||
match &sig.output {
|
||||
syn::ReturnType::Default => output.extend(quote::quote! {()}),
|
||||
syn::ReturnType::Type(_, ty) => {
|
||||
output.extend(quote::quote! {
|
||||
<#ty as burn_cube::frontend::CubeType>::ExpandType
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
codegen.fn_inputs = inputs;
|
||||
codegen.fn_output = output;
|
||||
|
||||
codegen
|
||||
}
|
||||
|
||||
fn gen_kernel_struct(&self) -> TokenStream {
|
||||
let ident = Ident::new(&self.name, Span::call_site());
|
||||
let generics = add_runtime(self.generics.clone());
|
||||
let phantoms = self.phantoms(&generics, true);
|
||||
let mut comptimes = quote::quote! {};
|
||||
|
||||
for (ty, ident) in self.state_comptimes.iter() {
|
||||
comptimes.extend(quote::quote! {
|
||||
#ident: #ty,
|
||||
});
|
||||
}
|
||||
|
||||
quote::quote! {
|
||||
/// Kernel
|
||||
pub struct #ident #generics {
|
||||
settings: KernelSettings,
|
||||
#comptimes
|
||||
#phantoms
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_settings(&self) -> TokenStream {
|
||||
let mut variables = quote::quote! {};
|
||||
|
||||
for (pos, (ident, _ty)) in self.state_inputs.iter().enumerate() {
|
||||
variables.extend(quote::quote! {
|
||||
settings = ArgSettings::<R>::configure_input(&#ident, #pos, settings);
|
||||
});
|
||||
}
|
||||
|
||||
for (pos, (ident, _ty)) in self.state_outputs.iter().enumerate() {
|
||||
variables.extend(quote::quote! {
|
||||
settings = ArgSettings::<R>::configure_output(&#ident, #pos, settings);
|
||||
});
|
||||
}
|
||||
|
||||
quote::quote! {
|
||||
let mut settings = KernelSettings::default();
|
||||
settings = settings.cube_dim(cube_dim);
|
||||
#variables
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_register_input(&self) -> TokenStream {
|
||||
let generics = &self.generics;
|
||||
let mut variables = quote::quote! {};
|
||||
|
||||
for (pos, (_ident, ty)) in self.state_inputs.iter().enumerate() {
|
||||
variables.extend(quote::quote! {
|
||||
#pos => std::sync::Arc::new(<#ty as LaunchArgExpand>::expand(builder, settings.vectorization_input(#pos))),
|
||||
});
|
||||
}
|
||||
|
||||
quote::quote! {
|
||||
#[allow(unused)]
|
||||
fn register_input #generics(
|
||||
builder: &mut KernelBuilder,
|
||||
settings: &KernelSettings,
|
||||
position: usize,
|
||||
) -> std::sync::Arc<dyn core::any::Any> {
|
||||
match position {
|
||||
#variables
|
||||
_ => panic!("Input {position} is invalid."),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_register_output(&self) -> TokenStream {
|
||||
let generics = &self.generics;
|
||||
let mut variables = quote::quote! {};
|
||||
|
||||
for (pos, (_ident, ty)) in self.state_outputs.iter().enumerate() {
|
||||
variables.extend(quote::quote! {
|
||||
#pos => std::sync::Arc::new(<#ty as LaunchArgExpand>::expand_output(builder, settings.vectorization_output(#pos))),
|
||||
});
|
||||
}
|
||||
|
||||
quote::quote! {
|
||||
#[allow(unused)]
|
||||
fn register_output #generics (
|
||||
builder: &mut KernelBuilder,
|
||||
settings: &KernelSettings,
|
||||
position: usize,
|
||||
) -> std::sync::Arc<dyn core::any::Any> {
|
||||
match position {
|
||||
#variables
|
||||
_ => panic!("Input {position} is invalid."),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_define_impl(&self, expand: &TokenStream) -> TokenStream {
|
||||
let mut expand_args = quote::quote! { &mut builder.context, };
|
||||
|
||||
let mut variables = quote::quote! {};
|
||||
|
||||
for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() {
|
||||
variables.extend(quote::quote! {
|
||||
let #ident: &<#ty as CubeType>::ExpandType = inputs
|
||||
.get(&#pos)
|
||||
.unwrap()
|
||||
.downcast_ref()
|
||||
.expect("Input type should be correct. It could be caused by an invalid kernel input/output alias.");
|
||||
});
|
||||
}
|
||||
|
||||
for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() {
|
||||
variables.extend(quote::quote! {
|
||||
let #ident: &<#ty as CubeType>::ExpandType = outputs
|
||||
.get(&#pos)
|
||||
.unwrap()
|
||||
.downcast_ref()
|
||||
.expect("Output type should be correct. It could be caused by an invalid kernel input/output alias.");
|
||||
});
|
||||
}
|
||||
|
||||
for arg in self.state_args.iter() {
|
||||
expand_args.extend(quote::quote! {
|
||||
#arg.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
let expand_func = match self.generics.params.is_empty() {
|
||||
true => quote::quote! { #expand },
|
||||
false => {
|
||||
let generics = self.generics.split_for_impl().1;
|
||||
quote::quote! { #expand::#generics }
|
||||
}
|
||||
};
|
||||
|
||||
quote::quote! {
|
||||
#variables
|
||||
#expand_func(#expand_args);
|
||||
builder.build(self.settings.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_define_args(&self) -> TokenStream {
|
||||
let num_inputs = self.state_inputs.len();
|
||||
let num_outputs = self.state_outputs.len();
|
||||
|
||||
let register_input = self.gen_register_input();
|
||||
let register_output = self.gen_register_output();
|
||||
|
||||
let (register_input_call, register_output_call) = match self.generics.params.is_empty() {
|
||||
true => (
|
||||
quote::quote! { register_input },
|
||||
quote::quote! { register_output },
|
||||
),
|
||||
false => {
|
||||
let generics = self.generics.split_for_impl().1;
|
||||
|
||||
(
|
||||
quote::quote! { register_input::#generics },
|
||||
quote::quote! { register_output::#generics },
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let mut variables = quote::quote! {};
|
||||
|
||||
for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() {
|
||||
variables.extend(quote::quote! {
|
||||
let #ident = <&#ty as CubeType>::ExpandType =
|
||||
*inputs.remove(&#pos).unwrap().downcast().unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() {
|
||||
variables.extend(quote::quote! {
|
||||
let #ident = <&mut #ty as CubeType>::ExpandType =
|
||||
*outputs.remove(&#pos).unwrap().downcast().unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
let mut tokens = quote::quote! {
|
||||
let mut builder = KernelBuilder::default();
|
||||
|
||||
let mut inputs: std::collections::BTreeMap<usize, std::sync::Arc<dyn core::any::Any>> = std::collections::BTreeMap::new();
|
||||
let mut outputs: std::collections::BTreeMap<usize, std::sync::Arc<dyn core::any::Any>> = std::collections::BTreeMap::new();
|
||||
|
||||
for mapping in self.settings.mappings.iter() {
|
||||
if !inputs.contains_key(&mapping.pos_input) {
|
||||
inputs.insert(
|
||||
mapping.pos_input,
|
||||
#register_input_call(&mut builder, &self.settings, mapping.pos_input),
|
||||
);
|
||||
}
|
||||
|
||||
let input = inputs.get(&mapping.pos_input).unwrap();
|
||||
outputs.insert(mapping.pos_output, input.clone());
|
||||
}
|
||||
|
||||
#register_input
|
||||
#register_output
|
||||
};
|
||||
|
||||
if num_inputs > 0 {
|
||||
tokens.extend(quote::quote! {
|
||||
for i in 0..#num_inputs {
|
||||
if !inputs.contains_key(&i) {
|
||||
inputs.insert(i, #register_input_call(&mut builder, &self.settings, i));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if num_outputs > 0 {
|
||||
tokens.extend(quote::quote! {
|
||||
for i in 0..#num_outputs {
|
||||
if !outputs.contains_key(&i) {
|
||||
outputs.insert(i, #register_output_call(&mut builder, &self.settings, i));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
fn gen_compile_impl(&self, expand: &TokenStream) -> TokenStream {
|
||||
let ident = Ident::new(&self.name, Span::call_site());
|
||||
let generics = add_runtime(self.generics.clone());
|
||||
let (impl_gen, ty_gen, where_gen) = generics.split_for_impl();
|
||||
|
||||
let mut format_str = "{:?}-{}".to_string();
|
||||
for _ in 0..self.state_comptimes.len() {
|
||||
format_str.push_str("-{:?}");
|
||||
}
|
||||
|
||||
let mut format_args = quote::quote! { core::any::TypeId::of::<Self>(), self.settings, };
|
||||
|
||||
for (_, ident) in self.state_comptimes.iter() {
|
||||
format_args.extend(quote::quote! { self.#ident, });
|
||||
}
|
||||
|
||||
let define_args = self.gen_define_args();
|
||||
let define_impl = self.gen_define_impl(expand);
|
||||
|
||||
quote::quote! {
|
||||
impl #impl_gen Kernel for #ident #ty_gen #where_gen {
|
||||
fn define(&self) -> KernelDefinition {
|
||||
#define_args
|
||||
#define_impl
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!(#format_str, #format_args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn phantoms(&self, generics: &Generics, declaration: bool) -> TokenStream {
|
||||
let mut phantoms = quote::quote! {};
|
||||
|
||||
for param in generics.params.iter() {
|
||||
let ty = match param {
|
||||
syn::GenericParam::Type(ty) => ty,
|
||||
_ => continue,
|
||||
};
|
||||
let ident = Ident::new(
|
||||
format!("_{}", ty.ident.to_string().to_lowercase()).as_str(),
|
||||
Span::call_site(),
|
||||
);
|
||||
let ty = &ty.ident;
|
||||
if declaration {
|
||||
phantoms.extend(quote::quote! {
|
||||
#ident: core::marker::PhantomData<#ty>,
|
||||
});
|
||||
} else {
|
||||
phantoms.extend(quote::quote! {
|
||||
#ident: core::marker::PhantomData::<#ty>,
|
||||
});
|
||||
}
|
||||
}
|
||||
phantoms
|
||||
}
|
||||
|
||||
fn gen_launch_body(&self) -> TokenStream {
|
||||
let ident = Ident::new(&self.name, Span::call_site());
|
||||
let generics = add_runtime(self.generics.clone());
|
||||
let phantoms = self.phantoms(&generics, false);
|
||||
|
||||
let mut comptimes = quote::quote! {};
|
||||
let settings = self.gen_settings();
|
||||
|
||||
let mut body = quote::quote! {
|
||||
let mut launcher = KernelLauncher::<R>::default();
|
||||
};
|
||||
|
||||
for (input, _) in self.state_inputs.iter() {
|
||||
body.extend(quote::quote! {
|
||||
#input.register(&mut launcher);
|
||||
});
|
||||
}
|
||||
|
||||
for (input, _) in self.state_outputs.iter() {
|
||||
body.extend(quote::quote! {
|
||||
#input.register(&mut launcher);
|
||||
});
|
||||
}
|
||||
|
||||
for (_ty, ident) in self.state_comptimes.iter() {
|
||||
comptimes.extend(quote::quote! {
|
||||
#ident,
|
||||
});
|
||||
}
|
||||
|
||||
let kernel = quote::quote! {
|
||||
#ident {
|
||||
settings,
|
||||
#comptimes
|
||||
#phantoms
|
||||
}
|
||||
};
|
||||
|
||||
quote::quote! {
|
||||
#settings
|
||||
|
||||
let kernel = #kernel;
|
||||
|
||||
#body
|
||||
|
||||
launcher.launch(cube_count, kernel, client);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
|
||||
let codegen = Codegen::from_sig(sig);
|
||||
|
||||
let ident = &sig.ident;
|
||||
|
||||
let ident_expand = quote::quote! {
|
||||
__expand
|
||||
};
|
||||
|
||||
let generics = add_runtime(add_lifetime(sig.generics.clone()));
|
||||
let body = codegen.gen_launch_body();
|
||||
let kernel = codegen.gen_kernel_struct();
|
||||
let compile = codegen.gen_compile_impl(&ident_expand);
|
||||
let (inputs, output) = (codegen.fn_inputs, codegen.fn_output);
|
||||
let doc = format!("Launch the kernel [{ident}()] on the given runtime.");
|
||||
|
||||
quote::quote! {
|
||||
#kernel
|
||||
#compile
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[doc = #doc]
|
||||
pub fn launch #generics (
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
cube_dim: CubeDim,
|
||||
#inputs
|
||||
) -> #output {
|
||||
#body;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_lifetime(mut generics: Generics) -> Generics {
|
||||
let lifetime: syn::Generics = parse_quote! {<'a>};
|
||||
|
||||
generics
|
||||
.params
|
||||
.insert(0, lifetime.params.into_iter().next().unwrap());
|
||||
generics
|
||||
}
|
||||
|
||||
pub fn add_runtime(mut generics: Generics) -> Generics {
|
||||
let runtime: syn::Generics = parse_quote! { <R: Runtime> };
|
||||
|
||||
generics
|
||||
.params
|
||||
.push(runtime.params.into_iter().next().unwrap());
|
||||
generics
|
||||
}
|
||||
|
||||
fn first_generic_ty(ty: &syn::Type) -> syn::Type {
|
||||
match ty {
|
||||
syn::Type::Path(pat) => match &pat.path.segments.first().unwrap().arguments {
|
||||
syn::PathArguments::AngleBracketed(ty) => match ty.args.first().unwrap() {
|
||||
syn::GenericArgument::Type(ty) => ty.clone(),
|
||||
_ => panic!("Should have a generic type"),
|
||||
},
|
||||
_ => panic!("Comptime must have a generic"),
|
||||
},
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn no_ref(ty: &syn::Type) -> &syn::Type {
|
||||
match ty {
|
||||
syn::Type::Reference(val) => &val.elem,
|
||||
_ => ty,
|
||||
}
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
mod base;
|
||||
mod branch;
|
||||
mod expr;
|
||||
mod function;
|
||||
mod launch;
|
||||
mod operation;
|
||||
mod variable;
|
||||
|
||||
pub(crate) use base::codegen_statement;
|
||||
pub(crate) use launch::codegen_launch;
|
|
@ -1,270 +0,0 @@
|
|||
use crate::tracker::VariableTracker;
|
||||
|
||||
use super::{base::Codegen, expr::codegen_expr};
|
||||
|
||||
/// Codegen for binary operations (+, -, *, etc.)
|
||||
pub(crate) fn codegen_binary(
|
||||
binary: &syn::ExprBinary,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Codegen {
|
||||
let lhs = codegen_expr(&binary.left, loop_level, variable_tracker);
|
||||
let (lhs, is_comptime_lhs, lhs_array) = (lhs.tokens, lhs.is_comptime, lhs.array_indexing);
|
||||
let (rhs, is_comptime_rhs) = codegen_expr(&binary.right, loop_level, variable_tracker).split();
|
||||
|
||||
if is_comptime_lhs && is_comptime_rhs {
|
||||
return Codegen::new(
|
||||
quote::quote! {
|
||||
#binary
|
||||
},
|
||||
true,
|
||||
);
|
||||
}
|
||||
|
||||
Codegen::new(
|
||||
match binary.op {
|
||||
syn::BinOp::Add(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::add::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Sub(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::sub::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Mul(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::mul::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Div(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::div::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Rem(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::rem::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Ne(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::ne::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Gt(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::gt::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Ge(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::ge::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Lt(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::lt::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Le(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::le::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Eq(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::eq::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::AddAssign(_) => {
|
||||
if let Some(array) = lhs_array {
|
||||
let (array, index) = (array.array, array.index);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #rhs;
|
||||
burn_cube::frontend::add_assign_array_op::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::BinOp::SubAssign(_) => {
|
||||
if let Some(array) = lhs_array {
|
||||
let (array, index) = (array.array, array.index);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #rhs;
|
||||
burn_cube::frontend::sub_assign_array_op::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::BinOp::MulAssign(_) => {
|
||||
if let Some(array) = lhs_array {
|
||||
let (array, index) = (array.array, array.index);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #rhs;
|
||||
burn_cube::frontend::mul_assign_array_op::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::BinOp::DivAssign(_) => {
|
||||
if let Some(array) = lhs_array {
|
||||
let (array, index) = (array.array, array.index);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #rhs;
|
||||
burn_cube::frontend::div_assign_array_op::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::BinOp::And(_) => quote::quote! {
|
||||
{
|
||||
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::and::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Or(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::or::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::BitAnd(_) => quote::quote! {
|
||||
{
|
||||
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::bitand::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::BitXor(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::bitxor::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Shl(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::shl::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
syn::BinOp::Shr(_) => quote::quote! {
|
||||
{
|
||||
let _lhs = #lhs;
|
||||
let _rhs = #rhs;
|
||||
burn_cube::frontend::shr::expand(context, _lhs, _rhs)
|
||||
}
|
||||
},
|
||||
_ => todo!("Codegen: unsupported op {:?}", binary.op),
|
||||
},
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
/// Codegen for unary operations
|
||||
pub(crate) fn codegen_unary(
|
||||
unary: &syn::ExprUnary,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Codegen {
|
||||
let (inner, is_comptime) = codegen_expr(&unary.expr, loop_level, variable_tracker).split();
|
||||
|
||||
if is_comptime {
|
||||
return Codegen::new(
|
||||
quote::quote! {
|
||||
#unary
|
||||
},
|
||||
true,
|
||||
);
|
||||
}
|
||||
|
||||
Codegen::new(
|
||||
match unary.op {
|
||||
syn::UnOp::Not(_) => quote::quote! {
|
||||
{
|
||||
let _inner = #inner;
|
||||
burn_cube::frontend::not::expand(context, _inner)
|
||||
}
|
||||
},
|
||||
_ => todo!("Codegen: unsupported op {:?}", unary.op),
|
||||
},
|
||||
false,
|
||||
)
|
||||
}
|
|
@ -1,322 +0,0 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::ToTokens;
|
||||
use syn::{punctuated::Punctuated, FieldValue, Lit, Member, PathArguments, Token};
|
||||
|
||||
use crate::{analyzer::KEYWORDS, codegen_function::expr::codegen_expr, tracker::VariableTracker};
|
||||
|
||||
use super::base::Codegen;
|
||||
|
||||
/// Codegen for literals
|
||||
pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
|
||||
match lit.lit {
|
||||
// We treat floats differently to avoid getting 4..into() for instance
|
||||
Lit::Float(_) => {
|
||||
let lit_str = lit.lit.to_token_stream().to_string();
|
||||
let float_lit = lit_str.parse::<f32>().unwrap();
|
||||
quote::quote! { #float_lit }
|
||||
}
|
||||
_ => {
|
||||
quote::quote! { #lit }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for arrays of literals
|
||||
pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
|
||||
let mut tokens = quote::quote! {};
|
||||
for element in array.elems.iter() {
|
||||
let token = match element {
|
||||
syn::Expr::Lit(lit) => codegen_lit(lit),
|
||||
_ => {
|
||||
return syn::Error::new_spanned(array, "Only arrays of literals are supported")
|
||||
.into_compile_error()
|
||||
}
|
||||
};
|
||||
tokens.extend(quote::quote! { #token, });
|
||||
}
|
||||
quote::quote! { [ #tokens ] }
|
||||
}
|
||||
|
||||
/// Codegen for a local declaration (let ...)
|
||||
/// Supports:
|
||||
/// let x = ...
|
||||
/// let x: T = ...
|
||||
/// let _ = ...
|
||||
/// let mut _ = ...
|
||||
pub(crate) fn codegen_local(
|
||||
local: &syn::Local,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let let_tok = local.let_token;
|
||||
|
||||
let ident = match &local.pat {
|
||||
syn::Pat::Ident(ident) => ident.to_token_stream(),
|
||||
syn::Pat::Type(pat_type) => match &*pat_type.pat {
|
||||
syn::Pat::Ident(pat_ident) => pat_ident.to_token_stream(),
|
||||
_ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat),
|
||||
},
|
||||
syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(),
|
||||
_ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat),
|
||||
};
|
||||
|
||||
variable_tracker.codegen_declare(ident.to_string(), loop_level as u8);
|
||||
|
||||
match local.init.as_ref() {
|
||||
Some(init) => {
|
||||
let (init, is_comptime) =
|
||||
codegen_expr(&init.expr, loop_level, variable_tracker).split();
|
||||
|
||||
if is_comptime {
|
||||
variable_tracker
|
||||
.set_as_comptime(ident.to_string(), loop_level as u8, None)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if is_comptime {
|
||||
quote::quote! {
|
||||
#let_tok #ident = #init;
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
#let_tok #ident = {
|
||||
let _inner = #init;
|
||||
burn_cube::frontend::Init::init(_inner, context)
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
quote::quote! {
|
||||
#let_tok #ident;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for indexed access
|
||||
pub(crate) fn codegen_index(
|
||||
index: &syn::ExprIndex,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Codegen {
|
||||
let array = codegen_expr(&index.expr, loop_level, variable_tracker);
|
||||
let index = codegen_expr(&index.index, loop_level, variable_tracker);
|
||||
|
||||
let tokens = quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
burn_cube::frontend::index::expand(context, _array, _index)
|
||||
}
|
||||
};
|
||||
|
||||
let mut codegen = Codegen::new(tokens, false);
|
||||
codegen.array_indexing = Some(super::base::ArrayIndexing {
|
||||
array: array.tokens,
|
||||
index: index.tokens,
|
||||
});
|
||||
|
||||
codegen
|
||||
}
|
||||
|
||||
/// Codegen for assignation
|
||||
/// Supports:
|
||||
/// - scalar
|
||||
/// - indexed array
|
||||
pub(crate) fn codegen_assign(
|
||||
assign: &syn::ExprAssign,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
match assign.left.as_ref() {
|
||||
syn::Expr::Index(index) => {
|
||||
let array = codegen_expr(&index.expr, loop_level, variable_tracker);
|
||||
let index = codegen_expr(&index.index, loop_level, variable_tracker);
|
||||
let value = codegen_expr(&assign.right, loop_level, variable_tracker);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _array = #array;
|
||||
let _index = #index;
|
||||
let _value = #value;
|
||||
burn_cube::frontend::index_assign::expand(context, _array, _index, _value)
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Expr::Path(_) => {
|
||||
let lhs = codegen_expr(&assign.left, loop_level, variable_tracker);
|
||||
let rhs = codegen_expr(&assign.right, loop_level, variable_tracker);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _assign_lhs = #lhs;
|
||||
let _assign_rhs = #rhs;
|
||||
burn_cube::frontend::assign::expand(context, _assign_rhs, _assign_lhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
syn::Expr::Field(_) => {
|
||||
let lhs = codegen_expr(&assign.left, loop_level, variable_tracker);
|
||||
let rhs = codegen_expr(&assign.right, loop_level, variable_tracker);
|
||||
|
||||
quote::quote! {
|
||||
{
|
||||
let _assign_lhs = #lhs;
|
||||
let _assign_rhs = #rhs;
|
||||
burn_cube::frontend::assign::expand(context, _assign_rhs, _assign_lhs)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => todo!("Assign of expr {:?} unsupported", assign.left),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn codegen_path_var(
|
||||
path: &syn::ExprPath,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> Codegen {
|
||||
let ident = match path.path.get_ident() {
|
||||
Some(ident) => ident,
|
||||
None => {
|
||||
return Codegen::new(
|
||||
quote::quote! {
|
||||
#path
|
||||
},
|
||||
false,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let name = ident.to_string();
|
||||
|
||||
if name == "None" {
|
||||
return Codegen::new(quote::quote! { None }, true);
|
||||
}
|
||||
|
||||
if KEYWORDS.contains(&name.as_str()) {
|
||||
Codegen::new(
|
||||
quote::quote! {
|
||||
#ident :: expand(context)
|
||||
},
|
||||
false,
|
||||
)
|
||||
} else {
|
||||
let (will_be_used_again, is_comptime) = variable_tracker
|
||||
.codegen_reuse(name, loop_level as u8, None)
|
||||
.unwrap_or((true, false));
|
||||
|
||||
let output = if will_be_used_again {
|
||||
quote::quote! {
|
||||
#ident.clone()
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
#ident
|
||||
}
|
||||
};
|
||||
|
||||
Codegen::new(output, is_comptime)
|
||||
}
|
||||
}
|
||||
|
||||
/// Codegen for a field used in rhs of a statement
|
||||
/// This function adds cloning when necessary
|
||||
pub(crate) fn codegen_field(
|
||||
field: &syn::ExprField,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let (struct_, field) = if let Member::Named(attribute_ident) = &field.member {
|
||||
if let syn::Expr::Path(struct_expr) = &*field.base {
|
||||
let struct_ident = struct_expr
|
||||
.path
|
||||
.get_ident()
|
||||
.expect("Codegen: field access only supported on ident struct.");
|
||||
|
||||
(struct_ident, attribute_ident)
|
||||
} else {
|
||||
todo!("Codegen: field access only supported on ident struct.");
|
||||
}
|
||||
} else {
|
||||
todo!("Codegen: unnamed attribute not supported.");
|
||||
};
|
||||
|
||||
let (will_be_used_again, _) = variable_tracker
|
||||
.codegen_reuse(
|
||||
struct_.to_string(),
|
||||
loop_level as u8,
|
||||
Some(field.to_string()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
if will_be_used_again {
|
||||
quote::quote! {
|
||||
#struct_ . #field .clone()
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
#struct_ . #field
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Codegen for a struct declaration
|
||||
pub(crate) fn codegen_struct(
|
||||
struct_: &syn::ExprStruct,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let mut deconstructed_path = Vec::new();
|
||||
for segment in struct_.path.segments.iter() {
|
||||
let generics = if let PathArguments::AngleBracketed(arguments) = &segment.arguments {
|
||||
Some(arguments)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
deconstructed_path.push((&segment.ident, generics));
|
||||
}
|
||||
|
||||
let (struct_name, generics) = deconstructed_path
|
||||
.pop()
|
||||
.expect("At least one ident in the path");
|
||||
|
||||
// This is hacky but using <struct_ as CubeType>::ExpandType {...} is experimental in Rust
|
||||
let expanded_struct_name = syn::Ident::new(
|
||||
format!("{}Expand", struct_name).as_str(),
|
||||
proc_macro2::Span::call_site(),
|
||||
);
|
||||
|
||||
deconstructed_path.push((&expanded_struct_name, generics));
|
||||
|
||||
// Reconstruct the path
|
||||
let mut path_tokens = quote::quote! {};
|
||||
for (ident, angle_bracketed_generics) in deconstructed_path {
|
||||
let ident_tokens = ident.to_token_stream();
|
||||
let generics_tokens = angle_bracketed_generics.to_token_stream();
|
||||
|
||||
path_tokens.extend(quote::quote! {
|
||||
#ident_tokens #generics_tokens
|
||||
});
|
||||
}
|
||||
|
||||
let fields = codegen_field_creation(&struct_.fields, loop_level, variable_tracker);
|
||||
quote::quote! {
|
||||
#path_tokens { #fields }
|
||||
}
|
||||
}
|
||||
|
||||
fn codegen_field_creation(
|
||||
fields: &Punctuated<FieldValue, Token![,]>,
|
||||
loop_level: usize,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
) -> TokenStream {
|
||||
let mut field_tokens = quote::quote! {};
|
||||
for field in fields.iter() {
|
||||
let field_name_token = &field.member;
|
||||
let field_value_token = codegen_expr(&field.expr, loop_level, variable_tracker);
|
||||
field_tokens.extend(quote::quote! { #field_name_token : #field_value_token, });
|
||||
}
|
||||
field_tokens
|
||||
}
|
|
@ -1,110 +0,0 @@
|
|||
use proc_macro2::TokenStream;
|
||||
|
||||
use crate::codegen_common::signature::{expand_sig, ExpandMode};
|
||||
|
||||
pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
|
||||
let mut expand_items = Vec::new();
|
||||
|
||||
for item in tr.items.iter() {
|
||||
match item {
|
||||
syn::TraitItem::Fn(func) => {
|
||||
let expand = expand_sig(
|
||||
&func.sig,
|
||||
&syn::Visibility::Inherited,
|
||||
None,
|
||||
ExpandMode::MethodImpl,
|
||||
);
|
||||
expand_items.push(syn::parse_quote!(#expand;));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
tr.items.append(&mut expand_items);
|
||||
|
||||
quote::quote! {
|
||||
#tr
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
|
||||
let mut expand_items = Vec::new();
|
||||
|
||||
for item in tr.items.iter() {
|
||||
match item {
|
||||
syn::ImplItem::Fn(func) => {
|
||||
let ident = &func.sig.ident;
|
||||
let ident = quote::quote! {#ident::__expand};
|
||||
let mut inputs = quote::quote!();
|
||||
|
||||
for input in &func.sig.inputs {
|
||||
match input {
|
||||
syn::FnArg::Typed(pat) => {
|
||||
let ident = pat.pat.clone();
|
||||
inputs.extend(quote::quote! {
|
||||
#ident,
|
||||
});
|
||||
}
|
||||
_ => todo!("Only Typed inputs are supported"),
|
||||
}
|
||||
}
|
||||
|
||||
let expand = expand_sig(
|
||||
&func.sig,
|
||||
&syn::Visibility::Inherited,
|
||||
None,
|
||||
ExpandMode::MethodImpl,
|
||||
);
|
||||
|
||||
let tokens = if !tr.generics.params.is_empty() {
|
||||
let mut func = func.clone();
|
||||
for param in tr.generics.params.iter() {
|
||||
func.sig.generics.params.push(param.clone());
|
||||
}
|
||||
register_expand(&func, &ident, expand, inputs)
|
||||
} else {
|
||||
register_expand(func, &ident, expand, inputs)
|
||||
};
|
||||
|
||||
expand_items.push(syn::parse2(tokens).unwrap());
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
tr.items.append(&mut expand_items);
|
||||
|
||||
quote::quote! {
|
||||
#tr
|
||||
}
|
||||
}
|
||||
|
||||
fn register_expand(
|
||||
func: &syn::ImplItemFn,
|
||||
name: &TokenStream,
|
||||
expand: proc_macro2::TokenStream,
|
||||
inputs: proc_macro2::TokenStream,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let (func, func_expand) = if func.sig.generics.params.is_empty() {
|
||||
(
|
||||
quote::quote! { #func },
|
||||
quote::quote! {
|
||||
#name(context, #inputs)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let (_, gen, _) = &func.sig.generics.split_for_impl();
|
||||
(
|
||||
quote::quote! { #func },
|
||||
quote::quote! {
|
||||
#name::#gen(context, #inputs)
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
quote::quote! (
|
||||
#expand {
|
||||
#[cube]
|
||||
#func
|
||||
#func_expand
|
||||
}
|
||||
)
|
||||
}
|
|
@ -1,294 +0,0 @@
|
|||
use proc_macro::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::Ident;
|
||||
|
||||
use super::GenericsCodegen;
|
||||
|
||||
struct TypeCodegen {
|
||||
name: syn::Ident,
|
||||
name_launch: syn::Ident,
|
||||
name_expand: syn::Ident,
|
||||
fields: Vec<syn::Field>,
|
||||
generics: GenericsCodegen,
|
||||
vis: syn::Visibility,
|
||||
}
|
||||
|
||||
impl TypeCodegen {
|
||||
pub fn expand_ty(&self) -> proc_macro2::TokenStream {
|
||||
let mut fields = quote::quote! {};
|
||||
let name = &self.name_expand;
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let ident = &field.ident;
|
||||
let ty = &field.ty;
|
||||
let vis = &field.vis;
|
||||
|
||||
fields.extend(quote! {
|
||||
#vis #ident: <#ty as CubeType>::ExpandType,
|
||||
});
|
||||
}
|
||||
|
||||
let generics = self.generics.type_definitions();
|
||||
let vis = &self.vis;
|
||||
|
||||
quote! {
|
||||
#[derive(Clone)]
|
||||
#vis struct #name #generics {
|
||||
#fields
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn launch_ty(&self) -> proc_macro2::TokenStream {
|
||||
let mut fields = quote::quote! {};
|
||||
let name = &self.name_launch;
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let ident = &field.ident;
|
||||
let ty = &field.ty;
|
||||
let vis = &field.vis;
|
||||
|
||||
fields.extend(quote! {
|
||||
#vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
|
||||
});
|
||||
}
|
||||
|
||||
let generics = self.generics.all_definitions();
|
||||
|
||||
quote! {
|
||||
struct #name #generics {
|
||||
#fields
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn launch_new(&self) -> proc_macro2::TokenStream {
|
||||
let mut args = quote::quote! {};
|
||||
let mut fields = quote::quote! {};
|
||||
let name = &self.name_launch;
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let ident = &field.ident;
|
||||
let ty = &field.ty;
|
||||
let vis = &field.vis;
|
||||
|
||||
args.extend(quote! {
|
||||
#vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
|
||||
});
|
||||
fields.extend(quote! {
|
||||
#ident,
|
||||
});
|
||||
}
|
||||
|
||||
let generics_impl = self.generics.all_definitions();
|
||||
let generics_use = self.generics.all_in_use();
|
||||
let vis = &self.vis;
|
||||
|
||||
quote! {
|
||||
impl #generics_impl #name #generics_use {
|
||||
/// New kernel
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#vis fn new(#args) -> Self {
|
||||
Self {
|
||||
#fields
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn arg_settings_impl(&self) -> proc_macro2::TokenStream {
|
||||
let mut register_body = quote::quote! {};
|
||||
let mut configure_input_body = quote::quote! {};
|
||||
let mut configure_output_body = quote::quote! {};
|
||||
let name = &self.name_launch;
|
||||
|
||||
for (pos, field) in self.fields.iter().enumerate() {
|
||||
let ident = &field.ident;
|
||||
|
||||
register_body.extend(quote! {
|
||||
self.#ident.register(launcher);
|
||||
});
|
||||
configure_input_body.extend(quote! {
|
||||
settings = ArgSettings::<R>::configure_input(&self.#ident, #pos, settings);
|
||||
});
|
||||
configure_output_body.extend(quote! {
|
||||
settings = ArgSettings::<R>::configure_output(&self.#ident, #pos, settings);
|
||||
});
|
||||
}
|
||||
|
||||
let generics_impl = self.generics.all_definitions();
|
||||
let generics_use = self.generics.all_in_use();
|
||||
|
||||
quote! {
|
||||
impl #generics_impl ArgSettings<R> for #name #generics_use {
|
||||
fn register(&self, launcher: &mut KernelLauncher<R>) {
|
||||
#register_body
|
||||
}
|
||||
|
||||
fn configure_input(&self, position: usize, mut settings: KernelSettings) -> KernelSettings {
|
||||
#configure_input_body
|
||||
|
||||
settings
|
||||
}
|
||||
|
||||
fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings {
|
||||
#configure_output_body
|
||||
|
||||
settings
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cube_type_impl(&self) -> proc_macro2::TokenStream {
|
||||
let name = &self.name;
|
||||
let name_expand = &self.name_expand;
|
||||
|
||||
let generics_impl = self.generics.type_definitions();
|
||||
let generics_use = self.generics.type_in_use();
|
||||
|
||||
quote! {
|
||||
impl #generics_impl CubeType for #name #generics_use {
|
||||
type ExpandType = #name_expand #generics_use;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn launch_arg_impl(&self) -> proc_macro2::TokenStream {
|
||||
let mut body_input = quote::quote! {};
|
||||
let mut body_output = quote::quote! {};
|
||||
let name = &self.name;
|
||||
let name_launch = &self.name_launch;
|
||||
let name_expand = &self.name_expand;
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let ident = &field.ident;
|
||||
let ty = &field.ty;
|
||||
let vis = &field.vis;
|
||||
|
||||
body_input.extend(quote! {
|
||||
#vis #ident: <#ty as LaunchArgExpand>::expand(builder, vectorization),
|
||||
});
|
||||
body_output.extend(quote! {
|
||||
#vis #ident: <#ty as LaunchArgExpand>::expand_output(builder, vectorization),
|
||||
});
|
||||
}
|
||||
|
||||
let type_generics_impl = self.generics.type_definitions();
|
||||
let type_generics_use = self.generics.type_in_use();
|
||||
|
||||
let runtime_generics_impl = self.generics.runtime_definitions();
|
||||
let all_generics_use = self.generics.all_in_use();
|
||||
|
||||
quote! {
|
||||
impl #type_generics_impl LaunchArg for #name #type_generics_use {
|
||||
type RuntimeArg #runtime_generics_impl = #name_launch #all_generics_use;
|
||||
}
|
||||
|
||||
impl #type_generics_impl LaunchArgExpand for #name #type_generics_use {
|
||||
fn expand(
|
||||
builder: &mut KernelBuilder,
|
||||
vectorization: burn_cube::ir::Vectorization,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
#name_expand {
|
||||
#body_input
|
||||
}
|
||||
}
|
||||
fn expand_output(
|
||||
builder: &mut KernelBuilder,
|
||||
vectorization: burn_cube::ir::Vectorization,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
#name_expand {
|
||||
#body_output
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expand_type_impl(&self) -> proc_macro2::TokenStream {
|
||||
let name_expand = &self.name_expand;
|
||||
let type_generics_impl = self.generics.type_definitions();
|
||||
let type_generics_use = self.generics.type_in_use();
|
||||
|
||||
let mut body = quote::quote! {};
|
||||
for field in self.fields.iter() {
|
||||
let ident = &field.ident;
|
||||
body.extend(quote::quote! {
|
||||
#ident: Init::init(self.#ident, context),
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
impl #type_generics_impl Init for #name_expand #type_generics_use {
|
||||
fn init(self, context: &mut CubeContext) -> Self {
|
||||
Self {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream {
|
||||
let name = ast.ident.clone();
|
||||
let generics = ast.generics.clone();
|
||||
let visibility = ast.vis.clone();
|
||||
|
||||
let name_string = name.to_string();
|
||||
let name_expand = Ident::new(format!("{}Expand", name_string).as_str(), name.span());
|
||||
let name_launch = Ident::new(format!("{}Launch", name_string).as_str(), name.span());
|
||||
|
||||
let mut fields = Vec::new();
|
||||
|
||||
match &ast.data {
|
||||
syn::Data::Struct(struct_data) => {
|
||||
for field in struct_data.fields.iter() {
|
||||
fields.push(field.clone());
|
||||
}
|
||||
}
|
||||
syn::Data::Enum(_) => panic!("Only struct can be derived"),
|
||||
syn::Data::Union(_) => panic!("Only struct can be derived"),
|
||||
};
|
||||
|
||||
let codegen = TypeCodegen {
|
||||
name,
|
||||
name_launch,
|
||||
name_expand,
|
||||
fields,
|
||||
generics: GenericsCodegen::new(generics),
|
||||
vis: visibility,
|
||||
};
|
||||
|
||||
let expand_ty = codegen.expand_ty();
|
||||
let launch_ty = codegen.launch_ty();
|
||||
let launch_new = codegen.launch_new();
|
||||
|
||||
let cube_type_impl = codegen.cube_type_impl();
|
||||
let arg_settings_impl = codegen.arg_settings_impl();
|
||||
let launch_arg_impl = codegen.launch_arg_impl();
|
||||
let expand_type_impl = codegen.expand_type_impl();
|
||||
|
||||
if with_launch {
|
||||
quote! {
|
||||
#expand_ty
|
||||
#launch_ty
|
||||
#launch_new
|
||||
|
||||
#cube_type_impl
|
||||
#arg_settings_impl
|
||||
#launch_arg_impl
|
||||
#expand_type_impl
|
||||
}
|
||||
.into()
|
||||
} else {
|
||||
quote! {
|
||||
#expand_ty
|
||||
#cube_type_impl
|
||||
#expand_type_impl
|
||||
}
|
||||
.into()
|
||||
}
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::ToTokens;
|
||||
use syn::{GenericParam, Generics, Ident, Lifetime, LifetimeParam, TypeParam};
|
||||
|
||||
pub(crate) struct GenericsCodegen {
|
||||
arg_lifetime: syn::Generics,
|
||||
arg_runtime: syn::Generics,
|
||||
type_gens: syn::Generics,
|
||||
}
|
||||
|
||||
impl GenericsCodegen {
|
||||
pub(crate) fn new(type_gens: syn::Generics) -> Self {
|
||||
Self {
|
||||
arg_lifetime: Self::arg_lifetime(),
|
||||
arg_runtime: Self::arg_runtime(),
|
||||
type_gens,
|
||||
}
|
||||
}
|
||||
|
||||
fn arg_lifetime() -> Generics {
|
||||
let mut generics = Generics::default();
|
||||
let lifetime =
|
||||
GenericParam::Lifetime(LifetimeParam::new(Lifetime::new("'a", Span::call_site())));
|
||||
generics.params.push(lifetime);
|
||||
generics
|
||||
}
|
||||
|
||||
fn arg_runtime() -> Generics {
|
||||
let mut generics = Generics::default();
|
||||
let mut runtime_param = TypeParam::from(Ident::new("R", Span::call_site()));
|
||||
runtime_param
|
||||
.bounds
|
||||
.push(syn::parse_str("Runtime").unwrap());
|
||||
let runtime = GenericParam::Type(runtime_param);
|
||||
generics.params.push(runtime);
|
||||
generics
|
||||
}
|
||||
|
||||
pub(crate) fn type_definitions(&self) -> TokenStream {
|
||||
self.type_gens.to_token_stream()
|
||||
}
|
||||
|
||||
pub(crate) fn type_in_use(&self) -> TokenStream {
|
||||
generics_in_use_codegen(self.type_gens.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn runtime_definitions(&self) -> TokenStream {
|
||||
let mut generics = self.arg_runtime.clone();
|
||||
generics.params.extend(self.arg_lifetime.params.clone());
|
||||
generics.to_token_stream()
|
||||
}
|
||||
|
||||
pub(crate) fn all_definitions(&self) -> TokenStream {
|
||||
let mut generics = self.arg_lifetime.clone();
|
||||
generics.params.extend(self.arg_runtime.params.clone());
|
||||
generics.params.extend(self.type_gens.params.clone());
|
||||
generics.to_token_stream()
|
||||
}
|
||||
|
||||
pub(crate) fn all_in_use(&self) -> TokenStream {
|
||||
let mut generics = self.arg_lifetime.clone();
|
||||
generics.params.extend(self.arg_runtime.params.clone());
|
||||
generics.params.extend(self.type_gens.params.clone());
|
||||
generics_in_use_codegen(generics)
|
||||
}
|
||||
}
|
||||
|
||||
fn generics_in_use_codegen(generics: Generics) -> TokenStream {
|
||||
let mut tokens = quote::quote! {<};
|
||||
for generic in generics.params.iter() {
|
||||
let ident = match generic {
|
||||
GenericParam::Lifetime(param) => param.lifetime.to_token_stream(),
|
||||
GenericParam::Type(param) => param.ident.to_token_stream(),
|
||||
GenericParam::Const(_) => todo!("Const generic not supported"),
|
||||
};
|
||||
tokens.extend(quote::quote! { #ident, })
|
||||
}
|
||||
tokens.extend(quote::quote! {>});
|
||||
|
||||
tokens
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
mod base;
|
||||
mod generics;
|
||||
|
||||
pub(crate) use base::*;
|
||||
pub(crate) use generics::*;
|
|
@ -1,182 +0,0 @@
|
|||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
mod analyzer;
|
||||
mod codegen_function;
|
||||
mod codegen_trait;
|
||||
mod codegen_type;
|
||||
mod tracker;
|
||||
|
||||
pub(crate) mod codegen_common;
|
||||
|
||||
use analyzer::VariableAnalyzer;
|
||||
use codegen_common::signature::{expand_sig, ExpandMode};
|
||||
use codegen_function::{codegen_launch, codegen_statement};
|
||||
use codegen_trait::{expand_trait_def, expand_trait_impl};
|
||||
use codegen_type::generate_cube_type;
|
||||
use proc_macro::TokenStream;
|
||||
use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta};
|
||||
use tracker::VariableTracker;
|
||||
|
||||
enum CubeMode {
|
||||
/// Generates the expanded version of the function
|
||||
Default,
|
||||
/// Panics and prints the generated code, useful when debugging
|
||||
/// Use by writing #[cube(panic)]
|
||||
Debug,
|
||||
}
|
||||
|
||||
// Derive macro to define a cube type that is launched with a kernel
|
||||
#[proc_macro_derive(CubeLaunch)]
|
||||
pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
|
||||
generate_cube_type(&input, true)
|
||||
}
|
||||
|
||||
// Derive macro to define a cube type that is not launched
|
||||
#[proc_macro_derive(CubeType)]
|
||||
pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
|
||||
generate_cube_type(&input, false)
|
||||
}
|
||||
|
||||
struct SupportedAttributes {
|
||||
mode: CubeMode,
|
||||
launch: bool,
|
||||
}
|
||||
|
||||
/// Derive macro for the module.
|
||||
#[proc_macro_attribute]
|
||||
pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
|
||||
let args = parse_macro_input!(attr with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
|
||||
let attrs = parse_attributes(&args);
|
||||
|
||||
let code: TokenStream = match syn::parse::<syn::Item>(tokens).unwrap() {
|
||||
syn::Item::Fn(func) => cube_fn(func, &attrs),
|
||||
syn::Item::Impl(item) => expand_trait_impl(item).into(),
|
||||
syn::Item::Trait(item) => expand_trait_def(item).into(),
|
||||
_ => panic!("Cube annotations only supported for functions"),
|
||||
};
|
||||
|
||||
match attrs.mode {
|
||||
CubeMode::Default => code,
|
||||
CubeMode::Debug => panic!("{code}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream {
|
||||
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
|
||||
|
||||
match codegen_cube(&func, &mut variable_tracker, attrs.launch) {
|
||||
Ok(code) => code.into(),
|
||||
Err(err) => err.into(),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_attributes(args: &Punctuated<Meta, Comma>) -> SupportedAttributes {
|
||||
let mut mode = CubeMode::Default;
|
||||
let mut launch = false;
|
||||
|
||||
for arg in args.iter() {
|
||||
match arg {
|
||||
Meta::Path(path) => {
|
||||
if let Some(ident) = path.get_ident().map(|id| id.to_string()) {
|
||||
match ident.as_str() {
|
||||
"debug" => {
|
||||
mode = CubeMode::Debug;
|
||||
}
|
||||
"launch" => {
|
||||
launch = true;
|
||||
}
|
||||
_ => panic!("Attribute {ident} is not supported"),
|
||||
}
|
||||
} else {
|
||||
panic!("Only ident attribute supported");
|
||||
}
|
||||
}
|
||||
Meta::List(_) => panic!("No List attribute supported"),
|
||||
Meta::NameValue(_) => panic!("No NameValue attribute supported"),
|
||||
}
|
||||
}
|
||||
|
||||
SupportedAttributes { mode, launch }
|
||||
}
|
||||
|
||||
/// Generate the expanded version of a function marked with the cube macro
|
||||
fn codegen_cube(
|
||||
func: &syn::ItemFn,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
launch: bool,
|
||||
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
|
||||
let signature = expand_sig(
|
||||
&func.sig,
|
||||
&syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import
|
||||
// it from an outside module.
|
||||
Some(variable_tracker),
|
||||
ExpandMode::FuncImpl,
|
||||
);
|
||||
let mut body = quote::quote! {};
|
||||
|
||||
for statement in func.block.stmts.iter() {
|
||||
let tokens = codegen_statement(statement, 0, variable_tracker);
|
||||
body.extend(tokens);
|
||||
}
|
||||
|
||||
let is_in_error = !variable_tracker.errors.is_empty();
|
||||
|
||||
if is_in_error {
|
||||
// When there is an error, we don't generate the expand method, since it's only going to
|
||||
// create more errors that won't help fixing the issue.
|
||||
|
||||
let mut code = quote::quote! {
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#func
|
||||
};
|
||||
|
||||
for err in variable_tracker.errors.drain(..) {
|
||||
code.extend(err.into_compile_error());
|
||||
}
|
||||
|
||||
return Err(code);
|
||||
}
|
||||
|
||||
let launch_doc = if launch {
|
||||
"and launch functions "
|
||||
} else {
|
||||
"function "
|
||||
};
|
||||
|
||||
let launch = if launch {
|
||||
codegen_launch(&func.sig)
|
||||
} else {
|
||||
quote::quote! {}
|
||||
};
|
||||
|
||||
let mod_name = &func.sig.ident;
|
||||
let vis = &func.vis;
|
||||
let doc = format!("Module containing the expand {launch_doc}of {mod_name}.");
|
||||
|
||||
Ok(quote::quote! {
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#func
|
||||
|
||||
|
||||
#[doc = #doc]
|
||||
#vis mod #mod_name {
|
||||
use super::*;
|
||||
|
||||
#launch
|
||||
|
||||
#[allow(unused_mut)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#signature {
|
||||
#body
|
||||
}
|
||||
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,244 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
#[derive(new, Hash, PartialEq, Eq, Debug, Clone)]
|
||||
/// Identifies a variable uniquely
|
||||
pub struct VariableIdent {
|
||||
name: String,
|
||||
repeat: u8,
|
||||
scope: u8,
|
||||
field: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(new, Eq, PartialEq, Hash, Debug)]
|
||||
/// Identifies a variable, with possible collisions when variables are redeclared
|
||||
struct VariableKey {
|
||||
name: String,
|
||||
scope: u8,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
/// Tracks variable uses
|
||||
pub(crate) struct VariableTracker {
|
||||
scopes_declared: HashMap<String, Vec<u8>>,
|
||||
analysis_repeats: HashMap<VariableKey, u8>,
|
||||
codegen_repeats: HashMap<VariableKey, u8>,
|
||||
variable_uses: HashMap<VariableIdent, VariableUse>,
|
||||
pub errors: Vec<syn::Error>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
/// Encapsulates number of uses and whether this implies cloning
|
||||
pub(crate) struct VariableUse {
|
||||
pub num_used: usize,
|
||||
pub is_comptime: bool,
|
||||
}
|
||||
|
||||
impl VariableUse {
|
||||
pub fn should_clone(&self) -> bool {
|
||||
self.num_used > 1
|
||||
}
|
||||
}
|
||||
|
||||
impl VariableTracker {
|
||||
/// During analysis, tracks a variable declaration
|
||||
pub(crate) fn analyze_declare(&mut self, name: String, scope: u8, is_comptime: bool) {
|
||||
if let Some(scopes) = self.scopes_declared.get_mut(&name) {
|
||||
if !scopes.contains(&scope) {
|
||||
scopes.push(scope);
|
||||
}
|
||||
} else {
|
||||
self.scopes_declared.insert(name.clone(), vec![scope]);
|
||||
}
|
||||
|
||||
let key = VariableKey::new(name.clone(), scope);
|
||||
let repeat = if let Some(count) = self.analysis_repeats.get_mut(&key) {
|
||||
*count += 1;
|
||||
*count
|
||||
} else {
|
||||
self.analysis_repeats.insert(key, 0);
|
||||
0
|
||||
};
|
||||
|
||||
let analysis = VariableUse {
|
||||
num_used: 1,
|
||||
is_comptime,
|
||||
};
|
||||
let variable_ident = VariableIdent::new(name, repeat, scope, None);
|
||||
self.variable_uses.insert(variable_ident, analysis);
|
||||
}
|
||||
|
||||
/// During analysis, tracks a variable use
|
||||
pub(crate) fn analyze_reuse(&mut self, ident: &syn::Ident, scope: u8, field: Option<String>) {
|
||||
let name = ident.to_string();
|
||||
|
||||
if name == "None" {
|
||||
return;
|
||||
}
|
||||
|
||||
let scopes_declared = match self.scopes_declared.get(&name) {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
self.errors
|
||||
.push(syn::Error::new_spanned(ident, "Variable not declared"));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let scope = *scopes_declared
|
||||
.iter()
|
||||
.filter(|s| **s <= scope)
|
||||
.max()
|
||||
.unwrap();
|
||||
let key = VariableKey::new(name.clone(), scope);
|
||||
|
||||
// If the name and scope do not match a declared variable,
|
||||
// then we are using a variable declared in a parent scope, and
|
||||
// cloning must always happen, therefore no need for further analysis
|
||||
if let Some(repeat) = self.analysis_repeats.get(&key) {
|
||||
let variable = VariableIdent::new(name, *repeat, scope, field);
|
||||
self.analyze(&variable);
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments variable use and its parent struct if need be
|
||||
fn analyze(&mut self, variable_ident: &VariableIdent) {
|
||||
match self.variable_uses.get_mut(variable_ident) {
|
||||
Some(variable_use) => {
|
||||
variable_use.num_used += 1;
|
||||
}
|
||||
None => {
|
||||
// If variable was not inserted yet, it must be a field
|
||||
if variable_ident.field.is_some() {
|
||||
let mut parent_ident = variable_ident.clone();
|
||||
parent_ident.field = None;
|
||||
let parent = self.variable_uses.get(&parent_ident).unwrap();
|
||||
|
||||
let attr_analysis = VariableUse {
|
||||
num_used: 1,
|
||||
is_comptime: parent.is_comptime,
|
||||
};
|
||||
self.variable_uses
|
||||
.insert(variable_ident.clone(), attr_analysis);
|
||||
} else {
|
||||
panic!("Variable not declared");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Whether a field was previously seen or not, we must increase the use of the parent struct
|
||||
if variable_ident.field.is_some() {
|
||||
let mut declaration_ident = variable_ident.clone();
|
||||
declaration_ident.field = None;
|
||||
let declaration = self
|
||||
.variable_uses
|
||||
.get_mut(&declaration_ident)
|
||||
.unwrap_or_else(|| panic!("Struct {:?} does not exist", declaration_ident));
|
||||
declaration.num_used += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// During codegen, tracks a variable declaration.
|
||||
/// This must be done again to know on what repeat a use occurs
|
||||
pub(crate) fn codegen_declare(&mut self, name: String, scope: u8) {
|
||||
let key = VariableKey::new(name.clone(), scope);
|
||||
if let Some(count) = self.codegen_repeats.get_mut(&key) {
|
||||
*count += 1;
|
||||
} else {
|
||||
self.codegen_repeats.insert(key, 0);
|
||||
}
|
||||
}
|
||||
|
||||
/// During codegen, tracks a variable use.
|
||||
pub(crate) fn codegen_reuse(
|
||||
&mut self,
|
||||
name: String,
|
||||
scope: u8,
|
||||
field: Option<String>,
|
||||
) -> Result<(bool, bool), VariableReuseError> {
|
||||
let scopes_declared = self
|
||||
.scopes_declared
|
||||
.get(&name)
|
||||
.ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?;
|
||||
let scope_declared = *scopes_declared
|
||||
.iter()
|
||||
.filter(|s| **s <= scope)
|
||||
.max()
|
||||
.ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?;
|
||||
|
||||
let key = VariableKey::new(name.clone(), scope_declared);
|
||||
let repeat = self.codegen_repeats.get(&key).unwrap_or(&0);
|
||||
let ident = VariableIdent::new(name.clone(), *repeat, scope_declared, field.clone());
|
||||
|
||||
let should_clone_parent = if field.is_some() {
|
||||
let struct_ident = VariableIdent::new(name.clone(), *repeat, scope_declared, None);
|
||||
let parent_analysis = self
|
||||
.variable_uses
|
||||
.get_mut(&struct_ident)
|
||||
.ok_or_else(|| VariableNotFound::new(name.clone(), scope_declared, None))?;
|
||||
|
||||
parent_analysis.num_used -= 1;
|
||||
parent_analysis.should_clone()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let analysis = self
|
||||
.variable_uses
|
||||
.get_mut(&ident)
|
||||
.ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?;
|
||||
|
||||
analysis.num_used -= 1;
|
||||
let should_clone =
|
||||
analysis.should_clone() || should_clone_parent || scope_declared != scope;
|
||||
Ok((should_clone, analysis.is_comptime))
|
||||
}
|
||||
|
||||
pub fn set_as_comptime(
|
||||
&mut self,
|
||||
name: String,
|
||||
scope: u8,
|
||||
field: Option<String>,
|
||||
) -> Result<(), VariableReuseError> {
|
||||
let scopes_declared = self
|
||||
.scopes_declared
|
||||
.get(&name)
|
||||
.ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?;
|
||||
let scope_declared = *scopes_declared
|
||||
.iter()
|
||||
.filter(|s| **s <= scope)
|
||||
.max()
|
||||
.ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?;
|
||||
|
||||
let key = VariableKey::new(name.clone(), scope_declared);
|
||||
let repeat = self.codegen_repeats.get(&key).unwrap_or(&0);
|
||||
let ident = VariableIdent::new(name.clone(), *repeat, scope_declared, field.clone());
|
||||
|
||||
let analysis = self
|
||||
.variable_uses
|
||||
.get_mut(&ident)
|
||||
.ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?;
|
||||
|
||||
analysis.is_comptime = true;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct VariableNotFound {
|
||||
_name: String,
|
||||
_scope: u8,
|
||||
_field: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum VariableReuseError {
|
||||
VariableNotFound(VariableNotFound),
|
||||
}
|
||||
|
||||
impl From<VariableNotFound> for VariableReuseError {
|
||||
fn from(value: VariableNotFound) -> Self {
|
||||
Self::VariableNotFound(value)
|
||||
}
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
[package]
|
||||
authors = [
|
||||
"nathanielsimard <nathaniel.simard.42@gmail.com>",
|
||||
"louisfd <louisfd94@gmail.com>",
|
||||
]
|
||||
categories = ["science"]
|
||||
description = "Cube Compute Language (CubeCL) is a subset of Rust that can be executed on accelerators for compute intensive tasks."
|
||||
edition.workspace = true
|
||||
keywords = []
|
||||
license.workspace = true
|
||||
name = "burn-cube"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube"
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["tensor"]
|
||||
std = []
|
||||
template = []
|
||||
tensor = ["burn-tensor"]
|
||||
export_tests = []
|
||||
|
||||
[dependencies]
|
||||
burn-compute = { path = "../burn-compute", version = "0.14.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.14.0", default-features = false, optional = true }
|
||||
|
||||
bytemuck = { workspace = true }
|
||||
half = { workspace = true, features = ["bytemuck"] }
|
||||
serde = { workspace = true }
|
||||
burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" }
|
||||
derive-new = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
log = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
trybuild = "1"
|
|
@ -1,202 +0,0 @@
|
|||
<div align="center">
|
||||
<img src="../burn-cube/assets/logo.drawio.svg" width="400px"/>
|
||||
|
||||
<br />
|
||||
<br />
|
||||
|
||||
[![Rust Version](https://img.shields.io/badge/Rust-1.79.0+-blue)](https://releases.rs/docs/1.79.0)
|
||||
![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)
|
||||
|
||||
---
|
||||
|
||||
**Multi-platform high-performance compute language extension for Rust.**
|
||||
<br/>
|
||||
|
||||
</div>
|
||||
|
||||
## TL;DR
|
||||
|
||||
With CubeCL, you can program your GPU using Rust leveraging zero-cost abstraction to create maintainable, flexible and optimal compute kernels.
|
||||
|
||||
## Motivation
|
||||
|
||||
The goal of CubeCL is to ease the pain of writing highly optimized compute kernels that are portable across hardware.
|
||||
There is currently no adequate solution when you want optimal performance while still being multi-platform.
|
||||
You either have to write custom kernels for different hardware, often with different languages such as CUDA, Metal, or ROCm.
|
||||
To fix this, we created a Just-in-Time compiler with three core features: **automatic vectorization**, **comptime**, and **autotune**!
|
||||
|
||||
These features are extremely useful for anyone writing high-performance kernels, even when portability is not a concern.
|
||||
They improve code composability, reusability, testability, and maintainability, all while staying optimal.
|
||||
|
||||
### Disclaimer & History
|
||||
|
||||
CubeCL is currently in **alpha**.
|
||||
The only supported runtimes are CUDA and WebGPU for now.
|
||||
It's easy to add more GPU runtimes and we intend to support Metal, ROCm, and Vulkan; contributions are welcome!
|
||||
We also want to have an optimized JIT CPU runtime with SIMD instructions, leveraging [Cranelift](https://cranelift.dev).
|
||||
|
||||
While CubeCL is currently in use in [Burn](https://burn.dev), there are still a lot of rough edges; it isn't refined yet.
|
||||
The project started as a WebGPU-only backend for Burn.
|
||||
As we optimized it, we realized that we needed an intermediate representation (IR) that could be optimized then compiled to WGSL.
|
||||
Having an IR made it easy to support another compilation target, so we made a CUDA runtime.
|
||||
However, writing kernels directly in that IR wasn't easy, so we created a Rust frontend using the [syn](https://github.com/dtolnay/syn) crate.
|
||||
Navigating the differences between CUDA and WebGPU, while leveraging both platforms, forced us to come up with general concepts that worked everywhere.
|
||||
Hence, CubeCL was born!
|
||||
|
||||
## Design
|
||||
|
||||
CubeCL is designed around - you guessed it - Cubes! More specifically, it's based on cuboids, because not all axes are the same size.
|
||||
Since all compute APIs need to map to the hardware, which are tiles that can be accessed using a 3D representation, our topology can easily be mapped to concepts from other APIs.
|
||||
|
||||
<div align="center">
|
||||
|
||||
### CubeCL - Topology
|
||||
|
||||
<img src="./assets/cubecl.drawio.svg" width="100%"/>
|
||||
<br />
|
||||
</div>
|
||||
<br />
|
||||
|
||||
_A cube is composed of units, so a 3x3x3 cube has 27 units that can be accessed by their positions along the x, y, and z axes.
|
||||
Similarly, a hyper-cube is composed of cubes, just as a cube is composed of units.
|
||||
Each cube in the hyper-cube can be accessed by its position relative to the hyper-cube along the x, y, and z axes.
|
||||
Hence, a hyper-cube of 3x3x3 will have 27 cubes.
|
||||
In this example, the total number of working units would be 27 x 27 = 729._
|
||||
|
||||
<details>
|
||||
<summary>Topology Equivalence 👇</summary>
|
||||
<br />
|
||||
|
||||
Since all topology variables are constant within the kernel entry point, we chose to use the Rust constant syntax with capital letters.
|
||||
Often when creating kernels, we don't always care about the relative position of a unit within a cube along each axis, but often we only care about its position in general.
|
||||
Therefore, each kind of variable also has its own axis-independent variable, which is often not present in other languages, except WebGPU with `local_invocation_index`.
|
||||
|
||||
<br />
|
||||
|
||||
| CubeCL | CUDA | WebGPU |
|
||||
| -------------- | ----------- | ---------------------- |
|
||||
| CUBE_COUNT | N/A | N/A |
|
||||
| CUBE_COUNT_X | gridDim.x | num_workgroups.x |
|
||||
| CUBE_COUNT_Y | gridDim.y | num_workgroups.y |
|
||||
| CUBE_COUNT_Z | gridDim.z | num_workgroups.z |
|
||||
| CUBE_POS | N/A | N/A |
|
||||
| CUBE_POS_X | blockIdx.x | workgroup.x |
|
||||
| CUBE_POS_Y | blockIdx.y | workgroup.y |
|
||||
| CUBE_POS_Z | blockIdx.z | workgroup.z |
|
||||
| CUBE_DIM | N/A | N/A |
|
||||
| CUBE_DIM_X | blockDim.x | workgroup_size.x |
|
||||
| CUBE_DIM_Y | blockDim.y | workgroup_size.y |
|
||||
| CUBE_DIM_Z | blockDim.z | workgroup_size.z |
|
||||
| UNIT_POS | N/A | local_invocation_index |
|
||||
| UNIT_POS_X | threadIdx.x | local_invocation_id.x |
|
||||
| UNIT_POS_Y | threadIdx.y | local_invocation_id.y |
|
||||
| UNIT_POS_Z | threadIdx.z | local_invocation_id.z |
|
||||
| SUBCUBE_DIM | warpSize | subgroup_size |
|
||||
| ABSOLUTE_POS | N/A | N/A |
|
||||
| ABSOLUTE_POS_X | N/A | global_id.x |
|
||||
| ABSOLUTE_POS_Y | N/A | global_id.y |
|
||||
| ABSOLUTE_POS_Z | N/A | global_id.z |
|
||||
|
||||
</details>
|
||||
|
||||
## Special Features
|
||||
|
||||
#### Automatic Vectorization
|
||||
|
||||
High-performance kernels should rely on SIMD instructions whenever possible, but doing so can quickly get pretty complicated!
|
||||
With CubeCL, you can specify the vectorization factor of each input variable when launching a kernel.
|
||||
Inside the kernel code, you still use only one type, which is dynamically vectorized and supports automatic broadcasting.
|
||||
The runtimes are able to compile kernels and have all the necessary information to use the best instruction!
|
||||
However, since the algorithmic behavior may depend on the vectorization factor, CubeCL allows you to access it directly in the kernel when needed, without any performance loss, using the comptime system!
|
||||
|
||||
#### Comptime
|
||||
|
||||
CubeCL isn't just a new compute language: though it feels like you are writing GPU kernels, you are, in fact, writing compiler plugins that you can fully customize!
|
||||
Comptime is a way to modify the compiler IR at runtime when compiling a kernel for the first time.
|
||||
|
||||
This enables lots of optimizations and flexibility without having to write many separate variants of the same kernels to ensure maximal performance.
|
||||
|
||||
| Feature | Description |
|
||||
| ------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Instruction Specialization** | Not all instructions are available on all hardware, but when a specialized one exists, it should be enabled with a simple if statement. |
|
||||
| **Automatic Vectorization** | When you can use SIMD instructions, you should! But since not all hardware supports the same vectorization factors, it can be injected at runtime! |
|
||||
| **Loop Unrolling** | You may want multiple flavors of the same kernel, with loop unrolling for only a certain range of values. This can be configured easily with Comptime. |
|
||||
| **Shape Specialization** | For deep learning kernels, it's often crucial to rely on different kernels for different input sizes; you can do it by passing the shape information as Comptime values. |
|
||||
| **Compile Time Calculation** | In general, you can calculate a constant using Rust runtime properties and inject it into a kernel during its compilation, to avoid recalculating it during each execution. |
|
||||
|
||||
#### Autotuning
|
||||
|
||||
Autotuning drastically simplifies kernel selection by running small benchmarks at runtime to figure out the best kernels with the best configurations to run on the current hardware; an essential feature for portability.
|
||||
This feature combines gracefully with comptime to test the effect of different comptime values on performance; sometimes it can be surprising!
|
||||
|
||||
Even if the benchmarks may add some overhead when running the application for the first time, the information gets cached on the device and will be reused.
|
||||
It is usually a no-brainer trade-off for throughput-oriented programs such as deep learning models.
|
||||
You can even ship the autotune cache with your program, reducing cold start time when you have more control over the deployment target.
|
||||
|
||||
## Example
|
||||
|
||||
CubeCL is designed to be easy to use for Rust programmers: it relies on the same syntax and is fully integrated with the language.
|
||||
You can simply add an attribute on the top of a Rust function for it to be executed on the GPU.
|
||||
|
||||
```rust
|
||||
#[cube(launch)]
|
||||
fn gelu<F: Float>(input: &Array<F>, output: &mut Array<F>) {
|
||||
if ABSOLUTE_POS < input.len() {
|
||||
let x = input[ABSOLUTE_POS]
|
||||
let gelu = x * (1 + erf(x / sqrt(2))) / 2;
|
||||
output[ABSOLUTE_POS] = gelu;
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
type Runtime = CudaRuntime;
|
||||
|
||||
let device = Default::default();
|
||||
let client = Runtime::client(&device);
|
||||
|
||||
let input_handle = client.create(f32::as_bytes(&[-1., 0., 1., 5.]));
|
||||
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
|
||||
|
||||
gelu::launch::<F32, Runtime>(
|
||||
client,
|
||||
CubeCount::new(1, 1, 1),
|
||||
CubeDim::new(4, 1, 1),
|
||||
&input_handle,
|
||||
&output_handle,
|
||||
);
|
||||
|
||||
let output = client.read(output_handle.binding()).read_sync().unwrap();
|
||||
let output = f32::from_bytes(&output);
|
||||
|
||||
// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
|
||||
println!("{output:?}");
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
The `cube` attribute generates the code that is needed to compile a kernel.
|
||||
In the case above, the function `gelu_expand` and `gelu_launch` are automatically generated.
|
||||
This allows you to compose Cube functions easily:
|
||||
|
||||
```rust
|
||||
|
||||
#[cube]
|
||||
fn gelu_scalar<F: Float>(x: F) -> F {
|
||||
x * (1 + erf(x / sqrt(2))) / 2
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn gelu<F: Float>(input: Array<F>, mut output: Array<F>) {
|
||||
if ABSOLUTE_POS < input.shape(0) {
|
||||
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that you don't have to specify `launch` in a function that is only used by another Cube function.
|
||||
In addition, you can have return types without problem, which isn't the case when you are writing an entry point to a kernel using the `launch` attribute.
|
||||
The function `gelu_expand` will actually use `gelu_scalar_expand`, making it easy to combine your functions.
|
||||
|
||||
## Resource
|
||||
|
||||
If you have any questions or want to contribute, don't hesitate to join the [Discord](https://discord.gg/uPEBbYYDB6).
|
Binary file not shown.
Before Width: | Height: | Size: 32 KiB |
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 240 KiB |
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 807 KiB |
|
@ -1,21 +0,0 @@
|
|||
use crate::ir::{Elem, KernelDefinition};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Trait for compiled code representation
|
||||
pub trait CompilerRepresentation: Display {
|
||||
/// Computes and returns the shared memory size
|
||||
fn shared_memory_size(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Compiles the representation into its own representation that can be formatted into tokens.
|
||||
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
|
||||
/// The representation for the compiled code.
|
||||
type Representation: CompilerRepresentation;
|
||||
|
||||
/// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
|
||||
fn compile(kernel: KernelDefinition) -> Self::Representation;
|
||||
/// The size of the given element in bytes.
|
||||
fn elem_size(elem: Elem) -> usize;
|
||||
/// The maximal size of a shared memory
|
||||
fn max_shared_memory_size() -> usize;
|
||||
}
|
|
@ -1,358 +0,0 @@
|
|||
use crate::compute::{CubeCount, KernelTask};
|
||||
use crate::frontend::TensorHandle;
|
||||
use crate::ir::Elem;
|
||||
use crate::pod::CubeElement;
|
||||
use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX};
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_compute::server::{Binding, ComputeServer, Handle};
|
||||
|
||||
/// The position of the input or output to calculate the number of cubes to launch.
|
||||
pub enum CubeCountSettings<S: ComputeServer> {
|
||||
Input { pos: usize },
|
||||
Output { pos: usize },
|
||||
Custom(CubeCount<S>),
|
||||
}
|
||||
|
||||
pub struct Execution<'h, K, R: Runtime, Scalars> {
|
||||
scalars: Scalars,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
kernel: K,
|
||||
inputs: &'h [TensorHandle<'h, R>],
|
||||
outputs: &'h [TensorHandle<'h, R>],
|
||||
}
|
||||
|
||||
impl<'h, K, R: Runtime> Execution<'h, K, R, ()> {
|
||||
pub fn start(
|
||||
kernel: K,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) -> Execution<'h, K, R, ()> {
|
||||
Execution {
|
||||
scalars: (),
|
||||
client,
|
||||
kernel,
|
||||
inputs: &[],
|
||||
outputs: &[],
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn inputs(self, inputs: &'h [TensorHandle<'h, R>]) -> Execution<'h, K, R, ()> {
|
||||
Execution {
|
||||
scalars: self.scalars,
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs,
|
||||
outputs: self.outputs,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn outputs(self, outputs: &'h [TensorHandle<'h, R>]) -> Execution<'h, K, R, ()> {
|
||||
Execution {
|
||||
scalars: self.scalars,
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs: self.inputs,
|
||||
outputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'h, K, R> Execution<'h, K, R, ()>
|
||||
where
|
||||
K: Kernel + 'static,
|
||||
R: Runtime,
|
||||
{
|
||||
pub fn with_scalars<E>(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> {
|
||||
Execution {
|
||||
scalars: (scalars,),
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs: self.inputs,
|
||||
outputs: self.outputs,
|
||||
}
|
||||
}
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||
execute_dynamic::<R, K, f32, f32, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.kernel,
|
||||
launch,
|
||||
self.client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)>
|
||||
where
|
||||
K: Kernel + 'static,
|
||||
R: Runtime,
|
||||
E: CubeElement,
|
||||
{
|
||||
pub fn with_scalars<'b, E2>(
|
||||
self,
|
||||
scalars: &'b [E2],
|
||||
) -> Execution<'h, K, R, (&'a [E], &'b [E2])> {
|
||||
Execution {
|
||||
scalars: (self.scalars.0, scalars),
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs: self.inputs,
|
||||
outputs: self.outputs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||
execute_dynamic::<R, K, E, f32, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
None,
|
||||
None,
|
||||
self.kernel,
|
||||
launch,
|
||||
self.client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])>
|
||||
where
|
||||
K: Kernel + 'static,
|
||||
R: Runtime,
|
||||
E1: CubeElement,
|
||||
E2: CubeElement,
|
||||
{
|
||||
#[allow(unused, clippy::type_complexity)]
|
||||
pub fn with_scalars<'c, E3>(
|
||||
self,
|
||||
scalars: &'c [E3],
|
||||
) -> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> {
|
||||
Execution {
|
||||
scalars: (self.scalars.0, self.scalars.1, scalars),
|
||||
client: self.client,
|
||||
kernel: self.kernel,
|
||||
inputs: self.inputs,
|
||||
outputs: self.outputs,
|
||||
}
|
||||
}
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn execute(self, launch: CubeCountSettings<R::Server>)
|
||||
where
|
||||
K: Kernel + 'static,
|
||||
R: Runtime,
|
||||
{
|
||||
execute_dynamic::<R, K, E1, E2, f32>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
Some(self.scalars.1),
|
||||
None,
|
||||
self.kernel,
|
||||
launch,
|
||||
self.client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'h, 'a, 'b, 'c, K, R, E1, E2, E3> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])>
|
||||
where
|
||||
K: Kernel + 'static,
|
||||
R: Runtime,
|
||||
E1: CubeElement,
|
||||
E2: CubeElement,
|
||||
E3: CubeElement,
|
||||
{
|
||||
/// Execute a dynamic kernel.
|
||||
#[allow(unused)]
|
||||
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
|
||||
execute_dynamic::<R, K, E1, E2, E3>(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
Some(self.scalars.0),
|
||||
Some(self.scalars.1),
|
||||
Some(self.scalars.2),
|
||||
self.kernel,
|
||||
launch,
|
||||
self.client,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn execute_dynamic<R, K, E1, E2, E3>(
|
||||
inputs: &[TensorHandle<R>],
|
||||
outputs: &[TensorHandle<R>],
|
||||
scalars_1: Option<&[E1]>,
|
||||
scalars_2: Option<&[E2]>,
|
||||
scalars_3: Option<&[E3]>,
|
||||
kernel: K,
|
||||
launch: CubeCountSettings<R::Server>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) where
|
||||
K: Kernel + 'static,
|
||||
R: Runtime,
|
||||
E1: CubeElement,
|
||||
E2: CubeElement,
|
||||
E3: CubeElement,
|
||||
{
|
||||
let settings = execute_settings(
|
||||
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
|
||||
);
|
||||
let mut handles = settings.handles_tensors;
|
||||
|
||||
handles.push(settings.handle_info.binding());
|
||||
for handle in settings.handles_scalars.into_iter() {
|
||||
handles.push(handle.binding());
|
||||
}
|
||||
|
||||
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
|
||||
client.execute(kernel, settings.cube_count, handles);
|
||||
}
|
||||
|
||||
struct ExecuteSettings<R: Runtime> {
|
||||
handles_tensors: Vec<Binding<R::Server>>,
|
||||
handle_info: Handle<R::Server>,
|
||||
handles_scalars: Vec<Handle<R::Server>>,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
}
|
||||
|
||||
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
|
||||
inputs: &'a [TensorHandle<R>],
|
||||
outputs: &'a [TensorHandle<R>],
|
||||
scalars_1: Option<&[E1]>,
|
||||
scalars_2: Option<&[E2]>,
|
||||
scalars_3: Option<&[E3]>,
|
||||
launch: CubeCountSettings<R::Server>,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
) -> ExecuteSettings<R> {
|
||||
let mut info = Vec::new();
|
||||
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
|
||||
|
||||
// Inner function to fill the info buffer.
|
||||
let mut register_info_tensor = |strides: &[usize], shape: &[usize]| {
|
||||
if info.is_empty() {
|
||||
info.push(strides.len() as u32);
|
||||
}
|
||||
|
||||
for s in strides.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
for s in shape.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
};
|
||||
|
||||
let mut num_elems_output = 0;
|
||||
|
||||
// We start by registering the inputs.
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
if let CubeCountSettings::Input { pos } = &launch {
|
||||
if i == *pos {
|
||||
num_elems_output = calculate_num_elems_dyn_rank(input.shape);
|
||||
}
|
||||
};
|
||||
register_info_tensor(input.strides, input.shape);
|
||||
handles.push(input.handle.clone().binding());
|
||||
}
|
||||
|
||||
// Then we follow with the outputs.
|
||||
for (i, output) in outputs.iter().enumerate() {
|
||||
if let CubeCountSettings::Output { pos } = &launch {
|
||||
if i == *pos {
|
||||
num_elems_output = calculate_num_elems_dyn_rank(output.shape);
|
||||
}
|
||||
};
|
||||
register_info_tensor(output.strides, output.shape);
|
||||
handles.push(output.handle.clone().binding());
|
||||
}
|
||||
|
||||
// [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len]
|
||||
if R::require_array_lengths() {
|
||||
for input in inputs.iter() {
|
||||
let len = calculate_num_elems_dyn_rank(input.shape);
|
||||
info.push(len as u32);
|
||||
}
|
||||
|
||||
for output in outputs.iter() {
|
||||
let len = calculate_num_elems_dyn_rank(output.shape);
|
||||
info.push(len as u32);
|
||||
}
|
||||
}
|
||||
|
||||
let info = client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
let handles_scalars =
|
||||
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
|
||||
|
||||
let cube_count = match launch {
|
||||
CubeCountSettings::Custom(count) => count,
|
||||
_ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX),
|
||||
};
|
||||
|
||||
ExecuteSettings {
|
||||
handles_tensors: handles,
|
||||
handle_info: info,
|
||||
handles_scalars,
|
||||
cube_count,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
|
||||
scalars_0: Option<&[E1]>,
|
||||
scalars_1: Option<&[E2]>,
|
||||
scalars_2: Option<&[E3]>,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
) -> Vec<Handle<R::Server>> {
|
||||
// It is crucial that scalars follow this order: float, int, uint
|
||||
let element_priority = |elem: Elem| match elem {
|
||||
Elem::Float(_) => 0,
|
||||
Elem::Int(_) => 1,
|
||||
Elem::UInt => 2,
|
||||
Elem::Bool => panic!("Bool scalars are not supported"),
|
||||
};
|
||||
let scalar_priorities: [usize; 3] = [
|
||||
element_priority(E1::cube_elem()),
|
||||
element_priority(E2::cube_elem()),
|
||||
element_priority(E3::cube_elem()),
|
||||
];
|
||||
|
||||
let mut handles_scalars = Vec::new();
|
||||
for i in 0..3 {
|
||||
for (j, scalar_priority) in scalar_priorities.iter().enumerate() {
|
||||
if scalar_priority == &i {
|
||||
if j == 0 {
|
||||
if let Some(values) = &scalars_0 {
|
||||
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
|
||||
}
|
||||
} else if j == 1 {
|
||||
if let Some(values) = &scalars_1 {
|
||||
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
|
||||
}
|
||||
} else if j == 2 {
|
||||
if let Some(values) = &scalars_2 {
|
||||
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handles_scalars
|
||||
}
|
||||
|
||||
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
||||
let mut num_elems = 1;
|
||||
for i in shape.iter() {
|
||||
num_elems *= i;
|
||||
}
|
||||
num_elems
|
||||
}
|
|
@ -1,560 +0,0 @@
|
|||
use super::Compiler;
|
||||
use crate::{
|
||||
ir::{
|
||||
Binding, CubeDim, Elem, Item, KernelDefinition, Location, ReadingStrategy, Scope, Variable,
|
||||
Vectorization, Visibility,
|
||||
},
|
||||
Runtime,
|
||||
};
|
||||
|
||||
/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
|
||||
/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
|
||||
#[derive(Clone)]
|
||||
pub struct KernelIntegrator {
|
||||
expansion: KernelExpansion,
|
||||
input_bindings: Vec<Binding>,
|
||||
output_bindings: Vec<Binding>,
|
||||
named_bindings: Vec<(String, Binding)>,
|
||||
}
|
||||
|
||||
/// The information necessary to compile a [kernel definition](KernelDefinition).
|
||||
#[derive(Clone)]
|
||||
pub struct KernelExpansion {
|
||||
pub inputs: Vec<InputInfo>,
|
||||
pub outputs: Vec<OutputInfo>,
|
||||
pub scope: Scope,
|
||||
}
|
||||
|
||||
/// Simply indicate the output that can be replaced by the input.
|
||||
#[derive(new, Clone, Copy, Debug)]
|
||||
pub struct InplaceMapping {
|
||||
/// Input position.
|
||||
pub pos_input: usize,
|
||||
/// Output position.
|
||||
pub pos_output: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
enum VectorizationPartial {
|
||||
Input {
|
||||
pos: usize,
|
||||
vectorization: Vectorization,
|
||||
},
|
||||
Output {
|
||||
pos: usize,
|
||||
vectorization: Vectorization,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct KernelSettings {
|
||||
pub mappings: Vec<InplaceMapping>,
|
||||
vectorization_global: Option<Vectorization>,
|
||||
vectorization_partial: Vec<VectorizationPartial>,
|
||||
cube_dim: CubeDim,
|
||||
pub reading_strategy: Vec<(u16, ReadingStrategy)>,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for KernelSettings {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// The goal of this implementation is to generate the shortest representation
|
||||
// that won't clash with any other compilation settings. This is crucial since we rely on
|
||||
// this representation to know when to compile a new version of a kernel.
|
||||
//
|
||||
// Each main section starts with a letter that can't be used by other main sections:
|
||||
//
|
||||
// * Mapping: m
|
||||
// * Input: i
|
||||
// * Output: o
|
||||
//
|
||||
// * Reading Strategy: r
|
||||
// * Output layout: o
|
||||
// * Plain: p
|
||||
//
|
||||
// * Vectorization Global: vg{factor}
|
||||
// * Vectorization Partial Input: v{factor}i{pos}
|
||||
// * Vectorization Partial Output: vo
|
||||
// * Cube Dim X: x
|
||||
// * Cube Dim Y: y
|
||||
// * Cube Dim Z: z
|
||||
f.write_str("m")?;
|
||||
for mapping in self.mappings.iter() {
|
||||
f.write_fmt(format_args!(
|
||||
"i{}o{}",
|
||||
mapping.pos_input, mapping.pos_output
|
||||
))?;
|
||||
}
|
||||
|
||||
f.write_str("r")?;
|
||||
|
||||
for (input, strategy) in self.reading_strategy.iter() {
|
||||
match strategy {
|
||||
ReadingStrategy::OutputLayout => f.write_fmt(format_args!("i{}o", input)),
|
||||
ReadingStrategy::Plain => f.write_fmt(format_args!("i{}p", input)),
|
||||
}?;
|
||||
}
|
||||
|
||||
match self.vectorization_global {
|
||||
Some(vectorization) => f.write_fmt(format_args!("vg{}", vectorization))?,
|
||||
None => f.write_str("vn")?,
|
||||
};
|
||||
|
||||
for vectorization in self.vectorization_partial.iter() {
|
||||
match vectorization {
|
||||
VectorizationPartial::Input { pos, vectorization } => {
|
||||
f.write_fmt(format_args!("v{vectorization}i{pos}"))?
|
||||
}
|
||||
VectorizationPartial::Output { pos, vectorization } => {
|
||||
f.write_fmt(format_args!("v{vectorization}o{pos}"))?
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
"x{}y{}z{}",
|
||||
self.cube_dim.x, self.cube_dim.y, self.cube_dim.x
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelSettings {
|
||||
/// Compile the shader with vectorization enabled for all inputs and outputs.
|
||||
#[allow(dead_code)]
|
||||
pub fn vectorize_global(mut self, vectorization: Vectorization) -> Self {
|
||||
self.vectorization_global = Some(vectorization);
|
||||
self
|
||||
}
|
||||
|
||||
/// Compile the shader with vectorization enabled for an input.
|
||||
#[allow(dead_code)]
|
||||
pub fn vectorize_input(mut self, position: usize, vectorization: Vectorization) -> Self {
|
||||
// Not setting the vectorization factor when it's the default value reduces the kernel id
|
||||
// size.
|
||||
if vectorization == 1 {
|
||||
return self;
|
||||
}
|
||||
|
||||
self.vectorization_partial
|
||||
.push(VectorizationPartial::Input {
|
||||
pos: position,
|
||||
vectorization,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Compile the shader with vectorization enabled for an output.
|
||||
#[allow(dead_code)]
|
||||
pub fn vectorize_output(mut self, position: usize, vectorization: Vectorization) -> Self {
|
||||
// Not setting the vectorization factor when it's the default value reduces the kernel id
|
||||
// size.
|
||||
if vectorization == 1 {
|
||||
return self;
|
||||
}
|
||||
|
||||
self.vectorization_partial
|
||||
.push(VectorizationPartial::Output {
|
||||
pos: position,
|
||||
vectorization,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Fetch the vectorization for the provided input position.
|
||||
pub fn vectorization_input(&self, position: usize) -> Vectorization {
|
||||
if let Some(vec) = self.vectorization_global {
|
||||
return vec;
|
||||
}
|
||||
|
||||
for partial in self.vectorization_partial.iter() {
|
||||
if let VectorizationPartial::Input { pos, vectorization } = partial {
|
||||
if *pos == position {
|
||||
return *vectorization;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
1
|
||||
}
|
||||
|
||||
/// Fetch the vectorization for the provided output position.
|
||||
pub fn vectorization_output(&self, position: usize) -> Vectorization {
|
||||
if let Some(vec) = self.vectorization_global {
|
||||
return vec;
|
||||
}
|
||||
|
||||
for partial in self.vectorization_partial.iter() {
|
||||
if let VectorizationPartial::Output { pos, vectorization } = partial {
|
||||
if *pos == position {
|
||||
return *vectorization;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
1
|
||||
}
|
||||
|
||||
/// Compile the shader with inplace enabled by the given [mapping](InplaceMapping).
|
||||
///
|
||||
/// Notes:
|
||||
///
|
||||
/// You should favor using `dynamic_settings` when using fusion, since the mapping is going to
|
||||
/// be created from the runtime information.
|
||||
pub fn inplace(mut self, mappings: Vec<InplaceMapping>) -> Self {
|
||||
self.mappings = mappings;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set cube dimension.
|
||||
#[allow(dead_code)]
|
||||
pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
|
||||
self.cube_dim = cube_dim;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn is_contiguous(strides: &[usize]) -> bool {
|
||||
let mut current = 0;
|
||||
|
||||
for stride in strides.iter().rev() {
|
||||
if current > *stride {
|
||||
return false;
|
||||
}
|
||||
current = *stride;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Information related to an input.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum InputInfo {
|
||||
Array { item: Item, visibility: Visibility },
|
||||
Scalar { elem: Elem, size: usize },
|
||||
}
|
||||
|
||||
impl InputInfo {
|
||||
/// The item type of the input.
|
||||
#[allow(dead_code)]
|
||||
pub fn item(&self) -> Item {
|
||||
match self {
|
||||
InputInfo::Array {
|
||||
item,
|
||||
visibility: _,
|
||||
} => *item,
|
||||
InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OutputInfo {
|
||||
/// The item type of the input.
|
||||
#[allow(dead_code)]
|
||||
pub fn item(&self) -> Item {
|
||||
match self {
|
||||
OutputInfo::ArrayWrite {
|
||||
item,
|
||||
local: _,
|
||||
position: _,
|
||||
} => *item,
|
||||
OutputInfo::InputArrayWrite {
|
||||
item,
|
||||
input: _,
|
||||
local: _,
|
||||
position: _,
|
||||
} => *item,
|
||||
OutputInfo::Array { item } => *item,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Information related to an output.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum OutputInfo {
|
||||
/// Write the local variable to a new array.
|
||||
///
|
||||
/// This will create a new binding in the [kernel definition](KernelDefinition).
|
||||
ArrayWrite {
|
||||
item: Item,
|
||||
local: u16,
|
||||
position: Variable,
|
||||
},
|
||||
/// Write the local variable to an existing input binding.
|
||||
InputArrayWrite {
|
||||
item: Item,
|
||||
input: u16,
|
||||
local: u16,
|
||||
position: Variable,
|
||||
},
|
||||
/// Simply register the output, but don't automatically add a write to it.
|
||||
///
|
||||
/// Useful when a procedure writes to the output using operations.
|
||||
Array { item: Item },
|
||||
}
|
||||
|
||||
impl OutputInfo {
|
||||
#[allow(dead_code)]
|
||||
pub fn elem_size<R: Runtime>(&self) -> usize {
|
||||
let elem = match self {
|
||||
OutputInfo::ArrayWrite {
|
||||
item,
|
||||
local: _,
|
||||
position: _,
|
||||
} => bool_elem(item.elem()),
|
||||
OutputInfo::InputArrayWrite {
|
||||
item,
|
||||
input: _,
|
||||
local: _,
|
||||
position: _,
|
||||
} => bool_elem(item.elem()),
|
||||
OutputInfo::Array { item } => bool_elem(item.elem()),
|
||||
};
|
||||
<R::Compiler as Compiler>::elem_size(elem)
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelIntegrator {
|
||||
/// Starts a new compilation.
|
||||
pub fn new(info: KernelExpansion) -> Self {
|
||||
Self {
|
||||
expansion: info,
|
||||
input_bindings: Default::default(),
|
||||
output_bindings: Default::default(),
|
||||
named_bindings: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs the compilation with the provided [settings](KernelSettings).
|
||||
pub fn integrate(mut self, mut settings: KernelSettings) -> KernelDefinition {
|
||||
if let Some(vectorization) = settings.vectorization_global {
|
||||
self.expansion.scope.vectorize(vectorization);
|
||||
}
|
||||
|
||||
self.register_inputs(&settings);
|
||||
self.register_outputs(&mut settings);
|
||||
|
||||
let inputs = self.input_bindings;
|
||||
let outputs = self.output_bindings;
|
||||
let mut named = Vec::with_capacity(2);
|
||||
|
||||
named.push((
|
||||
"info".to_string(),
|
||||
Binding {
|
||||
item: Item::new(Elem::UInt),
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None, // We avoid putting the length here since it will force a new kernel
|
||||
// for each tensor rank.
|
||||
},
|
||||
));
|
||||
|
||||
for (name, binding) in self.named_bindings.into_iter() {
|
||||
named.push((name, binding));
|
||||
}
|
||||
|
||||
KernelDefinition {
|
||||
inputs,
|
||||
outputs,
|
||||
named,
|
||||
cube_dim: settings.cube_dim,
|
||||
body: self.expansion.scope,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_inputs(&mut self, settings: &KernelSettings) {
|
||||
for (id, strategy) in settings.reading_strategy.iter() {
|
||||
self.expansion.scope.update_read(*id, *strategy);
|
||||
}
|
||||
|
||||
for input in self.expansion.inputs.drain(..) {
|
||||
match input {
|
||||
InputInfo::Array { item, visibility } => {
|
||||
let item = if let Some(vectorization) = settings.vectorization_global {
|
||||
item.vectorize(vectorization)
|
||||
} else {
|
||||
item
|
||||
};
|
||||
|
||||
self.input_bindings.push(Binding {
|
||||
item: bool_item(item),
|
||||
visibility,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
}
|
||||
InputInfo::Scalar { elem, size } => {
|
||||
let elem = bool_elem(elem);
|
||||
|
||||
self.named_bindings.push((
|
||||
format!("scalars_{}", elem),
|
||||
Binding {
|
||||
item: Item::new(elem),
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: Some(size),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_outputs(&mut self, settings: &mut KernelSettings) {
|
||||
let mut index = 0;
|
||||
|
||||
if !settings.mappings.is_empty() {
|
||||
let mut mappings = Vec::new();
|
||||
core::mem::swap(&mut settings.mappings, &mut mappings);
|
||||
|
||||
for mapping in mappings {
|
||||
self.register_inplace_mapping(mapping);
|
||||
}
|
||||
}
|
||||
|
||||
for array in self.expansion.outputs.drain(..) {
|
||||
match array {
|
||||
OutputInfo::ArrayWrite {
|
||||
item,
|
||||
local,
|
||||
position,
|
||||
} => {
|
||||
let item = if let Some(vectorization) = settings.vectorization_global {
|
||||
item.vectorize(vectorization)
|
||||
} else {
|
||||
item
|
||||
};
|
||||
let item_adapted = bool_item(item);
|
||||
|
||||
self.output_bindings.push(Binding {
|
||||
item: item_adapted,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
self.expansion.scope.write_global(
|
||||
Variable::Local {
|
||||
id: local,
|
||||
item,
|
||||
depth: self.expansion.scope.depth,
|
||||
},
|
||||
Variable::GlobalOutputArray {
|
||||
id: index,
|
||||
item: item_adapted,
|
||||
},
|
||||
position,
|
||||
);
|
||||
index += 1;
|
||||
}
|
||||
OutputInfo::InputArrayWrite {
|
||||
item,
|
||||
input,
|
||||
local,
|
||||
position,
|
||||
} => {
|
||||
let item = if let Some(vectorization) = settings.vectorization_global {
|
||||
item.vectorize(vectorization)
|
||||
} else {
|
||||
item
|
||||
};
|
||||
|
||||
self.expansion.scope.write_global(
|
||||
Variable::Local {
|
||||
id: local,
|
||||
item,
|
||||
depth: self.expansion.scope.depth,
|
||||
},
|
||||
Variable::GlobalInputArray {
|
||||
id: input,
|
||||
item: bool_item(item),
|
||||
},
|
||||
position,
|
||||
);
|
||||
}
|
||||
OutputInfo::Array { item } => {
|
||||
let item = if let Some(vectorization) = settings.vectorization_global {
|
||||
item.vectorize(vectorization)
|
||||
} else {
|
||||
item
|
||||
};
|
||||
let elem_adapted = bool_item(item);
|
||||
|
||||
self.output_bindings.push(Binding {
|
||||
item: elem_adapted,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_inplace_mapping(&mut self, mapping: InplaceMapping) {
|
||||
let output = match self.expansion.outputs.get_mut(mapping.pos_output) {
|
||||
Some(output) => output,
|
||||
None => {
|
||||
// The mapping is handled differently, normally by cube itself.
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let (item, local, position) = match output {
|
||||
OutputInfo::ArrayWrite { item, local, position } => (item, local, position),
|
||||
OutputInfo::InputArrayWrite {
|
||||
item: _,
|
||||
input,
|
||||
local: _,
|
||||
position: _,
|
||||
} => {
|
||||
assert_eq!(
|
||||
*input, mapping.pos_input as u16,
|
||||
"Can't use different inputs for the same output."
|
||||
);
|
||||
return;
|
||||
}
|
||||
OutputInfo::Array { item: _ } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."),
|
||||
};
|
||||
|
||||
let item = match self.input_bindings.get_mut(mapping.pos_input) {
|
||||
Some(binding) => {
|
||||
// Update input visibility.
|
||||
binding.visibility = Visibility::ReadWrite;
|
||||
// Inputs modified inplace should be read without any specified layout.
|
||||
self.expansion
|
||||
.scope
|
||||
.update_read(mapping.pos_input as u16, ReadingStrategy::Plain);
|
||||
|
||||
// Use the same item as the input.
|
||||
//
|
||||
// The output can be different (i.e inplace boolean operations on float bindings).
|
||||
binding.item
|
||||
}
|
||||
None => *item,
|
||||
};
|
||||
|
||||
// Update the output.
|
||||
*output = OutputInfo::InputArrayWrite {
|
||||
item,
|
||||
input: mapping.pos_input as u16,
|
||||
local: *local,
|
||||
position: *position,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_item(ty: Item) -> Item {
|
||||
Item {
|
||||
elem: bool_elem(ty.elem),
|
||||
vectorization: ty.vectorization,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bool_elem(elem: Elem) -> Elem {
|
||||
match elem {
|
||||
// U32 are used for bool tensors
|
||||
Elem::Bool => Elem::UInt,
|
||||
_ => elem,
|
||||
}
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
mod execution;
|
||||
mod integrator;
|
||||
|
||||
mod compiler;
|
||||
|
||||
pub use compiler::*;
|
||||
pub use execution::*;
|
||||
pub use integrator::*;
|
|
@ -1,104 +0,0 @@
|
|||
use crate::ir::{Elem, Item, Visibility};
|
||||
use crate::prelude::KernelDefinition;
|
||||
use crate::KernelSettings;
|
||||
use crate::{
|
||||
frontend::{CubeContext, ExpandElement},
|
||||
InputInfo, KernelExpansion, KernelIntegrator, OutputInfo,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition).
|
||||
pub struct KernelBuilder {
|
||||
/// Cube [context](CubeContext).
|
||||
pub context: CubeContext,
|
||||
inputs: Vec<InputInfo>,
|
||||
outputs: Vec<OutputInfo>,
|
||||
indices: HashMap<Elem, usize>,
|
||||
num_input: u16,
|
||||
num_output: u16,
|
||||
}
|
||||
|
||||
impl KernelBuilder {
|
||||
/// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
|
||||
pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
|
||||
let index = match self.indices.get_mut(&elem) {
|
||||
Some(index) => match self.inputs.get_mut(*index).unwrap() {
|
||||
InputInfo::Scalar { elem: _, size } => {
|
||||
*size += 1;
|
||||
*size as u16 - 1
|
||||
}
|
||||
_ => panic!("Should be a scalar."),
|
||||
},
|
||||
None => {
|
||||
self.indices.insert(elem, self.inputs.len());
|
||||
self.inputs.push(InputInfo::Scalar { size: 1, elem });
|
||||
0
|
||||
}
|
||||
};
|
||||
|
||||
self.context.scalar(index, elem)
|
||||
}
|
||||
|
||||
/// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
|
||||
pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
|
||||
self.outputs.push(OutputInfo::Array { item });
|
||||
let variable = self.context.output(self.num_output, item);
|
||||
self.num_output += 1;
|
||||
|
||||
variable
|
||||
}
|
||||
|
||||
/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
|
||||
pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
|
||||
self.inputs.push(InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
});
|
||||
let variable = self.context.input(self.num_input, item);
|
||||
self.num_input += 1;
|
||||
variable
|
||||
}
|
||||
|
||||
/// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
|
||||
pub fn output_array(&mut self, item: Item) -> ExpandElement {
|
||||
self.outputs.push(OutputInfo::Array { item });
|
||||
let variable = self.context.output(self.num_output, item);
|
||||
self.num_output += 1;
|
||||
|
||||
variable
|
||||
}
|
||||
|
||||
/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
|
||||
pub fn input_array(&mut self, item: Item) -> ExpandElement {
|
||||
self.inputs.push(InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
});
|
||||
let variable = self.context.input(self.num_input, item);
|
||||
self.num_input += 1;
|
||||
variable
|
||||
}
|
||||
|
||||
/// Build the [kernel definition](KernelDefinition).
|
||||
pub fn build(self, settings: KernelSettings) -> KernelDefinition {
|
||||
KernelIntegrator::new(KernelExpansion {
|
||||
scope: self.context.into_scope(),
|
||||
inputs: self.inputs,
|
||||
outputs: self.outputs,
|
||||
})
|
||||
.integrate(settings)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KernelBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
context: CubeContext::root(),
|
||||
inputs: Vec::new(),
|
||||
outputs: Vec::new(),
|
||||
indices: HashMap::new(),
|
||||
num_input: 0,
|
||||
num_output: 0,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
|
||||
use alloc::sync::Arc;
|
||||
use burn_compute::server::{Binding, ComputeServer};
|
||||
|
||||
/// A kernel, compiled in the target language
|
||||
pub struct CompiledKernel {
|
||||
/// Source code of the kernel
|
||||
pub source: String,
|
||||
/// Size of a cube for the compiled kernel
|
||||
pub cube_dim: CubeDim,
|
||||
/// The number of bytes used by the share memory
|
||||
pub shared_mem_bytes: usize,
|
||||
}
|
||||
|
||||
/// Kernel trait with the ComputeShader that will be compiled and cached based on the
|
||||
/// provided id.
|
||||
pub trait CubeTask: Send + Sync {
|
||||
/// Identifier for the kernel, used for caching kernel compilation.
|
||||
fn id(&self) -> String;
|
||||
/// Compile the kernel into source
|
||||
fn compile(&self) -> CompiledKernel;
|
||||
}
|
||||
|
||||
/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
|
||||
#[derive(new)]
|
||||
pub struct KernelTask<C: Compiler, K: Kernel> {
|
||||
kernel_definition: K,
|
||||
_compiler: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
|
||||
fn compile(&self) -> CompiledKernel {
|
||||
let gpu_ir = self.kernel_definition.define();
|
||||
let cube_dim = gpu_ir.cube_dim;
|
||||
let lower_level_ir = C::compile(gpu_ir);
|
||||
let shared_mem_bytes = lower_level_ir.shared_memory_size();
|
||||
let source = lower_level_ir.to_string();
|
||||
|
||||
CompiledKernel {
|
||||
source,
|
||||
cube_dim,
|
||||
shared_mem_bytes,
|
||||
}
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
self.kernel_definition.id().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeTask for Arc<dyn CubeTask> {
|
||||
fn compile(&self) -> CompiledKernel {
|
||||
self.as_ref().compile()
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
self.as_ref().id()
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeTask for Box<dyn CubeTask> {
|
||||
fn compile(&self) -> CompiledKernel {
|
||||
self.as_ref().compile()
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
self.as_ref().id()
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides launch information specifying the number of work groups to be used by a compute shader.
|
||||
pub enum CubeCount<S: ComputeServer> {
|
||||
/// Dispatch x,y,z work groups.
|
||||
Static(u32, u32, u32),
|
||||
/// Dispatch work groups based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
|
||||
Dynamic(Binding<S>),
|
||||
}
|
||||
|
||||
impl<S: ComputeServer> Clone for CubeCount<S> {
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
|
||||
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,337 +0,0 @@
|
|||
use crate::compute::{CubeCount, KernelTask};
|
||||
use crate::ir::{Elem, FloatKind, IntKind};
|
||||
use crate::prelude::ArrayHandle;
|
||||
use crate::KernelSettings;
|
||||
use crate::{calculate_num_elems_dyn_rank, frontend::TensorHandle, Kernel, Runtime};
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_compute::server::Binding;
|
||||
use bytemuck::NoUninit;
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
/// Prepare a kernel for [launch](KernelLauncher::launch).
|
||||
pub struct KernelLauncher<R: Runtime> {
|
||||
tensors: TensorState<R>,
|
||||
scalar_bf16: ScalarState<half::bf16>,
|
||||
scalar_f16: ScalarState<half::f16>,
|
||||
scalar_f32: ScalarState<f32>,
|
||||
scalar_f64: ScalarState<f64>,
|
||||
scalar_u32: ScalarState<u32>,
|
||||
scalar_i64: ScalarState<i64>,
|
||||
scalar_i32: ScalarState<i32>,
|
||||
scalar_order: Vec<Elem>,
|
||||
pub settings: KernelSettings,
|
||||
}
|
||||
|
||||
impl<R: Runtime> KernelLauncher<R> {
|
||||
/// Register a tensor to be launched.
|
||||
pub fn register_tensor(&mut self, tensor: &TensorHandle<'_, R>) {
|
||||
self.tensors.push(tensor);
|
||||
}
|
||||
|
||||
/// Register an array to be launched.
|
||||
pub fn register_array(&mut self, array: &ArrayHandle<'_, R>) {
|
||||
self.tensors.push(&array.as_tensor());
|
||||
}
|
||||
|
||||
/// Register a u32 scalar to be launched.
|
||||
pub fn register_u32(&mut self, scalar: u32) {
|
||||
self.register_scalar(Elem::UInt);
|
||||
self.scalar_u32.push(scalar);
|
||||
}
|
||||
|
||||
/// Register a i32 scalar to be launched.
|
||||
pub fn register_i32(&mut self, scalar: i32) {
|
||||
self.register_scalar(Elem::Int(IntKind::I32));
|
||||
self.scalar_i32.push(scalar);
|
||||
}
|
||||
|
||||
/// Register a i64 scalar to be launched.
|
||||
pub fn register_i64(&mut self, scalar: i64) {
|
||||
self.register_scalar(Elem::Int(IntKind::I64));
|
||||
self.scalar_i64.push(scalar);
|
||||
}
|
||||
|
||||
/// Register a bf16 scalar to be launched.
|
||||
pub fn register_bf16(&mut self, scalar: half::bf16) {
|
||||
self.register_scalar(Elem::Float(FloatKind::BF16));
|
||||
self.scalar_bf16.push(scalar);
|
||||
}
|
||||
|
||||
/// Register a f16 scalar to be launched.
|
||||
pub fn register_f16(&mut self, scalar: half::f16) {
|
||||
self.register_scalar(Elem::Float(FloatKind::F16));
|
||||
self.scalar_f16.push(scalar);
|
||||
}
|
||||
|
||||
/// Register a f32 scalar to be launched.
|
||||
pub fn register_f32(&mut self, scalar: f32) {
|
||||
self.register_scalar(Elem::Float(FloatKind::F32));
|
||||
self.scalar_f32.push(scalar);
|
||||
}
|
||||
|
||||
/// Register a f64 scalar to be launched.
|
||||
pub fn register_f64(&mut self, scalar: f64) {
|
||||
self.register_scalar(Elem::Float(FloatKind::F64));
|
||||
self.scalar_f64.push(scalar);
|
||||
}
|
||||
|
||||
/// Launch the kernel.
|
||||
pub fn launch<K: Kernel>(
|
||||
self,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
kernel: K,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
) {
|
||||
let bindings = self.into_bindings(&client);
|
||||
|
||||
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
|
||||
|
||||
client.execute(kernel, cube_count, bindings);
|
||||
}
|
||||
|
||||
/// We need to create the bindings in the same order they are defined in the compilation step.
|
||||
///
|
||||
/// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed
|
||||
/// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
|
||||
/// are registered in the same order they are added. This is why we store the scalar data type
|
||||
/// in the `scalar_order` vector, so that we can register them in the same order.
|
||||
fn into_bindings(
|
||||
mut self,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
) -> Vec<Binding<R::Server>> {
|
||||
let mut bindings = Vec::new();
|
||||
|
||||
self.tensors.register(client, &mut bindings);
|
||||
|
||||
for elem in self.scalar_order.drain(..) {
|
||||
match elem {
|
||||
Elem::Float(kind) => match kind {
|
||||
FloatKind::F16 => self.scalar_f16.register::<R>(client, &mut bindings),
|
||||
FloatKind::BF16 => self.scalar_bf16.register::<R>(client, &mut bindings),
|
||||
FloatKind::F32 => self.scalar_f32.register::<R>(client, &mut bindings),
|
||||
FloatKind::F64 => self.scalar_f64.register::<R>(client, &mut bindings),
|
||||
},
|
||||
Elem::Int(kind) => match kind {
|
||||
IntKind::I32 => self.scalar_i32.register::<R>(client, &mut bindings),
|
||||
IntKind::I64 => self.scalar_i64.register::<R>(client, &mut bindings),
|
||||
},
|
||||
Elem::UInt => self.scalar_u32.register::<R>(client, &mut bindings),
|
||||
Elem::Bool => panic!("Bool can't be passed as bindings."),
|
||||
}
|
||||
}
|
||||
|
||||
bindings
|
||||
}
|
||||
|
||||
fn register_scalar(&mut self, elem: Elem) {
|
||||
if !self.scalar_order.contains(&elem) {
|
||||
self.scalar_order.push(elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles the tensor state.
|
||||
pub enum TensorState<R: Runtime> {
|
||||
/// No tensor is registered yet.
|
||||
Empty,
|
||||
/// The registered tensors.
|
||||
Some {
|
||||
bindings: Vec<Binding<R::Server>>,
|
||||
metadata: Vec<u32>,
|
||||
lengths: Vec<u32>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Handles the scalar state of an element type
|
||||
///
|
||||
/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
|
||||
pub enum ScalarState<T> {
|
||||
/// No scalar of that type is registered yet.
|
||||
Empty,
|
||||
/// The registered scalars.
|
||||
Some(Vec<T>),
|
||||
}
|
||||
|
||||
impl<R: Runtime> TensorState<R> {
|
||||
/// Push a new tensor to the state.
|
||||
pub fn push(&mut self, tensor: &TensorHandle<'_, R>) {
|
||||
if let TensorState::Empty = self {
|
||||
*self = TensorState::Some {
|
||||
bindings: Vec::with_capacity(1),
|
||||
metadata: Vec::new(),
|
||||
lengths: Vec::new(),
|
||||
};
|
||||
};
|
||||
|
||||
let (bindings, metadata, lengths) = match self {
|
||||
TensorState::Empty => panic!("Should be init"),
|
||||
TensorState::Some {
|
||||
bindings,
|
||||
metadata,
|
||||
lengths,
|
||||
} => (bindings, metadata, lengths),
|
||||
};
|
||||
|
||||
bindings.push(tensor.handle.clone().binding());
|
||||
|
||||
let old_rank = if metadata.is_empty() {
|
||||
let rank = tensor.strides.len() as u32;
|
||||
metadata.push(rank);
|
||||
None
|
||||
} else if tensor.strides.len() > metadata[0] as usize {
|
||||
let old_rank = metadata[0];
|
||||
let rank = tensor.strides.len() as u32;
|
||||
Self::adjust_rank(metadata, bindings.len(), rank);
|
||||
Some(old_rank)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata);
|
||||
Self::register_shape(tensor.shape, old_rank, metadata);
|
||||
|
||||
if R::require_array_lengths() {
|
||||
let len = calculate_num_elems_dyn_rank(tensor.shape);
|
||||
lengths.push(len as u32);
|
||||
}
|
||||
}
|
||||
|
||||
fn adjust_rank(metadata: &mut Vec<u32>, num_registered: usize, rank: u32) {
|
||||
let old_rank = metadata[0] as usize;
|
||||
let rank_diff = rank as usize - old_rank;
|
||||
let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered);
|
||||
|
||||
for pos in 0..num_registered {
|
||||
let stride_index = (pos * old_rank * 2) + 1;
|
||||
let shape_index = stride_index + old_rank;
|
||||
|
||||
let strides_old = &metadata[stride_index..stride_index + old_rank];
|
||||
let shape_old = &metadata[shape_index..shape_index + old_rank];
|
||||
|
||||
Self::register_strides(
|
||||
strides_old,
|
||||
shape_old,
|
||||
Some(old_rank as u32),
|
||||
&mut updated_metadata,
|
||||
);
|
||||
Self::register_shape(shape_old, Some(old_rank as u32), &mut updated_metadata);
|
||||
}
|
||||
|
||||
core::mem::swap(&mut updated_metadata, metadata);
|
||||
}
|
||||
|
||||
fn register_strides<T: ToPrimitive>(
|
||||
strides: &[T],
|
||||
shape: &[T],
|
||||
old_rank: Option<u32>,
|
||||
output: &mut Vec<u32>,
|
||||
) {
|
||||
let old_rank = if let Some(old_rank) = old_rank {
|
||||
let rank = output[0];
|
||||
let rank_diff = old_rank - rank;
|
||||
let padded_strides = if rank_diff > 0 {
|
||||
shape
|
||||
.iter()
|
||||
.take(old_rank as usize)
|
||||
.map(|a| a.to_u32().unwrap())
|
||||
.sum::<u32>()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
for _ in 0..rank_diff {
|
||||
output.push(padded_strides.to_u32().unwrap());
|
||||
}
|
||||
|
||||
old_rank as usize
|
||||
} else {
|
||||
output[0] as usize // same as current.
|
||||
};
|
||||
|
||||
for stride in strides.iter().take(old_rank) {
|
||||
output.push(stride.to_u32().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
fn register_shape<T: ToPrimitive>(shape: &[T], old_rank: Option<u32>, output: &mut Vec<u32>) {
|
||||
let old_rank = if let Some(old_rank) = old_rank {
|
||||
let rank = output[0];
|
||||
let rank_diff = rank - old_rank;
|
||||
|
||||
for _ in 0..rank_diff {
|
||||
output.push(1);
|
||||
}
|
||||
|
||||
old_rank as usize
|
||||
} else {
|
||||
output[0] as usize // same as current
|
||||
};
|
||||
|
||||
for elem in shape.iter().take(old_rank) {
|
||||
output.push(elem.to_u32().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
fn register(
|
||||
self,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
bindings_global: &mut Vec<Binding<R::Server>>,
|
||||
) {
|
||||
if let Self::Some {
|
||||
bindings,
|
||||
mut metadata,
|
||||
lengths,
|
||||
} = self
|
||||
{
|
||||
if R::require_array_lengths() {
|
||||
for len in lengths {
|
||||
metadata.push(len);
|
||||
}
|
||||
}
|
||||
|
||||
bindings_global.extend(bindings);
|
||||
bindings_global.push(client.create(bytemuck::cast_slice(&metadata)).binding());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: NoUninit> ScalarState<T> {
|
||||
/// Add a new scalar value to the state.
|
||||
pub fn push(&mut self, val: T) {
|
||||
match self {
|
||||
ScalarState::Empty => *self = Self::Some(vec![val]),
|
||||
ScalarState::Some(values) => values.push(val),
|
||||
}
|
||||
}
|
||||
|
||||
fn register<R: Runtime>(
|
||||
&self,
|
||||
client: &ComputeClient<R::Server, R::Channel>,
|
||||
bindings: &mut Vec<Binding<R::Server>>,
|
||||
) {
|
||||
match self {
|
||||
ScalarState::Empty => (),
|
||||
ScalarState::Some(values) => {
|
||||
let handle = client.create(bytemuck::cast_slice(values));
|
||||
bindings.push(handle.binding());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> Default for KernelLauncher<R> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tensors: TensorState::Empty,
|
||||
scalar_bf16: ScalarState::Empty,
|
||||
scalar_f16: ScalarState::Empty,
|
||||
scalar_f32: ScalarState::Empty,
|
||||
scalar_f64: ScalarState::Empty,
|
||||
scalar_u32: ScalarState::Empty,
|
||||
scalar_i64: ScalarState::Empty,
|
||||
scalar_i32: ScalarState::Empty,
|
||||
scalar_order: Vec::new(),
|
||||
settings: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
mod builder;
|
||||
mod kernel;
|
||||
mod launcher;
|
||||
|
||||
pub use builder::*;
|
||||
pub use kernel::*;
|
||||
pub use launcher::*;
|
|
@ -1,12 +0,0 @@
|
|||
#[macro_export]
|
||||
macro_rules! unexpanded {
|
||||
() => ({
|
||||
panic!("Unexpanded Cube functions should not be called. ");
|
||||
});
|
||||
($msg:expr) => ({
|
||||
panic!($msg);
|
||||
});
|
||||
($fmt:expr, $($arg:tt)*) => ({
|
||||
panic!($fmt, $($arg)*);
|
||||
});
|
||||
}
|
|
@ -1,153 +0,0 @@
|
|||
use std::ops::Deref;
|
||||
|
||||
use crate::frontend::{CubeContext, ExpandElement, UInt};
|
||||
use crate::ir::{Branch, Elem, If, IfElse, Item, Loop, RangeLoop, Variable};
|
||||
|
||||
use super::comptime::Comptime;
|
||||
|
||||
pub fn range<S, E>(start: S, end: E, _unroll: Comptime<bool>) -> impl Iterator<Item = UInt>
|
||||
where
|
||||
S: Into<UInt>,
|
||||
E: Into<UInt>,
|
||||
{
|
||||
let start: UInt = start.into();
|
||||
let end: UInt = end.into();
|
||||
|
||||
(start.val..end.val).map(UInt::new)
|
||||
}
|
||||
|
||||
pub fn range_expand<F, S, E>(context: &mut CubeContext, start: S, end: E, unroll: bool, mut func: F)
|
||||
where
|
||||
F: FnMut(&mut CubeContext, ExpandElement),
|
||||
S: Into<ExpandElement>,
|
||||
E: Into<ExpandElement>,
|
||||
{
|
||||
let start: ExpandElement = start.into();
|
||||
let end: ExpandElement = end.into();
|
||||
|
||||
if unroll {
|
||||
let start = match start.deref() {
|
||||
Variable::ConstantScalar { value, .. } => *value as usize,
|
||||
_ => panic!("Only constant start can be unrolled."),
|
||||
};
|
||||
let end = match end.deref() {
|
||||
Variable::ConstantScalar { value, .. } => *value as usize,
|
||||
_ => panic!("Only constant end can be unrolled."),
|
||||
};
|
||||
|
||||
for i in start..end {
|
||||
func(context, i.into())
|
||||
}
|
||||
} else {
|
||||
let mut child = context.child();
|
||||
let index_ty = Item::new(Elem::UInt);
|
||||
let i = child.scope.borrow_mut().create_local_undeclared(index_ty);
|
||||
let i = ExpandElement::Plain(i);
|
||||
|
||||
func(&mut child, i.clone());
|
||||
|
||||
context.register(Branch::RangeLoop(RangeLoop {
|
||||
i: *i,
|
||||
start: *start,
|
||||
end: *end,
|
||||
scope: child.into_scope(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn if_expand<IF>(
|
||||
context: &mut CubeContext,
|
||||
comptime_cond: Option<bool>,
|
||||
runtime_cond: ExpandElement,
|
||||
mut block: IF,
|
||||
) where
|
||||
IF: FnMut(&mut CubeContext),
|
||||
{
|
||||
match comptime_cond {
|
||||
Some(cond) => {
|
||||
if cond {
|
||||
block(context);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut child = context.child();
|
||||
|
||||
block(&mut child);
|
||||
|
||||
context.register(Branch::If(If {
|
||||
cond: *runtime_cond,
|
||||
scope: child.into_scope(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn if_else_expand<IF, EL>(
|
||||
context: &mut CubeContext,
|
||||
comptime_cond: Option<bool>,
|
||||
runtime_cond: ExpandElement,
|
||||
mut then_block: IF,
|
||||
mut else_block: EL,
|
||||
) where
|
||||
IF: FnMut(&mut CubeContext),
|
||||
EL: FnMut(&mut CubeContext),
|
||||
{
|
||||
match comptime_cond {
|
||||
Some(cond) => {
|
||||
if cond {
|
||||
then_block(context);
|
||||
} else {
|
||||
else_block(context);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut then_child = context.child();
|
||||
then_block(&mut then_child);
|
||||
|
||||
let mut else_child = context.child();
|
||||
else_block(&mut else_child);
|
||||
|
||||
context.register(Branch::IfElse(IfElse {
|
||||
cond: *runtime_cond,
|
||||
scope_if: then_child.into_scope(),
|
||||
scope_else: else_child.into_scope(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn break_expand(context: &mut CubeContext) {
|
||||
context.register(Branch::Break);
|
||||
}
|
||||
|
||||
pub fn return_expand(context: &mut CubeContext) {
|
||||
context.register(Branch::Return);
|
||||
}
|
||||
|
||||
pub fn loop_expand<FB>(context: &mut CubeContext, mut block: FB)
|
||||
where
|
||||
FB: FnMut(&mut CubeContext),
|
||||
{
|
||||
let mut inside_loop = context.child();
|
||||
|
||||
block(&mut inside_loop);
|
||||
context.register(Branch::Loop(Loop {
|
||||
scope: inside_loop.into_scope(),
|
||||
}));
|
||||
}
|
||||
|
||||
pub fn while_loop_expand<FC, FB>(context: &mut CubeContext, mut cond_fn: FC, mut block: FB)
|
||||
where
|
||||
FC: FnMut(&mut CubeContext) -> ExpandElement,
|
||||
FB: FnMut(&mut CubeContext),
|
||||
{
|
||||
let mut inside_loop = context.child();
|
||||
|
||||
let cond: ExpandElement = cond_fn(&mut inside_loop);
|
||||
if_expand(&mut inside_loop, None, cond, break_expand);
|
||||
|
||||
block(&mut inside_loop);
|
||||
context.register(Branch::Loop(Loop {
|
||||
scope: inside_loop.into_scope(),
|
||||
}));
|
||||
}
|
|
@ -1,238 +0,0 @@
|
|||
//! This module exposes cooperative matrix-multiply and accumulate operations.
|
||||
//!
|
||||
//! Most of the functions are actually unsafe, since they mutate their input, even if they are
|
||||
//! passed as reference.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! This is a basic 16x16x16 matrix multiplication example.
|
||||
//!
|
||||
//! ```rust, ignore
|
||||
//! #[cube(launch)]
|
||||
//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
|
||||
//! let a = cmma::Matrix::<F16>::new(
|
||||
//! cmma::MatrixIdent::A,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! cmma::MatrixLayout::RowMajor,
|
||||
//! );
|
||||
//! let b = cmma::Matrix::<F16>::new(
|
||||
//! cmma::MatrixIdent::B,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! cmma::MatrixLayout::ColMajor,
|
||||
//! );
|
||||
//! let c = cmma::Matrix::<F32>::new(
|
||||
//! cmma::MatrixIdent::Accumulator,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! cmma::MatrixLayout::Undefined,
|
||||
//! );
|
||||
//! cmma::fill::<F32>(&c, F32::new(0.0));
|
||||
//! cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
|
||||
//! cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
|
||||
//!
|
||||
//! cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
|
||||
//!
|
||||
//! cmma::store::<F32>(
|
||||
//! out.as_slice_mut(),
|
||||
//! &c,
|
||||
//! UInt::new(16),
|
||||
//! cmma::MatrixLayout::RowMajor,
|
||||
//! );
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
ir::{self, Operation},
|
||||
unexpanded,
|
||||
};
|
||||
|
||||
use super::{
|
||||
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut,
|
||||
UInt,
|
||||
};
|
||||
|
||||
pub use ir::{MatrixIdent, MatrixLayout};
|
||||
|
||||
/// A matrix represent a 2D grid of numbers.
|
||||
///
|
||||
/// They can either be in a [row major](MatrixLayout::RowMajor) or a
|
||||
/// [column major](MatrixLayout::ColMajor) format.
|
||||
pub struct Matrix<C: CubeType> {
|
||||
_c: PhantomData<C>,
|
||||
}
|
||||
|
||||
/// Expand type of [Matrix].
|
||||
#[derive(Clone)]
|
||||
pub struct MatrixExpand {
|
||||
elem: ExpandElement,
|
||||
}
|
||||
|
||||
impl<C: CubeType> CubeType for Matrix<C> {
|
||||
type ExpandType = MatrixExpand;
|
||||
}
|
||||
|
||||
impl Init for MatrixExpand {
|
||||
fn init(self, _context: &mut CubeContext) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: CubePrimitive> Matrix<C> {
|
||||
/// Create a new matrix that is going to be used in the
|
||||
/// [matrix-multiply and accumulate](execute()) function.
|
||||
///
|
||||
/// You have to declare the shape used for the execution.
|
||||
/// The shape of the current matrix is determined using the [MatrixIdent].
|
||||
///
|
||||
/// * [MatrixIdent::A] Shape => (M, K)
|
||||
/// * [MatrixIdent::B] Shape => (K, N)
|
||||
/// * [MatrixIdent::Accumulator] Shape => (M, N)
|
||||
///
|
||||
/// Not all shapes are supported, and the permitted shapes depend on the element type.
|
||||
///
|
||||
/// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(ident: MatrixIdent, m: u8, n: u8, k: u8, layout: MatrixLayout) -> Self {
|
||||
Matrix { _c: PhantomData }
|
||||
}
|
||||
|
||||
pub fn __expand_new(
|
||||
context: &mut CubeContext,
|
||||
ident: MatrixIdent,
|
||||
m: u8,
|
||||
n: u8,
|
||||
k: u8,
|
||||
layout: MatrixLayout,
|
||||
) -> MatrixExpand {
|
||||
let elem = context.create_matrix(ir::Matrix {
|
||||
ident,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
elem: C::as_elem(),
|
||||
layout,
|
||||
});
|
||||
MatrixExpand { elem }
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill the matrix with the provided value.
|
||||
#[allow(unused_variables)]
|
||||
pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Module containing the expand function for [fill()].
|
||||
pub mod fill {
|
||||
use super::*;
|
||||
|
||||
/// Expand method of [fill()].
|
||||
pub fn __expand<C: CubeType>(
|
||||
context: &mut CubeContext,
|
||||
mat: MatrixExpand,
|
||||
value: ExpandElement,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Fill {
|
||||
mat: *mat.elem,
|
||||
value: *value,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the matrix with the provided array using the stride.
|
||||
#[allow(unused_variables)]
|
||||
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Module containing the expand function for [load()].
|
||||
pub mod load {
|
||||
use super::*;
|
||||
|
||||
/// Expand method of [load()].
|
||||
#[allow(unused_variables)]
|
||||
pub fn __expand<C: CubeType>(
|
||||
context: &mut CubeContext,
|
||||
mat: MatrixExpand,
|
||||
value: ExpandElementTyped<Slice<'static, C>>,
|
||||
stride: ExpandElement,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Load {
|
||||
mat: *mat.elem,
|
||||
value: *value.expand,
|
||||
stride: *stride,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/// Store the matrix in the given array following the given stride and layout.
|
||||
#[allow(unused_variables)]
|
||||
pub fn store<C: CubePrimitive>(
|
||||
output: &mut SliceMut<'_, C>,
|
||||
mat: &Matrix<C>,
|
||||
stride: UInt,
|
||||
layout: MatrixLayout,
|
||||
) {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Module containing the expand function for [store()].
|
||||
pub mod store {
|
||||
use super::*;
|
||||
|
||||
/// Expand method of [store()].
|
||||
#[allow(unused_variables)]
|
||||
pub fn __expand<C: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
output: ExpandElementTyped<SliceMut<'static, C>>,
|
||||
mat: MatrixExpand,
|
||||
stride: ExpandElement,
|
||||
layout: MatrixLayout,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Store {
|
||||
output: *output.expand,
|
||||
mat: *mat.elem,
|
||||
stride: *stride,
|
||||
layout,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
|
||||
#[allow(unused_variables)]
|
||||
pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
|
||||
mat_a: &Matrix<A>,
|
||||
mat_b: &Matrix<B>,
|
||||
mat_c: &Matrix<C>,
|
||||
mat_d: &Matrix<D>,
|
||||
) {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Module containing the expand function for [execute()].
|
||||
pub mod execute {
|
||||
use super::*;
|
||||
|
||||
/// Expand method of [execute()].
|
||||
pub fn __expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
mat_a: MatrixExpand,
|
||||
mat_b: MatrixExpand,
|
||||
mat_c: MatrixExpand,
|
||||
mat_d: MatrixExpand,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Execute {
|
||||
mat_a: *mat_a.elem,
|
||||
mat_b: *mat_b.elem,
|
||||
mat_c: *mat_c.elem,
|
||||
mat_d: *mat_d.elem,
|
||||
}));
|
||||
}
|
||||
}
|
|
@ -1,146 +0,0 @@
|
|||
use crate::{
|
||||
frontend::{CubeContext, CubeType},
|
||||
unexpanded,
|
||||
};
|
||||
|
||||
use super::{ExpandElement, Init, UInt, Vectorized};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
/// Encapsulates a value to signify it must be used at compilation time rather than in the kernel
|
||||
///
|
||||
/// Use `Comptime<Option<T>>` to have an alternate runtime behaviour if the compilation time value is not present
|
||||
pub struct Comptime<T> {
|
||||
pub(crate) inner: T,
|
||||
}
|
||||
|
||||
impl<T> Comptime<T> {
|
||||
/// Create a new Comptime. Useful when hardcoding values in
|
||||
/// Cube kernels. For instance:
|
||||
/// if Comptime::new(false) {...} never generates the inner code block
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
/// Get the inner value of a Comptime. For instance:
|
||||
/// let c = Comptime::new(false);
|
||||
/// if Comptime::get(c) {...}
|
||||
pub fn get(_comptime: Self) -> T {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn map<R, F: Fn(T) -> R>(_comptime: Self, _closure: F) -> Comptime<R> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn map_expand<R, F: Fn(T) -> R>(inner: T, closure: F) -> R {
|
||||
closure(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubeType + Into<T::ExpandType>> Comptime<Option<T>> {
|
||||
/// Map a Comptime optional to a Comptime boolean that tell
|
||||
/// whether the optional contained a value
|
||||
pub fn is_some(comptime: Self) -> Comptime<bool> {
|
||||
Comptime::new(comptime.inner.is_some())
|
||||
}
|
||||
|
||||
/// Return the inner value of the Comptime if it exists,
|
||||
/// otherwise tell how to compute it at runtime
|
||||
pub fn unwrap_or_else<F>(_comptime: Self, mut _alt: F) -> T
|
||||
where
|
||||
F: FnOnce() -> T,
|
||||
{
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expanded version of unwrap_or_else
|
||||
pub fn unwrap_or_else_expand<F>(
|
||||
context: &mut CubeContext,
|
||||
t: Option<T>,
|
||||
alt: F,
|
||||
) -> <T as CubeType>::ExpandType
|
||||
where
|
||||
F: FnOnce(&mut CubeContext) -> T::ExpandType,
|
||||
{
|
||||
match t {
|
||||
Some(t) => t.into(),
|
||||
None => alt(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Init> CubeType for Comptime<T> {
|
||||
type ExpandType = T;
|
||||
}
|
||||
|
||||
impl<T: Vectorized> Comptime<T> {
|
||||
pub fn vectorization(_state: &T) -> Comptime<UInt> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn vectorization_expand(_context: &mut CubeContext, state: T) -> UInt {
|
||||
state.vectorization_factor()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Into<ExpandElement>> Comptime<T> {
|
||||
pub fn runtime(_comptime: Self) -> T {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn runtime_expand(_context: &mut CubeContext, inner: T) -> ExpandElement {
|
||||
inner.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Add<T, Output = T>> core::ops::Add for Comptime<T> {
|
||||
type Output = Comptime<T>;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
Comptime::new(self.inner.add(rhs.inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Sub<T, Output = T>> core::ops::Sub for Comptime<T> {
|
||||
type Output = Comptime<T>;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
Comptime::new(self.inner.sub(rhs.inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Div<T, Output = T>> core::ops::Div for Comptime<T> {
|
||||
type Output = Comptime<T>;
|
||||
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
Comptime::new(self.inner.div(rhs.inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Mul<T, Output = T>> core::ops::Mul for Comptime<T> {
|
||||
type Output = Comptime<T>;
|
||||
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
Comptime::new(self.inner.mul(rhs.inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::ops::Rem<T, Output = T>> core::ops::Rem for Comptime<T> {
|
||||
type Output = Comptime<T>;
|
||||
|
||||
fn rem(self, rhs: Self) -> Self::Output {
|
||||
Comptime::new(self.inner.rem(rhs.inner))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::cmp::PartialOrd + core::cmp::PartialEq> core::cmp::PartialEq for Comptime<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
core::cmp::PartialEq::eq(&self.inner, &other.inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: core::cmp::PartialOrd + core::cmp::PartialEq> core::cmp::PartialOrd for Comptime<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
core::cmp::PartialOrd::partial_cmp(&self.inner, &other.inner)
|
||||
}
|
||||
}
|
|
@ -1,145 +0,0 @@
|
|||
use crate::frontend::ExpandElement;
|
||||
use crate::ir::{self, Elem, Item, Operation, Scope};
|
||||
use alloc::rc::Rc;
|
||||
use core::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct VariablePool {
|
||||
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
|
||||
}
|
||||
|
||||
impl VariablePool {
|
||||
/// Returns an old, not used anymore variable, if there exists one.
|
||||
pub fn reuse(&self, item: Item) -> Option<ExpandElement> {
|
||||
let map = self.map.borrow();
|
||||
|
||||
// Filter for candidate variables of the same Item
|
||||
let variables = map.get(&item)?;
|
||||
|
||||
// Among the candidates, take a variable if it's only referenced by the map
|
||||
// Arbitrarily takes the first it finds in reverse order.
|
||||
for variable in variables.iter().rev() {
|
||||
match variable {
|
||||
ExpandElement::Managed(var) => {
|
||||
if Rc::strong_count(var) == 1 {
|
||||
return Some(variable.clone());
|
||||
}
|
||||
}
|
||||
ExpandElement::Plain(_) => (),
|
||||
}
|
||||
}
|
||||
|
||||
// If no candidate was found, a new var will be needed
|
||||
None
|
||||
}
|
||||
|
||||
/// Insert a new variable in the map, which is classified by Item
|
||||
pub fn insert(&mut self, var: ExpandElement) {
|
||||
let mut map = self.map.borrow_mut();
|
||||
let item = var.item();
|
||||
|
||||
if let Some(variables) = map.get_mut(&item) {
|
||||
variables.push(var.clone());
|
||||
} else {
|
||||
map.insert(var.item(), vec![var.clone()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CubeContext {
|
||||
pub root: Rc<RefCell<Scope>>,
|
||||
pub scope: Rc<RefCell<Scope>>,
|
||||
pub pool: VariablePool,
|
||||
}
|
||||
|
||||
impl CubeContext {
|
||||
/// Create a new cube context, with a root scope
|
||||
/// A root scope is at the root of a compute shader
|
||||
/// Therefore there is one cube context per shader
|
||||
pub fn root() -> CubeContext {
|
||||
let root = Rc::new(RefCell::new(Scope::root()));
|
||||
let scope = root.clone();
|
||||
|
||||
Self {
|
||||
pool: Default::default(),
|
||||
scope,
|
||||
root,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<O: Into<Operation>>(&mut self, op: O) {
|
||||
self.scope.borrow_mut().register(op)
|
||||
}
|
||||
|
||||
pub fn child(&mut self) -> CubeContext {
|
||||
let scope = self.scope.borrow_mut().child();
|
||||
|
||||
Self {
|
||||
scope: Rc::new(RefCell::new(scope)),
|
||||
root: self.root.clone(),
|
||||
pool: self.pool.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_scope(self) -> Scope {
|
||||
core::mem::drop(self.root);
|
||||
|
||||
Rc::into_inner(self.scope)
|
||||
.expect("Only one reference")
|
||||
.into_inner()
|
||||
}
|
||||
|
||||
/// When a new variable is required, we check if we can reuse an old one
|
||||
/// Otherwise we create a new one.
|
||||
pub fn create_local(&mut self, item: Item) -> ExpandElement {
|
||||
// Reuse an old variable if possible
|
||||
if let Some(var) = self.pool.reuse(item) {
|
||||
return var;
|
||||
}
|
||||
|
||||
// Create a new variable at the root scope
|
||||
// Insert it in the variable pool for potential reuse
|
||||
let new = ExpandElement::Managed(Rc::new(self.root.borrow_mut().create_local(item)));
|
||||
self.pool.insert(new.clone());
|
||||
|
||||
new
|
||||
}
|
||||
|
||||
/// Create a new matrix element.
|
||||
pub fn create_matrix(&mut self, matrix: ir::Matrix) -> ExpandElement {
|
||||
let variable = self.scope.borrow_mut().create_matrix(matrix);
|
||||
ExpandElement::Plain(variable)
|
||||
}
|
||||
|
||||
/// Create a new slice element.
|
||||
pub fn create_slice(&mut self, item: Item) -> ExpandElement {
|
||||
let variable = self.scope.borrow_mut().create_slice(item);
|
||||
ExpandElement::Plain(variable)
|
||||
}
|
||||
|
||||
pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
|
||||
ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
|
||||
}
|
||||
|
||||
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
|
||||
ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size))
|
||||
}
|
||||
|
||||
/// Obtain the index-th input
|
||||
pub fn input(&mut self, id: u16, item: Item) -> ExpandElement {
|
||||
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item })
|
||||
}
|
||||
|
||||
/// Obtain the index-th output
|
||||
pub fn output(&mut self, id: u16, item: Item) -> ExpandElement {
|
||||
let var = crate::ir::Variable::GlobalOutputArray { id, item };
|
||||
self.scope.borrow_mut().write_global_custom(var);
|
||||
ExpandElement::Plain(var)
|
||||
}
|
||||
|
||||
/// Obtain the index-th scalar
|
||||
pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement {
|
||||
ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem })
|
||||
}
|
||||
}
|
|
@ -1,238 +0,0 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
compute::{KernelBuilder, KernelLauncher},
|
||||
frontend::CubeType,
|
||||
ir::{Item, Vectorization},
|
||||
unexpanded, KernelSettings, Runtime,
|
||||
};
|
||||
use crate::{
|
||||
frontend::{indexation::Index, CubeContext},
|
||||
prelude::{assign, index, index_assign, Comptime},
|
||||
};
|
||||
|
||||
use super::{
|
||||
ArgSettings, CubePrimitive, ExpandElement, ExpandElementTyped, Init, LaunchArg,
|
||||
LaunchArgExpand, TensorHandle, UInt,
|
||||
};
|
||||
|
||||
/// A contiguous array of elements.
|
||||
pub struct Array<E> {
|
||||
_val: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<C: CubeType> CubeType for Array<C> {
|
||||
type ExpandType = ExpandElementTyped<Array<C>>;
|
||||
}
|
||||
|
||||
impl<T: CubePrimitive + Clone> Array<T> {
|
||||
pub fn new<S: Index>(_size: S) -> Self {
|
||||
Array { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
|
||||
Array { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn __expand_new<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Array need constant initialization value"),
|
||||
};
|
||||
context
|
||||
.create_local_array(Item::new(T::as_elem()), size)
|
||||
.into()
|
||||
}
|
||||
|
||||
pub fn __expand_vectorized<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
vectorization_factor: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Shared memory need constant initialization value"),
|
||||
};
|
||||
context
|
||||
.create_local_array(
|
||||
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
|
||||
size,
|
||||
)
|
||||
.into()
|
||||
}
|
||||
|
||||
pub fn to_vectorized(self, _vectorization_factor: Comptime<UInt>) -> T {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: CubeType> ExpandElementTyped<Array<C>> {
|
||||
pub fn to_vectorized_expand(
|
||||
self,
|
||||
context: &mut CubeContext,
|
||||
vectorization_factor: UInt,
|
||||
) -> ExpandElement {
|
||||
let factor = vectorization_factor.val;
|
||||
let var = self.expand.clone();
|
||||
let mut new_var = context.create_local(Item::vectorized(var.item().elem(), factor as u8));
|
||||
if vectorization_factor.val == 1 {
|
||||
let element = index::expand(context, self.clone(), 0u32);
|
||||
assign::expand(context, element, new_var.clone());
|
||||
} else {
|
||||
for i in 0..factor {
|
||||
let element = index::expand(context, self.expand.clone(), i);
|
||||
new_var = index_assign::expand(context, new_var, i, element);
|
||||
}
|
||||
}
|
||||
new_var
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: CubeType> CubeType for &Array<C> {
|
||||
type ExpandType = ExpandElementTyped<Array<C>>;
|
||||
}
|
||||
impl<C: CubeType> Init for ExpandElementTyped<Array<C>> {
|
||||
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
|
||||
// The type can't be deeply cloned/copied.
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: CubeType> Array<E> {
|
||||
/// Obtain the array length
|
||||
pub fn len(&self) -> UInt {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: CubePrimitive> LaunchArg for Array<C> {
|
||||
type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
|
||||
}
|
||||
|
||||
impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
|
||||
fn expand(
|
||||
builder: &mut KernelBuilder,
|
||||
vectorization: Vectorization,
|
||||
) -> ExpandElementTyped<Array<C>> {
|
||||
builder
|
||||
.input_array(Item::vectorized(C::as_elem(), vectorization))
|
||||
.into()
|
||||
}
|
||||
fn expand_output(
|
||||
builder: &mut KernelBuilder,
|
||||
vectorization: Vectorization,
|
||||
) -> ExpandElementTyped<Array<C>> {
|
||||
builder
|
||||
.output_array(Item::vectorized(C::as_elem(), vectorization))
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Tensor representation with a reference to the [server handle](burn_compute::server::Handle).
|
||||
pub struct ArrayHandle<'a, R: Runtime> {
|
||||
pub handle: &'a burn_compute::server::Handle<R::Server>,
|
||||
pub length: [usize; 1],
|
||||
}
|
||||
|
||||
pub enum ArrayArg<'a, R: Runtime> {
|
||||
/// The array is passed with an array handle.
|
||||
Handle {
|
||||
/// The array handle.
|
||||
handle: ArrayHandle<'a, R>,
|
||||
/// The vectorization factor.
|
||||
vectorization_factor: u8,
|
||||
},
|
||||
/// The array is aliasing another input array.
|
||||
Alias {
|
||||
/// The position of the input array.
|
||||
input_pos: usize,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> ArgSettings<R> for ArrayArg<'a, R> {
|
||||
fn register(&self, launcher: &mut KernelLauncher<R>) {
|
||||
if let ArrayArg::Handle {
|
||||
handle,
|
||||
vectorization_factor: _,
|
||||
} = self
|
||||
{
|
||||
launcher.register_array(handle)
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_input(&self, position: usize, settings: KernelSettings) -> KernelSettings {
|
||||
match self {
|
||||
Self::Handle {
|
||||
handle: _,
|
||||
vectorization_factor,
|
||||
} => settings.vectorize_input(position, *vectorization_factor),
|
||||
Self::Alias { input_pos: _ } => {
|
||||
panic!("Not yet supported, only output can be aliased for now.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings {
|
||||
match self {
|
||||
Self::Handle {
|
||||
handle: _,
|
||||
vectorization_factor,
|
||||
} => settings.vectorize_output(position, *vectorization_factor),
|
||||
Self::Alias { input_pos } => {
|
||||
settings.mappings.push(crate::InplaceMapping {
|
||||
pos_input: *input_pos,
|
||||
pos_output: position,
|
||||
});
|
||||
settings
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> ArrayArg<'a, R> {
|
||||
/// Create a new array argument.
|
||||
///
|
||||
/// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization
|
||||
/// factor of 1.
|
||||
pub fn new(handle: &'a burn_compute::server::Handle<R::Server>, length: usize) -> Self {
|
||||
ArrayArg::Handle {
|
||||
handle: ArrayHandle::new(handle, length),
|
||||
vectorization_factor: 1,
|
||||
}
|
||||
}
|
||||
/// Create a new array argument specified with its vectorization factor.
|
||||
pub fn vectorized(
|
||||
vectorization_factor: u8,
|
||||
handle: &'a burn_compute::server::Handle<R::Server>,
|
||||
length: usize,
|
||||
) -> Self {
|
||||
ArrayArg::Handle {
|
||||
handle: ArrayHandle::new(handle, length),
|
||||
vectorization_factor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: Runtime> ArrayHandle<'a, R> {
|
||||
pub fn new(handle: &'a burn_compute::server::Handle<R::Server>, length: usize) -> Self {
|
||||
Self {
|
||||
handle,
|
||||
length: [length],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_tensor(&self) -> TensorHandle<'_, R> {
|
||||
let shape = &self.length;
|
||||
|
||||
TensorHandle {
|
||||
handle: self.handle,
|
||||
strides: &[1],
|
||||
shape,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,277 +0,0 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
ir::{Operator, Variable, Vectorization},
|
||||
prelude::{init_expand, CubeContext, KernelBuilder, KernelLauncher},
|
||||
KernelSettings, Runtime,
|
||||
};
|
||||
use alloc::rc::Rc;
|
||||
|
||||
use super::{UInt, Vectorized};
|
||||
|
||||
/// Types used in a cube function must implement this trait
|
||||
///
|
||||
/// Variables whose values will be known at runtime must
|
||||
/// have ExpandElement as associated type
|
||||
/// Variables whose values will be known at compile time
|
||||
/// must have the primitive type as associated type
|
||||
///
|
||||
/// Note: Cube functions should be written using CubeTypes,
|
||||
/// so that the code generated uses the associated ExpandType.
|
||||
/// This allows Cube code to not necessitate cloning, which is cumbersome
|
||||
/// in algorithmic code. The necessary cloning will automatically appear in
|
||||
/// the generated code.
|
||||
pub trait CubeType {
|
||||
type ExpandType: Clone + Init;
|
||||
}
|
||||
|
||||
/// Trait to be implemented by [cube types](CubeType) implementations.
|
||||
pub trait Init: Sized {
|
||||
/// Initialize a type within a [context](CubeContext).
|
||||
///
|
||||
/// You can return the same value when the variable is a non-mutable data structure or
|
||||
/// if the type can not be deeply cloned/copied.
|
||||
fn init(self, context: &mut CubeContext) -> Self;
|
||||
}
|
||||
|
||||
/// Defines how a [launch argument](LaunchArg) can be expanded.
|
||||
///
|
||||
/// Normally this type should be implemented two times for an argument.
|
||||
/// Once for the reference and the other for the mutable reference. Often time, the reference
|
||||
/// should expand the argument as an input while the mutable reference should expand the argument
|
||||
/// as an output.
|
||||
pub trait LaunchArgExpand: CubeType {
|
||||
/// Register an input variable during compilation that fill the [KernelBuilder].
|
||||
fn expand(
|
||||
builder: &mut KernelBuilder,
|
||||
vectorization: Vectorization,
|
||||
) -> <Self as CubeType>::ExpandType;
|
||||
/// Register an output variable during compilation that fill the [KernelBuilder].
|
||||
fn expand_output(
|
||||
builder: &mut KernelBuilder,
|
||||
vectorization: Vectorization,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
Self::expand(builder, vectorization)
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines a type that can be used as argument to a kernel.
|
||||
pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
|
||||
/// The runtime argument for the kernel.
|
||||
type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
|
||||
}
|
||||
|
||||
impl LaunchArg for () {
|
||||
type RuntimeArg<'a, R: Runtime> = ();
|
||||
}
|
||||
|
||||
impl<R: Runtime> ArgSettings<R> for () {
|
||||
fn register(&self, _launcher: &mut KernelLauncher<R>) {
|
||||
// nothing to do
|
||||
}
|
||||
}
|
||||
|
||||
impl LaunchArgExpand for () {
|
||||
fn expand(
|
||||
_builder: &mut KernelBuilder,
|
||||
_vectorization: Vectorization,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeType for () {
|
||||
type ExpandType = ();
|
||||
}
|
||||
|
||||
impl Init for () {
|
||||
fn init(self, _context: &mut CubeContext) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines the argument settings used to launch a kernel.
|
||||
pub trait ArgSettings<R: Runtime>: Send + Sync {
|
||||
/// Register the information to the [KernelLauncher].
|
||||
fn register(&self, launcher: &mut KernelLauncher<R>);
|
||||
/// Configure an input argument at the given position.
|
||||
fn configure_input(&self, _position: usize, settings: KernelSettings) -> KernelSettings {
|
||||
settings
|
||||
}
|
||||
/// Configure an output argument at the given position.
|
||||
fn configure_output(&self, _position: usize, settings: KernelSettings) -> KernelSettings {
|
||||
settings
|
||||
}
|
||||
}
|
||||
|
||||
/// Reference to a JIT variable
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ExpandElement {
|
||||
/// Variable kept in the variable pool.
|
||||
Managed(Rc<Variable>),
|
||||
/// Variable not kept in the variable pool.
|
||||
Plain(Variable),
|
||||
}
|
||||
|
||||
/// Expand type associated with a type.
|
||||
#[derive(new)]
|
||||
pub struct ExpandElementTyped<T> {
|
||||
pub(crate) expand: ExpandElement,
|
||||
pub(crate) _type: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Vectorized for ExpandElementTyped<T> {
|
||||
fn vectorization_factor(&self) -> UInt {
|
||||
self.expand.vectorization_factor()
|
||||
}
|
||||
|
||||
fn vectorize(self, factor: UInt) -> Self {
|
||||
Self {
|
||||
expand: self.expand.vectorize(factor),
|
||||
_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for ExpandElementTyped<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
expand: self.expand.clone(),
|
||||
_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<ExpandElement> for ExpandElementTyped<T> {
|
||||
fn from(expand: ExpandElement) -> Self {
|
||||
Self {
|
||||
expand,
|
||||
_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<ExpandElementTyped<T>> for ExpandElement {
|
||||
fn from(value: ExpandElementTyped<T>) -> Self {
|
||||
value.expand
|
||||
}
|
||||
}
|
||||
|
||||
impl ExpandElement {
|
||||
pub fn can_mut(&self) -> bool {
|
||||
match self {
|
||||
ExpandElement::Managed(var) => {
|
||||
if let Variable::Local { .. } = var.as_ref() {
|
||||
Rc::strong_count(var) <= 2
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
ExpandElement::Plain(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Deref for ExpandElement {
|
||||
type Target = Variable;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
ExpandElement::Managed(var) => var.as_ref(),
|
||||
ExpandElement::Plain(var) => var,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ExpandElement> for Variable {
|
||||
fn from(value: ExpandElement) -> Self {
|
||||
match value {
|
||||
ExpandElement::Managed(var) => *var,
|
||||
ExpandElement::Plain(var) => var,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Init for ExpandElement {
|
||||
fn init(self, context: &mut CubeContext) -> Self {
|
||||
if self.can_mut() {
|
||||
// Can reuse inplace :)
|
||||
return self;
|
||||
}
|
||||
|
||||
let mut init = |elem: Self| init_expand(context, elem, Operator::Assign);
|
||||
|
||||
match *self {
|
||||
Variable::GlobalScalar { .. } => init(self),
|
||||
Variable::LocalScalar { .. } => init(self),
|
||||
Variable::ConstantScalar { .. } => init(self),
|
||||
Variable::Local { .. } => init(self),
|
||||
// Constant should be initialized since the new variable can be mutated afterward.
|
||||
// And it is assumed those values are cloned.
|
||||
Variable::Rank
|
||||
| Variable::UnitPos
|
||||
| Variable::UnitPosX
|
||||
| Variable::UnitPosY
|
||||
| Variable::UnitPosZ
|
||||
| Variable::CubePos
|
||||
| Variable::CubePosX
|
||||
| Variable::CubePosY
|
||||
| Variable::CubePosZ
|
||||
| Variable::CubeDim
|
||||
| Variable::CubeDimX
|
||||
| Variable::CubeDimY
|
||||
| Variable::CubeDimZ
|
||||
| Variable::CubeCount
|
||||
| Variable::CubeCountX
|
||||
| Variable::CubeCountY
|
||||
| Variable::CubeCountZ
|
||||
| Variable::SubcubeDim
|
||||
| Variable::AbsolutePos
|
||||
| Variable::AbsolutePosX
|
||||
| Variable::AbsolutePosY
|
||||
| Variable::AbsolutePosZ => init(self),
|
||||
// Array types can't be copied, so we should simply return the same variable.
|
||||
Variable::SharedMemory { .. }
|
||||
| Variable::GlobalInputArray { .. }
|
||||
| Variable::GlobalOutputArray { .. }
|
||||
| Variable::LocalArray { .. }
|
||||
| Variable::Slice { .. }
|
||||
| Variable::Matrix { .. } => self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_init_for {
|
||||
($($t:ty),*) => {
|
||||
$(
|
||||
impl Init for $t {
|
||||
fn init(self, _context: &mut CubeContext) -> Self {
|
||||
panic!("Shouln't be called, only for comptime.")
|
||||
}
|
||||
}
|
||||
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
// Add all types used within comptime
|
||||
impl_init_for!(u32, bool, UInt);
|
||||
|
||||
impl<T: Init> Init for Option<T> {
|
||||
fn init(self, context: &mut CubeContext) -> Self {
|
||||
self.map(|o| Init::init(o, context))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubeType> CubeType for Vec<T> {
|
||||
type ExpandType = Vec<T::ExpandType>;
|
||||
}
|
||||
|
||||
impl<T: CubeType> CubeType for &mut Vec<T> {
|
||||
type ExpandType = Vec<T::ExpandType>;
|
||||
}
|
||||
|
||||
impl<T: Init> Init for Vec<T> {
|
||||
fn init(self, context: &mut CubeContext) -> Self {
|
||||
self.into_iter().map(|e| e.init(context)).collect()
|
||||
}
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
use crate::frontend::{CubePrimitive, CubeType, ExpandElement};
|
||||
use crate::ir::Elem;
|
||||
|
||||
use super::Vectorized;
|
||||
|
||||
// To be consistent with other primitive type.
|
||||
/// Boolean type.
|
||||
pub type Bool = bool;
|
||||
|
||||
impl CubeType for bool {
|
||||
type ExpandType = ExpandElement;
|
||||
}
|
||||
|
||||
impl CubePrimitive for bool {
|
||||
fn as_elem() -> Elem {
|
||||
Elem::Bool
|
||||
}
|
||||
}
|
||||
|
||||
impl Vectorized for bool {
|
||||
fn vectorization_factor(&self) -> crate::prelude::UInt {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn vectorize(self, _factor: crate::prelude::UInt) -> Self {
|
||||
todo!()
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
use crate::frontend::{assign, CubeContext, CubePrimitive, CubeType};
|
||||
use crate::ir::{Item, Variable};
|
||||
use crate::{frontend::ExpandElement, unexpanded};
|
||||
|
||||
/// Enable elegant casting from any to any CubeElem
|
||||
pub trait Cast: CubePrimitive {
|
||||
fn cast_from<From: CubePrimitive>(value: From) -> Self;
|
||||
|
||||
fn __expand_cast_from<From>(
|
||||
context: &mut CubeContext,
|
||||
value: From,
|
||||
) -> <Self as CubeType>::ExpandType
|
||||
where
|
||||
From: Into<ExpandElement>,
|
||||
{
|
||||
let value: ExpandElement = value.into();
|
||||
let var: Variable = *value;
|
||||
let new_var = context.create_local(Item::vectorized(
|
||||
<Self as CubePrimitive>::as_elem(),
|
||||
var.item().vectorization,
|
||||
));
|
||||
assign::expand(context, value, new_var.clone());
|
||||
new_var
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: CubePrimitive> Cast for P {
|
||||
fn cast_from<From: CubePrimitive>(_value: From) -> Self {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
use crate::frontend::UInt;
|
||||
use crate::frontend::{CubeType, ExpandElement};
|
||||
use crate::ir::{Elem, Variable};
|
||||
|
||||
use super::Vectorized;
|
||||
|
||||
/// Form of CubeType that encapsulates all primitive types:
|
||||
/// Numeric, UInt, Bool
|
||||
pub trait CubePrimitive:
|
||||
CubeType<ExpandType = ExpandElement>
|
||||
+ Vectorized
|
||||
+ core::cmp::Eq
|
||||
+ core::cmp::PartialEq
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static
|
||||
+ Clone
|
||||
+ Copy
|
||||
{
|
||||
/// Return the element type to use on GPU
|
||||
fn as_elem() -> Elem;
|
||||
}
|
||||
|
||||
macro_rules! impl_into_expand_element {
|
||||
($type:ty) => {
|
||||
impl From<$type> for ExpandElement {
|
||||
fn from(value: $type) -> Self {
|
||||
ExpandElement::Plain(Variable::from(value))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_into_expand_element!(u32);
|
||||
impl_into_expand_element!(usize);
|
||||
impl_into_expand_element!(bool);
|
||||
impl_into_expand_element!(f32);
|
||||
impl_into_expand_element!(i32);
|
||||
impl_into_expand_element!(i64);
|
||||
|
||||
/// Useful for Comptime
|
||||
impl From<UInt> for ExpandElement {
|
||||
fn from(value: UInt) -> Self {
|
||||
ExpandElement::Plain(crate::ir::Variable::ConstantScalar {
|
||||
value: value.val as f64,
|
||||
elem: UInt::as_elem(),
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,233 +0,0 @@
|
|||
use half::{bf16, f16};
|
||||
|
||||
use crate::frontend::{Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Powf, Recip, Sin, Sqrt, Tanh};
|
||||
use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric};
|
||||
use crate::ir::{Elem, FloatKind, Item, Variable, Vectorization};
|
||||
|
||||
use crate::compute::{KernelBuilder, KernelLauncher};
|
||||
use crate::prelude::index_assign;
|
||||
use crate::{unexpanded, Runtime};
|
||||
|
||||
use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized};
|
||||
|
||||
/// Floating point numbers. Used as input in float kernels
|
||||
pub trait Float:
|
||||
Numeric
|
||||
+ Exp
|
||||
+ Log
|
||||
+ Log1p
|
||||
+ Cos
|
||||
+ Sin
|
||||
+ Tanh
|
||||
+ Powf
|
||||
+ Sqrt
|
||||
+ Floor
|
||||
+ Ceil
|
||||
+ Erf
|
||||
+ Recip
|
||||
+ core::ops::Index<UInt, Output = Self>
|
||||
+ core::ops::IndexMut<UInt, Output = Self>
|
||||
{
|
||||
fn new(val: f32) -> Self;
|
||||
fn vectorized(val: f32, vectorization: UInt) -> Self;
|
||||
fn vectorized_empty(vectorization: UInt) -> Self;
|
||||
fn __expand_new(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
|
||||
fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: f32,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType;
|
||||
fn __expand_vectorized_empty(
|
||||
context: &mut CubeContext,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType;
|
||||
}
|
||||
|
||||
macro_rules! impl_float {
|
||||
($type:ident, $primitive:ty) => {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct $type {
|
||||
pub val: f32,
|
||||
pub vectorization: u8,
|
||||
}
|
||||
|
||||
impl CubeType for $type {
|
||||
type ExpandType = ExpandElement;
|
||||
}
|
||||
|
||||
impl CubePrimitive for $type {
|
||||
/// Return the element type to use on GPU
|
||||
fn as_elem() -> Elem {
|
||||
Elem::Float(FloatKind::$type)
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for $type {
|
||||
type Primitive = $primitive;
|
||||
}
|
||||
|
||||
impl Float for $type {
|
||||
fn new(val: f32) -> Self {
|
||||
Self {
|
||||
val,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn vectorized(val: f32, vectorization: UInt) -> Self {
|
||||
if vectorization.val == 1 {
|
||||
Self::new(val)
|
||||
} else {
|
||||
Self {
|
||||
val,
|
||||
vectorization: vectorization.val as u8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn vectorized_empty(vectorization: UInt) -> Self {
|
||||
Self::vectorized(0., vectorization)
|
||||
}
|
||||
|
||||
fn __expand_new(
|
||||
_context: &mut CubeContext,
|
||||
val: f32,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: f32,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
if vectorization.val == 1 {
|
||||
Self::__expand_new(context, val)
|
||||
} else {
|
||||
let mut new_var = context
|
||||
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
|
||||
for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() {
|
||||
new_var = index_assign::expand(context, new_var, i, *element);
|
||||
}
|
||||
|
||||
new_var
|
||||
}
|
||||
}
|
||||
|
||||
fn __expand_vectorized_empty(
|
||||
context: &mut CubeContext,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
if vectorization.val == 1 {
|
||||
Self::__expand_new(context, 0.)
|
||||
} else {
|
||||
context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Index<UInt> for $type {
|
||||
type Output = Self;
|
||||
|
||||
fn index(&self, _index: UInt) -> &Self::Output {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::IndexMut<UInt> for $type {
|
||||
fn index_mut(&mut self, _index: UInt) -> &mut Self::Output {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl LaunchArgExpand for $type {
|
||||
fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
|
||||
assert_eq!(vectorization, 1, "Attempted to vectorize a scalar");
|
||||
builder.scalar($type::as_elem())
|
||||
}
|
||||
}
|
||||
|
||||
impl Vectorized for $type {
|
||||
fn vectorization_factor(&self) -> UInt {
|
||||
UInt {
|
||||
val: self.vectorization as u32,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn vectorize(mut self, factor: UInt) -> Self {
|
||||
self.vectorization = factor.vectorization;
|
||||
self
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_float!(F16, f16);
|
||||
impl_float!(BF16, bf16);
|
||||
impl_float!(F32, f32);
|
||||
impl_float!(F64, f64);
|
||||
|
||||
impl From<f32> for F32 {
|
||||
fn from(value: f32) -> Self {
|
||||
Self {
|
||||
val: value,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for BF16 {
|
||||
fn from(value: f32) -> Self {
|
||||
Self {
|
||||
val: value,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for F16 {
|
||||
fn from(value: f32) -> Self {
|
||||
Self {
|
||||
val: value,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for F64 {
|
||||
fn from(value: f32) -> Self {
|
||||
Self {
|
||||
val: value,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarArgSettings for f16 {
|
||||
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
|
||||
settings.register_f16(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarArgSettings for bf16 {
|
||||
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
|
||||
settings.register_bf16(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarArgSettings for f32 {
|
||||
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
|
||||
settings.register_f32(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarArgSettings for f64 {
|
||||
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
|
||||
settings.register_f64(*self);
|
||||
}
|
||||
}
|
|
@ -1,146 +0,0 @@
|
|||
use crate::compute::{KernelBuilder, KernelLauncher};
|
||||
use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric};
|
||||
use crate::ir::{Elem, IntKind, Item, Variable, Vectorization};
|
||||
use crate::prelude::index_assign;
|
||||
use crate::Runtime;
|
||||
|
||||
use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized};
|
||||
|
||||
/// Signed integer. Used as input in int kernels
|
||||
pub trait Int: Numeric + std::ops::Rem<Output = Self> {
|
||||
fn new(val: i64) -> Self;
|
||||
fn vectorized(val: i64, vectorization: UInt) -> Self;
|
||||
fn __expand_new(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType;
|
||||
fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: i64,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType;
|
||||
}
|
||||
|
||||
macro_rules! impl_int {
|
||||
($type:ident, $primitive:ty) => {
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct $type {
|
||||
pub val: $primitive,
|
||||
pub vectorization: u8,
|
||||
}
|
||||
|
||||
impl CubeType for $type {
|
||||
type ExpandType = ExpandElement;
|
||||
}
|
||||
|
||||
impl CubePrimitive for $type {
|
||||
fn as_elem() -> Elem {
|
||||
Elem::Int(IntKind::$type)
|
||||
}
|
||||
}
|
||||
|
||||
impl Numeric for $type {
|
||||
type Primitive = $primitive;
|
||||
}
|
||||
|
||||
impl Int for $type {
|
||||
fn new(val: i64) -> Self {
|
||||
Self {
|
||||
val: val as $primitive,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn vectorized(val: i64, vectorization: UInt) -> Self {
|
||||
if vectorization.val == 1 {
|
||||
Self::new(val)
|
||||
} else {
|
||||
Self {
|
||||
val: val as $primitive,
|
||||
vectorization: vectorization.val as u8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn __expand_new(
|
||||
_context: &mut CubeContext,
|
||||
val: i64,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: i64,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
if vectorization.val == 1 {
|
||||
Self::__expand_new(context, val)
|
||||
} else {
|
||||
let mut new_var = context
|
||||
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
|
||||
for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() {
|
||||
new_var = index_assign::expand(context, new_var, i, *element);
|
||||
}
|
||||
|
||||
new_var
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LaunchArgExpand for $type {
|
||||
fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
|
||||
assert_eq!(vectorization, 1, "Attempted to vectorize a scalar");
|
||||
builder.scalar($type::as_elem())
|
||||
}
|
||||
}
|
||||
|
||||
impl Vectorized for $type {
|
||||
fn vectorization_factor(&self) -> UInt {
|
||||
UInt {
|
||||
val: self.vectorization as u32,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn vectorize(mut self, factor: UInt) -> Self {
|
||||
self.vectorization = factor.vectorization;
|
||||
self
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_int!(I32, i32);
|
||||
impl_int!(I64, i64);
|
||||
|
||||
impl From<i64> for I64 {
|
||||
fn from(value: i64) -> Self {
|
||||
Self {
|
||||
val: value,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for I32 {
|
||||
fn from(value: i32) -> Self {
|
||||
Self {
|
||||
val: value,
|
||||
vectorization: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarArgSettings for i32 {
|
||||
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
|
||||
settings.register_i32(*self);
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarArgSettings for i64 {
|
||||
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
|
||||
settings.register_i64(*self);
|
||||
}
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
mod array;
|
||||
mod base;
|
||||
mod bool;
|
||||
mod cast;
|
||||
mod cube_elem;
|
||||
mod float;
|
||||
mod int;
|
||||
mod numeric;
|
||||
mod shared_memory;
|
||||
mod slice;
|
||||
mod tensor;
|
||||
mod uint;
|
||||
mod vectorized;
|
||||
|
||||
pub use array::*;
|
||||
pub use base::*;
|
||||
pub use bool::*;
|
||||
pub use cast::*;
|
||||
pub use cube_elem::*;
|
||||
pub use float::*;
|
||||
pub use int::*;
|
||||
pub use numeric::*;
|
||||
pub use shared_memory::*;
|
||||
pub use slice::*;
|
||||
pub use tensor::*;
|
||||
pub use uint::*;
|
||||
pub use vectorized::*;
|
|
@ -1,93 +0,0 @@
|
|||
use crate::compute::KernelLauncher;
|
||||
use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement};
|
||||
use crate::ir::{Item, Variable};
|
||||
use crate::prelude::Clamp;
|
||||
use crate::Runtime;
|
||||
use crate::{
|
||||
frontend::{index_assign, Abs, Max, Min, Remainder},
|
||||
unexpanded,
|
||||
};
|
||||
|
||||
use super::{ArgSettings, LaunchArg, LaunchArgExpand};
|
||||
|
||||
/// Type that encompasses both (unsigned or signed) integers and floats
|
||||
/// Used in kernels that should work for both.
|
||||
pub trait Numeric:
|
||||
Copy
|
||||
+ CubePrimitive
|
||||
+ LaunchArgExpand
|
||||
+ std::ops::Add<Output = Self>
|
||||
+ std::ops::AddAssign
|
||||
+ std::ops::SubAssign
|
||||
+ std::ops::MulAssign
|
||||
+ std::ops::DivAssign
|
||||
+ std::ops::Sub<Output = Self>
|
||||
+ std::ops::Mul<Output = Self>
|
||||
+ std::ops::Div<Output = Self>
|
||||
+ std::cmp::PartialOrd
|
||||
+ Abs
|
||||
+ Max
|
||||
+ Min
|
||||
+ Clamp
|
||||
+ Remainder
|
||||
{
|
||||
/// Create a new constant numeric.
|
||||
///
|
||||
/// Note: since this must work for both integer and float
|
||||
/// only the less expressive of both can be created (int)
|
||||
/// If a number with decimals is needed, use Float::new.
|
||||
///
|
||||
/// This method panics when unexpanded. For creating an element
|
||||
/// with a val, use the new method of the sub type.
|
||||
fn from_int(_val: i64) -> Self {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
type Primitive: ScalarArgSettings;
|
||||
|
||||
fn from_vec<const D: usize>(_vec: [i64; D]) -> Self {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
fn __expand_from_int(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
fn __expand_from_vec<const D: usize>(
|
||||
context: &mut CubeContext,
|
||||
vec: [i64; D],
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let mut new_var = context.create_local(Item::vectorized(Self::as_elem(), vec.len() as u8));
|
||||
for (i, element) in vec.iter().enumerate() {
|
||||
new_var = index_assign::expand(context, new_var, i, *element);
|
||||
}
|
||||
|
||||
new_var
|
||||
}
|
||||
}
|
||||
|
||||
/// Similar to [ArgSettings], however only for scalar types that don't depend on the [Runtime]
|
||||
/// trait.
|
||||
pub trait ScalarArgSettings: Send + Sync {
|
||||
/// Register the information to the [KernelLauncher].
|
||||
fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct ScalarArg<T: Numeric> {
|
||||
elem: T::Primitive,
|
||||
}
|
||||
|
||||
impl<T: Numeric, R: Runtime> ArgSettings<R> for ScalarArg<T> {
|
||||
fn register(&self, launcher: &mut crate::compute::KernelLauncher<R>) {
|
||||
self.elem.register(launcher);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Numeric> LaunchArg for T {
|
||||
type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
|
||||
}
|
|
@ -1,63 +0,0 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
frontend::{indexation::Index, CubeContext, CubePrimitive, CubeType},
|
||||
ir::Item,
|
||||
};
|
||||
|
||||
use super::{ExpandElementTyped, Init, UInt};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct SharedMemory<T: CubeType> {
|
||||
_val: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
|
||||
fn init(self, _context: &mut CubeContext) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubePrimitive> CubeType for SharedMemory<T> {
|
||||
type ExpandType = ExpandElementTyped<SharedMemory<T>>;
|
||||
}
|
||||
|
||||
impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
||||
pub fn new<S: Index>(_size: S) -> Self {
|
||||
SharedMemory { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
|
||||
SharedMemory { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn __expand_vectorized<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
vectorization_factor: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Shared memory need constant initialization value"),
|
||||
};
|
||||
let var = context.create_shared(
|
||||
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
|
||||
size,
|
||||
);
|
||||
ExpandElementTyped::new(var)
|
||||
}
|
||||
|
||||
pub fn __expand_new<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Shared memory need constant initialization value"),
|
||||
};
|
||||
let var = context.create_shared(Item::new(T::as_elem()), size);
|
||||
ExpandElementTyped::new(var)
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue