Fix how we compute the final non-padding token for ForSequenceClassification models (#35911)
* Fix how we compute the final non-padding token for Gemma (and probably other models) * .size() -> .shape[] * Propagating changes to other models * Propagating changes to other models * Change it for all ForSequenceClassification models * Fix batch dim * More TF fixes * Copy the TF fix around as well * Correct layer name for TFCTRL * Cleaner .to() * Clean up the nested if-else * Use argmax() instead of .max().values
Showing
+244 -248
Please register or sign in to comment