mirror of https://github.com/tracel-ai/burn.git
Feat/wgpu/prng bernoulli (#571)
This commit is contained in:
parent
a69788ad4b
commit
87125da6c9
|
@ -0,0 +1,133 @@
|
|||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{
|
||||
context::WorkGroup,
|
||||
element::WgpuElement,
|
||||
kernel::{prng::base::get_seeds, KernelSettings},
|
||||
kernel_wgsl,
|
||||
pool::get_context,
|
||||
tensor::WgpuTensor,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
|
||||
kernel_wgsl!(BernoulliPRNG, "../../template/prng/bernoulli.wgsl");
|
||||
|
||||
/// Pseudo-random generator for bernoulli
|
||||
pub fn random_bernoulli<G: GraphicsApi, E: WgpuElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &WgpuDevice,
|
||||
prob: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
const WORKGROUP: usize = 32;
|
||||
const N_VALUES_PER_THREAD: u32 = 128;
|
||||
let num_elems = shape.num_elements();
|
||||
let num_threads = f32::ceil(num_elems as f32 / N_VALUES_PER_THREAD as f32);
|
||||
let num_invocations = f32::ceil(num_threads / (WORKGROUP * WORKGROUP) as f32);
|
||||
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
|
||||
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
|
||||
let workgroup = WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1);
|
||||
|
||||
let buffer = context.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(context.clone(), shape, buffer);
|
||||
|
||||
let mut info = get_seeds();
|
||||
info.insert(0, N_VALUES_PER_THREAD);
|
||||
let info_buffer = context.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let args = [prob];
|
||||
let args_buffer = context.create_buffer_with_data(E::as_bytes(&args));
|
||||
|
||||
let kernel =
|
||||
context.compile_static::<KernelSettings<BernoulliPRNG, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
context.execute(
|
||||
workgroup,
|
||||
kernel,
|
||||
&[&output.buffer, &info_buffer, &args_buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::f32;
|
||||
|
||||
use burn_tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
|
||||
use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice};
|
||||
|
||||
#[test]
|
||||
fn subsequent_calls_give_different_tensors() {
|
||||
TestBackend::seed(0);
|
||||
let shape: Shape<2> = [40, 40].into();
|
||||
let device = WgpuDevice::default();
|
||||
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::random_device(
|
||||
shape.clone(),
|
||||
Distribution::Bernoulli(0.5),
|
||||
&device,
|
||||
);
|
||||
let tensor_2 = Tensor::<TestBackend, 2>::random_device(
|
||||
shape.clone(),
|
||||
Distribution::Bernoulli(0.5),
|
||||
&device,
|
||||
);
|
||||
let mut diff_exists = false;
|
||||
for i in 0..shape.num_elements() {
|
||||
if tensor_1.to_data().value[i] != tensor_2.to_data().value[i] {
|
||||
diff_exists = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert!(diff_exists);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn number_of_1_proportional_to_prob() {
|
||||
TestBackend::seed(0);
|
||||
let shape: Shape<2> = [40, 40].into();
|
||||
let device = WgpuDevice::default();
|
||||
let prob = 0.7;
|
||||
|
||||
let tensor_1 = Tensor::<TestBackend, 2>::random_device(
|
||||
shape.clone(),
|
||||
Distribution::Bernoulli(prob),
|
||||
&device,
|
||||
);
|
||||
|
||||
// High bound slightly over 1 so 1.0 is included in second bin
|
||||
let bin_stats = calculate_bin_stats(tensor_1.into_data().value, 2, 0., 1.1);
|
||||
assert!(
|
||||
f32::abs((bin_stats[1].count as f32 / shape.num_elements() as f32) - prob as f32)
|
||||
< 0.01
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runs_test() {
|
||||
TestBackend::seed(0);
|
||||
let shape = Shape::new([512, 512]);
|
||||
let device = WgpuDevice::default();
|
||||
let tensor = Tensor::<TestBackend, 2>::random_device(
|
||||
shape.clone(),
|
||||
Distribution::Bernoulli(0.5),
|
||||
&device,
|
||||
);
|
||||
|
||||
let numbers = tensor.clone().into_data().value;
|
||||
let stats = calculate_bin_stats(numbers, 2, 0., 1.1);
|
||||
let n_0 = stats[0].count as f32;
|
||||
let n_1 = stats[1].count as f32;
|
||||
let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32;
|
||||
|
||||
let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0;
|
||||
let variance = ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1))
|
||||
/ ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.));
|
||||
let z = (n_runs - expectation) / variance.sqrt();
|
||||
|
||||
// below 2 means we can have good confidence in the randomness
|
||||
assert!(z.abs() < 2.);
|
||||
}
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
mod base;
|
||||
mod bernoulli;
|
||||
mod normal;
|
||||
mod uniform;
|
||||
|
||||
pub use bernoulli::*;
|
||||
pub use normal::*;
|
||||
pub use uniform::*;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{numeric, BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor};
|
||||
use crate::kernel::prng::{random_normal, random_uniform};
|
||||
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
||||
use crate::kernel::{
|
||||
self, unary_default, unary_inplace_default, unary_scalar_default, unary_scalar_inplace_default,
|
||||
};
|
||||
|
@ -7,9 +7,8 @@ use crate::kernel::{
|
|||
use crate::unary_scalar_inplace;
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
unary, unary_inplace, unary_scalar, GraphicsApi, WgpuBackend, SEED,
|
||||
unary, unary_inplace, unary_scalar, GraphicsApi, WgpuBackend,
|
||||
};
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
use burn_tensor::ElementConversion;
|
||||
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
|
||||
|
||||
|
@ -33,24 +32,15 @@ where
|
|||
distribution: Distribution<FloatElem<Self>>,
|
||||
device: &Device<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
// TODO other distributions than default
|
||||
match distribution {
|
||||
Distribution::Default => random_uniform::<G, F, D>(shape, device, 0.elem(), 1.elem()),
|
||||
Distribution::Uniform(low, high) => random_uniform::<G, F, D>(shape, device, low, high),
|
||||
Distribution::Bernoulli(prob) => {
|
||||
random_bernoulli::<G, F, D>(shape, device, prob.elem())
|
||||
}
|
||||
Distribution::Normal(mean, std) => {
|
||||
random_normal::<G, F, D>(shape, device, mean.elem(), std.elem())
|
||||
}
|
||||
_ => {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
|
||||
rng_seeded.clone()
|
||||
} else {
|
||||
get_seeded_rng()
|
||||
};
|
||||
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
|
||||
*seed = Some(rng);
|
||||
tensor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> info: array<u32, 5>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> args: array<{{ elem }}, 2>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(local_invocation_index) local_id: u32,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
let wg_size_x = {{ workgroup_size_x }}u;
|
||||
let wg_size_y = {{ workgroup_size_y }}u;
|
||||
let wg = workgroup_id.x * num_workgroups.y + workgroup_id.y;
|
||||
let n_threads_per_workgroup = wg_size_x * wg_size_y;
|
||||
let wg_offset = wg * n_threads_per_workgroup;
|
||||
let unique_thread_id = wg_offset + local_id;
|
||||
let large_prime = 1000000007u;
|
||||
let thread_seed = large_prime * unique_thread_id;
|
||||
|
||||
var state: array<u32, 4u>;
|
||||
for (var i = 0u; i < 4u; i++) {
|
||||
state[i] = info[i + 1u] + thread_seed;
|
||||
}
|
||||
|
||||
let n_values_per_thread = info[0u];
|
||||
for (var i = 0u; i < n_values_per_thread; i++) {
|
||||
state[0u] = taus_step(state[0u], 13u, 19u, 12u, 4294967294u);
|
||||
state[1u] = taus_step(state[1u], 2u, 25u, 4u, 4294967288u);
|
||||
state[2u] = taus_step(state[2u], 3u, 11u, 17u, 4294967280u);
|
||||
state[3u] = lcg_step(state[3u]);
|
||||
let hybrid_taus = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
|
||||
let write_index = wg_offset * n_values_per_thread + local_id + i * n_threads_per_workgroup;
|
||||
let float = cast_float(hybrid_taus);
|
||||
|
||||
let prob = args[0];
|
||||
output[write_index] = {{ elem }}(float < prob);
|
||||
}
|
||||
}
|
||||
|
||||
fn lcg_step(z: u32) -> u32 {
|
||||
return (1664525u * z + 1013904223u); // modulo 2^32, not necessary in u32
|
||||
}
|
||||
|
||||
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
|
||||
let b = ((z << s1) ^ z) >> s2;
|
||||
return (z & m) << s3 ^ b;
|
||||
}
|
||||
|
||||
fn cast_float(number: u32) -> {{ elem }} {
|
||||
return 2.3283064365387e-10 * {{ elem }}(number);
|
||||
}
|
Loading…
Reference in New Issue