Fix inhomogeneous shape error in example (#30434)

Fix inhomogeneous shape error in example.
This commit is contained in:
Lu Teng 2024-05-22 01:14:11 +08:00 committed by GitHub
parent d24097e022
commit 5bf9caa06d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 3 deletions

View File

@ -426,7 +426,8 @@ def eval_data_collator(dataset: Dataset, batch_size: int):
for idx in batch_idx: for idx in batch_idx:
batch = dataset[idx] batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()} # Ignore `offset_mapping` to avoid numpy/JAX array conversion issue.
batch = {k: np.array(v) for k, v in batch.items() if k != "offset_mapping"}
yield batch yield batch
@ -1000,7 +1001,6 @@ def main():
position=2, position=2,
): ):
_ = batch.pop("example_id") _ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = pad_shard_unpad(p_eval_step)( predictions = pad_shard_unpad(p_eval_step)(
state, batch, min_device_batch=per_device_eval_batch_size state, batch, min_device_batch=per_device_eval_batch_size
) )
@ -1055,7 +1055,6 @@ def main():
eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2 eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2
): ):
_ = batch.pop("example_id") _ = batch.pop("example_id")
_ = batch.pop("offset_mapping")
predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size) predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
start_logits = np.array(predictions[0]) start_logits = np.array(predictions[0])
end_logits = np.array(predictions[1]) end_logits = np.array(predictions[1])