046-模拟测试unittest.mock

概述

unittest.mock是Python标准库中的模拟测试模块,提供了强大的Mock对象功能,用于在测试中替换和模拟真实对象的行为。Mock对象允许我们隔离被测试的代码,控制外部依赖的行为,从而编写更可靠和快速的单元测试。

主要特性

  • Mock对象:创建可配置的模拟对象
  • MagicMock:支持魔术方法的Mock对象
  • patch装饰器:临时替换对象和属性
  • spec参数:基于真实对象创建Mock
  • side_effect:模拟异常和复杂行为
  • assert方法:验证Mock对象的调用情况

基本Mock对象

Mock类基础

import unittest
from unittest.mock import Mock, MagicMock, patch, call
import requests
import json
from datetime import datetime

# 示例:简单的服务类
class EmailService:
    """邮件服务类"""
    
    def __init__(self, smtp_server, port=587):
        self.smtp_server = smtp_server
        self.port = port
        self.connected = False
    
    def connect(self):
        """连接到SMTP服务器"""
        # 模拟连接逻辑
        if self.smtp_server and self.port:
            self.connected = True
            return True
        return False
    
    def send_email(self, to_email, subject, body):
        """发送邮件"""
        if not self.connected:
            raise ConnectionError("未连接到SMTP服务器")
        
        # 模拟发送邮件
        email_data = {
            'to': to_email,
            'subject': subject,
            'body': body,
            'timestamp': datetime.now().isoformat()
        }
        
        # 这里通常会有实际的SMTP发送逻辑
        return email_data
    
    def disconnect(self):
        """断开连接"""
        self.connected = False

class UserNotificationService:
    """用户通知服务"""
    
    def __init__(self, email_service):
        self.email_service = email_service
    
    def send_welcome_email(self, user_email, username):
        """发送欢迎邮件"""
        subject = f"欢迎, {username}!"
        body = f"亲爱的 {username},\n\n欢迎加入我们的平台!"
        
        try:
            self.email_service.connect()
            result = self.email_service.send_email(user_email, subject, body)
            self.email_service.disconnect()
            return result
        except Exception as e:
            return {'error': str(e)}
    
    def send_password_reset_email(self, user_email, reset_token):
        """发送密码重置邮件"""
        subject = "密码重置请求"
        body = f"您的密码重置令牌是: {reset_token}"
        
        try:
            self.email_service.connect()
            result = self.email_service.send_email(user_email, subject, body)
            self.email_service.disconnect()
            return result
        except Exception as e:
            return {'error': str(e)}

class TestBasicMock(unittest.TestCase):
    """基本Mock测试"""
    
    def test_mock_creation(self):
        """测试Mock对象创建"""
        # 创建基本Mock对象
        mock_obj = Mock()
        
        # Mock对象可以调用任何方法
        result = mock_obj.some_method()
        self.assertIsInstance(result, Mock)
        
        # 可以设置返回值
        mock_obj.some_method.return_value = "mocked result"
        self.assertEqual(mock_obj.some_method(), "mocked result")
    
    def test_mock_attributes(self):
        """测试Mock对象属性"""
        mock_obj = Mock()
        
        # 设置属性
        mock_obj.name = "Test Mock"
        mock_obj.value = 42
        
        self.assertEqual(mock_obj.name, "Test Mock")
        self.assertEqual(mock_obj.value, 42)
        
        # 动态属性也会返回Mock对象
        self.assertIsInstance(mock_obj.dynamic_attr, Mock)
    
    def test_mock_method_calls(self):
        """测试Mock方法调用"""
        mock_obj = Mock()
        
        # 调用方法
        mock_obj.method1("arg1", "arg2")
        mock_obj.method2(key="value")
        
        # 验证调用
        mock_obj.method1.assert_called_with("arg1", "arg2")
        mock_obj.method2.assert_called_with(key="value")
        
        # 验证调用次数
        self.assertEqual(mock_obj.method1.call_count, 1)
        self.assertEqual(mock_obj.method2.call_count, 1)
    
    def test_mock_side_effect(self):
        """测试Mock副作用"""
        mock_obj = Mock()
        
        # 设置副作用为异常
        mock_obj.failing_method.side_effect = ValueError("模拟错误")
        
        with self.assertRaises(ValueError):
            mock_obj.failing_method()
        
        # 设置副作用为函数
        def custom_side_effect(x):
            return x * 2
        
        mock_obj.doubler.side_effect = custom_side_effect
        self.assertEqual(mock_obj.doubler(5), 10)
        
        # 设置副作用为值列表
        mock_obj.sequence.side_effect = [1, 2, 3]
        self.assertEqual(mock_obj.sequence(), 1)
        self.assertEqual(mock_obj.sequence(), 2)
        self.assertEqual(mock_obj.sequence(), 3)

class TestEmailServiceMock(unittest.TestCase):
    """邮件服务Mock测试"""
    
    def test_user_notification_with_mock(self):
        """使用Mock测试用户通知服务"""
        # 创建Mock邮件服务
        mock_email_service = Mock(spec=EmailService)
        
        # 配置Mock行为
        mock_email_service.connect.return_value = True
        mock_email_service.send_email.return_value = {
            'to': 'user@example.com',
            'subject': '欢迎, John!',
            'body': '亲爱的 John,\n\n欢迎加入我们的平台!',
            'timestamp': '2023-01-01T12:00:00'
        }
        
        # 创建通知服务
        notification_service = UserNotificationService(mock_email_service)
        
        # 测试发送欢迎邮件
        result = notification_service.send_welcome_email('user@example.com', 'John')
        
        # 验证结果
        self.assertIn('to', result)
        self.assertEqual(result['to'], 'user@example.com')
        
        # 验证Mock调用
        mock_email_service.connect.assert_called_once()
        mock_email_service.send_email.assert_called_once_with(
            'user@example.com',
            '欢迎, John!',
            '亲爱的 John,\n\n欢迎加入我们的平台!'
        )
        mock_email_service.disconnect.assert_called_once()
    
    def test_email_service_connection_failure(self):
        """测试邮件服务连接失败"""
        mock_email_service = Mock(spec=EmailService)
        
        # 模拟连接失败
        mock_email_service.connect.side_effect = ConnectionError("连接失败")
        
        notification_service = UserNotificationService(mock_email_service)
        
        # 测试连接失败情况
        result = notification_service.send_welcome_email('user@example.com', 'John')
        
        # 验证错误处理
        self.assertIn('error', result)
        self.assertIn('连接失败', result['error'])
    
    def test_email_service_send_failure(self):
        """测试邮件发送失败"""
        mock_email_service = Mock(spec=EmailService)
        
        # 配置连接成功但发送失败
        mock_email_service.connect.return_value = True
        mock_email_service.send_email.side_effect = Exception("发送失败")
        
        notification_service = UserNotificationService(mock_email_service)
        
        result = notification_service.send_password_reset_email('user@example.com', 'reset123')
        
        # 验证错误处理
        self.assertIn('error', result)
        self.assertIn('发送失败', result['error'])

MagicMock和魔术方法

MagicMock高级功能

import unittest
from unittest.mock import MagicMock, Mock
import operator

class DataContainer:
    """数据容器类"""
    
    def __init__(self, data=None):
        self.data = data or []
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __setitem__(self, index, value):
        self.data[index] = value
    
    def __contains__(self, item):
        return item in self.data
    
    def __iter__(self):
        return iter(self.data)
    
    def __str__(self):
        return f"DataContainer({self.data})"
    
    def __repr__(self):
        return f"DataContainer(data={self.data!r})"

class Calculator:
    """计算器类"""
    
    def __init__(self, data_source):
        self.data_source = data_source
    
    def sum_all(self):
        """计算所有数据的和"""
        return sum(self.data_source)
    
    def get_item_at(self, index):
        """获取指定索引的项目"""
        return self.data_source[index]
    
    def count_items(self):
        """计算项目数量"""
        return len(self.data_source)
    
    def contains_value(self, value):
        """检查是否包含指定值"""
        return value in self.data_source

class TestMagicMock(unittest.TestCase):
    """MagicMock测试"""
    
    def test_magic_mock_basic(self):
        """测试MagicMock基本功能"""
        # MagicMock支持魔术方法
        magic_mock = MagicMock()
        
        # 设置魔术方法的返回值
        magic_mock.__len__.return_value = 5
        magic_mock.__getitem__.return_value = "mocked_item"
        magic_mock.__contains__.return_value = True
        
        # 测试魔术方法
        self.assertEqual(len(magic_mock), 5)
        self.assertEqual(magic_mock[0], "mocked_item")
        self.assertTrue("anything" in magic_mock)
    
    def test_magic_mock_with_calculator(self):
        """使用MagicMock测试计算器"""
        # 创建MagicMock数据源
        mock_data_source = MagicMock()
        
        # 配置魔术方法
        mock_data_source.__iter__.return_value = iter([1, 2, 3, 4, 5])
        mock_data_source.__len__.return_value = 5
        mock_data_source.__getitem__.return_value = 10
        mock_data_source.__contains__.return_value = True
        
        # 创建计算器
        calculator = Calculator(mock_data_source)
        
        # 测试各种操作
        self.assertEqual(calculator.sum_all(), 15)  # sum([1,2,3,4,5])
        self.assertEqual(calculator.count_items(), 5)
        self.assertEqual(calculator.get_item_at(0), 10)
        self.assertTrue(calculator.contains_value(3))
        
        # 验证魔术方法调用
        mock_data_source.__iter__.assert_called_once()
        mock_data_source.__len__.assert_called_once()
        mock_data_source.__getitem__.assert_called_once_with(0)
        mock_data_source.__contains__.assert_called_once_with(3)
    
    def test_magic_mock_side_effects(self):
        """测试MagicMock副作用"""
        mock_container = MagicMock()
        
        # 设置__getitem__的副作用
        def getitem_side_effect(index):
            if index < 0 or index >= 3:
                raise IndexError("索引超出范围")
            return f"item_{index}"
        
        mock_container.__getitem__.side_effect = getitem_side_effect
        mock_container.__len__.return_value = 3
        
        calculator = Calculator(mock_container)
        
        # 测试正常索引
        self.assertEqual(calculator.get_item_at(0), "item_0")
        self.assertEqual(calculator.get_item_at(2), "item_2")
        
        # 测试异常情况
        with self.assertRaises(IndexError):
            calculator.get_item_at(5)
    
    def test_magic_mock_call_tracking(self):
        """测试MagicMock调用跟踪"""
        mock_obj = MagicMock()
        
        # 进行各种调用
        len(mock_obj)
        mock_obj[0]
        mock_obj[1] = "value"
        "test" in mock_obj
        str(mock_obj)
        
        # 验证调用记录
        mock_obj.__len__.assert_called_once()
        mock_obj.__getitem__.assert_called_once_with(0)
        mock_obj.__setitem__.assert_called_once_with(1, "value")
        mock_obj.__contains__.assert_called_once_with("test")
        mock_obj.__str__.assert_called_once()
    
    def test_magic_mock_comparison_operations(self):
        """测试MagicMock比较操作"""
        mock_obj = MagicMock()
        
        # 配置比较操作
        mock_obj.__eq__.return_value = True
        mock_obj.__lt__.return_value = False
        mock_obj.__gt__.return_value = True
        
        # 测试比较
        self.assertTrue(mock_obj == "anything")
        self.assertFalse(mock_obj < 10)
        self.assertTrue(mock_obj > 5)
        
        # 验证调用
        mock_obj.__eq__.assert_called_with("anything")
        mock_obj.__lt__.assert_called_with(10)
        mock_obj.__gt__.assert_called_with(5)
    
    def test_magic_mock_arithmetic_operations(self):
        """测试MagicMock算术操作"""
        mock_num = MagicMock()
        
        # 配置算术操作
        mock_num.__add__.return_value = 15
        mock_num.__sub__.return_value = 5
        mock_num.__mul__.return_value = 50
        mock_num.__truediv__.return_value = 2.5
        
        # 测试算术操作
        self.assertEqual(mock_num + 5, 15)
        self.assertEqual(mock_num - 5, 5)
        self.assertEqual(mock_num * 5, 50)
        self.assertEqual(mock_num / 4, 2.5)
        
        # 验证调用
        mock_num.__add__.assert_called_with(5)
        mock_num.__sub__.assert_called_with(5)
        mock_num.__mul__.assert_called_with(5)
        mock_num.__truediv__.assert_called_with(4)

patch装饰器和上下文管理器

patch装饰器详解

import unittest
from unittest.mock import patch, Mock, MagicMock
import requests
import os
import json
from datetime import datetime

# 示例:API客户端
class APIClient:
    """API客户端"""
    
    def __init__(self, base_url, api_key):
        self.base_url = base_url
        self.api_key = api_key
    
    def get_user(self, user_id):
        """获取用户信息"""
        url = f"{self.base_url}/users/{user_id}"
        headers = {"Authorization": f"Bearer {self.api_key}"}
        
        response = requests.get(url, headers=headers)
        response.raise_for_status()
        return response.json()
    
    def create_user(self, user_data):
        """创建用户"""
        url = f"{self.base_url}/users"
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        response = requests.post(url, json=user_data, headers=headers)
        response.raise_for_status()
        return response.json()
    
    def upload_file(self, file_path, user_id):
        """上传文件"""
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"文件不存在: {file_path}")
        
        url = f"{self.base_url}/users/{user_id}/files"
        headers = {"Authorization": f"Bearer {self.api_key}"}
        
        with open(file_path, 'rb') as f:
            files = {'file': f}
            response = requests.post(url, files=files, headers=headers)
        
        response.raise_for_status()
        return response.json()

class FileManager:
    """文件管理器"""
    
    def __init__(self, base_path):
        self.base_path = base_path
    
    def create_file(self, filename, content):
        """创建文件"""
        file_path = os.path.join(self.base_path, filename)
        
        # 确保目录存在
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        
        with open(file_path, 'w', encoding='utf-8') as f:
            f.write(content)
        
        return file_path
    
    def read_file(self, filename):
        """读取文件"""
        file_path = os.path.join(self.base_path, filename)
        
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"文件不存在: {file_path}")
        
        with open(file_path, 'r', encoding='utf-8') as f:
            return f.read()
    
    def delete_file(self, filename):
        """删除文件"""
        file_path = os.path.join(self.base_path, filename)
        
        if os.path.exists(file_path):
            os.remove(file_path)
            return True
        return False
    
    def list_files(self):
        """列出文件"""
        if not os.path.exists(self.base_path):
            return []
        
        return [f for f in os.listdir(self.base_path) 
                if os.path.isfile(os.path.join(self.base_path, f))]

class TestPatchDecorator(unittest.TestCase):
    """patch装饰器测试"""
    
    @patch('requests.get')
    def test_api_client_get_user(self, mock_get):
        """测试API客户端获取用户"""
        # 配置mock响应
        mock_response = Mock()
        mock_response.json.return_value = {
            'id': 1,
            'name': 'John Doe',
            'email': 'john@example.com'
        }
        mock_response.raise_for_status.return_value = None
        mock_get.return_value = mock_response
        
        # 创建API客户端并测试
        client = APIClient("https://api.example.com", "test_key")
        result = client.get_user(1)
        
        # 验证结果
        self.assertEqual(result['id'], 1)
        self.assertEqual(result['name'], 'John Doe')
        
        # 验证请求调用
        mock_get.assert_called_once_with(
            "https://api.example.com/users/1",
            headers={"Authorization": "Bearer test_key"}
        )
    
    @patch('requests.post')
    def test_api_client_create_user(self, mock_post):
        """测试API客户端创建用户"""
        # 配置mock响应
        mock_response = Mock()
        mock_response.json.return_value = {
            'id': 2,
            'name': 'Jane Doe',
            'email': 'jane@example.com',
            'created_at': '2023-01-01T12:00:00Z'
        }
        mock_response.raise_for_status.return_value = None
        mock_post.return_value = mock_response
        
        # 测试创建用户
        client = APIClient("https://api.example.com", "test_key")
        user_data = {
            'name': 'Jane Doe',
            'email': 'jane@example.com'
        }
        result = client.create_user(user_data)
        
        # 验证结果
        self.assertEqual(result['id'], 2)
        self.assertIn('created_at', result)
        
        # 验证请求调用
        mock_post.assert_called_once_with(
            "https://api.example.com/users",
            json=user_data,
            headers={
                "Authorization": "Bearer test_key",
                "Content-Type": "application/json"
            }
        )
    
    @patch('requests.get')
    def test_api_client_error_handling(self, mock_get):
        """测试API客户端错误处理"""
        # 配置mock抛出异常
        mock_get.side_effect = requests.exceptions.HTTPError("404 Not Found")
        
        client = APIClient("https://api.example.com", "test_key")
        
        # 测试异常处理
        with self.assertRaises(requests.exceptions.HTTPError):
            client.get_user(999)
    
    @patch('os.path.exists')
    @patch('builtins.open', create=True)
    @patch('requests.post')
    def test_api_client_upload_file(self, mock_post, mock_open, mock_exists):
        """测试API客户端文件上传"""
        # 配置mocks
        mock_exists.return_value = True
        mock_file = Mock()
        mock_open.return_value.__enter__.return_value = mock_file
        
        mock_response = Mock()
        mock_response.json.return_value = {
            'file_id': 'file123',
            'filename': 'test.txt',
            'size': 1024
        }
        mock_response.raise_for_status.return_value = None
        mock_post.return_value = mock_response
        
        # 测试文件上传
        client = APIClient("https://api.example.com", "test_key")
        result = client.upload_file("/path/to/test.txt", 1)
        
        # 验证结果
        self.assertEqual(result['file_id'], 'file123')
        
        # 验证调用
        mock_exists.assert_called_once_with("/path/to/test.txt")
        mock_open.assert_called_once_with("/path/to/test.txt", 'rb')
        mock_post.assert_called_once()

class TestPatchContextManager(unittest.TestCase):
    """patch上下文管理器测试"""
    
    def test_file_manager_with_patch_context(self):
        """使用patch上下文管理器测试文件管理器"""
        file_manager = FileManager("/test/path")
        
        # 使用patch作为上下文管理器
        with patch('os.makedirs') as mock_makedirs, \
             patch('builtins.open', create=True) as mock_open, \
             patch('os.path.dirname') as mock_dirname:
            
            # 配置mocks
            mock_dirname.return_value = "/test/path"
            mock_file = Mock()
            mock_open.return_value.__enter__.return_value = mock_file
            
            # 测试创建文件
            result = file_manager.create_file("test.txt", "Hello, World!")
            
            # 验证调用
            mock_makedirs.assert_called_once_with("/test/path", exist_ok=True)
            mock_open.assert_called_once_with("/test/path/test.txt", 'w', encoding='utf-8')
            mock_file.write.assert_called_once_with("Hello, World!")
    
    def test_multiple_patches(self):
        """测试多个patch"""
        file_manager = FileManager("/test/path")
        
        with patch('os.path.exists') as mock_exists, \
             patch('os.listdir') as mock_listdir, \
             patch('os.path.isfile') as mock_isfile, \
             patch('os.path.join') as mock_join:
            
            # 配置mocks
            mock_exists.return_value = True
            mock_listdir.return_value = ['file1.txt', 'file2.txt', 'dir1']
            mock_isfile.side_effect = lambda path: path.endswith('.txt')
            mock_join.side_effect = lambda base, name: f"{base}/{name}"
            
            # 测试列出文件
            files = file_manager.list_files()
            
            # 验证结果
            self.assertEqual(files, ['file1.txt', 'file2.txt'])
            
            # 验证调用
            mock_exists.assert_called_once_with("/test/path")
            mock_listdir.assert_called_once_with("/test/path")
            self.assertEqual(mock_isfile.call_count, 3)

class TestPatchObject(unittest.TestCase):
    """patch.object测试"""
    
    def test_patch_object_method(self):
        """测试patch.object方法"""
        file_manager = FileManager("/test/path")
        
        # 使用patch.object替换特定方法
        with patch.object(file_manager, 'read_file') as mock_read:
            mock_read.return_value = "Mocked content"
            
            # 测试被patch的方法
            content = file_manager.read_file("test.txt")
            self.assertEqual(content, "Mocked content")
            
            # 验证调用
            mock_read.assert_called_once_with("test.txt")
    
    @patch.object(FileManager, 'create_file')
    def test_patch_object_decorator(self, mock_create):
        """测试patch.object装饰器"""
        mock_create.return_value = "/test/path/created.txt"
        
        file_manager = FileManager("/test/path")
        result = file_manager.create_file("created.txt", "content")
        
        self.assertEqual(result, "/test/path/created.txt")
        mock_create.assert_called_once_with("created.txt", "content")

class TestPatchMultiple(unittest.TestCase):
    """patch.multiple测试"""
    
    @patch.multiple('os', makedirs=Mock(), remove=Mock())
    def test_patch_multiple_decorator(self):
        """测试patch.multiple装饰器"""
        file_manager = FileManager("/test/path")
        
        # 这些调用会使用mocked版本
        file_manager.delete_file("test.txt")
        
        # 验证调用
        os.remove.assert_called()
    
    def test_patch_multiple_context(self):
        """测试patch.multiple上下文管理器"""
        with patch.multiple('os.path', 
                          exists=Mock(return_value=True),
                          join=Mock(side_effect=lambda *args: '/'.join(args))):
            
            file_manager = FileManager("/test/path")
            
            # 使用mocked版本
            result = file_manager.delete_file("test.txt")
            
            # 验证调用
            os.path.exists.assert_called()

spec参数和自动规范

使用spec创建规范化Mock

import unittest
from unittest.mock import Mock, MagicMock, patch, create_autospec
from datetime import datetime
import json

class DatabaseConnection:
    """数据库连接类"""
    
    def __init__(self, host, port, database):
        self.host = host
        self.port = port
        self.database = database
        self.connected = False
    
    def connect(self):
        """连接数据库"""
        # 模拟连接逻辑
        self.connected = True
        return True
    
    def disconnect(self):
        """断开连接"""
        self.connected = False
    
    def execute_query(self, query, params=None):
        """执行查询"""
        if not self.connected:
            raise ConnectionError("数据库未连接")
        
        # 模拟查询执行
        return {"rows": [], "affected": 0}
    
    def execute_transaction(self, queries):
        """执行事务"""
        if not self.connected:
            raise ConnectionError("数据库未连接")
        
        results = []
        for query in queries:
            result = self.execute_query(query)
            results.append(result)
        
        return results
    
    def get_table_info(self, table_name):
        """获取表信息"""
        query = f"DESCRIBE {table_name}"
        return self.execute_query(query)

class UserRepository:
    """用户仓库类"""
    
    def __init__(self, db_connection):
        self.db = db_connection
    
    def create_user(self, user_data):
        """创建用户"""
        query = "INSERT INTO users (name, email) VALUES (?, ?)"
        params = (user_data['name'], user_data['email'])
        
        result = self.db.execute_query(query, params)
        return result
    
    def get_user_by_id(self, user_id):
        """根据ID获取用户"""
        query = "SELECT * FROM users WHERE id = ?"
        params = (user_id,)
        
        result = self.db.execute_query(query, params)
        return result
    
    def update_user(self, user_id, user_data):
        """更新用户"""
        query = "UPDATE users SET name = ?, email = ? WHERE id = ?"
        params = (user_data['name'], user_data['email'], user_id)
        
        result = self.db.execute_query(query, params)
        return result
    
    def delete_user(self, user_id):
        """删除用户"""
        query = "DELETE FROM users WHERE id = ?"
        params = (user_id,)
        
        result = self.db.execute_query(query, params)
        return result

class TestSpecParameter(unittest.TestCase):
    """spec参数测试"""
    
    def test_mock_without_spec(self):
        """测试没有spec的Mock"""
        # 没有spec的Mock可以调用任何方法
        mock_db = Mock()
        
        # 这些调用都会成功,即使原始类没有这些方法
        mock_db.nonexistent_method()
        mock_db.another_fake_method("arg")
        
        # 这可能导致测试中的错误被隐藏
        self.assertTrue(True)  # 测试总是通过
    
    def test_mock_with_spec(self):
        """测试带spec的Mock"""
        # 使用spec限制Mock的接口
        mock_db = Mock(spec=DatabaseConnection)
        
        # 可以调用真实类的方法
        mock_db.connect()
        mock_db.execute_query("SELECT * FROM users")
        
        # 尝试调用不存在的方法会引发AttributeError
        with self.assertRaises(AttributeError):
            mock_db.nonexistent_method()
    
    def test_mock_with_spec_set(self):
        """测试带spec_set的Mock"""
        # spec_set更严格,不允许设置新属性
        mock_db = Mock(spec_set=DatabaseConnection)
        
        # 可以设置现有属性
        mock_db.connected = True
        
        # 尝试设置不存在的属性会引发AttributeError
        with self.assertRaises(AttributeError):
            mock_db.new_attribute = "value"
    
    def test_autospec(self):
        """测试autospec"""
        # create_autospec自动创建规范化的Mock
        mock_db = create_autospec(DatabaseConnection)
        
        # 配置方法返回值
        mock_db.connect.return_value = True
        mock_db.execute_query.return_value = {"rows": [{"id": 1, "name": "John"}]}
        
        # 创建用户仓库
        user_repo = UserRepository(mock_db)
        
        # 测试创建用户
        user_data = {"name": "John", "email": "john@example.com"}
        result = user_repo.create_user(user_data)
        
        # 验证调用
        mock_db.execute_query.assert_called_once_with(
            "INSERT INTO users (name, email) VALUES (?, ?)",
            ("John", "john@example.com")
        )
        
        # 验证返回值
        self.assertEqual(result["rows"][0]["name"], "John")
    
    def test_autospec_with_instance(self):
        """测试基于实例的autospec"""
        # 创建真实实例
        real_db = DatabaseConnection("localhost", 5432, "testdb")
        
        # 基于实例创建autospec
        mock_db = create_autospec(real_db, spec_set=True)
        
        # 配置Mock行为
        mock_db.connect.return_value = True
        mock_db.execute_query.return_value = {"rows": [], "affected": 1}
        
        user_repo = UserRepository(mock_db)
        
        # 测试更新用户
        user_data = {"name": "Jane", "email": "jane@example.com"}
        result = user_repo.update_user(1, user_data)
        
        # 验证调用
        mock_db.execute_query.assert_called_once_with(
            "UPDATE users SET name = ?, email = ? WHERE id = ?",
            ("Jane", "jane@example.com", 1)
        )
    
    def test_spec_with_inheritance(self):
        """测试继承类的spec"""
        class ExtendedDatabaseConnection(DatabaseConnection):
            def backup_database(self):
                """备份数据库"""
                return {"backup_id": "backup_123"}
            
            def restore_database(self, backup_id):
                """恢复数据库"""
                return {"status": "restored"}
        
        # 使用扩展类作为spec
        mock_db = Mock(spec=ExtendedDatabaseConnection)
        
        # 可以调用基类和扩展类的方法
        mock_db.connect()
        mock_db.backup_database()
        mock_db.restore_database("backup_123")
        
        # 验证方法存在
        self.assertTrue(hasattr(mock_db, 'connect'))
        self.assertTrue(hasattr(mock_db, 'backup_database'))
        self.assertTrue(hasattr(mock_db, 'restore_database'))

class TestAdvancedSpecUsage(unittest.TestCase):
    """高级spec使用测试"""
    
    def test_partial_mock_with_spec(self):
        """测试部分Mock与spec"""
        # 创建真实对象
        real_db = DatabaseConnection("localhost", 5432, "testdb")
        
        # 只Mock特定方法
        with patch.object(real_db, 'execute_query', spec=True) as mock_execute:
            mock_execute.return_value = {"rows": [{"id": 1}], "affected": 1}
            
            user_repo = UserRepository(real_db)
            result = user_repo.get_user_by_id(1)
            
            # 验证Mock调用
            mock_execute.assert_called_once_with(
                "SELECT * FROM users WHERE id = ?",
                (1,)
            )
            
            # 验证返回值
            self.assertEqual(result["rows"][0]["id"], 1)
    
    def test_nested_spec(self):
        """测试嵌套对象的spec"""
        class DatabaseManager:
            def __init__(self):
                self.connection = DatabaseConnection("localhost", 5432, "db")
            
            def get_connection(self):
                return self.connection
        
        # 创建嵌套Mock
        mock_manager = create_autospec(DatabaseManager)
        mock_connection = create_autospec(DatabaseConnection)
        
        # 配置嵌套Mock
        mock_manager.get_connection.return_value = mock_connection
        mock_connection.execute_query.return_value = {"rows": []}
        
        # 测试嵌套调用
        connection = mock_manager.get_connection()
        result = connection.execute_query("SELECT 1")
        
        # 验证调用
        mock_manager.get_connection.assert_called_once()
        mock_connection.execute_query.assert_called_once_with("SELECT 1")
    
    def test_spec_with_properties(self):
        """测试带属性的spec"""
        class ConfigurableDB(DatabaseConnection):
            @property
            def connection_string(self):
                return f"{self.host}:{self.port}/{self.database}"
            
            @property
            def is_connected(self):
                return self.connected
            
            @is_connected.setter
            def is_connected(self, value):
                self.connected = value
        
        # 创建带属性的Mock
        mock_db = create_autospec(ConfigurableDB)
        
        # 配置属性
        mock_db.connection_string = "localhost:5432/testdb"
        mock_db.is_connected = True
        
        # 测试属性访问
        self.assertEqual(mock_db.connection_string, "localhost:5432/testdb")
        self.assertTrue(mock_db.is_connected)
    
    def test_spec_validation_errors(self):
        """测试spec验证错误"""
        mock_db = Mock(spec=DatabaseConnection)
        
        # 正确的方法调用
        mock_db.connect()
        mock_db.execute_query("SELECT 1")
        
        # 错误的方法名会引发AttributeError
        with self.assertRaises(AttributeError):
            mock_db.invalid_method()
        
        # 但是可以访问不存在的属性(除非使用spec_set)
        mock_db.some_attribute = "value"  # 这不会引发错误
        
        # 使用spec_set会更严格
        strict_mock = Mock(spec_set=DatabaseConnection)
        with self.assertRaises(AttributeError):
            strict_mock.invalid_attribute = "value"

高级Mock技巧

复杂场景的Mock策略

import unittest
from unittest.mock import Mock, MagicMock, patch, call, ANY
import asyncio
from datetime import datetime, timedelta
import json

class CacheService:
    """缓存服务"""
    
    def __init__(self, redis_client):
        self.redis = redis_client
    
    def get(self, key):
        """获取缓存值"""
        return self.redis.get(key)
    
    def set(self, key, value, expire=None):
        """设置缓存值"""
        if expire:
            return self.redis.setex(key, expire, value)
        return self.redis.set(key, value)
    
    def delete(self, key):
        """删除缓存"""
        return self.redis.delete(key)

class NotificationService:
    """通知服务"""
    
    def __init__(self, email_client, sms_client, push_client):
        self.email = email_client
        self.sms = sms_client
        self.push = push_client
    
    def send_notification(self, user, message, channels=None):
        """发送通知"""
        if channels is None:
            channels = ['email']
        
        results = {}
        
        if 'email' in channels:
            results['email'] = self.email.send(user.email, message)
        
        if 'sms' in channels:
            results['sms'] = self.sms.send(user.phone, message)
        
        if 'push' in channels:
            results['push'] = self.push.send(user.device_id, message)
        
        return results

class OrderService:
    """订单服务"""
    
    def __init__(self, db, cache, notification):
        self.db = db
        self.cache = cache
        self.notification = notification
    
    def create_order(self, user, items):
        """创建订单"""
        # 计算总价
        total = sum(item['price'] * item['quantity'] for item in items)
        
        # 创建订单数据
        order_data = {
            'user_id': user.id,
            'items': items,
            'total': total,
            'status': 'pending',
            'created_at': datetime.now().isoformat()
        }
        
        # 保存到数据库
        order_id = self.db.insert('orders', order_data)
        order_data['id'] = order_id
        
        # 缓存订单
        cache_key = f"order:{order_id}"
        self.cache.set(cache_key, json.dumps(order_data), expire=3600)
        
        # 发送通知
        message = f"订单 {order_id} 已创建,总金额: {total}"
        self.notification.send_notification(user, message, ['email', 'push'])
        
        return order_data
    
    def get_order(self, order_id):
        """获取订单"""
        # 先从缓存获取
        cache_key = f"order:{order_id}"
        cached_order = self.cache.get(cache_key)
        
        if cached_order:
            return json.loads(cached_order)
        
        # 从数据库获取
        order = self.db.select('orders', {'id': order_id})
        
        if order:
            # 更新缓存
            self.cache.set(cache_key, json.dumps(order), expire=3600)
        
        return order

class TestAdvancedMockTechniques(unittest.TestCase):
    """高级Mock技巧测试"""
    
    def test_mock_call_tracking(self):
        """测试Mock调用跟踪"""
        mock_redis = Mock()
        cache_service = CacheService(mock_redis)
        
        # 执行多个操作
        cache_service.set("key1", "value1")
        cache_service.set("key2", "value2", expire=60)
        cache_service.get("key1")
        cache_service.delete("key1")
        
        # 验证调用顺序
        expected_calls = [
            call.set("key1", "value1"),
            call.setex("key2", 60, "value2"),
            call.get("key1"),
            call.delete("key1")
        ]
        
        # 验证所有调用
        mock_redis.assert_has_calls(expected_calls, any_order=False)
        
        # 验证特定方法的调用次数
        self.assertEqual(mock_redis.set.call_count, 1)
        self.assertEqual(mock_redis.setex.call_count, 1)
        self.assertEqual(mock_redis.get.call_count, 1)
        self.assertEqual(mock_redis.delete.call_count, 1)
    
    def test_mock_with_any_matcher(self):
        """测试使用ANY匹配器"""
        mock_db = Mock()
        mock_cache = Mock()
        mock_notification = Mock()
        
        order_service = OrderService(mock_db, mock_cache, mock_notification)
        
        # 创建用户和订单项
        user = Mock()
        user.id = 1
        user.email = "user@example.com"
        user.device_id = "device123"
        
        items = [
            {'price': 10.0, 'quantity': 2},
            {'price': 5.0, 'quantity': 1}
        ]
        
        # 配置Mock返回值
        mock_db.insert.return_value = 12345
        
        # 创建订单
        order_service.create_order(user, items)
        
        # 使用ANY匹配器验证调用
        mock_db.insert.assert_called_once_with('orders', ANY)
        mock_cache.set.assert_called_once_with(
            "order:12345", 
            ANY,  # JSON字符串,具体内容可能变化
            expire=3600
        )
        
        # 验证通知调用
        mock_notification.send_notification.assert_called_once_with(
            user, 
            ANY,  # 消息内容包含订单ID和总金额
            ['email', 'push']
        )
    
    def test_mock_side_effect_sequence(self):
        """测试Mock副作用序列"""
        mock_cache = Mock()
        
        # 设置get方法的副作用序列
        mock_cache.get.side_effect = [
            None,  # 第一次调用返回None(缓存未命中)
            '{"id": 1, "status": "pending"}',  # 第二次调用返回缓存数据
            '{"id": 1, "status": "confirmed"}'  # 第三次调用返回更新后的数据
        ]
        
        cache_service = CacheService(mock_cache)
        
        # 第一次获取(缓存未命中)
        result1 = cache_service.get("order:1")
        self.assertIsNone(result1)
        
        # 第二次获取(缓存命中)
        result2 = cache_service.get("order:1")
        self.assertEqual(result2, '{"id": 1, "status": "pending"}')
        
        # 第三次获取(缓存更新)
        result3 = cache_service.get("order:1")
        self.assertEqual(result3, '{"id": 1, "status": "confirmed"}')
        
        # 验证调用次数
        self.assertEqual(mock_cache.get.call_count, 3)
    
    def test_mock_side_effect_function(self):
        """测试Mock副作用函数"""
        mock_db = Mock()
        
        # 模拟数据库数据
        orders_data = {
            1: {'id': 1, 'status': 'pending', 'total': 100},
            2: {'id': 2, 'status': 'confirmed', 'total': 200}
        }
        
        def db_select_side_effect(table, conditions):
            """模拟数据库查询"""
            if table == 'orders':
                order_id = conditions.get('id')
                return orders_data.get(order_id)
            return None
        
        mock_db.select.side_effect = db_select_side_effect
        
        mock_cache = Mock()
        mock_cache.get.return_value = None  # 缓存未命中
        
        mock_notification = Mock()
        
        order_service = OrderService(mock_db, mock_cache, mock_notification)
        
        # 测试获取存在的订单
        order1 = order_service.get_order(1)
        self.assertEqual(order1['id'], 1)
        self.assertEqual(order1['status'], 'pending')
        
        # 测试获取不存在的订单
        order_none = order_service.get_order(999)
        self.assertIsNone(order_none)
        
        # 验证数据库调用
        expected_calls = [
            call('orders', {'id': 1}),
            call('orders', {'id': 999})
        ]
        mock_db.select.assert_has_calls(expected_calls)
    
    def test_mock_exception_handling(self):
        """测试Mock异常处理"""
        mock_email = Mock()
        mock_sms = Mock()
        mock_push = Mock()
        
        # 配置不同的异常
        mock_email.send.side_effect = ConnectionError("邮件服务不可用")
        mock_sms.send.return_value = {"status": "success"}
        mock_push.send.side_effect = TimeoutError("推送服务超时")
        
        notification_service = NotificationService(mock_email, mock_sms, mock_push)
        
        user = Mock()
        user.email = "user@example.com"
        user.phone = "1234567890"
        user.device_id = "device123"
        
        # 测试异常处理
        try:
            results = notification_service.send_notification(
                user, 
                "测试消息", 
                ['email', 'sms', 'push']
            )
            
            # 如果没有异常处理,这里会失败
            self.fail("应该抛出异常")
            
        except (ConnectionError, TimeoutError):
            # 预期的异常
            pass
        
        # 验证调用
        mock_email.send.assert_called_once()
        mock_sms.send.assert_called_once()
    
    def test_mock_property_access(self):
        """测试Mock属性访问"""
        mock_user = Mock()
        
        # 配置属性
        mock_user.id = 1
        mock_user.email = "user@example.com"
        mock_user.phone = "1234567890"
        mock_user.device_id = "device123"
        
        # 配置动态属性
        type(mock_user).full_name = Mock(return_value="John Doe")
        
        # 测试属性访问
        self.assertEqual(mock_user.id, 1)
        self.assertEqual(mock_user.email, "user@example.com")
        self.assertEqual(mock_user.full_name, "John Doe")
        
        # 测试属性设置
        mock_user.status = "active"
        self.assertEqual(mock_user.status, "active")
    
    def test_mock_context_manager(self):
        """测试Mock上下文管理器"""
        mock_connection = MagicMock()
        
        # 配置上下文管理器
        mock_connection.__enter__.return_value = mock_connection
        mock_connection.__exit__.return_value = None
        
        # 使用上下文管理器
        with mock_connection as conn:
            conn.execute("SELECT 1")
            conn.commit()
        
        # 验证上下文管理器调用
        mock_connection.__enter__.assert_called_once()
        mock_connection.__exit__.assert_called_once()
        
        # 验证内部调用
        mock_connection.execute.assert_called_once_with("SELECT 1")
        mock_connection.commit.assert_called_once()
    
    def test_mock_async_methods(self):
        """测试Mock异步方法"""
        async def test_async():
            mock_async_client = Mock()
            
            # 配置异步方法
            async def async_side_effect(*args, **kwargs):
                return {"status": "success", "data": args[0]}
            
            mock_async_client.fetch_data = Mock(side_effect=async_side_effect)
            
            # 测试异步调用
            result = await mock_async_client.fetch_data("test_data")
            
            self.assertEqual(result["status"], "success")
            self.assertEqual(result["data"], "test_data")
            
            # 验证调用
            mock_async_client.fetch_data.assert_called_once_with("test_data")
        
        # 运行异步测试
        asyncio.run(test_async())

总结

unittest.mock模块是Python测试中不可或缺的工具,提供了强大而灵活的模拟测试能力。

主要特性

  1. Mock对象:创建可配置的模拟对象,替代真实依赖
  2. MagicMock:支持魔术方法的高级Mock对象
  3. patch装饰器:临时替换模块、类或方法
  4. spec参数:基于真实对象创建规范化Mock
  5. side_effect:模拟复杂行为和异常情况
  6. 调用验证:详细的调用跟踪和验证机制

最佳实践

  1. 使用spec:通过spec参数确保Mock接口的正确性
  2. 适度Mock:只Mock必要的依赖,避免过度Mock
  3. 验证调用:使用assert方法验证Mock的调用情况
  4. 清晰配置:明确配置Mock的返回值和副作用
  5. 异常测试:使用side_effect测试异常情况
  6. 上下文管理:合理使用patch上下文管理器

应用场景

  • 外部依赖隔离:数据库、网络服务、文件系统
  • 异常情况测试:网络错误、超时、资源不足
  • 性能测试:避免真实操作的性能开销
  • 并发测试:模拟复杂的并发场景
  • 集成测试:替换复杂的外部系统
  • 回归测试:确保代码修改不影响现有功能

通过熟练掌握unittest.mock,可以编写更可靠、更快速、更易维护的单元测试。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐