046-模拟测试unittest.mock
Python unittest.mock 模块提供了强大的测试工具,用于创建模拟对象、隔离测试代码和控制依赖项行为。核心功能包括:1) 创建可配置的Mock对象;2) 使用MagicMock支持魔术方法;3) 通过patch装饰器临时替换对象;4) 基于真实对象创建Mock(spec参数);5) 模拟异常和复杂行为(side_effect);6) 验证方法调用情况。示例展示了如何测试EmailSe
·
046-模拟测试unittest.mock
概述
unittest.mock是Python标准库中的模拟测试模块,提供了强大的Mock对象功能,用于在测试中替换和模拟真实对象的行为。Mock对象允许我们隔离被测试的代码,控制外部依赖的行为,从而编写更可靠和快速的单元测试。
主要特性
- Mock对象:创建可配置的模拟对象
- MagicMock:支持魔术方法的Mock对象
- patch装饰器:临时替换对象和属性
- spec参数:基于真实对象创建Mock
- side_effect:模拟异常和复杂行为
- assert方法:验证Mock对象的调用情况
基本Mock对象
Mock类基础
import unittest
from unittest.mock import Mock, MagicMock, patch, call
import requests
import json
from datetime import datetime
# 示例:简单的服务类
class EmailService:
"""邮件服务类"""
def __init__(self, smtp_server, port=587):
self.smtp_server = smtp_server
self.port = port
self.connected = False
def connect(self):
"""连接到SMTP服务器"""
# 模拟连接逻辑
if self.smtp_server and self.port:
self.connected = True
return True
return False
def send_email(self, to_email, subject, body):
"""发送邮件"""
if not self.connected:
raise ConnectionError("未连接到SMTP服务器")
# 模拟发送邮件
email_data = {
'to': to_email,
'subject': subject,
'body': body,
'timestamp': datetime.now().isoformat()
}
# 这里通常会有实际的SMTP发送逻辑
return email_data
def disconnect(self):
"""断开连接"""
self.connected = False
class UserNotificationService:
"""用户通知服务"""
def __init__(self, email_service):
self.email_service = email_service
def send_welcome_email(self, user_email, username):
"""发送欢迎邮件"""
subject = f"欢迎, {username}!"
body = f"亲爱的 {username},\n\n欢迎加入我们的平台!"
try:
self.email_service.connect()
result = self.email_service.send_email(user_email, subject, body)
self.email_service.disconnect()
return result
except Exception as e:
return {'error': str(e)}
def send_password_reset_email(self, user_email, reset_token):
"""发送密码重置邮件"""
subject = "密码重置请求"
body = f"您的密码重置令牌是: {reset_token}"
try:
self.email_service.connect()
result = self.email_service.send_email(user_email, subject, body)
self.email_service.disconnect()
return result
except Exception as e:
return {'error': str(e)}
class TestBasicMock(unittest.TestCase):
"""基本Mock测试"""
def test_mock_creation(self):
"""测试Mock对象创建"""
# 创建基本Mock对象
mock_obj = Mock()
# Mock对象可以调用任何方法
result = mock_obj.some_method()
self.assertIsInstance(result, Mock)
# 可以设置返回值
mock_obj.some_method.return_value = "mocked result"
self.assertEqual(mock_obj.some_method(), "mocked result")
def test_mock_attributes(self):
"""测试Mock对象属性"""
mock_obj = Mock()
# 设置属性
mock_obj.name = "Test Mock"
mock_obj.value = 42
self.assertEqual(mock_obj.name, "Test Mock")
self.assertEqual(mock_obj.value, 42)
# 动态属性也会返回Mock对象
self.assertIsInstance(mock_obj.dynamic_attr, Mock)
def test_mock_method_calls(self):
"""测试Mock方法调用"""
mock_obj = Mock()
# 调用方法
mock_obj.method1("arg1", "arg2")
mock_obj.method2(key="value")
# 验证调用
mock_obj.method1.assert_called_with("arg1", "arg2")
mock_obj.method2.assert_called_with(key="value")
# 验证调用次数
self.assertEqual(mock_obj.method1.call_count, 1)
self.assertEqual(mock_obj.method2.call_count, 1)
def test_mock_side_effect(self):
"""测试Mock副作用"""
mock_obj = Mock()
# 设置副作用为异常
mock_obj.failing_method.side_effect = ValueError("模拟错误")
with self.assertRaises(ValueError):
mock_obj.failing_method()
# 设置副作用为函数
def custom_side_effect(x):
return x * 2
mock_obj.doubler.side_effect = custom_side_effect
self.assertEqual(mock_obj.doubler(5), 10)
# 设置副作用为值列表
mock_obj.sequence.side_effect = [1, 2, 3]
self.assertEqual(mock_obj.sequence(), 1)
self.assertEqual(mock_obj.sequence(), 2)
self.assertEqual(mock_obj.sequence(), 3)
class TestEmailServiceMock(unittest.TestCase):
"""邮件服务Mock测试"""
def test_user_notification_with_mock(self):
"""使用Mock测试用户通知服务"""
# 创建Mock邮件服务
mock_email_service = Mock(spec=EmailService)
# 配置Mock行为
mock_email_service.connect.return_value = True
mock_email_service.send_email.return_value = {
'to': 'user@example.com',
'subject': '欢迎, John!',
'body': '亲爱的 John,\n\n欢迎加入我们的平台!',
'timestamp': '2023-01-01T12:00:00'
}
# 创建通知服务
notification_service = UserNotificationService(mock_email_service)
# 测试发送欢迎邮件
result = notification_service.send_welcome_email('user@example.com', 'John')
# 验证结果
self.assertIn('to', result)
self.assertEqual(result['to'], 'user@example.com')
# 验证Mock调用
mock_email_service.connect.assert_called_once()
mock_email_service.send_email.assert_called_once_with(
'user@example.com',
'欢迎, John!',
'亲爱的 John,\n\n欢迎加入我们的平台!'
)
mock_email_service.disconnect.assert_called_once()
def test_email_service_connection_failure(self):
"""测试邮件服务连接失败"""
mock_email_service = Mock(spec=EmailService)
# 模拟连接失败
mock_email_service.connect.side_effect = ConnectionError("连接失败")
notification_service = UserNotificationService(mock_email_service)
# 测试连接失败情况
result = notification_service.send_welcome_email('user@example.com', 'John')
# 验证错误处理
self.assertIn('error', result)
self.assertIn('连接失败', result['error'])
def test_email_service_send_failure(self):
"""测试邮件发送失败"""
mock_email_service = Mock(spec=EmailService)
# 配置连接成功但发送失败
mock_email_service.connect.return_value = True
mock_email_service.send_email.side_effect = Exception("发送失败")
notification_service = UserNotificationService(mock_email_service)
result = notification_service.send_password_reset_email('user@example.com', 'reset123')
# 验证错误处理
self.assertIn('error', result)
self.assertIn('发送失败', result['error'])
MagicMock和魔术方法
MagicMock高级功能
import unittest
from unittest.mock import MagicMock, Mock
import operator
class DataContainer:
"""数据容器类"""
def __init__(self, data=None):
self.data = data or []
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def __setitem__(self, index, value):
self.data[index] = value
def __contains__(self, item):
return item in self.data
def __iter__(self):
return iter(self.data)
def __str__(self):
return f"DataContainer({self.data})"
def __repr__(self):
return f"DataContainer(data={self.data!r})"
class Calculator:
"""计算器类"""
def __init__(self, data_source):
self.data_source = data_source
def sum_all(self):
"""计算所有数据的和"""
return sum(self.data_source)
def get_item_at(self, index):
"""获取指定索引的项目"""
return self.data_source[index]
def count_items(self):
"""计算项目数量"""
return len(self.data_source)
def contains_value(self, value):
"""检查是否包含指定值"""
return value in self.data_source
class TestMagicMock(unittest.TestCase):
"""MagicMock测试"""
def test_magic_mock_basic(self):
"""测试MagicMock基本功能"""
# MagicMock支持魔术方法
magic_mock = MagicMock()
# 设置魔术方法的返回值
magic_mock.__len__.return_value = 5
magic_mock.__getitem__.return_value = "mocked_item"
magic_mock.__contains__.return_value = True
# 测试魔术方法
self.assertEqual(len(magic_mock), 5)
self.assertEqual(magic_mock[0], "mocked_item")
self.assertTrue("anything" in magic_mock)
def test_magic_mock_with_calculator(self):
"""使用MagicMock测试计算器"""
# 创建MagicMock数据源
mock_data_source = MagicMock()
# 配置魔术方法
mock_data_source.__iter__.return_value = iter([1, 2, 3, 4, 5])
mock_data_source.__len__.return_value = 5
mock_data_source.__getitem__.return_value = 10
mock_data_source.__contains__.return_value = True
# 创建计算器
calculator = Calculator(mock_data_source)
# 测试各种操作
self.assertEqual(calculator.sum_all(), 15) # sum([1,2,3,4,5])
self.assertEqual(calculator.count_items(), 5)
self.assertEqual(calculator.get_item_at(0), 10)
self.assertTrue(calculator.contains_value(3))
# 验证魔术方法调用
mock_data_source.__iter__.assert_called_once()
mock_data_source.__len__.assert_called_once()
mock_data_source.__getitem__.assert_called_once_with(0)
mock_data_source.__contains__.assert_called_once_with(3)
def test_magic_mock_side_effects(self):
"""测试MagicMock副作用"""
mock_container = MagicMock()
# 设置__getitem__的副作用
def getitem_side_effect(index):
if index < 0 or index >= 3:
raise IndexError("索引超出范围")
return f"item_{index}"
mock_container.__getitem__.side_effect = getitem_side_effect
mock_container.__len__.return_value = 3
calculator = Calculator(mock_container)
# 测试正常索引
self.assertEqual(calculator.get_item_at(0), "item_0")
self.assertEqual(calculator.get_item_at(2), "item_2")
# 测试异常情况
with self.assertRaises(IndexError):
calculator.get_item_at(5)
def test_magic_mock_call_tracking(self):
"""测试MagicMock调用跟踪"""
mock_obj = MagicMock()
# 进行各种调用
len(mock_obj)
mock_obj[0]
mock_obj[1] = "value"
"test" in mock_obj
str(mock_obj)
# 验证调用记录
mock_obj.__len__.assert_called_once()
mock_obj.__getitem__.assert_called_once_with(0)
mock_obj.__setitem__.assert_called_once_with(1, "value")
mock_obj.__contains__.assert_called_once_with("test")
mock_obj.__str__.assert_called_once()
def test_magic_mock_comparison_operations(self):
"""测试MagicMock比较操作"""
mock_obj = MagicMock()
# 配置比较操作
mock_obj.__eq__.return_value = True
mock_obj.__lt__.return_value = False
mock_obj.__gt__.return_value = True
# 测试比较
self.assertTrue(mock_obj == "anything")
self.assertFalse(mock_obj < 10)
self.assertTrue(mock_obj > 5)
# 验证调用
mock_obj.__eq__.assert_called_with("anything")
mock_obj.__lt__.assert_called_with(10)
mock_obj.__gt__.assert_called_with(5)
def test_magic_mock_arithmetic_operations(self):
"""测试MagicMock算术操作"""
mock_num = MagicMock()
# 配置算术操作
mock_num.__add__.return_value = 15
mock_num.__sub__.return_value = 5
mock_num.__mul__.return_value = 50
mock_num.__truediv__.return_value = 2.5
# 测试算术操作
self.assertEqual(mock_num + 5, 15)
self.assertEqual(mock_num - 5, 5)
self.assertEqual(mock_num * 5, 50)
self.assertEqual(mock_num / 4, 2.5)
# 验证调用
mock_num.__add__.assert_called_with(5)
mock_num.__sub__.assert_called_with(5)
mock_num.__mul__.assert_called_with(5)
mock_num.__truediv__.assert_called_with(4)
patch装饰器和上下文管理器
patch装饰器详解
import unittest
from unittest.mock import patch, Mock, MagicMock
import requests
import os
import json
from datetime import datetime
# 示例:API客户端
class APIClient:
"""API客户端"""
def __init__(self, base_url, api_key):
self.base_url = base_url
self.api_key = api_key
def get_user(self, user_id):
"""获取用户信息"""
url = f"{self.base_url}/users/{user_id}"
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.get(url, headers=headers)
response.raise_for_status()
return response.json()
def create_user(self, user_data):
"""创建用户"""
url = f"{self.base_url}/users"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=user_data, headers=headers)
response.raise_for_status()
return response.json()
def upload_file(self, file_path, user_id):
"""上传文件"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
url = f"{self.base_url}/users/{user_id}/files"
headers = {"Authorization": f"Bearer {self.api_key}"}
with open(file_path, 'rb') as f:
files = {'file': f}
response = requests.post(url, files=files, headers=headers)
response.raise_for_status()
return response.json()
class FileManager:
"""文件管理器"""
def __init__(self, base_path):
self.base_path = base_path
def create_file(self, filename, content):
"""创建文件"""
file_path = os.path.join(self.base_path, filename)
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def read_file(self, filename):
"""读取文件"""
file_path = os.path.join(self.base_path, filename)
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def delete_file(self, filename):
"""删除文件"""
file_path = os.path.join(self.base_path, filename)
if os.path.exists(file_path):
os.remove(file_path)
return True
return False
def list_files(self):
"""列出文件"""
if not os.path.exists(self.base_path):
return []
return [f for f in os.listdir(self.base_path)
if os.path.isfile(os.path.join(self.base_path, f))]
class TestPatchDecorator(unittest.TestCase):
"""patch装饰器测试"""
@patch('requests.get')
def test_api_client_get_user(self, mock_get):
"""测试API客户端获取用户"""
# 配置mock响应
mock_response = Mock()
mock_response.json.return_value = {
'id': 1,
'name': 'John Doe',
'email': 'john@example.com'
}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
# 创建API客户端并测试
client = APIClient("https://api.example.com", "test_key")
result = client.get_user(1)
# 验证结果
self.assertEqual(result['id'], 1)
self.assertEqual(result['name'], 'John Doe')
# 验证请求调用
mock_get.assert_called_once_with(
"https://api.example.com/users/1",
headers={"Authorization": "Bearer test_key"}
)
@patch('requests.post')
def test_api_client_create_user(self, mock_post):
"""测试API客户端创建用户"""
# 配置mock响应
mock_response = Mock()
mock_response.json.return_value = {
'id': 2,
'name': 'Jane Doe',
'email': 'jane@example.com',
'created_at': '2023-01-01T12:00:00Z'
}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
# 测试创建用户
client = APIClient("https://api.example.com", "test_key")
user_data = {
'name': 'Jane Doe',
'email': 'jane@example.com'
}
result = client.create_user(user_data)
# 验证结果
self.assertEqual(result['id'], 2)
self.assertIn('created_at', result)
# 验证请求调用
mock_post.assert_called_once_with(
"https://api.example.com/users",
json=user_data,
headers={
"Authorization": "Bearer test_key",
"Content-Type": "application/json"
}
)
@patch('requests.get')
def test_api_client_error_handling(self, mock_get):
"""测试API客户端错误处理"""
# 配置mock抛出异常
mock_get.side_effect = requests.exceptions.HTTPError("404 Not Found")
client = APIClient("https://api.example.com", "test_key")
# 测试异常处理
with self.assertRaises(requests.exceptions.HTTPError):
client.get_user(999)
@patch('os.path.exists')
@patch('builtins.open', create=True)
@patch('requests.post')
def test_api_client_upload_file(self, mock_post, mock_open, mock_exists):
"""测试API客户端文件上传"""
# 配置mocks
mock_exists.return_value = True
mock_file = Mock()
mock_open.return_value.__enter__.return_value = mock_file
mock_response = Mock()
mock_response.json.return_value = {
'file_id': 'file123',
'filename': 'test.txt',
'size': 1024
}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
# 测试文件上传
client = APIClient("https://api.example.com", "test_key")
result = client.upload_file("/path/to/test.txt", 1)
# 验证结果
self.assertEqual(result['file_id'], 'file123')
# 验证调用
mock_exists.assert_called_once_with("/path/to/test.txt")
mock_open.assert_called_once_with("/path/to/test.txt", 'rb')
mock_post.assert_called_once()
class TestPatchContextManager(unittest.TestCase):
"""patch上下文管理器测试"""
def test_file_manager_with_patch_context(self):
"""使用patch上下文管理器测试文件管理器"""
file_manager = FileManager("/test/path")
# 使用patch作为上下文管理器
with patch('os.makedirs') as mock_makedirs, \
patch('builtins.open', create=True) as mock_open, \
patch('os.path.dirname') as mock_dirname:
# 配置mocks
mock_dirname.return_value = "/test/path"
mock_file = Mock()
mock_open.return_value.__enter__.return_value = mock_file
# 测试创建文件
result = file_manager.create_file("test.txt", "Hello, World!")
# 验证调用
mock_makedirs.assert_called_once_with("/test/path", exist_ok=True)
mock_open.assert_called_once_with("/test/path/test.txt", 'w', encoding='utf-8')
mock_file.write.assert_called_once_with("Hello, World!")
def test_multiple_patches(self):
"""测试多个patch"""
file_manager = FileManager("/test/path")
with patch('os.path.exists') as mock_exists, \
patch('os.listdir') as mock_listdir, \
patch('os.path.isfile') as mock_isfile, \
patch('os.path.join') as mock_join:
# 配置mocks
mock_exists.return_value = True
mock_listdir.return_value = ['file1.txt', 'file2.txt', 'dir1']
mock_isfile.side_effect = lambda path: path.endswith('.txt')
mock_join.side_effect = lambda base, name: f"{base}/{name}"
# 测试列出文件
files = file_manager.list_files()
# 验证结果
self.assertEqual(files, ['file1.txt', 'file2.txt'])
# 验证调用
mock_exists.assert_called_once_with("/test/path")
mock_listdir.assert_called_once_with("/test/path")
self.assertEqual(mock_isfile.call_count, 3)
class TestPatchObject(unittest.TestCase):
"""patch.object测试"""
def test_patch_object_method(self):
"""测试patch.object方法"""
file_manager = FileManager("/test/path")
# 使用patch.object替换特定方法
with patch.object(file_manager, 'read_file') as mock_read:
mock_read.return_value = "Mocked content"
# 测试被patch的方法
content = file_manager.read_file("test.txt")
self.assertEqual(content, "Mocked content")
# 验证调用
mock_read.assert_called_once_with("test.txt")
@patch.object(FileManager, 'create_file')
def test_patch_object_decorator(self, mock_create):
"""测试patch.object装饰器"""
mock_create.return_value = "/test/path/created.txt"
file_manager = FileManager("/test/path")
result = file_manager.create_file("created.txt", "content")
self.assertEqual(result, "/test/path/created.txt")
mock_create.assert_called_once_with("created.txt", "content")
class TestPatchMultiple(unittest.TestCase):
"""patch.multiple测试"""
@patch.multiple('os', makedirs=Mock(), remove=Mock())
def test_patch_multiple_decorator(self):
"""测试patch.multiple装饰器"""
file_manager = FileManager("/test/path")
# 这些调用会使用mocked版本
file_manager.delete_file("test.txt")
# 验证调用
os.remove.assert_called()
def test_patch_multiple_context(self):
"""测试patch.multiple上下文管理器"""
with patch.multiple('os.path',
exists=Mock(return_value=True),
join=Mock(side_effect=lambda *args: '/'.join(args))):
file_manager = FileManager("/test/path")
# 使用mocked版本
result = file_manager.delete_file("test.txt")
# 验证调用
os.path.exists.assert_called()
spec参数和自动规范
使用spec创建规范化Mock
import unittest
from unittest.mock import Mock, MagicMock, patch, create_autospec
from datetime import datetime
import json
class DatabaseConnection:
"""数据库连接类"""
def __init__(self, host, port, database):
self.host = host
self.port = port
self.database = database
self.connected = False
def connect(self):
"""连接数据库"""
# 模拟连接逻辑
self.connected = True
return True
def disconnect(self):
"""断开连接"""
self.connected = False
def execute_query(self, query, params=None):
"""执行查询"""
if not self.connected:
raise ConnectionError("数据库未连接")
# 模拟查询执行
return {"rows": [], "affected": 0}
def execute_transaction(self, queries):
"""执行事务"""
if not self.connected:
raise ConnectionError("数据库未连接")
results = []
for query in queries:
result = self.execute_query(query)
results.append(result)
return results
def get_table_info(self, table_name):
"""获取表信息"""
query = f"DESCRIBE {table_name}"
return self.execute_query(query)
class UserRepository:
"""用户仓库类"""
def __init__(self, db_connection):
self.db = db_connection
def create_user(self, user_data):
"""创建用户"""
query = "INSERT INTO users (name, email) VALUES (?, ?)"
params = (user_data['name'], user_data['email'])
result = self.db.execute_query(query, params)
return result
def get_user_by_id(self, user_id):
"""根据ID获取用户"""
query = "SELECT * FROM users WHERE id = ?"
params = (user_id,)
result = self.db.execute_query(query, params)
return result
def update_user(self, user_id, user_data):
"""更新用户"""
query = "UPDATE users SET name = ?, email = ? WHERE id = ?"
params = (user_data['name'], user_data['email'], user_id)
result = self.db.execute_query(query, params)
return result
def delete_user(self, user_id):
"""删除用户"""
query = "DELETE FROM users WHERE id = ?"
params = (user_id,)
result = self.db.execute_query(query, params)
return result
class TestSpecParameter(unittest.TestCase):
"""spec参数测试"""
def test_mock_without_spec(self):
"""测试没有spec的Mock"""
# 没有spec的Mock可以调用任何方法
mock_db = Mock()
# 这些调用都会成功,即使原始类没有这些方法
mock_db.nonexistent_method()
mock_db.another_fake_method("arg")
# 这可能导致测试中的错误被隐藏
self.assertTrue(True) # 测试总是通过
def test_mock_with_spec(self):
"""测试带spec的Mock"""
# 使用spec限制Mock的接口
mock_db = Mock(spec=DatabaseConnection)
# 可以调用真实类的方法
mock_db.connect()
mock_db.execute_query("SELECT * FROM users")
# 尝试调用不存在的方法会引发AttributeError
with self.assertRaises(AttributeError):
mock_db.nonexistent_method()
def test_mock_with_spec_set(self):
"""测试带spec_set的Mock"""
# spec_set更严格,不允许设置新属性
mock_db = Mock(spec_set=DatabaseConnection)
# 可以设置现有属性
mock_db.connected = True
# 尝试设置不存在的属性会引发AttributeError
with self.assertRaises(AttributeError):
mock_db.new_attribute = "value"
def test_autospec(self):
"""测试autospec"""
# create_autospec自动创建规范化的Mock
mock_db = create_autospec(DatabaseConnection)
# 配置方法返回值
mock_db.connect.return_value = True
mock_db.execute_query.return_value = {"rows": [{"id": 1, "name": "John"}]}
# 创建用户仓库
user_repo = UserRepository(mock_db)
# 测试创建用户
user_data = {"name": "John", "email": "john@example.com"}
result = user_repo.create_user(user_data)
# 验证调用
mock_db.execute_query.assert_called_once_with(
"INSERT INTO users (name, email) VALUES (?, ?)",
("John", "john@example.com")
)
# 验证返回值
self.assertEqual(result["rows"][0]["name"], "John")
def test_autospec_with_instance(self):
"""测试基于实例的autospec"""
# 创建真实实例
real_db = DatabaseConnection("localhost", 5432, "testdb")
# 基于实例创建autospec
mock_db = create_autospec(real_db, spec_set=True)
# 配置Mock行为
mock_db.connect.return_value = True
mock_db.execute_query.return_value = {"rows": [], "affected": 1}
user_repo = UserRepository(mock_db)
# 测试更新用户
user_data = {"name": "Jane", "email": "jane@example.com"}
result = user_repo.update_user(1, user_data)
# 验证调用
mock_db.execute_query.assert_called_once_with(
"UPDATE users SET name = ?, email = ? WHERE id = ?",
("Jane", "jane@example.com", 1)
)
def test_spec_with_inheritance(self):
"""测试继承类的spec"""
class ExtendedDatabaseConnection(DatabaseConnection):
def backup_database(self):
"""备份数据库"""
return {"backup_id": "backup_123"}
def restore_database(self, backup_id):
"""恢复数据库"""
return {"status": "restored"}
# 使用扩展类作为spec
mock_db = Mock(spec=ExtendedDatabaseConnection)
# 可以调用基类和扩展类的方法
mock_db.connect()
mock_db.backup_database()
mock_db.restore_database("backup_123")
# 验证方法存在
self.assertTrue(hasattr(mock_db, 'connect'))
self.assertTrue(hasattr(mock_db, 'backup_database'))
self.assertTrue(hasattr(mock_db, 'restore_database'))
class TestAdvancedSpecUsage(unittest.TestCase):
"""高级spec使用测试"""
def test_partial_mock_with_spec(self):
"""测试部分Mock与spec"""
# 创建真实对象
real_db = DatabaseConnection("localhost", 5432, "testdb")
# 只Mock特定方法
with patch.object(real_db, 'execute_query', spec=True) as mock_execute:
mock_execute.return_value = {"rows": [{"id": 1}], "affected": 1}
user_repo = UserRepository(real_db)
result = user_repo.get_user_by_id(1)
# 验证Mock调用
mock_execute.assert_called_once_with(
"SELECT * FROM users WHERE id = ?",
(1,)
)
# 验证返回值
self.assertEqual(result["rows"][0]["id"], 1)
def test_nested_spec(self):
"""测试嵌套对象的spec"""
class DatabaseManager:
def __init__(self):
self.connection = DatabaseConnection("localhost", 5432, "db")
def get_connection(self):
return self.connection
# 创建嵌套Mock
mock_manager = create_autospec(DatabaseManager)
mock_connection = create_autospec(DatabaseConnection)
# 配置嵌套Mock
mock_manager.get_connection.return_value = mock_connection
mock_connection.execute_query.return_value = {"rows": []}
# 测试嵌套调用
connection = mock_manager.get_connection()
result = connection.execute_query("SELECT 1")
# 验证调用
mock_manager.get_connection.assert_called_once()
mock_connection.execute_query.assert_called_once_with("SELECT 1")
def test_spec_with_properties(self):
"""测试带属性的spec"""
class ConfigurableDB(DatabaseConnection):
@property
def connection_string(self):
return f"{self.host}:{self.port}/{self.database}"
@property
def is_connected(self):
return self.connected
@is_connected.setter
def is_connected(self, value):
self.connected = value
# 创建带属性的Mock
mock_db = create_autospec(ConfigurableDB)
# 配置属性
mock_db.connection_string = "localhost:5432/testdb"
mock_db.is_connected = True
# 测试属性访问
self.assertEqual(mock_db.connection_string, "localhost:5432/testdb")
self.assertTrue(mock_db.is_connected)
def test_spec_validation_errors(self):
"""测试spec验证错误"""
mock_db = Mock(spec=DatabaseConnection)
# 正确的方法调用
mock_db.connect()
mock_db.execute_query("SELECT 1")
# 错误的方法名会引发AttributeError
with self.assertRaises(AttributeError):
mock_db.invalid_method()
# 但是可以访问不存在的属性(除非使用spec_set)
mock_db.some_attribute = "value" # 这不会引发错误
# 使用spec_set会更严格
strict_mock = Mock(spec_set=DatabaseConnection)
with self.assertRaises(AttributeError):
strict_mock.invalid_attribute = "value"
高级Mock技巧
复杂场景的Mock策略
import unittest
from unittest.mock import Mock, MagicMock, patch, call, ANY
import asyncio
from datetime import datetime, timedelta
import json
class CacheService:
"""缓存服务"""
def __init__(self, redis_client):
self.redis = redis_client
def get(self, key):
"""获取缓存值"""
return self.redis.get(key)
def set(self, key, value, expire=None):
"""设置缓存值"""
if expire:
return self.redis.setex(key, expire, value)
return self.redis.set(key, value)
def delete(self, key):
"""删除缓存"""
return self.redis.delete(key)
class NotificationService:
"""通知服务"""
def __init__(self, email_client, sms_client, push_client):
self.email = email_client
self.sms = sms_client
self.push = push_client
def send_notification(self, user, message, channels=None):
"""发送通知"""
if channels is None:
channels = ['email']
results = {}
if 'email' in channels:
results['email'] = self.email.send(user.email, message)
if 'sms' in channels:
results['sms'] = self.sms.send(user.phone, message)
if 'push' in channels:
results['push'] = self.push.send(user.device_id, message)
return results
class OrderService:
"""订单服务"""
def __init__(self, db, cache, notification):
self.db = db
self.cache = cache
self.notification = notification
def create_order(self, user, items):
"""创建订单"""
# 计算总价
total = sum(item['price'] * item['quantity'] for item in items)
# 创建订单数据
order_data = {
'user_id': user.id,
'items': items,
'total': total,
'status': 'pending',
'created_at': datetime.now().isoformat()
}
# 保存到数据库
order_id = self.db.insert('orders', order_data)
order_data['id'] = order_id
# 缓存订单
cache_key = f"order:{order_id}"
self.cache.set(cache_key, json.dumps(order_data), expire=3600)
# 发送通知
message = f"订单 {order_id} 已创建,总金额: {total}"
self.notification.send_notification(user, message, ['email', 'push'])
return order_data
def get_order(self, order_id):
"""获取订单"""
# 先从缓存获取
cache_key = f"order:{order_id}"
cached_order = self.cache.get(cache_key)
if cached_order:
return json.loads(cached_order)
# 从数据库获取
order = self.db.select('orders', {'id': order_id})
if order:
# 更新缓存
self.cache.set(cache_key, json.dumps(order), expire=3600)
return order
class TestAdvancedMockTechniques(unittest.TestCase):
"""高级Mock技巧测试"""
def test_mock_call_tracking(self):
"""测试Mock调用跟踪"""
mock_redis = Mock()
cache_service = CacheService(mock_redis)
# 执行多个操作
cache_service.set("key1", "value1")
cache_service.set("key2", "value2", expire=60)
cache_service.get("key1")
cache_service.delete("key1")
# 验证调用顺序
expected_calls = [
call.set("key1", "value1"),
call.setex("key2", 60, "value2"),
call.get("key1"),
call.delete("key1")
]
# 验证所有调用
mock_redis.assert_has_calls(expected_calls, any_order=False)
# 验证特定方法的调用次数
self.assertEqual(mock_redis.set.call_count, 1)
self.assertEqual(mock_redis.setex.call_count, 1)
self.assertEqual(mock_redis.get.call_count, 1)
self.assertEqual(mock_redis.delete.call_count, 1)
def test_mock_with_any_matcher(self):
"""测试使用ANY匹配器"""
mock_db = Mock()
mock_cache = Mock()
mock_notification = Mock()
order_service = OrderService(mock_db, mock_cache, mock_notification)
# 创建用户和订单项
user = Mock()
user.id = 1
user.email = "user@example.com"
user.device_id = "device123"
items = [
{'price': 10.0, 'quantity': 2},
{'price': 5.0, 'quantity': 1}
]
# 配置Mock返回值
mock_db.insert.return_value = 12345
# 创建订单
order_service.create_order(user, items)
# 使用ANY匹配器验证调用
mock_db.insert.assert_called_once_with('orders', ANY)
mock_cache.set.assert_called_once_with(
"order:12345",
ANY, # JSON字符串,具体内容可能变化
expire=3600
)
# 验证通知调用
mock_notification.send_notification.assert_called_once_with(
user,
ANY, # 消息内容包含订单ID和总金额
['email', 'push']
)
def test_mock_side_effect_sequence(self):
"""测试Mock副作用序列"""
mock_cache = Mock()
# 设置get方法的副作用序列
mock_cache.get.side_effect = [
None, # 第一次调用返回None(缓存未命中)
'{"id": 1, "status": "pending"}', # 第二次调用返回缓存数据
'{"id": 1, "status": "confirmed"}' # 第三次调用返回更新后的数据
]
cache_service = CacheService(mock_cache)
# 第一次获取(缓存未命中)
result1 = cache_service.get("order:1")
self.assertIsNone(result1)
# 第二次获取(缓存命中)
result2 = cache_service.get("order:1")
self.assertEqual(result2, '{"id": 1, "status": "pending"}')
# 第三次获取(缓存更新)
result3 = cache_service.get("order:1")
self.assertEqual(result3, '{"id": 1, "status": "confirmed"}')
# 验证调用次数
self.assertEqual(mock_cache.get.call_count, 3)
def test_mock_side_effect_function(self):
"""测试Mock副作用函数"""
mock_db = Mock()
# 模拟数据库数据
orders_data = {
1: {'id': 1, 'status': 'pending', 'total': 100},
2: {'id': 2, 'status': 'confirmed', 'total': 200}
}
def db_select_side_effect(table, conditions):
"""模拟数据库查询"""
if table == 'orders':
order_id = conditions.get('id')
return orders_data.get(order_id)
return None
mock_db.select.side_effect = db_select_side_effect
mock_cache = Mock()
mock_cache.get.return_value = None # 缓存未命中
mock_notification = Mock()
order_service = OrderService(mock_db, mock_cache, mock_notification)
# 测试获取存在的订单
order1 = order_service.get_order(1)
self.assertEqual(order1['id'], 1)
self.assertEqual(order1['status'], 'pending')
# 测试获取不存在的订单
order_none = order_service.get_order(999)
self.assertIsNone(order_none)
# 验证数据库调用
expected_calls = [
call('orders', {'id': 1}),
call('orders', {'id': 999})
]
mock_db.select.assert_has_calls(expected_calls)
def test_mock_exception_handling(self):
"""测试Mock异常处理"""
mock_email = Mock()
mock_sms = Mock()
mock_push = Mock()
# 配置不同的异常
mock_email.send.side_effect = ConnectionError("邮件服务不可用")
mock_sms.send.return_value = {"status": "success"}
mock_push.send.side_effect = TimeoutError("推送服务超时")
notification_service = NotificationService(mock_email, mock_sms, mock_push)
user = Mock()
user.email = "user@example.com"
user.phone = "1234567890"
user.device_id = "device123"
# 测试异常处理
try:
results = notification_service.send_notification(
user,
"测试消息",
['email', 'sms', 'push']
)
# 如果没有异常处理,这里会失败
self.fail("应该抛出异常")
except (ConnectionError, TimeoutError):
# 预期的异常
pass
# 验证调用
mock_email.send.assert_called_once()
mock_sms.send.assert_called_once()
def test_mock_property_access(self):
"""测试Mock属性访问"""
mock_user = Mock()
# 配置属性
mock_user.id = 1
mock_user.email = "user@example.com"
mock_user.phone = "1234567890"
mock_user.device_id = "device123"
# 配置动态属性
type(mock_user).full_name = Mock(return_value="John Doe")
# 测试属性访问
self.assertEqual(mock_user.id, 1)
self.assertEqual(mock_user.email, "user@example.com")
self.assertEqual(mock_user.full_name, "John Doe")
# 测试属性设置
mock_user.status = "active"
self.assertEqual(mock_user.status, "active")
def test_mock_context_manager(self):
"""测试Mock上下文管理器"""
mock_connection = MagicMock()
# 配置上下文管理器
mock_connection.__enter__.return_value = mock_connection
mock_connection.__exit__.return_value = None
# 使用上下文管理器
with mock_connection as conn:
conn.execute("SELECT 1")
conn.commit()
# 验证上下文管理器调用
mock_connection.__enter__.assert_called_once()
mock_connection.__exit__.assert_called_once()
# 验证内部调用
mock_connection.execute.assert_called_once_with("SELECT 1")
mock_connection.commit.assert_called_once()
def test_mock_async_methods(self):
"""测试Mock异步方法"""
async def test_async():
mock_async_client = Mock()
# 配置异步方法
async def async_side_effect(*args, **kwargs):
return {"status": "success", "data": args[0]}
mock_async_client.fetch_data = Mock(side_effect=async_side_effect)
# 测试异步调用
result = await mock_async_client.fetch_data("test_data")
self.assertEqual(result["status"], "success")
self.assertEqual(result["data"], "test_data")
# 验证调用
mock_async_client.fetch_data.assert_called_once_with("test_data")
# 运行异步测试
asyncio.run(test_async())
总结
unittest.mock模块是Python测试中不可或缺的工具,提供了强大而灵活的模拟测试能力。
主要特性
- Mock对象:创建可配置的模拟对象,替代真实依赖
- MagicMock:支持魔术方法的高级Mock对象
- patch装饰器:临时替换模块、类或方法
- spec参数:基于真实对象创建规范化Mock
- side_effect:模拟复杂行为和异常情况
- 调用验证:详细的调用跟踪和验证机制
最佳实践
- 使用spec:通过spec参数确保Mock接口的正确性
- 适度Mock:只Mock必要的依赖,避免过度Mock
- 验证调用:使用assert方法验证Mock的调用情况
- 清晰配置:明确配置Mock的返回值和副作用
- 异常测试:使用side_effect测试异常情况
- 上下文管理:合理使用patch上下文管理器
应用场景
- 外部依赖隔离:数据库、网络服务、文件系统
- 异常情况测试:网络错误、超时、资源不足
- 性能测试:避免真实操作的性能开销
- 并发测试:模拟复杂的并发场景
- 集成测试:替换复杂的外部系统
- 回归测试:确保代码修改不影响现有功能
通过熟练掌握unittest.mock,可以编写更可靠、更快速、更易维护的单元测试。
更多推荐


所有评论(0)