test: added tests for memory embeder

pull/1320/head
Tymec 2023-04-14 14:56:58 +02:00
parent b042376db4
commit fb6684450c
1 changed files with 33 additions and 0 deletions

33
tests/embeder_test.py Normal file
View File

@ -0,0 +1,33 @@
import os
import sys
# Probably a better way:
sys.path.append(os.path.abspath('../scripts'))
from memory.base import get_embedding
def MockConfig():
return type('MockConfig', (object,), {
'debug_mode': False,
'continuous_mode': False,
'speak_mode': False,
'memory_embeder': 'sbert'
})
class TestMemoryEmbeder(unittest.TestCase):
def setUp(self):
self.cfg = MockConfig()
def test_ada(self):
self.cfg.memory_embeder = "ada"
text = "Sample text"
result = get_embedding(text)
self.assertEqual(result.shape, (1536,))
def test_sbert(self):
self.cfg.memory_embeder = "sbert"
text = "Sample text"
result = get_embedding(text)
self.assertEqual(result.shape, (768,))
if __name__ == '__main__':
unittest.main()