From b345f4e7f7f852f7def0ebe473d21419ae235fdc Mon Sep 17 00:00:00 2001 From: Arthur Zucker <arthur.zucker@gmail.com> Date: Tue, 26 Mar 2024 20:15:07 +0900 Subject: [PATCH] warning was in info mode. --- src/transformers/modeling_utils.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2f24170ee5..6b4f780a59 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -595,12 +595,12 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): state_dict._metadata = metadata error_msgs = [] + unexpected_keys, missing_keys = [], [] # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: nn.Module, state_dict, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - unexpected_keys, missing_keys = [], [] args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict @@ -631,7 +631,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): # it's safe to delete it. del state_dict - return error_msgs + return error_msgs, unexpected_keys, missing_keys def find_submodule_and_param_name(model, long_key, start_prefix): @@ -3902,7 +3902,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix remove_prefix_from_model, ignore_mismatched_sizes, ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + error_msgs, unexpected_keys, missing_keys = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) offload_index = None else: # Sharded checkpoint or whole but low_cpu_mem_usage==True @@ -3975,8 +3975,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) error_msgs += new_error_msgs else: - error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) - + mmsg, unexpected_keys, missing_keys = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + error_msgs += mmsg # force memory release del state_dict gc.collect() @@ -4010,9 +4010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: - archs = [] if model.config.architectures is None else model.config.architectures - warner = logger.warning if model.__class__.__name__ in archs else logger.info - warner( + logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" -- GitLab