Add an examples folder for code downstream tasks (#18679)

* add examples subfolder

* mention examples in codeparrot readme

* use Trainer optimizer and scheduler type and add output_dir as argument

* add example of text-to-python and python-to-text models

* mention the downstream examples in the readme

* fix typo
This commit is contained in:
Loubna Ben Allal 2022-08-18 18:24:24 +02:00 committed by GitHub
parent a123eee9df
commit bbbb453e58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 200 additions and 2 deletions

View File

@ -12,7 +12,11 @@ This is an open-source effort to train and evaluate code generation models. Code
- continuously push checkpoints to the hub with `huggingface_hub` - continuously push checkpoints to the hub with `huggingface_hub`
- stream the dataset with `datasets` during training to avoid disk bottlenecks - stream the dataset with `datasets` during training to avoid disk bottlenecks
- apply the `code_eval` metric in `datasets` to evaluate on [OpenAI's _HumanEval_ benchmark](https://huggingface.co/datasets/openai_humaneval) - apply the `code_eval` metric in `datasets` to evaluate on [OpenAI's _HumanEval_ benchmark](https://huggingface.co/datasets/openai_humaneval)
- showcase examples for downstream tasks with code models in [examples](https://github.com/huggingface/transformers/tree/main/examples/research_projects/codeparrot/examples) folder:
- Algorithmic complexity prediction
- Code generation from english text
- Code explanation
## Installation ## Installation
To install the dependencies simply run the following command: To install the dependencies simply run the following command:
```bash ```bash

View File

@ -0,0 +1,58 @@
# Examples
In this folder we showcase some examples to use code models for downstream tasks.
## Complexity prediction
In this task we want to predict the complexity of Java programs in [CodeComplex](https://huggingface.co/datasets/codeparrot/codecomplex) dataset. Using Hugging Face `trainer`, we finetuned [multilingual CodeParrot](https://huggingface.co/codeparrot/codeparrot-small-multi) and [UniXcoder](https://huggingface.co/microsoft/unixcoder-base-nine) on it, and we used the latter to build this Java complexity prediction [space](https://huggingface.co/spaces/codeparrot/code-complexity-predictor) on Hugging Face hub.
To fine-tune a model on this dataset you can use the following commands:
```python
python train_complexity_predictor.py \
--model_ckpt microsoft/unixcoder-base-nine \
--num_epochs 60 \
--num_warmup_steps 10 \
--batch_size 8 \
--learning_rate 5e-4
```
## Code generation: text to python
In this task we want to train a model to generate code from english text. We finetuned Codeparrot-small on [github-jupyter-text-to-code](https://huggingface.co/datasets/codeparrot/github-jupyter-text-to-code), a dataset where the samples are a succession of docstrings and their Python code, originally extracted from Jupyter notebooks parsed in this [dataset](https://huggingface.co/datasets/codeparrot/github-jupyter-parsed).
To fine-tune a model on this dataset we use the same [script](https://github.com/huggingface/transformers/blob/main/examples/research_projects/codeparrot/scripts/codeparrot_training.py) as the pretraining of codeparrot:
```python
accelerate launch scripts/codeparrot_training.py \
--model_ckpt codeparrot/codeparrot-small \
--dataset_name_train codeparrot/github-jupyter-text-to-code \
--dataset_name_valid codeparrot/github-jupyter-text-to-code \
--train_batch_size 12 \
--valid_batch_size 12 \
--learning_rate 5e-4 \
--num_warmup_steps 100 \
--gradient_accumulation 1 \
--gradient_checkpointing False \
--max_train_steps 3000 \
--save_checkpoint_steps 200 \
--save_dir jupyter-text-to-python
```
## Code explanation: python to text
In this task we want to train a model to explain python code. We finetuned Codeparrot-small on [github-jupyter-code-to-text](https://huggingface.co/datasets/codeparrot/github-jupyter-code-to-text), a dataset where the samples are a succession of Python code and its explanation as a docstring, we just inverted the order of text and code pairs in github-jupyter-code-to-text dataset and added the delimiters "Explanation:" and "End of explanation" inside the doctrings.
To fine-tune a model on this dataset we use the same [script](https://github.com/huggingface/transformers/blob/main/examples/research_projects/codeparrot/scripts/codeparrot_training.py) as the pretraining of codeparrot:
```python
accelerate launch scripts/codeparrot_training.py \
--model_ckpt codeparrot/codeparrot-small \
--dataset_name_train codeparrot/github-jupyter-code-to-text \
--dataset_name_valid codeparrot/github-jupyter-code-to-text \
--train_batch_size 12 \
--valid_batch_size 12 \
--learning_rate 5e-4 \
--num_warmup_steps 100 \
--gradient_accumulation 1 \
--gradient_checkpointing False \
--max_train_steps 3000 \
--save_checkpoint_steps 200 \
--save_dir jupyter-python-to-text
```

View File

@ -0,0 +1,5 @@
datasets==2.3.2
transformers==4.21.1
wandb==0.13.1
evaluate==0.2.2
scikit-learn==1.1.2

View File

@ -0,0 +1,132 @@
import argparse
from copy import deepcopy
import numpy as np
from datasets import ClassLabel, DatasetDict, load_dataset
from evaluate import load
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_ckpt", type=str, default="microsoft/unixcoder-base-nine")
parser.add_argument("--num_epochs", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=6)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--freeze", type=bool, default=True)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--num_warmup_steps", type=int, default=10)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--output_dir", type=str, default="./results")
return parser.parse_args()
metric = load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return metric.compute(predictions=predictions, references=labels)
class CustomCallback(TrainerCallback):
def __init__(self, trainer) -> None:
super().__init__()
self._trainer = trainer
def on_epoch_end(self, args, state, control, **kwargs):
if control.should_evaluate:
control_copy = deepcopy(control)
self._trainer.evaluate(eval_dataset=self._trainer.train_dataset, metric_key_prefix="train")
return control_copy
def main():
args = get_args()
set_seed(args.seed)
dataset = load_dataset("codeparrot/codecomplex", split="train")
train_test = dataset.train_test_split(test_size=0.2)
test_validation = train_test["test"].train_test_split(test_size=0.5)
train_test_validation = DatasetDict(
{
"train": train_test["train"],
"test": test_validation["train"],
"valid": test_validation["test"],
}
)
print("Loading tokenizer and model")
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForSequenceClassification.from_pretrained(args.model_ckpt, num_labels=7)
model.config.pad_token_id = model.config.eos_token_id
if args.freeze:
for param in model.roberta.parameters():
param.requires_grad = False
labels = ClassLabel(num_classes=7, names=list(set(train_test_validation["train"]["complexity"])))
def tokenize(example):
inputs = tokenizer(example["src"], truncation=True, max_length=1024)
label = labels.str2int(example["complexity"])
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"label": label,
}
tokenized_datasets = train_test_validation.map(
tokenize,
batched=True,
remove_columns=train_test_validation["train"].column_names,
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir=args.output_dir,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
weight_decay=0.01,
metric_for_best_model="accuracy",
run_name="complexity-java",
report_to="wandb",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["valid"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
print("Training...")
trainer.add_callback(CustomCallback(trainer))
trainer.train()
if __name__ == "__main__":
main()

View File

@ -63,7 +63,6 @@ class DuplicationIndex:
self._index.insert(code_key, min_hash) self._index.insert(code_key, min_hash)
if len(close_duplicates) > 0: if len(close_duplicates) > 0:
for base_duplicate in close_duplicates: for base_duplicate in close_duplicates:
if base_duplicate in self._duplicate_clusters: if base_duplicate in self._duplicate_clusters:
self._duplicate_clusters[base_duplicate].add(code_key) self._duplicate_clusters[base_duplicate].add(code_key)