From a6937898c117a2f75c3ee354eb2e4916f428f441 Mon Sep 17 00:00:00 2001
From: LSinev <LSinev@users.noreply.github.com>
Date: Wed, 3 Aug 2022 18:35:22 +0300
Subject: [PATCH] Fix torch version comparisons

Comparisons like
version.parse(torch.__version__) > version.parse("1.6")
are True for torch==1.6.0+cu101 or torch==1.6.0+cpu

version.parse(version.parse(torch.__version__).base_version) are preferred (and available in pytorch_utils.py
---
 examples/research_projects/wav2vec2/run_asr.py  |  2 +-
 .../wav2vec2/run_common_voice.py                |  2 +-
 .../research_projects/wav2vec2/run_pretrain.py  |  2 +-
 src/transformers/activations.py                 |  6 +++---
 src/transformers/convert_graph_to_onnx.py       |  4 +++-
 .../models/albert/modeling_albert.py            | 10 +++++++---
 src/transformers/models/bert/modeling_bert.py   | 10 +++++++---
 .../models/big_bird/modeling_big_bird.py        |  5 ++---
 .../models/convbert/modeling_convbert.py        | 10 +++++++---
 .../models/data2vec/modeling_data2vec_text.py   | 10 +++++++---
 .../modeling_decision_transformer.py            | 10 +++++++---
 .../models/distilbert/modeling_distilbert.py    | 10 +++++++---
 .../models/electra/modeling_electra.py          | 10 +++++++---
 .../models/flaubert/modeling_flaubert.py        |  4 ++--
 src/transformers/models/flava/modeling_flava.py |  4 ++--
 src/transformers/models/fnet/modeling_fnet.py   |  5 ++---
 src/transformers/models/gpt2/modeling_gpt2.py   | 11 ++++++++---
 .../models/imagegpt/modeling_imagegpt.py        | 11 ++++++++---
 src/transformers/models/mctct/modeling_mctct.py |  4 ++--
 src/transformers/models/nezha/modeling_nezha.py | 10 +++++++---
 .../nystromformer/modeling_nystromformer.py     | 10 +++++++---
 .../models/qdqbert/modeling_qdqbert.py          |  5 ++---
 src/transformers/models/realm/modeling_realm.py | 10 +++++++---
 .../models/roberta/modeling_roberta.py          | 10 +++++++---
 src/transformers/models/vilt/modeling_vilt.py   | 12 ++++++++----
 .../xlm_roberta_xl/modeling_xlm_roberta_xl.py   | 10 +++++++---
 src/transformers/models/yoso/modeling_yoso.py   | 10 +++++++---
 src/transformers/onnx/convert.py                |  3 ++-
 src/transformers/pipelines/base.py              |  4 +++-
 src/transformers/pytorch_utils.py               |  8 ++++++--
 src/transformers/trainer.py                     | 17 +++++++++++------
 src/transformers/trainer_pt_utils.py            |  2 +-
 src/transformers/utils/import_utils.py          |  6 +++---
 ...ling_{{cookiecutter.lowercase_modelname}}.py |  4 ++--
 34 files changed, 164 insertions(+), 87 deletions(-)

diff --git a/examples/research_projects/wav2vec2/run_asr.py b/examples/research_projects/wav2vec2/run_asr.py
index ab9db11d2a..692aa39796 100755
--- a/examples/research_projects/wav2vec2/run_asr.py
+++ b/examples/research_projects/wav2vec2/run_asr.py
@@ -30,7 +30,7 @@ from transformers import (
 if is_apex_available():
     from apex import amp
 
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
     _is_native_amp_available = True
     from torch.cuda.amp import autocast
 
diff --git a/examples/research_projects/wav2vec2/run_common_voice.py b/examples/research_projects/wav2vec2/run_common_voice.py
index 10a3a77fa7..01a877a809 100644
--- a/examples/research_projects/wav2vec2/run_common_voice.py
+++ b/examples/research_projects/wav2vec2/run_common_voice.py
@@ -33,7 +33,7 @@ if is_apex_available():
     from apex import amp
 
 
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
     _is_native_amp_available = True
     from torch.cuda.amp import autocast
 
diff --git a/examples/research_projects/wav2vec2/run_pretrain.py b/examples/research_projects/wav2vec2/run_pretrain.py
index fb430d1407..8e0801429e 100755
--- a/examples/research_projects/wav2vec2/run_pretrain.py
+++ b/examples/research_projects/wav2vec2/run_pretrain.py
@@ -26,7 +26,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
 if is_apex_available():
     from apex import amp
 
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
     _is_native_amp_available = True
     from torch.cuda.amp import autocast
 
diff --git a/src/transformers/activations.py b/src/transformers/activations.py
index fad8d10613..5d413bba72 100644
--- a/src/transformers/activations.py
+++ b/src/transformers/activations.py
@@ -44,7 +44,7 @@ class GELUActivation(nn.Module):
 
     def __init__(self, use_gelu_python: bool = False):
         super().__init__()
-        if version.parse(torch.__version__) < version.parse("1.4") or use_gelu_python:
+        if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.4") or use_gelu_python:
             self.act = self._gelu_python
         else:
             self.act = nn.functional.gelu
@@ -110,7 +110,7 @@ class SiLUActivation(nn.Module):
 
     def __init__(self):
         super().__init__()
-        if version.parse(torch.__version__) < version.parse("1.7"):
+        if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
             self.act = self._silu_python
         else:
             self.act = nn.functional.silu
@@ -130,7 +130,7 @@ class MishActivation(nn.Module):
 
     def __init__(self):
         super().__init__()
-        if version.parse(torch.__version__) < version.parse("1.9"):
+        if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.9"):
             self.act = self._mish_python
         else:
             self.act = nn.functional.mish
diff --git a/src/transformers/convert_graph_to_onnx.py b/src/transformers/convert_graph_to_onnx.py
index c757fab8ff..59fb8ed39b 100644
--- a/src/transformers/convert_graph_to_onnx.py
+++ b/src/transformers/convert_graph_to_onnx.py
@@ -273,6 +273,8 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
     import torch
     from torch.onnx import export
 
+    from .pytorch_utils import is_torch_less_than_1_11
+
     print(f"Using framework PyTorch: {torch.__version__}")
 
     with torch.no_grad():
@@ -281,7 +283,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
 
         # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
         # so we check the torch version for backwards compatibility
-        if parse(torch.__version__) <= parse("1.10.99"):
+        if is_torch_less_than_1_11:
             export(
                 nlp.model,
                 model_args,
diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py
index 169f1faeb8..78df7911a2 100755
--- a/src/transformers/models/albert/modeling_albert.py
+++ b/src/transformers/models/albert/modeling_albert.py
@@ -20,7 +20,6 @@ from dataclasses import dataclass
 from typing import Dict, List, Optional, Tuple, Union
 
 import torch
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -35,7 +34,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import (
     ModelOutput,
     add_code_sample_docstrings,
@@ -212,7 +216,7 @@ class AlbertEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py
index c1ef87551b..495bbe2e49 100755
--- a/src/transformers/models/bert/modeling_bert.py
+++ b/src/transformers/models/bert/modeling_bert.py
@@ -24,7 +24,6 @@ from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -41,7 +40,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import (
     ModelOutput,
     add_code_sample_docstrings,
@@ -195,7 +199,7 @@ class BertEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py
index 06bc9251d7..fb30671927 100755
--- a/src/transformers/models/big_bird/modeling_big_bird.py
+++ b/src/transformers/models/big_bird/modeling_big_bird.py
@@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
 import numpy as np
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -38,7 +37,7 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward
+from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6
 from ...utils import (
     ModelOutput,
     add_code_sample_docstrings,
@@ -260,7 +259,7 @@ class BigBirdEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py
index 9884d32aca..136685ad6c 100755
--- a/src/transformers/models/convbert/modeling_convbert.py
+++ b/src/transformers/models/convbert/modeling_convbert.py
@@ -22,7 +22,6 @@ from typing import Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -36,7 +35,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel, SequenceSummary
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
 from .configuration_convbert import ConvBertConfig
 
@@ -194,7 +198,7 @@ class ConvBertEmbeddings(nn.Module):
         self.dropout = nn.Dropout(config.hidden_dropout_prob)
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py
index 9c85d34617..8a7d6308bf 100644
--- a/src/transformers/models/data2vec/modeling_data2vec_text.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_text.py
@@ -19,7 +19,6 @@ from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -35,7 +34,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import (
     add_code_sample_docstrings,
     add_start_docstrings,
@@ -83,7 +87,7 @@ class Data2VecTextForTextEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
index 959b9763d0..77804e7554 100755
--- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py
+++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
@@ -21,12 +21,16 @@ from typing import Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 
 from ...activations import ACT2FN
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
+from ...pytorch_utils import (
+    Conv1D,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_or_equal_than_1_6,
+    prune_conv1d_layer,
+)
 from ...utils import (
     ModelOutput,
     add_start_docstrings,
@@ -36,7 +40,7 @@ from ...utils import (
 )
 
 
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if is_torch_greater_or_equal_than_1_6:
     is_amp_available = True
     from torch.cuda.amp import autocast
 else:
diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py
index fc5b5a7b0f..1282788a57 100755
--- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -23,7 +23,6 @@ from typing import Dict, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import torch
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -40,7 +39,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import (
     add_code_sample_docstrings,
     add_start_docstrings,
@@ -102,7 +106,7 @@ class Embeddings(nn.Module):
 
         self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
         self.dropout = nn.Dropout(config.dropout)
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
             )
diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py
index 3f488fbcf5..c215256b3e 100644
--- a/src/transformers/models/electra/modeling_electra.py
+++ b/src/transformers/models/electra/modeling_electra.py
@@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -37,7 +36,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel, SequenceSummary
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import (
     ModelOutput,
     add_code_sample_docstrings,
@@ -165,7 +169,7 @@ class ElectraEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py
index 9721880ac9..4733c5d09b 100644
--- a/src/transformers/models/flaubert/modeling_flaubert.py
+++ b/src/transformers/models/flaubert/modeling_flaubert.py
@@ -19,10 +19,10 @@ import random
 from typing import Dict, Optional, Tuple, Union
 
 import torch
-from packaging import version
 from torch import nn
 
 from ...modeling_outputs import BaseModelOutput
+from ...pytorch_utils import is_torch_greater_than_1_6
 from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
 from ..xlm.modeling_xlm import (
     XLMForMultipleChoice,
@@ -139,7 +139,7 @@ class FlaubertModel(XLMModel):
         super().__init__(config)
         self.layerdrop = getattr(config, "layerdrop", 0.0)
         self.pre_norm = getattr(config, "pre_norm", False)
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
             )
diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py
index c0841a0e27..9201a98760 100644
--- a/src/transformers/models/flava/modeling_flava.py
+++ b/src/transformers/models/flava/modeling_flava.py
@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 
 from transformers.utils.doc import add_code_sample_docstrings
@@ -30,6 +29,7 @@ from transformers.utils.doc import add_code_sample_docstrings
 from ...activations import ACT2FN
 from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
 from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import is_torch_greater_than_1_6
 from ...utils import (
     ModelOutput,
     add_start_docstrings,
@@ -392,7 +392,7 @@ class FlavaTextEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py
index 8ed6718231..e2347adce9 100755
--- a/src/transformers/models/fnet/modeling_fnet.py
+++ b/src/transformers/models/fnet/modeling_fnet.py
@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -44,7 +43,7 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward
+from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6
 from ...utils import (
     add_code_sample_docstrings,
     add_start_docstrings,
@@ -118,7 +117,7 @@ class FNetEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
 
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index 1c61adb10d..4c6495d353 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -22,12 +22,18 @@ from typing import Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
+from ...pytorch_utils import (
+    Conv1D,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_or_equal_than_1_6,
+    prune_conv1d_layer,
+)
+
 
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if is_torch_greater_or_equal_than_1_6:
     is_amp_available = True
     from torch.cuda.amp import autocast
 else:
@@ -41,7 +47,6 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel, SequenceSummary
-from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
 from ...utils import (
     ModelOutput,
     add_code_sample_docstrings,
diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py
index c80f16267c..e71ea4a272 100755
--- a/src/transformers/models/imagegpt/modeling_imagegpt.py
+++ b/src/transformers/models/imagegpt/modeling_imagegpt.py
@@ -21,12 +21,18 @@ from typing import Any, Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
+from ...pytorch_utils import (
+    Conv1D,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_or_equal_than_1_6,
+    prune_conv1d_layer,
+)
+
 
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if is_torch_greater_or_equal_than_1_6:
     is_amp_available = True
     from torch.cuda.amp import autocast
 else:
@@ -39,7 +45,6 @@ from ...modeling_outputs import (
     SequenceClassifierOutputWithPast,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
 from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
 from .configuration_imagegpt import ImageGPTConfig
 
diff --git a/src/transformers/models/mctct/modeling_mctct.py b/src/transformers/models/mctct/modeling_mctct.py
index 25d368b7dc..3eb59a0c41 100755
--- a/src/transformers/models/mctct/modeling_mctct.py
+++ b/src/transformers/models/mctct/modeling_mctct.py
@@ -21,7 +21,6 @@ from typing import Optional
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 
 from ...activations import ACT2FN
@@ -34,6 +33,7 @@ from ...modeling_utils import (
     find_pruneable_heads_and_indices,
     prune_linear_layer,
 )
+from ...pytorch_utils import is_torch_greater_than_1_6
 from ...utils import logging
 from .configuration_mctct import MCTCTConfig
 
@@ -153,7 +153,7 @@ class MCTCTEmbeddings(nn.Module):
 
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py
index ab37c142bc..4fa38b3ed4 100644
--- a/src/transformers/models/nezha/modeling_nezha.py
+++ b/src/transformers/models/nezha/modeling_nezha.py
@@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -39,7 +38,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import (
     ModelOutput,
     add_code_sample_docstrings,
@@ -183,7 +187,7 @@ class NezhaEmbeddings(nn.Module):
         # any TensorFlow checkpoint file
         self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
         self.dropout = nn.Dropout(config.hidden_dropout_prob)
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros((1, config.max_position_embeddings), dtype=torch.long),
diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py
index b5813af781..e1f352d2c8 100755
--- a/src/transformers/models/nystromformer/modeling_nystromformer.py
+++ b/src/transformers/models/nystromformer/modeling_nystromformer.py
@@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -34,7 +33,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
 from .configuration_nystromformer import NystromformerConfig
 
@@ -68,7 +72,7 @@ class NystromformerEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py
index 805da6516f..35890625b1 100755
--- a/src/transformers/models/qdqbert/modeling_qdqbert.py
+++ b/src/transformers/models/qdqbert/modeling_qdqbert.py
@@ -23,7 +23,6 @@ from typing import Dict, List, Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -40,7 +39,7 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import find_pruneable_heads_and_indices, is_torch_greater_than_1_6, prune_linear_layer
 from ...utils import (
     add_code_sample_docstrings,
     add_start_docstrings,
@@ -167,7 +166,7 @@ class QDQBertEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py
index f63ea07ad9..6ee2b1fd14 100644
--- a/src/transformers/models/realm/modeling_realm.py
+++ b/src/transformers/models/realm/modeling_realm.py
@@ -20,7 +20,6 @@ from dataclasses import dataclass
 from typing import Optional, Tuple, Union
 
 import torch
-from packaging import version
 from torch import nn
 from torch.nn import CrossEntropyLoss
 
@@ -32,7 +31,12 @@ from ...modeling_outputs import (
     ModelOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
 from .configuration_realm import RealmConfig
 
@@ -181,7 +185,7 @@ class RealmEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py
index 0b57b1031e..46add0be50 100644
--- a/src/transformers/models/roberta/modeling_roberta.py
+++ b/src/transformers/models/roberta/modeling_roberta.py
@@ -20,7 +20,6 @@ from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -36,7 +35,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import (
     add_code_sample_docstrings,
     add_start_docstrings,
@@ -83,7 +87,7 @@ class RobertaEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py
index f20573d0d5..308358850c 100755
--- a/src/transformers/models/vilt/modeling_vilt.py
+++ b/src/transformers/models/vilt/modeling_vilt.py
@@ -21,7 +21,6 @@ from typing import List, Optional, Tuple
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import CrossEntropyLoss
 
@@ -35,14 +34,19 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    find_pruneable_heads_and_indices,
+    is_torch_greater_or_equal_than_1_10,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
 from .configuration_vilt import ViltConfig
 
 
 logger = logging.get_logger(__name__)
 
-if version.parse(torch.__version__) < version.parse("1.10.0"):
+if not is_torch_greater_or_equal_than_1_10:
     logger.warning(
         f"You are using torch=={torch.__version__}, but torch>=1.10.0 is required to use "
         "ViltModel. Please upgrade torch."
@@ -251,7 +255,7 @@ class TextEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
index 70dd422157..aa41466767 100644
--- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
+++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
@@ -19,7 +19,6 @@ from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -35,7 +34,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import (
     add_code_sample_docstrings,
     add_start_docstrings,
@@ -76,7 +80,7 @@ class XLMRobertaXLEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long),
diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py
index 2977cfe64c..085d46bdfb 100644
--- a/src/transformers/models/yoso/modeling_yoso.py
+++ b/src/transformers/models/yoso/modeling_yoso.py
@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
@@ -35,7 +34,12 @@ from ...modeling_outputs import (
     TokenClassifierOutput,
 )
 from ...modeling_utils import PreTrainedModel
-from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...pytorch_utils import (
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    is_torch_greater_than_1_6,
+    prune_linear_layer,
+)
 from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
 from .configuration_yoso import YosoConfig
 
@@ -257,7 +261,7 @@ class YosoEmbeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py
index d4deb79668..a896b76a1c 100644
--- a/src/transformers/onnx/convert.py
+++ b/src/transformers/onnx/convert.py
@@ -34,6 +34,7 @@ from .config import OnnxConfig
 
 if is_torch_available():
     from ..modeling_utils import PreTrainedModel
+    from ..pytorch_utils import is_torch_less_than_1_11
 
 if is_tf_available():
     from ..modeling_tf_utils import TFPreTrainedModel
@@ -155,7 +156,7 @@ def export_pytorch(
 
             # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
             # so we check the torch version for backwards compatibility
-            if parse(torch.__version__) < parse("1.10"):
+            if is_torch_less_than_1_11:
                 # export can work with named args but the dict containing named args
                 # has to be the last element of the args tuple.
                 try:
diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index 29a12e7df2..a3e11eb600 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -967,7 +967,9 @@ class Pipeline(_ScikitCompat):
 
     def get_inference_context(self):
         inference_context = (
-            torch.inference_mode if version.parse(torch.__version__) >= version.parse("1.9.0") else torch.no_grad
+            torch.inference_mode
+            if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.9.0")
+            else torch.no_grad
         )
         return inference_context
 
diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py
index c7bfba81fb..571a5d7d3c 100644
--- a/src/transformers/pytorch_utils.py
+++ b/src/transformers/pytorch_utils.py
@@ -25,8 +25,12 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
 
 logger = logging.get_logger(__name__)
 
-is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0")
-is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
+parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
+is_torch_greater_or_equal_than_1_6 = parsed_torch_version_base >= version.parse("1.6.0")
+is_torch_greater_than_1_6 = parsed_torch_version_base > version.parse("1.6.0")
+is_torch_less_than_1_8 = parsed_torch_version_base < version.parse("1.8.0")
+is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
+is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
 
 
 def torch_int_div(tensor1, tensor2):
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 37a21b0939..90a30aaa9f 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -71,7 +71,12 @@ from .modelcard import TrainingSummary
 from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
 from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
 from .optimization import Adafactor, get_scheduler
-from .pytorch_utils import ALL_LAYERNORM_LAYERS
+from .pytorch_utils import (
+    ALL_LAYERNORM_LAYERS,
+    is_torch_greater_or_equal_than_1_6,
+    is_torch_greater_or_equal_than_1_10,
+    is_torch_less_than_1_11,
+)
 from .tokenization_utils_base import PreTrainedTokenizerBase
 from .trainer_callback import (
     CallbackHandler,
@@ -165,11 +170,11 @@ if is_in_notebook():
 if is_apex_available():
     from apex import amp
 
-if version.parse(torch.__version__) >= version.parse("1.6"):
+if is_torch_greater_or_equal_than_1_6:
     _is_torch_generator_available = True
     _is_native_cuda_amp_available = True
 
-if version.parse(torch.__version__) >= version.parse("1.10"):
+if is_torch_greater_or_equal_than_1_10:
     _is_native_cpu_amp_available = True
 
 if is_datasets_available():
@@ -405,7 +410,7 @@ class Trainer:
             # Would have to update setup.py with torch>=1.12.0
             # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
             # below is the current alternative.
-            if version.parse(torch.__version__) < version.parse("1.12.0"):
+            if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
                 raise ValueError("FSDP requires PyTorch >= 1.12.0")
 
             from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
@@ -1676,7 +1681,7 @@ class Trainer:
                 is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                     train_dataloader.sampler, RandomSampler
                 )
-                if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler:
+                if is_torch_less_than_1_11 or not is_random_sampler:
                     # We just need to begin an iteration to create the randomization of the sampler.
                     # That was before PyTorch 1.11 however...
                     for _ in train_dataloader:
@@ -2430,7 +2435,7 @@ class Trainer:
         arguments, depending on the situation.
         """
         if self.use_cuda_amp or self.use_cpu_amp:
-            if version.parse(torch.__version__) >= version.parse("1.10"):
+            if is_torch_greater_or_equal_than_1_10:
                 ctx_manager = (
                     torch.cpu.amp.autocast(dtype=self.amp_dtype)
                     if self.use_cpu_amp
diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py
index c264f89d07..e1ad471b07 100644
--- a/src/transformers/trainer_pt_utils.py
+++ b/src/transformers/trainer_pt_utils.py
@@ -835,7 +835,7 @@ def _get_learning_rate(self):
         last_lr = (
             # backward compatibility for pytorch schedulers
             self.lr_scheduler.get_last_lr()[0]
-            if version.parse(torch.__version__) >= version.parse("1.4")
+            if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.4")
             else self.lr_scheduler.get_lr()[0]
         )
     return last_lr
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 363d337e2b..37172d14fc 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -300,7 +300,7 @@ def is_torch_bf16_gpu_available():
     # 4. torch.autocast exists
     # XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
     # really only correct for the 0th gpu (or currently set default device if different from 0)
-    if version.parse(torch.__version__) < version.parse("1.10"):
+    if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.10"):
         return False
 
     if torch.cuda.is_available() and torch.version.cuda is not None:
@@ -322,7 +322,7 @@ def is_torch_bf16_cpu_available():
 
     import torch
 
-    if version.parse(torch.__version__) < version.parse("1.10"):
+    if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.10"):
         return False
 
     try:
@@ -357,7 +357,7 @@ def is_torch_tf32_available():
         return False
     if int(torch.version.cuda.split(".")[0]) < 11:
         return False
-    if version.parse(torch.__version__) < version.parse("1.7"):
+    if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
         return False
 
     return True
diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
index b2ffcbb6c2..cbe8153c0e 100755
--- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
+++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
@@ -22,7 +22,6 @@ import os
 
 import torch
 import torch.utils.checkpoint
-from packaging import version
 from torch import nn
 from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 from typing import Optional, Tuple, Union
@@ -48,6 +47,7 @@ from ...pytorch_utils import (
     apply_chunking_to_forward,
     find_pruneable_heads_and_indices,
     prune_linear_layer,
+    is_torch_greater_than_1_6,
 )
 from ...utils import logging
 from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@@ -157,7 +157,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
         # position_ids (1, len position emb) is contiguous in memory and exported when serialized
         self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
         self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
-        if version.parse(torch.__version__) > version.parse("1.6.0"):
+        if is_torch_greater_than_1_6:
             self.register_buffer(
                 "token_type_ids",
                 torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
-- 
GitLab