From 5bf9caa06df47acda751ff5bcef95b937b86c71f Mon Sep 17 00:00:00 2001 From: Lu Teng Date: Wed, 22 May 2024 01:14:11 +0800 Subject: [PATCH] Fix inhomogeneous shape error in example (#30434) Fix inhomogeneous shape error in example. --- examples/flax/question-answering/run_qa.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index 7c705090f1..16a744ddc3 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -426,7 +426,8 @@ def eval_data_collator(dataset: Dataset, batch_size: int): for idx in batch_idx: batch = dataset[idx] - batch = {k: np.array(v) for k, v in batch.items()} + # Ignore `offset_mapping` to avoid numpy/JAX array conversion issue. + batch = {k: np.array(v) for k, v in batch.items() if k != "offset_mapping"} yield batch @@ -1000,7 +1001,6 @@ def main(): position=2, ): _ = batch.pop("example_id") - _ = batch.pop("offset_mapping") predictions = pad_shard_unpad(p_eval_step)( state, batch, min_device_batch=per_device_eval_batch_size ) @@ -1055,7 +1055,6 @@ def main(): eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2 ): _ = batch.pop("example_id") - _ = batch.pop("offset_mapping") predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size) start_logits = np.array(predictions[0]) end_logits = np.array(predictions[1])