Feat/wgpu/prng bernoulli (#571)

This commit is contained in:
Louis Fortier-Dubois 2023-08-01 12:54:22 -04:00 committed by GitHub
parent a69788ad4b
commit 87125da6c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 200 additions and 15 deletions

View File

@ -0,0 +1,133 @@
use burn_tensor::Shape;
use crate::{
context::WorkGroup,
element::WgpuElement,
kernel::{prng::base::get_seeds, KernelSettings},
kernel_wgsl,
pool::get_context,
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
};
kernel_wgsl!(BernoulliPRNG, "../../template/prng/bernoulli.wgsl");
/// Pseudo-random generator for bernoulli
pub fn random_bernoulli<G: GraphicsApi, E: WgpuElement, const D: usize>(
shape: Shape<D>,
device: &WgpuDevice,
prob: E,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: u32 = 128;
let num_elems = shape.num_elements();
let num_threads = f32::ceil(num_elems as f32 / N_VALUES_PER_THREAD as f32);
let num_invocations = f32::ceil(num_threads / (WORKGROUP * WORKGROUP) as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
let workgroup = WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1);
let buffer = context.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(context.clone(), shape, buffer);
let mut info = get_seeds();
info.insert(0, N_VALUES_PER_THREAD);
let info_buffer = context.create_buffer_with_data(bytemuck::cast_slice(&info));
let args = [prob];
let args_buffer = context.create_buffer_with_data(E::as_bytes(&args));
let kernel =
context.compile_static::<KernelSettings<BernoulliPRNG, E, i32, WORKGROUP, WORKGROUP, 1>>();
context.execute(
workgroup,
kernel,
&[&output.buffer, &info_buffer, &args_buffer],
);
output
}
#[cfg(test)]
mod tests {
use core::f32;
use burn_tensor::{backend::Backend, Distribution, Shape, Tensor};
use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice};
#[test]
fn subsequent_calls_give_different_tensors() {
TestBackend::seed(0);
let shape: Shape<2> = [40, 40].into();
let device = WgpuDevice::default();
let tensor_1 = Tensor::<TestBackend, 2>::random_device(
shape.clone(),
Distribution::Bernoulli(0.5),
&device,
);
let tensor_2 = Tensor::<TestBackend, 2>::random_device(
shape.clone(),
Distribution::Bernoulli(0.5),
&device,
);
let mut diff_exists = false;
for i in 0..shape.num_elements() {
if tensor_1.to_data().value[i] != tensor_2.to_data().value[i] {
diff_exists = true;
break;
}
}
assert!(diff_exists);
}
#[test]
fn number_of_1_proportional_to_prob() {
TestBackend::seed(0);
let shape: Shape<2> = [40, 40].into();
let device = WgpuDevice::default();
let prob = 0.7;
let tensor_1 = Tensor::<TestBackend, 2>::random_device(
shape.clone(),
Distribution::Bernoulli(prob),
&device,
);
// High bound slightly over 1 so 1.0 is included in second bin
let bin_stats = calculate_bin_stats(tensor_1.into_data().value, 2, 0., 1.1);
assert!(
f32::abs((bin_stats[1].count as f32 / shape.num_elements() as f32) - prob as f32)
< 0.01
);
}
#[test]
fn runs_test() {
TestBackend::seed(0);
let shape = Shape::new([512, 512]);
let device = WgpuDevice::default();
let tensor = Tensor::<TestBackend, 2>::random_device(
shape.clone(),
Distribution::Bernoulli(0.5),
&device,
);
let numbers = tensor.clone().into_data().value;
let stats = calculate_bin_stats(numbers, 2, 0., 1.1);
let n_0 = stats[0].count as f32;
let n_1 = stats[1].count as f32;
let n_runs = (stats[0].n_runs + stats[1].n_runs) as f32;
let expectation = (2. * n_0 * n_1) / (n_0 + n_1) + 1.0;
let variance = ((2. * n_0 * n_1) * (2. * n_0 * n_1 - n_0 - n_1))
/ ((n_0 + n_1).powf(2.) * (n_0 + n_1 - 1.));
let z = (n_runs - expectation) / variance.sqrt();
// below 2 means we can have good confidence in the randomness
assert!(z.abs() < 2.);
}
}

View File

@ -1,6 +1,8 @@
mod base;
mod bernoulli;
mod normal;
mod uniform;
pub use bernoulli::*;
pub use normal::*;
pub use uniform::*;

View File

@ -1,5 +1,5 @@
use super::{numeric, BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor};
use crate::kernel::prng::{random_normal, random_uniform};
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
use crate::kernel::{
self, unary_default, unary_inplace_default, unary_scalar_default, unary_scalar_inplace_default,
};
@ -7,9 +7,8 @@ use crate::kernel::{
use crate::unary_scalar_inplace;
use crate::{
element::{FloatElement, IntElement},
unary, unary_inplace, unary_scalar, GraphicsApi, WgpuBackend, SEED,
unary, unary_inplace, unary_scalar, GraphicsApi, WgpuBackend,
};
use burn_common::rand::get_seeded_rng;
use burn_tensor::ElementConversion;
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
@ -33,24 +32,15 @@ where
distribution: Distribution<FloatElem<Self>>,
device: &Device<Self>,
) -> FloatTensor<Self, D> {
// TODO other distributions than default
match distribution {
Distribution::Default => random_uniform::<G, F, D>(shape, device, 0.elem(), 1.elem()),
Distribution::Uniform(low, high) => random_uniform::<G, F, D>(shape, device, low, high),
Distribution::Bernoulli(prob) => {
random_bernoulli::<G, F, D>(shape, device, prob.elem())
}
Distribution::Normal(mean, std) => {
random_normal::<G, F, D>(shape, device, mean.elem(), std.elem())
}
_ => {
let mut seed = SEED.lock().unwrap();
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
rng_seeded.clone()
} else {
get_seeded_rng()
};
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
*seed = Some(rng);
tensor
}
}
}

View File

@ -0,0 +1,60 @@
@group(0)
@binding(0)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> info: array<u32, 5>;
@group(0)
@binding(2)
var<storage, read> args: array<{{ elem }}, 2>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(local_invocation_index) local_id: u32,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let wg_size_x = {{ workgroup_size_x }}u;
let wg_size_y = {{ workgroup_size_y }}u;
let wg = workgroup_id.x * num_workgroups.y + workgroup_id.y;
let n_threads_per_workgroup = wg_size_x * wg_size_y;
let wg_offset = wg * n_threads_per_workgroup;
let unique_thread_id = wg_offset + local_id;
let large_prime = 1000000007u;
let thread_seed = large_prime * unique_thread_id;
var state: array<u32, 4u>;
for (var i = 0u; i < 4u; i++) {
state[i] = info[i + 1u] + thread_seed;
}
let n_values_per_thread = info[0u];
for (var i = 0u; i < n_values_per_thread; i++) {
state[0u] = taus_step(state[0u], 13u, 19u, 12u, 4294967294u);
state[1u] = taus_step(state[1u], 2u, 25u, 4u, 4294967288u);
state[2u] = taus_step(state[2u], 3u, 11u, 17u, 4294967280u);
state[3u] = lcg_step(state[3u]);
let hybrid_taus = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let write_index = wg_offset * n_values_per_thread + local_id + i * n_threads_per_workgroup;
let float = cast_float(hybrid_taus);
let prob = args[0];
output[write_index] = {{ elem }}(float < prob);
}
}
fn lcg_step(z: u32) -> u32 {
return (1664525u * z + 1013904223u); // modulo 2^32, not necessary in u32
}
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
let b = ((z << s1) ^ z) >> s2;
return (z & m) << s3 ^ b;
}
fn cast_float(number: u32) -> {{ elem }} {
return 2.3283064365387e-10 * {{ elem }}(number);
}