Autotune: fix inputs (#926)

This commit is contained in:
Louis Fortier-Dubois 2023-11-06 08:59:31 -05:00 committed by GitHub
parent 6548f1a730
commit a0297530ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 32 deletions

View File

@ -1,10 +1,11 @@
use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet};
use burn_tensor::Element;
use burn_tensor::{Element, ElementConversion};
use crate::{
compute::WgpuAutotuneKey,
element::WgpuElement,
kernel::matmul::{tune::utils::autotune_tensors, utils::init_matmul_output},
kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform},
ops::numeric::empty_device,
tensor::WgpuTensor,
};
@ -37,32 +38,42 @@ impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet<WgpuAutotune
}
fn autotunables(&self) -> Vec<Box<dyn AutotuneOperation>> {
let lhs = autotune_tensors(&self.lhs);
let rhs = autotune_tensors(&self.rhs);
let out = autotune_tensors(&self.out);
let random_bounds: (E, E) = ((-10.0).elem::<E>(), (10.0).elem::<E>());
let lhs = random_like_uniform(&self.lhs, random_bounds.0, random_bounds.1);
let rhs = random_like_uniform(&self.rhs, random_bounds.0, random_bounds.1);
let out = empty_device(
self.out.client.clone(),
self.out.device.clone(),
self.out.shape.clone(),
);
vec![
Box::new(MemoryCoalescingMatmulDefault::<E, 3>::new(
Box::new(MemoryCoalescingMatmulDefault::<E, D>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(MemoryCoalescingMatmulW16x16::<E, 3>::new(
Box::new(MemoryCoalescingMatmulW16x16::<E, D>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4TilingMatmulDefault::<E, 3>::new(
Box::new(Vec4TilingMatmulDefault::<E, D>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4TilingMatmulUnpaddedDefault::<E, 3>::new(
Box::new(Vec4TilingMatmulUnpaddedDefault::<E, D>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, D>::new(
lhs.clone(),
rhs.clone(),
out.clone(),
)),
Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, 3>::new(lhs, rhs, out)),
]
}

View File

@ -1,6 +1,5 @@
mod base;
mod key;
mod utils;
pub use base::*;
pub use key::*;

View File

@ -1,19 +0,0 @@
use burn_tensor::Element;
use crate::{element::WgpuElement, ops::numeric::ones_device, tensor::WgpuTensor};
pub(crate) fn autotune_tensors<E: WgpuElement + Element, const D: usize>(
tensor: &WgpuTensor<E, D>,
) -> WgpuTensor<E, 3> {
let n_batches = 2;
ones_device(
tensor.client.clone(),
tensor.device.clone(),
[
n_batches,
tensor.shape.dims[D - 2],
tensor.shape.dims[D - 1],
]
.into(),
)
}

View File

@ -1,7 +1,7 @@
use burn_tensor::Shape;
use crate::{
compute::{compute_client, StaticKernel},
compute::{compute_client, StaticKernel, WgpuComputeClient},
element::WgpuElement,
kernel::{
prng::base::{make_args_buffer, make_info_buffer},
@ -31,10 +31,36 @@ pub fn random_uniform<G: GraphicsApi, E: WgpuElement, const D: usize>(
device: &WgpuDevice,
low: E,
high: E,
) -> WgpuTensor<E, D> {
let client = compute_client::<G>(device);
uniform_kernel(client, device, &shape, low, high)
}
/// Pseudo-random generator for uniform distribution, based on
/// another tensor's client, device and shape
pub fn random_like_uniform<E: WgpuElement, const D: usize>(
tensor: &WgpuTensor<E, D>,
low: E,
high: E,
) -> WgpuTensor<E, D> {
uniform_kernel(
tensor.client.clone(),
&tensor.device,
&tensor.shape,
low,
high,
)
}
fn uniform_kernel<E: WgpuElement, const D: usize>(
client: WgpuComputeClient,
device: &WgpuDevice,
shape: &Shape<D>,
low: E,
high: E,
) -> WgpuTensor<E, D> {
const N_VALUES_PER_THREAD: usize = 128;
let client = compute_client::<G>(device);
let output = empty_device(client.clone(), device.clone(), shape.clone());
let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
let args_handle = make_args_buffer(client.clone(), &[low, high]);