Merge pull request #1 from defog-ai/wendy/testing
Better instructions, comments, and README from Wendy
This commit is contained in:
commit
36ecbc8867
|
@ -37,7 +37,14 @@ We classified each generated question into one of 5 categories. The table displa
|
|||
| where | 80.0 | 65.7 | 62.9 | 60.0 | 60.0 | 60.0 | 45.7 |
|
||||
|
||||
## Using SQLCoder
|
||||
You can use SQLCoder via the `transformers` library by downloading our model weights from the HuggingFace repo. We have added sample code for inference [here](./inference.py). You can also use a demo on our website [here](https://defog.ai/sqlcoder), or run SQLCoder in Colab [here](https://colab.research.google.com/drive/13BIKsqHnPOBcQ-ba2p77L5saiepTIwu0#scrollTo=ZpbVgVHMkJvC)
|
||||
You can use SQLCoder via the `transformers` library by downloading our model weights from the HuggingFace repo. We have added sample code for [inference](./inference.py) on a [sample database](./metadata.sql).
|
||||
```bash
|
||||
python inference.py -q "Question about the sample database goes here"
|
||||
|
||||
# Sample questions:
|
||||
```
|
||||
|
||||
You can also use a demo on our website [here](https://defog.ai/sqlcoder), or run SQLCoder in Colab [here](https://colab.research.google.com/drive/13BIKsqHnPOBcQ-ba2p77L5saiepTIwu0#scrollTo=ZpbVgVHMkJvC)
|
||||
|
||||
## Hardware Requirements
|
||||
SQLCoder has been tested on an A100 40GB GPU with `bfloat16` weights. You can also load an 8-bit quantized version of the model on consumer GPUs with 20GB or more of memory – like RTX 4090, RTX 3090, and Apple M2 Pro, M2 Max, or M2 Ultra Chips with 20GB or more of memory.
|
||||
|
|
|
@ -27,7 +27,7 @@ def get_tokenizer_model(model_name):
|
|||
return tokenizer, model
|
||||
|
||||
def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
|
||||
tokenizer, model = get_tokenizer_model("defog/starcoder-finetune-v3")
|
||||
tokenizer, model = get_tokenizer_model("defog/sqlcoder")
|
||||
prompt = generate_prompt(question, prompt_file, metadata_file)
|
||||
|
||||
# make sure the model stops generating at triple ticks
|
||||
|
@ -58,7 +58,7 @@ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql
|
|||
if __name__ == "__main__":
|
||||
# Parse arguments
|
||||
parser = argparse.ArgumentParser(description="Run inference on a question")
|
||||
parser.add_argument("--question", type=str, help="Question to run inference on")
|
||||
parser.add_argument("-q","--question", type=str, help="Question to run inference on")
|
||||
args = parser.parse_args()
|
||||
question = args.question
|
||||
print("Loading a model and generating a SQL query for answering your question...")
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
argsparse
|
||||
torch
|
||||
transformers
|
Loading…
Reference in New Issue