* some nits
* update test
* add support d\sd[a
* remove some dummy inputs
* all good
* style
* nits
* fixes
* fix more copies
* nits
* styling
* fix
* Update src/transformers/models/mistral/modeling_mistral.py
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* add a slow test just to be sure
* fixup
---------
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
* add sdpa
* wip
* cleaning
* add ref
* yet more cleaning
* and more :)
* wip llama
* working llama
* add output_attentions=True support
* bigcode sdpa support
* fixes
* gpt-bigcode support, require torch>=2.1.1
* add falcon support
* fix conflicts falcon
* style
* fix attention_mask definition
* remove output_attentions from attnmaskconverter
* support whisper without removing any Copied from statement
* fix mbart default to eager renaming
* fix typo in falcon
* fix is_causal in SDPA
* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained
* add warnings when falling back on the manual implementation
* precise doc
* wip replace _flash_attn_enabled by config.attn_implementation
* fix typo
* add tests
* style
* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace
* obey to config.attn_implementation if a config is passed in from_pretrained
* fix is_torch_sdpa_available when torch is not installed
* remove dead code
* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/models/bart/modeling_bart.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* remove duplicate pretraining_tp code
* add dropout in llama
* precise comment on attn_mask
* add fmt: off for _unmask_unattended docstring
* precise num_masks comment
* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion
* cleanup modeling_utils
* backward compatibility
* fix style as requested
* style
* improve documentation
* test pass
* style
* add _unmask_unattended tests
* skip meaningless tests for idefics
* hard_check SDPA requirements when specifically requested
* standardize the use if XXX_ATTENTION_CLASSES
* fix SDPA bug with mem-efficient backend on CUDA when using fp32
* fix test
* rely on SDPA is_causal parameter to handle the causal mask in some cases
* fix FALCON_ATTENTION_CLASSES
* remove _flash_attn_2_enabled occurences
* fix test
* add OPT to the list of supported flash models
* improve test
* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test
* remove remaining _flash_attn_2_enabled occurence
* Update src/transformers/modeling_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/modeling_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/modeling_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/modeling_attn_mask_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update docs/source/en/perf_infer_gpu_one.md
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* remove use_attn_implementation
* fix docstring & slight bug
* make attn_implementation internal (_attn_implementation)
* typos
* fix tests
* deprecate use_flash_attention_2=True
* fix test
* add back llama that was removed by mistake
* fix tests
* remove _flash_attn_2_enabled occurences bis
* add check & test that passed attn_implementation is valid
* fix falcon torchscript export
* fix device of mask in tests
* add tip about torch.jit.trace and move bt doc below sdpa
* fix parameterized.expand order
* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there
* update sdpaattention class with the new cache
* Update src/transformers/configuration_utils.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Update src/transformers/models/bark/modeling_bark.py
* address review comments
* WIP torch.jit.trace fix. left: test both eager & sdpa
* add test for torch.jit.trace for both eager/sdpa
* fix falcon with torch==2.0 that needs to use sdpa
* fix doc
* hopefully last fix
* fix key_value_length that has no default now in mask converter
* is it flacky?
* fix speculative decoding bug
* tests do pass
* fix following #27907
---------
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* Copies `modeling_flax_gpt_neo.py` to start
* MLP Block. WIP Attention and Block
* Adds Flax implementation of `LlamaMLP`
Validated with in-file test.
Some slight numeric differences, but assuming it isn't an issue
* Adds `FlaxLlamaRMSNorm` layer
`flax.linen` includes `RMSNorm` layer but not necessarily in all
versions. Hence, we add in-file.
* Adds FlaxLlamaAttention
Copied from GPT-J as it has efficient caching implementation as well as
rotary embeddings.
Notice numerically different, but not by a huge amount. Needs
investigating
* Adds `FlaxLlamaDecoderLayer`
numerically inaccurate, debugging..
* debugging rotary mismatch
gptj uses interleaved whilst llama uses contiguous
i think they match now but still final result is wrong.
maybe drop back to just debugging attention layer?
* fixes bug with decoder layer
still somewhat numerically inaccurate, but close enough for now
* adds markers for what to implement next
the structure here diverges a lot from the PT version.
not a big fan of it, but just get something working for now
* implements `FlaxLlamaBlockCollection`]
tolerance must be higher than expected, kinda disconcerting
* Adds `FlaxLlamaModule`
equivalent PyTorch model is `LlamaModel`
yay! a language model🤗
* adds `FlaxLlamaForCausalLMModule`
equivalent to `LlamaForCausalLM`
still missing returning dict or tuple, will add later
* start porting pretrained wrappers
realised it probably needs return dict as a prereq
* cleanup, quality, style
* readds `return_dict` and model output named tuples
* (tentatively) pretrained wrappers work 🔥
* fixes numerical mismatch in `FlaxLlamaRMSNorm`
seems `jax.lax.rsqrt` does not match `torch.sqrt`.
manually computing `1 / jax.numpy.sqrt` results in matching values.
* [WIP] debugging numerics
* numerical match
I think issue was accidental change of backend. forcing CPU fixes test.
We expect some mismatch on GPU.
* adds in model and integration tests for Flax Llama
summary of failing:
- mul invalid combination of dimensions
- one numerical mismatch
- bf16 conversion (maybe my local backend issue)
- params are not FrozenDict
* adds missing TYPE_CHECKING import and `make fixup`
* adds back missing docstrings
needs review on quality of docstrings, not sure what is required.
Furthermore, need to check if `CHECKPOINT_FOR_DOC` is valid. See TODO
* commenting out equivalence test as can just use common
* debugging
* Fixes bug where mask and pos_ids were swapped in pretrained models
This results in all tests passing now 🔥
* cleanup of modeling file
* cleanup of test file
* Resolving simpler review comments
* addresses more minor review comments
* fixing introduced pytest errors from review
* wip additional slow tests
* wip tests
need to grab a GPU machine to get real logits for comparison
otherwise, slow tests should be okay
* `make quality`, `make style`
* adds slow integration tests
- checking logits
- checking hidden states
- checking generation outputs
* `make fix-copies`
* fix mangled function following `make fix-copies`
* adds missing type checking imports
* fixes missing parameter checkpoint warning
* more finegrained 'Copied from' tags
avoids issue of overwriting `LLAMA_INPUTS_DOCSTRING`
* swaps import guards
??? how did these get swapped initially?
* removing `inv_freq` again as pytorch version has now removed
* attempting to get CI to pass
* adds doc entries for llama flax models
* fixes typo in __init__.py imports
* adds back special equivalence tests
these come from the gpt neo flax tests. there is special behaviour for these models that needs to override the common version
* overrides tests with dummy to see if CI passes
need to fill in these tests later
* adds my contribution to docs
* `make style; make quality`
* replaces random masking with fixed to work with flax version
* `make quality; make style`
* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* updates `x`->`tensor` in `rotate_half`
* addresses smaller review comments
* Update docs/source/en/model_doc/llama.md
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* adds integration test class
* adds `dtype` to rotary embedding to cast outputs
* adds type to flax llama rotary layer
* `make style`
* `make fix-copies`
* Apply suggestions from code review
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* applies suggestions from review
* Update modeling_flax_llama.py
* `make fix-copies`
* Update tests/models/llama/test_modeling_llama.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Update src/transformers/models/llama/modeling_flax_llama.py
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* fixes shape mismatch in FlaxLlamaMLP
* applies some suggestions from reviews
* casts attn output logits to f32 regardless of dtype
* adds attn bias using `LlamaConfig.attention_bias`
* adds Copied From comments to Flax Llama test
* mistral and persimmon test change -copy from llama
* updates docs index
* removes Copied from in tests
it was preventing `make fix-copies` from succeeding
* quality and style
* ignores FlaxLlama input docstring
* adds revision to `_CHECKPOINT_FOR_DOC`
* repo consistency and quality
* removes unused import
* removes copied from from Phi test
now diverges from llama tests following FlaxLlama changes
* adds `_REAL_CHECKPOINT_FOR_DOC`
* removes refs from pr tests
* reformat to make ruff happy
---------
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Fix mistral generate for long prompt / response
* Add unit test
* fix linter
* fix linter
* fix test
* add assisted generation test for mistral and load the model in 4 bit + fa2
* try to stylify using ruff
* might need to remove these changes?
* use ruf format andruff check
* use isinstance instead of type comparision
* use # fmt: skip
* use # fmt: skip
* nits
* soem styling changes
* update ci job
* nits isinstance
* more files update
* nits
* more nits
* small nits
* check and format
* revert wrong changes
* actually use formatter instead of checker
* nits
* well docbuilder is overwriting this commit
* revert notebook changes
* try to nuke docbuilder
* style
* fix feature exrtaction test
* remve `indent-width = 4`
* fixup
* more nits
* update the ruff version that we use
* style
* nuke docbuilder styling
* leve the print for detected changes
* nits
* Remove file I/O
Co-authored-by: charliermarsh
<charlie.r.marsh@gmail.com>
* style
* nits
* revert notebook changes
* Add # fmt skip when possible
* Add # fmt skip when possible
* Fix
* More ` # fmt: skip` usage
* More ` # fmt: skip` usage
* More ` # fmt: skip` usage
* NIts
* more fixes
* fix tapas
* Another way to skip
* Recommended way
* Fix two more fiels
* Remove asynch
Remove asynch
---------
Co-authored-by: charliermarsh <charlie.r.marsh@gmail.com>
* add FA-2 support for mistral
* fixup
* add sliding windows
* fixing few nits
* v1 slicing cache - logits do not match
* add comment
* fix bugs
* more mem efficient
* add warning once
* add warning once
* oops
* fixup
* more comments
* copy
* add safety checker
* fixup
* Update src/transformers/models/mistral/modeling_mistral.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* copied from
* up
* raise when padding side is right
* fixup
* add doc + few minor changes
* fixup
---------
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>