Update cuda-jit (#1799)

This commit is contained in:
Nathaniel Simard 2024-05-24 11:31:47 -04:00 committed by GitHub
parent 23c622a9f8
commit c7ad25ab60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 80 additions and 104 deletions

26
Cargo.lock generated
View File

@ -247,6 +247,7 @@ dependencies = [
"arboard",
"burn",
"burn-common",
"burn-cuda",
"burn-wgpu",
"clap 4.5.4",
"colored",
@ -546,6 +547,23 @@ dependencies = [
"syn 2.0.65",
]
[[package]]
name = "burn-cuda"
version = "0.15.0"
dependencies = [
"burn-common",
"burn-compute",
"burn-cube",
"burn-fusion",
"burn-jit",
"burn-tensor",
"bytemuck",
"cudarc 0.10.0",
"derive-new",
"half",
"log",
]
[[package]]
name = "burn-dataset"
version = "0.15.0"
@ -829,7 +847,7 @@ dependencies = [
"byteorder",
"candle-kernels",
"candle-metal-kernels",
"cudarc",
"cudarc 0.11.0",
"gemm",
"half",
"libc",
@ -1323,6 +1341,12 @@ dependencies = [
"memchr",
]
[[package]]
name = "cudarc"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9395df0cab995685664e79cc35ad6302bf08fb9c5d82301875a183affe1278b1"
[[package]]
name = "cudarc"
version = "0.11.0"

View File

@ -16,7 +16,7 @@ members = [
exclude = [
"examples/notebook",
"crates/burn-cuda" # comment this line to work on burn-cuda
# "crates/burn-cuda" # comment this line to work on burn-cuda
]
[workspace.package]

View File

@ -24,12 +24,14 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
cuda-jit = ["burn-cuda"]
[dependencies]
arboard = { workspace = true }
burn = { path = "../crates/burn", default-features = false }
burn-common = { path = "../crates/burn-common", version = "0.15.0" }
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0" }
burn-cuda = { path = "../crates/burn-cuda", version = "0.15.0", optional = true }
clap = { workspace = true }
colored = { workspace = true }
derive-new = { workspace = true }

View File

@ -78,6 +78,8 @@ enum BackendValues {
Wgpu,
#[strum(to_string = "wgpu-fusion")]
WgpuFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
}
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]

View File

@ -57,6 +57,8 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu";
#[cfg(feature = "wgpu-fusion")]
let feature_name = "wgpu-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";
#[cfg(feature = "wgpu")]
{
@ -129,6 +131,13 @@ macro_rules! bench_on_backend {
let device = CandleDevice::Metal(0);
bench::<Candle>(&device, feature_name, url, token);
}
#[cfg(feature = "cuda-jit")]
{
use burn_cuda::{Cuda, CudaDevice};
bench::<Cuda>(&CudaDevice::default(), feature_name, url, token);
}
};
}

View File

@ -112,7 +112,7 @@ impl BenchmarkComputations {
/// Benchmark trait.
pub trait Benchmark {
/// Benchmark arguments.
type Args;
type Args: Clone;
/// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
/// count as included in the duration.
@ -149,19 +149,20 @@ pub trait Benchmark {
#[cfg(feature = "std")]
{
// Warmup
self.execute(self.prepare());
let args = self.prepare();
self.execute(args.clone());
self.sync();
let mut durations = Vec::with_capacity(self.num_samples());
for _ in 0..self.num_samples() {
// Prepare
let args = self.prepare();
self.sync();
// Execute the benchmark
let start = Instant::now();
self.execute(args);
self.execute(args.clone());
self.sync();
let end = Instant::now();

View File

@ -15,6 +15,12 @@ pub struct TuneBenchmark<S: ComputeServer, C> {
client: ComputeClient<S, C>,
}
impl Clone for Box<dyn AutotuneOperation> {
fn clone(&self) -> Self {
self.as_ref().clone()
}
}
impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
type Args = Box<dyn AutotuneOperation>;

View File

@ -22,9 +22,10 @@ burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false
burn-compute = { path = "../burn-compute", version = "0.15.0" }
burn-tensor = { path = "../burn-tensor", version = "0.15.0" }
burn-common = { path = "../burn-common", version = "0.15.0" }
burn-cube = { path = "../burn-cube", version = "0.15.0" }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
half = { workspace = true }
half = { workspace = true }
bytemuck = { workspace = true }
cudarc = "0.10.0"
@ -37,4 +38,4 @@ burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false,
] }
[package.metadata.docs.rs]
features = ["doc"]
features = ["doc"]

View File

@ -1,5 +1,6 @@
use burn_cube::{dialect as gpu, Compiler};
use super::Instruction;
use burn_jit::gpu::{self};
#[allow(clippy::too_many_arguments)]
#[derive(new, Clone, Debug, Default)]
@ -16,15 +17,15 @@ pub struct CudaCompiler {
global_invocation_id: (bool, bool, bool),
}
impl burn_jit::Compiler for CudaCompiler {
impl Compiler for CudaCompiler {
type Representation = super::ComputeShader;
fn compile(shader: burn_jit::gpu::ComputeShader) -> Self::Representation {
fn compile(shader: burn_cube::dialect::ComputeShader) -> Self::Representation {
let compiler = Self::default();
compiler.compile_shader(shader)
}
fn elem_size(elem: burn_jit::gpu::Elem) -> usize {
fn elem_size(elem: gpu::Elem) -> usize {
Self::compile_elem(elem).size()
}
@ -75,44 +76,7 @@ impl CudaCompiler {
fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec<Instruction> {
let mut instructions = Vec::new();
let mut processing = value.process();
for operation in &mut processing.operations {
if let gpu::Operation::Operator(gpu::Operator::Index(operands)) = operation {
// Replace all Index operators for global arrays with CheckedIndexAssign procedures
match operands.lhs {
gpu::Variable::GlobalInputArray(_, _)
| gpu::Variable::GlobalOutputArray(_, _) => {
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndex(
gpu::CheckedIndex {
lhs: operands.lhs,
rhs: operands.rhs,
out: operands.out,
},
));
}
// Cannot perform bound check on non-global arrays, do nothing.
_ => (),
}
}
if let gpu::Operation::Operator(gpu::Operator::IndexAssign(operands)) = operation {
// Replace all IndexAssign operators of global arrays with CheckedIndexAssign procedures
match operands.out {
gpu::Variable::GlobalInputArray(_, _)
| gpu::Variable::GlobalOutputArray(_, _) => {
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndexAssign(
gpu::CheckedIndexAssign {
lhs: operands.lhs,
rhs: operands.rhs,
out: operands.out,
},
));
}
// Cannot perform bound check on non-global arrays, do nothing.
_ => (),
}
}
}
let processing = value.process();
for var in processing.variables {
instructions.push(Instruction::DeclareVariable {
@ -415,11 +379,12 @@ impl CudaCompiler {
}
fn compile_item(item: gpu::Item) -> super::Item {
match item {
gpu::Item::Vec4(elem) => super::Item::Vec4(Self::compile_elem(elem)),
gpu::Item::Vec3(elem) => super::Item::Vec3(Self::compile_elem(elem)),
gpu::Item::Vec2(elem) => super::Item::Vec2(Self::compile_elem(elem)),
gpu::Item::Scalar(elem) => super::Item::Scalar(Self::compile_elem(elem)),
match item.vectorization {
4 => super::Item::Vec4(Self::compile_elem(item.elem)),
3 => super::Item::Vec3(Self::compile_elem(item.elem)),
2 => super::Item::Vec2(Self::compile_elem(item.elem)),
1 => super::Item::Scalar(Self::compile_elem(item.elem)),
_ => panic!("Vectorization factor unsupported {:?}", item.vectorization),
}
}

View File

@ -1,4 +1,4 @@
use burn_jit::gpu;
use burn_cube::dialect as gpu;
use half::{bf16, f16};
use std::fmt::Display;

View File

@ -1,6 +1,7 @@
use burn_cube::{dialect::WorkgroupSize, CompilerRepresentation};
// use super::{Body, Extension, Item};
use super::{Body, Item};
use burn_jit::{gpu::WorkgroupSize, CompilerRepresentation};
use std::fmt::Display;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]

View File

@ -4,8 +4,11 @@ use burn_compute::{
memory_management::MemoryManagement,
server::{self, ComputeServer},
};
use burn_jit::compute::{JitAutotuneKey, Kernel, WorkGroup};
use burn_jit::gpu::WorkgroupSize;
use burn_cube::dialect::WorkgroupSize;
use burn_cube::JitKernel;
use burn_cube::Kernel;
use burn_cube::WorkGroup;
use burn_jit::JitAutotuneKey;
use cudarc::driver::sys::CUctx_st;
use cudarc::driver::sys::CUfunc_st;
use std::collections::HashMap;

View File

@ -1,42 +0,0 @@
use burn_jit::JitElement;
use crate::compiler;
/// The base element trait for the wgpu backend.
pub trait CudaElement: JitElement {
fn cuda_elem() -> compiler::Elem;
}
/// The float element type for the wgpu backend.
pub trait FloatElement: CudaElement + burn_jit::FloatElement {}
/// The int element type for the wgpu backend.
pub trait IntElement: CudaElement + burn_jit::IntElement {}
impl CudaElement for u32 {
fn cuda_elem() -> compiler::Elem {
compiler::Elem::U32
}
}
impl CudaElement for i32 {
fn cuda_elem() -> compiler::Elem {
compiler::Elem::I32
}
}
impl CudaElement for f32 {
fn cuda_elem() -> compiler::Elem {
compiler::Elem::F32
}
}
impl CudaElement for half::bf16 {
fn cuda_elem() -> compiler::Elem {
compiler::Elem::BF16
}
}
impl FloatElement for f32 {}
impl FloatElement for half::bf16 {}
impl IntElement for i32 {}

View File

@ -4,7 +4,6 @@ extern crate alloc;
mod compute;
mod device;
mod element;
mod runtime;
pub mod compiler;

View File

@ -6,7 +6,7 @@ use burn_compute::{
tune::Tuner,
ComputeRuntime,
};
use burn_jit::Runtime;
use burn_cube::Runtime;
use std::sync::Arc;
use crate::{
@ -18,6 +18,11 @@ use crate::{
#[derive(Debug)]
pub struct CudaRuntime;
impl burn_jit::JitRuntime for CudaRuntime {
type JitDevice = CudaDevice;
type JitServer = CudaServer<SimpleMemoryManagement<CudaStorage>>;
}
// static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
ComputeRuntime::new();
@ -51,7 +56,7 @@ impl Runtime for CudaRuntime {
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(1),
SliceStrategy::Never,
SliceStrategy::Ratio(0.8),
);
CudaContext::new(memory_management, stream, ctx)
}