mirror of https://github.com/tracel-ai/burn.git
Add example for custom CSV dataset (#1129)
This commit is contained in:
parent
f43b686366
commit
535458e7b9
|
@ -81,6 +81,7 @@ wasm-bindgen-futures = "0.4.38"
|
|||
wasm-logger = "0.2.0"
|
||||
wasm-timer = "0.2.5"
|
||||
console_error_panic_hook = "0.1.7"
|
||||
reqwest = "0.11.23"
|
||||
|
||||
|
||||
# WGPU stuff
|
||||
|
|
|
@ -62,15 +62,18 @@ where
|
|||
|
||||
/// Create from a csv file.
|
||||
///
|
||||
/// The first line of the csv file must be the header. The header must contain the name of the fields in the struct.
|
||||
/// The provided `csv::ReaderBuilder` can be configured to fit your csv format.
|
||||
///
|
||||
/// The supported field types are: String, integer, float, and bool.
|
||||
///
|
||||
/// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
|
||||
pub fn from_csv<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut rdr = csv::Reader::from_reader(reader);
|
||||
/// See:
|
||||
/// - [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
|
||||
/// - [Delimiters, quotes and variable length records](https://docs.rs/csv/latest/csv/tutorial/index.html#delimiters-quotes-and-variable-length-records)
|
||||
pub fn from_csv<P: AsRef<Path>>(
|
||||
path: P,
|
||||
builder: &csv::ReaderBuilder,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let mut rdr = builder.from_path(path)?;
|
||||
|
||||
let mut items = Vec::new();
|
||||
|
||||
|
@ -97,6 +100,7 @@ mod tests {
|
|||
const DB_FILE: &str = "tests/data/sqlite-dataset.db";
|
||||
const JSON_FILE: &str = "tests/data/dataset.json";
|
||||
const CSV_FILE: &str = "tests/data/dataset.csv";
|
||||
const CSV_FMT_FILE: &str = "tests/data/dataset-fmt.csv";
|
||||
|
||||
type SqlDs = SqliteDataset<Sample>;
|
||||
|
||||
|
@ -110,7 +114,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct SampleCvs {
|
||||
pub struct SampleCsv {
|
||||
column_str: String,
|
||||
column_int: i64,
|
||||
column_bool: bool,
|
||||
|
@ -147,7 +151,24 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
pub fn from_csv_rows() {
|
||||
let dataset = InMemDataset::<SampleCvs>::from_csv(CSV_FILE).unwrap();
|
||||
let rdr = csv::ReaderBuilder::new();
|
||||
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FILE, &rdr).unwrap();
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 1;
|
||||
|
||||
assert_eq!(dataset.get(non_existing_record_index), None);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
|
||||
assert!(!dataset.get(record_index).unwrap().column_bool);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn from_csv_rows_fmt() {
|
||||
let mut rdr = csv::ReaderBuilder::new();
|
||||
let rdr = rdr.delimiter(b' ').has_headers(false);
|
||||
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FMT_FILE, rdr).unwrap();
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 1;
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
HI1 1 true 1.0
|
||||
HI2 1 false 1.0
|
|
|
@ -0,0 +1,2 @@
|
|||
# Ignore downloaded csv file
|
||||
*.csv
|
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
authors = ["guillaumelagrange <lagrange.guillaume.1@gmail.com>"]
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "custom-csv-dataset"
|
||||
description = "Example implementation for loading a custom CSV dataset from disk"
|
||||
publish = false
|
||||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["burn/dataset"]
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../burn"}
|
||||
|
||||
# File download
|
||||
reqwest = {workspace = true, features = ["blocking"]}
|
||||
tempfile = {workspace = true}
|
||||
|
||||
# CSV parsing
|
||||
csv = {workspace = true}
|
||||
serde = {workspace = true, features = ["std", "derive"]}
|
|
@ -0,0 +1,11 @@
|
|||
# Custom CSV Dataset
|
||||
|
||||
The [custom-csv-dataset](src/dataset.rs) example implements the `Dataset` trait to retrieve dataset elements from a `.csv` file on disk. For this example, we use the [diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset) (original [source](https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html)).
|
||||
|
||||
The dataset only contains 442 records, so we use [`InMemDataset::from_csv(path)`](src/dataset.rs#L80) method to read the csv dataset file into a vector (in-memory) of [`DiabetesPatient`](src/dataset.rs#L13) records (struct) with the help of `serde`.
|
||||
|
||||
## Example Usage
|
||||
|
||||
```sh
|
||||
cargo run --example custom-csv-dataset
|
||||
```
|
|
@ -0,0 +1,15 @@
|
|||
use burn::data::dataset::Dataset;
|
||||
use custom_csv_dataset::dataset::DiabetesDataset;
|
||||
|
||||
fn main() {
|
||||
let dataset = DiabetesDataset::new().expect("Could not load diabetes dataset");
|
||||
|
||||
println!("Dataset loaded with {} rows", dataset.len());
|
||||
|
||||
// Display first and last elements
|
||||
let item = dataset.get(0).unwrap();
|
||||
println!("First item:\n{:?}", item);
|
||||
|
||||
let item = dataset.get(441).unwrap();
|
||||
println!("Last item:\n{:?}", item);
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
use burn::data::dataset::{Dataset, InMemDataset};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fs::File,
|
||||
io::copy,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
/// Diabetes patient record.
|
||||
/// For each field, we manually specify the expected header name for serde as all names
|
||||
/// are capitalized and some field names are not very informative.
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct DiabetesPatient {
|
||||
/// Age in years
|
||||
#[serde(rename = "AGE")]
|
||||
pub age: u8,
|
||||
|
||||
/// Sex categorical label
|
||||
#[serde(rename = "SEX")]
|
||||
pub sex: u8,
|
||||
|
||||
/// Body mass index
|
||||
#[serde(rename = "BMI")]
|
||||
pub bmi: f32,
|
||||
|
||||
/// Average blood pressure
|
||||
#[serde(rename = "BP")]
|
||||
pub bp: f32,
|
||||
|
||||
/// S1: total serum cholesterol
|
||||
#[serde(rename = "S1")]
|
||||
pub tc: u16,
|
||||
|
||||
/// S2: low-density lipoproteins
|
||||
#[serde(rename = "S2")]
|
||||
pub ldl: f32,
|
||||
|
||||
/// S3: high-density lipoproteins
|
||||
#[serde(rename = "S3")]
|
||||
pub hdl: f32,
|
||||
|
||||
/// S4: total cholesterol
|
||||
#[serde(rename = "S4")]
|
||||
pub tch: f32,
|
||||
|
||||
/// S5: possibly log of serum triglycerides level
|
||||
#[serde(rename = "S5")]
|
||||
pub ltg: f32,
|
||||
|
||||
/// S6: blood sugar level
|
||||
#[serde(rename = "S6")]
|
||||
pub glu: u8,
|
||||
|
||||
/// Y: quantitative measure of disease progression one year after baseline
|
||||
#[serde(rename = "Y")]
|
||||
pub response: u16,
|
||||
}
|
||||
|
||||
/// Diabetes patients dataset, also used in [scikit-learn](https://scikit-learn.org/stable/).
|
||||
/// See [Diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset).
|
||||
///
|
||||
/// The data is parsed from a single csv file (tab as the delimiter).
|
||||
/// The dataset contains 10 baseline variables (age, sex, body mass index, average blood pressure and
|
||||
/// 6 blood serum measurements for a total of 442 diabetes patients.
|
||||
/// For each patient, the response of interest, a quantitative measure of disease progression one year
|
||||
/// after baseline, was collected. This represents the target variable.
|
||||
pub struct DiabetesDataset {
|
||||
dataset: InMemDataset<DiabetesPatient>,
|
||||
}
|
||||
|
||||
impl DiabetesDataset {
|
||||
pub fn new() -> Result<Self, std::io::Error> {
|
||||
// Download dataset csv file
|
||||
let path = DiabetesDataset::download();
|
||||
|
||||
// Build dataset from csv with tab ('\t') delimiter
|
||||
let mut rdr = csv::ReaderBuilder::new();
|
||||
let rdr = rdr.delimiter(b'\t');
|
||||
|
||||
let dataset = InMemDataset::from_csv(path, rdr).unwrap();
|
||||
|
||||
let dataset = Self { dataset };
|
||||
|
||||
Ok(dataset)
|
||||
}
|
||||
/// Download the CSV file from its original source on the web.
|
||||
/// Panics if the download cannot be completed or the content of the file cannot be written to disk.
|
||||
fn download() -> PathBuf {
|
||||
// Point file to current example directory
|
||||
let example_dir = Path::new(file!()).parent().unwrap().parent().unwrap();
|
||||
let file_name = example_dir.join("diabetes.csv");
|
||||
|
||||
if file_name.exists() {
|
||||
println!("File already downloaded at {:?}", file_name);
|
||||
} else {
|
||||
// Get file from web
|
||||
println!("Downloading file to {:?}", file_name);
|
||||
let url = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt";
|
||||
let mut response = reqwest::blocking::get(url).unwrap();
|
||||
|
||||
// Create file to write the downloaded content to
|
||||
let mut file = File::create(&file_name).unwrap();
|
||||
|
||||
// Copy the downloaded contents
|
||||
copy(&mut response, &mut file).unwrap();
|
||||
};
|
||||
|
||||
file_name
|
||||
}
|
||||
}
|
||||
|
||||
// Implement the `Dataset` trait which requires `get` and `len`
|
||||
impl Dataset<DiabetesPatient> for DiabetesDataset {
|
||||
fn get(&self, index: usize) -> Option<DiabetesPatient> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
pub mod dataset;
|
Loading…
Reference in New Issue