transformers/tests/models/gptj
bytebarde be3fd8a262
[Flash Attention 2] Add flash attention 2 for GPT-J (#28295)
* initial implementation of flash attention for gptj

* modify flash attention and overwrite test_flash_attn_2_generate_padding_right

* update flash attention support list

* remove the copy line in the `CodeGenBlock`

* address copy mechanism

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add GPTJ attention classes

* add expected outputs in the gptj test

* Ensure repo consistency with 'make fix-copies'

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2024-03-13 08:43:00 +01:00
..
__init__.py Move test model folders (#17034) 2022-05-03 14:42:02 +02:00
test_modeling_flax_gptj.py Update all references to canonical models (#29001) 2024-02-16 08:16:58 +01:00
test_modeling_gptj.py [Flash Attention 2] Add flash attention 2 for GPT-J (#28295) 2024-03-13 08:43:00 +01:00
test_modeling_tf_gptj.py [`Styling`] stylify using ruff (#27144) 2023-11-16 17:43:19 +01:00