Unverified Commit 694aaa7f authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix how we compute the final non-padding token for ForSequenceClassification models (#35911)

* Fix how we compute the final non-padding token for Gemma (and probably other models)

* .size() -> .shape[]

* Propagating changes to other models

* Propagating changes to other models

* Change it for all ForSequenceClassification models

* Fix batch dim

* More TF fixes

* Copy the TF fix around as well

* Correct layer name for TFCTRL

* Cleaner .to()

* Clean up the nested if-else

* Use argmax() instead of .max().values
parent 531d1511
Showing with 244 additions and 248 deletions
+244 -248
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment