Skip to content
GitLab
Explore
Projects
Groups
Topics
Snippets
Projects
Groups
Topics
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
某某某
transformers-new
Commits
d228f50a
Unverified
Commit
d228f50a
authored
2 months ago
by
Mohamed Mekkouri
Committed by
GitHub
2 months ago
Browse files
Options
Download
Patches
Plain Diff
Fixing gated repo issues (#37463)
using unsloth model
parent
a5dfb989
main
add-dia
add_kernelize
all_jobs_can_compare_against_prev_runs_clean_trigger
batched_handle_empty_string
change-mi250-ci-slack-channel
change_build_input_tests
check_push
check_torch_27
ci-test-huggingface-hub-v0.31.0.rc0-release
ci-test-huggingface-hub-v0.32.0.rc0-release
ci-test-huggingface-hub-v0.32.0.rc1-release
clean-modeling
continuous-batching
dep_create_token_type_id
dependabot/pip/examples/flax/vision/torch-2.6.0
dependabot/pip/examples/tensorflow/language-modeling-tpu/transformers-4.50.0
disable-mi210-ci
elie-temp-nope
find-test-failure-diff-between-envs
fix-apex
fix-modular
fix-tp-check
fix/default_cb_scheduler
fix_batch_test
fix_circleci_not_triggered
fix_sam_samhq
fix_tiny_gh
fixing_gptq_tests
fsdp2-checkpointing
get-our-efficiency-back
glm4
gpt2
hf-papers
mistral3-xpu-cpu-offload
more-cleaning
more_info_ci_temp
new_blt
non-model-inits
nouamane/context-parallel
one-class-to-rule-them-all
push-ci-image
remove_unused_test_attribs
run_amd_scheduled_ci_caller
skip_flaky_tests_double_check
skip_internvl_tests
slight-readme-reword
spm_converter
temp123
test-datasets-main
test_fast_only_refactor
tok_refactor
trigger_688f4707bfc5f6adc6f4f18c2081c5a66db590d1
trigger_doc_build_after_bot_push
trigger_via_api_backup
try_cpu_offload
try_torch_2.7_on_circleci_jobs
update-notification-service-amd-ci
update-patch-helper
update_loss
v4.51.3-BitNet-release
v4.51.3-CSM-release
v4.51.3-D-FINE-release
v4.51.3-GraniteMoeHybrid-release
v4.51.3-InternVL-release
v4.51.3-Janus-release
v4.51.3-LlamaGuard-release
v4.51.3-MLCD-release
v4.51.3-Qwen2.5-Omni-release
v4.51.3-SAM-HQ-release
v4.51.3-TimesFM-release
v4.52-release
vas-bert-attn-refactor
vas-bert-attn-refactors
vas-whisper-attn-refactor
why_no_trigger
working
working-version
v4.52.3
v4.52.2
v4.52.1
v4.52.0
v4.51.3-TimesFM-preview
v4.51.3-SAM-HQ-preview
v4.51.3-Qwen2.5-Omni-preview
v4.51.3-MLCD-preview
v4.51.3-LlamaGuard-preview
v4.51.3-Janus-preview
v4.51.3-InternVL-preview
v4.51.3-GraniteMoeHybrid-preview
v4.51.3-D-FINE-preview
v4.51.3-CSM-preview
v4.51.3-BitNet-preview
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
tests/quantization/quark_integration/test_quark.py
+1
-7
tests/quantization/quark_integration/test_quark.py
with
1 addition
and
7 deletions
+1
-7
tests/quantization/quark_integration/test_quark.py
+
1
−
7
View file @
d228f50a
...
...
@@ -19,7 +19,6 @@ from transformers.testing_utils import (
is_torch_available
,
require_accelerate
,
require_quark
,
require_read_token
,
require_torch_gpu
,
require_torch_multi_gpu
,
slow
,
...
...
@@ -44,7 +43,7 @@ class QuarkConfigTest(unittest.TestCase):
@require_quark
@require_torch_gpu
class
QuarkTest
(
unittest
.
TestCase
):
reference_model_name
=
"
meta-llama/
Llama-3.1-8B-Instruct"
reference_model_name
=
"
unsloth/Meta-
Llama-3.1-8B-Instruct"
quantized_model_name
=
"amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
input_text
=
"Today I am in Paris and"
...
...
@@ -76,13 +75,11 @@ class QuarkTest(unittest.TestCase):
device_map
=
cls
.
device_map
,
)
@require_read_token
def
test_memory_footprint
(
self
):
mem_quantized
=
self
.
quantized_model
.
get_memory_footprint
()
self
.
assertTrue
(
self
.
mem_fp16
/
mem_quantized
>
self
.
EXPECTED_RELATIVE_DIFFERENCE
)
@require_read_token
def
test_device_and_dtype_assignment
(
self
):
r
"""
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
...
...
@@ -96,7 +93,6 @@ class QuarkTest(unittest.TestCase):
# Tries with a `dtype``
self
.
quantized_model
.
to
(
torch
.
float16
)
@require_read_token
def
test_original_dtype
(
self
):
r
"""
A simple test to check if the model succesfully stores the original dtype
...
...
@@ -107,7 +103,6 @@ class QuarkTest(unittest.TestCase):
self
.
assertTrue
(
isinstance
(
self
.
quantized_model
.
model
.
layers
[
0
].
mlp
.
gate_proj
,
QParamsLinear
))
@require_read_token
def
check_inference_correctness
(
self
,
model
):
r
"""
Test the generation quality of the quantized model and see that we are matching the expected output.
...
...
@@ -131,7 +126,6 @@ class QuarkTest(unittest.TestCase):
# Get the generation
self
.
assertIn
(
self
.
tokenizer
.
decode
(
output_sequences
[
0
],
skip_special_tokens
=
True
),
self
.
EXPECTED_OUTPUTS
)
@require_read_token
def
test_generate_quality
(
self
):
"""
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment
Menu
Explore
Projects
Groups
Topics
Snippets