mirror of https://github.com/tracel-ai/burn.git
Update cuda-jit (#1799)
This commit is contained in:
parent
23c622a9f8
commit
c7ad25ab60
|
@ -247,6 +247,7 @@ dependencies = [
|
|||
"arboard",
|
||||
"burn",
|
||||
"burn-common",
|
||||
"burn-cuda",
|
||||
"burn-wgpu",
|
||||
"clap 4.5.4",
|
||||
"colored",
|
||||
|
@ -546,6 +547,23 @@ dependencies = [
|
|||
"syn 2.0.65",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-cuda"
|
||||
version = "0.15.0"
|
||||
dependencies = [
|
||||
"burn-common",
|
||||
"burn-compute",
|
||||
"burn-cube",
|
||||
"burn-fusion",
|
||||
"burn-jit",
|
||||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"cudarc 0.10.0",
|
||||
"derive-new",
|
||||
"half",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-dataset"
|
||||
version = "0.15.0"
|
||||
|
@ -829,7 +847,7 @@ dependencies = [
|
|||
"byteorder",
|
||||
"candle-kernels",
|
||||
"candle-metal-kernels",
|
||||
"cudarc",
|
||||
"cudarc 0.11.0",
|
||||
"gemm",
|
||||
"half",
|
||||
"libc",
|
||||
|
@ -1323,6 +1341,12 @@ dependencies = [
|
|||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9395df0cab995685664e79cc35ad6302bf08fb9c5d82301875a183affe1278b1"
|
||||
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.11.0"
|
||||
|
|
|
@ -16,7 +16,7 @@ members = [
|
|||
|
||||
exclude = [
|
||||
"examples/notebook",
|
||||
"crates/burn-cuda" # comment this line to work on burn-cuda
|
||||
# "crates/burn-cuda" # comment this line to work on burn-cuda
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
|
|
@ -24,12 +24,14 @@ tch-cpu = ["burn/tch"]
|
|||
tch-gpu = ["burn/tch"]
|
||||
wgpu = ["burn/wgpu", "burn/autotune"]
|
||||
wgpu-fusion = ["wgpu", "burn/fusion"]
|
||||
cuda-jit = ["burn-cuda"]
|
||||
|
||||
[dependencies]
|
||||
arboard = { workspace = true }
|
||||
burn = { path = "../crates/burn", default-features = false }
|
||||
burn-common = { path = "../crates/burn-common", version = "0.15.0" }
|
||||
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0" }
|
||||
burn-cuda = { path = "../crates/burn-cuda", version = "0.15.0", optional = true }
|
||||
clap = { workspace = true }
|
||||
colored = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
|
|
|
@ -78,6 +78,8 @@ enum BackendValues {
|
|||
Wgpu,
|
||||
#[strum(to_string = "wgpu-fusion")]
|
||||
WgpuFusion,
|
||||
#[strum(to_string = "cuda-jit")]
|
||||
CudaJit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
|
||||
|
|
|
@ -57,6 +57,8 @@ macro_rules! bench_on_backend {
|
|||
let feature_name = "wgpu";
|
||||
#[cfg(feature = "wgpu-fusion")]
|
||||
let feature_name = "wgpu-fusion";
|
||||
#[cfg(feature = "cuda-jit")]
|
||||
let feature_name = "cuda-jit";
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
{
|
||||
|
@ -129,6 +131,13 @@ macro_rules! bench_on_backend {
|
|||
let device = CandleDevice::Metal(0);
|
||||
bench::<Candle>(&device, feature_name, url, token);
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda-jit")]
|
||||
{
|
||||
use burn_cuda::{Cuda, CudaDevice};
|
||||
|
||||
bench::<Cuda>(&CudaDevice::default(), feature_name, url, token);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ impl BenchmarkComputations {
|
|||
/// Benchmark trait.
|
||||
pub trait Benchmark {
|
||||
/// Benchmark arguments.
|
||||
type Args;
|
||||
type Args: Clone;
|
||||
|
||||
/// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
|
||||
/// count as included in the duration.
|
||||
|
@ -149,19 +149,20 @@ pub trait Benchmark {
|
|||
#[cfg(feature = "std")]
|
||||
{
|
||||
// Warmup
|
||||
self.execute(self.prepare());
|
||||
let args = self.prepare();
|
||||
|
||||
self.execute(args.clone());
|
||||
self.sync();
|
||||
|
||||
let mut durations = Vec::with_capacity(self.num_samples());
|
||||
|
||||
for _ in 0..self.num_samples() {
|
||||
// Prepare
|
||||
let args = self.prepare();
|
||||
self.sync();
|
||||
|
||||
// Execute the benchmark
|
||||
let start = Instant::now();
|
||||
self.execute(args);
|
||||
self.execute(args.clone());
|
||||
self.sync();
|
||||
let end = Instant::now();
|
||||
|
||||
|
|
|
@ -15,6 +15,12 @@ pub struct TuneBenchmark<S: ComputeServer, C> {
|
|||
client: ComputeClient<S, C>,
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn AutotuneOperation> {
|
||||
fn clone(&self) -> Self {
|
||||
self.as_ref().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
|
||||
type Args = Box<dyn AutotuneOperation>;
|
||||
|
||||
|
|
|
@ -22,9 +22,10 @@ burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false
|
|||
burn-compute = { path = "../burn-compute", version = "0.15.0" }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0" }
|
||||
burn-common = { path = "../burn-common", version = "0.15.0" }
|
||||
burn-cube = { path = "../burn-cube", version = "0.15.0" }
|
||||
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
|
||||
half = { workspace = true }
|
||||
|
||||
half = { workspace = true }
|
||||
bytemuck = { workspace = true }
|
||||
cudarc = "0.10.0"
|
||||
|
||||
|
@ -37,4 +38,4 @@ burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false,
|
|||
] }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
features = ["doc"]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use burn_cube::{dialect as gpu, Compiler};
|
||||
|
||||
use super::Instruction;
|
||||
use burn_jit::gpu::{self};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[derive(new, Clone, Debug, Default)]
|
||||
|
@ -16,15 +17,15 @@ pub struct CudaCompiler {
|
|||
global_invocation_id: (bool, bool, bool),
|
||||
}
|
||||
|
||||
impl burn_jit::Compiler for CudaCompiler {
|
||||
impl Compiler for CudaCompiler {
|
||||
type Representation = super::ComputeShader;
|
||||
|
||||
fn compile(shader: burn_jit::gpu::ComputeShader) -> Self::Representation {
|
||||
fn compile(shader: burn_cube::dialect::ComputeShader) -> Self::Representation {
|
||||
let compiler = Self::default();
|
||||
compiler.compile_shader(shader)
|
||||
}
|
||||
|
||||
fn elem_size(elem: burn_jit::gpu::Elem) -> usize {
|
||||
fn elem_size(elem: gpu::Elem) -> usize {
|
||||
Self::compile_elem(elem).size()
|
||||
}
|
||||
|
||||
|
@ -75,44 +76,7 @@ impl CudaCompiler {
|
|||
|
||||
fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec<Instruction> {
|
||||
let mut instructions = Vec::new();
|
||||
let mut processing = value.process();
|
||||
|
||||
for operation in &mut processing.operations {
|
||||
if let gpu::Operation::Operator(gpu::Operator::Index(operands)) = operation {
|
||||
// Replace all Index operators for global arrays with CheckedIndexAssign procedures
|
||||
match operands.lhs {
|
||||
gpu::Variable::GlobalInputArray(_, _)
|
||||
| gpu::Variable::GlobalOutputArray(_, _) => {
|
||||
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndex(
|
||||
gpu::CheckedIndex {
|
||||
lhs: operands.lhs,
|
||||
rhs: operands.rhs,
|
||||
out: operands.out,
|
||||
},
|
||||
));
|
||||
}
|
||||
// Cannot perform bound check on non-global arrays, do nothing.
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
if let gpu::Operation::Operator(gpu::Operator::IndexAssign(operands)) = operation {
|
||||
// Replace all IndexAssign operators of global arrays with CheckedIndexAssign procedures
|
||||
match operands.out {
|
||||
gpu::Variable::GlobalInputArray(_, _)
|
||||
| gpu::Variable::GlobalOutputArray(_, _) => {
|
||||
*operation = gpu::Operation::Procedure(gpu::Procedure::CheckedIndexAssign(
|
||||
gpu::CheckedIndexAssign {
|
||||
lhs: operands.lhs,
|
||||
rhs: operands.rhs,
|
||||
out: operands.out,
|
||||
},
|
||||
));
|
||||
}
|
||||
// Cannot perform bound check on non-global arrays, do nothing.
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
let processing = value.process();
|
||||
|
||||
for var in processing.variables {
|
||||
instructions.push(Instruction::DeclareVariable {
|
||||
|
@ -415,11 +379,12 @@ impl CudaCompiler {
|
|||
}
|
||||
|
||||
fn compile_item(item: gpu::Item) -> super::Item {
|
||||
match item {
|
||||
gpu::Item::Vec4(elem) => super::Item::Vec4(Self::compile_elem(elem)),
|
||||
gpu::Item::Vec3(elem) => super::Item::Vec3(Self::compile_elem(elem)),
|
||||
gpu::Item::Vec2(elem) => super::Item::Vec2(Self::compile_elem(elem)),
|
||||
gpu::Item::Scalar(elem) => super::Item::Scalar(Self::compile_elem(elem)),
|
||||
match item.vectorization {
|
||||
4 => super::Item::Vec4(Self::compile_elem(item.elem)),
|
||||
3 => super::Item::Vec3(Self::compile_elem(item.elem)),
|
||||
2 => super::Item::Vec2(Self::compile_elem(item.elem)),
|
||||
1 => super::Item::Scalar(Self::compile_elem(item.elem)),
|
||||
_ => panic!("Vectorization factor unsupported {:?}", item.vectorization),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn_jit::gpu;
|
||||
use burn_cube::dialect as gpu;
|
||||
use half::{bf16, f16};
|
||||
use std::fmt::Display;
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use burn_cube::{dialect::WorkgroupSize, CompilerRepresentation};
|
||||
|
||||
// use super::{Body, Extension, Item};
|
||||
use super::{Body, Item};
|
||||
use burn_jit::{gpu::WorkgroupSize, CompilerRepresentation};
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
|
|
|
@ -4,8 +4,11 @@ use burn_compute::{
|
|||
memory_management::MemoryManagement,
|
||||
server::{self, ComputeServer},
|
||||
};
|
||||
use burn_jit::compute::{JitAutotuneKey, Kernel, WorkGroup};
|
||||
use burn_jit::gpu::WorkgroupSize;
|
||||
use burn_cube::dialect::WorkgroupSize;
|
||||
use burn_cube::JitKernel;
|
||||
use burn_cube::Kernel;
|
||||
use burn_cube::WorkGroup;
|
||||
use burn_jit::JitAutotuneKey;
|
||||
use cudarc::driver::sys::CUctx_st;
|
||||
use cudarc::driver::sys::CUfunc_st;
|
||||
use std::collections::HashMap;
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
use burn_jit::JitElement;
|
||||
|
||||
use crate::compiler;
|
||||
|
||||
/// The base element trait for the wgpu backend.
|
||||
pub trait CudaElement: JitElement {
|
||||
fn cuda_elem() -> compiler::Elem;
|
||||
}
|
||||
|
||||
/// The float element type for the wgpu backend.
|
||||
pub trait FloatElement: CudaElement + burn_jit::FloatElement {}
|
||||
|
||||
/// The int element type for the wgpu backend.
|
||||
pub trait IntElement: CudaElement + burn_jit::IntElement {}
|
||||
|
||||
impl CudaElement for u32 {
|
||||
fn cuda_elem() -> compiler::Elem {
|
||||
compiler::Elem::U32
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaElement for i32 {
|
||||
fn cuda_elem() -> compiler::Elem {
|
||||
compiler::Elem::I32
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaElement for f32 {
|
||||
fn cuda_elem() -> compiler::Elem {
|
||||
compiler::Elem::F32
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaElement for half::bf16 {
|
||||
fn cuda_elem() -> compiler::Elem {
|
||||
compiler::Elem::BF16
|
||||
}
|
||||
}
|
||||
|
||||
impl FloatElement for f32 {}
|
||||
impl FloatElement for half::bf16 {}
|
||||
impl IntElement for i32 {}
|
|
@ -4,7 +4,6 @@ extern crate alloc;
|
|||
|
||||
mod compute;
|
||||
mod device;
|
||||
mod element;
|
||||
mod runtime;
|
||||
|
||||
pub mod compiler;
|
||||
|
|
|
@ -6,7 +6,7 @@ use burn_compute::{
|
|||
tune::Tuner,
|
||||
ComputeRuntime,
|
||||
};
|
||||
use burn_jit::Runtime;
|
||||
use burn_cube::Runtime;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
|
@ -18,6 +18,11 @@ use crate::{
|
|||
#[derive(Debug)]
|
||||
pub struct CudaRuntime;
|
||||
|
||||
impl burn_jit::JitRuntime for CudaRuntime {
|
||||
type JitDevice = CudaDevice;
|
||||
type JitServer = CudaServer<SimpleMemoryManagement<CudaStorage>>;
|
||||
}
|
||||
|
||||
// static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
|
||||
static RUNTIME: ComputeRuntime<CudaDevice, Server, MutexComputeChannel<Server>> =
|
||||
ComputeRuntime::new();
|
||||
|
@ -51,7 +56,7 @@ impl Runtime for CudaRuntime {
|
|||
let memory_management = SimpleMemoryManagement::new(
|
||||
storage,
|
||||
DeallocStrategy::new_period_tick(1),
|
||||
SliceStrategy::Never,
|
||||
SliceStrategy::Ratio(0.8),
|
||||
);
|
||||
CudaContext::new(memory_management, stream, ctx)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue