• tomeras91's avatar
    Add jamba (#29943) · 3f20877d
    tomeras91 authored
    * Add jamba arch
    
    * apply "make fix-copies" changes
    
    * fix link to model in JambaConfig docstring
    
    * Add n_ctx in modeling file because repo-consistency wants that
    
    * Add jamba to flash attention and sdpa documentation
    
    * mamba dt_proj quant fix now works for LoRA as well
    
    * override test_left_padding_compatibility and use a more permissive tolerance. left padding numerical difference are accentuated by mamba layers
    
    * add jamba to tokenization auto
    
    * fix comments of shape (PR #24 in the model page: https://huggingface.co/ai21labs/Jamba-v0.1/discussions/24)
    
    * simple PR fixes
    
    * remove unnecessary kwargs from JambaAttentionDecoderLayer and JambaMambaDecoderLayer
    
    * remove the LoRA hack for the mamba dt_proj bias. It was solved in huggingface/peft#1530 (https://github.com/huggingface/peft/pull/1530)
    
    * Add copied comment on JambaMLP (it's the same as MixtralMLP)
    
    * remove padding_mask warnings. It's not supported anymore
    
    * fix docstring. Float instead of int
    
    * A few more minor PR fixes
    
    * (1) lowercase names for mamba layernorms (2) remove _apply_inner_layernorms and do it directly in the forward pass
    
    * Return None attention weights from mamba layers. Append to all attentions only if not None.
    
    * remove some leftover jamba archive lists
    
    * Better separation between expert vs non-expert layers. non-expert layers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel
    
    * no need to take router_logits at config.expert_layer_offset anymore. result.router_logits now holds results only for expert layers
    
    * Add Jamba paper on READMEs
    
    * (1) rename n_ctx -> max_position_embeddings (2) don't use it in the modeling file since it's not needed (set it as an exception to check_config_attributes)
    
    * Add copied from comment
    
    * remove the code path for apply_inner_layernorms=False. Jamba always has the inner mamba layernorms
    
    * clearer docstring for _convert_to_standard_cache
    
    * style fixes
    
    * Change calc_logits_for_entire_prompt (bool) to num_logits_to_keep (int). Adapt assisted decoding code tp use it. Also small change in low memory beam search decoding path to support this new int value in model_inputs
    
    * rename test so it still overrides what its meant to override
    
    * draft
    
    * oups
    
    * nit
    
    * remove more complexe logic
    
    * fix names used in config
    
    * fix fix fix
    
    * style
    
    * fix some more failing tests
    
    * generate did not init the cache 🙃
    
    
    
    * more small nits
    
    * typo
    
    * config.mamba_expand * config.hidden_size for the intermediate size of the mamba shapes
    
    * fix init of pkv with torch.tensor()
    
    * empty tensor
    
    * fix some init issues
    
    * stupid changes required by generate because it does not even support it's own DynamicCache class
    
    * more fixes
    
    * fix general assisted gen cache_position bug
    
    * tests passing
    
    * Add offsets and periods as SPECIAL_CASES_TO_ALLOW in check_config_attributes.py
    
    * fix reorder_cache to reorder mamba states and override some more functions in HybridMambaAttentionDynamicCache
    
    * no need to override test_past_key_values_format() and _check_past_key_values_for_generate() in tests anymore
    
    * fix docstrings and typehints for past_key_values
    
    * style fixes
    
    * fix docs
    
    * change typehint due to copy from Mixtral
    
    * forgot import
    
    * import order
    
    * Add configuration_jamba and modeling_jamba to not_doctested because the model is too big to download (in docstring of JambaForCausalLM.forward)
    
    * Add integration test with tiny tandom Jamba model on hub
    
    * fix flash attention cache shapes
    
    * bring back forgotten hidden states
    
    * rename HybridMambaAttentionDynamicCache.seqlen_offset to has_previous_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model
    
    * align integration test after modeling fixes
    
    * bugfix - mamba can use precomputed states only of forward pass is on a single token
    
    * bugfix - mamba can use precomputed states only if they match the batch size
    
    * typo
    
    * remove making _prepare_4d_causal_attention_mask a leaf function
    
    * stop using past_seq_len.get_seq_length(). Use cache positions instead. Adjust test (test_decoder_model_past_with_large_inputs) accordingly
    
    ---------
    
    Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
    Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
    3f20877d