mirror of https://github.com/tracel-ai/burn.git
Added parameter trust_remote_code to hf dataset call. (#2013)
* Added parameter trust_remote_code to hf dataset call. * Removed test modul as it may break causing false negatives. Set default trust_remote_code to false. Added an example that highlights the usecase.
This commit is contained in:
parent
9804bf81b2
commit
befe6c1601
|
@ -0,0 +1,22 @@
|
|||
use burn_dataset::HuggingfaceDatasetLoader;
|
||||
use burn_dataset::SqliteDataset;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct MnistItemRaw {
|
||||
pub _image_bytes: Vec<u8>,
|
||||
pub _label: usize,
|
||||
}
|
||||
fn main() {
|
||||
// There are some datasets, such as https://huggingface.co/datasets/ylecun/mnist/tree/main that contains a script,
|
||||
// In this cases you must enable trusting remote code execution if you want to use it.
|
||||
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
||||
.with_trust_remote_code(true)
|
||||
.dataset("train")
|
||||
.unwrap();
|
||||
|
||||
// However not all dataset requires it https://huggingface.co/datasets/Anthropic/hh-rlhf/tree/main
|
||||
let _train_ds: SqliteDataset<MnistItemRaw> = HuggingfaceDatasetLoader::new("Anthropic/hh-rlhf")
|
||||
.dataset("train")
|
||||
.unwrap();
|
||||
}
|
|
@ -63,6 +63,7 @@ pub struct HuggingfaceDatasetLoader {
|
|||
base_dir: Option<PathBuf>,
|
||||
huggingface_token: Option<String>,
|
||||
huggingface_cache_dir: Option<String>,
|
||||
trust_remote_code: bool,
|
||||
}
|
||||
|
||||
impl HuggingfaceDatasetLoader {
|
||||
|
@ -74,6 +75,7 @@ impl HuggingfaceDatasetLoader {
|
|||
base_dir: None,
|
||||
huggingface_token: None,
|
||||
huggingface_cache_dir: None,
|
||||
trust_remote_code: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,6 +113,14 @@ impl HuggingfaceDatasetLoader {
|
|||
self
|
||||
}
|
||||
|
||||
/// Specify whether or not to trust remote code.
|
||||
///
|
||||
/// If not specified, trust remote code is set to true.
|
||||
pub fn with_trust_remote_code(mut self, trust_remote_code: bool) -> Self {
|
||||
self.trust_remote_code = trust_remote_code;
|
||||
self
|
||||
}
|
||||
|
||||
/// Load the dataset.
|
||||
pub fn dataset<I: DeserializeOwned + Clone>(
|
||||
self,
|
||||
|
@ -153,6 +163,7 @@ impl HuggingfaceDatasetLoader {
|
|||
base_dir,
|
||||
self.huggingface_token,
|
||||
self.huggingface_cache_dir,
|
||||
self.trust_remote_code,
|
||||
)?;
|
||||
}
|
||||
|
||||
|
@ -168,6 +179,7 @@ fn import(
|
|||
base_dir: PathBuf,
|
||||
huggingface_token: Option<String>,
|
||||
huggingface_cache_dir: Option<String>,
|
||||
trust_remote_code: bool,
|
||||
) -> Result<(), ImporterError> {
|
||||
let venv_python_path = install_python_deps(&base_dir)?;
|
||||
|
||||
|
@ -195,7 +207,10 @@ fn import(
|
|||
command.arg("--cache_dir");
|
||||
command.arg(huggingface_cache_dir);
|
||||
}
|
||||
|
||||
if trust_remote_code {
|
||||
command.arg("--trust_remote_code");
|
||||
command.arg("True");
|
||||
}
|
||||
let mut handle = command.spawn().unwrap();
|
||||
handle
|
||||
.wait()
|
||||
|
|
|
@ -6,7 +6,14 @@ from sqlalchemy import Column, Integer, Table, create_engine, event, inspect
|
|||
from sqlalchemy.types import LargeBinary
|
||||
|
||||
|
||||
def download_and_export(name: str, subset: str, db_file: str, token: str, cache_dir: str):
|
||||
def download_and_export(
|
||||
name: str,
|
||||
subset: str,
|
||||
db_file: str,
|
||||
token: str,
|
||||
cache_dir: str,
|
||||
trust_remote_code: bool,
|
||||
):
|
||||
"""
|
||||
Download a dataset from using HuggingFace dataset and export it to a sqlite database.
|
||||
"""
|
||||
|
@ -15,18 +22,24 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
|
|||
# bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}
|
||||
# We should handle this case, but unfortunately we did not come across this case yet to test it.
|
||||
|
||||
print("*"*80)
|
||||
print("*" * 80)
|
||||
print("Starting huggingface dataset download and export")
|
||||
print(f"Dataset Name: {name}")
|
||||
print(f"Subset Name: {subset}")
|
||||
print(f"Sqlite database file: {db_file}")
|
||||
print(f"Trust remote code: {trust_remote_code}")
|
||||
if cache_dir is None:
|
||||
print(f"Custom cache dir: {cache_dir}")
|
||||
print("*"*80)
|
||||
print("*" * 80)
|
||||
|
||||
# Load the dataset
|
||||
dataset_all = load_dataset(
|
||||
name, subset, cache_dir=cache_dir, use_auth_token=token)
|
||||
name,
|
||||
subset,
|
||||
cache_dir=cache_dir,
|
||||
use_auth_token=token,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
print(f"Dataset: {dataset_all}")
|
||||
|
||||
|
@ -34,10 +47,10 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
|
|||
engine = create_engine(f"sqlite:///{db_file}")
|
||||
|
||||
# Set some sqlite pragmas to speed up the database
|
||||
event.listen(engine, 'connect', set_sqlite_pragma)
|
||||
event.listen(engine, "connect", set_sqlite_pragma)
|
||||
|
||||
# Add an row_id column to each table as primary key (datasets does not have API for this)
|
||||
event.listen(Table, 'before_create', add_pk_column)
|
||||
event.listen(Table, "before_create", add_pk_column)
|
||||
|
||||
# Export each split in the dataset
|
||||
for key in dataset_all.keys():
|
||||
|
@ -61,7 +74,7 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
|
|||
engine,
|
||||
# don't save the index, use row_id instead (index is not unique)
|
||||
index=False,
|
||||
dtype=blob_columns(dataset), # save binary columns as blob
|
||||
dtype=blob_columns(dataset), # save binary columns as blob
|
||||
)
|
||||
|
||||
# Print the schema of the database so we can reference the columns in the rust code
|
||||
|
@ -89,8 +102,8 @@ def rename_columns(dataset):
|
|||
"""
|
||||
|
||||
for name in dataset.features.keys():
|
||||
if '.' in name:
|
||||
dataset = dataset.rename_column(name, name.replace('.', '_'))
|
||||
if "." in name:
|
||||
dataset = dataset.rename_column(name, name.replace(".", "_"))
|
||||
|
||||
return dataset
|
||||
|
||||
|
@ -151,11 +164,22 @@ def parse_args():
|
|||
"--subset", type=str, help="Subset name", required=False, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token", type=str, help="HuggingFace authentication token", required=False, default=None
|
||||
"--token",
|
||||
type=str,
|
||||
help="HuggingFace authentication token",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, help="Cache directory", required=False, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust_remote_code",
|
||||
type=bool,
|
||||
help="Trust remote code",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -169,6 +193,7 @@ def run():
|
|||
args.file,
|
||||
args.token,
|
||||
args.cache_dir,
|
||||
args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue