diff --git a/inference.py b/inference.py index 2d6c4c1..d18ab7a 100644 --- a/inference.py +++ b/inference.py @@ -48,8 +48,7 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql eos_token_id=eos_token_id, pad_token_id=eos_token_id, )[0]["generated_text"] - .split("# SQL\n")[-1] - .split("```")[0] + .split("```")[1] .split(";")[0] .strip() + ";"