Perf: cube reuse shape and strides (#1939)

This commit is contained in:
Nathaniel Simard 2024-07-02 08:28:32 -04:00 committed by GitHub
parent 849c8f453b
commit ad81a997af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 109 additions and 13 deletions

View File

@ -1,6 +1,30 @@
use burn_cube::ir as cube;
use std::fmt::Display;
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
pub struct ConstantShape {
pub position: usize,
pub dim: usize,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
pub struct ConstantStride {
pub position: usize,
pub dim: usize,
}
impl Display for ConstantStride {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("stride_{}_{}", self.position, self.dim))
}
}
impl Display for ConstantShape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("shape_{}_{}", self.position, self.dim))
}
}
#[derive(Debug, Clone)]
pub enum Variable {
SubgroupSize,
@ -41,6 +65,8 @@ pub enum Variable {
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
ConstantShape(ConstantShape),
ConstantStride(ConstantStride),
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -106,6 +132,8 @@ impl Variable {
Variable::WorkgroupSize => true,
Variable::NumWorkgroups => true,
Variable::SubgroupSize => true,
Variable::ConstantShape(_) => true,
Variable::ConstantStride(_) => true,
}
}
pub fn index(&self, index: usize) -> IndexedVariable {
@ -155,6 +183,8 @@ impl Variable {
Self::NumWorkgroupsY => Item::Scalar(Elem::U32),
Self::NumWorkgroupsZ => Item::Scalar(Elem::U32),
Self::SubgroupSize => Item::Scalar(Elem::U32),
Self::ConstantShape(_) => Item::Scalar(Elem::U32),
Self::ConstantStride(_) => Item::Scalar(Elem::U32),
}
}
pub fn elem(&self) -> Elem {
@ -262,6 +292,8 @@ impl Display for Variable {
Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"),
Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"),
Variable::SubgroupSize => f.write_str("subgroup_size"),
Variable::ConstantShape(val) => f.write_fmt(format_args!("{val}")),
Variable::ConstantStride(val) => f.write_fmt(format_args!("{val}")),
}
}
}

View File

@ -1,4 +1,6 @@
use super::Instruction;
use hashbrown::HashSet;
use super::{ConstantShape, ConstantStride, Instruction};
use std::fmt::Display;
/// A body is composed of a list of [instructions](Instruction).
@ -12,6 +14,8 @@ pub struct Body {
pub id: bool,
pub stride: bool,
pub shape: bool,
pub constant_shapes: HashSet<ConstantShape>,
pub constant_strides: HashSet<ConstantStride>,
}
impl Display for Body {
@ -29,6 +33,24 @@ impl Display for Body {
f.write_str("let rank_2: u32 = rank * 2u;\n")?;
}
for shape in self.constant_shapes.iter() {
let declaration = Instruction::Shape {
dim: super::Variable::ConstantScalar(shape.dim as f64, super::Elem::U32),
position: shape.position,
out: super::Variable::ConstantShape(*shape),
};
f.write_fmt(format_args!("let {declaration};\n"))?;
}
for stride in self.constant_strides.iter() {
let declaration = Instruction::Stride {
dim: super::Variable::ConstantScalar(stride.dim as f64, super::Elem::U32),
position: stride.position,
out: super::Variable::ConstantStride(*stride),
};
f.write_fmt(format_args!("let {declaration};\n"))?;
}
for ops in self.instructions.iter() {
f.write_fmt(format_args!("{ops}"))?;
}

View File

@ -1,7 +1,8 @@
use super::{shader::ComputeShader, Item, SharedMemory};
use super::{LocalArray, Subgroup};
use super::{ConstantShape, ConstantStride, LocalArray, Subgroup};
use crate::compiler::wgsl;
use burn_cube::ir as cube;
use hashbrown::HashSet;
/// Wgsl Compiler.
#[derive(Clone, Default)]
@ -14,14 +15,16 @@ pub struct WgslCompiler {
workgroup_id: bool,
rank: bool,
id: bool,
stride: bool,
shape: bool,
num_workgroups: bool,
workgroup_id_no_axis: bool,
workgroup_size_no_axis: bool,
num_workgroup_no_axis: bool,
shared_memories: Vec<SharedMemory>,
local_arrays: Vec<LocalArray>,
shape: bool,
stride: bool,
constant_shapes: HashSet<ConstantShape>,
constant_strides: HashSet<ConstantStride>,
}
impl core::fmt::Debug for WgslCompiler {
@ -54,12 +57,21 @@ impl WgslCompiler {
let instructions = self.compile_scope(&mut value.body);
let extensions = register_extensions(&instructions);
let mut constant_shapes = HashSet::new();
let mut constant_strides = HashSet::new();
core::mem::swap(&mut self.constant_shapes, &mut constant_shapes);
core::mem::swap(&mut self.constant_strides, &mut constant_strides);
let body = wgsl::Body {
instructions,
rank: true,
id: self.id,
stride: self.stride,
shape: self.shape,
constant_shapes,
constant_strides,
};
wgsl::ComputeShader {
@ -420,28 +432,58 @@ impl WgslCompiler {
match metadata {
cube::Metadata::Stride { dim, var, out } => {
self.stride = true;
let position = match var {
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
};
wgsl::Instruction::Stride {
dim: self.compile_variable(dim),
position,
out: self.compile_variable(out),
let dim = self.compile_variable(dim);
let out = self.compile_variable(out);
match dim {
wgsl::Variable::ConstantScalar(val, _) => {
let var = ConstantStride {
position,
dim: val as usize,
};
self.constant_strides.insert(var);
wgsl::Instruction::Assign {
input: wgsl::Variable::ConstantStride(var),
out,
}
}
_ => wgsl::Instruction::Stride { dim, position, out },
}
}
cube::Metadata::Shape { dim, var, out } => {
self.shape = true;
let position = match var {
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
_ => panic!("Only Input and Output have a shape, got {:?}", var),
_ => panic!("Only Input and Output have a shape, got: {:?}", var),
};
wgsl::Instruction::Shape {
dim: self.compile_variable(dim),
position,
out: self.compile_variable(out),
let dim = self.compile_variable(dim);
let out = self.compile_variable(out);
match dim {
wgsl::Variable::ConstantScalar(val, _) => {
let var = ConstantShape {
position,
dim: val as usize,
};
self.constant_shapes.insert(var);
wgsl::Instruction::Assign {
input: wgsl::Variable::ConstantShape(var),
out,
}
}
_ => wgsl::Instruction::Shape { dim, position, out },
}
}
cube::Metadata::ArrayLength { var, out } => wgsl::Instruction::ArrayLength {