update inference.py to return correct SQL without worrying about prompt

This commit is contained in:
Rishabh Srivastava 2024-02-04 16:24:33 +08:00
parent 4d884e2083
commit c3ca3dd8ea
1 changed files with 2 additions and 1 deletions

View File

@ -39,6 +39,7 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql
tokenizer=tokenizer,
max_new_tokens=300,
do_sample=False,
return_full_text=False, # added return_full_text parameter to prevent splitting issues with prompt
num_beams=5, # do beam search with 5 beams for high quality results
)
generated_query = (
@ -48,8 +49,8 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql
eos_token_id=eos_token_id,
pad_token_id=eos_token_id,
)[0]["generated_text"]
.split("```")[1]
.split(";")[0]
.split("```")[0]
.strip()
+ ";"
)