update inference.py to return correct SQL without worrying about prompt
This commit is contained in:
parent
4d884e2083
commit
c3ca3dd8ea
|
@ -39,6 +39,7 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql
|
|||
tokenizer=tokenizer,
|
||||
max_new_tokens=300,
|
||||
do_sample=False,
|
||||
return_full_text=False, # added return_full_text parameter to prevent splitting issues with prompt
|
||||
num_beams=5, # do beam search with 5 beams for high quality results
|
||||
)
|
||||
generated_query = (
|
||||
|
@ -48,8 +49,8 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql
|
|||
eos_token_id=eos_token_id,
|
||||
pad_token_id=eos_token_id,
|
||||
)[0]["generated_text"]
|
||||
.split("```")[1]
|
||||
.split(";")[0]
|
||||
.split("```")[0]
|
||||
.strip()
|
||||
+ ";"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue