mirror of https://github.com/tracel-ai/burn.git
Fix bench load record benchmarks (#1826)
This commit is contained in:
parent
cccd96de48
commit
c59a3b8b8a
|
@ -53,7 +53,7 @@ struct LoadRecordBenchmark<B: Backend> {
|
|||
}
|
||||
|
||||
impl<B: Backend> Benchmark for LoadRecordBenchmark<B> {
|
||||
type Args = BenchmarkModuleRecord<B>;
|
||||
type Args = BenchmarkModule<B>;
|
||||
|
||||
fn name(&self) -> String {
|
||||
format!("load_record_{:?}", self.kind).to_lowercase()
|
||||
|
@ -67,7 +67,9 @@ impl<B: Backend> Benchmark for LoadRecordBenchmark<B> {
|
|||
10
|
||||
}
|
||||
|
||||
fn execute(&self, record: Self::Args) {
|
||||
fn execute(&self, module: Self::Args) {
|
||||
let record = module.into_record();
|
||||
|
||||
let _ = match self.kind {
|
||||
Kind::Lazy => {
|
||||
let module = self.config.init(&self.device);
|
||||
|
@ -86,9 +88,8 @@ impl<B: Backend> Benchmark for LoadRecordBenchmark<B> {
|
|||
fn prepare(&self) -> Self::Args {
|
||||
let module = self.config.init(&self.device);
|
||||
// Force sync.
|
||||
let module_initialized = module.clone();
|
||||
|
||||
module_initialized.into_record()
|
||||
module.clone()
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
|
|
Loading…
Reference in New Issue