import unittest from unittest.mock import patch, MagicMock from src.service.ollama.ollamaModelManager import OllamaModelManager class TestOllamaModelManager(unittest.TestCase): def setUp(self): self.model_manager = OllamaModelManager() self.model_manager.base_url = "http://test-url:11434" @patch('src.service.ollama.ollamaModelManager.requests.get') def test_get_available_models_success(self, mock_get): # Setup mock response mock_response = MagicMock() mock_response.json.return_value = { 'models': [ {'name': 'model1'}, {'name': 'model2'} ] } mock_get.return_value = mock_response # Call method result = self.model_manager.get_available_models() # Assertions mock_get.assert_called_once_with("http://test-url:11434/api/tags") self.assertEqual(len(result), 2) self.assertEqual(result, [{'name': 'model1'}, {'name': 'model2'}]) @patch('src.service.ollama.ollamaModelManager.requests.get') def test_get_available_models_exception(self, mock_get): # Setup mock to raise exception mock_get.side_effect = Exception("Connection error") # Call method result = self.model_manager.get_available_models() # Assertions self.assertEqual(result, []) @patch('src.service.ollama.ollamaModelManager.requests.post') def test_get_model_details_success(self, mock_post): # Setup mock response mock_response = MagicMock() mock_response.json.return_value = { 'parameters': {'context_length': 4096}, 'modelfile': {'parameter': 'llama2'}, 'license': 'Apache 2.0' } mock_post.return_value = mock_response # Call method result = self.model_manager.get_model_details('llama2') # Assertions mock_post.assert_called_once_with( "http://test-url:11434/api/show", json={"name": "llama2"} ) self.assertEqual(result['name'], 'llama2') self.assertEqual(result['context_size'], 4096) self.assertEqual(result['license'], 'Apache 2.0') @patch('src.service.ollama.ollamaModelManager.requests.post') def test_get_model_details_exception(self, mock_post): # Setup mock to raise exception mock_post.side_effect = Exception("API error") # Call method result = self.model_manager.get_model_details('unknown_model') # Assertions self.assertEqual(result['name'], 'unknown_model') self.assertTrue('error' in result) self.assertEqual(result['error'], 'API error') if __name__ == '__main__': unittest.main()