2023-04-15 12:50:54 +00:00
|
|
|
import unittest
|
2023-04-17 20:41:42 +00:00
|
|
|
|
2023-04-15 12:50:54 +00:00
|
|
|
import tests.context
|
2023-04-15 13:20:19 +00:00
|
|
|
from autogpt.token_counter import count_message_tokens, count_string_tokens
|
2023-04-15 12:50:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestTokenCounter(unittest.TestCase):
|
|
|
|
def test_count_message_tokens(self):
|
|
|
|
messages = [
|
|
|
|
{"role": "user", "content": "Hello"},
|
2023-04-15 19:55:13 +00:00
|
|
|
{"role": "assistant", "content": "Hi there!"},
|
2023-04-15 12:50:54 +00:00
|
|
|
]
|
|
|
|
self.assertEqual(count_message_tokens(messages), 17)
|
|
|
|
|
|
|
|
def test_count_message_tokens_with_name(self):
|
|
|
|
messages = [
|
|
|
|
{"role": "user", "content": "Hello", "name": "John"},
|
2023-04-15 19:55:13 +00:00
|
|
|
{"role": "assistant", "content": "Hi there!"},
|
2023-04-15 12:50:54 +00:00
|
|
|
]
|
|
|
|
self.assertEqual(count_message_tokens(messages), 17)
|
|
|
|
|
|
|
|
def test_count_message_tokens_empty_input(self):
|
|
|
|
self.assertEqual(count_message_tokens([]), 3)
|
|
|
|
|
|
|
|
def test_count_message_tokens_invalid_model(self):
|
|
|
|
messages = [
|
|
|
|
{"role": "user", "content": "Hello"},
|
2023-04-15 19:55:13 +00:00
|
|
|
{"role": "assistant", "content": "Hi there!"},
|
2023-04-15 12:50:54 +00:00
|
|
|
]
|
|
|
|
with self.assertRaises(KeyError):
|
|
|
|
count_message_tokens(messages, model="invalid_model")
|
|
|
|
|
|
|
|
def test_count_message_tokens_gpt_4(self):
|
|
|
|
messages = [
|
|
|
|
{"role": "user", "content": "Hello"},
|
2023-04-15 19:55:13 +00:00
|
|
|
{"role": "assistant", "content": "Hi there!"},
|
2023-04-15 12:50:54 +00:00
|
|
|
]
|
|
|
|
self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15)
|
|
|
|
|
|
|
|
def test_count_string_tokens(self):
|
|
|
|
string = "Hello, world!"
|
2023-04-15 19:55:13 +00:00
|
|
|
self.assertEqual(
|
|
|
|
count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4
|
|
|
|
)
|
2023-04-15 12:50:54 +00:00
|
|
|
|
|
|
|
def test_count_string_tokens_empty_input(self):
|
|
|
|
self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0)
|
|
|
|
|
|
|
|
def test_count_message_tokens_invalid_model(self):
|
|
|
|
messages = [
|
|
|
|
{"role": "user", "content": "Hello"},
|
2023-04-15 19:55:13 +00:00
|
|
|
{"role": "assistant", "content": "Hi there!"},
|
2023-04-15 12:50:54 +00:00
|
|
|
]
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
|
|
count_message_tokens(messages, model="invalid_model")
|
|
|
|
|
|
|
|
def test_count_string_tokens_gpt_4(self):
|
|
|
|
string = "Hello, world!"
|
|
|
|
self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4)
|
|
|
|
|
|
|
|
|
2023-04-15 19:55:13 +00:00
|
|
|
if __name__ == "__main__":
|
2023-04-15 12:50:54 +00:00
|
|
|
unittest.main()
|