Commit 47eacad8 authored by Matt's avatar Matt
Browse files

Disable head masking

No related merge requests found
Showing with 8 additions and 19 deletions
+8 -19
......@@ -26,7 +26,7 @@ from ...modeling_tf_outputs import (
TFCausalLMOutputWithPast,
TFSequenceClassifierOutputWithPast,
)
from ...modeling_tf_utils import TFPreTrainedModel, get_tf_activation
from ...modeling_tf_utils import TFPreTrainedModel, get_tf_activation, unpack_inputs
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_gpt_neo import GPTNeoConfig
......@@ -101,7 +101,7 @@ class TFGPTNeoSelfAttention(tf.keras.layers.Layer):
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
raise ValueError("Head masking is not supported in TF right now!")
attn_output = tf.matmul(attn_weights, value)
return attn_output, attn_weights
......@@ -181,8 +181,8 @@ class TFGPTNeoMLP(tf.keras.layers.Layer):
def __init__(self, intermediate_size, config, **kwargs): # in MLP: intermediate_size= 4 * hidden_size
super().__init__(**kwargs)
embed_dim = config.hidden_size
self.c_fc = tf.keras.layers.Dense(embed_dim, intermediate_size, name="c_fc")
self.c_proj = tf.keras.layers.Dense(intermediate_size, embed_dim, name="c_proj")
self.c_fc = tf.keras.layers.Dense(intermediate_size, name="c_fc")
self.c_proj = tf.keras.layers.Dense(embed_dim, name="c_proj")
self.act = get_tf_activation(config.activation_function)
self.dropout = tf.keras.layers.Dropout(float(config.resid_dropout))
......@@ -296,11 +296,6 @@ GPT_NEO_INPUTS_DOCSTRING = r"""
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
......@@ -352,8 +347,6 @@ class TFGPTNeoModel(TFGPTNeoPreTrainedModel, tf.keras.layers.Layer):
self.h = [TFGPTNeoBlock(config, layer_id=i, name=f"h_{i}") for i in range(config.num_layers)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.wte
......@@ -367,6 +360,7 @@ class TFGPTNeoModel(TFGPTNeoPreTrainedModel, tf.keras.layers.Layer):
output_type=TFBaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@unpack_inputs
def call(
self,
input_ids: Optional[tf.Tensor] = None,
......@@ -374,7 +368,6 @@ class TFGPTNeoModel(TFGPTNeoPreTrainedModel, tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
......@@ -417,7 +410,6 @@ class TFGPTNeoModel(TFGPTNeoPreTrainedModel, tf.keras.layers.Layer):
attention_mask = attention_mask[:, None, None, :]
attention_mask = tf.cast(attention_mask, self.dtype)
attention_mask = (1.0 - attention_mask) * tf.reduce_min(self.dtype)
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
......@@ -437,7 +429,6 @@ class TFGPTNeoModel(TFGPTNeoPreTrainedModel, tf.keras.layers.Layer):
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
......@@ -518,6 +509,7 @@ class TFGPTNeoForCausalLM(TFGPTNeoPreTrainedModel, tf.keras.layers.Layer):
output_type=TFCausalLMOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@unpack_inputs
def call(
self,
input_ids: Optional[tf.Tensor] = None,
......@@ -525,7 +517,6 @@ class TFGPTNeoForCausalLM(TFGPTNeoPreTrainedModel, tf.keras.layers.Layer):
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
use_cache: Optional[bool] = None,
......@@ -546,7 +537,6 @@ class TFGPTNeoForCausalLM(TFGPTNeoPreTrainedModel, tf.keras.layers.Layer):
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -614,7 +604,7 @@ class TFGPTNeoForSequenceClassification(TFGPTNeoPreTrainedModel, tf.keras.layers
super().__init__(config, *args, **kwargs)
self.num_labels = config.num_labels
self.transformer = TFGPTNeoModel(config, name="transformer")
self.score = tf.keras.layers.Dense(config.hidden_size, self.num_labels, use_bias=False, name="score")
self.score = tf.keras.layers.Dense(self.num_labels, use_bias=False, name="score")
# Initialize weights and apply final processing
self.post_init()
......@@ -624,6 +614,7 @@ class TFGPTNeoForSequenceClassification(TFGPTNeoPreTrainedModel, tf.keras.layers
output_type=TFSequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
@unpack_inputs
def call(
self,
input_ids: Optional[tf.Tensor] = None,
......@@ -631,7 +622,6 @@ class TFGPTNeoForSequenceClassification(TFGPTNeoPreTrainedModel, tf.keras.layers
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
use_cache: Optional[bool] = None,
......@@ -653,7 +643,6 @@ class TFGPTNeoForSequenceClassification(TFGPTNeoPreTrainedModel, tf.keras.layers
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment