Added parameter trust_remote_code to hf dataset call. (#2013)

* Added parameter trust_remote_code to hf dataset call.

* Removed test modul as it may break causing false negatives.
Set default trust_remote_code to false.
Added an example that highlights the usecase.
This commit is contained in:
José Manuel 2024-07-17 15:40:23 -06:00 committed by GitHub
parent 9804bf81b2
commit befe6c1601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 11 deletions

View File

@ -0,0 +1,22 @@
use burn_dataset::HuggingfaceDatasetLoader;
use burn_dataset::SqliteDataset;
use serde::Deserialize;
#[derive(Deserialize, Debug, Clone)]
struct MnistItemRaw {
pub _image_bytes: Vec<u8>,
pub _label: usize,
}
fn main() {
// There are some datasets, such as https://huggingface.co/datasets/ylecun/mnist/tree/main that contains a script,
// In this cases you must enable trusting remote code execution if you want to use it.
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
.with_trust_remote_code(true)
.dataset("train")
.unwrap();
// However not all dataset requires it https://huggingface.co/datasets/Anthropic/hh-rlhf/tree/main
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("Anthropic/hh-rlhf")
.dataset("train")
.unwrap();
}

View File

@ -63,6 +63,7 @@ pub struct HuggingfaceDatasetLoader {
base_dir: Option<PathBuf>,
huggingface_token: Option<String>,
huggingface_cache_dir: Option<String>,
trust_remote_code: bool,
}
impl HuggingfaceDatasetLoader {
@ -74,6 +75,7 @@ impl HuggingfaceDatasetLoader {
base_dir: None,
huggingface_token: None,
huggingface_cache_dir: None,
trust_remote_code: false,
}
}
@ -111,6 +113,14 @@ impl HuggingfaceDatasetLoader {
self
}
/// Specify whether or not to trust remote code.
///
/// If not specified, trust remote code is set to true.
pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
self.trust_remote_code = trust_remote_code;
self
}
/// Load the dataset.
pub fn dataset<I: DeserializeOwned + Clone>(
self,
@ -153,6 +163,7 @@ impl HuggingfaceDatasetLoader {
base_dir,
self.huggingface_token,
self.huggingface_cache_dir,
self.trust_remote_code,
)?;
}
@ -168,6 +179,7 @@ fn import(
base_dir: PathBuf,
huggingface_token: Option<String>,
huggingface_cache_dir: Option<String>,
trust_remote_code: bool,
) -> Result<(), ImporterError> {
let venv_python_path = install_python_deps(&base_dir)?;
@ -195,7 +207,10 @@ fn import(
command.arg("--cache_dir");
command.arg(huggingface_cache_dir);
}
if trust_remote_code {
command.arg("--trust_remote_code");
command.arg("True");
}
let mut handle = command.spawn().unwrap();
handle
.wait()

View File

@ -6,7 +6,14 @@ from sqlalchemy import Column, Integer, Table, create_engine, event, inspect
from sqlalchemy.types import LargeBinary
def download_and_export(name: str, subset: str, db_file: str, token: str, cache_dir: str):
def download_and_export(
name: str,
subset: str,
db_file: str,
token: str,
cache_dir: str,
trust_remote_code: bool,
):
"""
Download a dataset from using HuggingFace dataset and export it to a sqlite database.
"""
@ -15,18 +22,24 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
# bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}
# We should handle this case, but unfortunately we did not come across this case yet to test it.
print("*"*80)
print("*" * 80)
print("Starting huggingface dataset download and export")
print(f"Dataset Name: {name}")
print(f"Subset Name: {subset}")
print(f"Sqlite database file: {db_file}")
print(f"Trust remote code: {trust_remote_code}")
if cache_dir is None:
print(f"Custom cache dir: {cache_dir}")
print("*"*80)
print("*" * 80)
# Load the dataset
dataset_all = load_dataset(
name, subset, cache_dir=cache_dir, use_auth_token=token)
name,
subset,
cache_dir=cache_dir,
use_auth_token=token,
trust_remote_code=trust_remote_code,
)
print(f"Dataset: {dataset_all}")
@ -34,10 +47,10 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
engine = create_engine(f"sqlite:///{db_file}")
# Set some sqlite pragmas to speed up the database
event.listen(engine, 'connect', set_sqlite_pragma)
event.listen(engine, "connect", set_sqlite_pragma)
# Add an row_id column to each table as primary key (datasets does not have API for this)
event.listen(Table, 'before_create', add_pk_column)
event.listen(Table, "before_create", add_pk_column)
# Export each split in the dataset
for key in dataset_all.keys():
@ -61,7 +74,7 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
engine,
# don't save the index, use row_id instead (index is not unique)
index=False,
dtype=blob_columns(dataset), # save binary columns as blob
dtype=blob_columns(dataset), # save binary columns as blob
)
# Print the schema of the database so we can reference the columns in the rust code
@ -89,8 +102,8 @@ def rename_columns(dataset):
"""
for name in dataset.features.keys():
if '.' in name:
dataset = dataset.rename_column(name, name.replace('.', '_'))
if "." in name:
dataset = dataset.rename_column(name, name.replace(".", "_"))
return dataset
@ -151,11 +164,22 @@ def parse_args():
"--subset", type=str, help="Subset name", required=False, default=None
)
parser.add_argument(
"--token", type=str, help="HuggingFace authentication token", required=False, default=None
"--token",
type=str,
help="HuggingFace authentication token",
required=False,
default=None,
)
parser.add_argument(
"--cache_dir", type=str, help="Cache directory", required=False, default=None
)
parser.add_argument(
"--trust_remote_code",
type=bool,
help="Trust remote code",
required=False,
default=None,
)
return parser.parse_args()
@ -169,6 +193,7 @@ def run():
args.file,
args.token,
args.cache_dir,
args.trust_remote_code,
)