Refactor/wgpu/prng (#576)

This commit is contained in:
Louis Fortier-Dubois 2023-08-02 16:08:50 -04:00 committed by GitHub
parent 73fb0eaa7e
commit d5f9f69cea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 263 additions and 290 deletions

View File

@ -41,6 +41,7 @@ burn-tensor = {path = "../burn-tensor", version = "0.9.0", default-features = fa
"export_tests",
]}
burn-ndarray = {path = "../burn-ndarray", version = "0.9.0" }
serial_test = "0.5.0"
[[bench]]
name = "unary"

View File

@ -188,6 +188,20 @@ pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> Wor
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
}
pub(crate) fn prng_workgroup(
num_elems: usize,
workgroup_size: usize,
n_values_per_thread: usize,
) -> WorkGroup {
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
let num_elem_per_invocation = workgroup_size * workgroup_size;
let num_invocations = f32::ceil(num_threads / num_elem_per_invocation as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1)
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -1,7 +1,13 @@
use burn_common::rand::get_seeded_rng;
use rand::Rng;
use std::sync::Arc;
use crate::SEED;
use burn_common::rand::get_seeded_rng;
use burn_tensor::Shape;
use rand::Rng;
use wgpu::Buffer;
use crate::{context::Context, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor, SEED};
kernel_wgsl!(Prng, "../../template/prng/prng.wgsl");
pub(crate) fn get_seeds() -> Vec<u32> {
let mut seed = SEED.lock().unwrap();
@ -17,6 +23,24 @@ pub(crate) fn get_seeds() -> Vec<u32> {
seeds
}
pub(crate) fn make_output_tensor<E: WgpuElement, const D: usize>(
context: Arc<Context>,
shape: Shape<D>,
) -> WgpuTensor<E, D> {
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
WgpuTensor::new(context.clone(), shape, buffer)
}
pub(crate) fn make_info_buffer(context: Arc<Context>, n_values_per_thread: usize) -> Arc<Buffer> {
let mut info = get_seeds();
info.insert(0, n_values_per_thread as u32);
context.create_buffer_with_data(bytemuck::cast_slice(&info))
}
pub(crate) fn make_args_buffer<E: WgpuElement>(context: Arc<Context>, args: &[E]) -> Arc<Buffer> {
context.create_buffer_with_data(E::as_bytes(args))
}
#[cfg(test)]
pub mod tests {
use burn_tensor::Element;

View File

@ -1,16 +1,31 @@
use burn_tensor::Shape;
use crate::{
context::WorkGroup,
element::WgpuElement,
kernel::{prng::base::get_seeds, KernelSettings},
kernel_wgsl,
kernel::{
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
},
pool::get_context,
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
};
kernel_wgsl!(BernoulliPRNG, "../../template/prng/bernoulli.wgsl");
use super::base::Prng;
struct BernoulliPrng;
impl StaticKernel for BernoulliPrng {
fn source_template() -> SourceTemplate {
Prng::source_template()
.register("num_args", "1")
.register(
"prng_loop",
include_str!("../../template/prng/bernoulli_inner_loop.wgsl"),
)
.add_template("fn cast_elem(e: bool) -> {{ elem }} {return {{elem}}(e);}")
}
}
/// Pseudo-random generator for bernoulli
pub fn random_bernoulli<G: GraphicsApi, E: WgpuElement, const D: usize>(
@ -18,32 +33,17 @@ pub fn random_bernoulli<G: GraphicsApi, E: WgpuElement, const D: usize>(
device: &WgpuDevice,
prob: E,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: u32 = 128;
let num_elems = shape.num_elements();
let num_threads = f32::ceil(num_elems as f32 / N_VALUES_PER_THREAD as f32);
let num_invocations = f32::ceil(num_threads / (WORKGROUP * WORKGROUP) as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
let workgroup = WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1);
const N_VALUES_PER_THREAD: usize = 128;
let buffer = context.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(context.clone(), shape, buffer);
let mut info = get_seeds();
info.insert(0, N_VALUES_PER_THREAD);
let info_buffer = context.create_buffer_with_data(bytemuck::cast_slice(&info));
let args = [prob];
let args_buffer = context.create_buffer_with_data(E::as_bytes(&args));
let kernel =
context.compile_static::<KernelSettings<BernoulliPRNG, E, i32, WORKGROUP, WORKGROUP, 1>>();
let context = get_context::<G>(device);
let output = make_output_tensor(context.clone(), shape.clone());
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
let args_buffer = make_args_buffer(context.clone(), &[prob]);
context.execute(
workgroup,
kernel,
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
context.compile_static::<KernelSettings<BernoulliPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
&[&output.buffer, &info_buffer, &args_buffer],
);
@ -55,10 +55,12 @@ mod tests {
use core::f32;
use burn_tensor::{backend::Backend, Distribution, Shape, Tensor};
use serial_test::serial;
use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice};
#[test]
#[serial]
fn subsequent_calls_give_different_tensors() {
TestBackend::seed(0);
let shape: Shape<2> = [40, 40].into();
@ -85,6 +87,7 @@ mod tests {
}
#[test]
#[serial]
fn number_of_1_proportional_to_prob() {
TestBackend::seed(0);
let shape: Shape<2> = [40, 40].into();
@ -101,11 +104,12 @@ mod tests {
let bin_stats = calculate_bin_stats(tensor_1.into_data().value, 2, 0., 1.1);
assert!(
f32::abs((bin_stats[1].count as f32 / shape.num_elements() as f32) - prob as f32)
< 0.01
< 0.05
);
}
#[test]
#[serial]
fn runs_test() {
TestBackend::seed(0);
let shape = Shape::new([512, 512]);
@ -128,6 +132,7 @@ mod tests {
let z = (n_runs - expectation) / variance.sqrt();
// below 2 means we can have good confidence in the randomness
assert!(z.abs() < 2.);
// we put 2.5 to make sure it passes even when very unlucky
assert!(z.abs() < 2.5);
}
}

View File

@ -1,16 +1,33 @@
use burn_tensor::Shape;
use crate::{
context::WorkGroup,
element::WgpuElement,
kernel::{prng::base::get_seeds, KernelSettings},
kernel_wgsl,
kernel::{
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
},
pool::get_context,
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
};
kernel_wgsl!(NormalPRNG, "../../template/prng/normal.wgsl");
use super::base::Prng;
struct NormalPrng;
impl StaticKernel for NormalPrng {
fn source_template() -> SourceTemplate {
Prng::source_template()
.register("num_args", "2")
.register(
"prng_loop",
include_str!("../../template/prng/normal_inner_loop.wgsl"),
)
.add_template(include_str!(
"../../template/prng/box_muller_transform.wgsl"
))
}
}
/// Pseudo-random generator for normal distribution
pub fn random_normal<G: GraphicsApi, E: WgpuElement, const D: usize>(
@ -19,33 +36,17 @@ pub fn random_normal<G: GraphicsApi, E: WgpuElement, const D: usize>(
mean: E,
std: E,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: u32 = 128; // must be even
const N_VALUES_PER_THREAD: usize = 128; // must be even
let num_elems = shape.num_elements();
let num_threads = f32::ceil(num_elems as f32 / N_VALUES_PER_THREAD as f32);
let num_invocations = f32::ceil(num_threads / (WORKGROUP * WORKGROUP) as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
let workgroup = WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1);
let buffer = context.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(context.clone(), shape, buffer);
let mut info = get_seeds();
info.insert(0, N_VALUES_PER_THREAD);
let info_buffer = context.create_buffer_with_data(bytemuck::cast_slice(&info));
let args = [mean, std];
let args_buffer = context.create_buffer_with_data(E::as_bytes(&args));
let kernel =
context.compile_static::<KernelSettings<NormalPRNG, E, i32, WORKGROUP, WORKGROUP, 1>>();
let context = get_context::<G>(device);
let output = make_output_tensor(context.clone(), shape.clone());
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
let args_buffer = make_args_buffer(context.clone(), &[mean, std]);
context.execute(
workgroup,
kernel,
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
context.compile_static::<KernelSettings<NormalPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
&[&output.buffer, &info_buffer, &args_buffer],
);
@ -56,10 +57,12 @@ pub fn random_normal<G: GraphicsApi, E: WgpuElement, const D: usize>(
mod tests {
use burn_tensor::{backend::Backend, Data, Distribution, Shape, Tensor};
use serial_test::serial;
use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice};
#[test]
#[serial]
fn subsequent_calls_give_different_tensors() {
TestBackend::seed(0);
let shape = [4, 5];
@ -75,6 +78,7 @@ mod tests {
}
#[test]
#[serial]
fn empirical_mean_close_to_expectation() {
TestBackend::seed(0);
let shape = [128, 128];
@ -87,6 +91,7 @@ mod tests {
}
#[test]
#[serial]
fn normal_respects_68_95_99_rule() {
// https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule
let shape: Shape<2> = [1000, 1000].into();
@ -106,7 +111,7 @@ mod tests {
);
let assert_approx_eq = |count, percent| {
let expected = percent * shape.num_elements() as f32 / 100.;
assert!(f32::abs(count as f32 - expected) < 1000.);
assert!(f32::abs(count as f32 - expected) < 2000.);
};
assert_approx_eq(stats[0].count, 2.1);
assert_approx_eq(stats[1].count, 13.6);

View File

@ -1,16 +1,28 @@
use burn_tensor::Shape;
use crate::{
context::WorkGroup,
element::WgpuElement,
kernel::{prng::base::get_seeds, KernelSettings},
kernel_wgsl,
kernel::{
prng::base::{make_args_buffer, make_info_buffer, make_output_tensor},
prng_workgroup, KernelSettings, SourceTemplate, StaticKernel,
},
pool::get_context,
tensor::WgpuTensor,
GraphicsApi, WgpuDevice,
};
kernel_wgsl!(UniformPRNG, "../../template/prng/uniform.wgsl");
use super::base::Prng;
struct UniformPrng;
impl StaticKernel for UniformPrng {
fn source_template() -> SourceTemplate {
Prng::source_template().register("num_args", "2").register(
"prng_loop",
include_str!("../../template/prng/uniform_inner_loop.wgsl"),
)
}
}
/// Pseudo-random generator for uniform distribution
pub fn random_uniform<G: GraphicsApi, E: WgpuElement, const D: usize>(
@ -19,32 +31,17 @@ pub fn random_uniform<G: GraphicsApi, E: WgpuElement, const D: usize>(
low: E,
high: E,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
const WORKGROUP: usize = 32;
const N_VALUES_PER_THREAD: u32 = 128;
let num_elems = shape.num_elements();
let num_threads = f32::ceil(num_elems as f32 / N_VALUES_PER_THREAD as f32);
let num_invocations = f32::ceil(num_threads / (WORKGROUP * WORKGROUP) as f32);
let workgroup_x = f32::ceil(f32::sqrt(num_invocations));
let workgroup_y = f32::ceil(num_invocations / workgroup_x);
let workgroup = WorkGroup::new(workgroup_x as u32, workgroup_y as u32, 1);
const N_VALUES_PER_THREAD: usize = 128;
let buffer = context.create_buffer(num_elems * core::mem::size_of::<E>());
let output = WgpuTensor::new(context.clone(), shape, buffer);
let mut info = get_seeds();
info.insert(0, N_VALUES_PER_THREAD);
let info_buffer = context.create_buffer_with_data(bytemuck::cast_slice(&info));
let args = [low, high];
let args_buffer = context.create_buffer_with_data(E::as_bytes(&args));
let kernel =
context.compile_static::<KernelSettings<UniformPRNG, E, i32, WORKGROUP, WORKGROUP, 1>>();
let context = get_context::<G>(device);
let output = make_output_tensor(context.clone(), shape.clone());
let info_buffer = make_info_buffer(context.clone(), N_VALUES_PER_THREAD);
let args_buffer = make_args_buffer(context.clone(), &[low, high]);
context.execute(
workgroup,
kernel,
prng_workgroup(shape.num_elements(), WORKGROUP, N_VALUES_PER_THREAD),
context.compile_static::<KernelSettings<UniformPrng, E, i32, WORKGROUP, WORKGROUP, 1>>(),
&[&output.buffer, &info_buffer, &args_buffer],
);
@ -56,10 +53,12 @@ mod tests {
use core::f32;
use burn_tensor::{backend::Backend, Distribution, Shape, Tensor};
use serial_test::serial;
use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice};
#[test]
#[serial]
fn subsequent_calls_give_different_tensors() {
TestBackend::seed(0);
let shape = [4, 5];
@ -75,6 +74,7 @@ mod tests {
}
#[test]
#[serial]
fn values_all_within_interval_default() {
TestBackend::seed(0);
let shape = [24, 24];
@ -85,6 +85,7 @@ mod tests {
}
#[test]
#[serial]
fn values_all_within_interval_uniform() {
TestBackend::seed(0);
let shape = [24, 24];
@ -96,6 +97,7 @@ mod tests {
}
#[test]
#[serial]
fn at_least_one_value_per_bin_uniform() {
TestBackend::seed(0);
let shape = [64, 64];
@ -114,6 +116,7 @@ mod tests {
}
#[test]
#[serial]
fn runs_test() {
TestBackend::seed(0);
let shape = Shape::new([512, 512]);
@ -133,6 +136,7 @@ mod tests {
let z = (n_runs - expectation) / variance.sqrt();
// below 2 means we can have good confidence in the randomness
assert!(z.abs() < 2.);
// we put 2.5 to make sure it passes even when very unlucky
assert!(z.abs() < 2.5);
}
}

View File

@ -1,60 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> info: array<u32, 5>;
@group(0)
@binding(2)
var<storage, read> args: array<{{ elem }}, 2>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(local_invocation_index) local_id: u32,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let wg_size_x = {{ workgroup_size_x }}u;
let wg_size_y = {{ workgroup_size_y }}u;
let wg = workgroup_id.x * num_workgroups.y + workgroup_id.y;
let n_threads_per_workgroup = wg_size_x * wg_size_y;
let wg_offset = wg * n_threads_per_workgroup;
let unique_thread_id = wg_offset + local_id;
let large_prime = 1000000007u;
let thread_seed = large_prime * unique_thread_id;
var state: array<u32, 4u>;
for (var i = 0u; i < 4u; i++) {
state[i] = info[i + 1u] + thread_seed;
}
let n_values_per_thread = info[0u];
for (var i = 0u; i < n_values_per_thread; i++) {
state[0u] = taus_step(state[0u], 13u, 19u, 12u, 4294967294u);
state[1u] = taus_step(state[1u], 2u, 25u, 4u, 4294967288u);
state[2u] = taus_step(state[2u], 3u, 11u, 17u, 4294967280u);
state[3u] = lcg_step(state[3u]);
let hybrid_taus = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let write_index = wg_offset * n_values_per_thread + local_id + i * n_threads_per_workgroup;
let float = cast_float(hybrid_taus);
let prob = args[0];
output[write_index] = {{ elem }}(float < prob);
}
}
fn lcg_step(z: u32) -> u32 {
return (1664525u * z + 1013904223u); // modulo 2^32, not necessary in u32
}
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
let b = ((z << s1) ^ z) >> s2;
return (z & m) << s3 ^ b;
}
fn cast_float(number: u32) -> {{ elem }} {
return 2.3283064365387e-10 * {{ elem }}(number);
}

View File

@ -0,0 +1,13 @@
let prob = args[0];
for (var i = 0u; i < n_values_per_thread; i++) {
let write_index = write_index_base + i * n_threads_per_workgroup;
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let float = cast_float(random_u32);
output[write_index] = cast_elem(float < prob);
}

View File

@ -0,0 +1,12 @@
fn box_muller_transform(unit_1: {{ elem }}, unit_2: {{ elem }}) -> array<{{ elem }}, 2> {
let mean = args[0];
let stdev = args[1];
let coeff = stdev * sqrt(-2.0 * log(unit_1));
let pi = 3.141592653589793238;
let trigo_arg = 2.0 * pi * unit_2;
let cos_ = cos(trigo_arg);
let sin_ = sin(trigo_arg);
return array(coeff * cos_ + mean, coeff * sin_ + mean);
}

View File

@ -1,81 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> info: array<u32, 5>;
@group(0)
@binding(2)
var<storage, read> args: array<{{ elem }}, 2>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(local_invocation_index) local_id: u32,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let wg_size_x = {{ workgroup_size_x }}u;
let wg_size_y = {{ workgroup_size_y }}u;
let wg = workgroup_id.x * num_workgroups.y + workgroup_id.y;
let n_threads_per_workgroup = wg_size_x * wg_size_y;
let wg_offset = wg * n_threads_per_workgroup;
let unique_thread_id = wg_offset + local_id;
let large_prime = 1000000007u;
let thread_seed = large_prime * unique_thread_id;
var state: array<u32, 4u>;
for (var i = 0u; i < 4u; i++) {
state[i] = info[i + 1u] + thread_seed;
}
let n_values_per_thread = info[0u];
// TODO ASSERT random_normal n threads is even
for (var i = 0u; i < n_values_per_thread / 2u; i++) {
var units: array<{{elem}}, 2>;
for (var j = 0u; j < 2u; j++) {
state[0u] = taus_step(state[0u], 13u, 19u, 12u, 4294967294u);
state[1u] = taus_step(state[1u], 2u, 25u, 4u, 4294967288u);
state[2u] = taus_step(state[2u], 3u, 11u, 17u, 4294967280u);
state[3u] = lcg_step(state[3u]);
let hybrid_taus = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
units[j] = cast_float(hybrid_taus);
}
let transformed = box_muller_transform(units[0], units[1]);
let write_index_0 = wg_offset * n_values_per_thread + local_id + (2u * i) * n_threads_per_workgroup;
let write_index_1 = write_index_0 + n_threads_per_workgroup;
output[write_index_0] = transformed[0];
output[write_index_1] = transformed[1];
}
}
fn lcg_step(z: u32) -> u32 {
return (1664525u * z + 1013904223u);
}
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
let b = ((z << s1) ^ z) >> s2;
return (z & m) << s3 ^ b;
}
fn cast_float(number: u32) -> {{ elem }} {
return 2.3283064365387e-10 * {{ elem }}(number);
}
fn box_muller_transform(unit_1: {{ elem }}, unit_2: {{ elem }}) -> array<{{elem}}, 2> {
let mean = args[0];
let stdev = args[1];
let coeff = stdev * sqrt(-2.0 * log(unit_1));
let pi = 3.141592653589793238;
let trigo_arg = 2.0 * pi * unit_2;
let cos_ = cos(trigo_arg);
let sin_ = sin(trigo_arg);
return array(coeff * cos_ + mean, coeff * sin_ + mean);
}

View File

@ -0,0 +1,23 @@
for (var i = 0u; i < n_values_per_thread / 2u; i++) {
let write_index_0 = write_index_base + (2u * i) * n_threads_per_workgroup;
let write_index_1 = write_index_0 + n_threads_per_workgroup;
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_1_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let random_1 = cast_float(random_1_u32);
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_2_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let random_2 = cast_float(random_2_u32);
let transformed = box_muller_transform(random_1, random_2);
output[write_index_0] = transformed[0];
output[write_index_1] = transformed[1];
}

View File

@ -0,0 +1,60 @@
@group(0)
@binding(0)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> info: array<u32, 5>;
@group(0)
@binding(2)
var<storage, read> args: array<{{ elem }}, {{ num_args }}>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(local_invocation_index) local_id: u32,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
// Thread preparation
let n_threads_per_workgroup = {{ workgroup_size }}u;
let workgroup_offset = (workgroup_id.x * num_workgroups.y + workgroup_id.y) * n_threads_per_workgroup;
let n_values_per_thread = info[0u];
let write_index_base = workgroup_offset * n_values_per_thread + local_id;
// Set state with unique seeds
let thread_seed = 1000000007u * (workgroup_offset + local_id);
var state: array<u32, 4u>;
for (var i = 0u; i < 4u; i++) {
state[i] = info[i + 1u] + thread_seed;
}
// Creation of n_values_per_thread values, specific to the distribution
{{ prng_loop }}
}
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
let b = ((z << s1) ^ z) >> s2;
return (z & m) << s3 ^ b;
}
fn taus_step_0(z: u32) -> u32 {
return taus_step(z, 13u, 19u, 12u, 4294967294u);
}
fn taus_step_1(z: u32) -> u32 {
return taus_step(z, 2u, 25u, 4u, 4294967288u);
}
fn taus_step_2(z: u32) -> u32 {
return taus_step(z, 3u, 11u, 17u, 4294967280u);
}
fn lcg_step(z: u32) -> u32 {
return (1664525u * z + 1013904223u);
}
fn cast_float(number: u32) -> {{ elem }} {
return 2.3283064365387e-10 * {{ elem }}(number);
}

View File

@ -1,63 +0,0 @@
@group(0)
@binding(0)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> info: array<u32, 5>;
@group(0)
@binding(2)
var<storage, read> args: array<{{ elem }}, 2>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(local_invocation_index) local_id: u32,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let wg_size_x = {{ workgroup_size_x }}u;
let wg_size_y = {{ workgroup_size_y }}u;
let wg = workgroup_id.x * num_workgroups.y + workgroup_id.y;
let n_threads_per_workgroup = wg_size_x * wg_size_y;
let wg_offset = wg * n_threads_per_workgroup;
let unique_thread_id = wg_offset + local_id;
let large_prime = 1000000007u;
let thread_seed = large_prime * unique_thread_id;
var state: array<u32, 4u>;
for (var i = 0u; i < 4u; i++) {
state[i] = info[i + 1u] + thread_seed;
}
let n_values_per_thread = info[0u];
for (var i = 0u; i < n_values_per_thread; i++) {
state[0u] = taus_step(state[0u], 13u, 19u, 12u, 4294967294u);
state[1u] = taus_step(state[1u], 2u, 25u, 4u, 4294967288u);
state[2u] = taus_step(state[2u], 3u, 11u, 17u, 4294967280u);
state[3u] = lcg_step(state[3u]);
let hybrid_taus = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let write_index = wg_offset * n_values_per_thread + local_id + i * n_threads_per_workgroup;
let float = cast_float(hybrid_taus);
let low = args[0];
let high = args[1];
let scale = high - low;
let bias = low;
output[write_index] = float * scale + bias;
}
}
fn lcg_step(z: u32) -> u32 {
return (1664525u * z + 1013904223u); // modulo 2^32, not necessary in u32
}
fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
let b = ((z << s1) ^ z) >> s2;
return (z & m) << s3 ^ b;
}
fn cast_float(number: u32) -> {{ elem }} {
return 2.3283064365387e-10 * {{ elem }}(number);
}

View File

@ -0,0 +1,16 @@
let low = args[0];
let high = args[1];
let scale = high - low;
let bias = low;
for (var i = 0u; i < n_values_per_thread; i++) {
let write_index = write_index_base + i * n_threads_per_workgroup;
state[0u] = taus_step_0(state[0u]);
state[1u] = taus_step_1(state[1u]);
state[2u] = taus_step_2(state[2u]);
state[3u] = lcg_step(state[3u]);
let random_u32 = state[0u] ^ state[1u] ^ state[2u] ^ state[3u];
let float = cast_float(random_u32);
output[write_index] = float * scale + bias;
}