Skip to content
GitLab
Explore
Projects
Groups
Topics
Snippets
Projects
Groups
Topics
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
zhusg
transformers-new
Commits
1141eff1
Unverified
Commit
1141eff1
authored
7 months ago
by
Vladislav Bronzov
Committed by
GitHub
7 months ago
Browse files
Options
Download
Patches
Plain Diff
Add Pytorch Tensor Parallel support for Mistral (#34927)
add base tp support
parent
4d1d0f29
main
35597_custom_tokenizer
action_to_notify_new_model_push_backup
add-deci-lm
add-dia
add-spinquant
add_eagle
add_kernelize
albertvillanova-patch-1
all_jobs_can_compare_against_prev_runs_clean_trigger
allow-disabling-compile
base-model-loading
batched_handle_empty_string
build_ci_docker_image_amd1
build_ci_docker_image_amd2
build_ci_docker_image_amd3
change-mi250-ci-slack-channel
change_build_input_tests
change_to_draft_2
change_to_draft_3
change_to_draft_4
change_to_draft_4-release
chat-template-quick-fix
check-v4.49-release
check_circleci_new_trigger
check_circleci_tokenizer
check_compile_if_flaky
check_doc_image
check_draft_4
check_env_runner
check_push
check_quantized_param_bnb4
check_temp
check_test_from_pretrained_low_cpu_mem_usage_equal
check_torch_27
ci-test-huggingface-hub-0.29.0.rc6
ci-test-huggingface-hub-0.30.0.rc1
ci-test-huggingface-hub-v0.27.0.rc0
ci-test-huggingface-hub-v0.27.0.rc1
ci-test-huggingface-hub-v0.27.0rc1
ci-test-huggingface-hub-v0.28.0.rc0
ci-test-huggingface-hub-v0.28.0.rc5
ci-test-huggingface-hub-v0.29.0.rc0
ci-test-huggingface-hub-v0.29.0.rc1
ci-test-huggingface-hub-v0.29.0.rc2
ci-test-huggingface-hub-v0.29.0.rc5
ci-test-huggingface-hub-v0.29.0.rc7
ci-test-huggingface-hub-v0.29.3.rc0
ci-test-huggingface-hub-v0.30.0.rc3-release
ci-test-huggingface-hub-v0.31.0.rc0-release
ci-test-huggingface-hub-v0.32.0.rc0-release
ci-test-huggingface-hub-v0.32.0.rc1-release
ci_with_commit_41b9b92b52215bed472c9a534a06abbc3a9a95cd
ci_with_torch_2.7
ci_with_torch_2.7.1_commit_0ef339ff1b63bb03a388c79bfbebec9085e10564
ci_with_torch_2.7_commit_0ef339ff1b63bb03a388c79bfbebec9085e10564
ci_with_torch_version_base
circleci_debug_base
circleci_debug_base_MobileNetV1ModelTest_test_batching_equivalence
circleci_debug_base_timm
circleci_debug_base_timm_3
clean-modeling
composable-tp
continuous-batching
custom-compute-loss-num-batches
dduf-compability
dduf-compatibility-with-file-explorer
debug+_audio
deepspeed-amd-pytorch-version-fix
dep_create_token_type_id
dependabot/pip/examples/flax/vision/torch-2.6.0
dependabot/pip/examples/tensorflow/language-modeling-tpu/transformers-4.50.0
disable-mi210-ci
dummy-pr
elie-temp-nope
faster_set_initialized_submodules
feature/#35425
find-test-failure-diff-between-envs
fix-Seq2SeqTrainingArguments-doc
fix-apex
fix-autoprocessor-import-order
fix-ci
fix-compressed-tensors
fix-device-map
fix-doc-builder
fix-flash-attention-with-static-cache
fix-gemma2-sliding-window
fix-gemma3-grad-acc
fix-kwargs-issues
fix-modular
fix-pytorch-deepspeed-image
fix-quantizer
fix-tp-check
fix/default_cb_scheduler
fix_aria_ci
fix_batch_test
fix_circleci_not_triggered
fix_docker_autoawq
fix_docker_autogptq_from_source
fix_falcon_processor
fix_flaky_4
fix_flaky_test_assisted_decoding_matches_greedy_search
fix_flaky_test_pt_tf_model_equivalence
fix_module_conversion_util_ci
fix_offload_disk_gguf
fix_print
fix_quanto_llama27b
fix_require_class
fix_sam_samhq
fix_tie3
fix_tiny_gh
fixing_gptq_tests
flex_attention_qwen2
fsdp2-checkpointing
get-our-efficiency-back
glm4
gpt2
hf-papers
ifix_aqlm_modules_to_not_convert
image-chunked-prefill
init_round_2
init_round_5
llama-refactor
llama4-unhardcode
merging_to_test
metadata_job_2
mistral3-xpu-cpu-offload
more-cleaning
more_info_ci_temp
muellerzr-fixup-warning
muellerzr-more-models-sadface
muellerzr-speedup-modular-conversion
multiple-modular
new_blt
nit-ga-condition
nit_cleanup
no-more-pointing-at-remote-repos
no_overwrite_test_batching_equivalence
non-model-inits
nouamane/context-parallel
one-class-to-rule-them-all
parallel
pcuenca-patch-1
pixtral_batchmixfeature_fix
pixtral_processor_structure_fix
prefill-chunking
processor-template-duplicated-tokens
push-ci-image
raise-from
random_dispatch
refactor-from-pretrained-base-commit
remove-items
remove-torch-pre-releases-amd-image
remove_unused_test_attribs
revert-37178-revert-loadibng-issue
run_amd_scheduled_ci_caller
run_amd_scheduled_ci_caller_testing
run_amd_scheduled_ci_caller_testing1
run_ci_without_kenlm
secure-amd-ci
skip_flaky_test
skip_flaky_tests_double_check
skip_internvl_tests
slight-readme-reword
spm_converter
stop_repeating_setup
temp-disable-scheduled-amd-ci
temp-kosmos25
temp123
tensor-cache
test-datasets-main
test-deepseek-fp8
test-fused-moe
test-tp-old-version
test_fast_only_refactor
test_safetensors_0.5.0
tests-fetcher-test-all
timm_wrapper_kwargs
tiny-fixes-qwen2.5-vl
tok_refactor
tokenizers_prerelease
tp-support
tp-test
transformers-should-not-set-env-vars
trigger_688f4707bfc5f6adc6f4f18c2081c5a66db590d1
trigger_all
trigger_all_2
trigger_build
trigger_doc_build_after_bot_push
trigger_via_api_backup
try_cpu_offload
try_torch_2.7_on_circleci_jobs
update-from-pretrained
update-min-safetensors
update-notification-service-amd-ci
update-patch-helper
update-recommended-reviewers
update-special-tokens
update-tp-nits
update_loss
use-hfh-loading-saving-state-dict-helpers
use-process-retry-on-amd-smi
use_uv
v4.47-release
v4.48-release
v4.49-release
v4.49.0-AyaVision-release
v4.49.0-Gemma-3-release
v4.49.0-Mistral-3-release
v4.49.0-SigLIP-2-release
v4.49.0-SmolVLM-2-release
v4.50-release
v4.50.3-DeepSeek-3-release
v4.51-release
v4.51.3-BitNet-release
v4.51.3-CSM-release
v4.51.3-D-FINE-release
v4.51.3-GraniteMoeHybrid-release
v4.51.3-InternVL-release
v4.51.3-Janus-release
v4.51.3-LlamaGuard-release
v4.51.3-MLCD-release
v4.51.3-Qwen2.5-Omni-release
v4.51.3-SAM-HQ-release
v4.51.3-TimesFM-release
v4.52-release
vas-bert-attn-refactor
vas-bert-attn-refactors
vas-whisper-attn-refactor
vb/add-baichuan
vision_visualizer
why_no_trigger
working
working-version
ydshieh-push-ci-image
v4.52.3
v4.52.2
v4.52.1
v4.52.0
v4.51.3
v4.51.3-TimesFM-preview
v4.51.3-SAM-HQ-preview
v4.51.3-Qwen2.5-Omni-preview
v4.51.3-MLCD-preview
v4.51.3-LlamaGuard-preview
v4.51.3-Janus-preview
v4.51.3-InternVL-preview
v4.51.3-GraniteMoeHybrid-preview
v4.51.3-D-FINE-preview
v4.51.3-CSM-preview
v4.51.3-BitNet-preview
v4.51.2
v4.51.1
v4.51.0
v4.50.3
v4.50.3-DeepSeek-3
v4.50.2
v4.50.1
v4.50.0
v4.50.r32
v4.50.r3
v4.49.0
v4.49.0-SmolVLM-2
v4.49.0-SigLIP-2
v4.49.0-Mistral-3
v4.49.0-Gemma-3
v4.49.0-AyaVision
v4.48.3
v4.48.2
v4.48.1
v4.48.0
v4.47.1
v4.47.0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/transformers/models/mistral/configuration_mistral.py
+10
-0
src/transformers/models/mistral/configuration_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+4
-3
src/transformers/models/mistral/modeling_mistral.py
with
14 additions
and
3 deletions
+14
-3
src/transformers/models/mistral/configuration_mistral.py
+
10
−
0
View file @
1141eff1
...
...
@@ -97,6 +97,16 @@ class MistralConfig(PretrainedConfig):
model_type
=
"mistral"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
# Default tensor parallel plan for base model `MistralModel`
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise"
,
"layers.*.self_attn.k_proj"
:
"colwise"
,
"layers.*.self_attn.v_proj"
:
"colwise"
,
"layers.*.self_attn.o_proj"
:
"rowwise"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
def
__init__
(
self
,
...
...
This diff is collapsed.
Click to expand it.
src/transformers/models/mistral/modeling_mistral.py
+
4
−
3
View file @
1141eff1
...
...
@@ -227,9 +227,9 @@ class MistralAttention(nn.Module):
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
-
1
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
-
1
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
-
1
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
...
...
@@ -983,6 +983,7 @@ class MistralModel(MistralPreTrainedModel):
class
MistralForCausalLM
(
MistralPreTrainedModel
,
GenerationMixin
):
_tied_weights_keys
=
[
"lm_head.weight"
]
_tp_plan
=
{
"lm_head"
:
"colwise_rep"
}
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment
Menu
Explore
Projects
Groups
Topics
Snippets