mirror of https://github.com/tracel-ai/burn.git
Can configure wgpu max tasks (#603)
This commit is contained in:
parent
00d3d208b8
commit
d18d1b0bb9
|
@ -25,3 +25,9 @@ mod wgpu {
|
|||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
You can set `BURN_WGPU_MAX_TASKS` to a positive integer that determines how many computing tasks are submitted in batches to the graphics API.
|
||||
The best value should be the smallest one that allows 100% GPU usage.
|
||||
A high value might increase GPU memory usage with no benefit.
|
||||
|
|
|
@ -29,6 +29,7 @@ pub struct SyncContextServer {
|
|||
queue: wgpu::Queue,
|
||||
encoder: CommandEncoder,
|
||||
tasks: Vec<ComputeTask>,
|
||||
max_tasks: usize,
|
||||
}
|
||||
|
||||
/// Basic building block to execute computing tasks on the GPU.
|
||||
|
@ -50,21 +51,27 @@ impl SyncContextServer {
|
|||
label: Some("Command Encoder"),
|
||||
});
|
||||
|
||||
// TODO: Support a way to modify this value without std.
|
||||
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
|
||||
Ok(value) => value
|
||||
.parse::<usize>()
|
||||
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
|
||||
Err(_) => 16, // 16 tasks by default
|
||||
};
|
||||
|
||||
Self {
|
||||
device,
|
||||
queue,
|
||||
encoder,
|
||||
tasks: Vec::new(),
|
||||
max_tasks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_compute(&mut self, task: ComputeTask) {
|
||||
self.tasks.push(task);
|
||||
|
||||
// Submit the tasks to the GPU when more than 50 tasks are accumulated.
|
||||
const MAX_TASKS: usize = 50;
|
||||
|
||||
if self.tasks.len() > MAX_TASKS {
|
||||
if self.tasks.len() > self.max_tasks {
|
||||
self.register_tasks();
|
||||
self.submit();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue