transformers/.circleci/config.yml

436 lines
17 KiB
YAML
Raw Normal View History

version: 2.1
orbs:
gcp-gke: circleci/gcp-gke@1.0.4
go: circleci/go@1.3.0
# TPU REFERENCES
references:
checkout_ml_testing: &checkout_ml_testing
run:
name: Checkout ml-testing-accelerators
command: |
git clone https://github.com/GoogleCloudPlatform/ml-testing-accelerators.git
cd ml-testing-accelerators
git fetch origin 5e88ac24f631c27045e62f0e8d5dfcf34e425e25:stable
git checkout stable
build_push_docker: &build_push_docker
run:
name: Configure Docker
command: |
gcloud --quiet auth configure-docker
cd docker/transformers-pytorch-tpu
if [ -z "$CIRCLE_PR_NUMBER" ]; then docker build --tag "$GCR_IMAGE_PATH:$CIRCLE_WORKFLOW_JOB_ID" -f Dockerfile --build-arg "TEST_IMAGE=1" . ; else docker build --tag "$GCR_IMAGE_PATH:$CIRCLE_WORKFLOW_JOB_ID" -f Dockerfile --build-arg "TEST_IMAGE=1" --build-arg "GITHUB_REF=pull/$CIRCLE_PR_NUMBER/head" . ; fi
docker push "$GCR_IMAGE_PATH:$CIRCLE_WORKFLOW_JOB_ID"
deploy_cluster: &deploy_cluster
run:
name: Deploy the job on the kubernetes cluster
command: |
go get github.com/google/go-jsonnet/cmd/jsonnet && \
export PATH=$PATH:$HOME/go/bin && \
kubectl create -f docker/transformers-pytorch-tpu/dataset.yaml || true && \
job_name=$(jsonnet -J ml-testing-accelerators/ docker/transformers-pytorch-tpu/bert-base-cased.jsonnet --ext-str image=$GCR_IMAGE_PATH --ext-str image-tag=$CIRCLE_WORKFLOW_JOB_ID | kubectl create -f -) && \
job_name=${job_name#job.batch/} && \
job_name=${job_name% created} && \
echo "Waiting on kubernetes job: $job_name" && \
i=0 && \
# 30 checks spaced 30s apart = 900s total.
max_checks=30 && \
status_code=2 && \
# Check on the job periodically. Set the status code depending on what
# happened to the job in Kubernetes. If we try max_checks times and
# still the job hasn't finished, give up and return the starting
# non-zero status code.
while [ $i -lt $max_checks ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else echo "Job not finished yet"; fi; sleep 30; done && \
echo "Done waiting. Job status code: $status_code" && \
pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \
echo "GKE pod name: $pod_name" && \
kubectl logs -f $pod_name --container=train
echo "Done with log retrieval attempt." && \
gcloud container images delete "$GCR_IMAGE_PATH:$CIRCLE_WORKFLOW_JOB_ID" --force-delete-tags && \
exit $status_code
delete_gke_jobs: &delete_gke_jobs
run:
name: Delete GKE Jobs
command: |
# Match jobs whose age matches patterns like '1h' or '1d', i.e. any job
# that has been around longer than 1hr. First print all columns for
# matches, then execute the delete.
kubectl get job | awk 'match($4,/[0-9]+[dh]/) {print $0}'
kubectl delete job $(kubectl get job | awk 'match($4,/[0-9]+[dh]/) {print $1}')
2018-12-21 05:33:39 +08:00
jobs:
2019-12-23 01:21:35 +08:00
run_tests_torch_and_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-torch_and_tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece]
[WIP] Tapas v4 (tres) (#9117) * First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Test PyTorch scatter * Set to slow + minify * Calm flake8 down * First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Add add_pooling_layer argument to TapasModel Fix comments by @sgugger and @patrickvonplaten * Fix issue in docs + fix style and quality * Clean up conversion script and add task parameter to TapasConfig * Revert the task parameter of TapasConfig Some minor fixes * Improve conversion script and add test for absolute position embeddings * Improve conversion script and add test for absolute position embeddings * Fix bug with reset_position_index_per_cell arg of the conversion cli * Add notebooks to the examples directory and fix style and quality * Apply suggestions from code review * Move from `nielsr/` to `google/` namespace * Apply Sylvain's comments Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Rogge Niels <niels.rogge@howest.be> Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: sgugger <sylvain.gugger@gmail.com>
2020-12-16 06:08:49 +08:00
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: RUN_PT_TF_CROSS_TESTS=1 python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_tf ./tests/ -m is_pt_tf_cross_test --durations=0 | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
2019-12-23 01:21:35 +08:00
run_tests_torch:
working_directory: ~/transformers
2018-12-21 05:33:39 +08:00
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
2018-12-21 05:33:39 +08:00
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
[WIP] Tapas v4 (tres) (#9117) * First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Test PyTorch scatter * Set to slow + minify * Calm flake8 down * First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Add add_pooling_layer argument to TapasModel Fix comments by @sgugger and @patrickvonplaten * Fix issue in docs + fix style and quality * Clean up conversion script and add task parameter to TapasConfig * Revert the task parameter of TapasConfig Some minor fixes * Improve conversion script and add test for absolute position embeddings * Improve conversion script and add test for absolute position embeddings * Fix bug with reset_position_index_per_cell arg of the conversion cli * Add notebooks to the examples directory and fix style and quality * Apply suggestions from code review * Move from `nielsr/` to `google/` namespace * Apply Sylvain's comments Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Rogge Niels <niels.rogge@howest.be> Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: sgugger <sylvain.gugger@gmail.com>
2020-12-16 06:08:49 +08:00
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python -m pytest -n 8 --dist=loadfile -s --make-reports=tests_torch ./tests/ | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
2020-10-29 01:59:43 +08:00
2019-12-23 01:21:35 +08:00
run_tests_tf:
working_directory: ~/transformers
2019-09-05 16:23:04 +08:00
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
2019-09-05 16:23:04 +08:00
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece]
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-tf-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_tf ./tests/ | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
Integrate Bert-like model on Flax runtime. (#3722) * WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... :boom: :boom: :boom: * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik :tada: * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Again import sorting... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format ! Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Flake one more time :notes: Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump flax to it's latest version Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * isort import in tests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix-copies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use relative imports Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Stefan Schweter <stefan@schweter.it> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-10-19 21:55:41 +08:00
run_tests_flax:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
Integrate Bert-like model on Flax runtime. (#3722) * WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... :boom: :boom: :boom: * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik :tada: * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Again import sorting... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format ! Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Flake one more time :notes: Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump flax to it's latest version Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * isort import in tests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix-copies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use relative imports Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Stefan Schweter <stefan@schweter.it> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-10-19 21:55:41 +08:00
- run: pip install --upgrade pip
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece]
Integrate Bert-like model on Flax runtime. (#3722) * WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... :boom: :boom: :boom: * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik :tada: * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Again import sorting... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format ! Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Flake one more time :notes: Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump flax to it's latest version Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * isort import in tests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix-copies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use relative imports Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Stefan Schweter <stefan@schweter.it> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-10-19 21:55:41 +08:00
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-flax-{{ checksum "setup.py" }}
Integrate Bert-like model on Flax runtime. (#3722) * WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... :boom: :boom: :boom: * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik :tada: * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Again import sorting... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format ! Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Flake one more time :notes: Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump flax to it's latest version Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * isort import in tests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix-copies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use relative imports Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Stefan Schweter <stefan@schweter.it> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-10-19 21:55:41 +08:00
paths:
- '~/.cache/pip'
- run: python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_flax ./tests/ | tee tests_output.txt
Integrate Bert-like model on Flax runtime. (#3722) * WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... :boom: :boom: :boom: * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik :tada: * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Again import sorting... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format ! Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Flake one more time :notes: Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump flax to it's latest version Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * isort import in tests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix-copies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use relative imports Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Stefan Schweter <stefan@schweter.it> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-10-19 21:55:41 +08:00
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
run_tests_pipelines_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
[WIP] Tapas v4 (tres) (#9117) * First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Test PyTorch scatter * Set to slow + minify * Calm flake8 down * First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Add add_pooling_layer argument to TapasModel Fix comments by @sgugger and @patrickvonplaten * Fix issue in docs + fix style and quality * Clean up conversion script and add task parameter to TapasConfig * Revert the task parameter of TapasConfig Some minor fixes * Improve conversion script and add test for absolute position embeddings * Improve conversion script and add test for absolute position embeddings * Fix bug with reset_position_index_per_cell arg of the conversion cli * Add notebooks to the examples directory and fix style and quality * Apply suggestions from code review * Move from `nielsr/` to `google/` namespace * Apply Sylvain's comments Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Rogge Niels <niels.rogge@howest.be> Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: sgugger <sylvain.gugger@gmail.com>
2020-12-16 06:08:49 +08:00
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: RUN_PIPELINE_TESTS=1 python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_pipelines_torch -m is_pipeline_test ./tests/ | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
run_tests_pipelines_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,testing,sentencepiece]
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-tf-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: RUN_PIPELINE_TESTS=1 python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_pipelines_tf ./tests/ -m is_pipeline_test | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
2019-12-23 01:21:35 +08:00
run_tests_custom_tokenizers:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
environment:
RUN_CUSTOM_TOKENIZERS: yes
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-custom_tokenizers-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[ja,testing,sentencepiece]
- run: python -m unidic download
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-custom_tokenizers-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python -m pytest -s --make-reports=tests_custom_tokenizers ./tests/test_tokenization_bert_japanese.py | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
2019-12-23 01:21:35 +08:00
run_examples_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-torch_examples-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,sentencepiece,testing]
- run: pip install -r examples/_tests_requirements.txt
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-torch_examples-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_torch ./examples/ | tee examples_output.txt
- store_artifacts:
path: ~/transformers/examples_output.txt
- store_artifacts:
path: ~/transformers/reports
run_tests_git_lfs:
working_directory: ~/transformers
docker:
- image: circleci/python:3.7
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo apt-get install git-lfs
- run: |
git config --global user.email "ci@dummy.com"
git config --global user.name "ci"
- run: pip install --upgrade pip
- run: pip install .[testing]
- run: RUN_GIT_LFS_TESTS=1 python -m pytest -sv ./tests/test_hf_api.py -k "HfLargefilesTest"
2020-04-17 23:23:18 +08:00
build_doc:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-build_doc-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install ."[all, docs]"
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-build_doc-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: cd docs && make html SPHINXOPTS="-W"
2020-04-17 23:23:18 +08:00
- store_artifacts:
path: ./docs/_build
2019-08-30 00:14:29 +08:00
deploy_doc:
working_directory: ~/transformers
2019-08-30 00:14:29 +08:00
docker:
- image: circleci/python:3.6
2019-08-30 00:14:29 +08:00
steps:
- add_ssh_keys:
fingerprints:
- "5b:7a:95:18:07:8c:aa:76:4c:60:35:88:ad:60:56:71"
2019-08-30 00:14:29 +08:00
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-deploy_doc-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install ."[all,docs]"
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-deploy_doc-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
2019-10-23 05:53:52 +08:00
- run: ./.circleci/deploy.sh
2019-12-21 22:49:11 +08:00
check_code_quality:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
resource_class: medium
2019-12-21 22:49:11 +08:00
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
2020-10-29 01:59:43 +08:00
- v0.4-code_quality-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install isort
- run: pip install .[all,quality]
- save_cache:
2020-10-29 01:59:43 +08:00
key: v0.4-code_quality-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: black --check examples tests src utils
- run: isort --check-only examples tests src utils
- run: flake8 examples tests src utils
- run: python utils/style_doc.py src/transformers docs/source --max_len 119 --check_only
- run: python utils/check_copies.py
- run: python utils/check_table.py
- run: python utils/check_dummies.py
- run: python utils/check_repo.py
check_repository_consistency:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
resource_class: small
parallelism: 1
steps:
- checkout
- run: pip install requests
- run: python ./utils/link_tester.py
# TPU JOBS
run_examples_tpu:
docker:
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- go/install
- *checkout_ml_testing
- gcp-gke/install
- gcp-gke/update-kubeconfig-with-credentials:
cluster: $GKE_CLUSTER
perform-login: true
- setup_remote_docker
- *build_push_docker
- *deploy_cluster
cleanup-gke-jobs:
docker:
- image: circleci/python:3.6
steps:
- gcp-gke/install
- gcp-gke/update-kubeconfig-with-credentials:
cluster: $GKE_CLUSTER
perform-login: true
- *delete_gke_jobs
2019-08-30 00:14:29 +08:00
workflow_filters: &workflow_filters
filters:
branches:
only:
2019-10-23 06:01:10 +08:00
- master
workflows:
2019-08-30 00:14:29 +08:00
version: 2
build_and_test:
jobs:
2019-12-21 22:49:11 +08:00
- check_code_quality
- check_repository_consistency
2019-12-23 01:21:35 +08:00
- run_examples_torch
- run_tests_custom_tokenizers
- run_tests_torch_and_tf
- run_tests_torch
- run_tests_tf
Integrate Bert-like model on Flax runtime. (#3722) * WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... :boom: :boom: :boom: * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik :tada: * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Again import sorting... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format ! Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Flake one more time :notes: Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump flax to it's latest version Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * isort import in tests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix-copies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use relative imports Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Stefan Schweter <stefan@schweter.it> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-10-19 21:55:41 +08:00
- run_tests_flax
- run_tests_pipelines_torch
- run_tests_pipelines_tf
- run_tests_git_lfs
2020-04-17 23:23:18 +08:00
- build_doc
2019-10-23 05:51:30 +08:00
- deploy_doc: *workflow_filters
2020-08-11 14:49:37 +08:00
tpu_testing_jobs:
triggers:
- schedule:
# Set to run at the first minute of every hour.
cron: "0 8 * * *"
filters:
branches:
only:
- master
jobs:
- cleanup-gke-jobs
- run_examples_tpu