mirror of https://github.com/tracel-ai/burn.git
Added parameter trust_remote_code to hf dataset call. (#2013)
* Added parameter trust_remote_code to hf dataset call. * Removed test modul as it may break causing false negatives. Set default trust_remote_code to false. Added an example that highlights the usecase.
This commit is contained in:
parent
9804bf81b2
commit
befe6c1601
|
@ -0,0 +1,22 @@
|
||||||
|
use burn_dataset::HuggingfaceDatasetLoader;
|
||||||
|
use burn_dataset::SqliteDataset;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Clone)]
|
||||||
|
struct MnistItemRaw {
|
||||||
|
pub _image_bytes: Vec<u8>,
|
||||||
|
pub _label: usize,
|
||||||
|
}
|
||||||
|
fn main() {
|
||||||
|
// There are some datasets, such as https://huggingface.co/datasets/ylecun/mnist/tree/main that contains a script,
|
||||||
|
// In this cases you must enable trusting remote code execution if you want to use it.
|
||||||
|
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
||||||
|
.with_trust_remote_code(true)
|
||||||
|
.dataset("train")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// However not all dataset requires it https://huggingface.co/datasets/Anthropic/hh-rlhf/tree/main
|
||||||
|
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("Anthropic/hh-rlhf")
|
||||||
|
.dataset("train")
|
||||||
|
.unwrap();
|
||||||
|
}
|
|
@ -63,6 +63,7 @@ pub struct HuggingfaceDatasetLoader {
|
||||||
base_dir: Option<PathBuf>,
|
base_dir: Option<PathBuf>,
|
||||||
huggingface_token: Option<String>,
|
huggingface_token: Option<String>,
|
||||||
huggingface_cache_dir: Option<String>,
|
huggingface_cache_dir: Option<String>,
|
||||||
|
trust_remote_code: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HuggingfaceDatasetLoader {
|
impl HuggingfaceDatasetLoader {
|
||||||
|
@ -74,6 +75,7 @@ impl HuggingfaceDatasetLoader {
|
||||||
base_dir: None,
|
base_dir: None,
|
||||||
huggingface_token: None,
|
huggingface_token: None,
|
||||||
huggingface_cache_dir: None,
|
huggingface_cache_dir: None,
|
||||||
|
trust_remote_code: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,6 +113,14 @@ impl HuggingfaceDatasetLoader {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Specify whether or not to trust remote code.
|
||||||
|
///
|
||||||
|
/// If not specified, trust remote code is set to true.
|
||||||
|
pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
|
||||||
|
self.trust_remote_code = trust_remote_code;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Load the dataset.
|
/// Load the dataset.
|
||||||
pub fn dataset<I: DeserializeOwned + Clone>(
|
pub fn dataset<I: DeserializeOwned + Clone>(
|
||||||
self,
|
self,
|
||||||
|
@ -153,6 +163,7 @@ impl HuggingfaceDatasetLoader {
|
||||||
base_dir,
|
base_dir,
|
||||||
self.huggingface_token,
|
self.huggingface_token,
|
||||||
self.huggingface_cache_dir,
|
self.huggingface_cache_dir,
|
||||||
|
self.trust_remote_code,
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,6 +179,7 @@ fn import(
|
||||||
base_dir: PathBuf,
|
base_dir: PathBuf,
|
||||||
huggingface_token: Option<String>,
|
huggingface_token: Option<String>,
|
||||||
huggingface_cache_dir: Option<String>,
|
huggingface_cache_dir: Option<String>,
|
||||||
|
trust_remote_code: bool,
|
||||||
) -> Result<(), ImporterError> {
|
) -> Result<(), ImporterError> {
|
||||||
let venv_python_path = install_python_deps(&base_dir)?;
|
let venv_python_path = install_python_deps(&base_dir)?;
|
||||||
|
|
||||||
|
@ -195,7 +207,10 @@ fn import(
|
||||||
command.arg("--cache_dir");
|
command.arg("--cache_dir");
|
||||||
command.arg(huggingface_cache_dir);
|
command.arg(huggingface_cache_dir);
|
||||||
}
|
}
|
||||||
|
if trust_remote_code {
|
||||||
|
command.arg("--trust_remote_code");
|
||||||
|
command.arg("True");
|
||||||
|
}
|
||||||
let mut handle = command.spawn().unwrap();
|
let mut handle = command.spawn().unwrap();
|
||||||
handle
|
handle
|
||||||
.wait()
|
.wait()
|
||||||
|
|
|
@ -6,7 +6,14 @@ from sqlalchemy import Column, Integer, Table, create_engine, event, inspect
|
||||||
from sqlalchemy.types import LargeBinary
|
from sqlalchemy.types import LargeBinary
|
||||||
|
|
||||||
|
|
||||||
def download_and_export(name: str, subset: str, db_file: str, token: str, cache_dir: str):
|
def download_and_export(
|
||||||
|
name: str,
|
||||||
|
subset: str,
|
||||||
|
db_file: str,
|
||||||
|
token: str,
|
||||||
|
cache_dir: str,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Download a dataset from using HuggingFace dataset and export it to a sqlite database.
|
Download a dataset from using HuggingFace dataset and export it to a sqlite database.
|
||||||
"""
|
"""
|
||||||
|
@ -15,18 +22,24 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
|
||||||
# bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}
|
# bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}
|
||||||
# We should handle this case, but unfortunately we did not come across this case yet to test it.
|
# We should handle this case, but unfortunately we did not come across this case yet to test it.
|
||||||
|
|
||||||
print("*"*80)
|
print("*" * 80)
|
||||||
print("Starting huggingface dataset download and export")
|
print("Starting huggingface dataset download and export")
|
||||||
print(f"Dataset Name: {name}")
|
print(f"Dataset Name: {name}")
|
||||||
print(f"Subset Name: {subset}")
|
print(f"Subset Name: {subset}")
|
||||||
print(f"Sqlite database file: {db_file}")
|
print(f"Sqlite database file: {db_file}")
|
||||||
|
print(f"Trust remote code: {trust_remote_code}")
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
print(f"Custom cache dir: {cache_dir}")
|
print(f"Custom cache dir: {cache_dir}")
|
||||||
print("*"*80)
|
print("*" * 80)
|
||||||
|
|
||||||
# Load the dataset
|
# Load the dataset
|
||||||
dataset_all = load_dataset(
|
dataset_all = load_dataset(
|
||||||
name, subset, cache_dir=cache_dir, use_auth_token=token)
|
name,
|
||||||
|
subset,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
use_auth_token=token,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Dataset: {dataset_all}")
|
print(f"Dataset: {dataset_all}")
|
||||||
|
|
||||||
|
@ -34,10 +47,10 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
|
||||||
engine = create_engine(f"sqlite:///{db_file}")
|
engine = create_engine(f"sqlite:///{db_file}")
|
||||||
|
|
||||||
# Set some sqlite pragmas to speed up the database
|
# Set some sqlite pragmas to speed up the database
|
||||||
event.listen(engine, 'connect', set_sqlite_pragma)
|
event.listen(engine, "connect", set_sqlite_pragma)
|
||||||
|
|
||||||
# Add an row_id column to each table as primary key (datasets does not have API for this)
|
# Add an row_id column to each table as primary key (datasets does not have API for this)
|
||||||
event.listen(Table, 'before_create', add_pk_column)
|
event.listen(Table, "before_create", add_pk_column)
|
||||||
|
|
||||||
# Export each split in the dataset
|
# Export each split in the dataset
|
||||||
for key in dataset_all.keys():
|
for key in dataset_all.keys():
|
||||||
|
@ -61,7 +74,7 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
|
||||||
engine,
|
engine,
|
||||||
# don't save the index, use row_id instead (index is not unique)
|
# don't save the index, use row_id instead (index is not unique)
|
||||||
index=False,
|
index=False,
|
||||||
dtype=blob_columns(dataset), # save binary columns as blob
|
dtype=blob_columns(dataset), # save binary columns as blob
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print the schema of the database so we can reference the columns in the rust code
|
# Print the schema of the database so we can reference the columns in the rust code
|
||||||
|
@ -89,8 +102,8 @@ def rename_columns(dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for name in dataset.features.keys():
|
for name in dataset.features.keys():
|
||||||
if '.' in name:
|
if "." in name:
|
||||||
dataset = dataset.rename_column(name, name.replace('.', '_'))
|
dataset = dataset.rename_column(name, name.replace(".", "_"))
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
@ -151,11 +164,22 @@ def parse_args():
|
||||||
"--subset", type=str, help="Subset name", required=False, default=None
|
"--subset", type=str, help="Subset name", required=False, default=None
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--token", type=str, help="HuggingFace authentication token", required=False, default=None
|
"--token",
|
||||||
|
type=str,
|
||||||
|
help="HuggingFace authentication token",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache_dir", type=str, help="Cache directory", required=False, default=None
|
"--cache_dir", type=str, help="Cache directory", required=False, default=None
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust_remote_code",
|
||||||
|
type=bool,
|
||||||
|
help="Trust remote code",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -169,6 +193,7 @@ def run():
|
||||||
args.file,
|
args.file,
|
||||||
args.token,
|
args.token,
|
||||||
args.cache_dir,
|
args.cache_dir,
|
||||||
|
args.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue