Migration/cubecl (#2041)

This commit is contained in:
Nathaniel Simard 2024-07-22 11:08:40 -04:00 committed by GitHub
parent 0d5025edbb
commit 19cd67a9e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
341 changed files with 723 additions and 32736 deletions

260
Cargo.lock generated
View File

@ -255,10 +255,10 @@ dependencies = [
"arboard",
"burn",
"burn-common",
"burn-cuda",
"burn-wgpu",
"clap 4.5.9",
"colored",
"cubecl",
"derive-new",
"dirs 5.0.1",
"github-device-flow",
@ -469,41 +469,17 @@ dependencies = [
name = "burn-common"
version = "0.14.0"
dependencies = [
"cubecl-common",
"dashmap",
"data-encoding",
"derive-new",
"getrandom",
"indicatif",
"pollster",
"rand",
"rayon",
"reqwest 0.12.5",
"serde",
"spin",
"tokio",
"web-time",
]
[[package]]
name = "burn-compute"
version = "0.14.0"
dependencies = [
"async-channel",
"burn-common",
"derive-new",
"dirs 5.0.1",
"hashbrown 0.14.5",
"log",
"md5",
"pollster",
"rand",
"serde",
"serde_json",
"serial_test",
"spin",
"web-time",
]
[[package]]
name = "burn-core"
version = "0.14.0"
@ -534,49 +510,6 @@ dependencies = [
"thiserror",
]
[[package]]
name = "burn-cube"
version = "0.14.0"
dependencies = [
"burn-compute",
"burn-cube-macros",
"burn-tensor",
"bytemuck",
"derive-new",
"half",
"log",
"num-traits",
"serde",
"trybuild",
]
[[package]]
name = "burn-cube-macros"
version = "0.14.0"
dependencies = [
"derive-new",
"proc-macro2",
"quote",
"syn 2.0.71",
]
[[package]]
name = "burn-cuda"
version = "0.14.0"
dependencies = [
"burn-common",
"burn-compute",
"burn-cube",
"burn-fusion",
"burn-jit",
"burn-tensor",
"bytemuck",
"cudarc",
"derive-new",
"half",
"log",
]
[[package]]
name = "burn-dataset"
version = "0.14.0"
@ -653,7 +586,7 @@ dependencies = [
"thiserror",
"tracing-core",
"tracing-subscriber",
"zip 2.1.3",
"zip 2.1.4",
]
[[package]]
@ -662,13 +595,12 @@ version = "0.14.0"
dependencies = [
"burn-autodiff",
"burn-common",
"burn-compute",
"burn-cube",
"burn-fusion",
"burn-ndarray",
"burn-tensor",
"burn-tensor-testgen",
"bytemuck",
"cubecl",
"derive-new",
"half",
"hashbrown 0.14.5",
@ -727,6 +659,7 @@ dependencies = [
"burn-common",
"burn-tensor-testgen",
"bytemuck",
"cubecl",
"derive-new",
"half",
"hashbrown 0.14.5",
@ -768,19 +701,10 @@ dependencies = [
name = "burn-wgpu"
version = "0.14.0"
dependencies = [
"async-channel",
"burn-common",
"burn-compute",
"burn-cube",
"burn-fusion",
"burn-jit",
"burn-tensor",
"bytemuck",
"derive-new",
"hashbrown 0.14.5",
"log",
"pollster",
"wgpu",
"cubecl",
]
[[package]]
@ -850,9 +774,9 @@ dependencies = [
[[package]]
name = "candle-core"
version = "0.5.1"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "311d8dbe293aa3b5c34f6a57727fafd67d17a74fa8b65276501237c233b34ffd"
checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7"
dependencies = [
"accelerate-src",
"byteorder",
@ -877,18 +801,18 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.5.1"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3b4b048ca298fb8be90b0f4d0fe68bdca9de956ab52bb6e381463d955f2b661"
checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488"
dependencies = [
"bindgen_cuda",
]
[[package]]
name = "candle-metal-kernels"
version = "0.5.1"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d31136c9541c82b7de0937c9a58210ada38e17d70810e0eacc0a99d849d848d"
checksum = "f889aacd02fd999620a0435133d7cf3b58c81ef9dd5e47c38939b7a72345ea86"
dependencies = [
"metal 0.27.0",
"once_cell",
@ -922,9 +846,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.1.5"
version = "1.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "324c74f2155653c90b04f25b2a47a8a631360cb908f92a772695f430c7e31052"
checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f"
dependencies = [
"jobserver",
"libc",
@ -1362,11 +1286,124 @@ dependencies = [
"memchr",
]
[[package]]
name = "cubecl"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
dependencies = [
"cubecl-core",
"cubecl-cuda",
"cubecl-linalg",
"cubecl-wgpu",
]
[[package]]
name = "cubecl-common"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
dependencies = [
"derive-new",
"getrandom",
"pollster",
"rand",
"serde",
"spin",
"web-time",
]
[[package]]
name = "cubecl-core"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
dependencies = [
"bytemuck",
"cubecl-macros",
"cubecl-runtime",
"derive-new",
"half",
"log",
"num-traits",
"serde",
]
[[package]]
name = "cubecl-cuda"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
dependencies = [
"bytemuck",
"cubecl-common",
"cubecl-core",
"cubecl-runtime",
"cudarc",
"derive-new",
"half",
"log",
]
[[package]]
name = "cubecl-linalg"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
dependencies = [
"bytemuck",
"cubecl-core",
"cubecl-runtime",
"half",
]
[[package]]
name = "cubecl-macros"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
dependencies = [
"derive-new",
"proc-macro2",
"quote",
"syn 2.0.71",
]
[[package]]
name = "cubecl-runtime"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
dependencies = [
"async-channel",
"cubecl-common",
"derive-new",
"dirs 5.0.1",
"hashbrown 0.14.5",
"log",
"md5",
"pollster",
"serde",
"serde_json",
"spin",
"web-time",
]
[[package]]
name = "cubecl-wgpu"
version = "0.1.1"
source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c"
dependencies = [
"async-channel",
"bytemuck",
"cubecl-common",
"cubecl-core",
"cubecl-runtime",
"derive-new",
"hashbrown 0.14.5",
"log",
"pollster",
"wgpu",
]
[[package]]
name = "cudarc"
version = "0.11.8"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a56028291ec3b0f6711e2e1b2d597484d359833dcb68331ce89e538012f835c4"
checksum = "e395cd01168d63af826749573071f3c5069b338ae473cab355d22db0b2bb5a0d"
dependencies = [
"half",
"libloading 0.8.4",
@ -1421,6 +1458,7 @@ version = "0.14.0"
dependencies = [
"burn",
"bytemuck",
"cubecl",
"derive-new",
"log",
"serde",
@ -2020,15 +2058,6 @@ dependencies = [
"slab",
]
[[package]]
name = "gelu"
version = "0.14.0"
dependencies = [
"burn-cube",
"burn-cuda",
"burn-wgpu",
]
[[package]]
name = "gemm"
version = "0.17.1"
@ -2846,6 +2875,7 @@ dependencies = [
"burn-import",
"burn-wgpu",
"console_error_panic_hook",
"cubecl",
"js-sys",
"log",
"serde",
@ -4909,9 +4939,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sdd"
version = "1.6.0"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eb0dde0ccd15e337a3cf738a9a38115c6d8e74795d074e73973dad3d229a897"
checksum = "85f05a494052771fc5bd0619742363b5e24e5ad72ab3111ec2e27925b8edc5f3"
[[package]]
name = "security-framework"
@ -5706,7 +5736,7 @@ dependencies = [
"serde",
"serde_spanned",
"toml_datetime",
"winnow 0.6.13",
"winnow 0.6.14",
]
[[package]]
@ -5826,20 +5856,6 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "trybuild"
version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b1e5645f2ee8025c2f1d75e1138f2dd034d74e6ba54620f3c569ba2a2a1ea06"
dependencies = [
"glob",
"serde",
"serde_derive",
"serde_json",
"termcolor",
"toml",
]
[[package]]
name = "typenum"
version = "1.17.0"
@ -6478,9 +6494,9 @@ dependencies = [
[[package]]
name = "winnow"
version = "0.6.13"
version = "0.6.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1"
checksum = "374ec40a2d767a3c1b4972d9475ecd557356637be906f2cb3f7fe17a6eb5e22f"
dependencies = [
"memchr",
]
@ -6699,9 +6715,9 @@ dependencies = [
[[package]]
name = "zip"
version = "2.1.3"
version = "2.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "775a2b471036342aa69bc5a602bc889cb0a06cda00477d0c69566757d5553d39"
checksum = "e29ab4097989787b2029a5981c41b7bfb427b5a601e23f455daacb4d0360a9e9"
dependencies = [
"aes",
"arbitrary",

View File

@ -16,7 +16,7 @@ members = [
exclude = [
"examples/notebook",
# "crates/burn-cuda" # comment this line to work on burn-cuda
"crates/burn-cuda", # comment this line to work on burn-cuda
]
[workspace.package]
@ -27,7 +27,7 @@ license = "MIT OR Apache-2.0"
[workspace.dependencies]
bytemuck = "1.16.1"
candle-core = { version = "0.5.1" }
candle-core = { version = "0.6.0" }
clap = { version = "4.5.9", features = ["derive"] }
colored = "2.1.0"
console_error_panic_hook = "0.1.7"
@ -140,6 +140,9 @@ nvml-wrapper = "0.10.0"
sysinfo = "0.30.13"
systemstat = "0.2.3"
cubecl = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl", default-features = false }
cubecl-common = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl", default-features = false }
[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
opt-level = 2

View File

@ -24,14 +24,14 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
cuda-jit = ["burn-cuda"]
# cuda-jit = ["burn-cuda"]
[dependencies]
arboard = { workspace = true }
burn = { path = "../crates/burn", default-features = false }
burn-common = { path = "../crates/burn-common", version = "0.14.0" }
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.14.0" }
burn-cuda = { path = "../crates/burn-cuda", version = "0.14.0", optional = true }
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.14.0", optional = true }
# burn-cuda = { path = "../crates/burn-cuda", version = "0.14.0", optional = true }
clap = { workspace = true }
colored = { workspace = true }
derive-new = { workspace = true }
@ -49,6 +49,7 @@ strum_macros = { workspace = true }
sysinfo = { workspace = true, features = ["serde"] }
wgpu = { workspace = true }
wsl = { workspace = true }
cubecl = { workspace = true, features = ["wgpu"] }
[dev-dependencies]
rstest = { workspace = true }

View File

@ -29,7 +29,7 @@ impl<B: Backend, const D: usize> Benchmark for MatmulBenchmark<B, D> {
}
fn execute(&self, (lhs, rhs): Self::Args) {
lhs.clone().transpose().matmul(rhs.clone());
lhs.clone().matmul(rhs.clone());
}
fn prepare(&self) -> Self::Args {
@ -52,11 +52,11 @@ fn bench<B: Backend>(
token: Option<&str>,
) {
const D: usize = 3;
let batch_size = 32;
let m = 256;
let k = 1024;
let n = 256;
let shape_lhs = [batch_size, k, m].into();
let batch_size = 8;
let m = 2048;
let k = 2048;
let n = 2048;
let shape_lhs = [batch_size, m, k].into();
let shape_rhs = [batch_size, k, n].into();
let benchmark = MatmulBenchmark::<B, D>::new(shape_lhs, shape_rhs, device.clone());

View File

@ -1,5 +1,5 @@
use burn::serde::{Deserialize, Serialize};
use burn_wgpu::GraphicsApi;
use cubecl::wgpu::GraphicsApi;
use std::collections::HashSet;
use sysinfo;
use wgpu;
@ -51,7 +51,7 @@ impl BenchmarkSystemInfo {
fn enumerate_gpus() -> Vec<String> {
let instance = wgpu::Instance::default();
let adapters: Vec<wgpu::Adapter> = instance
.enumerate_adapters(burn_wgpu::AutoGraphicsApi::backend().into())
.enumerate_adapters(cubecl::wgpu::AutoGraphicsApi::backend().into())
.into_iter()
.filter(|adapter| {
let info = adapter.get_info();

View File

@ -11,8 +11,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-common"
version.workspace = true
[features]
default = ["std"]
std = ["rand/std", "data-encoding/std", "dep:pollster"]
default = ["std", "cubecl-common/default"]
std = ["cubecl-common/std"]
doc = ["default"]
network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
rayon = ["dep:rayon"]
@ -23,13 +23,7 @@ web-time = { version = "1.1.0" }
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
rand = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
derive-new = { workspace = true }
serde = { workspace = true }
data-encoding = { workspace = true }
pollster = { workspace = true, optional = true }
# Network downloader
indicatif = { workspace = true, optional = true }
@ -38,6 +32,7 @@ tokio = { workspace = true, optional = true }
# Parallel
rayon = { workspace = true, optional = true }
cubecl-common = { workspace = true, default-features = false }
[dev-dependencies]
dashmap = { workspace = true }

View File

@ -1,307 +0,0 @@
use alloc::format;
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use core::fmt::Display;
use core::time::Duration;
use serde::{Deserialize, Serialize};
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
use std::time::Instant;
#[cfg(all(target_family = "wasm", feature = "std"))]
use web_time::Instant;
/// Results of a benchmark run.
#[derive(new, Debug, Default, Clone, Serialize, Deserialize)]
pub struct BenchmarkDurations {
/// All durations of the run, in the order they were benchmarked
pub durations: Vec<Duration>,
}
impl BenchmarkDurations {
/// Returns a tuple of durations: (min, max, median)
fn min_max_median_durations(&self) -> (Duration, Duration, Duration) {
let mut sorted = self.durations.clone();
sorted.sort();
let min = *sorted.first().unwrap();
let max = *sorted.last().unwrap();
let median = *sorted.get(sorted.len() / 2).unwrap();
(min, max, median)
}
/// Returns the median duration among all durations
pub(crate) fn mean_duration(&self) -> Duration {
self.durations.iter().sum::<Duration>() / self.durations.len() as u32
}
/// Returns the variance durations for the durations
pub(crate) fn variance_duration(&self, mean: Duration) -> Duration {
let var = self
.durations
.iter()
.map(|duration| {
let tmp = duration.as_secs_f64() - mean.as_secs_f64();
Duration::from_secs_f64(tmp * tmp)
})
.sum::<Duration>()
/ self.durations.len() as u32;
var
}
}
impl Display for BenchmarkDurations {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let computed = BenchmarkComputations::new(self);
let BenchmarkComputations {
mean,
median,
variance,
min,
max,
} = computed;
let num_sample = self.durations.len();
f.write_str(
format!(
"
Result
Samples {num_sample}
Mean {mean:.3?}
Variance {variance:.3?}
Median {median:.3?}
Min {min:.3?}
Max {max:.3?}
"
)
.as_str(),
)
}
}
/// Computed values from benchmark durations.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct BenchmarkComputations {
/// Mean of all the durations.
pub mean: Duration,
/// Median of all the durations.
pub median: Duration,
/// Variance of all the durations.
pub variance: Duration,
/// Minimum duration amongst all durations.
pub min: Duration,
/// Maximum duration amongst all durations.
pub max: Duration,
}
impl BenchmarkComputations {
/// Compute duration values and return a BenchmarkComputations struct
pub fn new(durations: &BenchmarkDurations) -> Self {
let mean = durations.mean_duration();
let (min, max, median) = durations.min_max_median_durations();
Self {
mean,
median,
min,
max,
variance: durations.variance_duration(mean),
}
}
}
/// Benchmark trait.
pub trait Benchmark {
/// Benchmark arguments.
type Args: Clone;
/// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
/// count as included in the duration.
///
/// # Notes
///
/// This should not include warmup, the benchmark will be run at least one time without
/// measuring the execution time.
fn prepare(&self) -> Self::Args;
/// Execute the benchmark and returns the time it took to complete.
fn execute(&self, args: Self::Args);
/// Number of samples per run required to have a statistical significance.
fn num_samples(&self) -> usize {
10
}
/// Name of the benchmark, should be short and it should match the name
/// defined in the crate Cargo.toml
fn name(&self) -> String;
/// The options passed to the benchmark.
fn options(&self) -> Option<String> {
None
}
/// Shapes dimensions
fn shapes(&self) -> Vec<Vec<usize>> {
vec![]
}
/// Wait for computed to be over
fn sync(&self);
/// Run the benchmark a number of times.
fn run(&self) -> BenchmarkDurations {
#[cfg(not(feature = "std"))]
panic!("Attempting to run benchmark in a no-std environment");
#[cfg(feature = "std")]
{
// Warmup
let args = self.prepare();
self.execute(args.clone());
self.sync();
let mut durations = Vec::with_capacity(self.num_samples());
for _ in 0..self.num_samples() {
// Prepare
self.sync();
// Execute the benchmark
let start = Instant::now();
self.execute(args.clone());
self.sync();
let end = Instant::now();
// Register the duration
durations.push(end - start);
}
BenchmarkDurations { durations }
}
}
}
/// Result of a benchmark run, with metadata
#[derive(Default, Clone)]
pub struct BenchmarkResult {
/// Individual raw results of the run
pub raw: BenchmarkDurations,
/// Computed values for the run
pub computed: BenchmarkComputations,
/// Git commit hash of the commit in which the run occurred
pub git_hash: String,
/// Name of the benchmark
pub name: String,
/// Options passed to the benchmark
pub options: Option<String>,
/// Shape dimensions
pub shapes: Vec<Vec<usize>>,
/// Time just before the run
pub timestamp: u128,
}
impl Display for BenchmarkResult {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
format!(
"
Timestamp: {}
Git Hash: {}
Benchmarking - {}{}
",
self.timestamp, self.git_hash, self.name, self.raw
)
.as_str(),
)
}
}
#[cfg(feature = "std")]
/// Runs the given benchmark on the device and prints result and information.
pub fn run_benchmark<BM>(benchmark: BM) -> BenchmarkResult
where
BM: Benchmark,
{
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
let output = std::process::Command::new("git")
.args(["rev-parse", "HEAD"])
.output()
.unwrap();
let git_hash = String::from_utf8(output.stdout).unwrap().trim().to_string();
let durations = benchmark.run();
BenchmarkResult {
raw: durations.clone(),
computed: BenchmarkComputations::new(&durations),
git_hash,
name: benchmark.name(),
options: benchmark.options(),
shapes: benchmark.shapes(),
timestamp,
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn test_min_max_median_durations_even_number_of_samples() {
let durations = BenchmarkDurations {
durations: vec![
Duration::new(10, 0),
Duration::new(20, 0),
Duration::new(30, 0),
Duration::new(40, 0),
Duration::new(50, 0),
],
};
let (min, max, median) = durations.min_max_median_durations();
assert_eq!(min, Duration::from_secs(10));
assert_eq!(max, Duration::from_secs(50));
assert_eq!(median, Duration::from_secs(30));
}
#[test]
fn test_min_max_median_durations_odd_number_of_samples() {
let durations = BenchmarkDurations {
durations: vec![
Duration::new(18, 5),
Duration::new(20, 0),
Duration::new(30, 0),
Duration::new(40, 0),
],
};
let (min, max, median) = durations.min_max_median_durations();
assert_eq!(min, Duration::from_nanos(18000000005_u64));
assert_eq!(max, Duration::from_secs(40));
assert_eq!(median, Duration::from_secs(30));
}
#[test]
fn test_mean_duration() {
let durations = BenchmarkDurations {
durations: vec![
Duration::new(10, 0),
Duration::new(20, 0),
Duration::new(30, 0),
Duration::new(40, 0),
],
};
let mean = durations.mean_duration();
assert_eq!(mean, Duration::from_secs(25));
}
#[test]
fn test_variance_duration() {
let durations = BenchmarkDurations {
durations: vec![
Duration::new(10, 0),
Duration::new(20, 0),
Duration::new(30, 0),
Duration::new(40, 0),
Duration::new(50, 0),
],
};
let mean = durations.mean_duration();
let variance = durations.variance_duration(mean);
assert_eq!(variance, Duration::from_secs(200));
}
}

View File

@ -5,28 +5,10 @@
//!
//! This library contains common types used by other Burn crates that must be shared.
#[macro_use]
extern crate derive_new;
/// Id module contains types for unique identifiers.
pub mod id;
/// Rand module contains types for random number generation for non-std environments and for
/// std environments.
pub mod rand;
/// Stub module contains types for stubs for non-std environments and for std environments.
pub mod stub;
/// Module for benchmarking any executable part
pub mod benchmark;
/// Useful when you need to read async data without having to decorate each function with async
/// notation.
pub mod reader;
/// Synchronization type module, used both by ComputeServer and Backends.
pub mod sync_type;
pub use cubecl_common::*;
extern crate alloc;

View File

@ -1,45 +0,0 @@
pub use rand::{rngs::StdRng, Rng, SeedableRng};
use rand::distributions::Standard;
use rand::prelude::Distribution;
/// Returns a seeded random number generator using entropy.
#[cfg(feature = "std")]
#[inline(always)]
pub fn get_seeded_rng() -> StdRng {
StdRng::from_entropy()
}
/// Returns a seeded random number generator using a pre-generated seed.
#[cfg(not(feature = "std"))]
#[inline(always)]
pub fn get_seeded_rng() -> StdRng {
const CONST_SEED: u64 = 42;
StdRng::seed_from_u64(CONST_SEED)
}
/// Generates random data from a thread-local RNG.
#[cfg(feature = "std")]
#[inline]
pub fn gen_random<T>() -> T
where
Standard: Distribution<T>,
{
rand::thread_rng().gen()
}
/// Generates random data from a mutex-protected RNG.
#[cfg(not(feature = "std"))]
#[inline]
pub fn gen_random<T>() -> T
where
Standard: Distribution<T>,
{
use crate::stub::Mutex;
static RNG: Mutex<Option<StdRng>> = Mutex::new(None);
let mut rng = RNG.lock().unwrap();
if rng.is_none() {
*rng = Some(get_seeded_rng());
}
rng.as_mut().unwrap().gen()
}

View File

@ -1,54 +0,0 @@
use alloc::{boxed::Box, sync::Arc, task::Wake, vec::Vec};
use core::{
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
};
/// A future that is used to read resources from a compute server.
pub type Reader = Pin<Box<dyn Future<Output = Vec<u8>> + Send>>;
/// Create a reader from a concrete value.
pub fn reader_from_concrete(val: Vec<u8>) -> Reader {
Box::pin(async move { val })
}
struct DummyWaker;
impl Wake for DummyWaker {
fn wake(self: Arc<Self>) {}
fn wake_by_ref(self: &Arc<Self>) {}
}
/// Read a future synchronously.
///
/// On WASM futures cannot block, so this only succeeds if the future returns immediately.
/// If you want to handle this error, please use
/// try_read_sync instead.
pub fn read_sync<F: Future<Output = T>, T>(f: F) -> T {
try_read_sync(f).expect("Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. If possible, try using an async variant of this function instead.")
}
/// Read a future synchronously.
///
/// On WASM futures cannot block, so this only succeeds if the future returns immediately.
/// otherwise this returns None.
pub fn try_read_sync<F: Future<Output = T>, T>(f: F) -> Option<T> {
// Create a dummy context.
let waker = Waker::from(Arc::new(DummyWaker));
let mut context = Context::from_waker(&waker);
// Pin & poll the future. Some backends don't do async readbacks, and instead immediately get
// the data. This let's us detect when a future is synchronous and doesn't require any waiting.
let mut pinned = core::pin::pin!(f);
match pinned.as_mut().poll(&mut context) {
Poll::Ready(output) => Some(output),
// On platforms that support it, now just block on the future and drive it to completion.
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
Poll::Pending => Some(pollster::block_on(pinned)),
// Otherwise, just bail and return None - this futures will have to be read back asynchronously.
#[cfg(any(target_family = "wasm", not(feature = "std")))]
Poll::Pending => None,
}
}

View File

@ -1,154 +0,0 @@
#[cfg(not(feature = "std"))]
use spin::{
Mutex as MutexImported, MutexGuard, Once as OnceImported, RwLock as RwLockImported,
RwLockReadGuard, RwLockWriteGuard,
};
#[cfg(feature = "std")]
use std::sync::{
Mutex as MutexImported, MutexGuard, OnceLock as OnceImported, RwLock as RwLockImported,
RwLockReadGuard, RwLockWriteGuard,
};
/// A mutual exclusion primitive useful for protecting shared data
///
/// This mutex will block threads waiting for the lock to become available. The
/// mutex can also be statically initialized or created via a [Mutex::new]
///
/// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap
#[derive(Debug)]
pub struct Mutex<T> {
inner: MutexImported<T>,
}
impl<T> Mutex<T> {
/// Creates a new mutex in an unlocked state ready for use.
#[inline(always)]
pub const fn new(value: T) -> Self {
Self {
inner: MutexImported::new(value),
}
}
/// Locks the mutex blocking the current thread until it is able to do so.
#[inline(always)]
pub fn lock(&self) -> Result<MutexGuard<T>, alloc::string::String> {
#[cfg(not(feature = "std"))]
{
Ok(self.inner.lock())
}
#[cfg(feature = "std")]
{
self.inner.lock().map_err(|err| err.to_string())
}
}
}
/// A reader-writer lock which is exclusively locked for writing or shared for reading.
/// This reader-writer lock will block threads waiting for the lock to become available.
/// The lock can also be statically initialized or created via a [RwLock::new]
/// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap
#[derive(Debug)]
pub struct RwLock<T> {
inner: RwLockImported<T>,
}
impl<T> RwLock<T> {
/// Creates a new reader-writer lock in an unlocked state ready for use.
#[inline(always)]
pub const fn new(value: T) -> Self {
Self {
inner: RwLockImported::new(value),
}
}
/// Locks this rwlock with shared read access, blocking the current thread
/// until it can be acquired.
#[inline(always)]
pub fn read(&self) -> Result<RwLockReadGuard<T>, alloc::string::String> {
#[cfg(not(feature = "std"))]
{
Ok(self.inner.read())
}
#[cfg(feature = "std")]
{
self.inner.read().map_err(|err| err.to_string())
}
}
/// Locks this rwlock with exclusive write access, blocking the current thread
/// until it can be acquired.
#[inline(always)]
pub fn write(&self) -> Result<RwLockWriteGuard<T>, alloc::string::String> {
#[cfg(not(feature = "std"))]
{
Ok(self.inner.write())
}
#[cfg(feature = "std")]
{
self.inner.write().map_err(|err| err.to_string())
}
}
}
/// A unique identifier for a running thread.
///
/// This module is a stub when no std is available to swap with std::thread::ThreadId.
#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
pub struct ThreadId(core::num::NonZeroU64);
/// A cell that provides lazy one-time initialization that implements [Sync] and [Send].
///
/// This module is a stub when no std is available to swap with [std::sync::OnceLock].
pub struct SyncOnceCell<T>(OnceImported<T>);
impl<T: core::fmt::Debug> Default for SyncOnceCell<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: core::fmt::Debug> SyncOnceCell<T> {
/// Create a new once.
#[inline(always)]
pub fn new() -> Self {
Self(OnceImported::new())
}
/// Initialize the cell with a value.
#[inline(always)]
pub fn initialized(value: T) -> Self {
#[cfg(not(feature = "std"))]
{
let cell = OnceImported::initialized(value);
Self(cell)
}
#[cfg(feature = "std")]
{
let cell = OnceImported::new();
cell.set(value).unwrap();
Self(cell)
}
}
/// Gets the contents of the cell, initializing it with `f` if the cell
/// was empty.
#[inline(always)]
pub fn get_or_init<F>(&self, f: F) -> &T
where
F: FnOnce() -> T,
{
#[cfg(not(feature = "std"))]
{
self.0.call_once(f)
}
#[cfg(feature = "std")]
{
self.0.get_or_init(f)
}
}
}

View File

@ -1,8 +0,0 @@
/// What kind of synchronization to use.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SyncType {
/// Submit all outstanding tasks to the task queue if any.
Flush,
/// Submit all tasks to the task queue and wait for all of them to complete.
Wait,
}

View File

@ -1,51 +0,0 @@
[package]
authors = ["louisfd <louisfd94@gmail.com>", "Nathaniel Simard"]
categories = ["science"]
description = "Compute crate that helps creating high performance async backends."
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "data"]
license.workspace = true
name = "burn-compute"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-compute"
version.workspace = true
[features]
default = [
"std",
"channel-mutex",
"channel-mpsc",
"channel-cell",
"storage-bytes",
"autotune-persistent-cache",
]
std = ["burn-common/std"]
channel-mutex = []
channel-cell = []
channel-mpsc = ["dep:async-channel", "dep:pollster"] # Assume std
storage-bytes = []
autotune-persistent-cache = ["dirs", "md5", "serde", "serde_json"] # Assume std
[dependencies]
burn-common = { path = "../burn-common", version = "0.14.0", default-features = false }
derive-new = { workspace = true }
spin = { workspace = true }
log = { workspace = true }
hashbrown = { workspace = true }
dirs = { workspace = true, optional = true }
serde = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["std"], optional = true }
md5 = { workspace = true, optional = true }
pollster = { workspace = true, optional = true }
async-channel = { workspace = true, optional = true }
[target.'cfg(target_family = "wasm")'.dependencies]
web-time = { workspace = true }
[dev-dependencies]
serial_test = { workspace = true }
rand = { workspace = true }
[[bench]]
name = "dynamic"
harness = false

View File

@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@ -1 +0,0 @@
../../LICENSE-MIT

View File

@ -1,7 +0,0 @@
# Burn Compute
This crate helps creating high performance async backends.
- [x] Asynchronous kernel executions
- [x] Memory allocation management
- [x] Autotuning

View File

@ -1,29 +0,0 @@
use std::collections::LinkedList;
use burn_compute::{
memory_management::{
dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
MemoryManagement,
},
storage::BytesStorage,
};
const MB: usize = 1024 * 1024;
fn main() {
let start = std::time::Instant::now();
let storage = BytesStorage::default();
let mut mm = DynamicMemoryManagement::new(
storage,
DynamicMemoryManagementOptions::preset(2048 * MB, 32),
);
let mut handles = LinkedList::new();
for _ in 0..100 * 2048 {
if handles.len() >= 4000 {
handles.pop_front();
}
let handle = mm.reserve(MB, || {});
handles.push_back(handle);
}
println!("{:?}", start.elapsed());
}

View File

@ -1,36 +0,0 @@
use crate::{
server::{Binding, ComputeServer, Handle},
storage::ComputeStorage,
};
use alloc::vec::Vec;
use burn_common::{reader::Reader, sync_type::SyncType};
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
/// while ensuring thread-safety
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send + Sync {
/// Given a binding, returns owned resource as bytes
fn read(&self, binding: Binding<Server>) -> Reader;
/// Given a resource handle, return the storage resource.
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource;
/// Given a resource as bytes, stores it and returns the resource handle
fn create(&self, data: &[u8]) -> Handle<Server>;
/// Reserves `size` bytes in the storage, and returns a handle over them
fn empty(&self, size: usize) -> Handle<Server>;
/// Executes the `kernel` over the given `bindings`.
fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
);
/// Perform some synchronization of commands on the server.
fn sync(&self, sync_type: SyncType);
}

View File

@ -1,85 +0,0 @@
use super::ComputeChannel;
use crate::server::{Binding, ComputeServer, Handle};
use crate::storage::ComputeStorage;
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
use burn_common::sync_type::SyncType;
/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability.
///
/// # Important
///
/// Only use this channel if you don't use any threading in your application, otherwise it will
/// panic or cause undefined behaviors.
///
/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer
/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels.
#[derive(Debug)]
pub struct RefCellComputeChannel<Server> {
server: Arc<core::cell::RefCell<Server>>,
}
impl<S> Clone for RefCellComputeChannel<S> {
fn clone(&self) -> Self {
Self {
server: self.server.clone(),
}
}
}
impl<Server> RefCellComputeChannel<Server>
where
Server: ComputeServer,
{
/// Create a new cell compute channel.
pub fn new(server: Server) -> Self {
Self {
server: Arc::new(core::cell::RefCell::new(server)),
}
}
}
impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
where
Server: ComputeServer + Send,
{
fn read(&self, binding: Binding<Server>) -> Reader {
self.server.borrow_mut().read(binding)
}
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
self.server.borrow_mut().get_resource(binding)
}
fn create(&self, resource: &[u8]) -> Handle<Server> {
self.server.borrow_mut().create(resource)
}
fn empty(&self, size: usize) -> Handle<Server> {
self.server.borrow_mut().empty(size)
}
fn execute(
&self,
kernel_description: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.server
.borrow_mut()
.execute(kernel_description, count, bindings)
}
fn sync(&self, sync_type: SyncType) {
self.server.borrow_mut().sync(sync_type)
}
}
/// This is unsafe, since no concurrency is supported by the `RefCell` channel.
/// However using this channel should only be done in single threaded environments such as `no-std`.
unsafe impl<Server: ComputeServer> Send for RefCellComputeChannel<Server> {}
unsafe impl<Server: ComputeServer> Sync for RefCellComputeChannel<Server> {}

View File

@ -1,17 +0,0 @@
mod base;
pub use base::*;
#[cfg(feature = "channel-mutex")]
mod mutex;
#[cfg(feature = "channel-mutex")]
pub use mutex::*;
#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))]
mod mpsc;
#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))]
pub use mpsc::*;
#[cfg(feature = "channel-cell")]
mod cell;
#[cfg(feature = "channel-cell")]
pub use cell::*;

View File

@ -1,181 +0,0 @@
use burn_common::{reader::Reader, sync_type::SyncType};
use std::{sync::Arc, thread};
use super::ComputeChannel;
use crate::{
server::{Binding, ComputeServer, Handle},
storage::ComputeStorage,
};
/// Create a channel using a [multi-producer, single-consumer channel to communicate with
/// the compute server spawn on its own thread.
#[derive(Debug)]
pub struct MpscComputeChannel<Server>
where
Server: ComputeServer,
{
state: Arc<MpscComputeChannelState<Server>>,
}
#[derive(Debug)]
struct MpscComputeChannelState<Server>
where
Server: ComputeServer,
{
_handle: thread::JoinHandle<()>,
sender: async_channel::Sender<Message<Server>>,
}
type Callback<Response> = async_channel::Sender<Response>;
enum Message<Server>
where
Server: ComputeServer,
{
Read(Binding<Server>, Callback<Vec<u8>>),
GetResource(
Binding<Server>,
Callback<<Server::Storage as ComputeStorage>::Resource>,
),
Create(Vec<u8>, Callback<Handle<Server>>),
Empty(usize, Callback<Handle<Server>>),
ExecuteKernel(
(Server::Kernel, Server::DispatchOptions),
Vec<Binding<Server>>,
),
Sync(SyncType, Callback<()>),
}
impl<Server> MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
{
/// Create a new mpsc compute channel.
pub fn new(mut server: Server) -> Self {
let (sender, receiver) = async_channel::unbounded();
let _handle = thread::spawn(move || {
// Run the whole procedure as one blocking future. This is much simpler than trying
// to use some multithreaded executor.
pollster::block_on(async {
while let Ok(message) = receiver.recv().await {
match message {
Message::Read(binding, callback) => {
let data = server.read(binding).await;
callback.send(data).await.unwrap();
}
Message::GetResource(binding, callback) => {
let data = server.get_resource(binding);
callback.send(data).await.unwrap();
}
Message::Create(data, callback) => {
let handle = server.create(&data);
callback.send(handle).await.unwrap();
}
Message::Empty(size, callback) => {
let handle = server.empty(size);
callback.send(handle).await.unwrap();
}
Message::ExecuteKernel(kernel, bindings) => {
server.execute(kernel.0, kernel.1, bindings);
}
Message::Sync(sync_type, callback) => {
server.sync(sync_type);
callback.send(()).await.unwrap();
}
};
}
});
});
let state = Arc::new(MpscComputeChannelState { sender, _handle });
Self { state }
}
}
impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
}
}
}
impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
{
fn read(&self, binding: Binding<Server>) -> Reader {
let sender = self.state.sender.clone();
Box::pin(async move {
let (callback, response) = async_channel::unbounded();
sender.send(Message::Read(binding, callback)).await.unwrap();
handle_response(response.recv().await)
})
}
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
let (callback, response) = async_channel::unbounded();
self.state
.sender
.send_blocking(Message::GetResource(binding, callback))
.unwrap();
handle_response(response.recv_blocking())
}
fn create(&self, data: &[u8]) -> Handle<Server> {
let (callback, response) = async_channel::unbounded();
self.state
.sender
.send_blocking(Message::Create(data.to_vec(), callback))
.unwrap();
handle_response(response.recv_blocking())
}
fn empty(&self, size: usize) -> Handle<Server> {
let (callback, response) = async_channel::unbounded();
self.state
.sender
.send_blocking(Message::Empty(size, callback))
.unwrap();
handle_response(response.recv_blocking())
}
fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.state
.sender
.send_blocking(Message::ExecuteKernel((kernel, count), bindings))
.unwrap()
}
fn sync(&self, sync_type: SyncType) {
let (callback, response) = async_channel::unbounded();
self.state
.sender
.send_blocking(Message::Sync(sync_type, callback))
.unwrap();
handle_response(response.recv_blocking())
}
}
fn handle_response<Response, Err: core::fmt::Debug>(response: Result<Response, Err>) -> Response {
match response {
Ok(val) => val,
Err(err) => panic!("Can't connect to the server correctly {err:?}"),
}
}

View File

@ -1,71 +0,0 @@
use super::ComputeChannel;
use crate::server::{Binding, ComputeServer, Handle};
use crate::storage::ComputeStorage;
use alloc::sync::Arc;
use alloc::vec::Vec;
use burn_common::reader::Reader;
use burn_common::sync_type::SyncType;
use spin::Mutex;
/// The MutexComputeChannel ensures thread-safety by locking the server
/// on every operation
#[derive(Debug)]
pub struct MutexComputeChannel<Server> {
server: Arc<Mutex<Server>>,
}
impl<S> Clone for MutexComputeChannel<S> {
fn clone(&self) -> Self {
Self {
server: self.server.clone(),
}
}
}
impl<Server> MutexComputeChannel<Server>
where
Server: ComputeServer,
{
/// Create a new mutex compute channel.
pub fn new(server: Server) -> Self {
Self {
server: Arc::new(Mutex::new(server)),
}
}
}
impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
where
Server: ComputeServer,
{
fn read(&self, handle: Binding<Server>) -> Reader {
self.server.lock().read(handle)
}
fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
self.server.lock().get_resource(binding)
}
fn create(&self, data: &[u8]) -> Handle<Server> {
self.server.lock().create(data)
}
fn empty(&self, size: usize) -> Handle<Server> {
self.server.lock().empty(size)
}
fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
handles: Vec<Binding<Server>>,
) {
self.server.lock().execute(kernel, count, handles)
}
fn sync(&self, sync_type: SyncType) {
self.server.lock().sync(sync_type)
}
}

View File

@ -1,119 +0,0 @@
use crate::{
channel::ComputeChannel,
server::{Binding, ComputeServer, Handle},
storage::ComputeStorage,
tune::{AutotuneOperationSet, Tuner},
};
use alloc::vec::Vec;
use alloc::{boxed::Box, sync::Arc};
use burn_common::stub::RwLock;
use burn_common::sync_type::SyncType;
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
/// It should be obtained for a specific device via the Compute struct.
#[derive(Debug)]
pub struct ComputeClient<Server: ComputeServer, Channel> {
channel: Channel,
tuner: Arc<RwLock<Tuner<Server::AutotuneKey>>>,
features: Arc<Server::FeatureSet>,
}
impl<S, C> Clone for ComputeClient<S, C>
where
S: ComputeServer,
C: ComputeChannel<S>,
{
fn clone(&self) -> Self {
Self {
channel: self.channel.clone(),
tuner: self.tuner.clone(),
features: self.features.clone(),
}
}
}
impl<Server, Channel> ComputeClient<Server, Channel>
where
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
/// Create a new client.
pub fn new(
channel: Channel,
tuner: Arc<RwLock<Tuner<Server::AutotuneKey>>>,
features: Arc<Server::FeatureSet>,
) -> Self {
Self {
channel,
tuner,
features,
}
}
/// Given a binding, returns owned resource as bytes.
pub async fn read_async(&self, binding: Binding<Server>) -> Vec<u8> {
self.channel.read(binding).await
}
/// Given a binding, returns owned resource as bytes.
///
/// # Remarks
/// Panics if the read operation fails.
pub fn read(&self, binding: Binding<Server>) -> Vec<u8> {
burn_common::reader::read_sync(self.channel.read(binding))
}
/// Given a resource handle, returns the storage resource.
pub fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
self.channel.get_resource(binding)
}
/// Given a resource, stores it and returns the resource handle.
pub fn create(&self, data: &[u8]) -> Handle<Server> {
self.channel.create(data)
}
/// Reserves `size` bytes in the storage, and returns a handle over them.
pub fn empty(&self, size: usize) -> Handle<Server> {
self.channel.empty(size)
}
/// Executes the `kernel` over the given `bindings`.
pub fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.channel.execute(kernel, count, bindings)
}
/// Wait for the completion of every task in the server.
pub fn sync(&self, sync_type: SyncType) {
self.channel.sync(sync_type)
}
/// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks
pub fn autotune_execute(
&self,
autotune_operation_set: Box<dyn AutotuneOperationSet<Server::AutotuneKey>>,
) {
self.tuner
.write()
.unwrap()
.execute_autotune(autotune_operation_set, self);
}
/// Get the fastest kernel for the given autotune key if it exists.
pub fn autotune_result(&self, key: &Server::AutotuneKey) -> Option<usize> {
self.tuner.read().unwrap().autotune_fastest(key)
}
/// Get the features supported by the compute server.
pub fn features(&self) -> &Server::FeatureSet {
self.features.as_ref()
}
}

View File

@ -1,94 +0,0 @@
use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
use core::ops::DerefMut;
use hashbrown::HashMap;
/// The compute type has the responsibility to retrieve the correct compute client based on the
/// given device.
pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
}
impl<Device, Server, Channel> Default for ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
fn default() -> Self {
Self::new()
}
}
impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
/// Create a new compute.
pub const fn new() -> Self {
Self {
clients: spin::Mutex::new(None),
}
}
/// Get the compute client for the given device.
///
/// Provide the init function to create a new client if it isn't already initialized.
pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
where
Init: Fn() -> ComputeClient<Server, Channel>,
{
let mut clients = self.clients.lock();
if clients.is_none() {
Self::register_inner(device, init(), &mut clients);
}
match clients.deref_mut() {
Some(clients) => match clients.get(device) {
Some(client) => client.clone(),
None => {
let client = init();
clients.insert(device.clone(), client.clone());
client
}
},
_ => unreachable!(),
}
}
/// Register the compute client for the given device.
///
/// # Note
///
/// This function is mostly useful when the creation of the compute client can't be done
/// synchronously and require special context.
///
/// # Panics
///
/// If a client is already registered for the given device.
pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
let mut clients = self.clients.lock();
Self::register_inner(device, client, &mut clients);
}
fn register_inner(
device: &Device,
client: ComputeClient<Server, Channel>,
clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
) {
if clients.is_none() {
*clients = Some(HashMap::new());
}
if let Some(clients) = clients {
if clients.contains_key(device) {
panic!("Client already created for device {:?}", device);
}
clients.insert(device.clone(), client);
}
}
}

View File

@ -1,175 +0,0 @@
use alloc::sync::Arc;
#[macro_export(local_inner_macros)]
/// Create a new storage ID type.
macro_rules! storage_id_type {
($name:ident) => {
/// Storage ID.
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub struct $name {
value: usize,
}
impl $name {
/// Create a new ID.
pub fn new() -> Self {
use core::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let value = COUNTER.fetch_add(1, Ordering::Relaxed);
if value == usize::MAX {
core::panic!("Memory ID overflowed");
}
Self { value }
}
}
impl Default for $name {
fn default() -> Self {
Self::new()
}
}
};
}
/// Reference to a buffer handle.
#[derive(Clone, Debug)]
pub struct HandleRef<Id> {
id: Arc<Id>,
all: Arc<()>,
}
/// Reference to buffer binding.
#[derive(Clone, Debug)]
pub struct BindingRef<Id> {
id: Id,
_all: Arc<()>,
}
impl<Id> BindingRef<Id>
where
Id: Clone + core::fmt::Debug,
{
/// The id associated to the buffer.
pub(crate) fn id(&self) -> &Id {
&self.id
}
}
impl<Id> HandleRef<Id>
where
Id: Clone + core::fmt::Debug,
{
/// Create a new handle.
pub(crate) fn new(id: Id) -> Self {
Self {
id: Arc::new(id),
all: Arc::new(()),
}
}
/// The id associated to the handle.
pub(crate) fn id(&self) -> &Id {
&self.id
}
/// Get the binding.
pub(crate) fn binding(self) -> BindingRef<Id> {
BindingRef {
id: self.id.as_ref().clone(),
_all: self.all,
}
}
/// If the handle can be mut.
pub(crate) fn can_mut(&self) -> bool {
// 1 memory management reference with 1 tensor reference.
Arc::strong_count(&self.id) <= 2
}
/// If the resource is free.
pub(crate) fn is_free(&self) -> bool {
Arc::strong_count(&self.all) <= 1
}
}
#[macro_export(local_inner_macros)]
/// Create new memory ID types.
macro_rules! memory_id_type {
($id:ident, $handle:ident) => {
/// Memory Handle.
#[derive(Clone, Debug)]
pub struct $handle {
value: $crate::id::HandleRef<$id>,
}
/// Memory ID.
#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
pub struct $id {
pub(crate) value: usize,
}
impl $handle {
/// Create a new ID.
pub(crate) fn new() -> Self {
let value = Self::gen_id();
Self {
value: $crate::id::HandleRef::new($id { value }),
}
}
fn gen_id() -> usize {
static COUNTER: core::sync::atomic::AtomicUsize =
core::sync::atomic::AtomicUsize::new(0);
let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
if value == usize::MAX {
core::panic!("Memory ID overflowed");
}
value
}
}
impl core::ops::Deref for $handle {
type Target = $crate::id::HandleRef<$id>;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl Default for $handle {
fn default() -> Self {
Self::new()
}
}
};
($id:ident, $handle:ident, $binding:ident) => {
memory_id_type!($id, $handle);
/// Binding of a memory handle.
#[derive(Clone, Debug)]
pub struct $binding {
value: $crate::id::BindingRef<$id>,
}
impl $handle {
pub(crate) fn binding(self) -> $binding {
$binding {
value: self.value.binding(),
}
}
}
impl core::ops::Deref for $binding {
type Target = $crate::id::BindingRef<$id>;
fn deref(&self) -> &Self::Target {
&self.value
}
}
};
}

View File

@ -1,29 +0,0 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
//! Burn compute crate that helps creating high performance async backends.
extern crate alloc;
#[macro_use]
extern crate derive_new;
mod id;
/// Compute channel module.
pub mod channel;
/// Compute client module.
pub mod client;
/// Autotune module
pub mod tune;
/// Memory management module.
pub mod memory_management;
/// Compute server module.
pub mod server;
/// Compute Storage module.
pub mod storage;
mod compute;
pub use compute::*;

View File

@ -1,57 +0,0 @@
use crate::storage::ComputeStorage;
/// The managed tensor buffer handle that points to some memory segment.
/// It should not contain actual data.
pub trait MemoryHandle<Binding>: Clone + Send + Sync + core::fmt::Debug {
/// Checks if the underlying memory can be safely mutated.
fn can_mut(&self) -> bool;
/// Get the binding associated to the current handle.
fn binding(self) -> Binding;
}
/// Binding to a [memory handle](MemoryHandle).
pub trait MemoryBinding: Clone + Send + Sync + core::fmt::Debug {}
/// The MemoryManagement trait encapsulates strategies for (de)allocating memory.
/// It is bound to the ComputeStorage trait, which does the actual (de)allocations.
///
/// The MemoryManagement can only reserve memory space or get the resource located at a space.
/// Modification of the resource data should be done directly on the resource.
pub trait MemoryManagement<Storage: ComputeStorage>: Send + core::fmt::Debug {
/// The associated type that must implement [MemoryHandle].
type Handle: MemoryHandle<Self::Binding>;
/// The associated type that must implement [MemoryBinding]
type Binding: MemoryBinding;
/// Returns the resource from the storage at the specified handle
fn get(&mut self, binding: Self::Binding) -> Storage::Resource;
/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
/// Bypass the memory allocation algorithm to allocate data directly.
///
/// # Notes
///
/// Can be useful for servers that want specific control over memory.
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
/// Bypass the memory allocation algorithm to deallocate data directly.
///
/// # Notes
///
/// Can be useful for servers that want specific control over memory.
fn dealloc(&mut self, binding: Self::Binding);
/// Fetch the storage used by the memory manager.
///
/// # Notes
///
/// The storage should probably not be used for allocations since the handles won't be
/// compatible with the ones provided by the current trait. Prefer using the
/// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions.
///
/// This is useful if you need to time the deallocations based on async computation, or to
/// change the mode of storage for different reasons.
fn storage(&mut self) -> &mut Storage;
}

View File

@ -1,181 +0,0 @@
use super::memory_pool::{
MemoryExtensionStrategy, MemoryPool, MemoryPoolBinding, MemoryPoolHandle, RoundingStrategy,
SmallMemoryPool,
};
use crate::storage::ComputeStorage;
use alloc::vec::Vec;
use super::MemoryManagement;
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
pub struct DynamicMemoryManagement<Storage> {
min_chunk_alignment_offset: usize,
small_memory_pool: SmallMemoryPool,
pools: Vec<MemoryPool>,
options: Vec<MemoryPoolOptions>,
storage: Storage,
}
/// Options to initialize a [dynamic memory management](DynamicMemoryManagement).
#[derive(new, Debug)]
pub struct DynamicMemoryManagementOptions {
pools: Vec<MemoryPoolOptions>,
min_chunk_alignment_offset: usize,
}
/// Options to create a memory pool.
#[derive(Debug)]
pub struct MemoryPoolOptions {
/// The amount of bytes used for each chunk in the memory pool.
pub chunk_size: usize,
/// The number of chunks allocated directly at creation.
///
/// Useful when you know in advance how much memory you'll need.
pub chunk_num_prealloc: usize,
/// The max size in bytes a slice can take in the pool.
pub slice_max_size: usize,
}
impl DynamicMemoryManagementOptions {
/// Creates the options from device limits.
pub fn preset(max_chunk_size: usize, min_chunk_alignment_offset: usize) -> Self {
// Rounding down to a factor of 8.
let max_chunk_size = (max_chunk_size / 8) * 8;
const MB: usize = 1024 * 1024;
let mut pools = Vec::new();
pools.push(MemoryPoolOptions {
chunk_size: max_chunk_size,
chunk_num_prealloc: 0,
slice_max_size: max_chunk_size,
});
let mut current = max_chunk_size;
while current >= 32 * MB {
current /= 4;
pools.push(MemoryPoolOptions {
chunk_size: current,
chunk_num_prealloc: 0,
// Creating max slices lower than the chunk size reduces fragmentation.
slice_max_size: current / 2usize.pow(pools.len() as u32),
});
}
Self {
pools,
min_chunk_alignment_offset,
}
}
}
impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
pub fn new(mut storage: Storage, mut options: DynamicMemoryManagementOptions) -> Self {
options
.pools
.sort_by(|pool1, pool2| usize::cmp(&pool1.slice_max_size, &pool2.slice_max_size));
let min_chunk_alignment_offset = options.min_chunk_alignment_offset;
let pools = options
.pools
.iter()
.map(|option| {
let mut pool = MemoryPool::new(
MemoryExtensionStrategy::Never,
RoundingStrategy::FixedAmount(option.chunk_size),
min_chunk_alignment_offset,
);
for _ in 0..option.chunk_num_prealloc {
pool.alloc(&mut storage, option.chunk_size, || {});
}
pool
})
.collect();
Self {
min_chunk_alignment_offset,
small_memory_pool: SmallMemoryPool::new(min_chunk_alignment_offset),
pools,
options: options.pools,
storage,
}
}
}
impl<Storage> core::fmt::Debug for DynamicMemoryManagement<Storage> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
alloc::format!(
"DynamicMemoryManagement {:?}",
core::any::type_name::<Storage>(),
)
.as_str(),
)
}
}
impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagement<Storage> {
type Handle = MemoryPoolHandle;
type Binding = MemoryPoolBinding;
fn get(&mut self, binding: Self::Binding) -> Storage::Resource {
if let Some(handle) = self.small_memory_pool.get(&mut self.storage, &binding) {
return handle;
}
for pool in &mut self.pools {
if let Some(handle) = pool.get(&mut self.storage, &binding) {
return handle;
}
}
panic!("No handle found in memory pools");
}
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
if size <= self.min_chunk_alignment_offset {
return self
.small_memory_pool
.reserve(&mut self.storage, size, sync);
}
for (index, option) in self.options.iter().enumerate() {
if size <= option.slice_max_size {
let pool = &mut self.pools[index];
return pool.reserve(&mut self.storage, size, sync);
}
}
panic!("No memory pool big enough to reserve {size} bytes.");
}
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
if size <= self.min_chunk_alignment_offset {
return self.small_memory_pool.alloc(&mut self.storage, size, sync);
}
for (index, option) in self.options.iter().enumerate() {
if size <= option.slice_max_size {
let pool = &mut self.pools[index];
return pool.alloc(&mut self.storage, size, sync);
}
}
panic!("No memory pool big enough to alloc {size} bytes.");
}
fn dealloc(&mut self, _binding: Self::Binding) {
// Can't dealloc slices.
}
fn storage(&mut self) -> &mut Storage {
&mut self.storage
}
}

View File

@ -1,586 +0,0 @@
use super::index::SearchIndex;
use super::{
ChunkHandle, ChunkId, MemoryChunk, MemoryPoolBinding, MemoryPoolHandle, MemorySlice,
RingBuffer, SliceHandle, SliceId,
};
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
use alloc::vec::Vec;
use hashbrown::{HashMap, HashSet};
pub struct MemoryPool {
chunks: HashMap<ChunkId, Chunk>,
slices: HashMap<SliceId, Slice>,
#[allow(unused)] // will be used when we rewrite memory extension
memory_extension_strategy: MemoryExtensionStrategy,
rounding: RoundingStrategy,
chunk_index: SearchIndex<ChunkId>,
ring: RingBuffer<Chunk, Slice>,
recently_added_chunks: Vec<ChunkId>,
recently_allocated_size: usize,
buffer_alignment: usize,
}
struct SliceUpdate {
slice_id: SliceId,
size: usize,
}
#[derive(new, Debug)]
pub struct Chunk {
pub storage: StorageHandle,
pub handle: ChunkHandle,
pub slices: MemoryPage,
}
// TODO: consider using generic trait and decouple from Slice
#[derive(new, Debug)]
pub struct MemoryPage {
pub slices: HashMap<usize, SliceId>,
}
impl MemoryPage {
/// merge slice at first_slice_address with the next slice (if there is one and if it's free)
/// return a boolean representing if a merge happened
fn merge_with_next_slice(
&mut self,
first_slice_address: usize,
slices: &mut HashMap<SliceId, Slice>,
) -> bool {
let first_slice_id = self.find_slice(first_slice_address).expect(
"merge_with_next_slice shouldn't be called with a nonexistent first_slice address",
);
let next_slice_address =
first_slice_address + slices.get(&first_slice_id).unwrap().effective_size();
if let Some(next_slice_id) = self.find_slice(next_slice_address) {
let (next_slice_eff_size, next_slice_is_free) = {
let next_slice = slices.get(&next_slice_id).unwrap();
(next_slice.effective_size(), next_slice.is_free())
};
if next_slice_is_free {
let first_slice = slices.get_mut(&first_slice_id).unwrap();
let first_slice_eff_size = first_slice.effective_size();
let first_slice_offset = first_slice.storage.offset();
let merged_size = first_slice_eff_size + next_slice_eff_size;
first_slice.storage.utilization = StorageUtilization::Slice {
size: merged_size,
offset: first_slice_offset,
};
first_slice.padding = 0;
// Cleanup of the extra slice
self.slices.remove(&next_slice_address);
slices.remove(&next_slice_id);
return true;
}
return false;
}
false
}
fn find_slice(&self, address: usize) -> Option<SliceId> {
let slice_id = self.slices.get(&address);
slice_id.copied()
}
fn insert_slice(&mut self, address: usize, slice: SliceId) {
self.slices.insert(address, slice);
}
fn slices_sorted_by_address(&self) -> Vec<SliceId> {
let mut entries: Vec<(usize, SliceId)> = self.slices.clone().into_iter().collect();
entries.sort_by_key(|&(key, _)| key);
let sorted_slices: Vec<SliceId> = entries.into_iter().map(|(_, values)| values).collect();
sorted_slices
}
}
#[derive(new, Debug)]
pub struct Slice {
pub storage: StorageHandle,
pub handle: SliceHandle,
pub chunk: ChunkHandle,
pub padding: usize,
}
impl Slice {
pub fn effective_size(&self) -> usize {
self.storage.size() + self.padding
}
}
const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16;
pub enum RoundingStrategy {
FixedAmount(usize),
#[allow(unused)]
None,
}
impl RoundingStrategy {
fn alloc_size(&self, size: usize) -> usize {
match self {
RoundingStrategy::FixedAmount(chunk_size) => {
assert!(*chunk_size >= size);
*chunk_size
}
RoundingStrategy::None => size,
}
}
}
/// The strategy defines the frequency at which merging of free slices (defragmentation) occurs
#[allow(unused)]
#[derive(Debug)]
pub enum MemoryExtensionStrategy {
/// Once every n calls to reserve.
PeriodTick {
/// Number of calls to be executed before triggering the defragmentation.
period: usize,
/// Current state. Should start at zero.
state: usize,
},
/// Never defragment.
Never,
}
#[allow(unused)]
impl MemoryExtensionStrategy {
/// Create a new strategy with the given period.
pub fn new_period_tick(period: usize) -> Self {
MemoryExtensionStrategy::PeriodTick { period, state: 0 }
}
#[allow(unused)]
fn should_extend_max_memory(&mut self) -> bool {
match self {
MemoryExtensionStrategy::PeriodTick { period, state } => {
*state = (*state + 1) % *period;
*state == 0
}
MemoryExtensionStrategy::Never => false,
}
}
}
impl MemoryPool {
pub fn new(
merging_strategy: MemoryExtensionStrategy,
alloc_strategy: RoundingStrategy,
buffer_alignment: usize,
) -> Self {
Self {
chunks: HashMap::new(),
slices: HashMap::new(),
memory_extension_strategy: merging_strategy,
rounding: alloc_strategy,
chunk_index: SearchIndex::new(),
ring: RingBuffer::new(buffer_alignment),
recently_added_chunks: Vec::new(),
recently_allocated_size: 0,
buffer_alignment,
}
}
/// Returns the resource from the storage, for the specified handle.
pub fn get<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
binding: &MemoryPoolBinding,
) -> Option<Storage::Resource> {
self.slices
.get(binding.slice.id())
.map(|s| &s.storage)
.map(|h| storage.get(h))
}
/// Reserves memory of specified size using the reserve algorithm, and return
/// a handle to the reserved memory.
///
/// Also clean ups, merging free slices together if permitted by the merging strategy
pub fn reserve<Storage: ComputeStorage, Sync: FnOnce()>(
&mut self,
storage: &mut Storage,
size: usize,
sync: Sync,
) -> MemoryPoolHandle {
let slice = self.get_free_slice(size);
match slice {
Some(slice) => MemoryPoolHandle {
slice: slice.clone(),
},
None => self.alloc(storage, size, sync),
}
}
pub fn alloc<Storage: ComputeStorage, Sync: FnOnce()>(
&mut self,
storage: &mut Storage,
size: usize,
#[allow(unused)] sync: Sync,
) -> MemoryPoolHandle {
let alloc_size = self.rounding.alloc_size(size);
self.alloc_slice(storage, alloc_size, size)
}
fn alloc_slice<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
alloc_size: usize,
slice_size: usize,
) -> MemoryPoolHandle {
let chunk_size = self.rounding.alloc_size(alloc_size);
let handle_chunk = self.create_chunk(storage, chunk_size);
let chunk_size = self.chunks.get(handle_chunk.id()).unwrap().storage.size();
self.recently_added_chunks.push(*handle_chunk.id());
self.recently_allocated_size += chunk_size;
let chunk_id = *handle_chunk.id();
let (slice, extra_slice) =
self.allocate_slices(handle_chunk.clone(), chunk_size, slice_size);
let handle_slice = slice.handle.clone();
self.update_chunk_metadata(chunk_id, slice, extra_slice);
MemoryPoolHandle {
slice: handle_slice,
}
}
fn allocate_slices(
&self,
handle_chunk: ChunkHandle,
alloc_size: usize,
slice_size: usize,
) -> (Slice, Option<Slice>) {
let slice = self.create_slice(0, slice_size, handle_chunk.clone());
let effective_size = slice.effective_size();
let extra_slice = if effective_size < alloc_size {
Some(self.create_slice(effective_size, alloc_size - effective_size, handle_chunk))
} else {
None
};
(slice, extra_slice)
}
fn update_chunk_metadata(
&mut self,
chunk_id: ChunkId,
slice: Slice,
extra_slice: Option<Slice>,
) {
let slice_id = *slice.handle.id();
let slice_offset = slice.storage.offset();
self.slices.insert(slice_id, slice);
self.chunks
.get_mut(&chunk_id)
.unwrap()
.slices
.slices
.insert(slice_offset, slice_id);
if let Some(extra_slice) = extra_slice {
let extra_slice_id = *extra_slice.handle.id();
let extra_slice_offset = extra_slice.storage.offset();
self.slices.insert(extra_slice_id, extra_slice);
self.chunks
.get_mut(&chunk_id)
.unwrap()
.slices
.slices
.insert(extra_slice_offset, extra_slice_id);
}
}
#[allow(unused)]
fn display_memory_usage(&self) {
let total_memory_usage: f64 = self
.chunks
.values()
.map(|chunk| chunk.storage.size() as f64)
.sum();
let effective_memory_usage: f64 = self
.slices
.values()
.filter(|slice| slice.handle.is_free())
.map(|slice| slice.storage.size() as f64)
.sum();
let ratio = 100.0 * effective_memory_usage / total_memory_usage;
log::info!("the memory usage is {ratio}");
}
/// Finds a free slice that can contain the given size
/// Returns the chunk's id and size.
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
if size < MIN_SIZE_NEEDED_TO_OFFSET {
return None;
}
let padding = calculate_padding(size, self.buffer_alignment);
let effective_size = size + padding;
let slice_id =
self.ring
.find_free_slice(effective_size, &mut self.chunks, &mut self.slices)?;
let slice = self.slices.get_mut(&slice_id).unwrap();
let old_slice_size = slice.effective_size();
let offset = match slice.storage.utilization {
StorageUtilization::Full(_) => 0,
StorageUtilization::Slice { offset, size: _ } => offset,
};
slice.storage.utilization = StorageUtilization::Slice { offset, size };
let new_padding = old_slice_size - size;
slice.padding = new_padding;
assert_eq!(
slice.effective_size(),
old_slice_size,
"new and old slice should have the same size"
);
Some(slice.handle.clone())
}
/// Creates a slice of size `size` upon the given chunk with the given offset.
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice {
assert_eq!(
offset % self.buffer_alignment,
0,
"slice with offset {offset} needs to be a multiple of {}",
self.buffer_alignment
);
if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET {
panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
}
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
let handle = SliceHandle::new();
let storage = StorageHandle {
id: chunk.storage.id.clone(),
utilization: StorageUtilization::Slice { offset, size },
};
let padding = calculate_padding(size, self.buffer_alignment);
Slice::new(storage, handle, chunk.handle.clone(), padding)
}
/// Creates a chunk of given size by allocating on the storage.
fn create_chunk<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
size: usize,
) -> ChunkHandle {
let padding = calculate_padding(size, self.buffer_alignment);
let effective_size = size + padding;
let storage = storage.alloc(effective_size);
let handle = ChunkHandle::new();
let id = *handle.id();
self.ring.push_chunk(id);
self.chunks.insert(
id,
Chunk::new(storage, handle.clone(), MemoryPage::new(HashMap::new())),
);
self.chunk_index.insert(id, size);
handle
}
#[allow(unused)]
fn extend_max_memory<Storage: ComputeStorage>(&mut self, storage: &mut Storage) {
let mut slices = Vec::<SliceUpdate>::new();
let mut deallocations = HashSet::<ChunkId>::new();
let mut chunks_total_size: usize = 0;
for chunk_id in &self.recently_added_chunks {
let chunk = self.chunks.get(chunk_id).unwrap();
let chunk_id = *chunk.handle.id();
let sorted_slice = chunk.slices.slices_sorted_by_address();
for slice_id in sorted_slice {
let slice = self.slices.get(&slice_id).unwrap();
let size = slice.storage.size();
slices.push(SliceUpdate { slice_id, size });
}
chunks_total_size += chunk.storage.size();
deallocations.insert(chunk_id);
}
if !slices.is_empty() {
self.move_to_new_chunk(chunks_total_size, storage, &mut slices, &mut deallocations);
} else {
self.deallocate(storage, &mut deallocations);
}
}
fn deallocate<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
deallocations: &mut HashSet<ChunkId>,
) {
for id in deallocations.drain() {
let mut chunk = self.chunks.remove(&id).unwrap();
self.ring.remove_chunk(id);
for (_address, slice_id) in chunk.slices.slices.drain() {
let slice = self.slices.get(&slice_id).unwrap();
let chunk_id = *slice.chunk.id();
assert_ne!(chunk_id, id, "Chunk id should be updated");
}
self.chunk_index.remove(&id);
storage.dealloc(chunk.storage.id);
}
}
fn move_to_new_chunk<Storage: ComputeStorage>(
&mut self,
alloc_size: usize,
storage: &mut Storage,
slices: &mut Vec<SliceUpdate>,
deallocations: &mut HashSet<ChunkId>,
) {
let chunk = self.create_chunk(storage, alloc_size);
let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone();
let mut offset = 0;
let mut slices_ids: Vec<(usize, SliceId)> = Vec::new();
for update in slices.drain(..) {
let slice_id = update.slice_id;
let slice = self.slices.get_mut(&slice_id).unwrap();
let old_storage = slice.storage.clone();
slice.chunk = chunk.clone();
slice.storage = StorageHandle {
id: storage_id.clone(),
utilization: StorageUtilization::Slice {
offset,
size: update.size,
},
};
storage.copy(&old_storage, &slice.storage);
slices_ids.push((offset, slice_id));
offset += slice.effective_size();
}
let chunk = self.chunks.get_mut(chunk.id()).unwrap();
let chunk_handle = chunk.handle.clone();
for (address, slice_id) in slices_ids.drain(..) {
chunk.slices.insert_slice(address, slice_id);
}
let chunk_size = chunk.storage.size();
let last_slice_size = chunk_size - offset;
assert_eq!(last_slice_size % self.buffer_alignment, 0);
if last_slice_size != 0 {
self.create_slice(offset, last_slice_size, chunk_handle);
}
self.deallocate(storage, deallocations);
}
}
fn calculate_padding(size: usize, buffer_alignment: usize) -> usize {
let remainder = size % buffer_alignment;
if remainder != 0 {
buffer_alignment - remainder
} else {
0
}
}
impl MemorySlice for Slice {
fn is_free(&self) -> bool {
self.handle.is_free()
}
fn size(&self) -> usize {
self.effective_size()
}
fn split(&mut self, offset_slice: usize, buffer_alignment: usize) -> Option<Self> {
let size_new = self.effective_size() - offset_slice;
let offset_new = self.storage.offset() + offset_slice;
let old_size = self.effective_size();
let storage_new = StorageHandle {
id: self.storage.id.clone(),
utilization: StorageUtilization::Slice {
offset: offset_new,
size: size_new,
},
};
self.storage.utilization = StorageUtilization::Slice {
offset: self.storage.offset(),
size: offset_slice,
};
if offset_new > 0 && size_new < MIN_SIZE_NEEDED_TO_OFFSET {
panic!("tried to create slice of size {size_new} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
}
if offset_new % buffer_alignment != 0 {
panic!("slice with offset {offset_new} needs to be a multiple of {buffer_alignment}");
}
let handle = SliceHandle::new();
if size_new < buffer_alignment {
self.padding = old_size - offset_slice;
assert_eq!(self.effective_size(), old_size);
return None;
}
assert!(
size_new >= buffer_alignment,
"Size new > {buffer_alignment}"
);
self.padding = 0;
let padding = calculate_padding(size_new - buffer_alignment, buffer_alignment);
Some(Slice::new(storage_new, handle, self.chunk.clone(), padding))
}
fn id(&self) -> SliceId {
*self.handle.id()
}
fn next_slice_position(&self) -> usize {
self.storage.offset() + self.effective_size()
}
}
impl MemoryChunk<Slice> for Chunk {
fn merge_next_slice(
&mut self,
from_slice_index: usize,
slices: &mut HashMap<SliceId, Slice>,
) -> bool {
self.slices.merge_with_next_slice(from_slice_index, slices)
}
fn slice(&self, index: usize) -> Option<SliceId> {
self.slices.find_slice(index)
}
fn insert_slice(
&mut self,
position: usize,
slice: Slice,
slices: &mut HashMap<SliceId, Slice>,
) {
self.slices.insert_slice(position, slice.id());
slices.insert(slice.id(), slice);
}
}

View File

@ -1,33 +0,0 @@
use crate::memory_id_type;
use crate::memory_management::{MemoryBinding, MemoryHandle};
// The ChunkId allows to keep track of how many references there are to a specific chunk.
memory_id_type!(ChunkId, ChunkHandle);
// The SliceId allows to keep track of how many references there are to a specific slice.
memory_id_type!(SliceId, SliceHandle, SliceBinding);
/// A tensor memory handle, referring to either a chunk or a slice.
#[derive(Debug, Clone)]
pub struct MemoryPoolHandle {
pub slice: SliceHandle,
}
/// Binding of the [dynamic handle](DynamicHandle).
#[derive(Debug, Clone)]
pub struct MemoryPoolBinding {
pub slice: SliceBinding,
}
impl MemoryBinding for MemoryPoolBinding {}
impl MemoryHandle<MemoryPoolBinding> for MemoryPoolHandle {
fn can_mut(&self) -> bool {
self.slice.can_mut()
}
fn binding(self) -> MemoryPoolBinding {
MemoryPoolBinding {
slice: self.slice.binding(),
}
}
}

View File

@ -1,65 +0,0 @@
use alloc::collections::BTreeMap;
use alloc::vec;
use alloc::vec::Vec;
use core::hash::Hash;
use hashbrown::HashMap;
/// Data Structure that helps to search items by size efficiently.
pub struct SearchIndex<T> {
items_per_size: BTreeMap<usize, Vec<T>>,
sizes_per_item: HashMap<T, usize>,
}
impl<T: PartialEq + Eq + Hash + Clone> SearchIndex<T> {
/// Create a new item search index.
pub fn new() -> Self {
Self {
items_per_size: BTreeMap::new(),
sizes_per_item: HashMap::new(),
}
}
/// Insert a new sized item into the search index.
pub fn insert(&mut self, item: T, size: usize) {
self.remove(&item);
if let Some(values) = self.items_per_size.get_mut(&size) {
values.push(item.clone())
} else {
self.items_per_size.insert(size, vec![item.clone()]);
}
self.sizes_per_item.insert(item, size);
}
/// Find the item by size range.
#[allow(unused)]
pub fn find_by_size(
&self,
range: core::ops::Range<usize>,
) -> impl DoubleEndedIterator<Item = &T> {
self.items_per_size.range(range).flat_map(|a| a.1)
}
/// Remove an item from the index.
pub fn remove(&mut self, item: &T) {
let size = match self.sizes_per_item.remove(item) {
Some(size) => size,
None => return,
};
if let Some(values) = self.items_per_size.get_mut(&size) {
let mut removed_index = None;
for (i, v) in values.iter().enumerate() {
if v == item {
removed_index = Some(i);
break;
}
}
if let Some(index) = removed_index {
values.remove(index);
}
}
}
}

View File

@ -1,11 +0,0 @@
pub(crate) mod index;
mod ring;
mod base;
mod handle;
mod small;
pub use base::*;
pub use handle::*;
pub use ring::*;
pub use small::*;

View File

@ -1,454 +0,0 @@
use alloc::vec::Vec;
use core::marker::PhantomData;
use hashbrown::HashMap;
use super::{ChunkId, SliceId};
#[derive(Debug)]
pub struct RingBuffer<C: MemoryChunk<S>, S: MemorySlice> {
queue: Vec<ChunkId>,
chunk_positions: HashMap<ChunkId, usize>,
cursor_slice: usize,
cursor_chunk: usize,
_s: PhantomData<S>,
_c: PhantomData<C>,
buffer_alignment: usize,
}
pub trait MemoryChunk<S: MemorySlice> {
fn merge_next_slice(&mut self, slice_position: usize, slices: &mut HashMap<SliceId, S>)
-> bool;
fn slice(&self, index: usize) -> Option<SliceId>;
fn insert_slice(&mut self, position: usize, slice: S, slices: &mut HashMap<SliceId, S>);
}
pub trait MemorySlice: Sized {
fn is_free(&self) -> bool;
fn size(&self) -> usize;
fn split(&mut self, offset: usize, buffer_alignment: usize) -> Option<Self>;
fn id(&self) -> SliceId;
fn next_slice_position(&self) -> usize;
}
impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
pub fn new(buffer_alignment: usize) -> Self {
Self {
queue: Vec::new(),
chunk_positions: HashMap::new(),
cursor_slice: 0,
cursor_chunk: 0,
_s: PhantomData,
_c: PhantomData,
buffer_alignment,
}
}
pub fn push_chunk(&mut self, chunk_id: ChunkId) {
self.queue.push(chunk_id);
self.chunk_positions.insert(chunk_id, self.queue.len() - 1);
}
pub fn remove_chunk(&mut self, chunk_id: ChunkId) {
if let Some(position) = self.chunk_positions.remove(&chunk_id) {
self.queue.remove(position);
}
self.chunk_positions.clear();
for (pos, id) in self.queue.iter().enumerate() {
self.chunk_positions.insert(*id, pos);
}
self.cursor_chunk = 0;
self.cursor_slice = 0;
}
pub fn find_free_slice(
&mut self,
size: usize,
chunks: &mut HashMap<ChunkId, C>,
slices: &mut HashMap<SliceId, S>,
) -> Option<SliceId> {
let max_second = self.cursor_chunk;
let result = self.find_free_slice_in_all_chunks(size, chunks, slices, self.queue.len());
if result.is_some() {
return result;
}
self.cursor_chunk = 0;
self.cursor_slice = 0;
self.find_free_slice_in_all_chunks(size, chunks, slices, max_second)
}
fn find_free_slice_in_chunk(
&mut self,
size: usize,
chunk: &mut C,
slices: &mut HashMap<SliceId, S>,
mut slice_index: usize,
) -> Option<(usize, SliceId)> {
while let Some(slice_id) = chunk.slice(slice_index) {
//mutable borrow scope
{
let slice = slices.get_mut(&slice_id).unwrap();
let is_big_enough = slice.size() >= size;
let is_free = slice.is_free();
if is_big_enough && is_free {
if slice.size() > size {
if let Some(new_slice) = slice.split(size, self.buffer_alignment) {
let new_slice_id = new_slice.id();
chunk.insert_slice(slice.next_slice_position(), new_slice, slices);
slices.get(&new_slice_id).unwrap();
}
}
return Some((slice_index, slice_id));
}
}
{
let slice = slices.get_mut(&slice_id).unwrap();
let is_free = slice.is_free();
if is_free && chunk.merge_next_slice(slice_index, slices) {
continue;
}
}
if let Some(slice) = slices.get(&slice_id) {
slice_index = slice.next_slice_position();
} else {
panic!("current slice_id should still be valid after potential merge");
}
}
None
}
fn find_free_slice_in_all_chunks(
&mut self,
size: usize,
chunks: &mut HashMap<ChunkId, C>,
slices: &mut HashMap<SliceId, S>,
max_cursor_position: usize,
) -> Option<SliceId> {
let start = self.cursor_chunk;
let end = usize::min(self.queue.len(), max_cursor_position);
let mut slice_index = self.cursor_slice;
for chunk_index in start..end {
if chunk_index > start {
slice_index = 0;
}
if let Some(id) = self.queue.get(chunk_index) {
let chunk = chunks.get_mut(id).unwrap();
let result = self.find_free_slice_in_chunk(size, chunk, slices, slice_index);
if let Some((_cursor_slice, slice)) = result {
let slice = slices.get(&slice).unwrap();
self.cursor_slice = slice.next_slice_position();
self.cursor_chunk = chunk_index;
return Some(slice.id());
}
}
self.cursor_chunk = chunk_index;
self.cursor_slice = 0;
}
None
}
}
#[cfg(test)]
mod tests {
use super::stub::*;
use super::*;
use alloc::vec;
#[test]
fn simple_1() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
let slice_1 = new_slice(0, 100, 0);
let slice_2 = new_slice(1, 200, 1);
let chunk_1 = new_chunk(0, vec![0, 1]);
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
ring.push_chunk(ChunkId { value: 0 });
let slice = ring.find_free_slice(50, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 0 });
assert_eq!(slices.get(&slice).unwrap().size, 50);
assert_eq!(slices.len(), 3);
assert_eq!(chunks.values().last().unwrap().slices.len(), 3);
}
#[test]
fn simple_2() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
let slice_1 = new_slice(0, 100, 0);
let slice_2 = new_slice(1, 200, 1);
let chunk_1 = new_chunk(0, vec![0, 1]);
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
ring.push_chunk(ChunkId { value: 0 });
let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 0 });
assert_eq!(slices.get(&slice).unwrap().size, 150);
assert_eq!(slices.len(), 2);
assert_eq!(chunks.values().last().unwrap().slices.len(), 2);
}
#[test]
fn multiple_chunks() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
let slice_1 = new_slice(0, 100, 0);
let slice_2 = new_slice(1, 200, 1);
let slice_3 = new_slice(2, 200, 0);
let slice_4 = new_slice(3, 200, 1);
let chunk_1 = new_chunk(0, vec![0, 1]);
let chunk_2 = new_chunk(1, vec![2, 3]);
let mut slices = HashMap::from([
(slice_1.id, slice_1),
(slice_2.id, slice_2),
(slice_3.id, slice_3),
(slice_4.id, slice_4),
]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]);
ring.push_chunk(ChunkId { value: 0 });
ring.push_chunk(ChunkId { value: 1 });
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = false;
slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = false;
let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 2 });
let slice = ring.find_free_slice(100, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 0 });
}
#[test]
fn find_free_slice_with_exact_fit() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
let slice_1 = new_slice(0, 100, 0);
let slice_2 = new_slice(1, 200, 1);
let chunk_1 = new_chunk(0, vec![0, 1]);
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
ring.push_chunk(ChunkId { value: 0 });
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = false;
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 1 });
assert_eq!(slices.get(&slice).unwrap().size, 200);
assert_eq!(slices.len(), 2);
assert_eq!(chunks.values().last().unwrap().slices.len(), 2);
}
#[test]
fn find_free_slice_with_merging() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
let slice_1 = new_slice(0, 100, 0);
let slice_2 = new_slice(1, 50, 1);
let slice_3 = new_slice(2, 100, 2);
let chunk_1 = new_chunk(0, vec![0, 1, 2]);
let mut slices = HashMap::from([
(slice_1.id, slice_1),
(slice_2.id, slice_2),
(slice_3.id, slice_3),
]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
ring.push_chunk(ChunkId { value: 0 });
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true;
let slice = ring.find_free_slice(250, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 0 });
assert_eq!(slices.get(&slice).unwrap().size, 250);
assert_eq!(slices.len(), 1);
assert_eq!(chunks.values().last().unwrap().slices.len(), 1);
}
#[test]
fn find_free_slice_with_multiple_chunks_and_merging() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
let slice_1 = new_slice(0, 50, 0);
let slice_2 = new_slice(1, 50, 1);
let chunk_1 = new_chunk(0, vec![0, 1]);
let slice_3 = new_slice(2, 100, 0);
let slice_4 = new_slice(3, 50, 1);
let chunk_2 = new_chunk(1, vec![2, 3]);
let mut slices = HashMap::from([
(slice_1.id, slice_1),
(slice_2.id, slice_2),
(slice_3.id, slice_3),
(slice_4.id, slice_4),
]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]);
ring.push_chunk(ChunkId { value: 0 });
ring.push_chunk(ChunkId { value: 1 });
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = true;
let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap();
assert_eq!(slices.get(&slice).unwrap().size, 150);
assert_eq!(slices.len(), 2);
assert_eq!(chunks.values().last().unwrap().slices.len(), 1);
}
fn new_slice(id: usize, size: usize, position: usize) -> TestSlice {
TestSlice {
id: SliceId { value: id },
is_free: true,
size,
position,
}
}
fn new_chunk(id: usize, slices: Vec<usize>) -> TestChunk {
TestChunk {
id: ChunkId { value: id },
slices: slices.into_iter().map(|i| SliceId { value: i }).collect(),
}
}
}
#[cfg(test)]
mod stub {
use super::*;
use burn_common::*;
#[derive(Debug)]
pub struct TestChunk {
pub id: ChunkId,
pub slices: Vec<SliceId>,
}
#[derive(Debug)]
pub struct TestSlice {
pub id: SliceId,
pub is_free: bool,
pub size: usize,
pub position: usize,
}
impl MemorySlice for TestSlice {
fn is_free(&self) -> bool {
self.is_free
}
fn size(&self) -> usize {
self.size
}
fn split(&mut self, offset: usize, _buffer_alignment: usize) -> Option<Self> {
let size_remained = self.size - offset;
self.size = offset;
Some(Self {
id: SliceId {
value: rand::gen_random(),
},
is_free: true,
size: size_remained,
position: self.position + 1,
})
}
fn id(&self) -> SliceId {
self.id
}
fn next_slice_position(&self) -> usize {
self.position + 1
}
}
impl MemoryChunk<TestSlice> for TestChunk {
fn merge_next_slice(
&mut self,
from_slice_index: usize,
slices: &mut HashMap<SliceId, TestSlice>,
) -> bool {
let slice_id_current = self.slices.get(from_slice_index).unwrap();
let slice_id_next = self.slices.get(from_slice_index + 1);
let slice_id_next = match slice_id_next {
Some(val) => val,
None => return false,
};
let slice_next = slices.get(slice_id_next).unwrap();
let is_free = slice_next.is_free;
let size = slice_next.size;
let slice_current = slices.get_mut(slice_id_current).unwrap();
if is_free {
slice_current.size += size;
slices.remove(slice_id_next);
self.slices.remove(from_slice_index + 1);
for (index, temp_slice_id) in self.slices.iter_mut().enumerate() {
let slice = slices.get_mut(temp_slice_id).unwrap();
slice.position = index;
}
return true;
}
false
}
fn slice(&self, index: usize) -> Option<SliceId> {
self.slices.get(index).copied()
}
fn insert_slice(
&mut self,
position: usize,
slice: TestSlice,
slices: &mut HashMap<SliceId, TestSlice>,
) {
self.slices.insert(position, slice.id());
slices.insert(slice.id(), slice);
for (index, temp_slice_id) in self.slices.iter_mut().enumerate() {
let temp_slice = slices.get_mut(temp_slice_id).unwrap();
temp_slice.position = index;
}
}
}
}

View File

@ -1,226 +0,0 @@
use super::{ChunkHandle, ChunkId, MemoryPoolBinding, MemoryPoolHandle, SliceHandle, SliceId};
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
use alloc::vec::Vec;
use hashbrown::HashMap;
/// A memory pool that allocates fixed-size chunks (32 bytes each) and reuses them to minimize allocations.
///
/// - Only one slice is supported per chunk due to the limitations in WGPU where small allocations cannot be offset.
/// - The pool uses a ring buffer to efficiently manage and reuse chunks.
///
/// Fields:
/// - `chunks`: A hashmap storing the allocated chunks by their IDs.
/// - `slices`: A hashmap storing the slices by their IDs.
/// - `ring_buffer`: A vector used as a ring buffer to manage chunk reuse.
/// - `index`: The current position in the ring buffer.
pub struct SmallMemoryPool {
chunks: HashMap<ChunkId, SmallChunk>,
slices: HashMap<SliceId, SmallSlice>,
ring_buffer: Vec<ChunkId>,
index: usize,
buffer_storage_alignment_offset: usize,
}
#[derive(new, Debug)]
pub struct SmallChunk {
pub storage: StorageHandle,
#[allow(dead_code)]
pub handle: ChunkHandle,
pub slice: Option<SliceId>,
}
#[derive(new, Debug)]
pub struct SmallSlice {
pub storage: StorageHandle,
pub handle: SliceHandle,
#[allow(dead_code)]
pub chunk: ChunkHandle,
pub padding: usize,
}
impl SmallSlice {
pub fn effective_size(&self) -> usize {
self.storage.size() + self.padding
}
}
impl SmallMemoryPool {
pub fn new(buffer_storage_alignment_offset: usize) -> Self {
Self {
chunks: HashMap::new(),
slices: HashMap::new(),
ring_buffer: Vec::new(),
index: 0,
buffer_storage_alignment_offset,
}
}
/// Returns the resource from the storage, for the specified handle.
pub fn get<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
binding: &MemoryPoolBinding,
) -> Option<Storage::Resource> {
self.slices
.get(binding.slice.id())
.map(|s| &s.storage)
.map(|h| storage.get(h))
}
/// Reserves memory of specified size using the reserve algorithm, and return
/// a handle to the reserved memory.
///
/// Also clean ups, merging free slices together if permitted by the merging strategy
pub fn reserve<Storage: ComputeStorage, Sync: FnOnce()>(
&mut self,
storage: &mut Storage,
size: usize,
sync: Sync,
) -> MemoryPoolHandle {
assert!(size <= self.buffer_storage_alignment_offset);
let slice = self.get_free_slice(size);
match slice {
Some(slice) => MemoryPoolHandle {
slice: slice.clone(),
},
None => self.alloc(storage, size, sync),
}
}
pub fn alloc<Storage: ComputeStorage, Sync: FnOnce()>(
&mut self,
storage: &mut Storage,
size: usize,
_sync: Sync,
) -> MemoryPoolHandle {
assert!(size <= self.buffer_storage_alignment_offset);
self.alloc_slice(storage, size)
}
fn alloc_slice<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
slice_size: usize,
) -> MemoryPoolHandle {
let handle_chunk = self.create_chunk(storage, self.buffer_storage_alignment_offset);
let chunk_id = *handle_chunk.id();
let slice = self.allocate_slice(handle_chunk.clone(), slice_size);
let handle_slice = slice.handle.clone();
self.update_chunk_metadata(chunk_id, slice);
MemoryPoolHandle {
slice: handle_slice,
}
}
fn allocate_slice(&self, handle_chunk: ChunkHandle, slice_size: usize) -> SmallSlice {
let slice = self.create_slice(0, slice_size, handle_chunk.clone());
let effective_size = slice.effective_size();
assert_eq!(effective_size, self.buffer_storage_alignment_offset);
slice
}
fn update_chunk_metadata(&mut self, chunk_id: ChunkId, slice: SmallSlice) {
let slice_id = *slice.handle.id();
self.slices.insert(slice_id, slice);
self.chunks.get_mut(&chunk_id).unwrap().slice = Some(slice_id);
}
fn find_free_slice(&mut self) -> Option<SliceId> {
if self.ring_buffer.is_empty() {
return None;
}
for _ in 0..self.ring_buffer.len() {
let chunk_id = self.ring_buffer.get(self.index).unwrap();
let chunk = self.chunks.get(chunk_id).unwrap();
let slice = self.slices.get(&chunk.slice.unwrap()).unwrap();
self.index = (self.index + 1) % self.ring_buffer.len();
if slice.handle.is_free() {
return Some(*slice.handle.id());
}
}
None
}
/// Finds a free slice that can contain the given size
/// Returns the chunk's id and size.
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
let slice_id = self.find_free_slice()?;
let slice = self.slices.get_mut(&slice_id).unwrap();
let old_slice_size = slice.effective_size();
let offset = match slice.storage.utilization {
StorageUtilization::Full(_) => 0,
StorageUtilization::Slice { offset, size: _ } => offset,
};
assert_eq!(offset, 0);
slice.storage.utilization = StorageUtilization::Slice { offset, size };
let new_padding = old_slice_size - size;
slice.padding = new_padding;
assert_eq!(
slice.effective_size(),
old_slice_size,
"new and old slice should have the same size"
);
Some(slice.handle.clone())
}
/// Creates a slice of size `size` upon the given chunk with the given offset.
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> SmallSlice {
assert_eq!(offset, 0);
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
let handle = SliceHandle::new();
let storage = StorageHandle {
id: chunk.storage.id.clone(),
utilization: StorageUtilization::Slice { offset, size },
};
let padding = calculate_padding(size, self.buffer_storage_alignment_offset);
SmallSlice::new(storage, handle, chunk.handle.clone(), padding)
}
/// Creates a chunk of given size by allocating on the storage.
fn create_chunk<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
size: usize,
) -> ChunkHandle {
let padding = calculate_padding(size, self.buffer_storage_alignment_offset);
let effective_size = size + padding;
let storage = storage.alloc(effective_size);
let handle = ChunkHandle::new();
let id = *handle.id();
self.ring_buffer.push(id);
self.chunks
.insert(id, SmallChunk::new(storage, handle.clone(), None));
handle
}
#[allow(unused)]
fn deallocate<Storage: ComputeStorage>(&mut self, _storage: &mut Storage) {
todo!()
}
}
fn calculate_padding(size: usize, buffer_storage_alignment_offset: usize) -> usize {
let remainder = size % buffer_storage_alignment_offset;
if remainder != 0 {
buffer_storage_alignment_offset - remainder
} else {
0
}
}

View File

@ -1,10 +0,0 @@
pub(crate) mod memory_pool;
mod base;
pub use base::*;
/// Dynamic memory management strategy.
pub mod dynamic;
/// Simple memory management strategy.
pub mod simple;

View File

@ -1,559 +0,0 @@
use crate::{
memory_id_type,
storage::{ComputeStorage, StorageHandle, StorageUtilization},
};
use alloc::vec::Vec;
use hashbrown::HashMap;
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
use std::time;
#[cfg(all(target_family = "wasm", feature = "std"))]
use web_time as time;
use super::{MemoryBinding, MemoryHandle, MemoryManagement};
// The ChunkId allows to keep track of how many references there are to a specific chunk.
memory_id_type!(ChunkId, ChunkHandle, ChunkBinding);
// The SliceId allows to keep track of how many references there are to a specific slice.
memory_id_type!(SliceId, SliceHandle, SliceBinding);
/// A tensor memory handle, referring to either a chunk or a slice.
#[derive(Debug, Clone)]
pub enum SimpleHandle {
/// A whole chunk of memory.
Chunk(ChunkHandle),
/// A slice of a chunk of memory.
Slice(SliceHandle),
}
/// Binding of the [simple handle](SimpleHandle).
#[derive(Debug, Clone)]
pub enum SimpleBinding {
/// Binding of the [chunk handle](ChunkHandle).
Chunk(ChunkBinding),
/// Binding of the [slice handle](SliceHandle)
Slice(SliceBinding),
}
/// The strategy defines the frequency at which deallocation of unused memory chunks should occur.
#[derive(Debug)]
pub enum DeallocStrategy {
/// Once every n calls to reserve.
PeriodTick {
/// Number of calls to be executed before triggering the deallocation.
period: usize,
/// Current state. Should start at zero.
state: usize,
},
#[cfg(feature = "std")]
/// Once every period of time
PeriodTime {
/// Number of time before triggering the deallocation.
period: time::Duration,
/// Current state. Should start at now.
state: time::Instant,
},
/// Never deallocate.
Never,
}
/// The strategy defines when to reuse chunk with slices.
#[derive(Debug)]
pub enum SliceStrategy {
/// Never use slices.
Never,
/// Ratio needed before the chunk can be used as a slice. Between 0 and 1.
Ratio(f32),
/// When the reserved memory is at least {} bytes.
MinimumSize(usize),
/// When the reserved memory less than {} bytes.
MaximumSize(usize),
}
impl SliceStrategy {
/// If the chunk can be used with a slice.
pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool {
if chunk_size < reserved_size {
return false;
}
match self {
SliceStrategy::Never => false,
SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio,
SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes,
SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes,
}
}
}
impl DeallocStrategy {
/// Create a new strategy with the given period.
pub fn new_period_tick(period: usize) -> Self {
DeallocStrategy::PeriodTick { period, state: 0 }
}
fn should_dealloc(&mut self) -> bool {
match self {
DeallocStrategy::PeriodTick { period, state } => {
*state = (*state + 1) % *period;
*state == 0
}
#[cfg(feature = "std")]
DeallocStrategy::PeriodTime { period, state } => {
if &state.elapsed() > period {
*state = time::Instant::now();
true
} else {
false
}
}
DeallocStrategy::Never => false,
}
}
}
#[derive(new)]
struct Chunk {
storage: StorageHandle,
handle: ChunkHandle,
slices: Vec<SliceId>,
}
#[derive(new)]
struct Slice {
storage: StorageHandle,
handle: SliceHandle,
// It is important to keep the chunk handle inside the slice, since it increases the ref count
// on the chunk id and make the `is_free` method return false until the slice is freed.
//
// TL;DR we can't only store the chunk id.
chunk: ChunkHandle,
}
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
pub struct SimpleMemoryManagement<Storage> {
chunks: HashMap<ChunkId, Chunk>,
slices: HashMap<SliceId, Slice>,
dealloc_strategy: DeallocStrategy,
slice_strategy: SliceStrategy,
storage: Storage,
}
impl<Storage> core::fmt::Debug for SimpleMemoryManagement<Storage> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
alloc::format!(
"SimpleMemoryManagement {:?} - {:?}",
self.dealloc_strategy,
core::any::type_name::<Storage>(),
)
.as_str(),
)
}
}
impl MemoryBinding for SimpleBinding {}
impl MemoryHandle<SimpleBinding> for SimpleHandle {
fn can_mut(&self) -> bool {
match &self {
SimpleHandle::Chunk(id) => id.can_mut(),
SimpleHandle::Slice(id) => id.can_mut(),
}
}
fn binding(self) -> SimpleBinding {
match self {
Self::Chunk(handle) => SimpleBinding::Chunk(handle.binding()),
Self::Slice(handle) => SimpleBinding::Slice(handle.binding()),
}
}
}
impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManagement<Storage> {
type Handle = SimpleHandle;
type Binding = SimpleBinding;
/// Returns the resource from the storage, for the specified handle.
fn get(&mut self, binding: Self::Binding) -> Storage::Resource {
let storage = match binding {
SimpleBinding::Chunk(chunk) => {
&self
.chunks
.get(chunk.id())
.expect("Storage found for the given execution buffer handle")
.storage
}
SimpleBinding::Slice(slice) => {
&self
.slices
.get(slice.id())
.expect("Storage found for the given execution buffer handle")
.storage
}
};
self.storage.get(storage)
}
/// Reserves memory of specified size using the reserve algorithm, and return
/// a handle to the reserved memory.
///
/// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy.
fn reserve<Sync: FnOnce()>(&mut self, size: usize, _sync: Sync) -> Self::Handle {
self.cleanup_slices();
let handle = self.reserve_algorithm(size);
if self.dealloc_strategy.should_dealloc() {
self.cleanup_chunks();
}
handle
}
fn alloc<Sync: FnOnce()>(&mut self, size: usize, _sync: Sync) -> Self::Handle {
self.create_chunk(size)
}
fn dealloc(&mut self, binding: Self::Binding) {
match binding {
SimpleBinding::Chunk(chunk) => {
if let Some(chunk) = self.chunks.remove(chunk.id()) {
self.storage.dealloc(chunk.storage.id);
}
}
SimpleBinding::Slice(_) => panic!("Can't dealloc slice manually"),
}
}
fn storage(&mut self) -> &mut Storage {
&mut self.storage
}
}
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
/// Creates a new instance using the given storage, deallocation strategy and slice strategy.
pub fn new(
storage: Storage,
dealloc_strategy: DeallocStrategy,
slice_strategy: SliceStrategy,
) -> Self {
Self {
chunks: HashMap::new(),
slices: HashMap::new(),
dealloc_strategy,
slice_strategy,
storage,
}
}
fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle {
// Looks for a large enough, existing but unused chunk of memory.
let chunk = self.find_free_chunk(size);
match chunk {
Some(chunk) => {
if size == chunk.storage.size() {
// If there is one of exactly the same size, it reuses it.
SimpleHandle::Chunk(chunk.handle.clone())
} else {
// Otherwise creates a slice of the right size upon it, always starting at zero.
self.create_slice(size, chunk.handle.clone())
}
}
// If no chunk available, creates one of exactly the right size.
None => self.create_chunk(size),
}
}
/// Finds the smallest of the free and large enough chunks to fit `size`
/// Returns the chunk's id and size.
fn find_free_chunk(&self, size: usize) -> Option<&Chunk> {
let mut size_diff_current = usize::MAX;
let mut current = None;
for chunk in self.chunks.values() {
// If chunk is already used, we do not choose it
if !chunk.handle.is_free() {
continue;
}
let storage_size = chunk.storage.size();
// If we find a chunk of exactly the right size, we stop searching altogether
if size == storage_size {
current = Some(chunk);
break;
}
// Finds the smallest of the large enough chunks that can accept a slice
// of the given size
if self.slice_strategy.can_use_chunk(storage_size, size) {
let size_diff = storage_size - size;
if size_diff < size_diff_current {
current = Some(chunk);
size_diff_current = size_diff;
}
}
}
current
}
/// Creates a slice of size `size` upon the given chunk.
///
/// For now slices must start at zero, therefore there can be only one per chunk
fn create_slice(&mut self, size: usize, handle_chunk: ChunkHandle) -> SimpleHandle {
let chunk = self.chunks.get_mut(handle_chunk.id()).unwrap();
let handle_slice = SliceHandle::new();
let storage = StorageHandle {
id: chunk.storage.id.clone(),
utilization: StorageUtilization::Slice { offset: 0, size },
};
if chunk.slices.is_empty() {
self.slices.insert(
*handle_slice.id(),
Slice::new(storage, handle_slice.clone(), handle_chunk.clone()),
);
} else {
panic!("Can't have more than 1 slice yet.");
}
chunk.slices.push(*handle_slice.id());
SimpleHandle::Slice(handle_slice)
}
/// Creates a chunk of given size by allocating on the storage.
fn create_chunk(&mut self, size: usize) -> SimpleHandle {
let storage = self.storage.alloc(size);
let handle = ChunkHandle::new();
self.chunks.insert(
*handle.id(),
Chunk::new(storage, handle.clone(), Vec::new()),
);
SimpleHandle::Chunk(handle)
}
/// Deallocates free chunks and remove them from chunks map.
fn cleanup_chunks(&mut self) {
let mut ids_to_remove = Vec::new();
self.chunks.iter().for_each(|(chunk_id, chunk)| {
if chunk.handle.is_free() {
ids_to_remove.push(*chunk_id);
}
});
ids_to_remove
.iter()
.map(|chunk_id| self.chunks.remove(chunk_id).unwrap())
.for_each(|chunk| {
self.storage.dealloc(chunk.storage.id);
});
}
/// Removes free slices from slice map and corresponding chunks.
fn cleanup_slices(&mut self) {
let mut ids_to_remove = Vec::new();
self.slices.iter().for_each(|(slice_id, slice)| {
if slice.handle.is_free() {
ids_to_remove.push(*slice_id);
}
});
ids_to_remove
.iter()
.map(|slice_id| self.slices.remove(slice_id).unwrap())
.for_each(|slice| {
let chunk = self.chunks.get_mut(slice.chunk.id()).unwrap();
chunk.slices.retain(|id| id != slice.handle.id());
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
memory_management::{MemoryHandle, MemoryManagement},
storage::BytesStorage,
};
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
fn reserve_no_sync(&mut self, size: usize) -> SimpleHandle {
self.reserve(size, || {})
}
}
#[test]
fn can_mut_with_single_tensor_reference() {
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let simple_handle = memory_management.create_chunk(chunk_size);
let x = simple_handle.clone();
core::mem::drop(simple_handle);
assert!(x.can_mut());
}
#[test]
fn two_tensor_references_remove_mutability() {
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let simple_handle = memory_management.create_chunk(chunk_size);
let x = simple_handle.clone();
assert!(!simple_handle.can_mut());
assert!(!x.can_mut())
}
#[test]
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let _chunk_handle = memory_management.reserve_no_sync(chunk_size);
let _new_handle = memory_management.reserve_no_sync(chunk_size);
assert_eq!(memory_management.chunks.len(), 2);
}
#[test]
fn when_empty_chunk_is_cleaned_upexists_it_disappears() {
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let chunk_handle = memory_management.reserve_no_sync(chunk_size);
drop(chunk_handle);
memory_management.cleanup_chunks();
assert_eq!(memory_management.chunks.len(), 0);
}
#[test]
fn never_dealloc_strategy_never_deallocs() {
let mut never_dealloc = DeallocStrategy::Never;
for _ in 0..20 {
assert!(!never_dealloc.should_dealloc())
}
}
#[test]
fn period_tick_dealloc_strategy_should_dealloc_after_period() {
let period = 3;
let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period);
for _ in 0..3 {
for _ in 0..period - 1 {
assert!(!period_tick_dealloc.should_dealloc());
}
assert!(period_tick_dealloc.should_dealloc());
}
}
#[test]
fn slice_strategy_minimum_bytes() {
let strategy = SliceStrategy::MinimumSize(100);
assert!(strategy.can_use_chunk(200, 101));
assert!(!strategy.can_use_chunk(200, 99));
}
#[test]
fn slice_strategy_maximum_bytes() {
let strategy = SliceStrategy::MaximumSize(100);
assert!(strategy.can_use_chunk(200, 99));
assert!(!strategy.can_use_chunk(200, 101));
}
#[test]
fn slice_strategy_ratio() {
let strategy = SliceStrategy::Ratio(0.9);
assert!(strategy.can_use_chunk(200, 180));
assert!(!strategy.can_use_chunk(200, 179));
}
#[test]
fn test_handle_mutability() {
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Ratio(0.5),
);
let handle = memory_management.reserve_no_sync(10);
let other_ref = handle.clone();
assert!(!handle.can_mut(), "Handle can't be mut when multiple ref.");
drop(other_ref);
assert!(handle.can_mut(), "Handle should be mut when only one ref.");
}
#[test]
fn test_slice_mutability() {
let mut memory_management = SimpleMemoryManagement::new(
BytesStorage::default(),
DeallocStrategy::Never,
SliceStrategy::Ratio(0.5),
);
let chunk = memory_management.reserve_no_sync(10);
if let super::SimpleHandle::Slice(_) = chunk {
panic!("Should be a chunk.")
}
drop(chunk);
let slice = memory_management.reserve_no_sync(8);
if let super::SimpleHandle::Chunk(_) = &slice {
panic!("Should be a slice.")
}
if let super::SimpleHandle::Slice(slice) = slice {
let other_ref = slice.clone();
assert!(
!slice.can_mut(),
"Slice can't be mut when multiple ref to the same handle."
);
drop(other_ref);
assert!(
slice.can_mut(),
"Slice should be mut when only one ref to the same handle."
);
assert!(
!slice.is_free(),
"Slice can't be reallocated when one ref still exist."
);
}
}
}

View File

@ -1,105 +0,0 @@
use crate::{
memory_management::{MemoryHandle, MemoryManagement},
storage::ComputeStorage,
tune::AutotuneKey,
};
use alloc::vec::Vec;
use burn_common::{reader::Reader, sync_type::SyncType};
use core::fmt::Debug;
/// The compute server is responsible for handling resources and computations over resources.
///
/// Everything in the server is mutable, therefore it should be solely accessed through the
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
pub trait ComputeServer: Send + core::fmt::Debug
where
Self: Sized,
{
/// The kernel type defines the computation algorithms.
type Kernel: Send;
/// Options when dispatching the kernel, eg. the number of executions.
type DispatchOptions: Send;
/// The [storage](ComputeStorage) type defines how data is stored and accessed.
type Storage: ComputeStorage;
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
type MemoryManagement: MemoryManagement<Self::Storage>;
/// The key used to cache operations used on specific inputs in autotune
type AutotuneKey: AutotuneKey;
/// Features supported by the compute server.
type FeatureSet: Send + Sync;
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, binding: Binding<Self>) -> Reader;
/// Given a resource handle, returns the storage resource.
fn get_resource(
&mut self,
binding: Binding<Self>,
) -> <Self::Storage as ComputeStorage>::Resource;
/// Given a resource as bytes, stores it and returns the memory handle.
fn create(&mut self, data: &[u8]) -> Handle<Self>;
/// Reserves `size` bytes in the storage, and returns a handle over them.
fn empty(&mut self, size: usize) -> Handle<Self>;
/// Executes the `kernel` over the given memory `handles`.
///
/// Kernels have mutable access to every resource they are given
/// and are responsible of determining which should be read or written.
fn execute(
&mut self,
kernel: Self::Kernel,
count: Self::DispatchOptions,
bindings: Vec<Binding<Self>>,
);
/// Wait for the completion of every task in the server.
fn sync(&mut self, command: SyncType);
}
/// Server handle containing the [memory handle](MemoryManagement::Handle).
#[derive(new, Debug)]
pub struct Handle<Server: ComputeServer> {
/// Memory handle.
pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Handle,
}
/// Binding of a [tensor handle](Handle) to execute a kernel.
#[derive(new, Debug)]
pub struct Binding<Server: ComputeServer> {
/// Memory binding.
pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Binding,
}
impl<Server: ComputeServer> Handle<Server> {
/// If the tensor handle can be reused inplace.
pub fn can_mut(&self) -> bool {
MemoryHandle::can_mut(&self.memory)
}
}
impl<Server: ComputeServer> Handle<Server> {
/// Convert the [handle](Handle) into a [binding](Binding).
pub fn binding(self) -> Binding<Server> {
Binding {
memory: MemoryHandle::binding(self.memory),
}
}
}
impl<Server: ComputeServer> Clone for Handle<Server> {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
}
}
}
impl<Server: ComputeServer> Clone for Binding<Server> {
fn clone(&self) -> Self {
Self {
memory: self.memory.clone(),
}
}
}

View File

@ -1,64 +0,0 @@
use crate::storage_id_type;
// This ID is used to map a handle to its actual data.
storage_id_type!(StorageId);
/// Defines if data uses a full memory chunk or a slice of it.
#[derive(Clone, Debug)]
pub enum StorageUtilization {
/// Full memory chunk of specified size
Full(usize),
/// Slice of memory chunk with start index and size.
Slice {
/// The offset in bytes from the chunk start.
offset: usize,
/// The size of the slice in bytes.
size: usize,
},
}
/// Contains the [storage id](StorageId) of a resource and the way it is used.
#[derive(new, Clone, Debug)]
pub struct StorageHandle {
/// Storage id.
pub id: StorageId,
/// How the storage is used.
pub utilization: StorageUtilization,
}
impl StorageHandle {
/// Returns the size the handle is pointing to in memory.
pub fn size(&self) -> usize {
match self.utilization {
StorageUtilization::Full(size) => size,
StorageUtilization::Slice { offset: _, size } => size,
}
}
/// Returns the size the handle is pointing to in memory.
pub fn offset(&self) -> usize {
match self.utilization {
StorageUtilization::Full(..) => panic!("full size slice not supported anymore"),
StorageUtilization::Slice { offset, .. } => offset,
}
}
}
/// Storage types are responsible for allocating and deallocating memory.
pub trait ComputeStorage: Send {
/// The resource associated type determines the way data is implemented and how
/// it can be accessed by kernels.
type Resource: Send;
/// Returns the underlying resource for a specified storage handle
fn get(&mut self, handle: &StorageHandle) -> Self::Resource;
/// Allocates `size` units of memory and returns a handle to it
fn alloc(&mut self, size: usize) -> StorageHandle;
/// Deallocates the memory pointed by the given storage id.
fn dealloc(&mut self, id: StorageId);
/// Copy
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle);
}

View File

@ -1,150 +0,0 @@
use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use alloc::alloc::{alloc, dealloc, Layout};
use hashbrown::HashMap;
/// The bytes storage maps ids to pointers of bytes in a contiguous layout.
#[derive(Default)]
pub struct BytesStorage {
memory: HashMap<StorageId, AllocatedBytes>,
}
impl core::fmt::Debug for BytesStorage {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("BytesStorage")
}
}
/// Can send to other threads.
unsafe impl Send for BytesStorage {}
unsafe impl Send for BytesResource {}
/// This struct is a pointer to a memory chunk or slice.
pub struct BytesResource {
ptr: *mut u8,
utilization: StorageUtilization,
}
/// This struct refers to a specific (contiguous) layout of bytes.
struct AllocatedBytes {
ptr: *mut u8,
layout: Layout,
}
impl BytesResource {
fn get_exact_location_and_length(&self) -> (*mut u8, usize) {
match self.utilization {
StorageUtilization::Full(len) => (self.ptr, len),
StorageUtilization::Slice { offset, size } => unsafe { (self.ptr.add(offset), size) },
}
}
/// Returns the resource as a mutable slice of bytes.
pub fn write<'a>(&self) -> &'a mut [u8] {
let (ptr, len) = self.get_exact_location_and_length();
unsafe { core::slice::from_raw_parts_mut(ptr, len) }
}
/// Returns the resource as an immutable slice of bytes.
pub fn read<'a>(&self) -> &'a [u8] {
let (ptr, len) = self.get_exact_location_and_length();
unsafe { core::slice::from_raw_parts(ptr, len) }
}
}
impl ComputeStorage for BytesStorage {
type Resource = BytesResource;
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let allocated_bytes = self.memory.get_mut(&handle.id).unwrap();
BytesResource {
ptr: allocated_bytes.ptr,
utilization: handle.utilization.clone(),
}
}
fn alloc(&mut self, size: usize) -> StorageHandle {
let id = StorageId::new();
let handle = StorageHandle {
id: id.clone(),
utilization: StorageUtilization::Full(size),
};
unsafe {
let layout = Layout::array::<u8>(size).unwrap();
let ptr = alloc(layout);
let memory = AllocatedBytes { ptr, layout };
self.memory.insert(id, memory);
}
handle
}
fn dealloc(&mut self, id: StorageId) {
if let Some(memory) = self.memory.remove(&id) {
unsafe {
dealloc(memory.ptr, memory.layout);
}
}
}
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
assert_eq!(from.size(), to.size());
let input = self.get(from);
let output = self.get(to);
for i in 0..from.size() {
let offset = i + from.offset();
let ptr_out = output.ptr.wrapping_add(offset);
let offset = i + to.offset();
let ptr_in = input.ptr.wrapping_add(offset);
unsafe { *ptr_in = *ptr_out }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_can_alloc_and_dealloc() {
let mut storage = BytesStorage::default();
let handle_1 = storage.alloc(64);
assert_eq!(handle_1.size(), 64);
storage.dealloc(handle_1.id);
}
#[test]
fn test_slices() {
let mut storage = BytesStorage::default();
let handle_1 = storage.alloc(64);
let handle_2 = StorageHandle::new(
handle_1.id.clone(),
StorageUtilization::Slice {
offset: 24,
size: 8,
},
);
storage
.get(&handle_1)
.write()
.iter_mut()
.enumerate()
.for_each(|(i, b)| {
*b = i as u8;
});
let bytes = storage.get(&handle_2).read().to_vec();
storage.dealloc(handle_1.id);
assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]);
}
}

View File

@ -1,8 +0,0 @@
mod base;
pub use base::*;
#[cfg(feature = "storage-bytes")]
mod bytes_cpu;
#[cfg(feature = "storage-bytes")]
pub use bytes_cpu::*;

View File

@ -1,9 +0,0 @@
mod operation;
mod tune_benchmark;
mod tune_cache;
mod tuner;
pub use operation::*;
pub use tune_benchmark::*;
pub use tune_cache::*;
pub use tuner::*;

View File

@ -1,69 +0,0 @@
use alloc::boxed::Box;
use alloc::string::String;
use alloc::vec::Vec;
use core::fmt::{Debug, Display};
use core::hash::Hash;
/// Default checksum for an operation set
#[cfg(feature = "autotune-persistent-cache")]
pub fn compute_checksum(autotunables: &[Box<dyn AutotuneOperation>]) -> String {
let mut checksum = String::new();
autotunables.iter().for_each(|op| {
checksum += op.name();
});
format!("{:x}", md5::compute(checksum))
}
/// Groups operations of the same type for autotune
pub trait AutotuneOperationSet<K>: Send {
/// The key used in the tune cache
fn key(&self) -> K;
/// All candidate operations for autotuning this operation type
/// Operations can run on toy tensors of relevant size
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>>;
/// Returns the operation for the given index, matching the order
/// returned by autotunables. Operation obtained here runs on original tensors
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation>;
/// Compute a checksum that can invalidate outdated cached auto-tune results.
#[cfg(feature = "autotune-persistent-cache")]
fn compute_checksum(&self) -> String {
compute_checksum(&self.autotunables())
}
}
/// Contains operation to run and inputs on which to run it
pub trait AutotuneOperation {
/// Runs the operation
fn execute(self: Box<Self>);
/// The name of the operation.
fn name(&self) -> &str {
core::any::type_name::<Self>()
}
/// Clones the operation and inputs
fn clone(&self) -> Box<dyn AutotuneOperation>;
}
#[cfg(feature = "autotune-persistent-cache")]
/// Trait alias with support for persistent caching
pub trait AutotuneKey:
Clone
+ Debug
+ PartialEq
+ Eq
+ Hash
+ Display
+ serde::Serialize
+ serde::de::DeserializeOwned
+ Send
+ Sync
{
}
#[cfg(not(feature = "autotune-persistent-cache"))]
/// Trait alias
pub trait AutotuneKey: Clone + Debug + PartialEq + Eq + Hash + Display {}
impl AutotuneKey for String {}

View File

@ -1,48 +0,0 @@
use burn_common::benchmark::Benchmark;
use burn_common::sync_type::SyncType;
use crate::channel::ComputeChannel;
use crate::client::ComputeClient;
use crate::server::ComputeServer;
use super::AutotuneOperation;
use alloc::boxed::Box;
use alloc::string::{String, ToString};
/// A benchmark that runs on server handles
#[derive(new)]
pub struct TuneBenchmark<S: ComputeServer, C> {
operation: Box<dyn AutotuneOperation>,
client: ComputeClient<S, C>,
}
impl Clone for Box<dyn AutotuneOperation> {
fn clone(&self) -> Self {
self.as_ref().clone()
}
}
impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
type Args = Box<dyn AutotuneOperation>;
fn prepare(&self) -> Self::Args {
self.operation.clone()
}
fn num_samples(&self) -> usize {
10
}
fn execute(&self, operation: Self::Args) {
AutotuneOperation::execute(operation);
}
fn name(&self) -> String {
"autotune".to_string()
}
fn sync(&self) {
// For benchmarks - we need to wait for all tasks to complete before returning.
self.client.sync(SyncType::Wait);
}
}

View File

@ -1,243 +0,0 @@
#[cfg(feature = "autotune-persistent-cache")]
mod std_imports {
pub use std::fs;
pub use std::fs::File;
pub use std::io;
pub use std::path::Path;
pub use std::path::PathBuf;
}
#[cfg(feature = "autotune-persistent-cache")]
use std_imports::*;
#[cfg(feature = "autotune-persistent-cache")]
use serde::{Deserialize, Serialize};
use super::AutotuneKey;
use super::AutotuneOperation;
use super::AutotuneOperationSet;
use alloc::boxed::Box;
use hashbrown::HashMap;
#[cfg(feature = "autotune-persistent-cache")]
/// Return the file path for the persistent cache on disk
/// prefix should be the device id computed at the backend level
pub fn get_persistent_cache_file_path(prefix: &str) -> PathBuf {
let home_dir = dirs::home_dir().expect("An home directory should exist");
let path_dir = home_dir.join(".cache").join("burn").join("autotune");
let path = Path::new(&path_dir);
path.join(format!("{}-autotune-cache.json", prefix))
}
/// In-memory cache entry
#[derive(Debug)]
pub(crate) struct InMemoryCacheEntry {
#[cfg(feature = "autotune-persistent-cache")]
checksum_checked: bool,
fastest_index: usize,
}
/// Persistent cache entry
#[cfg(feature = "autotune-persistent-cache")]
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct PersistentCacheEntry {
checksum: String,
fastest_index: usize,
}
/// Use to find and reuse the best kernel for some input
#[derive(Debug)]
pub(crate) struct TuneCache<K> {
in_memory_cache: HashMap<K, InMemoryCacheEntry>,
#[cfg(feature = "autotune-persistent-cache")]
persistent_cache: HashMap<K, PersistentCacheEntry>,
#[cfg(feature = "autotune-persistent-cache")]
device_id: String,
#[cfg(feature = "autotune-persistent-cache")]
name: String,
}
/// Result of the cache try
pub enum TuneCacheResult<K> {
/// An operation is found and given
Hit(Box<dyn AutotuneOperation>),
/// No operation is found and the set is given back for ownership
Miss(Box<dyn AutotuneOperationSet<K>>),
}
impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn new(
#[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))] name: &str,
#[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))]
device_id: &str,
) -> Self {
#[cfg(feature = "autotune-persistent-cache")]
{
let mut cache = TuneCache {
in_memory_cache: HashMap::new(),
persistent_cache: HashMap::new(),
device_id: device_id.to_string(),
name: name.to_string(),
};
if let Err(e) = cache.load() {
log::warn!(
"Unable to load autotune cache. Cache will be ignored ({}).",
e
);
}
cache
}
#[cfg(not(feature = "autotune-persistent-cache"))]
{
TuneCache {
in_memory_cache: HashMap::new(),
}
}
}
pub(crate) fn find_fastest(&self, key: &K) -> Option<usize> {
let val = self.in_memory_cache.get(key)?;
#[cfg(feature = "autotune-persistent-cache")]
if val.checksum_checked {
Some(val.fastest_index)
} else {
None
}
#[cfg(not(feature = "autotune-persistent-cache"))]
Some(val.fastest_index)
}
pub(crate) fn try_cache(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
) -> TuneCacheResult<K> {
let key = autotune_operation_set.key();
let result = self.in_memory_cache.get_mut(&key);
#[cfg(feature = "autotune-persistent-cache")]
{
if let Some(InMemoryCacheEntry {
checksum_checked,
fastest_index,
}) = result
{
if !*checksum_checked {
let checksum = autotune_operation_set.compute_checksum();
let persistent_entry = self
.persistent_cache
.get(&key)
.expect("Both caches should be in sync");
if checksum != persistent_entry.checksum {
return TuneCacheResult::Miss(autotune_operation_set);
}
*checksum_checked = true;
}
return TuneCacheResult::Hit(autotune_operation_set.fastest(*fastest_index));
}
}
#[cfg(not(feature = "autotune-persistent-cache"))]
{
if let Some(InMemoryCacheEntry { fastest_index, .. }) = result {
return TuneCacheResult::Hit(autotune_operation_set.fastest(*fastest_index));
}
}
TuneCacheResult::Miss(autotune_operation_set)
}
pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
self.in_memory_cache.insert(
key,
InMemoryCacheEntry {
#[cfg(feature = "autotune-persistent-cache")]
checksum_checked: true,
fastest_index,
},
);
}
#[cfg(feature = "autotune-persistent-cache")]
pub(crate) fn persistent_cache_insert(
&mut self,
key: K,
checksum: String,
fastest_index: usize,
) {
self.persistent_cache.insert(
key,
PersistentCacheEntry {
checksum,
fastest_index,
},
);
}
/// Load the persistent cache data from disk
#[cfg(feature = "autotune-persistent-cache")]
pub(crate) fn load(&mut self) -> Result<(), io::Error> {
let file_path = self.get_persistent_cache_file_path();
// note: reading file from memory is faster than using
// serde from_reader with a buffered reader
// see issue:
// https://github.com/serde-rs/json/issues/160
match fs::read_to_string(file_path) {
Ok(data) => {
let data: Vec<(K, PersistentCacheEntry)> = serde_json::from_str(&data)?;
for (key, value) in data.into_iter() {
self.persistent_cache.insert(key, value);
}
Ok(())
}
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
Ok(())
} else {
Err(e)
}
}
}?;
for (key, entry) in self.persistent_cache.iter() {
self.in_memory_cache.insert(
key.clone(),
InMemoryCacheEntry {
checksum_checked: false,
fastest_index: entry.fastest_index,
},
);
}
Ok(())
}
/// Save the persistent cache on disk
#[cfg(feature = "autotune-persistent-cache")]
pub(crate) fn save(&self) {
let file_path = self.get_persistent_cache_file_path();
if let Some(parent_dir) = file_path.parent() {
if !parent_dir.exists() {
fs::create_dir_all(parent_dir).unwrap_or_else(|_| {
panic!(
"Should be able to create directory '{}' for autotune persistent cache file",
parent_dir.to_str().unwrap())
});
}
}
let file = File::create(file_path.clone()).unwrap_or_else(|_| {
panic!(
"Should be able to open autotune persistent cache file '{}'",
file_path.to_str().unwrap()
)
});
let data = self.persistent_cache.iter().collect::<Vec<_>>();
serde_json::to_writer_pretty(file, &data)
.expect("Should be able to write to autotune persistent cache");
}
/// Return the file path for the persistent cache on disk
#[cfg(feature = "autotune-persistent-cache")]
pub fn get_persistent_cache_file_path(&self) -> PathBuf {
get_persistent_cache_file_path(&format!("{}-{}", self.name, self.device_id))
}
}

View File

@ -1,124 +0,0 @@
#[cfg(target_family = "wasm")]
use web_time::Duration;
#[cfg(not(target_family = "wasm"))]
use core::time::Duration;
use alloc::boxed::Box;
use alloc::string::ToString;
use alloc::vec::Vec;
use burn_common::benchmark::{Benchmark, BenchmarkComputations, BenchmarkDurations};
use crate::channel::ComputeChannel;
use crate::client::ComputeClient;
use crate::server::ComputeServer;
use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCache};
use super::AutotuneKey;
#[derive(Debug)]
/// Executes autotune benchmarking and caching
pub struct Tuner<K: AutotuneKey> {
tune_cache: TuneCache<K>,
}
#[allow(clippy::new_without_default)]
impl<K: AutotuneKey> Tuner<K> {
/// Returns a tuner with cache initialized from persistent cache
pub fn new(name: &str, device_id: &str) -> Self {
Self {
tune_cache: TuneCache::new(name, device_id),
}
}
/// Fetch the fastest autotune operation index for an autotune key.
pub fn autotune_fastest(&self, key: &K) -> Option<usize> {
self.tune_cache.find_fastest(key)
}
/// Execute the fastest autotune operation if known, otherwise perform some benchmarks before.
pub fn execute_autotune<S, C>(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
client: &ComputeClient<S, C>,
) where
S: ComputeServer,
C: ComputeChannel<S>,
{
let operation = match self.tune_cache.try_cache(autotune_operation_set) {
super::TuneCacheResult::Hit(ops) => ops,
super::TuneCacheResult::Miss(set) => self.autotuning(set, client),
};
AutotuneOperation::execute(operation);
}
fn autotuning<S, C>(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
client: &ComputeClient<S, C>,
) -> Box<dyn AutotuneOperation>
where
S: ComputeServer,
C: ComputeChannel<S>,
{
let key = autotune_operation_set.key();
let autotunables = autotune_operation_set.autotunables();
let mut names = Vec::with_capacity(autotunables.len());
let results: Vec<BenchmarkDurations> = autotunables
.into_iter()
.map(|op| {
names.push(op.name().to_string());
self.run_benchmark(op, client)
})
.collect();
// Finds the fastest operation, stores it and returns it
let fastest_index = self.find_fastest(results);
let fastest_name = names.get(fastest_index).unwrap();
log::info!("Fastest result {fastest_name}-{key}");
self.tune_cache.cache_insert(key.clone(), fastest_index);
#[cfg(feature = "autotune-persistent-cache")]
{
let checksum = autotune_operation_set.compute_checksum();
self.tune_cache
.persistent_cache_insert(key, checksum, fastest_index);
self.tune_cache.save();
}
match self.tune_cache.try_cache(autotune_operation_set) {
super::TuneCacheResult::Hit(ops) => ops,
super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"),
}
}
fn run_benchmark<S, C>(
&mut self,
operation: Box<dyn AutotuneOperation>,
client: &ComputeClient<S, C>,
) -> BenchmarkDurations
where
S: ComputeServer,
C: ComputeChannel<S>,
{
TuneBenchmark::new(operation, client.clone()).run()
}
fn find_fastest(&self, results: Vec<BenchmarkDurations>) -> usize {
let mut smallest_duration = Duration::MAX;
let mut fastest_tunable = None;
for (i, result) in results.into_iter().enumerate() {
let computed = BenchmarkComputations::new(&result);
if computed.median < smallest_duration {
smallest_duration = computed.median;
fastest_tunable = Some(i);
}
}
fastest_tunable.expect("At least one kernel needed. ")
}
}

View File

@ -1,37 +0,0 @@
use std::sync::Arc;
use super::DummyServer;
use burn_common::stub::RwLock;
use burn_compute::channel::MutexComputeChannel;
use burn_compute::client::ComputeClient;
use burn_compute::memory_management::simple::{
DeallocStrategy, SimpleMemoryManagement, SliceStrategy,
};
use burn_compute::storage::BytesStorage;
use burn_compute::tune::Tuner;
use burn_compute::ComputeRuntime;
/// The dummy device.
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct DummyDevice;
pub type DummyChannel = MutexComputeChannel<DummyServer>;
pub type DummyClient = ComputeClient<DummyServer, DummyChannel>;
static RUNTIME: ComputeRuntime<DummyDevice, DummyServer, DummyChannel> = ComputeRuntime::new();
pub static TUNER_DEVICE_ID: &str = "tests/dummy-device";
pub static TUNER_PREFIX: &str = "dummy-tests/dummy-device";
pub fn init_client() -> ComputeClient<DummyServer, MutexComputeChannel<DummyServer>> {
let storage = BytesStorage::default();
let memory_management =
SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never);
let server = DummyServer::new(memory_management);
let channel = MutexComputeChannel::new(server);
let tuner = Arc::new(RwLock::new(Tuner::new("dummy", TUNER_DEVICE_ID)));
ComputeClient::new(channel, tuner, Arc::new(()))
}
pub fn client(device: &DummyDevice) -> DummyClient {
RUNTIME.client(device, init_client)
}

View File

@ -1,25 +0,0 @@
use burn_compute::storage::BytesResource;
/// The DummyKernel trait should be implemented for every supported operation
pub trait DummyKernel: Sync + Send {
fn compute(&self, resources: &mut [BytesResource]);
}
/// Contains the algorithm for element-wise addition
pub struct DummyElementwiseAddition;
impl DummyKernel for DummyElementwiseAddition {
fn compute(&self, inputs: &mut [BytesResource]) {
// Notice how the kernel is responsible for determining which inputs
// are read-only and which are writable.
let lhs = &inputs[0].read();
let rhs = &inputs[1].read();
let out = &mut inputs[2].write();
let size = lhs.len();
for i in 0..size {
out[i] = lhs[i] + rhs[i];
}
}
}

View File

@ -1,9 +0,0 @@
mod compute;
mod kernel;
mod server;
mod tune;
pub use compute::*;
pub use kernel::*;
pub use server::*;
pub use tune::*;

View File

@ -1,74 +0,0 @@
use std::sync::Arc;
use burn_common::{reader::reader_from_concrete, sync_type::SyncType};
use burn_compute::{
memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement},
server::{Binding, ComputeServer, Handle},
storage::{BytesResource, BytesStorage},
};
use derive_new::new;
use super::DummyKernel;
/// The dummy server is used to test the burn-compute infrastructure.
/// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks.
#[derive(new, Debug)]
pub struct DummyServer<MM = SimpleMemoryManagement<BytesStorage>> {
memory_management: MM,
}
impl<MM> ComputeServer for DummyServer<MM>
where
MM: MemoryManagement<BytesStorage>,
{
type DispatchOptions = ();
type Kernel = Arc<dyn DummyKernel>;
type Storage = BytesStorage;
type MemoryManagement = MM;
type AutotuneKey = String;
type FeatureSet = ();
fn read(&mut self, binding: Binding<Self>) -> burn_common::reader::Reader {
let bytes = self.memory_management.get(binding.memory);
reader_from_concrete(bytes.read().to_vec())
}
fn get_resource(&mut self, binding: Binding<Self>) -> BytesResource {
self.memory_management.get(binding.memory)
}
fn create(&mut self, data: &[u8]) -> Handle<Self> {
let handle = self.memory_management.reserve(data.len(), || {});
let resource = self.memory_management.get(handle.clone().binding());
let bytes = resource.write();
for (i, val) in data.iter().enumerate() {
bytes[i] = *val;
}
Handle::new(handle)
}
fn empty(&mut self, size: usize) -> Handle<Self> {
Handle::new(self.memory_management.reserve(size, || {}))
}
fn execute(
&mut self,
kernel: Self::Kernel,
_count: Self::DispatchOptions,
bindings: Vec<Binding<Self>>,
) {
let mut resources = bindings
.into_iter()
.map(|binding| self.memory_management.get(binding.memory))
.collect::<Vec<_>>();
kernel.compute(&mut resources);
}
fn sync(&mut self, _: SyncType) {
// Nothing to do with dummy backend.
}
}

View File

@ -1,32 +0,0 @@
use std::sync::Arc;
use burn_compute::{client::ComputeClient, server::Binding, tune::AutotuneOperation};
use derive_new::new;
use crate::dummy::{DummyChannel, DummyKernel, DummyServer};
#[derive(new)]
/// Extended kernel that accounts for additional parameters, i.e. needed
/// information that does not count as an input/output.
pub struct OneKernelAutotuneOperation {
kernel: Arc<dyn DummyKernel>,
client: ComputeClient<DummyServer, DummyChannel>,
shapes: Vec<Vec<usize>>,
bindings: Vec<Binding<DummyServer>>,
}
impl AutotuneOperation for OneKernelAutotuneOperation {
/// Executes the operation on given bindings and server, with the additional parameters
fn execute(self: Box<Self>) {
self.client.execute(self.kernel.clone(), (), self.bindings);
}
fn clone(&self) -> Box<dyn AutotuneOperation> {
Box::new(Self {
kernel: self.kernel.clone(),
client: self.client.clone(),
shapes: self.shapes.clone(),
bindings: self.bindings.clone(),
})
}
}

View File

@ -1,106 +0,0 @@
use std::{thread::sleep, time::Duration};
use burn_compute::storage::BytesResource;
use crate::dummy::DummyKernel;
const SLEEP_MS: u64 = 1;
pub struct DummyElementwiseAdditionSlowWrong;
pub struct DummyElementwiseMultiplication;
pub struct DummyElementwiseMultiplicationSlowWrong;
pub struct CacheTestFastOn3;
pub struct CacheTestSlowOn3;
pub struct ParameteredKernel;
impl DummyKernel for DummyElementwiseAdditionSlowWrong {
fn compute(&self, inputs: &mut [BytesResource]) {
// Slow and wrong on purpose, for tests
let lhs = &inputs[0].read();
let out = &mut inputs[2].write();
let size = lhs.len();
for i in 0..size {
sleep(Duration::from_millis(SLEEP_MS));
out[i] = lhs[i]
}
}
}
impl DummyKernel for DummyElementwiseMultiplication {
fn compute(&self, inputs: &mut [BytesResource]) {
let lhs = &inputs[0].read();
let rhs = &inputs[1].read();
let out = &mut inputs[2].write();
let size = lhs.len();
for i in 0..size {
out[i] = lhs[i] * rhs[i];
}
}
}
impl DummyKernel for DummyElementwiseMultiplicationSlowWrong {
fn compute(&self, inputs: &mut [BytesResource]) {
// Slow and wrong on purpose, for tests
let lhs = &inputs[0].read();
let out = &mut inputs[2].write();
let size = lhs.len();
for i in 0..size {
sleep(Duration::from_millis(SLEEP_MS));
out[i] = lhs[i];
}
}
}
impl DummyKernel for CacheTestFastOn3 {
fn compute(&self, inputs: &mut [BytesResource]) {
// This is an artificial kernel designed for testing cache only
let lhs = &inputs[0].read();
let out = &mut inputs[2].write();
let size = lhs.len();
if size == 3 {
out[..size].copy_from_slice(&lhs[..size]);
} else {
for i in 0..size {
sleep(Duration::from_millis(SLEEP_MS));
out[i] = lhs[i];
}
}
}
}
impl DummyKernel for CacheTestSlowOn3 {
fn compute(&self, inputs: &mut [BytesResource]) {
// This is an artificial kernel designed for testing cache only
let lhs = &inputs[0].read();
let rhs = &inputs[1].read();
let out = &mut inputs[2].write();
let size = lhs.len();
if size == 3 {
for i in 0..size {
sleep(Duration::from_millis(SLEEP_MS));
out[i] = rhs[i];
}
} else {
out[..size].copy_from_slice(&rhs[..size]);
}
}
}
impl DummyKernel for ParameteredKernel {
fn compute(&self, inputs: &mut [BytesResource]) {
// This is an artificial kernel designed for info buffer
let lhs = &inputs[0].read();
let rhs = &inputs[1].read();
let out = &mut inputs[2].write();
let info = &inputs[3].read();
for i in 0..lhs.len() {
out[i] = lhs[i] + rhs[i] + info[0];
}
}
}

View File

@ -1,8 +0,0 @@
mod autotune_operations;
mod kernels;
mod operation_sets;
pub use autotune_operations::*;
pub use kernels::*;
#[allow(unused)]
pub use operation_sets::*;

View File

@ -1,194 +0,0 @@
#[cfg(feature = "autotune-persistent-cache")]
use rand::{distributions::Alphanumeric, Rng};
use std::sync::Arc;
#[cfg(feature = "autotune-persistent-cache")]
use burn_compute::tune::compute_checksum;
use burn_compute::{
server::Binding,
tune::{AutotuneOperation, AutotuneOperationSet},
};
use crate::dummy::{
CacheTestFastOn3, CacheTestSlowOn3, DummyClient, DummyElementwiseAddition,
DummyElementwiseMultiplication, DummyElementwiseMultiplicationSlowWrong, DummyServer,
OneKernelAutotuneOperation,
};
use super::DummyElementwiseAdditionSlowWrong;
pub struct AdditionAutotuneOperationSet {
client: DummyClient,
key: String,
shapes: Vec<Vec<usize>>,
bindings: Vec<Binding<DummyServer>>,
}
impl AdditionAutotuneOperationSet {
#[allow(dead_code)]
pub fn new(
client: DummyClient,
shapes: Vec<Vec<usize>>,
bindings: Vec<Binding<DummyServer>>,
) -> Self {
Self {
client,
key: format!("{}-{}", "add", log_shape_input_key(&shapes)),
shapes,
bindings,
}
}
}
impl AutotuneOperationSet<String> for AdditionAutotuneOperationSet {
fn key(&self) -> String {
self.key.clone()
}
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
vec![
Box::new(OneKernelAutotuneOperation::new(
Arc::new(DummyElementwiseAddition),
self.client.clone(),
self.shapes.clone(),
self.bindings.clone(),
)),
Box::new(OneKernelAutotuneOperation::new(
Arc::new(DummyElementwiseAdditionSlowWrong),
self.client.clone(),
self.shapes.clone(),
self.bindings.clone(),
)),
]
}
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
self.autotunables()[fastest_index].clone()
}
}
pub struct MultiplicationAutotuneOperationSet {
client: DummyClient,
key: String,
shapes: Vec<Vec<usize>>,
bindings: Vec<Binding<DummyServer>>,
}
impl MultiplicationAutotuneOperationSet {
#[allow(dead_code)]
pub fn new(
client: DummyClient,
shapes: Vec<Vec<usize>>,
bindings: Vec<Binding<DummyServer>>,
) -> Self {
Self {
client,
key: format!("{}-{}", "mul", log_shape_input_key(&shapes)),
shapes,
bindings,
}
}
}
impl AutotuneOperationSet<String> for MultiplicationAutotuneOperationSet {
fn key(&self) -> String {
self.key.clone()
}
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
vec![
Box::new(OneKernelAutotuneOperation::new(
Arc::new(DummyElementwiseMultiplicationSlowWrong),
self.client.clone(),
self.shapes.clone(),
self.bindings.clone(),
)),
Box::new(OneKernelAutotuneOperation::new(
Arc::new(DummyElementwiseMultiplication),
self.client.clone(),
self.shapes.clone(),
self.bindings.clone(),
)),
]
}
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
self.autotunables()[fastest_index].clone()
}
}
pub struct CacheTestAutotuneOperationSet {
client: DummyClient,
key: String,
shapes: Vec<Vec<usize>>,
bindings: Vec<Binding<DummyServer>>,
pub generate_random_checksum: bool,
}
impl CacheTestAutotuneOperationSet {
#[allow(dead_code)]
pub fn new(
client: DummyClient,
shapes: Vec<Vec<usize>>,
bindings: Vec<Binding<DummyServer>>,
) -> Self {
Self {
client,
key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)),
shapes,
bindings,
generate_random_checksum: false,
}
}
}
impl AutotuneOperationSet<String> for CacheTestAutotuneOperationSet {
fn key(&self) -> String {
self.key.clone()
}
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
vec![
Box::new(OneKernelAutotuneOperation::new(
Arc::new(CacheTestFastOn3),
self.client.clone(),
self.shapes.clone(),
self.bindings.clone(),
)),
Box::new(OneKernelAutotuneOperation::new(
Arc::new(CacheTestSlowOn3),
self.client.clone(),
self.shapes.clone(),
self.bindings.clone(),
)),
]
}
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
self.autotunables()[fastest_index].clone()
}
#[cfg(feature = "std")]
fn compute_checksum(&self) -> String {
if self.generate_random_checksum {
let rand_string: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect();
rand_string
} else {
compute_checksum(&self.autotunables())
}
}
}
pub fn log_shape_input_key(shapes: &[Vec<usize>]) -> String {
let mut hash = String::new();
let lhs = &shapes[0];
for size in lhs {
let exp = f32::ceil(f32::log2(*size as f32)) as u32;
hash.push_str(2_u32.pow(exp).to_string().as_str());
hash.push(',');
}
hash
}

View File

@ -1,292 +0,0 @@
mod dummy;
use std::sync::Arc;
use crate::dummy::{client, DummyDevice, DummyElementwiseAddition};
use burn_compute::ComputeRuntime;
#[allow(unused)]
use serial_test::serial;
#[test]
fn created_resource_is_the_same_when_read() {
let client = client(&DummyDevice);
let resource = Vec::from([0, 1, 2]);
let resource_description = client.create(&resource);
let obtained_resource = client.read(resource_description.binding());
assert_eq!(resource, obtained_resource)
}
#[test]
fn empty_allocates_memory() {
let client = client(&DummyDevice);
let size = 4;
let resource_description = client.empty(size);
let empty_resource = client.read(resource_description.binding());
assert_eq!(empty_resource.len(), 4);
}
#[test]
fn execute_elementwise_addition() {
let client = client(&DummyDevice);
let lhs = client.create(&[0, 1, 2]);
let rhs = client.create(&[4, 4, 4]);
let out = client.empty(3);
client.execute(
Arc::new(DummyElementwiseAddition),
(),
vec![lhs.binding(), rhs.binding(), out.clone().binding()],
);
let obtained_resource = client.read(out.binding());
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_basic_addition_execution() {
let client = client(&DummyDevice);
let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs = client.create(&[0, 1, 2]);
let rhs = client.create(&[4, 4, 4]);
let out = client.empty(3);
let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()];
let addition_autotune_kernel =
dummy::AdditionAutotuneOperationSet::new(client.clone(), shapes, handles);
client.autotune_execute(Box::new(addition_autotune_kernel));
let obtained_resource = client.read(out.binding());
// If slow kernel was selected it would output [0, 1, 2]
assert_eq!(obtained_resource, Vec::from([4, 5, 6]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_basic_multiplication_execution() {
let client = client(&DummyDevice);
let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs = client.create(&[0, 1, 2]);
let rhs = client.create(&[4, 4, 4]);
let out = client.empty(3);
let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()];
let multiplication_autotune_kernel =
dummy::MultiplicationAutotuneOperationSet::new(client.clone(), shapes, handles);
client.autotune_execute(Box::new(multiplication_autotune_kernel));
let obtained_resource = client.read(out.binding());
// If slow kernel was selected it would output [0, 1, 2]
assert_eq!(obtained_resource, Vec::from([0, 4, 8]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_same_key_return_a_cache_hit() {
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);
// note: the key name depends on the shapes of the operation set
// see log_shape_input_key for more info.
// in this test both shapes [1,3] and [1,4] end up with the same key name
// which is 'cache_test-1,4'
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs_1 = client.create(&[0, 1, 2]);
let rhs_1 = client.create(&[4, 4, 4]);
let out_1 = client.empty(3);
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]];
let lhs_2 = client.create(&[0, 1, 2, 3]);
let rhs_2 = client.create(&[5, 6, 7, 8]);
let out_2 = client.empty(4);
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
let cache_test_autotune_kernel_1 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
let cache_test_autotune_kernel_2 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
let obtained_resource = client.read(out_2.binding());
// Cache should be hit, so CacheTestFastOn3 should be used, returning lhs
assert_eq!(obtained_resource, Vec::from([0, 1, 2, 3]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
// delete the cache file
let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX);
let _ = std::fs::remove_file(file_path);
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
let compute = Runtime::new();
let client = compute.client(&DummyDevice, dummy::init_client);
// in this test shapes [1,3] and [1,5] ends up with different key names
// which are 'cache_test-1,4' and 'cache_test-1,8'
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs_1 = client.create(&[0, 1, 2]);
let rhs_1 = client.create(&[4, 4, 4]);
let out_1 = client.empty(3);
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]];
let lhs_2 = client.create(&[0, 1, 2, 3, 4]);
let rhs_2 = client.create(&[5, 6, 7, 8, 9]);
let out_2 = client.empty(5);
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
let cache_test_autotune_kernel_1 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
let cache_test_autotune_kernel_2 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
// read the resource which should update the cache on disk
let obtained_resource = client.read(out_2.binding());
// Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
// delete the cache file
use burn_common::sync_type::SyncType;
let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX);
let parent_dir = file_path
.parent()
.expect("Cache file should have a parent directory");
// Delete the cache file's parent directory
let _ = std::fs::remove_dir_all(parent_dir);
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);
// in this test shapes [1,3] and [1,5] ends up with different key names
// which are 'cache_test-1,4' and 'cache_test-1,8'
let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs = client.create(&[0, 1, 2]);
let rhs = client.create(&[4, 4, 4]);
let out = client.empty(3);
let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()];
let cache_test_autotune_kernel =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes, handles);
client.autotune_execute(Box::new(cache_test_autotune_kernel));
// ensure that the autotune operations are run and cached
client.sync(SyncType::Wait);
assert!(
parent_dir.exists(),
"Parent directory of the cache file should exist"
);
assert!(file_path.exists(), "Cache file should exist");
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_different_keys_return_a_cache_miss() {
let client = client(&DummyDevice);
// in this test shapes [1,3] and [1,5] ends up with different key names
// which are 'cache_test-1,4' and 'cache_test-1,8'
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs_1 = client.create(&[0, 1, 2]);
let rhs_1 = client.create(&[4, 4, 4]);
let out_1 = client.empty(3);
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]];
let lhs_2 = client.create(&[0, 1, 2, 3, 4]);
let rhs_2 = client.create(&[5, 6, 7, 8, 9]);
let out_2 = client.empty(5);
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
let cache_test_autotune_kernel_1 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
let cache_test_autotune_kernel_2 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
let obtained_resource = client.read(out_2.binding());
// Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_different_checksums_return_a_cache_miss() {
use burn_common::sync_type::SyncType;
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);
// in this test both shapes [1,3] and [1,4] end up with the same key name
// which is 'cache_test-1,4'
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs_1 = client.create(&[0, 1, 2]);
let rhs_1 = client.create(&[4, 4, 4]);
let out_1 = client.empty(3);
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
let cache_test_autotune_kernel_1 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1);
client.autotune_execute(Box::new(cache_test_autotune_kernel_1));
client.sync(SyncType::Wait);
// we use a second compute client in order to have freshly initialized autotune cache
// and test invalidation of the cache when the checksum of the operation set is
// different
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);
let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]];
let lhs_2 = client.create(&[0, 1, 2, 3]);
let rhs_2 = client.create(&[5, 6, 7, 8]);
let out_2 = client.empty(4);
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
let mut cache_test_autotune_kernel_2 =
dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2);
cache_test_autotune_kernel_2.generate_random_checksum = true;
client.autotune_execute(Box::new(cache_test_autotune_kernel_2));
client.sync(SyncType::Wait);
let obtained_resource = client.read(out_2.binding());
// Cache should be missed because the checksum on 4 is generated randomly
// and thus is always different,
// so CacheTestSlowOn3 (but faster on 4) should be used, returning rhs
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8]));
}

View File

@ -1,27 +0,0 @@
[package]
authors = [
"nathanielsimard <nathaniel.simard.42@gmail.com>",
"louisfd <louisfd94@gmail.com",
]
categories = ["science"]
description = "TODO"
edition.workspace = true
keywords = []
license.workspace = true
name = "burn-cube-macros"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube-macros"
version.workspace = true
[lib]
proc-macro = true
[features]
default = []
std = []
[dependencies]
proc-macro2 = { workspace = true }
quote = { workspace = true }
syn = { workspace = true }
derive-new = { workspace = true }

View File

@ -1,280 +0,0 @@
use syn::{Member, Pat, PathArguments, Stmt};
use crate::tracker::VariableTracker;
pub const KEYWORDS: [&str; 20] = [
"ABSOLUTE_POS",
"ABSOLUTE_POS_X",
"ABSOLUTE_POS_Y",
"ABSOLUTE_POS_Z",
"UNIT_POS",
"UNIT_POS_X",
"UNIT_POS_Y",
"UNIT_POS_Z",
"CUBE_POS",
"CUBE_POS_X",
"CUBE_POS_Y",
"CUBE_POS_Z",
"CUBE_DIM",
"CUBE_DIM_X",
"CUBE_DIM_Y",
"CUBE_DIM_Z",
"CUBE_COUNT",
"CUBE_COUNT_X",
"CUBE_COUNT_Y",
"CUBE_COUNT_Z",
];
#[derive(Debug, Default)]
/// Reads the whole Cube code and accumulates information,
/// to generate a VariableTracker that looked variable uses ahead
pub(crate) struct VariableAnalyzer {
variable_tracker: VariableTracker,
}
impl VariableAnalyzer {
pub fn create_tracker(func: &syn::ItemFn) -> VariableTracker {
let analyzer = VariableAnalyzer::default();
analyzer.analyze(func)
}
}
impl VariableAnalyzer {
fn analyze(mut self, func: &syn::ItemFn) -> VariableTracker {
// Build the vector of (Id, depth), using recursion
self.signature_declarations(&func.sig);
self.find_occurrences_in_stmts(&func.block.stmts, 0);
self.variable_tracker
}
fn signature_declarations(&mut self, sig: &syn::Signature) {
for input in &sig.inputs {
match input {
syn::FnArg::Typed(pat) => {
let ident = &*pat.pat;
let is_comptime = is_ty_comptime(&pat.ty);
match ident {
syn::Pat::Ident(pat_ident) => {
let id = &pat_ident.ident;
self.variable_tracker
.analyze_declare(id.to_string(), 0, is_comptime);
}
_ => todo!("Analysis: unsupported ident {ident:?}"),
}
}
_ => todo!("Analysis: unsupported input {input:?}"),
}
}
}
fn find_occurrences_in_stmts(&mut self, stmts: &Vec<Stmt>, depth: u8) {
for stmt in stmts {
match stmt {
// Declaration
syn::Stmt::Local(local) => {
let mut is_comptime = false;
let id = match &local.pat {
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
syn::Pat::Type(pat_type) => {
is_comptime = is_ty_comptime(&pat_type.ty);
match &*pat_type.pat {
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
_ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat),
}
}
syn::Pat::Wild(_) => None,
_ => todo!("Analysis: unsupported path {:?}", local.pat),
};
if let Some(id) = id {
self.variable_tracker
.analyze_declare(id.to_string(), depth, is_comptime);
}
if let Some(local_init) = &local.init {
self.find_occurrences_in_expr(&local_init.expr, depth)
}
}
syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth),
_ => todo!("Analysis: unsupported stmt {stmt:?}"),
}
}
}
fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: u8) {
match expr {
syn::Expr::ForLoop(expr) => {
self.find_occurrences_in_expr(&expr.expr, depth);
let depth = depth + 1;
if let syn::Pat::Ident(pat_ident) = &*expr.pat {
let id = &pat_ident.ident;
self.variable_tracker
.analyze_declare(id.to_string(), depth, false);
}
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::While(expr) => {
let depth = depth + 1;
self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::Loop(expr) => {
let depth = depth + 1;
self.find_occurrences_in_stmts(&expr.body.stmts, depth);
}
syn::Expr::If(expr) => {
let depth = depth + 1;
self.find_occurrences_in_expr(&expr.cond, depth);
self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth);
if let Some((_, expr)) = &expr.else_branch {
match &**expr {
syn::Expr::Block(expr_block) => {
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
}
syn::Expr::If(expr) => {
self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth);
}
_ => unreachable!(),
}
}
}
syn::Expr::Assign(expr) => {
self.find_occurrences_in_expr(&expr.left, depth);
self.find_occurrences_in_expr(&expr.right, depth);
}
syn::Expr::Index(expr) => {
self.find_occurrences_in_expr(&expr.expr, depth);
self.find_occurrences_in_expr(&expr.index, depth);
}
syn::Expr::Path(expr) => {
if let Some(ident) = expr.path.get_ident() {
if !KEYWORDS.contains(&ident.to_string().as_str()) {
self.variable_tracker.analyze_reuse(ident, depth, None);
}
}
}
syn::Expr::Binary(expr) => {
self.find_occurrences_in_expr(&expr.left, depth);
self.find_occurrences_in_expr(&expr.right, depth);
}
syn::Expr::Lit(_) => {}
syn::Expr::Call(expr) => {
match &*expr.func {
syn::Expr::Path(expr_path) => {
if let Some(first_segment) = expr_path.path.segments.first() {
// Check if the path segment has generic arguments
if let PathArguments::AngleBracketed(arguments) =
&first_segment.arguments
{
// Extract the generic arguments
for arg in &arguments.args {
match arg {
syn::GenericArgument::Type(_)
| syn::GenericArgument::Constraint(_) => {}
_ => todo!("Analysis: Generic {:?} not supported", arg),
}
}
}
}
}
_ => todo!("Analysis: unsupported func expr {:?}", expr.func),
}
for arg in expr.args.iter() {
self.find_occurrences_in_expr(arg, depth);
}
}
syn::Expr::MethodCall(expr) => {
self.find_occurrences_in_expr(&expr.receiver, depth);
for arg in expr.args.iter() {
self.find_occurrences_in_expr(arg, depth);
}
}
syn::Expr::Break(_) => {}
syn::Expr::Return(expr) => {
if expr.expr.is_some() {
// Unsupported: handled in codegen.
}
}
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Array(_expr) => {
// No analysis since only literals are supported
}
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Closure(expr) => {
let depth = depth + 1;
for path in expr.inputs.iter() {
let mut is_comptime = false;
let ident = match path {
Pat::Ident(pat_ident) => &pat_ident.ident,
Pat::Type(pat_type) => {
is_comptime = is_ty_comptime(&pat_type.ty);
if let Pat::Ident(pat_ident) = &*pat_type.pat {
&pat_ident.ident
} else {
todo!("Analysis: {:?} not supported in closure inputs. ", path);
}
}
_ => todo!("Analysis: {:?} not supported in closure inputs. ", path),
};
self.variable_tracker
.analyze_declare(ident.to_string(), depth, is_comptime);
}
self.find_occurrences_in_expr(&expr.body, depth)
}
syn::Expr::Unary(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Field(expr) => {
if let Member::Named(attribute_ident) = &expr.member {
if let syn::Expr::Path(struct_expr) = &*expr.base {
let struct_ident = struct_expr
.path
.get_ident()
.expect("Analysis: field access only supported on ident struct.");
self.variable_tracker.analyze_reuse(
struct_ident,
depth,
Some(attribute_ident.to_string()),
);
} else {
todo!("Analysis: field access only supported on ident struct.");
}
} else {
todo!("Analysis: unnamed attribute not supported.");
}
}
syn::Expr::Struct(expr) => {
for field in expr.fields.iter() {
self.find_occurrences_in_expr(&field.expr, depth)
}
}
syn::Expr::Range(_range) => {
// Error is handled during codegen.
}
_ => {
// Error is handled during codegen.
}
}
}
}
fn is_ty_comptime(ty: &syn::Type) -> bool {
if let syn::Type::Path(path) = ty {
for segment in path.path.segments.iter() {
if segment.ident == "Comptime" {
return true;
}
}
}
false
}

View File

@ -1 +0,0 @@
pub(crate) mod signature;

View File

@ -1,70 +0,0 @@
use quote::ToTokens;
use crate::tracker::VariableTracker;
#[derive(Copy, Clone, Debug)]
pub enum ExpandMode {
FuncImpl,
MethodImpl,
}
pub fn expand_sig(
sig: &syn::Signature,
visibility: &syn::Visibility,
mut variable_tracker: Option<&mut VariableTracker>,
mode: ExpandMode,
) -> proc_macro2::TokenStream {
let mut inputs = quote::quote!();
for input in &sig.inputs {
match input {
syn::FnArg::Typed(pat) => {
let ident = pat.pat.clone();
if let syn::Pat::Ident(ident) = ident.as_ref() {
if let Some(vars) = &mut variable_tracker {
vars.codegen_declare(ident.ident.to_string(), 0);
}
}
let ty = no_ref(pat.ty.as_ref());
inputs.extend(quote::quote! {
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
});
}
_ => todo!("Only Typed inputs are supported"),
}
}
let mut output = quote::quote!();
match &sig.output {
syn::ReturnType::Default => output.extend(quote::quote! { ()}),
syn::ReturnType::Type(_, ty) => {
let ty = no_ref(ty.as_ref());
output.extend(quote::quote! {
<#ty as burn_cube::frontend::CubeType>::ExpandType
});
}
}
let ident = &sig.ident;
let ident = match mode {
ExpandMode::FuncImpl => syn::Ident::new("__expand".to_string().as_str(), ident.span()),
_ => syn::Ident::new(format!("__expand_{ident}").as_str(), ident.span()),
};
let generics = sig.generics.clone().into_token_stream();
quote::quote! {
/// Expanded Cube function
#visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output
}
}
pub fn no_ref(ty: &syn::Type) -> &syn::Type {
match ty {
syn::Type::Reference(val) => &val.elem,
_ => ty,
}
}

View File

@ -1,94 +0,0 @@
use proc_macro2::TokenStream;
use quote::ToTokens;
use super::{expr::codegen_expr, variable::codegen_local};
use crate::tracker::VariableTracker;
/// Codegen for a statement (generally one line)
/// Entry point of code generation
pub fn codegen_statement(
statement: &syn::Stmt,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
match statement {
syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_tracker),
syn::Stmt::Expr(expr, semi) => {
let expr = codegen_expr(expr, loop_level, variable_tracker).tokens;
match semi {
Some(_semi) => quote::quote!(
#expr;
),
None => expr,
}
}
_ => todo!("Codegen: statement {statement:?} not supported"),
}
}
/// Codegen for a code block (a list of statements)
pub(crate) fn codegen_block(
block: &syn::Block,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let mut statements = quote::quote!();
for statement in block.stmts.iter() {
statements.extend(codegen_statement(statement, loop_level, variable_tracker));
}
quote::quote! {
{
#statements
}
}
}
pub(crate) struct Codegen {
pub tokens: proc_macro2::TokenStream,
pub is_comptime: bool,
pub array_indexing: Option<ArrayIndexing>,
}
pub(crate) struct ArrayIndexing {
pub array: proc_macro2::TokenStream,
pub index: proc_macro2::TokenStream,
}
impl From<proc_macro2::TokenStream> for Codegen {
fn from(tokens: proc_macro2::TokenStream) -> Self {
Self {
tokens,
is_comptime: false,
array_indexing: None,
}
}
}
impl Codegen {
pub fn new<S: Into<proc_macro2::TokenStream>>(tokens: S, is_comptime: bool) -> Self {
Self {
tokens: tokens.into(),
is_comptime,
array_indexing: None,
}
}
pub fn split(self) -> (proc_macro2::TokenStream, bool) {
(self.tokens, self.is_comptime)
}
}
impl ToTokens for Codegen {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.extend(self.tokens.clone());
}
fn into_token_stream(self) -> TokenStream
where
Self: Sized,
{
self.tokens
}
}

View File

@ -1,193 +0,0 @@
use proc_macro2::TokenStream;
use crate::{codegen_function::expr::codegen_expr, tracker::VariableTracker};
use super::{
base::{codegen_block, Codegen},
function::codegen_call,
operation::codegen_binary,
variable::{codegen_lit, codegen_path_var},
};
/// Codegen of for loops
/// Supports range:
/// for i in range(start, end, unroll) {...}
pub(crate) fn codegen_for_loop(
for_loop: &syn::ExprForLoop,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let i = &for_loop.pat;
if let syn::Pat::Ident(pat_ident) = &*for_loop.pat {
let id = &pat_ident.ident;
variable_tracker.codegen_declare(id.to_string(), loop_level as u8 + 1);
}
let invalid_for_loop = || {
syn::Error::new_spanned(
&for_loop.expr,
"Invalid for loop: use [range](cubecl::prelude::range] instead.",
)
.into_compile_error()
};
match for_loop.expr.as_ref() {
syn::Expr::Call(call) => {
let func_name = match call.func.as_ref() {
syn::Expr::Path(path) => match path.path.get_ident() {
Some(ident) => ident,
None => return invalid_for_loop(),
},
_ => {
return invalid_for_loop();
}
};
if &func_name.to_string() == "range" {
let mut args = call.args.clone();
let unroll = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let end = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let start = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker);
quote::quote! {
{
let _start = #start;
let _end = #end;
let _unroll = #unroll;
burn_cube::frontend::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block);
}
}
} else {
invalid_for_loop()
}
}
_ => invalid_for_loop(),
}
}
/// Codegen for condition of an if or a while
pub(crate) fn codegen_cond(
cond: &syn::Expr,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
match cond {
syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_tracker),
syn::Expr::Lit(expr) => Codegen::new(codegen_lit(expr), false),
syn::Expr::Path(expr) => codegen_path_var(expr, loop_level, variable_tracker),
syn::Expr::Call(expr) => codegen_call(expr, loop_level, variable_tracker),
_ => todo!("{cond:?} cond not supported"),
}
}
/// Codegen for break statement
pub(crate) fn codegen_break() -> TokenStream {
quote::quote! {
burn_cube::frontend::branch::break_expand(context);
}
}
/// Codegen for return statement
pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream {
if expr_return.expr.is_some() {
return syn::Error::new_spanned(expr_return, "Only void return is supported.")
.into_compile_error();
}
quote::quote! {
burn_cube::frontend::branch::return_expand(context);
}
}
/// Codegen for if and if/else statements
/// Supports:
/// if cond {...}
/// if cond {...} else {...}
/// if Comptime::get(...) {...} [else {...}]
/// if Comptime::get(...) {...} [else if Comptime::get(...) {...}]* [else {...}]
pub(crate) fn codegen_if(
expr_if: &syn::ExprIf,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let (cond, is_comptime) = codegen_cond(&expr_if.cond, loop_level, variable_tracker).split();
let comptime_bool = if is_comptime {
quote::quote! { Some(#cond) }
} else {
quote::quote! { None }
};
let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_tracker);
if let Some((_, expr)) = &expr_if.else_branch {
let else_block = match &**expr {
syn::Expr::Block(expr_block) => {
codegen_block(&expr_block.block, loop_level + 1, variable_tracker)
}
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level + 1, variable_tracker),
_ => unreachable!(),
};
quote::quote! {
{
let _cond = #cond;
burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
}
}
} else {
quote::quote! {
let _cond = #cond;
burn_cube::frontend::branch::if_expand(context, #comptime_bool, _cond.into(), |context| #then_block);
}
}
}
/// Codegen of loop
pub(crate) fn codegen_loop(
loop_expr: &syn::ExprLoop,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let block = codegen_block(&loop_expr.body, loop_level + 1, variable_tracker);
quote::quote! {
burn_cube::frontend::branch::loop_expand(context, |context| #block);
}
}
/// Codegen for while loop
pub(crate) fn codegen_while_loop(
while_loop: &syn::ExprWhile,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let (cond, is_comptime) =
codegen_cond(&while_loop.cond, loop_level + 1, variable_tracker).split();
if is_comptime {
return syn::Error::new_spanned(while_loop.while_token, "Comptime not supported for while")
.into_compile_error();
}
let block = codegen_block(&while_loop.body, loop_level + 1, variable_tracker);
quote::quote! {
burn_cube::frontend::branch::while_loop_expand(context, |context| #cond, |context| #block);
}
}

View File

@ -1,99 +0,0 @@
use crate::tracker::VariableTracker;
use proc_macro2::TokenStream;
use super::{
base::{codegen_block, Codegen},
branch::{
codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_return,
codegen_while_loop,
},
function::{codegen_call, codegen_closure, codegen_expr_method_call},
operation::{codegen_binary, codegen_unary},
variable::{
codegen_array_lit, codegen_assign, codegen_field, codegen_index, codegen_lit,
codegen_path_var, codegen_struct,
},
};
/// Codegen for expressions
pub(crate) fn codegen_expr(
expr: &syn::Expr,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
match expr {
syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker),
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker),
_ => {
let mut array_indexing = None;
let tokens = match expr {
syn::Expr::Path(path) => {
return codegen_path_var(path, loop_level, variable_tracker)
}
syn::Expr::Binary(op) => return codegen_binary(op, loop_level, variable_tracker),
syn::Expr::Unary(op) => return codegen_unary(op, loop_level, variable_tracker),
syn::Expr::Lit(lit) => codegen_lit(lit),
syn::Expr::Closure(closure) => {
codegen_closure(closure, loop_level, variable_tracker)
}
syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_tracker),
syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_tracker),
syn::Expr::ForLoop(for_loop) => {
codegen_for_loop(for_loop, loop_level, variable_tracker)
}
syn::Expr::While(while_loop) => {
codegen_while_loop(while_loop, loop_level, variable_tracker)
}
syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_tracker),
syn::Expr::Break(_) => codegen_break(),
syn::Expr::Return(return_expr) => codegen_return(return_expr),
syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_tracker),
syn::Expr::MethodCall(call) => {
codegen_expr_method_call(call, loop_level, variable_tracker)
}
syn::Expr::Index(index) => {
let codegen = codegen_index(index, loop_level, variable_tracker);
array_indexing = codegen.array_indexing;
codegen.tokens
}
syn::Expr::Array(array) => codegen_array_lit(array),
syn::Expr::Reference(reference) => {
codegen_ref(reference, loop_level, variable_tracker)
}
syn::Expr::Field(field) => codegen_field(field, loop_level, variable_tracker),
syn::Expr::Struct(struct_) => codegen_struct(struct_, loop_level, variable_tracker),
syn::Expr::Range(range) => syn::Error::new_spanned(
range,
"Range is not supported, use [range](cubecl::prelude::range) instead.",
)
.to_compile_error(),
_ => {
syn::Error::new_spanned(expr, "Expression is not supported").to_compile_error()
}
};
let mut codegen = Codegen::new(tokens, false);
codegen.array_indexing = array_indexing;
codegen
}
}
}
/// Codegen for an expression containing a block
pub(crate) fn codegen_expr_block(
block: &syn::ExprBlock,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
codegen_block(&block.block, loop_level, variable_tracker)
}
pub(crate) fn codegen_ref(
reference: &syn::ExprReference,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
// We ignore reference for the expansion.
let inner = codegen_expr(&reference.expr, loop_level, variable_tracker);
quote::quote! { #inner }
}

View File

@ -1,250 +0,0 @@
use proc_macro2::{Span, TokenStream};
use quote::quote_spanned;
use syn::{
punctuated::Punctuated, spanned::Spanned, AngleBracketedGenericArguments, Expr, Ident,
PathArguments, Token,
};
use crate::{codegen_function::expr::codegen_expr, tracker::VariableTracker};
use super::base::Codegen;
/// Codegen for method call
/// Supports [expr].method(args)
pub(crate) fn codegen_expr_method_call(
call: &syn::ExprMethodCall,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let receiver = codegen_expr(&call.receiver, loop_level, variable_tracker);
let method_expand = syn::Ident::new(
format!("{}_expand", call.method).as_str(),
proc_macro2::Span::call_site(),
);
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
quote::quote!( {
#expansion
#receiver . #method_expand ( #variables )
})
}
/// Codegen for a closure
pub(crate) fn codegen_closure(
closure: &syn::ExprClosure,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let mut inputs = quote::quote! {};
for input in closure.inputs.iter() {
let (ident, ty) = match input {
syn::Pat::Ident(ident) => (&ident.ident, None),
syn::Pat::Type(pat_type) => (
if let syn::Pat::Ident(ident) = &*pat_type.pat {
&ident.ident
} else {
return syn::Error::new_spanned(pat_type, "Unsupported input")
.into_compile_error();
},
Some(pat_type.ty.clone()),
),
_ => return syn::Error::new_spanned(input, "Unsupported input").into_compile_error(),
};
if let Some(ty) = ty {
inputs.extend(quote::quote! {
#ident : #ty,
});
} else {
inputs.extend(quote::quote! {
#ident,
});
}
}
let body = codegen_expr(closure.body.as_ref(), loop_level, variable_tracker);
quote::quote! {
|context: &mut CubeContext, #inputs| #body
}
}
/// Maps
/// [A[::<...>]?::]^* func[::<...>] (args)
/// to
/// [A[::<...>]?::]^* func_expand[::<...>] (context, args)
///
/// Also returns a bool that is true if it's comptime
pub(crate) fn codegen_call(
call: &syn::ExprCall,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
// We start with parsing the function path
let path: Vec<(&Ident, Option<&AngleBracketedGenericArguments>)> = match call.func.as_ref() {
syn::Expr::Path(expr_path) => {
let mut path = Vec::new();
for segment in expr_path.path.segments.iter() {
let generics = if let PathArguments::AngleBracketed(arguments) = &segment.arguments
{
Some(arguments)
} else {
None
};
path.push((&segment.ident, generics));
}
path
}
_ => {
return Codegen::new(
syn::Error::new_spanned(&call.func, "Unsupported").into_compile_error(),
false,
)
}
};
// Path
let mut path_tokens = TokenStream::new();
let mut is_comptime = false;
let mut is_plain_func = true;
let mut comptime_func: Option<String> = None;
for (i, (ident, generics)) in path.iter().enumerate() {
let name = ident.to_string();
if name == "Comptime" {
is_comptime = true;
continue;
}
if let Some(first_char) = name.chars().next() {
if first_char.is_uppercase() {
is_plain_func = false;
}
}
if i == path.len() - 1 {
if is_comptime {
comptime_func = Some(ident.to_string());
break;
}
let func_name_expand = if is_plain_func {
quote::quote! {
#ident::__expand
}
} else {
let ident = syn::Ident::new(
format!("__expand_{ident}").as_str(),
proc_macro2::Span::call_site(),
);
quote::quote! {
#ident
}
};
path_tokens.extend(quote_spanned! {func_name_expand.span() => #func_name_expand });
if let Some(generics) = generics {
path_tokens.extend(quote_spanned! {generics.span() => #generics });
}
} else if let Some(generics) = generics {
path_tokens.extend(quote_spanned! {ident.span() => #ident });
path_tokens.extend(quote_spanned! {generics.span() => #generics :: });
} else {
path_tokens.extend(quote_spanned! {ident.span() => #ident :: });
}
}
// Arguments
if let Some(func_name) = comptime_func {
let tokens = match func_name.as_str() {
"get" | "new" => {
let code = call.args.first().unwrap();
quote::quote! {#code}
}
"map" => {
let args = &call.args;
// Codegen
quote::quote! {
{
Comptime::map_expand(#args)
}
}
}
"unwrap_or_else" => {
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
// Codegen
quote::quote! {{
#expansion
Comptime::unwrap_or_else_expand(#variables)
}}
}
"is_some" => {
let code = call.args.first().unwrap();
quote::quote! { #code.is_some() }
}
"vectorization" => {
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
// Codegen
quote::quote! {{
#expansion
Comptime::vectorization_expand(#variables)
}}
}
"vectorize" => {
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
// Codegen
quote::quote! {{
#expansion
Comptime::vectorize_expand(#variables)
}}
}
"runtime" => {
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
// Codegen
quote::quote! {{
#expansion
Comptime::runtime_expand(#variables)
}}
}
_ => panic!("Codegen: Comptime function {:?} does not exist", func_name),
};
Codegen::new(tokens, true)
} else {
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
// Codegen
let tokens = quote::quote! {{
#expansion
#path_tokens (#variables)
}};
Codegen::new(tokens, false)
}
}
fn codegen_args(
args: &Punctuated<Expr, Token![,]>,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> (TokenStream, TokenStream) {
let mut expansion = quote::quote! {};
let mut variables = quote::quote! {};
variables.extend(quote::quote! { context, });
for (i, argument) in args.iter().enumerate() {
let ident = Ident::new(format!("_var_{i}").as_str(), Span::call_site());
let arg_token = codegen_expr(argument, loop_level, variable_tracker);
expansion.extend(quote::quote! { let #ident = #arg_token; });
variables.extend(quote::quote! { #ident, });
}
(expansion, variables)
}

View File

@ -1,521 +0,0 @@
use proc_macro2::{Span, TokenStream};
use syn::{parse_quote, Generics, Ident};
#[derive(Default)]
struct Codegen {
// Basic attributes.
name: String,
generics: Generics,
fn_inputs: TokenStream,
fn_output: TokenStream,
// States to generate code.
state_comptimes: Vec<(syn::Type, Ident)>,
state_args: Vec<TokenStream>,
state_inputs: Vec<(Ident, syn::Type)>,
state_outputs: Vec<(Ident, syn::Type)>,
}
impl Codegen {
fn from_sig(sig: &syn::Signature) -> Self {
let mut codegen = Codegen::default();
let mut first_letter = sig.ident.to_string();
let second_part = first_letter.split_off(1);
codegen.name = format!("{}{}", first_letter.to_uppercase(), second_part);
codegen.generics = sig.generics.clone();
let mut inputs = quote::quote!();
for input in &sig.inputs {
let mut is_output = false;
let mut comptime = false;
match input {
syn::FnArg::Typed(pat) => {
let (ty, ident) = match pat.pat.as_ref() {
syn::Pat::Ident(ident) => {
if ident.mutability.is_some() {
is_output = true;
}
if let syn::Type::Reference(ty) = pat.ty.as_ref() {
if ty.mutability.is_some() {
is_output = true;
}
};
if let syn::Type::Path(pat) = pat.ty.as_ref() {
if let Some(name) = pat.path.segments.first() {
let name = name.ident.to_string();
if name == "Comptime" {
comptime = true;
}
}
};
(pat.ty.clone(), ident.ident.clone())
}
_ => panic!("Nop"),
};
if comptime {
codegen.state_args.push(quote::quote! {
self.#ident
});
} else {
codegen.state_args.push(quote::quote! {
#ident
});
}
if comptime {
let ty = no_ref(&ty);
inputs.extend(quote::quote! {
#ident: <#ty as burn_cube::frontend::CubeType>::ExpandType,
});
} else {
let ty = no_ref(&ty);
inputs.extend(quote::quote! {
#ident: RuntimeArg<'a, #ty, R>,
});
}
if is_output {
codegen
.state_outputs
.push((ident.clone(), no_ref(&ty).clone()));
} else if comptime {
codegen
.state_comptimes
.push((first_generic_ty(&ty).clone(), ident.clone()));
} else {
codegen
.state_inputs
.push((ident.clone(), no_ref(&ty).clone()));
}
}
_ => panic!("Only Typed inputs are supported"),
};
}
let mut output = quote::quote!();
match &sig.output {
syn::ReturnType::Default => output.extend(quote::quote! {()}),
syn::ReturnType::Type(_, ty) => {
output.extend(quote::quote! {
<#ty as burn_cube::frontend::CubeType>::ExpandType
});
}
}
codegen.fn_inputs = inputs;
codegen.fn_output = output;
codegen
}
fn gen_kernel_struct(&self) -> TokenStream {
let ident = Ident::new(&self.name, Span::call_site());
let generics = add_runtime(self.generics.clone());
let phantoms = self.phantoms(&generics, true);
let mut comptimes = quote::quote! {};
for (ty, ident) in self.state_comptimes.iter() {
comptimes.extend(quote::quote! {
#ident: #ty,
});
}
quote::quote! {
/// Kernel
pub struct #ident #generics {
settings: KernelSettings,
#comptimes
#phantoms
}
}
}
fn gen_settings(&self) -> TokenStream {
let mut variables = quote::quote! {};
for (pos, (ident, _ty)) in self.state_inputs.iter().enumerate() {
variables.extend(quote::quote! {
settings = ArgSettings::<R>::configure_input(&#ident, #pos, settings);
});
}
for (pos, (ident, _ty)) in self.state_outputs.iter().enumerate() {
variables.extend(quote::quote! {
settings = ArgSettings::<R>::configure_output(&#ident, #pos, settings);
});
}
quote::quote! {
let mut settings = KernelSettings::default();
settings = settings.cube_dim(cube_dim);
#variables
}
}
fn gen_register_input(&self) -> TokenStream {
let generics = &self.generics;
let mut variables = quote::quote! {};
for (pos, (_ident, ty)) in self.state_inputs.iter().enumerate() {
variables.extend(quote::quote! {
#pos => std::sync::Arc::new(<#ty as LaunchArgExpand>::expand(builder, settings.vectorization_input(#pos))),
});
}
quote::quote! {
#[allow(unused)]
fn register_input #generics(
builder: &mut KernelBuilder,
settings: &KernelSettings,
position: usize,
) -> std::sync::Arc<dyn core::any::Any> {
match position {
#variables
_ => panic!("Input {position} is invalid."),
}
}
}
}
fn gen_register_output(&self) -> TokenStream {
let generics = &self.generics;
let mut variables = quote::quote! {};
for (pos, (_ident, ty)) in self.state_outputs.iter().enumerate() {
variables.extend(quote::quote! {
#pos => std::sync::Arc::new(<#ty as LaunchArgExpand>::expand_output(builder, settings.vectorization_output(#pos))),
});
}
quote::quote! {
#[allow(unused)]
fn register_output #generics (
builder: &mut KernelBuilder,
settings: &KernelSettings,
position: usize,
) -> std::sync::Arc<dyn core::any::Any> {
match position {
#variables
_ => panic!("Input {position} is invalid."),
}
}
}
}
fn gen_define_impl(&self, expand: &TokenStream) -> TokenStream {
let mut expand_args = quote::quote! { &mut builder.context, };
let mut variables = quote::quote! {};
for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() {
variables.extend(quote::quote! {
let #ident: &<#ty as CubeType>::ExpandType = inputs
.get(&#pos)
.unwrap()
.downcast_ref()
.expect("Input type should be correct. It could be caused by an invalid kernel input/output alias.");
});
}
for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() {
variables.extend(quote::quote! {
let #ident: &<#ty as CubeType>::ExpandType = outputs
.get(&#pos)
.unwrap()
.downcast_ref()
.expect("Output type should be correct. It could be caused by an invalid kernel input/output alias.");
});
}
for arg in self.state_args.iter() {
expand_args.extend(quote::quote! {
#arg.clone(),
})
}
let expand_func = match self.generics.params.is_empty() {
true => quote::quote! { #expand },
false => {
let generics = self.generics.split_for_impl().1;
quote::quote! { #expand::#generics }
}
};
quote::quote! {
#variables
#expand_func(#expand_args);
builder.build(self.settings.clone())
}
}
fn gen_define_args(&self) -> TokenStream {
let num_inputs = self.state_inputs.len();
let num_outputs = self.state_outputs.len();
let register_input = self.gen_register_input();
let register_output = self.gen_register_output();
let (register_input_call, register_output_call) = match self.generics.params.is_empty() {
true => (
quote::quote! { register_input },
quote::quote! { register_output },
),
false => {
let generics = self.generics.split_for_impl().1;
(
quote::quote! { register_input::#generics },
quote::quote! { register_output::#generics },
)
}
};
let mut variables = quote::quote! {};
for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() {
variables.extend(quote::quote! {
let #ident = <&#ty as CubeType>::ExpandType =
*inputs.remove(&#pos).unwrap().downcast().unwrap();
});
}
for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() {
variables.extend(quote::quote! {
let #ident = <&mut #ty as CubeType>::ExpandType =
*outputs.remove(&#pos).unwrap().downcast().unwrap();
});
}
let mut tokens = quote::quote! {
let mut builder = KernelBuilder::default();
let mut inputs: std::collections::BTreeMap<usize, std::sync::Arc<dyn core::any::Any>> = std::collections::BTreeMap::new();
let mut outputs: std::collections::BTreeMap<usize, std::sync::Arc<dyn core::any::Any>> = std::collections::BTreeMap::new();
for mapping in self.settings.mappings.iter() {
if !inputs.contains_key(&mapping.pos_input) {
inputs.insert(
mapping.pos_input,
#register_input_call(&mut builder, &self.settings, mapping.pos_input),
);
}
let input = inputs.get(&mapping.pos_input).unwrap();
outputs.insert(mapping.pos_output, input.clone());
}
#register_input
#register_output
};
if num_inputs > 0 {
tokens.extend(quote::quote! {
for i in 0..#num_inputs {
if !inputs.contains_key(&i) {
inputs.insert(i, #register_input_call(&mut builder, &self.settings, i));
}
}
});
}
if num_outputs > 0 {
tokens.extend(quote::quote! {
for i in 0..#num_outputs {
if !outputs.contains_key(&i) {
outputs.insert(i, #register_output_call(&mut builder, &self.settings, i));
}
}
});
}
tokens
}
fn gen_compile_impl(&self, expand: &TokenStream) -> TokenStream {
let ident = Ident::new(&self.name, Span::call_site());
let generics = add_runtime(self.generics.clone());
let (impl_gen, ty_gen, where_gen) = generics.split_for_impl();
let mut format_str = "{:?}-{}".to_string();
for _ in 0..self.state_comptimes.len() {
format_str.push_str("-{:?}");
}
let mut format_args = quote::quote! { core::any::TypeId::of::<Self>(), self.settings, };
for (_, ident) in self.state_comptimes.iter() {
format_args.extend(quote::quote! { self.#ident, });
}
let define_args = self.gen_define_args();
let define_impl = self.gen_define_impl(expand);
quote::quote! {
impl #impl_gen Kernel for #ident #ty_gen #where_gen {
fn define(&self) -> KernelDefinition {
#define_args
#define_impl
}
fn id(&self) -> String {
format!(#format_str, #format_args)
}
}
}
}
fn phantoms(&self, generics: &Generics, declaration: bool) -> TokenStream {
let mut phantoms = quote::quote! {};
for param in generics.params.iter() {
let ty = match param {
syn::GenericParam::Type(ty) => ty,
_ => continue,
};
let ident = Ident::new(
format!("_{}", ty.ident.to_string().to_lowercase()).as_str(),
Span::call_site(),
);
let ty = &ty.ident;
if declaration {
phantoms.extend(quote::quote! {
#ident: core::marker::PhantomData<#ty>,
});
} else {
phantoms.extend(quote::quote! {
#ident: core::marker::PhantomData::<#ty>,
});
}
}
phantoms
}
fn gen_launch_body(&self) -> TokenStream {
let ident = Ident::new(&self.name, Span::call_site());
let generics = add_runtime(self.generics.clone());
let phantoms = self.phantoms(&generics, false);
let mut comptimes = quote::quote! {};
let settings = self.gen_settings();
let mut body = quote::quote! {
let mut launcher = KernelLauncher::<R>::default();
};
for (input, _) in self.state_inputs.iter() {
body.extend(quote::quote! {
#input.register(&mut launcher);
});
}
for (input, _) in self.state_outputs.iter() {
body.extend(quote::quote! {
#input.register(&mut launcher);
});
}
for (_ty, ident) in self.state_comptimes.iter() {
comptimes.extend(quote::quote! {
#ident,
});
}
let kernel = quote::quote! {
#ident {
settings,
#comptimes
#phantoms
}
};
quote::quote! {
#settings
let kernel = #kernel;
#body
launcher.launch(cube_count, kernel, client);
}
}
}
pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
let codegen = Codegen::from_sig(sig);
let ident = &sig.ident;
let ident_expand = quote::quote! {
__expand
};
let generics = add_runtime(add_lifetime(sig.generics.clone()));
let body = codegen.gen_launch_body();
let kernel = codegen.gen_kernel_struct();
let compile = codegen.gen_compile_impl(&ident_expand);
let (inputs, output) = (codegen.fn_inputs, codegen.fn_output);
let doc = format!("Launch the kernel [{ident}()] on the given runtime.");
quote::quote! {
#kernel
#compile
#[allow(clippy::too_many_arguments)]
#[doc = #doc]
pub fn launch #generics (
client: ComputeClient<R::Server, R::Channel>,
cube_count: CubeCount<R::Server>,
cube_dim: CubeDim,
#inputs
) -> #output {
#body;
}
}
}
pub fn add_lifetime(mut generics: Generics) -> Generics {
let lifetime: syn::Generics = parse_quote! {<'a>};
generics
.params
.insert(0, lifetime.params.into_iter().next().unwrap());
generics
}
pub fn add_runtime(mut generics: Generics) -> Generics {
let runtime: syn::Generics = parse_quote! { <R: Runtime> };
generics
.params
.push(runtime.params.into_iter().next().unwrap());
generics
}
fn first_generic_ty(ty: &syn::Type) -> syn::Type {
match ty {
syn::Type::Path(pat) => match &pat.path.segments.first().unwrap().arguments {
syn::PathArguments::AngleBracketed(ty) => match ty.args.first().unwrap() {
syn::GenericArgument::Type(ty) => ty.clone(),
_ => panic!("Should have a generic type"),
},
_ => panic!("Comptime must have a generic"),
},
_ => todo!(),
}
}
fn no_ref(ty: &syn::Type) -> &syn::Type {
match ty {
syn::Type::Reference(val) => &val.elem,
_ => ty,
}
}

View File

@ -1,10 +0,0 @@
mod base;
mod branch;
mod expr;
mod function;
mod launch;
mod operation;
mod variable;
pub(crate) use base::codegen_statement;
pub(crate) use launch::codegen_launch;

View File

@ -1,270 +0,0 @@
use crate::tracker::VariableTracker;
use super::{base::Codegen, expr::codegen_expr};
/// Codegen for binary operations (+, -, *, etc.)
pub(crate) fn codegen_binary(
binary: &syn::ExprBinary,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
let lhs = codegen_expr(&binary.left, loop_level, variable_tracker);
let (lhs, is_comptime_lhs, lhs_array) = (lhs.tokens, lhs.is_comptime, lhs.array_indexing);
let (rhs, is_comptime_rhs) = codegen_expr(&binary.right, loop_level, variable_tracker).split();
if is_comptime_lhs && is_comptime_rhs {
return Codegen::new(
quote::quote! {
#binary
},
true,
);
}
Codegen::new(
match binary.op {
syn::BinOp::Add(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::add::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Sub(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::sub::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Mul(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::mul::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Div(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::div::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Rem(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::rem::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Ne(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::ne::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Gt(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::gt::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Ge(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::ge::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Lt(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::lt::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Le(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::le::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Eq(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::eq::expand(context, _lhs, _rhs)
}
},
syn::BinOp::AddAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::add_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs)
}
}
}
}
syn::BinOp::SubAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::sub_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs)
}
}
}
}
syn::BinOp::MulAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::mul_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs)
}
}
}
}
syn::BinOp::DivAssign(_) => {
if let Some(array) = lhs_array {
let (array, index) = (array.array, array.index);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #rhs;
burn_cube::frontend::div_assign_array_op::expand(context, _array, _index, _value)
}
}
} else {
quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs)
}
}
}
}
syn::BinOp::And(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::and::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Or(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::or::expand(context, _lhs, _rhs)
}
},
syn::BinOp::BitAnd(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::bitand::expand(context, _lhs, _rhs)
}
},
syn::BinOp::BitXor(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::bitxor::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Shl(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::shl::expand(context, _lhs, _rhs)
}
},
syn::BinOp::Shr(_) => quote::quote! {
{
let _lhs = #lhs;
let _rhs = #rhs;
burn_cube::frontend::shr::expand(context, _lhs, _rhs)
}
},
_ => todo!("Codegen: unsupported op {:?}", binary.op),
},
false,
)
}
/// Codegen for unary operations
pub(crate) fn codegen_unary(
unary: &syn::ExprUnary,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
let (inner, is_comptime) = codegen_expr(&unary.expr, loop_level, variable_tracker).split();
if is_comptime {
return Codegen::new(
quote::quote! {
#unary
},
true,
);
}
Codegen::new(
match unary.op {
syn::UnOp::Not(_) => quote::quote! {
{
let _inner = #inner;
burn_cube::frontend::not::expand(context, _inner)
}
},
_ => todo!("Codegen: unsupported op {:?}", unary.op),
},
false,
)
}

View File

@ -1,322 +0,0 @@
use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::{punctuated::Punctuated, FieldValue, Lit, Member, PathArguments, Token};
use crate::{analyzer::KEYWORDS, codegen_function::expr::codegen_expr, tracker::VariableTracker};
use super::base::Codegen;
/// Codegen for literals
pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
match lit.lit {
// We treat floats differently to avoid getting 4..into() for instance
Lit::Float(_) => {
let lit_str = lit.lit.to_token_stream().to_string();
let float_lit = lit_str.parse::<f32>().unwrap();
quote::quote! { #float_lit }
}
_ => {
quote::quote! { #lit }
}
}
}
/// Codegen for arrays of literals
pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
let mut tokens = quote::quote! {};
for element in array.elems.iter() {
let token = match element {
syn::Expr::Lit(lit) => codegen_lit(lit),
_ => {
return syn::Error::new_spanned(array, "Only arrays of literals are supported")
.into_compile_error()
}
};
tokens.extend(quote::quote! { #token, });
}
quote::quote! { [ #tokens ] }
}
/// Codegen for a local declaration (let ...)
/// Supports:
/// let x = ...
/// let x: T = ...
/// let _ = ...
/// let mut _ = ...
pub(crate) fn codegen_local(
local: &syn::Local,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let let_tok = local.let_token;
let ident = match &local.pat {
syn::Pat::Ident(ident) => ident.to_token_stream(),
syn::Pat::Type(pat_type) => match &*pat_type.pat {
syn::Pat::Ident(pat_ident) => pat_ident.to_token_stream(),
_ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat),
},
syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(),
_ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat),
};
variable_tracker.codegen_declare(ident.to_string(), loop_level as u8);
match local.init.as_ref() {
Some(init) => {
let (init, is_comptime) =
codegen_expr(&init.expr, loop_level, variable_tracker).split();
if is_comptime {
variable_tracker
.set_as_comptime(ident.to_string(), loop_level as u8, None)
.unwrap();
}
if is_comptime {
quote::quote! {
#let_tok #ident = #init;
}
} else {
quote::quote! {
#let_tok #ident = {
let _inner = #init;
burn_cube::frontend::Init::init(_inner, context)
};
}
}
}
None => {
quote::quote! {
#let_tok #ident;
}
}
}
}
/// Codegen for indexed access
pub(crate) fn codegen_index(
index: &syn::ExprIndex,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
let array = codegen_expr(&index.expr, loop_level, variable_tracker);
let index = codegen_expr(&index.index, loop_level, variable_tracker);
let tokens = quote::quote! {
{
let _array = #array;
let _index = #index;
burn_cube::frontend::index::expand(context, _array, _index)
}
};
let mut codegen = Codegen::new(tokens, false);
codegen.array_indexing = Some(super::base::ArrayIndexing {
array: array.tokens,
index: index.tokens,
});
codegen
}
/// Codegen for assignation
/// Supports:
/// - scalar
/// - indexed array
pub(crate) fn codegen_assign(
assign: &syn::ExprAssign,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
match assign.left.as_ref() {
syn::Expr::Index(index) => {
let array = codegen_expr(&index.expr, loop_level, variable_tracker);
let index = codegen_expr(&index.index, loop_level, variable_tracker);
let value = codegen_expr(&assign.right, loop_level, variable_tracker);
quote::quote! {
{
let _array = #array;
let _index = #index;
let _value = #value;
burn_cube::frontend::index_assign::expand(context, _array, _index, _value)
}
}
}
syn::Expr::Path(_) => {
let lhs = codegen_expr(&assign.left, loop_level, variable_tracker);
let rhs = codegen_expr(&assign.right, loop_level, variable_tracker);
quote::quote! {
{
let _assign_lhs = #lhs;
let _assign_rhs = #rhs;
burn_cube::frontend::assign::expand(context, _assign_rhs, _assign_lhs)
}
}
}
syn::Expr::Field(_) => {
let lhs = codegen_expr(&assign.left, loop_level, variable_tracker);
let rhs = codegen_expr(&assign.right, loop_level, variable_tracker);
quote::quote! {
{
let _assign_lhs = #lhs;
let _assign_rhs = #rhs;
burn_cube::frontend::assign::expand(context, _assign_rhs, _assign_lhs)
}
}
}
_ => todo!("Assign of expr {:?} unsupported", assign.left),
}
}
pub(crate) fn codegen_path_var(
path: &syn::ExprPath,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> Codegen {
let ident = match path.path.get_ident() {
Some(ident) => ident,
None => {
return Codegen::new(
quote::quote! {
#path
},
false,
);
}
};
let name = ident.to_string();
if name == "None" {
return Codegen::new(quote::quote! { None }, true);
}
if KEYWORDS.contains(&name.as_str()) {
Codegen::new(
quote::quote! {
#ident :: expand(context)
},
false,
)
} else {
let (will_be_used_again, is_comptime) = variable_tracker
.codegen_reuse(name, loop_level as u8, None)
.unwrap_or((true, false));
let output = if will_be_used_again {
quote::quote! {
#ident.clone()
}
} else {
quote::quote! {
#ident
}
};
Codegen::new(output, is_comptime)
}
}
/// Codegen for a field used in rhs of a statement
/// This function adds cloning when necessary
pub(crate) fn codegen_field(
field: &syn::ExprField,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let (struct_, field) = if let Member::Named(attribute_ident) = &field.member {
if let syn::Expr::Path(struct_expr) = &*field.base {
let struct_ident = struct_expr
.path
.get_ident()
.expect("Codegen: field access only supported on ident struct.");
(struct_ident, attribute_ident)
} else {
todo!("Codegen: field access only supported on ident struct.");
}
} else {
todo!("Codegen: unnamed attribute not supported.");
};
let (will_be_used_again, _) = variable_tracker
.codegen_reuse(
struct_.to_string(),
loop_level as u8,
Some(field.to_string()),
)
.unwrap();
if will_be_used_again {
quote::quote! {
#struct_ . #field .clone()
}
} else {
quote::quote! {
#struct_ . #field
}
}
}
// Codegen for a struct declaration
pub(crate) fn codegen_struct(
struct_: &syn::ExprStruct,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let mut deconstructed_path = Vec::new();
for segment in struct_.path.segments.iter() {
let generics = if let PathArguments::AngleBracketed(arguments) = &segment.arguments {
Some(arguments)
} else {
None
};
deconstructed_path.push((&segment.ident, generics));
}
let (struct_name, generics) = deconstructed_path
.pop()
.expect("At least one ident in the path");
// This is hacky but using <struct_ as CubeType>::ExpandType {...} is experimental in Rust
let expanded_struct_name = syn::Ident::new(
format!("{}Expand", struct_name).as_str(),
proc_macro2::Span::call_site(),
);
deconstructed_path.push((&expanded_struct_name, generics));
// Reconstruct the path
let mut path_tokens = quote::quote! {};
for (ident, angle_bracketed_generics) in deconstructed_path {
let ident_tokens = ident.to_token_stream();
let generics_tokens = angle_bracketed_generics.to_token_stream();
path_tokens.extend(quote::quote! {
#ident_tokens #generics_tokens
});
}
let fields = codegen_field_creation(&struct_.fields, loop_level, variable_tracker);
quote::quote! {
#path_tokens { #fields }
}
}
fn codegen_field_creation(
fields: &Punctuated<FieldValue, Token![,]>,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let mut field_tokens = quote::quote! {};
for field in fields.iter() {
let field_name_token = &field.member;
let field_value_token = codegen_expr(&field.expr, loop_level, variable_tracker);
field_tokens.extend(quote::quote! { #field_name_token : #field_value_token, });
}
field_tokens
}

View File

@ -1,110 +0,0 @@
use proc_macro2::TokenStream;
use crate::codegen_common::signature::{expand_sig, ExpandMode};
pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
let mut expand_items = Vec::new();
for item in tr.items.iter() {
match item {
syn::TraitItem::Fn(func) => {
let expand = expand_sig(
&func.sig,
&syn::Visibility::Inherited,
None,
ExpandMode::MethodImpl,
);
expand_items.push(syn::parse_quote!(#expand;));
}
_ => continue,
}
}
tr.items.append(&mut expand_items);
quote::quote! {
#tr
}
}
pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
let mut expand_items = Vec::new();
for item in tr.items.iter() {
match item {
syn::ImplItem::Fn(func) => {
let ident = &func.sig.ident;
let ident = quote::quote! {#ident::__expand};
let mut inputs = quote::quote!();
for input in &func.sig.inputs {
match input {
syn::FnArg::Typed(pat) => {
let ident = pat.pat.clone();
inputs.extend(quote::quote! {
#ident,
});
}
_ => todo!("Only Typed inputs are supported"),
}
}
let expand = expand_sig(
&func.sig,
&syn::Visibility::Inherited,
None,
ExpandMode::MethodImpl,
);
let tokens = if !tr.generics.params.is_empty() {
let mut func = func.clone();
for param in tr.generics.params.iter() {
func.sig.generics.params.push(param.clone());
}
register_expand(&func, &ident, expand, inputs)
} else {
register_expand(func, &ident, expand, inputs)
};
expand_items.push(syn::parse2(tokens).unwrap());
}
_ => continue,
}
}
tr.items.append(&mut expand_items);
quote::quote! {
#tr
}
}
fn register_expand(
func: &syn::ImplItemFn,
name: &TokenStream,
expand: proc_macro2::TokenStream,
inputs: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let (func, func_expand) = if func.sig.generics.params.is_empty() {
(
quote::quote! { #func },
quote::quote! {
#name(context, #inputs)
},
)
} else {
let (_, gen, _) = &func.sig.generics.split_for_impl();
(
quote::quote! { #func },
quote::quote! {
#name::#gen(context, #inputs)
},
)
};
quote::quote! (
#expand {
#[cube]
#func
#func_expand
}
)
}

View File

@ -1,294 +0,0 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::Ident;
use super::GenericsCodegen;
struct TypeCodegen {
name: syn::Ident,
name_launch: syn::Ident,
name_expand: syn::Ident,
fields: Vec<syn::Field>,
generics: GenericsCodegen,
vis: syn::Visibility,
}
impl TypeCodegen {
pub fn expand_ty(&self) -> proc_macro2::TokenStream {
let mut fields = quote::quote! {};
let name = &self.name_expand;
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;
fields.extend(quote! {
#vis #ident: <#ty as CubeType>::ExpandType,
});
}
let generics = self.generics.type_definitions();
let vis = &self.vis;
quote! {
#[derive(Clone)]
#vis struct #name #generics {
#fields
}
}
}
pub fn launch_ty(&self) -> proc_macro2::TokenStream {
let mut fields = quote::quote! {};
let name = &self.name_launch;
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;
fields.extend(quote! {
#vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
});
}
let generics = self.generics.all_definitions();
quote! {
struct #name #generics {
#fields
}
}
}
pub fn launch_new(&self) -> proc_macro2::TokenStream {
let mut args = quote::quote! {};
let mut fields = quote::quote! {};
let name = &self.name_launch;
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;
args.extend(quote! {
#vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>,
});
fields.extend(quote! {
#ident,
});
}
let generics_impl = self.generics.all_definitions();
let generics_use = self.generics.all_in_use();
let vis = &self.vis;
quote! {
impl #generics_impl #name #generics_use {
/// New kernel
#[allow(clippy::too_many_arguments)]
#vis fn new(#args) -> Self {
Self {
#fields
}
}
}
}
}
pub fn arg_settings_impl(&self) -> proc_macro2::TokenStream {
let mut register_body = quote::quote! {};
let mut configure_input_body = quote::quote! {};
let mut configure_output_body = quote::quote! {};
let name = &self.name_launch;
for (pos, field) in self.fields.iter().enumerate() {
let ident = &field.ident;
register_body.extend(quote! {
self.#ident.register(launcher);
});
configure_input_body.extend(quote! {
settings = ArgSettings::<R>::configure_input(&self.#ident, #pos, settings);
});
configure_output_body.extend(quote! {
settings = ArgSettings::<R>::configure_output(&self.#ident, #pos, settings);
});
}
let generics_impl = self.generics.all_definitions();
let generics_use = self.generics.all_in_use();
quote! {
impl #generics_impl ArgSettings<R> for #name #generics_use {
fn register(&self, launcher: &mut KernelLauncher<R>) {
#register_body
}
fn configure_input(&self, position: usize, mut settings: KernelSettings) -> KernelSettings {
#configure_input_body
settings
}
fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings {
#configure_output_body
settings
}
}
}
}
pub fn cube_type_impl(&self) -> proc_macro2::TokenStream {
let name = &self.name;
let name_expand = &self.name_expand;
let generics_impl = self.generics.type_definitions();
let generics_use = self.generics.type_in_use();
quote! {
impl #generics_impl CubeType for #name #generics_use {
type ExpandType = #name_expand #generics_use;
}
}
}
pub fn launch_arg_impl(&self) -> proc_macro2::TokenStream {
let mut body_input = quote::quote! {};
let mut body_output = quote::quote! {};
let name = &self.name;
let name_launch = &self.name_launch;
let name_expand = &self.name_expand;
for field in self.fields.iter() {
let ident = &field.ident;
let ty = &field.ty;
let vis = &field.vis;
body_input.extend(quote! {
#vis #ident: <#ty as LaunchArgExpand>::expand(builder, vectorization),
});
body_output.extend(quote! {
#vis #ident: <#ty as LaunchArgExpand>::expand_output(builder, vectorization),
});
}
let type_generics_impl = self.generics.type_definitions();
let type_generics_use = self.generics.type_in_use();
let runtime_generics_impl = self.generics.runtime_definitions();
let all_generics_use = self.generics.all_in_use();
quote! {
impl #type_generics_impl LaunchArg for #name #type_generics_use {
type RuntimeArg #runtime_generics_impl = #name_launch #all_generics_use;
}
impl #type_generics_impl LaunchArgExpand for #name #type_generics_use {
fn expand(
builder: &mut KernelBuilder,
vectorization: burn_cube::ir::Vectorization,
) -> <Self as CubeType>::ExpandType {
#name_expand {
#body_input
}
}
fn expand_output(
builder: &mut KernelBuilder,
vectorization: burn_cube::ir::Vectorization,
) -> <Self as CubeType>::ExpandType {
#name_expand {
#body_output
}
}
}
}
}
pub fn expand_type_impl(&self) -> proc_macro2::TokenStream {
let name_expand = &self.name_expand;
let type_generics_impl = self.generics.type_definitions();
let type_generics_use = self.generics.type_in_use();
let mut body = quote::quote! {};
for field in self.fields.iter() {
let ident = &field.ident;
body.extend(quote::quote! {
#ident: Init::init(self.#ident, context),
});
}
quote! {
impl #type_generics_impl Init for #name_expand #type_generics_use {
fn init(self, context: &mut CubeContext) -> Self {
Self {
#body
}
}
}
}
}
}
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream {
let name = ast.ident.clone();
let generics = ast.generics.clone();
let visibility = ast.vis.clone();
let name_string = name.to_string();
let name_expand = Ident::new(format!("{}Expand", name_string).as_str(), name.span());
let name_launch = Ident::new(format!("{}Launch", name_string).as_str(), name.span());
let mut fields = Vec::new();
match &ast.data {
syn::Data::Struct(struct_data) => {
for field in struct_data.fields.iter() {
fields.push(field.clone());
}
}
syn::Data::Enum(_) => panic!("Only struct can be derived"),
syn::Data::Union(_) => panic!("Only struct can be derived"),
};
let codegen = TypeCodegen {
name,
name_launch,
name_expand,
fields,
generics: GenericsCodegen::new(generics),
vis: visibility,
};
let expand_ty = codegen.expand_ty();
let launch_ty = codegen.launch_ty();
let launch_new = codegen.launch_new();
let cube_type_impl = codegen.cube_type_impl();
let arg_settings_impl = codegen.arg_settings_impl();
let launch_arg_impl = codegen.launch_arg_impl();
let expand_type_impl = codegen.expand_type_impl();
if with_launch {
quote! {
#expand_ty
#launch_ty
#launch_new
#cube_type_impl
#arg_settings_impl
#launch_arg_impl
#expand_type_impl
}
.into()
} else {
quote! {
#expand_ty
#cube_type_impl
#expand_type_impl
}
.into()
}
}

View File

@ -1,81 +0,0 @@
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use syn::{GenericParam, Generics, Ident, Lifetime, LifetimeParam, TypeParam};
pub(crate) struct GenericsCodegen {
arg_lifetime: syn::Generics,
arg_runtime: syn::Generics,
type_gens: syn::Generics,
}
impl GenericsCodegen {
pub(crate) fn new(type_gens: syn::Generics) -> Self {
Self {
arg_lifetime: Self::arg_lifetime(),
arg_runtime: Self::arg_runtime(),
type_gens,
}
}
fn arg_lifetime() -> Generics {
let mut generics = Generics::default();
let lifetime =
GenericParam::Lifetime(LifetimeParam::new(Lifetime::new("'a", Span::call_site())));
generics.params.push(lifetime);
generics
}
fn arg_runtime() -> Generics {
let mut generics = Generics::default();
let mut runtime_param = TypeParam::from(Ident::new("R", Span::call_site()));
runtime_param
.bounds
.push(syn::parse_str("Runtime").unwrap());
let runtime = GenericParam::Type(runtime_param);
generics.params.push(runtime);
generics
}
pub(crate) fn type_definitions(&self) -> TokenStream {
self.type_gens.to_token_stream()
}
pub(crate) fn type_in_use(&self) -> TokenStream {
generics_in_use_codegen(self.type_gens.clone())
}
pub(crate) fn runtime_definitions(&self) -> TokenStream {
let mut generics = self.arg_runtime.clone();
generics.params.extend(self.arg_lifetime.params.clone());
generics.to_token_stream()
}
pub(crate) fn all_definitions(&self) -> TokenStream {
let mut generics = self.arg_lifetime.clone();
generics.params.extend(self.arg_runtime.params.clone());
generics.params.extend(self.type_gens.params.clone());
generics.to_token_stream()
}
pub(crate) fn all_in_use(&self) -> TokenStream {
let mut generics = self.arg_lifetime.clone();
generics.params.extend(self.arg_runtime.params.clone());
generics.params.extend(self.type_gens.params.clone());
generics_in_use_codegen(generics)
}
}
fn generics_in_use_codegen(generics: Generics) -> TokenStream {
let mut tokens = quote::quote! {<};
for generic in generics.params.iter() {
let ident = match generic {
GenericParam::Lifetime(param) => param.lifetime.to_token_stream(),
GenericParam::Type(param) => param.ident.to_token_stream(),
GenericParam::Const(_) => todo!("Const generic not supported"),
};
tokens.extend(quote::quote! { #ident, })
}
tokens.extend(quote::quote! {>});
tokens
}

View File

@ -1,5 +0,0 @@
mod base;
mod generics;
pub(crate) use base::*;
pub(crate) use generics::*;

View File

@ -1,182 +0,0 @@
#[macro_use]
extern crate derive_new;
mod analyzer;
mod codegen_function;
mod codegen_trait;
mod codegen_type;
mod tracker;
pub(crate) mod codegen_common;
use analyzer::VariableAnalyzer;
use codegen_common::signature::{expand_sig, ExpandMode};
use codegen_function::{codegen_launch, codegen_statement};
use codegen_trait::{expand_trait_def, expand_trait_impl};
use codegen_type::generate_cube_type;
use proc_macro::TokenStream;
use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta};
use tracker::VariableTracker;
enum CubeMode {
/// Generates the expanded version of the function
Default,
/// Panics and prints the generated code, useful when debugging
/// Use by writing #[cube(panic)]
Debug,
}
// Derive macro to define a cube type that is launched with a kernel
#[proc_macro_derive(CubeLaunch)]
pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
generate_cube_type(&input, true)
}
// Derive macro to define a cube type that is not launched
#[proc_macro_derive(CubeType)]
pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
generate_cube_type(&input, false)
}
struct SupportedAttributes {
mode: CubeMode,
launch: bool,
}
/// Derive macro for the module.
#[proc_macro_attribute]
pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
let attrs = parse_attributes(&args);
let code: TokenStream = match syn::parse::<syn::Item>(tokens).unwrap() {
syn::Item::Fn(func) => cube_fn(func, &attrs),
syn::Item::Impl(item) => expand_trait_impl(item).into(),
syn::Item::Trait(item) => expand_trait_def(item).into(),
_ => panic!("Cube annotations only supported for functions"),
};
match attrs.mode {
CubeMode::Default => code,
CubeMode::Debug => panic!("{code}"),
}
}
fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream {
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
match codegen_cube(&func, &mut variable_tracker, attrs.launch) {
Ok(code) => code.into(),
Err(err) => err.into(),
}
}
fn parse_attributes(args: &Punctuated<Meta, Comma>) -> SupportedAttributes {
let mut mode = CubeMode::Default;
let mut launch = false;
for arg in args.iter() {
match arg {
Meta::Path(path) => {
if let Some(ident) = path.get_ident().map(|id| id.to_string()) {
match ident.as_str() {
"debug" => {
mode = CubeMode::Debug;
}
"launch" => {
launch = true;
}
_ => panic!("Attribute {ident} is not supported"),
}
} else {
panic!("Only ident attribute supported");
}
}
Meta::List(_) => panic!("No List attribute supported"),
Meta::NameValue(_) => panic!("No NameValue attribute supported"),
}
}
SupportedAttributes { mode, launch }
}
/// Generate the expanded version of a function marked with the cube macro
fn codegen_cube(
func: &syn::ItemFn,
variable_tracker: &mut VariableTracker,
launch: bool,
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
let signature = expand_sig(
&func.sig,
&syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import
// it from an outside module.
Some(variable_tracker),
ExpandMode::FuncImpl,
);
let mut body = quote::quote! {};
for statement in func.block.stmts.iter() {
let tokens = codegen_statement(statement, 0, variable_tracker);
body.extend(tokens);
}
let is_in_error = !variable_tracker.errors.is_empty();
if is_in_error {
// When there is an error, we don't generate the expand method, since it's only going to
// create more errors that won't help fixing the issue.
let mut code = quote::quote! {
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
#func
};
for err in variable_tracker.errors.drain(..) {
code.extend(err.into_compile_error());
}
return Err(code);
}
let launch_doc = if launch {
"and launch functions "
} else {
"function "
};
let launch = if launch {
codegen_launch(&func.sig)
} else {
quote::quote! {}
};
let mod_name = &func.sig.ident;
let vis = &func.vis;
let doc = format!("Module containing the expand {launch_doc}of {mod_name}.");
Ok(quote::quote! {
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
#func
#[doc = #doc]
#vis mod #mod_name {
use super::*;
#launch
#[allow(unused_mut)]
#[allow(clippy::too_many_arguments)]
#signature {
#body
}
}
})
}

View File

@ -1,244 +0,0 @@
use std::collections::HashMap;
#[derive(new, Hash, PartialEq, Eq, Debug, Clone)]
/// Identifies a variable uniquely
pub struct VariableIdent {
name: String,
repeat: u8,
scope: u8,
field: Option<String>,
}
#[derive(new, Eq, PartialEq, Hash, Debug)]
/// Identifies a variable, with possible collisions when variables are redeclared
struct VariableKey {
name: String,
scope: u8,
}
#[derive(Debug, Default)]
/// Tracks variable uses
pub(crate) struct VariableTracker {
scopes_declared: HashMap<String, Vec<u8>>,
analysis_repeats: HashMap<VariableKey, u8>,
codegen_repeats: HashMap<VariableKey, u8>,
variable_uses: HashMap<VariableIdent, VariableUse>,
pub errors: Vec<syn::Error>,
}
#[derive(Debug, Default)]
/// Encapsulates number of uses and whether this implies cloning
pub(crate) struct VariableUse {
pub num_used: usize,
pub is_comptime: bool,
}
impl VariableUse {
pub fn should_clone(&self) -> bool {
self.num_used > 1
}
}
impl VariableTracker {
/// During analysis, tracks a variable declaration
pub(crate) fn analyze_declare(&mut self, name: String, scope: u8, is_comptime: bool) {
if let Some(scopes) = self.scopes_declared.get_mut(&name) {
if !scopes.contains(&scope) {
scopes.push(scope);
}
} else {
self.scopes_declared.insert(name.clone(), vec![scope]);
}
let key = VariableKey::new(name.clone(), scope);
let repeat = if let Some(count) = self.analysis_repeats.get_mut(&key) {
*count += 1;
*count
} else {
self.analysis_repeats.insert(key, 0);
0
};
let analysis = VariableUse {
num_used: 1,
is_comptime,
};
let variable_ident = VariableIdent::new(name, repeat, scope, None);
self.variable_uses.insert(variable_ident, analysis);
}
/// During analysis, tracks a variable use
pub(crate) fn analyze_reuse(&mut self, ident: &syn::Ident, scope: u8, field: Option<String>) {
let name = ident.to_string();
if name == "None" {
return;
}
let scopes_declared = match self.scopes_declared.get(&name) {
Some(val) => val,
None => {
self.errors
.push(syn::Error::new_spanned(ident, "Variable not declared"));
return;
}
};
let scope = *scopes_declared
.iter()
.filter(|s| **s <= scope)
.max()
.unwrap();
let key = VariableKey::new(name.clone(), scope);
// If the name and scope do not match a declared variable,
// then we are using a variable declared in a parent scope, and
// cloning must always happen, therefore no need for further analysis
if let Some(repeat) = self.analysis_repeats.get(&key) {
let variable = VariableIdent::new(name, *repeat, scope, field);
self.analyze(&variable);
}
}
/// Increments variable use and its parent struct if need be
fn analyze(&mut self, variable_ident: &VariableIdent) {
match self.variable_uses.get_mut(variable_ident) {
Some(variable_use) => {
variable_use.num_used += 1;
}
None => {
// If variable was not inserted yet, it must be a field
if variable_ident.field.is_some() {
let mut parent_ident = variable_ident.clone();
parent_ident.field = None;
let parent = self.variable_uses.get(&parent_ident).unwrap();
let attr_analysis = VariableUse {
num_used: 1,
is_comptime: parent.is_comptime,
};
self.variable_uses
.insert(variable_ident.clone(), attr_analysis);
} else {
panic!("Variable not declared");
}
}
};
// Whether a field was previously seen or not, we must increase the use of the parent struct
if variable_ident.field.is_some() {
let mut declaration_ident = variable_ident.clone();
declaration_ident.field = None;
let declaration = self
.variable_uses
.get_mut(&declaration_ident)
.unwrap_or_else(|| panic!("Struct {:?} does not exist", declaration_ident));
declaration.num_used += 1;
}
}
/// During codegen, tracks a variable declaration.
/// This must be done again to know on what repeat a use occurs
pub(crate) fn codegen_declare(&mut self, name: String, scope: u8) {
let key = VariableKey::new(name.clone(), scope);
if let Some(count) = self.codegen_repeats.get_mut(&key) {
*count += 1;
} else {
self.codegen_repeats.insert(key, 0);
}
}
/// During codegen, tracks a variable use.
pub(crate) fn codegen_reuse(
&mut self,
name: String,
scope: u8,
field: Option<String>,
) -> Result<(bool, bool), VariableReuseError> {
let scopes_declared = self
.scopes_declared
.get(&name)
.ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?;
let scope_declared = *scopes_declared
.iter()
.filter(|s| **s <= scope)
.max()
.ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?;
let key = VariableKey::new(name.clone(), scope_declared);
let repeat = self.codegen_repeats.get(&key).unwrap_or(&0);
let ident = VariableIdent::new(name.clone(), *repeat, scope_declared, field.clone());
let should_clone_parent = if field.is_some() {
let struct_ident = VariableIdent::new(name.clone(), *repeat, scope_declared, None);
let parent_analysis = self
.variable_uses
.get_mut(&struct_ident)
.ok_or_else(|| VariableNotFound::new(name.clone(), scope_declared, None))?;
parent_analysis.num_used -= 1;
parent_analysis.should_clone()
} else {
false
};
let analysis = self
.variable_uses
.get_mut(&ident)
.ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?;
analysis.num_used -= 1;
let should_clone =
analysis.should_clone() || should_clone_parent || scope_declared != scope;
Ok((should_clone, analysis.is_comptime))
}
pub fn set_as_comptime(
&mut self,
name: String,
scope: u8,
field: Option<String>,
) -> Result<(), VariableReuseError> {
let scopes_declared = self
.scopes_declared
.get(&name)
.ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?;
let scope_declared = *scopes_declared
.iter()
.filter(|s| **s <= scope)
.max()
.ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?;
let key = VariableKey::new(name.clone(), scope_declared);
let repeat = self.codegen_repeats.get(&key).unwrap_or(&0);
let ident = VariableIdent::new(name.clone(), *repeat, scope_declared, field.clone());
let analysis = self
.variable_uses
.get_mut(&ident)
.ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?;
analysis.is_comptime = true;
Ok(())
}
}
#[derive(new, Debug)]
pub struct VariableNotFound {
_name: String,
_scope: u8,
_field: Option<String>,
}
#[derive(Debug)]
#[allow(dead_code)]
pub enum VariableReuseError {
VariableNotFound(VariableNotFound),
}
impl From<VariableNotFound> for VariableReuseError {
fn from(value: VariableNotFound) -> Self {
Self::VariableNotFound(value)
}
}

View File

@ -1,37 +0,0 @@
[package]
authors = [
"nathanielsimard <nathaniel.simard.42@gmail.com>",
"louisfd <louisfd94@gmail.com>",
]
categories = ["science"]
description = "Cube Compute Language (CubeCL) is a subset of Rust that can be executed on accelerators for compute intensive tasks."
edition.workspace = true
keywords = []
license.workspace = true
name = "burn-cube"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube"
version.workspace = true
[features]
default = ["tensor"]
std = []
template = []
tensor = ["burn-tensor"]
export_tests = []
[dependencies]
burn-compute = { path = "../burn-compute", version = "0.14.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.14.0", default-features = false, optional = true }
bytemuck = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
serde = { workspace = true }
burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" }
derive-new = { workspace = true }
num-traits = { workspace = true }
log = { workspace = true }
[dev-dependencies]
trybuild = "1"

View File

@ -1,202 +0,0 @@
<div align="center">
<img src="../burn-cube/assets/logo.drawio.svg" width="400px"/>
<br />
<br />
[![Rust Version](https://img.shields.io/badge/Rust-1.79.0+-blue)](https://releases.rs/docs/1.79.0)
![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)
---
**Multi-platform high-performance compute language extension for Rust.**
<br/>
</div>
## TL;DR
With CubeCL, you can program your GPU using Rust leveraging zero-cost abstraction to create maintainable, flexible and optimal compute kernels.
## Motivation
The goal of CubeCL is to ease the pain of writing highly optimized compute kernels that are portable across hardware.
There is currently no adequate solution when you want optimal performance while still being multi-platform.
You either have to write custom kernels for different hardware, often with different languages such as CUDA, Metal, or ROCm.
To fix this, we created a Just-in-Time compiler with three core features: **automatic vectorization**, **comptime**, and **autotune**!
These features are extremely useful for anyone writing high-performance kernels, even when portability is not a concern.
They improve code composability, reusability, testability, and maintainability, all while staying optimal.
### Disclaimer & History
CubeCL is currently in **alpha**.
The only supported runtimes are CUDA and WebGPU for now.
It's easy to add more GPU runtimes and we intend to support Metal, ROCm, and Vulkan; contributions are welcome!
We also want to have an optimized JIT CPU runtime with SIMD instructions, leveraging [Cranelift](https://cranelift.dev).
While CubeCL is currently in use in [Burn](https://burn.dev), there are still a lot of rough edges; it isn't refined yet.
The project started as a WebGPU-only backend for Burn.
As we optimized it, we realized that we needed an intermediate representation (IR) that could be optimized then compiled to WGSL.
Having an IR made it easy to support another compilation target, so we made a CUDA runtime.
However, writing kernels directly in that IR wasn't easy, so we created a Rust frontend using the [syn](https://github.com/dtolnay/syn) crate.
Navigating the differences between CUDA and WebGPU, while leveraging both platforms, forced us to come up with general concepts that worked everywhere.
Hence, CubeCL was born!
## Design
CubeCL is designed around - you guessed it - Cubes! More specifically, it's based on cuboids, because not all axes are the same size.
Since all compute APIs need to map to the hardware, which are tiles that can be accessed using a 3D representation, our topology can easily be mapped to concepts from other APIs.
<div align="center">
### CubeCL - Topology
<img src="./assets/cubecl.drawio.svg" width="100%"/>
<br />
</div>
<br />
_A cube is composed of units, so a 3x3x3 cube has 27 units that can be accessed by their positions along the x, y, and z axes.
Similarly, a hyper-cube is composed of cubes, just as a cube is composed of units.
Each cube in the hyper-cube can be accessed by its position relative to the hyper-cube along the x, y, and z axes.
Hence, a hyper-cube of 3x3x3 will have 27 cubes.
In this example, the total number of working units would be 27 x 27 = 729._
<details>
<summary>Topology Equivalence 👇</summary>
<br />
Since all topology variables are constant within the kernel entry point, we chose to use the Rust constant syntax with capital letters.
Often when creating kernels, we don't always care about the relative position of a unit within a cube along each axis, but often we only care about its position in general.
Therefore, each kind of variable also has its own axis-independent variable, which is often not present in other languages, except WebGPU with `local_invocation_index`.
<br />
| CubeCL | CUDA | WebGPU |
| -------------- | ----------- | ---------------------- |
| CUBE_COUNT | N/A | N/A |
| CUBE_COUNT_X | gridDim.x | num_workgroups.x |
| CUBE_COUNT_Y | gridDim.y | num_workgroups.y |
| CUBE_COUNT_Z | gridDim.z | num_workgroups.z |
| CUBE_POS | N/A | N/A |
| CUBE_POS_X | blockIdx.x | workgroup.x |
| CUBE_POS_Y | blockIdx.y | workgroup.y |
| CUBE_POS_Z | blockIdx.z | workgroup.z |
| CUBE_DIM | N/A | N/A |
| CUBE_DIM_X | blockDim.x | workgroup_size.x |
| CUBE_DIM_Y | blockDim.y | workgroup_size.y |
| CUBE_DIM_Z | blockDim.z | workgroup_size.z |
| UNIT_POS | N/A | local_invocation_index |
| UNIT_POS_X | threadIdx.x | local_invocation_id.x |
| UNIT_POS_Y | threadIdx.y | local_invocation_id.y |
| UNIT_POS_Z | threadIdx.z | local_invocation_id.z |
| SUBCUBE_DIM | warpSize | subgroup_size |
| ABSOLUTE_POS | N/A | N/A |
| ABSOLUTE_POS_X | N/A | global_id.x |
| ABSOLUTE_POS_Y | N/A | global_id.y |
| ABSOLUTE_POS_Z | N/A | global_id.z |
</details>
## Special Features
#### Automatic Vectorization
High-performance kernels should rely on SIMD instructions whenever possible, but doing so can quickly get pretty complicated!
With CubeCL, you can specify the vectorization factor of each input variable when launching a kernel.
Inside the kernel code, you still use only one type, which is dynamically vectorized and supports automatic broadcasting.
The runtimes are able to compile kernels and have all the necessary information to use the best instruction!
However, since the algorithmic behavior may depend on the vectorization factor, CubeCL allows you to access it directly in the kernel when needed, without any performance loss, using the comptime system!
#### Comptime
CubeCL isn't just a new compute language: though it feels like you are writing GPU kernels, you are, in fact, writing compiler plugins that you can fully customize!
Comptime is a way to modify the compiler IR at runtime when compiling a kernel for the first time.
This enables lots of optimizations and flexibility without having to write many separate variants of the same kernels to ensure maximal performance.
| Feature | Description |
| ------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Instruction Specialization** | Not all instructions are available on all hardware, but when a specialized one exists, it should be enabled with a simple if statement. |
| **Automatic Vectorization** | When you can use SIMD instructions, you should! But since not all hardware supports the same vectorization factors, it can be injected at runtime! |
| **Loop Unrolling** | You may want multiple flavors of the same kernel, with loop unrolling for only a certain range of values. This can be configured easily with Comptime. |
| **Shape Specialization** | For deep learning kernels, it's often crucial to rely on different kernels for different input sizes; you can do it by passing the shape information as Comptime values. |
| **Compile Time Calculation** | In general, you can calculate a constant using Rust runtime properties and inject it into a kernel during its compilation, to avoid recalculating it during each execution. |
#### Autotuning
Autotuning drastically simplifies kernel selection by running small benchmarks at runtime to figure out the best kernels with the best configurations to run on the current hardware; an essential feature for portability.
This feature combines gracefully with comptime to test the effect of different comptime values on performance; sometimes it can be surprising!
Even if the benchmarks may add some overhead when running the application for the first time, the information gets cached on the device and will be reused.
It is usually a no-brainer trade-off for throughput-oriented programs such as deep learning models.
You can even ship the autotune cache with your program, reducing cold start time when you have more control over the deployment target.
## Example
CubeCL is designed to be easy to use for Rust programmers: it relies on the same syntax and is fully integrated with the language.
You can simply add an attribute on the top of a Rust function for it to be executed on the GPU.
```rust
#[cube(launch)]
fn gelu<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if ABSOLUTE_POS < input.len() {
let x = input[ABSOLUTE_POS]
let gelu = x * (1 + erf(x / sqrt(2))) / 2;
output[ABSOLUTE_POS] = gelu;
}
}
fn main() {
type Runtime = CudaRuntime;
let device = Default::default();
let client = Runtime::client(&device);
let input_handle = client.create(f32::as_bytes(&[-1., 0., 1., 5.]));
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
gelu::launch::<F32, Runtime>(
client,
CubeCount::new(1, 1, 1),
CubeDim::new(4, 1, 1),
&input_handle,
&output_handle,
);
let output = client.read(output_handle.binding()).read_sync().unwrap();
let output = f32::from_bytes(&output);
// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
println!("{output:?}");
}
```
The `cube` attribute generates the code that is needed to compile a kernel.
In the case above, the function `gelu_expand` and `gelu_launch` are automatically generated.
This allows you to compose Cube functions easily:
```rust
#[cube]
fn gelu_scalar<F: Float>(x: F) -> F {
x * (1 + erf(x / sqrt(2))) / 2
}
#[cube(launch)]
fn gelu<F: Float>(input: Array<F>, mut output: Array<F>) {
if ABSOLUTE_POS < input.shape(0) {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
}
}
```
Note that you don't have to specify `launch` in a function that is only used by another Cube function.
In addition, you can have return types without problem, which isn't the case when you are writing an entry point to a kernel using the `launch` attribute.
The function `gelu_expand` will actually use `gelu_scalar_expand`, making it easy to combine your functions.
## Resource
If you have any questions or want to contribute, don't hesitate to join the [Discord](https://discord.gg/uPEBbYYDB6).

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 240 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 807 KiB

View File

@ -1,21 +0,0 @@
use crate::ir::{Elem, KernelDefinition};
use std::fmt::Display;
/// Trait for compiled code representation
pub trait CompilerRepresentation: Display {
/// Computes and returns the shared memory size
fn shared_memory_size(&self) -> usize;
}
/// Compiles the representation into its own representation that can be formatted into tokens.
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
/// The representation for the compiled code.
type Representation: CompilerRepresentation;
/// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
fn compile(kernel: KernelDefinition) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
/// The maximal size of a shared memory
fn max_shared_memory_size() -> usize;
}

View File

@ -1,358 +0,0 @@
use crate::compute::{CubeCount, KernelTask};
use crate::frontend::TensorHandle;
use crate::ir::Elem;
use crate::pod::CubeElement;
use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX};
use burn_compute::client::ComputeClient;
use burn_compute::server::{Binding, ComputeServer, Handle};
/// The position of the input or output to calculate the number of cubes to launch.
pub enum CubeCountSettings<S: ComputeServer> {
Input { pos: usize },
Output { pos: usize },
Custom(CubeCount<S>),
}
pub struct Execution<'h, K, R: Runtime, Scalars> {
scalars: Scalars,
client: ComputeClient<R::Server, R::Channel>,
kernel: K,
inputs: &'h [TensorHandle<'h, R>],
outputs: &'h [TensorHandle<'h, R>],
}
impl<'h, K, R: Runtime> Execution<'h, K, R, ()> {
pub fn start(
kernel: K,
client: ComputeClient<R::Server, R::Channel>,
) -> Execution<'h, K, R, ()> {
Execution {
scalars: (),
client,
kernel,
inputs: &[],
outputs: &[],
}
}
#[allow(unused)]
pub fn inputs(self, inputs: &'h [TensorHandle<'h, R>]) -> Execution<'h, K, R, ()> {
Execution {
scalars: self.scalars,
client: self.client,
kernel: self.kernel,
inputs,
outputs: self.outputs,
}
}
pub fn outputs(self, outputs: &'h [TensorHandle<'h, R>]) -> Execution<'h, K, R, ()> {
Execution {
scalars: self.scalars,
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs,
}
}
}
impl<'h, K, R> Execution<'h, K, R, ()>
where
K: Kernel + 'static,
R: Runtime,
{
pub fn with_scalars<E>(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> {
Execution {
scalars: (scalars,),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, f32, f32, f32>(
self.inputs,
self.outputs,
None,
None,
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)>
where
K: Kernel + 'static,
R: Runtime,
E: CubeElement,
{
pub fn with_scalars<'b, E2>(
self,
scalars: &'b [E2],
) -> Execution<'h, K, R, (&'a [E], &'b [E2])> {
Execution {
scalars: (self.scalars.0, scalars),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, E, f32, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
None,
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])>
where
K: Kernel + 'static,
R: Runtime,
E1: CubeElement,
E2: CubeElement,
{
#[allow(unused, clippy::type_complexity)]
pub fn with_scalars<'c, E3>(
self,
scalars: &'c [E3],
) -> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> {
Execution {
scalars: (self.scalars.0, self.scalars.1, scalars),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
/// Execute a dynamic kernel.
#[allow(clippy::too_many_arguments)]
pub fn execute(self, launch: CubeCountSettings<R::Server>)
where
K: Kernel + 'static,
R: Runtime,
{
execute_dynamic::<R, K, E1, E2, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Some(self.scalars.1),
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, 'b, 'c, K, R, E1, E2, E3> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])>
where
K: Kernel + 'static,
R: Runtime,
E1: CubeElement,
E2: CubeElement,
E3: CubeElement,
{
/// Execute a dynamic kernel.
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, E1, E2, E3>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Some(self.scalars.1),
Some(self.scalars.2),
self.kernel,
launch,
self.client,
)
}
}
#[allow(clippy::too_many_arguments)]
fn execute_dynamic<R, K, E1, E2, E3>(
inputs: &[TensorHandle<R>],
outputs: &[TensorHandle<R>],
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
kernel: K,
launch: CubeCountSettings<R::Server>,
client: ComputeClient<R::Server, R::Channel>,
) where
K: Kernel + 'static,
R: Runtime,
E1: CubeElement,
E2: CubeElement,
E3: CubeElement,
{
let settings = execute_settings(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
);
let mut handles = settings.handles_tensors;
handles.push(settings.handle_info.binding());
for handle in settings.handles_scalars.into_iter() {
handles.push(handle.binding());
}
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.execute(kernel, settings.cube_count, handles);
}
struct ExecuteSettings<R: Runtime> {
handles_tensors: Vec<Binding<R::Server>>,
handle_info: Handle<R::Server>,
handles_scalars: Vec<Handle<R::Server>>,
cube_count: CubeCount<R::Server>,
}
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
inputs: &'a [TensorHandle<R>],
outputs: &'a [TensorHandle<R>],
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
launch: CubeCountSettings<R::Server>,
client: &ComputeClient<R::Server, R::Channel>,
) -> ExecuteSettings<R> {
let mut info = Vec::new();
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
// Inner function to fill the info buffer.
let mut register_info_tensor = |strides: &[usize], shape: &[usize]| {
if info.is_empty() {
info.push(strides.len() as u32);
}
for s in strides.iter() {
info.push(*s as u32);
}
for s in shape.iter() {
info.push(*s as u32);
}
};
let mut num_elems_output = 0;
// We start by registering the inputs.
for (i, input) in inputs.iter().enumerate() {
if let CubeCountSettings::Input { pos } = &launch {
if i == *pos {
num_elems_output = calculate_num_elems_dyn_rank(input.shape);
}
};
register_info_tensor(input.strides, input.shape);
handles.push(input.handle.clone().binding());
}
// Then we follow with the outputs.
for (i, output) in outputs.iter().enumerate() {
if let CubeCountSettings::Output { pos } = &launch {
if i == *pos {
num_elems_output = calculate_num_elems_dyn_rank(output.shape);
}
};
register_info_tensor(output.strides, output.shape);
handles.push(output.handle.clone().binding());
}
// [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len]
if R::require_array_lengths() {
for input in inputs.iter() {
let len = calculate_num_elems_dyn_rank(input.shape);
info.push(len as u32);
}
for output in outputs.iter() {
let len = calculate_num_elems_dyn_rank(output.shape);
info.push(len as u32);
}
}
let info = client.create(bytemuck::cast_slice(&info));
// Finally we finish with the named bindings.
let handles_scalars =
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
let cube_count = match launch {
CubeCountSettings::Custom(count) => count,
_ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX),
};
ExecuteSettings {
handles_tensors: handles,
handle_info: info,
handles_scalars,
cube_count,
}
}
fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
scalars_0: Option<&[E1]>,
scalars_1: Option<&[E2]>,
scalars_2: Option<&[E3]>,
client: &ComputeClient<R::Server, R::Channel>,
) -> Vec<Handle<R::Server>> {
// It is crucial that scalars follow this order: float, int, uint
let element_priority = |elem: Elem| match elem {
Elem::Float(_) => 0,
Elem::Int(_) => 1,
Elem::UInt => 2,
Elem::Bool => panic!("Bool scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
element_priority(E1::cube_elem()),
element_priority(E2::cube_elem()),
element_priority(E3::cube_elem()),
];
let mut handles_scalars = Vec::new();
for i in 0..3 {
for (j, scalar_priority) in scalar_priorities.iter().enumerate() {
if scalar_priority == &i {
if j == 0 {
if let Some(values) = &scalars_0 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
} else if j == 1 {
if let Some(values) = &scalars_1 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
} else if j == 2 {
if let Some(values) = &scalars_2 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
}
}
}
}
handles_scalars
}
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
let mut num_elems = 1;
for i in shape.iter() {
num_elems *= i;
}
num_elems
}

View File

@ -1,560 +0,0 @@
use super::Compiler;
use crate::{
ir::{
Binding, CubeDim, Elem, Item, KernelDefinition, Location, ReadingStrategy, Scope, Variable,
Vectorization, Visibility,
},
Runtime,
};
/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
#[derive(Clone)]
pub struct KernelIntegrator {
expansion: KernelExpansion,
input_bindings: Vec<Binding>,
output_bindings: Vec<Binding>,
named_bindings: Vec<(String, Binding)>,
}
/// The information necessary to compile a [kernel definition](KernelDefinition).
#[derive(Clone)]
pub struct KernelExpansion {
pub inputs: Vec<InputInfo>,
pub outputs: Vec<OutputInfo>,
pub scope: Scope,
}
/// Simply indicate the output that can be replaced by the input.
#[derive(new, Clone, Copy, Debug)]
pub struct InplaceMapping {
/// Input position.
pub pos_input: usize,
/// Output position.
pub pos_output: usize,
}
#[derive(Clone, Copy, Debug)]
enum VectorizationPartial {
Input {
pos: usize,
vectorization: Vectorization,
},
Output {
pos: usize,
vectorization: Vectorization,
},
}
#[derive(Default, Clone)]
pub struct KernelSettings {
pub mappings: Vec<InplaceMapping>,
vectorization_global: Option<Vectorization>,
vectorization_partial: Vec<VectorizationPartial>,
cube_dim: CubeDim,
pub reading_strategy: Vec<(u16, ReadingStrategy)>,
}
impl core::fmt::Display for KernelSettings {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// The goal of this implementation is to generate the shortest representation
// that won't clash with any other compilation settings. This is crucial since we rely on
// this representation to know when to compile a new version of a kernel.
//
// Each main section starts with a letter that can't be used by other main sections:
//
// * Mapping: m
// * Input: i
// * Output: o
//
// * Reading Strategy: r
// * Output layout: o
// * Plain: p
//
// * Vectorization Global: vg{factor}
// * Vectorization Partial Input: v{factor}i{pos}
// * Vectorization Partial Output: vo
// * Cube Dim X: x
// * Cube Dim Y: y
// * Cube Dim Z: z
f.write_str("m")?;
for mapping in self.mappings.iter() {
f.write_fmt(format_args!(
"i{}o{}",
mapping.pos_input, mapping.pos_output
))?;
}
f.write_str("r")?;
for (input, strategy) in self.reading_strategy.iter() {
match strategy {
ReadingStrategy::OutputLayout => f.write_fmt(format_args!("i{}o", input)),
ReadingStrategy::Plain => f.write_fmt(format_args!("i{}p", input)),
}?;
}
match self.vectorization_global {
Some(vectorization) => f.write_fmt(format_args!("vg{}", vectorization))?,
None => f.write_str("vn")?,
};
for vectorization in self.vectorization_partial.iter() {
match vectorization {
VectorizationPartial::Input { pos, vectorization } => {
f.write_fmt(format_args!("v{vectorization}i{pos}"))?
}
VectorizationPartial::Output { pos, vectorization } => {
f.write_fmt(format_args!("v{vectorization}o{pos}"))?
}
};
}
f.write_fmt(format_args!(
"x{}y{}z{}",
self.cube_dim.x, self.cube_dim.y, self.cube_dim.x
))
}
}
impl KernelSettings {
/// Compile the shader with vectorization enabled for all inputs and outputs.
#[allow(dead_code)]
pub fn vectorize_global(mut self, vectorization: Vectorization) -> Self {
self.vectorization_global = Some(vectorization);
self
}
/// Compile the shader with vectorization enabled for an input.
#[allow(dead_code)]
pub fn vectorize_input(mut self, position: usize, vectorization: Vectorization) -> Self {
// Not setting the vectorization factor when it's the default value reduces the kernel id
// size.
if vectorization == 1 {
return self;
}
self.vectorization_partial
.push(VectorizationPartial::Input {
pos: position,
vectorization,
});
self
}
/// Compile the shader with vectorization enabled for an output.
#[allow(dead_code)]
pub fn vectorize_output(mut self, position: usize, vectorization: Vectorization) -> Self {
// Not setting the vectorization factor when it's the default value reduces the kernel id
// size.
if vectorization == 1 {
return self;
}
self.vectorization_partial
.push(VectorizationPartial::Output {
pos: position,
vectorization,
});
self
}
/// Fetch the vectorization for the provided input position.
pub fn vectorization_input(&self, position: usize) -> Vectorization {
if let Some(vec) = self.vectorization_global {
return vec;
}
for partial in self.vectorization_partial.iter() {
if let VectorizationPartial::Input { pos, vectorization } = partial {
if *pos == position {
return *vectorization;
}
}
}
1
}
/// Fetch the vectorization for the provided output position.
pub fn vectorization_output(&self, position: usize) -> Vectorization {
if let Some(vec) = self.vectorization_global {
return vec;
}
for partial in self.vectorization_partial.iter() {
if let VectorizationPartial::Output { pos, vectorization } = partial {
if *pos == position {
return *vectorization;
}
}
}
1
}
/// Compile the shader with inplace enabled by the given [mapping](InplaceMapping).
///
/// Notes:
///
/// You should favor using `dynamic_settings` when using fusion, since the mapping is going to
/// be created from the runtime information.
pub fn inplace(mut self, mappings: Vec<InplaceMapping>) -> Self {
self.mappings = mappings;
self
}
/// Set cube dimension.
#[allow(dead_code)]
pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
self.cube_dim = cube_dim;
self
}
}
#[allow(dead_code)]
fn is_contiguous(strides: &[usize]) -> bool {
let mut current = 0;
for stride in strides.iter().rev() {
if current > *stride {
return false;
}
current = *stride;
}
true
}
/// Information related to an input.
#[derive(Clone, Debug)]
pub enum InputInfo {
Array { item: Item, visibility: Visibility },
Scalar { elem: Elem, size: usize },
}
impl InputInfo {
/// The item type of the input.
#[allow(dead_code)]
pub fn item(&self) -> Item {
match self {
InputInfo::Array {
item,
visibility: _,
} => *item,
InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
}
}
}
impl OutputInfo {
/// The item type of the input.
#[allow(dead_code)]
pub fn item(&self) -> Item {
match self {
OutputInfo::ArrayWrite {
item,
local: _,
position: _,
} => *item,
OutputInfo::InputArrayWrite {
item,
input: _,
local: _,
position: _,
} => *item,
OutputInfo::Array { item } => *item,
}
}
}
/// Information related to an output.
#[derive(Clone, Debug)]
pub enum OutputInfo {
/// Write the local variable to a new array.
///
/// This will create a new binding in the [kernel definition](KernelDefinition).
ArrayWrite {
item: Item,
local: u16,
position: Variable,
},
/// Write the local variable to an existing input binding.
InputArrayWrite {
item: Item,
input: u16,
local: u16,
position: Variable,
},
/// Simply register the output, but don't automatically add a write to it.
///
/// Useful when a procedure writes to the output using operations.
Array { item: Item },
}
impl OutputInfo {
#[allow(dead_code)]
pub fn elem_size<R: Runtime>(&self) -> usize {
let elem = match self {
OutputInfo::ArrayWrite {
item,
local: _,
position: _,
} => bool_elem(item.elem()),
OutputInfo::InputArrayWrite {
item,
input: _,
local: _,
position: _,
} => bool_elem(item.elem()),
OutputInfo::Array { item } => bool_elem(item.elem()),
};
<R::Compiler as Compiler>::elem_size(elem)
}
}
impl KernelIntegrator {
/// Starts a new compilation.
pub fn new(info: KernelExpansion) -> Self {
Self {
expansion: info,
input_bindings: Default::default(),
output_bindings: Default::default(),
named_bindings: Default::default(),
}
}
/// Performs the compilation with the provided [settings](KernelSettings).
pub fn integrate(mut self, mut settings: KernelSettings) -> KernelDefinition {
if let Some(vectorization) = settings.vectorization_global {
self.expansion.scope.vectorize(vectorization);
}
self.register_inputs(&settings);
self.register_outputs(&mut settings);
let inputs = self.input_bindings;
let outputs = self.output_bindings;
let mut named = Vec::with_capacity(2);
named.push((
"info".to_string(),
Binding {
item: Item::new(Elem::UInt),
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
// for each tensor rank.
},
));
for (name, binding) in self.named_bindings.into_iter() {
named.push((name, binding));
}
KernelDefinition {
inputs,
outputs,
named,
cube_dim: settings.cube_dim,
body: self.expansion.scope,
}
}
fn register_inputs(&mut self, settings: &KernelSettings) {
for (id, strategy) in settings.reading_strategy.iter() {
self.expansion.scope.update_read(*id, *strategy);
}
for input in self.expansion.inputs.drain(..) {
match input {
InputInfo::Array { item, visibility } => {
let item = if let Some(vectorization) = settings.vectorization_global {
item.vectorize(vectorization)
} else {
item
};
self.input_bindings.push(Binding {
item: bool_item(item),
visibility,
location: Location::Storage,
size: None,
});
}
InputInfo::Scalar { elem, size } => {
let elem = bool_elem(elem);
self.named_bindings.push((
format!("scalars_{}", elem),
Binding {
item: Item::new(elem),
visibility: Visibility::Read,
location: Location::Storage,
size: Some(size),
},
));
}
}
}
}
fn register_outputs(&mut self, settings: &mut KernelSettings) {
let mut index = 0;
if !settings.mappings.is_empty() {
let mut mappings = Vec::new();
core::mem::swap(&mut settings.mappings, &mut mappings);
for mapping in mappings {
self.register_inplace_mapping(mapping);
}
}
for array in self.expansion.outputs.drain(..) {
match array {
OutputInfo::ArrayWrite {
item,
local,
position,
} => {
let item = if let Some(vectorization) = settings.vectorization_global {
item.vectorize(vectorization)
} else {
item
};
let item_adapted = bool_item(item);
self.output_bindings.push(Binding {
item: item_adapted,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
});
self.expansion.scope.write_global(
Variable::Local {
id: local,
item,
depth: self.expansion.scope.depth,
},
Variable::GlobalOutputArray {
id: index,
item: item_adapted,
},
position,
);
index += 1;
}
OutputInfo::InputArrayWrite {
item,
input,
local,
position,
} => {
let item = if let Some(vectorization) = settings.vectorization_global {
item.vectorize(vectorization)
} else {
item
};
self.expansion.scope.write_global(
Variable::Local {
id: local,
item,
depth: self.expansion.scope.depth,
},
Variable::GlobalInputArray {
id: input,
item: bool_item(item),
},
position,
);
}
OutputInfo::Array { item } => {
let item = if let Some(vectorization) = settings.vectorization_global {
item.vectorize(vectorization)
} else {
item
};
let elem_adapted = bool_item(item);
self.output_bindings.push(Binding {
item: elem_adapted,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
});
index += 1;
}
}
}
}
fn register_inplace_mapping(&mut self, mapping: InplaceMapping) {
let output = match self.expansion.outputs.get_mut(mapping.pos_output) {
Some(output) => output,
None => {
// The mapping is handled differently, normally by cube itself.
return;
}
};
let (item, local, position) = match output {
OutputInfo::ArrayWrite { item, local, position } => (item, local, position),
OutputInfo::InputArrayWrite {
item: _,
input,
local: _,
position: _,
} => {
assert_eq!(
*input, mapping.pos_input as u16,
"Can't use different inputs for the same output."
);
return;
}
OutputInfo::Array { item: _ } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."),
};
let item = match self.input_bindings.get_mut(mapping.pos_input) {
Some(binding) => {
// Update input visibility.
binding.visibility = Visibility::ReadWrite;
// Inputs modified inplace should be read without any specified layout.
self.expansion
.scope
.update_read(mapping.pos_input as u16, ReadingStrategy::Plain);
// Use the same item as the input.
//
// The output can be different (i.e inplace boolean operations on float bindings).
binding.item
}
None => *item,
};
// Update the output.
*output = OutputInfo::InputArrayWrite {
item,
input: mapping.pos_input as u16,
local: *local,
position: *position,
};
}
}
fn bool_item(ty: Item) -> Item {
Item {
elem: bool_elem(ty.elem),
vectorization: ty.vectorization,
}
}
pub fn bool_elem(elem: Elem) -> Elem {
match elem {
// U32 are used for bool tensors
Elem::Bool => Elem::UInt,
_ => elem,
}
}

View File

@ -1,8 +0,0 @@
mod execution;
mod integrator;
mod compiler;
pub use compiler::*;
pub use execution::*;
pub use integrator::*;

View File

@ -1,104 +0,0 @@
use crate::ir::{Elem, Item, Visibility};
use crate::prelude::KernelDefinition;
use crate::KernelSettings;
use crate::{
frontend::{CubeContext, ExpandElement},
InputInfo, KernelExpansion, KernelIntegrator, OutputInfo,
};
use std::collections::HashMap;
/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition).
pub struct KernelBuilder {
/// Cube [context](CubeContext).
pub context: CubeContext,
inputs: Vec<InputInfo>,
outputs: Vec<OutputInfo>,
indices: HashMap<Elem, usize>,
num_input: u16,
num_output: u16,
}
impl KernelBuilder {
/// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
let index = match self.indices.get_mut(&elem) {
Some(index) => match self.inputs.get_mut(*index).unwrap() {
InputInfo::Scalar { elem: _, size } => {
*size += 1;
*size as u16 - 1
}
_ => panic!("Should be a scalar."),
},
None => {
self.indices.insert(elem, self.inputs.len());
self.inputs.push(InputInfo::Scalar { size: 1, elem });
0
}
};
self.context.scalar(index, elem)
}
/// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
self.outputs.push(OutputInfo::Array { item });
let variable = self.context.output(self.num_output, item);
self.num_output += 1;
variable
}
/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
self.inputs.push(InputInfo::Array {
item,
visibility: Visibility::Read,
});
let variable = self.context.input(self.num_input, item);
self.num_input += 1;
variable
}
/// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn output_array(&mut self, item: Item) -> ExpandElement {
self.outputs.push(OutputInfo::Array { item });
let variable = self.context.output(self.num_output, item);
self.num_output += 1;
variable
}
/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn input_array(&mut self, item: Item) -> ExpandElement {
self.inputs.push(InputInfo::Array {
item,
visibility: Visibility::Read,
});
let variable = self.context.input(self.num_input, item);
self.num_input += 1;
variable
}
/// Build the [kernel definition](KernelDefinition).
pub fn build(self, settings: KernelSettings) -> KernelDefinition {
KernelIntegrator::new(KernelExpansion {
scope: self.context.into_scope(),
inputs: self.inputs,
outputs: self.outputs,
})
.integrate(settings)
}
}
impl Default for KernelBuilder {
fn default() -> Self {
Self {
context: CubeContext::root(),
inputs: Vec::new(),
outputs: Vec::new(),
indices: HashMap::new(),
num_input: 0,
num_output: 0,
}
}
}

View File

@ -1,88 +0,0 @@
use std::marker::PhantomData;
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
use alloc::sync::Arc;
use burn_compute::server::{Binding, ComputeServer};
/// A kernel, compiled in the target language
pub struct CompiledKernel {
/// Source code of the kernel
pub source: String,
/// Size of a cube for the compiled kernel
pub cube_dim: CubeDim,
/// The number of bytes used by the share memory
pub shared_mem_bytes: usize,
}
/// Kernel trait with the ComputeShader that will be compiled and cached based on the
/// provided id.
pub trait CubeTask: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> String;
/// Compile the kernel into source
fn compile(&self) -> CompiledKernel;
}
/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
#[derive(new)]
pub struct KernelTask<C: Compiler, K: Kernel> {
kernel_definition: K,
_compiler: PhantomData<C>,
}
impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
fn compile(&self) -> CompiledKernel {
let gpu_ir = self.kernel_definition.define();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = C::compile(gpu_ir);
let shared_mem_bytes = lower_level_ir.shared_memory_size();
let source = lower_level_ir.to_string();
CompiledKernel {
source,
cube_dim,
shared_mem_bytes,
}
}
fn id(&self) -> String {
self.kernel_definition.id().clone()
}
}
impl CubeTask for Arc<dyn CubeTask> {
fn compile(&self) -> CompiledKernel {
self.as_ref().compile()
}
fn id(&self) -> String {
self.as_ref().id()
}
}
impl CubeTask for Box<dyn CubeTask> {
fn compile(&self) -> CompiledKernel {
self.as_ref().compile()
}
fn id(&self) -> String {
self.as_ref().id()
}
}
/// Provides launch information specifying the number of work groups to be used by a compute shader.
pub enum CubeCount<S: ComputeServer> {
/// Dispatch x,y,z work groups.
Static(u32, u32, u32),
/// Dispatch work groups based on the values in this buffer. The buffer should contain a u32 array [x, y, z].
Dynamic(Binding<S>),
}
impl<S: ComputeServer> Clone for CubeCount<S> {
fn clone(&self) -> Self {
match self {
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
}
}
}

View File

@ -1,337 +0,0 @@
use crate::compute::{CubeCount, KernelTask};
use crate::ir::{Elem, FloatKind, IntKind};
use crate::prelude::ArrayHandle;
use crate::KernelSettings;
use crate::{calculate_num_elems_dyn_rank, frontend::TensorHandle, Kernel, Runtime};
use burn_compute::client::ComputeClient;
use burn_compute::server::Binding;
use bytemuck::NoUninit;
use num_traits::ToPrimitive;
/// Prepare a kernel for [launch](KernelLauncher::launch).
pub struct KernelLauncher<R: Runtime> {
tensors: TensorState<R>,
scalar_bf16: ScalarState<half::bf16>,
scalar_f16: ScalarState<half::f16>,
scalar_f32: ScalarState<f32>,
scalar_f64: ScalarState<f64>,
scalar_u32: ScalarState<u32>,
scalar_i64: ScalarState<i64>,
scalar_i32: ScalarState<i32>,
scalar_order: Vec<Elem>,
pub settings: KernelSettings,
}
impl<R: Runtime> KernelLauncher<R> {
/// Register a tensor to be launched.
pub fn register_tensor(&mut self, tensor: &TensorHandle<'_, R>) {
self.tensors.push(tensor);
}
/// Register an array to be launched.
pub fn register_array(&mut self, array: &ArrayHandle<'_, R>) {
self.tensors.push(&array.as_tensor());
}
/// Register a u32 scalar to be launched.
pub fn register_u32(&mut self, scalar: u32) {
self.register_scalar(Elem::UInt);
self.scalar_u32.push(scalar);
}
/// Register a i32 scalar to be launched.
pub fn register_i32(&mut self, scalar: i32) {
self.register_scalar(Elem::Int(IntKind::I32));
self.scalar_i32.push(scalar);
}
/// Register a i64 scalar to be launched.
pub fn register_i64(&mut self, scalar: i64) {
self.register_scalar(Elem::Int(IntKind::I64));
self.scalar_i64.push(scalar);
}
/// Register a bf16 scalar to be launched.
pub fn register_bf16(&mut self, scalar: half::bf16) {
self.register_scalar(Elem::Float(FloatKind::BF16));
self.scalar_bf16.push(scalar);
}
/// Register a f16 scalar to be launched.
pub fn register_f16(&mut self, scalar: half::f16) {
self.register_scalar(Elem::Float(FloatKind::F16));
self.scalar_f16.push(scalar);
}
/// Register a f32 scalar to be launched.
pub fn register_f32(&mut self, scalar: f32) {
self.register_scalar(Elem::Float(FloatKind::F32));
self.scalar_f32.push(scalar);
}
/// Register a f64 scalar to be launched.
pub fn register_f64(&mut self, scalar: f64) {
self.register_scalar(Elem::Float(FloatKind::F64));
self.scalar_f64.push(scalar);
}
/// Launch the kernel.
pub fn launch<K: Kernel>(
self,
cube_count: CubeCount<R::Server>,
kernel: K,
client: ComputeClient<R::Server, R::Channel>,
) {
let bindings = self.into_bindings(&client);
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.execute(kernel, cube_count, bindings);
}
/// We need to create the bindings in the same order they are defined in the compilation step.
///
/// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed
/// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
/// are registered in the same order they are added. This is why we store the scalar data type
/// in the `scalar_order` vector, so that we can register them in the same order.
fn into_bindings(
mut self,
client: &ComputeClient<R::Server, R::Channel>,
) -> Vec<Binding<R::Server>> {
let mut bindings = Vec::new();
self.tensors.register(client, &mut bindings);
for elem in self.scalar_order.drain(..) {
match elem {
Elem::Float(kind) => match kind {
FloatKind::F16 => self.scalar_f16.register::<R>(client, &mut bindings),
FloatKind::BF16 => self.scalar_bf16.register::<R>(client, &mut bindings),
FloatKind::F32 => self.scalar_f32.register::<R>(client, &mut bindings),
FloatKind::F64 => self.scalar_f64.register::<R>(client, &mut bindings),
},
Elem::Int(kind) => match kind {
IntKind::I32 => self.scalar_i32.register::<R>(client, &mut bindings),
IntKind::I64 => self.scalar_i64.register::<R>(client, &mut bindings),
},
Elem::UInt => self.scalar_u32.register::<R>(client, &mut bindings),
Elem::Bool => panic!("Bool can't be passed as bindings."),
}
}
bindings
}
fn register_scalar(&mut self, elem: Elem) {
if !self.scalar_order.contains(&elem) {
self.scalar_order.push(elem);
}
}
}
/// Handles the tensor state.
pub enum TensorState<R: Runtime> {
/// No tensor is registered yet.
Empty,
/// The registered tensors.
Some {
bindings: Vec<Binding<R::Server>>,
metadata: Vec<u32>,
lengths: Vec<u32>,
},
}
/// Handles the scalar state of an element type
///
/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
pub enum ScalarState<T> {
/// No scalar of that type is registered yet.
Empty,
/// The registered scalars.
Some(Vec<T>),
}
impl<R: Runtime> TensorState<R> {
/// Push a new tensor to the state.
pub fn push(&mut self, tensor: &TensorHandle<'_, R>) {
if let TensorState::Empty = self {
*self = TensorState::Some {
bindings: Vec::with_capacity(1),
metadata: Vec::new(),
lengths: Vec::new(),
};
};
let (bindings, metadata, lengths) = match self {
TensorState::Empty => panic!("Should be init"),
TensorState::Some {
bindings,
metadata,
lengths,
} => (bindings, metadata, lengths),
};
bindings.push(tensor.handle.clone().binding());
let old_rank = if metadata.is_empty() {
let rank = tensor.strides.len() as u32;
metadata.push(rank);
None
} else if tensor.strides.len() > metadata[0] as usize {
let old_rank = metadata[0];
let rank = tensor.strides.len() as u32;
Self::adjust_rank(metadata, bindings.len(), rank);
Some(old_rank)
} else {
None
};
Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata);
Self::register_shape(tensor.shape, old_rank, metadata);
if R::require_array_lengths() {
let len = calculate_num_elems_dyn_rank(tensor.shape);
lengths.push(len as u32);
}
}
fn adjust_rank(metadata: &mut Vec<u32>, num_registered: usize, rank: u32) {
let old_rank = metadata[0] as usize;
let rank_diff = rank as usize - old_rank;
let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered);
for pos in 0..num_registered {
let stride_index = (pos * old_rank * 2) + 1;
let shape_index = stride_index + old_rank;
let strides_old = &metadata[stride_index..stride_index + old_rank];
let shape_old = &metadata[shape_index..shape_index + old_rank];
Self::register_strides(
strides_old,
shape_old,
Some(old_rank as u32),
&mut updated_metadata,
);
Self::register_shape(shape_old, Some(old_rank as u32), &mut updated_metadata);
}
core::mem::swap(&mut updated_metadata, metadata);
}
fn register_strides<T: ToPrimitive>(
strides: &[T],
shape: &[T],
old_rank: Option<u32>,
output: &mut Vec<u32>,
) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = old_rank - rank;
let padded_strides = if rank_diff > 0 {
shape
.iter()
.take(old_rank as usize)
.map(|a| a.to_u32().unwrap())
.sum::<u32>()
} else {
0
};
for _ in 0..rank_diff {
output.push(padded_strides.to_u32().unwrap());
}
old_rank as usize
} else {
output[0] as usize // same as current.
};
for stride in strides.iter().take(old_rank) {
output.push(stride.to_u32().unwrap());
}
}
fn register_shape<T: ToPrimitive>(shape: &[T], old_rank: Option<u32>, output: &mut Vec<u32>) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = rank - old_rank;
for _ in 0..rank_diff {
output.push(1);
}
old_rank as usize
} else {
output[0] as usize // same as current
};
for elem in shape.iter().take(old_rank) {
output.push(elem.to_u32().unwrap());
}
}
fn register(
self,
client: &ComputeClient<R::Server, R::Channel>,
bindings_global: &mut Vec<Binding<R::Server>>,
) {
if let Self::Some {
bindings,
mut metadata,
lengths,
} = self
{
if R::require_array_lengths() {
for len in lengths {
metadata.push(len);
}
}
bindings_global.extend(bindings);
bindings_global.push(client.create(bytemuck::cast_slice(&metadata)).binding());
}
}
}
impl<T: NoUninit> ScalarState<T> {
/// Add a new scalar value to the state.
pub fn push(&mut self, val: T) {
match self {
ScalarState::Empty => *self = Self::Some(vec![val]),
ScalarState::Some(values) => values.push(val),
}
}
fn register<R: Runtime>(
&self,
client: &ComputeClient<R::Server, R::Channel>,
bindings: &mut Vec<Binding<R::Server>>,
) {
match self {
ScalarState::Empty => (),
ScalarState::Some(values) => {
let handle = client.create(bytemuck::cast_slice(values));
bindings.push(handle.binding());
}
}
}
}
impl<R: Runtime> Default for KernelLauncher<R> {
fn default() -> Self {
Self {
tensors: TensorState::Empty,
scalar_bf16: ScalarState::Empty,
scalar_f16: ScalarState::Empty,
scalar_f32: ScalarState::Empty,
scalar_f64: ScalarState::Empty,
scalar_u32: ScalarState::Empty,
scalar_i64: ScalarState::Empty,
scalar_i32: ScalarState::Empty,
scalar_order: Vec::new(),
settings: Default::default(),
}
}
}

View File

@ -1,7 +0,0 @@
mod builder;
mod kernel;
mod launcher;
pub use builder::*;
pub use kernel::*;
pub use launcher::*;

View File

@ -1,12 +0,0 @@
#[macro_export]
macro_rules! unexpanded {
() => ({
panic!("Unexpanded Cube functions should not be called. ");
});
($msg:expr) => ({
panic!($msg);
});
($fmt:expr, $($arg:tt)*) => ({
panic!($fmt, $($arg)*);
});
}

View File

@ -1,153 +0,0 @@
use std::ops::Deref;
use crate::frontend::{CubeContext, ExpandElement, UInt};
use crate::ir::{Branch, Elem, If, IfElse, Item, Loop, RangeLoop, Variable};
use super::comptime::Comptime;
pub fn range<S, E>(start: S, end: E, _unroll: Comptime<bool>) -> impl Iterator<Item = UInt>
where
S: Into<UInt>,
E: Into<UInt>,
{
let start: UInt = start.into();
let end: UInt = end.into();
(start.val..end.val).map(UInt::new)
}
pub fn range_expand<F, S, E>(context: &mut CubeContext, start: S, end: E, unroll: bool, mut func: F)
where
F: FnMut(&mut CubeContext, ExpandElement),
S: Into<ExpandElement>,
E: Into<ExpandElement>,
{
let start: ExpandElement = start.into();
let end: ExpandElement = end.into();
if unroll {
let start = match start.deref() {
Variable::ConstantScalar { value, .. } => *value as usize,
_ => panic!("Only constant start can be unrolled."),
};
let end = match end.deref() {
Variable::ConstantScalar { value, .. } => *value as usize,
_ => panic!("Only constant end can be unrolled."),
};
for i in start..end {
func(context, i.into())
}
} else {
let mut child = context.child();
let index_ty = Item::new(Elem::UInt);
let i = child.scope.borrow_mut().create_local_undeclared(index_ty);
let i = ExpandElement::Plain(i);
func(&mut child, i.clone());
context.register(Branch::RangeLoop(RangeLoop {
i: *i,
start: *start,
end: *end,
scope: child.into_scope(),
}));
}
}
pub fn if_expand<IF>(
context: &mut CubeContext,
comptime_cond: Option<bool>,
runtime_cond: ExpandElement,
mut block: IF,
) where
IF: FnMut(&mut CubeContext),
{
match comptime_cond {
Some(cond) => {
if cond {
block(context);
}
}
None => {
let mut child = context.child();
block(&mut child);
context.register(Branch::If(If {
cond: *runtime_cond,
scope: child.into_scope(),
}));
}
}
}
pub fn if_else_expand<IF, EL>(
context: &mut CubeContext,
comptime_cond: Option<bool>,
runtime_cond: ExpandElement,
mut then_block: IF,
mut else_block: EL,
) where
IF: FnMut(&mut CubeContext),
EL: FnMut(&mut CubeContext),
{
match comptime_cond {
Some(cond) => {
if cond {
then_block(context);
} else {
else_block(context);
}
}
None => {
let mut then_child = context.child();
then_block(&mut then_child);
let mut else_child = context.child();
else_block(&mut else_child);
context.register(Branch::IfElse(IfElse {
cond: *runtime_cond,
scope_if: then_child.into_scope(),
scope_else: else_child.into_scope(),
}));
}
}
}
pub fn break_expand(context: &mut CubeContext) {
context.register(Branch::Break);
}
pub fn return_expand(context: &mut CubeContext) {
context.register(Branch::Return);
}
pub fn loop_expand<FB>(context: &mut CubeContext, mut block: FB)
where
FB: FnMut(&mut CubeContext),
{
let mut inside_loop = context.child();
block(&mut inside_loop);
context.register(Branch::Loop(Loop {
scope: inside_loop.into_scope(),
}));
}
pub fn while_loop_expand<FC, FB>(context: &mut CubeContext, mut cond_fn: FC, mut block: FB)
where
FC: FnMut(&mut CubeContext) -> ExpandElement,
FB: FnMut(&mut CubeContext),
{
let mut inside_loop = context.child();
let cond: ExpandElement = cond_fn(&mut inside_loop);
if_expand(&mut inside_loop, None, cond, break_expand);
block(&mut inside_loop);
context.register(Branch::Loop(Loop {
scope: inside_loop.into_scope(),
}));
}

View File

@ -1,238 +0,0 @@
//! This module exposes cooperative matrix-multiply and accumulate operations.
//!
//! Most of the functions are actually unsafe, since they mutate their input, even if they are
//! passed as reference.
//!
//! # Example
//!
//! This is a basic 16x16x16 matrix multiplication example.
//!
//! ```rust, ignore
//! #[cube(launch)]
//! pub fn example(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>) {
//! let a = cmma::Matrix::<F16>::new(
//! cmma::MatrixIdent::A,
//! 16,
//! 16,
//! 16,
//! cmma::MatrixLayout::RowMajor,
//! );
//! let b = cmma::Matrix::<F16>::new(
//! cmma::MatrixIdent::B,
//! 16,
//! 16,
//! 16,
//! cmma::MatrixLayout::ColMajor,
//! );
//! let c = cmma::Matrix::<F32>::new(
//! cmma::MatrixIdent::Accumulator,
//! 16,
//! 16,
//! 16,
//! cmma::MatrixLayout::Undefined,
//! );
//! cmma::fill::<F32>(&c, F32::new(0.0));
//! cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
//! cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
//!
//! cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
//!
//! cmma::store::<F32>(
//! out.as_slice_mut(),
//! &c,
//! UInt::new(16),
//! cmma::MatrixLayout::RowMajor,
//! );
//! }
//! ```
use std::marker::PhantomData;
use crate::{
ir::{self, Operation},
unexpanded,
};
use super::{
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut,
UInt,
};
pub use ir::{MatrixIdent, MatrixLayout};
/// A matrix represent a 2D grid of numbers.
///
/// They can either be in a [row major](MatrixLayout::RowMajor) or a
/// [column major](MatrixLayout::ColMajor) format.
pub struct Matrix<C: CubeType> {
_c: PhantomData<C>,
}
/// Expand type of [Matrix].
#[derive(Clone)]
pub struct MatrixExpand {
elem: ExpandElement,
}
impl<C: CubeType> CubeType for Matrix<C> {
type ExpandType = MatrixExpand;
}
impl Init for MatrixExpand {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
impl<C: CubePrimitive> Matrix<C> {
/// Create a new matrix that is going to be used in the
/// [matrix-multiply and accumulate](execute()) function.
///
/// You have to declare the shape used for the execution.
/// The shape of the current matrix is determined using the [MatrixIdent].
///
/// * [MatrixIdent::A] Shape => (M, K)
/// * [MatrixIdent::B] Shape => (K, N)
/// * [MatrixIdent::Accumulator] Shape => (M, N)
///
/// Not all shapes are supported, and the permitted shapes depend on the element type.
///
/// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes).
#[allow(unused_variables)]
pub fn new(ident: MatrixIdent, m: u8, n: u8, k: u8, layout: MatrixLayout) -> Self {
Matrix { _c: PhantomData }
}
pub fn __expand_new(
context: &mut CubeContext,
ident: MatrixIdent,
m: u8,
n: u8,
k: u8,
layout: MatrixLayout,
) -> MatrixExpand {
let elem = context.create_matrix(ir::Matrix {
ident,
m,
n,
k,
elem: C::as_elem(),
layout,
});
MatrixExpand { elem }
}
}
/// Fill the matrix with the provided value.
#[allow(unused_variables)]
pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
unexpanded!()
}
/// Module containing the expand function for [fill()].
pub mod fill {
use super::*;
/// Expand method of [fill()].
pub fn __expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElement,
) {
context.register(Operation::CoopMma(ir::CoopMma::Fill {
mat: *mat.elem,
value: *value,
}));
}
}
/// Load the matrix with the provided array using the stride.
#[allow(unused_variables)]
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
unexpanded!()
}
/// Module containing the expand function for [load()].
pub mod load {
use super::*;
/// Expand method of [load()].
#[allow(unused_variables)]
pub fn __expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElementTyped<Slice<'static, C>>,
stride: ExpandElement,
) {
context.register(Operation::CoopMma(ir::CoopMma::Load {
mat: *mat.elem,
value: *value.expand,
stride: *stride,
}));
}
}
/// Store the matrix in the given array following the given stride and layout.
#[allow(unused_variables)]
pub fn store<C: CubePrimitive>(
output: &mut SliceMut<'_, C>,
mat: &Matrix<C>,
stride: UInt,
layout: MatrixLayout,
) {
unexpanded!()
}
/// Module containing the expand function for [store()].
pub mod store {
use super::*;
/// Expand method of [store()].
#[allow(unused_variables)]
pub fn __expand<C: CubePrimitive>(
context: &mut CubeContext,
output: ExpandElementTyped<SliceMut<'static, C>>,
mat: MatrixExpand,
stride: ExpandElement,
layout: MatrixLayout,
) {
context.register(Operation::CoopMma(ir::CoopMma::Store {
output: *output.expand,
mat: *mat.elem,
stride: *stride,
layout,
}));
}
}
/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
#[allow(unused_variables)]
pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
mat_a: &Matrix<A>,
mat_b: &Matrix<B>,
mat_c: &Matrix<C>,
mat_d: &Matrix<D>,
) {
unexpanded!()
}
/// Module containing the expand function for [execute()].
pub mod execute {
use super::*;
/// Expand method of [execute()].
pub fn __expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
context: &mut CubeContext,
mat_a: MatrixExpand,
mat_b: MatrixExpand,
mat_c: MatrixExpand,
mat_d: MatrixExpand,
) {
context.register(Operation::CoopMma(ir::CoopMma::Execute {
mat_a: *mat_a.elem,
mat_b: *mat_b.elem,
mat_c: *mat_c.elem,
mat_d: *mat_d.elem,
}));
}
}

View File

@ -1,146 +0,0 @@
use crate::{
frontend::{CubeContext, CubeType},
unexpanded,
};
use super::{ExpandElement, Init, UInt, Vectorized};
#[derive(Clone, Copy)]
/// Encapsulates a value to signify it must be used at compilation time rather than in the kernel
///
/// Use `Comptime<Option<T>>` to have an alternate runtime behaviour if the compilation time value is not present
pub struct Comptime<T> {
pub(crate) inner: T,
}
impl<T> Comptime<T> {
/// Create a new Comptime. Useful when hardcoding values in
/// Cube kernels. For instance:
/// if Comptime::new(false) {...} never generates the inner code block
pub fn new(inner: T) -> Self {
Self { inner }
}
/// Get the inner value of a Comptime. For instance:
/// let c = Comptime::new(false);
/// if Comptime::get(c) {...}
pub fn get(_comptime: Self) -> T {
unexpanded!()
}
pub fn map<R, F: Fn(T) -> R>(_comptime: Self, _closure: F) -> Comptime<R> {
unexpanded!()
}
pub fn map_expand<R, F: Fn(T) -> R>(inner: T, closure: F) -> R {
closure(inner)
}
}
impl<T: CubeType + Into<T::ExpandType>> Comptime<Option<T>> {
/// Map a Comptime optional to a Comptime boolean that tell
/// whether the optional contained a value
pub fn is_some(comptime: Self) -> Comptime<bool> {
Comptime::new(comptime.inner.is_some())
}
/// Return the inner value of the Comptime if it exists,
/// otherwise tell how to compute it at runtime
pub fn unwrap_or_else<F>(_comptime: Self, mut _alt: F) -> T
where
F: FnOnce() -> T,
{
unexpanded!()
}
/// Expanded version of unwrap_or_else
pub fn unwrap_or_else_expand<F>(
context: &mut CubeContext,
t: Option<T>,
alt: F,
) -> <T as CubeType>::ExpandType
where
F: FnOnce(&mut CubeContext) -> T::ExpandType,
{
match t {
Some(t) => t.into(),
None => alt(context),
}
}
}
impl<T: Clone + Init> CubeType for Comptime<T> {
type ExpandType = T;
}
impl<T: Vectorized> Comptime<T> {
pub fn vectorization(_state: &T) -> Comptime<UInt> {
unexpanded!()
}
pub fn vectorization_expand(_context: &mut CubeContext, state: T) -> UInt {
state.vectorization_factor()
}
}
impl<T: Into<ExpandElement>> Comptime<T> {
pub fn runtime(_comptime: Self) -> T {
unexpanded!()
}
pub fn runtime_expand(_context: &mut CubeContext, inner: T) -> ExpandElement {
inner.into()
}
}
impl<T: core::ops::Add<T, Output = T>> core::ops::Add for Comptime<T> {
type Output = Comptime<T>;
fn add(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.add(rhs.inner))
}
}
impl<T: core::ops::Sub<T, Output = T>> core::ops::Sub for Comptime<T> {
type Output = Comptime<T>;
fn sub(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.sub(rhs.inner))
}
}
impl<T: core::ops::Div<T, Output = T>> core::ops::Div for Comptime<T> {
type Output = Comptime<T>;
fn div(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.div(rhs.inner))
}
}
impl<T: core::ops::Mul<T, Output = T>> core::ops::Mul for Comptime<T> {
type Output = Comptime<T>;
fn mul(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.mul(rhs.inner))
}
}
impl<T: core::ops::Rem<T, Output = T>> core::ops::Rem for Comptime<T> {
type Output = Comptime<T>;
fn rem(self, rhs: Self) -> Self::Output {
Comptime::new(self.inner.rem(rhs.inner))
}
}
impl<T: core::cmp::PartialOrd + core::cmp::PartialEq> core::cmp::PartialEq for Comptime<T> {
fn eq(&self, other: &Self) -> bool {
core::cmp::PartialEq::eq(&self.inner, &other.inner)
}
}
impl<T: core::cmp::PartialOrd + core::cmp::PartialEq> core::cmp::PartialOrd for Comptime<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
core::cmp::PartialOrd::partial_cmp(&self.inner, &other.inner)
}
}

View File

@ -1,145 +0,0 @@
use crate::frontend::ExpandElement;
use crate::ir::{self, Elem, Item, Operation, Scope};
use alloc::rc::Rc;
use core::cell::RefCell;
use std::collections::HashMap;
#[derive(Default, Clone)]
pub struct VariablePool {
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
}
impl VariablePool {
/// Returns an old, not used anymore variable, if there exists one.
pub fn reuse(&self, item: Item) -> Option<ExpandElement> {
let map = self.map.borrow();
// Filter for candidate variables of the same Item
let variables = map.get(&item)?;
// Among the candidates, take a variable if it's only referenced by the map
// Arbitrarily takes the first it finds in reverse order.
for variable in variables.iter().rev() {
match variable {
ExpandElement::Managed(var) => {
if Rc::strong_count(var) == 1 {
return Some(variable.clone());
}
}
ExpandElement::Plain(_) => (),
}
}
// If no candidate was found, a new var will be needed
None
}
/// Insert a new variable in the map, which is classified by Item
pub fn insert(&mut self, var: ExpandElement) {
let mut map = self.map.borrow_mut();
let item = var.item();
if let Some(variables) = map.get_mut(&item) {
variables.push(var.clone());
} else {
map.insert(var.item(), vec![var.clone()]);
}
}
}
pub struct CubeContext {
pub root: Rc<RefCell<Scope>>,
pub scope: Rc<RefCell<Scope>>,
pub pool: VariablePool,
}
impl CubeContext {
/// Create a new cube context, with a root scope
/// A root scope is at the root of a compute shader
/// Therefore there is one cube context per shader
pub fn root() -> CubeContext {
let root = Rc::new(RefCell::new(Scope::root()));
let scope = root.clone();
Self {
pool: Default::default(),
scope,
root,
}
}
pub fn register<O: Into<Operation>>(&mut self, op: O) {
self.scope.borrow_mut().register(op)
}
pub fn child(&mut self) -> CubeContext {
let scope = self.scope.borrow_mut().child();
Self {
scope: Rc::new(RefCell::new(scope)),
root: self.root.clone(),
pool: self.pool.clone(),
}
}
pub fn into_scope(self) -> Scope {
core::mem::drop(self.root);
Rc::into_inner(self.scope)
.expect("Only one reference")
.into_inner()
}
/// When a new variable is required, we check if we can reuse an old one
/// Otherwise we create a new one.
pub fn create_local(&mut self, item: Item) -> ExpandElement {
// Reuse an old variable if possible
if let Some(var) = self.pool.reuse(item) {
return var;
}
// Create a new variable at the root scope
// Insert it in the variable pool for potential reuse
let new = ExpandElement::Managed(Rc::new(self.root.borrow_mut().create_local(item)));
self.pool.insert(new.clone());
new
}
/// Create a new matrix element.
pub fn create_matrix(&mut self, matrix: ir::Matrix) -> ExpandElement {
let variable = self.scope.borrow_mut().create_matrix(matrix);
ExpandElement::Plain(variable)
}
/// Create a new slice element.
pub fn create_slice(&mut self, item: Item) -> ExpandElement {
let variable = self.scope.borrow_mut().create_slice(item);
ExpandElement::Plain(variable)
}
pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
}
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size))
}
/// Obtain the index-th input
pub fn input(&mut self, id: u16, item: Item) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item })
}
/// Obtain the index-th output
pub fn output(&mut self, id: u16, item: Item) -> ExpandElement {
let var = crate::ir::Variable::GlobalOutputArray { id, item };
self.scope.borrow_mut().write_global_custom(var);
ExpandElement::Plain(var)
}
/// Obtain the index-th scalar
pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem })
}
}

View File

@ -1,238 +0,0 @@
use std::marker::PhantomData;
use crate::{
compute::{KernelBuilder, KernelLauncher},
frontend::CubeType,
ir::{Item, Vectorization},
unexpanded, KernelSettings, Runtime,
};
use crate::{
frontend::{indexation::Index, CubeContext},
prelude::{assign, index, index_assign, Comptime},
};
use super::{
ArgSettings, CubePrimitive, ExpandElement, ExpandElementTyped, Init, LaunchArg,
LaunchArgExpand, TensorHandle, UInt,
};
/// A contiguous array of elements.
pub struct Array<E> {
_val: PhantomData<E>,
}
impl<C: CubeType> CubeType for Array<C> {
type ExpandType = ExpandElementTyped<Array<C>>;
}
impl<T: CubePrimitive + Clone> Array<T> {
pub fn new<S: Index>(_size: S) -> Self {
Array { _val: PhantomData }
}
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
Array { _val: PhantomData }
}
pub fn __expand_new<S: Index>(
context: &mut CubeContext,
size: S,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Array need constant initialization value"),
};
context
.create_local_array(Item::new(T::as_elem()), size)
.into()
}
pub fn __expand_vectorized<S: Index>(
context: &mut CubeContext,
size: S,
vectorization_factor: UInt,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context
.create_local_array(
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
size,
)
.into()
}
pub fn to_vectorized(self, _vectorization_factor: Comptime<UInt>) -> T {
unexpanded!()
}
}
impl<C: CubeType> ExpandElementTyped<Array<C>> {
pub fn to_vectorized_expand(
self,
context: &mut CubeContext,
vectorization_factor: UInt,
) -> ExpandElement {
let factor = vectorization_factor.val;
let var = self.expand.clone();
let mut new_var = context.create_local(Item::vectorized(var.item().elem(), factor as u8));
if vectorization_factor.val == 1 {
let element = index::expand(context, self.clone(), 0u32);
assign::expand(context, element, new_var.clone());
} else {
for i in 0..factor {
let element = index::expand(context, self.expand.clone(), i);
new_var = index_assign::expand(context, new_var, i, element);
}
}
new_var
}
}
impl<C: CubeType> CubeType for &Array<C> {
type ExpandType = ExpandElementTyped<Array<C>>;
}
impl<C: CubeType> Init for ExpandElementTyped<Array<C>> {
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
// The type can't be deeply cloned/copied.
self
}
}
impl<E: CubeType> Array<E> {
/// Obtain the array length
pub fn len(&self) -> UInt {
unexpanded!()
}
}
impl<C: CubePrimitive> LaunchArg for Array<C> {
type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;
}
impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
fn expand(
builder: &mut KernelBuilder,
vectorization: Vectorization,
) -> ExpandElementTyped<Array<C>> {
builder
.input_array(Item::vectorized(C::as_elem(), vectorization))
.into()
}
fn expand_output(
builder: &mut KernelBuilder,
vectorization: Vectorization,
) -> ExpandElementTyped<Array<C>> {
builder
.output_array(Item::vectorized(C::as_elem(), vectorization))
.into()
}
}
/// Tensor representation with a reference to the [server handle](burn_compute::server::Handle).
pub struct ArrayHandle<'a, R: Runtime> {
pub handle: &'a burn_compute::server::Handle<R::Server>,
pub length: [usize; 1],
}
pub enum ArrayArg<'a, R: Runtime> {
/// The array is passed with an array handle.
Handle {
/// The array handle.
handle: ArrayHandle<'a, R>,
/// The vectorization factor.
vectorization_factor: u8,
},
/// The array is aliasing another input array.
Alias {
/// The position of the input array.
input_pos: usize,
},
}
impl<'a, R: Runtime> ArgSettings<R> for ArrayArg<'a, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
if let ArrayArg::Handle {
handle,
vectorization_factor: _,
} = self
{
launcher.register_array(handle)
}
}
fn configure_input(&self, position: usize, settings: KernelSettings) -> KernelSettings {
match self {
Self::Handle {
handle: _,
vectorization_factor,
} => settings.vectorize_input(position, *vectorization_factor),
Self::Alias { input_pos: _ } => {
panic!("Not yet supported, only output can be aliased for now.");
}
}
}
fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings {
match self {
Self::Handle {
handle: _,
vectorization_factor,
} => settings.vectorize_output(position, *vectorization_factor),
Self::Alias { input_pos } => {
settings.mappings.push(crate::InplaceMapping {
pos_input: *input_pos,
pos_output: position,
});
settings
}
}
}
}
impl<'a, R: Runtime> ArrayArg<'a, R> {
/// Create a new array argument.
///
/// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization
/// factor of 1.
pub fn new(handle: &'a burn_compute::server::Handle<R::Server>, length: usize) -> Self {
ArrayArg::Handle {
handle: ArrayHandle::new(handle, length),
vectorization_factor: 1,
}
}
/// Create a new array argument specified with its vectorization factor.
pub fn vectorized(
vectorization_factor: u8,
handle: &'a burn_compute::server::Handle<R::Server>,
length: usize,
) -> Self {
ArrayArg::Handle {
handle: ArrayHandle::new(handle, length),
vectorization_factor,
}
}
}
impl<'a, R: Runtime> ArrayHandle<'a, R> {
pub fn new(handle: &'a burn_compute::server::Handle<R::Server>, length: usize) -> Self {
Self {
handle,
length: [length],
}
}
pub fn as_tensor(&self) -> TensorHandle<'_, R> {
let shape = &self.length;
TensorHandle {
handle: self.handle,
strides: &[1],
shape,
}
}
}

View File

@ -1,277 +0,0 @@
use std::marker::PhantomData;
use crate::{
ir::{Operator, Variable, Vectorization},
prelude::{init_expand, CubeContext, KernelBuilder, KernelLauncher},
KernelSettings, Runtime,
};
use alloc::rc::Rc;
use super::{UInt, Vectorized};
/// Types used in a cube function must implement this trait
///
/// Variables whose values will be known at runtime must
/// have ExpandElement as associated type
/// Variables whose values will be known at compile time
/// must have the primitive type as associated type
///
/// Note: Cube functions should be written using CubeTypes,
/// so that the code generated uses the associated ExpandType.
/// This allows Cube code to not necessitate cloning, which is cumbersome
/// in algorithmic code. The necessary cloning will automatically appear in
/// the generated code.
pub trait CubeType {
type ExpandType: Clone + Init;
}
/// Trait to be implemented by [cube types](CubeType) implementations.
pub trait Init: Sized {
/// Initialize a type within a [context](CubeContext).
///
/// You can return the same value when the variable is a non-mutable data structure or
/// if the type can not be deeply cloned/copied.
fn init(self, context: &mut CubeContext) -> Self;
}
/// Defines how a [launch argument](LaunchArg) can be expanded.
///
/// Normally this type should be implemented two times for an argument.
/// Once for the reference and the other for the mutable reference. Often time, the reference
/// should expand the argument as an input while the mutable reference should expand the argument
/// as an output.
pub trait LaunchArgExpand: CubeType {
/// Register an input variable during compilation that fill the [KernelBuilder].
fn expand(
builder: &mut KernelBuilder,
vectorization: Vectorization,
) -> <Self as CubeType>::ExpandType;
/// Register an output variable during compilation that fill the [KernelBuilder].
fn expand_output(
builder: &mut KernelBuilder,
vectorization: Vectorization,
) -> <Self as CubeType>::ExpandType {
Self::expand(builder, vectorization)
}
}
/// Defines a type that can be used as argument to a kernel.
pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
/// The runtime argument for the kernel.
type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
}
impl LaunchArg for () {
type RuntimeArg<'a, R: Runtime> = ();
}
impl<R: Runtime> ArgSettings<R> for () {
fn register(&self, _launcher: &mut KernelLauncher<R>) {
// nothing to do
}
}
impl LaunchArgExpand for () {
fn expand(
_builder: &mut KernelBuilder,
_vectorization: Vectorization,
) -> <Self as CubeType>::ExpandType {
}
}
impl CubeType for () {
type ExpandType = ();
}
impl Init for () {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
/// Defines the argument settings used to launch a kernel.
pub trait ArgSettings<R: Runtime>: Send + Sync {
/// Register the information to the [KernelLauncher].
fn register(&self, launcher: &mut KernelLauncher<R>);
/// Configure an input argument at the given position.
fn configure_input(&self, _position: usize, settings: KernelSettings) -> KernelSettings {
settings
}
/// Configure an output argument at the given position.
fn configure_output(&self, _position: usize, settings: KernelSettings) -> KernelSettings {
settings
}
}
/// Reference to a JIT variable
#[derive(Clone, Debug)]
pub enum ExpandElement {
/// Variable kept in the variable pool.
Managed(Rc<Variable>),
/// Variable not kept in the variable pool.
Plain(Variable),
}
/// Expand type associated with a type.
#[derive(new)]
pub struct ExpandElementTyped<T> {
pub(crate) expand: ExpandElement,
pub(crate) _type: PhantomData<T>,
}
impl<T> Vectorized for ExpandElementTyped<T> {
fn vectorization_factor(&self) -> UInt {
self.expand.vectorization_factor()
}
fn vectorize(self, factor: UInt) -> Self {
Self {
expand: self.expand.vectorize(factor),
_type: PhantomData,
}
}
}
impl<T> Clone for ExpandElementTyped<T> {
fn clone(&self) -> Self {
Self {
expand: self.expand.clone(),
_type: PhantomData,
}
}
}
impl<T> From<ExpandElement> for ExpandElementTyped<T> {
fn from(expand: ExpandElement) -> Self {
Self {
expand,
_type: PhantomData,
}
}
}
impl<T> From<ExpandElementTyped<T>> for ExpandElement {
fn from(value: ExpandElementTyped<T>) -> Self {
value.expand
}
}
impl ExpandElement {
pub fn can_mut(&self) -> bool {
match self {
ExpandElement::Managed(var) => {
if let Variable::Local { .. } = var.as_ref() {
Rc::strong_count(var) <= 2
} else {
false
}
}
ExpandElement::Plain(_) => false,
}
}
}
impl core::ops::Deref for ExpandElement {
type Target = Variable;
fn deref(&self) -> &Self::Target {
match self {
ExpandElement::Managed(var) => var.as_ref(),
ExpandElement::Plain(var) => var,
}
}
}
impl From<ExpandElement> for Variable {
fn from(value: ExpandElement) -> Self {
match value {
ExpandElement::Managed(var) => *var,
ExpandElement::Plain(var) => var,
}
}
}
impl Init for ExpandElement {
fn init(self, context: &mut CubeContext) -> Self {
if self.can_mut() {
// Can reuse inplace :)
return self;
}
let mut init = |elem: Self| init_expand(context, elem, Operator::Assign);
match *self {
Variable::GlobalScalar { .. } => init(self),
Variable::LocalScalar { .. } => init(self),
Variable::ConstantScalar { .. } => init(self),
Variable::Local { .. } => init(self),
// Constant should be initialized since the new variable can be mutated afterward.
// And it is assumed those values are cloned.
Variable::Rank
| Variable::UnitPos
| Variable::UnitPosX
| Variable::UnitPosY
| Variable::UnitPosZ
| Variable::CubePos
| Variable::CubePosX
| Variable::CubePosY
| Variable::CubePosZ
| Variable::CubeDim
| Variable::CubeDimX
| Variable::CubeDimY
| Variable::CubeDimZ
| Variable::CubeCount
| Variable::CubeCountX
| Variable::CubeCountY
| Variable::CubeCountZ
| Variable::SubcubeDim
| Variable::AbsolutePos
| Variable::AbsolutePosX
| Variable::AbsolutePosY
| Variable::AbsolutePosZ => init(self),
// Array types can't be copied, so we should simply return the same variable.
Variable::SharedMemory { .. }
| Variable::GlobalInputArray { .. }
| Variable::GlobalOutputArray { .. }
| Variable::LocalArray { .. }
| Variable::Slice { .. }
| Variable::Matrix { .. } => self,
}
}
}
macro_rules! impl_init_for {
($($t:ty),*) => {
$(
impl Init for $t {
fn init(self, _context: &mut CubeContext) -> Self {
panic!("Shouln't be called, only for comptime.")
}
}
)*
};
}
// Add all types used within comptime
impl_init_for!(u32, bool, UInt);
impl<T: Init> Init for Option<T> {
fn init(self, context: &mut CubeContext) -> Self {
self.map(|o| Init::init(o, context))
}
}
impl<T: CubeType> CubeType for Vec<T> {
type ExpandType = Vec<T::ExpandType>;
}
impl<T: CubeType> CubeType for &mut Vec<T> {
type ExpandType = Vec<T::ExpandType>;
}
impl<T: Init> Init for Vec<T> {
fn init(self, context: &mut CubeContext) -> Self {
self.into_iter().map(|e| e.init(context)).collect()
}
}

View File

@ -1,28 +0,0 @@
use crate::frontend::{CubePrimitive, CubeType, ExpandElement};
use crate::ir::Elem;
use super::Vectorized;
// To be consistent with other primitive type.
/// Boolean type.
pub type Bool = bool;
impl CubeType for bool {
type ExpandType = ExpandElement;
}
impl CubePrimitive for bool {
fn as_elem() -> Elem {
Elem::Bool
}
}
impl Vectorized for bool {
fn vectorization_factor(&self) -> crate::prelude::UInt {
todo!()
}
fn vectorize(self, _factor: crate::prelude::UInt) -> Self {
todo!()
}
}

View File

@ -1,31 +0,0 @@
use crate::frontend::{assign, CubeContext, CubePrimitive, CubeType};
use crate::ir::{Item, Variable};
use crate::{frontend::ExpandElement, unexpanded};
/// Enable elegant casting from any to any CubeElem
pub trait Cast: CubePrimitive {
fn cast_from<From: CubePrimitive>(value: From) -> Self;
fn __expand_cast_from<From>(
context: &mut CubeContext,
value: From,
) -> <Self as CubeType>::ExpandType
where
From: Into<ExpandElement>,
{
let value: ExpandElement = value.into();
let var: Variable = *value;
let new_var = context.create_local(Item::vectorized(
<Self as CubePrimitive>::as_elem(),
var.item().vectorization,
));
assign::expand(context, value, new_var.clone());
new_var
}
}
impl<P: CubePrimitive> Cast for P {
fn cast_from<From: CubePrimitive>(_value: From) -> Self {
unexpanded!()
}
}

View File

@ -1,49 +0,0 @@
use crate::frontend::UInt;
use crate::frontend::{CubeType, ExpandElement};
use crate::ir::{Elem, Variable};
use super::Vectorized;
/// Form of CubeType that encapsulates all primitive types:
/// Numeric, UInt, Bool
pub trait CubePrimitive:
CubeType<ExpandType = ExpandElement>
+ Vectorized
+ core::cmp::Eq
+ core::cmp::PartialEq
+ Send
+ Sync
+ 'static
+ Clone
+ Copy
{
/// Return the element type to use on GPU
fn as_elem() -> Elem;
}
macro_rules! impl_into_expand_element {
($type:ty) => {
impl From<$type> for ExpandElement {
fn from(value: $type) -> Self {
ExpandElement::Plain(Variable::from(value))
}
}
};
}
impl_into_expand_element!(u32);
impl_into_expand_element!(usize);
impl_into_expand_element!(bool);
impl_into_expand_element!(f32);
impl_into_expand_element!(i32);
impl_into_expand_element!(i64);
/// Useful for Comptime
impl From<UInt> for ExpandElement {
fn from(value: UInt) -> Self {
ExpandElement::Plain(crate::ir::Variable::ConstantScalar {
value: value.val as f64,
elem: UInt::as_elem(),
})
}
}

View File

@ -1,233 +0,0 @@
use half::{bf16, f16};
use crate::frontend::{Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Powf, Recip, Sin, Sqrt, Tanh};
use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric};
use crate::ir::{Elem, FloatKind, Item, Variable, Vectorization};
use crate::compute::{KernelBuilder, KernelLauncher};
use crate::prelude::index_assign;
use crate::{unexpanded, Runtime};
use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized};
/// Floating point numbers. Used as input in float kernels
pub trait Float:
Numeric
+ Exp
+ Log
+ Log1p
+ Cos
+ Sin
+ Tanh
+ Powf
+ Sqrt
+ Floor
+ Ceil
+ Erf
+ Recip
+ core::ops::Index<UInt, Output = Self>
+ core::ops::IndexMut<UInt, Output = Self>
{
fn new(val: f32) -> Self;
fn vectorized(val: f32, vectorization: UInt) -> Self;
fn vectorized_empty(vectorization: UInt) -> Self;
fn __expand_new(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
fn __expand_vectorized(
context: &mut CubeContext,
val: f32,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType;
fn __expand_vectorized_empty(
context: &mut CubeContext,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType;
}
macro_rules! impl_float {
($type:ident, $primitive:ty) => {
#[derive(Clone, Copy)]
pub struct $type {
pub val: f32,
pub vectorization: u8,
}
impl CubeType for $type {
type ExpandType = ExpandElement;
}
impl CubePrimitive for $type {
/// Return the element type to use on GPU
fn as_elem() -> Elem {
Elem::Float(FloatKind::$type)
}
}
impl Numeric for $type {
type Primitive = $primitive;
}
impl Float for $type {
fn new(val: f32) -> Self {
Self {
val,
vectorization: 1,
}
}
fn vectorized(val: f32, vectorization: UInt) -> Self {
if vectorization.val == 1 {
Self::new(val)
} else {
Self {
val,
vectorization: vectorization.val as u8,
}
}
}
fn vectorized_empty(vectorization: UInt) -> Self {
Self::vectorized(0., vectorization)
}
fn __expand_new(
_context: &mut CubeContext,
val: f32,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}
fn __expand_vectorized(
context: &mut CubeContext,
val: f32,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::__expand_new(context, val)
} else {
let mut new_var = context
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() {
new_var = index_assign::expand(context, new_var, i, *element);
}
new_var
}
}
fn __expand_vectorized_empty(
context: &mut CubeContext,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::__expand_new(context, 0.)
} else {
context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8))
}
}
}
impl core::ops::Index<UInt> for $type {
type Output = Self;
fn index(&self, _index: UInt) -> &Self::Output {
unexpanded!()
}
}
impl core::ops::IndexMut<UInt> for $type {
fn index_mut(&mut self, _index: UInt) -> &mut Self::Output {
unexpanded!()
}
}
impl LaunchArgExpand for $type {
fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
assert_eq!(vectorization, 1, "Attempted to vectorize a scalar");
builder.scalar($type::as_elem())
}
}
impl Vectorized for $type {
fn vectorization_factor(&self) -> UInt {
UInt {
val: self.vectorization as u32,
vectorization: 1,
}
}
fn vectorize(mut self, factor: UInt) -> Self {
self.vectorization = factor.vectorization;
self
}
}
};
}
impl_float!(F16, f16);
impl_float!(BF16, bf16);
impl_float!(F32, f32);
impl_float!(F64, f64);
impl From<f32> for F32 {
fn from(value: f32) -> Self {
Self {
val: value,
vectorization: 1,
}
}
}
impl From<f32> for BF16 {
fn from(value: f32) -> Self {
Self {
val: value,
vectorization: 1,
}
}
}
impl From<f32> for F16 {
fn from(value: f32) -> Self {
Self {
val: value,
vectorization: 1,
}
}
}
impl From<f32> for F64 {
fn from(value: f32) -> Self {
Self {
val: value,
vectorization: 1,
}
}
}
impl ScalarArgSettings for f16 {
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
settings.register_f16(*self);
}
}
impl ScalarArgSettings for bf16 {
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
settings.register_bf16(*self);
}
}
impl ScalarArgSettings for f32 {
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
settings.register_f32(*self);
}
}
impl ScalarArgSettings for f64 {
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
settings.register_f64(*self);
}
}

View File

@ -1,146 +0,0 @@
use crate::compute::{KernelBuilder, KernelLauncher};
use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric};
use crate::ir::{Elem, IntKind, Item, Variable, Vectorization};
use crate::prelude::index_assign;
use crate::Runtime;
use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized};
/// Signed integer. Used as input in int kernels
pub trait Int: Numeric + std::ops::Rem<Output = Self> {
fn new(val: i64) -> Self;
fn vectorized(val: i64, vectorization: UInt) -> Self;
fn __expand_new(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType;
fn __expand_vectorized(
context: &mut CubeContext,
val: i64,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType;
}
macro_rules! impl_int {
($type:ident, $primitive:ty) => {
#[derive(Clone, Copy)]
pub struct $type {
pub val: $primitive,
pub vectorization: u8,
}
impl CubeType for $type {
type ExpandType = ExpandElement;
}
impl CubePrimitive for $type {
fn as_elem() -> Elem {
Elem::Int(IntKind::$type)
}
}
impl Numeric for $type {
type Primitive = $primitive;
}
impl Int for $type {
fn new(val: i64) -> Self {
Self {
val: val as $primitive,
vectorization: 1,
}
}
fn vectorized(val: i64, vectorization: UInt) -> Self {
if vectorization.val == 1 {
Self::new(val)
} else {
Self {
val: val as $primitive,
vectorization: vectorization.val as u8,
}
}
}
fn __expand_new(
_context: &mut CubeContext,
val: i64,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}
fn __expand_vectorized(
context: &mut CubeContext,
val: i64,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::__expand_new(context, val)
} else {
let mut new_var = context
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() {
new_var = index_assign::expand(context, new_var, i, *element);
}
new_var
}
}
}
impl LaunchArgExpand for $type {
fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement {
assert_eq!(vectorization, 1, "Attempted to vectorize a scalar");
builder.scalar($type::as_elem())
}
}
impl Vectorized for $type {
fn vectorization_factor(&self) -> UInt {
UInt {
val: self.vectorization as u32,
vectorization: 1,
}
}
fn vectorize(mut self, factor: UInt) -> Self {
self.vectorization = factor.vectorization;
self
}
}
};
}
impl_int!(I32, i32);
impl_int!(I64, i64);
impl From<i64> for I64 {
fn from(value: i64) -> Self {
Self {
val: value,
vectorization: 1,
}
}
}
impl From<i32> for I32 {
fn from(value: i32) -> Self {
Self {
val: value,
vectorization: 1,
}
}
}
impl ScalarArgSettings for i32 {
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
settings.register_i32(*self);
}
}
impl ScalarArgSettings for i64 {
fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
settings.register_i64(*self);
}
}

View File

@ -1,27 +0,0 @@
mod array;
mod base;
mod bool;
mod cast;
mod cube_elem;
mod float;
mod int;
mod numeric;
mod shared_memory;
mod slice;
mod tensor;
mod uint;
mod vectorized;
pub use array::*;
pub use base::*;
pub use bool::*;
pub use cast::*;
pub use cube_elem::*;
pub use float::*;
pub use int::*;
pub use numeric::*;
pub use shared_memory::*;
pub use slice::*;
pub use tensor::*;
pub use uint::*;
pub use vectorized::*;

View File

@ -1,93 +0,0 @@
use crate::compute::KernelLauncher;
use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement};
use crate::ir::{Item, Variable};
use crate::prelude::Clamp;
use crate::Runtime;
use crate::{
frontend::{index_assign, Abs, Max, Min, Remainder},
unexpanded,
};
use super::{ArgSettings, LaunchArg, LaunchArgExpand};
/// Type that encompasses both (unsigned or signed) integers and floats
/// Used in kernels that should work for both.
pub trait Numeric:
Copy
+ CubePrimitive
+ LaunchArgExpand
+ std::ops::Add<Output = Self>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::Sub<Output = Self>
+ std::ops::Mul<Output = Self>
+ std::ops::Div<Output = Self>
+ std::cmp::PartialOrd
+ Abs
+ Max
+ Min
+ Clamp
+ Remainder
{
/// Create a new constant numeric.
///
/// Note: since this must work for both integer and float
/// only the less expressive of both can be created (int)
/// If a number with decimals is needed, use Float::new.
///
/// This method panics when unexpanded. For creating an element
/// with a val, use the new method of the sub type.
fn from_int(_val: i64) -> Self {
unexpanded!()
}
type Primitive: ScalarArgSettings;
fn from_vec<const D: usize>(_vec: [i64; D]) -> Self {
unexpanded!()
}
fn __expand_from_int(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}
fn __expand_from_vec<const D: usize>(
context: &mut CubeContext,
vec: [i64; D],
) -> <Self as CubeType>::ExpandType {
let mut new_var = context.create_local(Item::vectorized(Self::as_elem(), vec.len() as u8));
for (i, element) in vec.iter().enumerate() {
new_var = index_assign::expand(context, new_var, i, *element);
}
new_var
}
}
/// Similar to [ArgSettings], however only for scalar types that don't depend on the [Runtime]
/// trait.
pub trait ScalarArgSettings: Send + Sync {
/// Register the information to the [KernelLauncher].
fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
}
#[derive(new)]
pub struct ScalarArg<T: Numeric> {
elem: T::Primitive,
}
impl<T: Numeric, R: Runtime> ArgSettings<R> for ScalarArg<T> {
fn register(&self, launcher: &mut crate::compute::KernelLauncher<R>) {
self.elem.register(launcher);
}
}
impl<T: Numeric> LaunchArg for T {
type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
}

View File

@ -1,63 +0,0 @@
use std::marker::PhantomData;
use crate::{
frontend::{indexation::Index, CubeContext, CubePrimitive, CubeType},
ir::Item,
};
use super::{ExpandElementTyped, Init, UInt};
#[derive(Clone, Copy)]
pub struct SharedMemory<T: CubeType> {
_val: PhantomData<T>,
}
impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
impl<T: CubePrimitive> CubeType for SharedMemory<T> {
type ExpandType = ExpandElementTyped<SharedMemory<T>>;
}
impl<T: CubePrimitive + Clone> SharedMemory<T> {
pub fn new<S: Index>(_size: S) -> Self {
SharedMemory { _val: PhantomData }
}
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
SharedMemory { _val: PhantomData }
}
pub fn __expand_vectorized<S: Index>(
context: &mut CubeContext,
size: S,
vectorization_factor: UInt,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
let var = context.create_shared(
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
size,
);
ExpandElementTyped::new(var)
}
pub fn __expand_new<S: Index>(
context: &mut CubeContext,
size: S,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
let var = context.create_shared(Item::new(T::as_elem()), size);
ExpandElementTyped::new(var)
}
}

Some files were not shown because too many files have changed in this diff Show More