81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from src.service.ollama.ollamaModelManager import OllamaModelManager
|
|
|
|
class TestOllamaModelManager(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.model_manager = OllamaModelManager()
|
|
self.model_manager.base_url = "http://test-url:11434"
|
|
|
|
@patch('service.ollama.ollamaModelManager.requests.get')
|
|
def test_get_available_models_success(self, mock_get):
|
|
# Setup mock response
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
'models': [
|
|
{'name': 'model1'},
|
|
{'name': 'model2'}
|
|
]
|
|
}
|
|
mock_get.return_value = mock_response
|
|
|
|
# Call method
|
|
result = self.model_manager.get_available_models()
|
|
|
|
# Assertions
|
|
mock_get.assert_called_once_with("http://test-url:11434/api/tags")
|
|
self.assertEqual(len(result), 2)
|
|
self.assertEqual(result, [{'name': 'model1'}, {'name': 'model2'}])
|
|
|
|
@patch('service.ollama.ollamaModelManager.requests.get')
|
|
def test_get_available_models_exception(self, mock_get):
|
|
# Setup mock to raise exception
|
|
mock_get.side_effect = Exception("Connection error")
|
|
|
|
# Call method
|
|
result = self.model_manager.get_available_models()
|
|
|
|
# Assertions
|
|
self.assertEqual(result, [])
|
|
|
|
@patch('service.ollama.ollamaModelManager.requests.post')
|
|
def test_get_model_details_success(self, mock_post):
|
|
# Setup mock response
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
'parameters': {'context_length': 4096},
|
|
'modelfile': {'parameter': 'llama2'},
|
|
'license': 'Apache 2.0'
|
|
}
|
|
mock_post.return_value = mock_response
|
|
|
|
# Call method
|
|
result = self.model_manager.get_model_details('llama2')
|
|
|
|
# Assertions
|
|
mock_post.assert_called_once_with(
|
|
"http://test-url:11434/api/show",
|
|
json={"name": "llama2"}
|
|
)
|
|
self.assertEqual(result['name'], 'llama2')
|
|
self.assertEqual(result['context_size'], 4096)
|
|
self.assertEqual(result['license'], 'Apache 2.0')
|
|
|
|
@patch('service.ollama.ollamaModelManager.requests.post')
|
|
def test_get_model_details_exception(self, mock_post):
|
|
# Setup mock to raise exception
|
|
mock_post.side_effect = Exception("API error")
|
|
|
|
# Call method
|
|
result = self.model_manager.get_model_details('unknown_model')
|
|
|
|
# Assertions
|
|
self.assertEqual(result['name'], 'unknown_model')
|
|
self.assertTrue('error' in result)
|
|
self.assertEqual(result['error'], 'API error')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |