Add example for custom CSV dataset (#1129)

This commit is contained in:
Guillaume Lagrange 2024-01-11 09:24:25 -05:00 committed by GitHub
parent f43b686366
commit 535458e7b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 204 additions and 8 deletions

View File

@ -81,6 +81,7 @@ wasm-bindgen-futures = "0.4.38"
wasm-logger = "0.2.0"
wasm-timer = "0.2.5"
console_error_panic_hook = "0.1.7"
reqwest = "0.11.23"
# WGPU stuff

View File

@ -62,15 +62,18 @@ where
/// Create from a csv file.
///
/// The first line of the csv file must be the header. The header must contain the name of the fields in the struct.
/// The provided `csv::ReaderBuilder` can be configured to fit your csv format.
///
/// The supported field types are: String, integer, float, and bool.
///
/// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
pub fn from_csv<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut rdr = csv::Reader::from_reader(reader);
/// See:
/// - [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
/// - [Delimiters, quotes and variable length records](https://docs.rs/csv/latest/csv/tutorial/index.html#delimiters-quotes-and-variable-length-records)
pub fn from_csv<P: AsRef<Path>>(
path: P,
builder: &csv::ReaderBuilder,
) -> Result<Self, std::io::Error> {
let mut rdr = builder.from_path(path)?;
let mut items = Vec::new();
@ -97,6 +100,7 @@ mod tests {
const DB_FILE: &str = "tests/data/sqlite-dataset.db";
const JSON_FILE: &str = "tests/data/dataset.json";
const CSV_FILE: &str = "tests/data/dataset.csv";
const CSV_FMT_FILE: &str = "tests/data/dataset-fmt.csv";
type SqlDs = SqliteDataset<Sample>;
@ -110,7 +114,7 @@ mod tests {
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SampleCvs {
pub struct SampleCsv {
column_str: String,
column_int: i64,
column_bool: bool,
@ -147,7 +151,24 @@ mod tests {
#[test]
pub fn from_csv_rows() {
let dataset = InMemDataset::<SampleCvs>::from_csv(CSV_FILE).unwrap();
let rdr = csv::ReaderBuilder::new();
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FILE, &rdr).unwrap();
let non_existing_record_index: usize = 10;
let record_index: usize = 1;
assert_eq!(dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
assert!(!dataset.get(record_index).unwrap().column_bool);
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
}
#[test]
pub fn from_csv_rows_fmt() {
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.delimiter(b' ').has_headers(false);
let dataset = InMemDataset::<SampleCsv>::from_csv(CSV_FMT_FILE, rdr).unwrap();
let non_existing_record_index: usize = 10;
let record_index: usize = 1;

View File

@ -0,0 +1,2 @@
HI1 1 true 1.0
HI2 1 false 1.0
1 HI1 1 true 1.0
2 HI2 1 false 1.0

View File

@ -0,0 +1,2 @@
# Ignore downloaded csv file
*.csv

View File

@ -0,0 +1,22 @@
[package]
authors = ["guillaumelagrange <lagrange.guillaume.1@gmail.com>"]
edition.workspace = true
license.workspace = true
name = "custom-csv-dataset"
description = "Example implementation for loading a custom CSV dataset from disk"
publish = false
version.workspace = true
[features]
default = ["burn/dataset"]
[dependencies]
burn = {path = "../../burn"}
# File download
reqwest = {workspace = true, features = ["blocking"]}
tempfile = {workspace = true}
# CSV parsing
csv = {workspace = true}
serde = {workspace = true, features = ["std", "derive"]}

View File

@ -0,0 +1,11 @@
# Custom CSV Dataset
The [custom-csv-dataset](src/dataset.rs) example implements the `Dataset` trait to retrieve dataset elements from a `.csv` file on disk. For this example, we use the [diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset) (original [source](https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html)).
The dataset only contains 442 records, so we use [`InMemDataset::from_csv(path)`](src/dataset.rs#L80) method to read the csv dataset file into a vector (in-memory) of [`DiabetesPatient`](src/dataset.rs#L13) records (struct) with the help of `serde`.
## Example Usage
```sh
cargo run --example custom-csv-dataset
```

View File

@ -0,0 +1,15 @@
use burn::data::dataset::Dataset;
use custom_csv_dataset::dataset::DiabetesDataset;
fn main() {
let dataset = DiabetesDataset::new().expect("Could not load diabetes dataset");
println!("Dataset loaded with {} rows", dataset.len());
// Display first and last elements
let item = dataset.get(0).unwrap();
println!("First item:\n{:?}", item);
let item = dataset.get(441).unwrap();
println!("Last item:\n{:?}", item);
}

View File

@ -0,0 +1,121 @@
use burn::data::dataset::{Dataset, InMemDataset};
use serde::{Deserialize, Serialize};
use std::{
fs::File,
io::copy,
path::{Path, PathBuf},
};
/// Diabetes patient record.
/// For each field, we manually specify the expected header name for serde as all names
/// are capitalized and some field names are not very informative.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct DiabetesPatient {
/// Age in years
#[serde(rename = "AGE")]
pub age: u8,
/// Sex categorical label
#[serde(rename = "SEX")]
pub sex: u8,
/// Body mass index
#[serde(rename = "BMI")]
pub bmi: f32,
/// Average blood pressure
#[serde(rename = "BP")]
pub bp: f32,
/// S1: total serum cholesterol
#[serde(rename = "S1")]
pub tc: u16,
/// S2: low-density lipoproteins
#[serde(rename = "S2")]
pub ldl: f32,
/// S3: high-density lipoproteins
#[serde(rename = "S3")]
pub hdl: f32,
/// S4: total cholesterol
#[serde(rename = "S4")]
pub tch: f32,
/// S5: possibly log of serum triglycerides level
#[serde(rename = "S5")]
pub ltg: f32,
/// S6: blood sugar level
#[serde(rename = "S6")]
pub glu: u8,
/// Y: quantitative measure of disease progression one year after baseline
#[serde(rename = "Y")]
pub response: u16,
}
/// Diabetes patients dataset, also used in [scikit-learn](https://scikit-learn.org/stable/).
/// See [Diabetes dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset).
///
/// The data is parsed from a single csv file (tab as the delimiter).
/// The dataset contains 10 baseline variables (age, sex, body mass index, average blood pressure and
/// 6 blood serum measurements for a total of 442 diabetes patients.
/// For each patient, the response of interest, a quantitative measure of disease progression one year
/// after baseline, was collected. This represents the target variable.
pub struct DiabetesDataset {
dataset: InMemDataset<DiabetesPatient>,
}
impl DiabetesDataset {
pub fn new() -> Result<Self, std::io::Error> {
// Download dataset csv file
let path = DiabetesDataset::download();
// Build dataset from csv with tab ('\t') delimiter
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.delimiter(b'\t');
let dataset = InMemDataset::from_csv(path, rdr).unwrap();
let dataset = Self { dataset };
Ok(dataset)
}
/// Download the CSV file from its original source on the web.
/// Panics if the download cannot be completed or the content of the file cannot be written to disk.
fn download() -> PathBuf {
// Point file to current example directory
let example_dir = Path::new(file!()).parent().unwrap().parent().unwrap();
let file_name = example_dir.join("diabetes.csv");
if file_name.exists() {
println!("File already downloaded at {:?}", file_name);
} else {
// Get file from web
println!("Downloading file to {:?}", file_name);
let url = "https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt";
let mut response = reqwest::blocking::get(url).unwrap();
// Create file to write the downloaded content to
let mut file = File::create(&file_name).unwrap();
// Copy the downloaded contents
copy(&mut response, &mut file).unwrap();
};
file_name
}
}
// Implement the `Dataset` trait which requires `get` and `len`
impl Dataset<DiabetesPatient> for DiabetesDataset {
fn get(&self, index: usize) -> Option<DiabetesPatient> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}

View File

@ -0,0 +1 @@
pub mod dataset;