transformers/tests/models/gpt2
Eduardo Pacheco 22d159ddf9
Adding Flash Attention 2 Support for GPT2 (#29226)
* First commit to add flash attention 2 for GPT-2

* more improvements

* Make GPT2 pass tests and fixed Decison Transformers copies

* Fixed missing arg

* fix copies

* Added expected speedup

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Added test

* Fixed attn attribute

* Update docs/source/en/model_doc/gpt2.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/gpt2.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update Decision transformer attentions

* More updates

* Passing tests

* Fix copies

* Fix copies part 2

* Decision transformer updates

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix copies

* Decision transformer not supporting flash attn

* Addressed comments

* Addressed comments

* Addressed comments

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-03-28 09:31:24 +00:00
..
__init__.py Move test model folders (#17034) 2022-05-03 14:42:02 +02:00
test_modeling_flax_gpt2.py Update all references to canonical models (#29001) 2024-02-16 08:16:58 +01:00
test_modeling_gpt2.py Adding Flash Attention 2 Support for GPT2 (#29226) 2024-03-28 09:31:24 +00:00
test_modeling_tf_gpt2.py Remove static pretrained maps from the library's internals (#29112) 2024-03-25 10:33:38 +01:00
test_tokenization_gpt2.py Adds pretrained IDs directly in the tests (#29534) 2024-03-13 14:53:27 +01:00
test_tokenization_gpt2_tf.py Update all references to canonical models (#29001) 2024-02-16 08:16:58 +01:00