|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
from transformers import is_tf_available |
|
|
from transformers.testing_utils import require_tf |
|
|
|
|
|
|
|
|
if is_tf_available(): |
|
|
from transformers.activations_tf import get_tf_activation |
|
|
|
|
|
|
|
|
@require_tf |
|
|
class TestTFActivations(unittest.TestCase): |
|
|
def test_get_activation(self): |
|
|
get_tf_activation("swish") |
|
|
get_tf_activation("silu") |
|
|
get_tf_activation("gelu") |
|
|
get_tf_activation("relu") |
|
|
get_tf_activation("tanh") |
|
|
get_tf_activation("gelu_new") |
|
|
get_tf_activation("gelu_fast") |
|
|
get_tf_activation("mish") |
|
|
with self.assertRaises(KeyError): |
|
|
get_tf_activation("bogus") |
|
|
with self.assertRaises(KeyError): |
|
|
get_tf_activation(None) |
|
|
|