一. 架构与实现:总览设计框架,深入源码细节

SGI-STL30版本源代码,map和set的源代码在map/set/stl_map.h/stl_set.h/stl_tree.h等几个头文件中。map和set的实现框架核心部分截取下来如下:

// set
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_set.h>
#include <stl_multiset.h>

// map
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_map.h>
#include <stl_multimap.h>

// stl_set.h
template <class Key, class Compare = less<Key>, class Alloc = alloc>
class set {
public:
	// typedefs:
	typedef Key key_type;
	typedef Key value_type;
private:
	typedef rb_tree<key_type, value_type,
		identity<value_type>, key_compare, Alloc> rep_type;
	rep_type t; // red-black tree representing set
};

// stl_map.h
template <class Key, class T, class Compare = less<Key>, class Alloc = alloc>
class map {
public:
	// typedefs:
	typedef Key key_type;
	typedef T mapped_type;
	typedef pair<const Key, T> value_type;
private:
	typedef rb_tree<key_type, value_type,
		select1st<value_type>, key_compare, Alloc> rep_type;
	rep_type t; // red-black tree representing map
};

// stl_tree.h
struct __rb_tree_node_base
{
	typedef __rb_tree_color_type color_type;
	typedef __rb_tree_node_base* base_ptr;
	color_type color;
	base_ptr parent;
	base_ptr left;
	base_ptr right;
};

// stl_tree.h
template <class Key, class Value, class KeyOfValue, class Compare, class Alloc
	= alloc>
class rb_tree {
protected:
	typedef void* void_pointer;
	typedef __rb_tree_node_base* base_ptr;
	typedef __rb_tree_node<Value> rb_tree_node;
	typedef rb_tree_node* link_type;
	typedef Key key_type;
	typedef Value value_type;
public:
	// insert⽤的是第⼆个模板参数左形参
	pair<iterator, bool> insert_unique(const value_type& x);
	// erase和find⽤第⼀个模板参数做形参
	size_type erase(const key_type& x);
	iterator find(const key_type& x);
protected:
	size_type node_count; // keeps track of size of tree
	link_type header;
};
template <class Value>
struct __rb_tree_node : public __rb_tree_node_base
{
	typedef __rb_tree_node<Value>* link_type;
	Value value_field;
};
  • 通过下图对框架的分析,我们可以看到源码中 rb_tree 用了一个巧妙的泛型思想实现,rb_tree 实现key的搜索场景,还是key/value的搜索场景不是直接写死的,而是由第二个模板参数Value决定 _rb_tree_node 中存储的数据类型。
  • set实例化 rb_tree 时第二个模板参数给的是Key,map实例化 rb_tree 时第二个模板参数给的时pair<const key,T>,这样一颗红黑树既可以实现key搜索场景,也可以实现key/value搜索场景的map。
  • rb_tree 第二个模板参数Value已经控制了红黑树结点中的存储的数据类型,为什么还要传第一个模板参数Key呢?尤其是set,两个模板参数是一样的,这是很多同学这时的一个疑问。要注意的是对于map和set,find/erase时的函数参数都是Key,所以第一个模板参数是传给find/erase等函数做形参的类型的。对于set而言两个参数是一样的,但是对于map而言就完全不一样了,map insert的是 pair对象,但是find和ease的Key对象。
  • 吐槽一下,这里源码命名风格比较乱,set模板参数用的Key命令,map用的是Key和T命名,而 rb_tree用的又是Key和Value,可见大佬有时写代码也不规范,乱弹琴。

二. 核心设计思路:红黑树的泛型复用

STL 中 map 和 set 复用同一颗红黑树的核心是泛型编程 + 仿函数提取 key,解决了 “一颗树适配两种数据场景” 的问题,具体设计思路如下:

2.1 红黑树的模板参数设计

红黑树需要支持存储两种数据类型:

  • set 场景:存储单个 key(如intstring);
  • map 场景:存储pair<const Key, Value>(key 不可修改)。

因此红黑树的模板参数需抽象为 3 个:

template<class K, class T, class KeyOfT>
class RBTree {
    // K:find/erase时的key类型(统一接口参数)
    // T:红黑树节点存储的实际数据类型(set为 K,map为 pair<const K, V>)
    // KeyOfT:仿函数,从T中提取K(解决T类型不统一的比较问题)
};

2.2 仿函数 KeyOfT:统一 key 提取逻辑

由于 T 的类型不固定(K 或 pair),红黑树插入 / 查找时无法直接获取 key,需通过仿函数KeyOfT统一提取,由 map 和 set 分别实现适配:

  • set 的仿函数:直接返回 key(T=K);
  • map 的仿函数:返回 pair 的 first 成员(T=pair<const K, V>)。

2.3 核心约束:key 不可修改

  • set 的 key 是唯一标识,需禁止修改:红黑树存储const K
  • map 的 key 是索引,需禁止修改:pair 的 first 设为const K,value 可正常修改。

三. 基础组件实现:红黑树与仿函数

3.1 红黑树节点结构

节点存储模板类型 T,包含左右子指针、父指针和颜色标记:

#pragma once
#include<iostream>
#include<assert.h>
using namespace std;


// 枚举结点颜色
enum Colour
{
	Red,  // 红色结点
	Black // 黑色结点
};

// 红黑树结构
template<class T>
struct RBTreeNode
{
	T _data;				 //存储实际数据(K或pair<const K, V>)
	RBTreeNode<T>* _parent;  // 左子节点指针
	RBTreeNode<T>* _left;    // 右子节点指针
	RBTreeNode<T>* _right;   // 父节点指针(回溯平衡需用到)
	Colour _col;			 // 节点颜色

	RBTreeNode(const T& data)
		:_parent(nullptr)
		, _left(nullptr)
		, _right(nullptr)
		, _data(data)
		, _col(Red) // 非空树插入时设为红色,避免破坏规则4
	{}
};

3.2 仿函数实现(map/set 层)

3.2.1 set 的仿函数:直接返回 key
#pragma once
#include"RBTree.h"

namespace Scy
{
	template<class K>
	class set
	{
		// 仿函数:从T(const K)中提取key
		struct SetKeyofT
		{
			const K& operator() (const K& key)
			{
				return key;
			}
		};
	private:
		// 红黑树:存储const K,禁止修改
		RBTree<K, const K, SetKeyofT> _t;
	};
}
3.2.2 map 的仿函数:提取 pair 的 first
#pragma once
#include"RBTree.h"

template<class K, class V>
class map
{
	struct MapKeyofT
	{
		// 仿函数:从T(pair<const K, V>)中提取key
		const K& operator() (const pair<K,V>& kv)
		{
			return kv.first;
		}
	};
private:
	// 红黑树:存储pair<const K, V>,key不可修改
	RBTree<K, pair<const K, V>, MapKeyofT> _t;
};

3.3 红黑树核心接口(附迭代器)

重点实现Insert(返回pair<Iterator, bool>,支持 map 的 [])和Find,平衡维护逻辑与基础红黑树一致:包括迭代器,operator++这里实现一下,- - 的话就不展示了,要实现的话还需要额外带一个_root;

#pragma once
#include<iostream>
#include<assert.h>
using namespace std;


// 枚举结点颜色
enum Colour
{
	Red,  // 红色结点
	Black // 黑色结点
};

// 红黑树结构
template<class T>
struct RBTreeNode
{
	T _data;				 //存储实际数据(K或pair<const K, V>)
	RBTreeNode<T>* _parent;  // 左子节点指针
	RBTreeNode<T>* _left;    // 右子节点指针
	RBTreeNode<T>* _right;   // 父节点指针(回溯平衡需用到)
	Colour _col;			 // 节点颜色

	RBTreeNode(const T& data)
		:_parent(nullptr)
		, _left(nullptr)
		, _right(nullptr)
		, _data(data)
		, _col(Red) // 非空树插入时设为红色,避免破坏规则4
	{}
};


template<class T,class Ref,class Ptr>
struct RBTreeIterator
{
	typedef RBTreeNode<T> Node;
	typedef RBTreeIterator<T, Ref, Ptr> Self;

	Node* _node;

	RBTreeIterator(Node* node)
		:_node(node)
	{}

	Self& operator++()
	{
		if (_node->_right)
		{
			Node* minRight = _node->_right;
			while (minRight->_left)
			{
				minRight = minRight->_left;
			}
			_node = minRight;
		}
		else
		{
			Node* cur = _node;
			Node* parent = cur->_parent;
			while (parent && cur == parent->_right)
			{
				cur = parent;
				parent = parent->_parent;
			}
			_node = parent;
		}
		return *this;
	}
	Ref operator* ()
	{
		return _node->_data;
	}

	Ptr operator->()
	{
		return &(_node->_data);
	}

	bool operator!=(const Self& s) const
	{
		return _node != s._node;
	}

	bool operator-(const Self& s) const
	{
		return _node == s._node;
	}
};

// RBTree<K, pair<K, V>> _t;-> // map
// RBTree<K, K> _t;->		   // set

template<class K, class T,class KeyofT>
class RBTree
{
	typedef RBTreeNode<T> Node;
public:
	typedef RBTreeIterator<T, T&, T*> Iterator;
	typedef RBTreeIterator<T, const T&, const T*> ConstIterator;
	
	~RBTree()
	{
		Destory(_root);
		_root = nullptr;
	}

	void Destory(Node* root)
	{
		if (root == nullptr)
			return;
		Destory(root->_left);
		Destory(root->_right);
		delete root;
	}

	Iterator Begin()
	{
		Node* minLeft = _root;
		while (minLeft && minLeft->_left)
		{
			minLeft = minLeft->_left;
		}
		return Iterator(minLeft);
	}

	Iterator End()
	{
		return Iterator(nullptr);
	}

	ConstIterator Begin() const
	{
		Node* minLeft = _root;
		while (minLeft && minLeft->_left)
		{
			minLeft = minLeft->_left;
		}
		return ConstIterator(minLeft);
	}

	ConstIterator End() const
	{
		return ConstIterator(nullptr);
	}
	
	// 插入接口:返回pair<迭代器, bool>(bool标记是否插入成功)
	pair<Iterator,bool> Insert(const T& data)
	{
		if (_root == nullptr)
		{
			_root = new Node(data);
			_root->_col = Black;
			return { Iterator(_root),true };
		}

		KeyofT kot;
		Node* parent = nullptr;
		Node* cur = _root;

		while (cur)
		{
			if (kot(cur->_data) < kot(data))
			{
				parent = cur;
				cur = cur->_right;
			}
			//else if (kot(cur->_data) > kot(data))
			else if (kot(data) < kot(cur->_data))
			{
				parent = cur;
				cur = cur->_left;
			}
			else
			{
				return {Iterator(cur),false};
			}
		}

		cur = new Node(data);
		Node* newnode = cur;
		cur->_col = Red;
		if (kot(parent->_data) < kot(data))
		{
			parent->_right = cur;
		}
		else
		{
			parent->_left = cur;
		}
		cur->_parent = parent;

		while (parent && parent->_col == Red)
		{
			Node* grandparent = parent->_parent;
			if (grandparent->_left == parent)
			{
				Node* uncle = grandparent->_right;
				// uncle存在且为红色
				if (uncle && uncle->_col == Red)
				{
					// 变色+继续向上处理
					parent->_col = Black;
					uncle->_col = Black;
					grandparent->_col = Red;

					cur = grandparent;
					parent = cur->_parent;
				}
				else //uncle不存在或者存在且为黑色 
				{
					if (cur == parent->_left) // 单旋+变色
					{
						//   g
						// p   u
						//c
						RotateR(grandparent);
						parent->_col = Black;
						grandparent->_col = Red;
					}
					else // 双旋+变色
					{
						//   g
						// p   u
						//  c
						RotateL(parent);
						RotateR(grandparent);

						cur->_col = Black;
						grandparent->_col = Red;
					}
					break;
				}
			}

			else
			{
				Node* uncle = grandparent->_left;
				if (uncle && uncle->_col == Red)
				{
					// 变色+继续向上处理
					uncle->_col = Black;
					parent->_col = Black;
					grandparent->_col = Red;

					cur = grandparent;
					parent = cur->_parent;
				}
				else
				{
					if (parent->_right == cur) // 单旋+变色
					{
						//   g
						// u   p
						//      c
						RotateL(grandparent);

						parent->_col = Black;
						grandparent->_col = Red;
					}
					else // 双旋+变色
					{
						//   g
						// u   p
						//   c

						RotateR(parent);
						RotateL(grandparent);

						cur->_col = Black;
						grandparent->_col = Red;
					}

					break;
				}
			}
		}
		// 确保根节点始终为黑色(防止回溯时根被设为红色)
		_root->_col = Black;

		return {Iterator(newnode),true};
	}

	// 查找接口:按K查找,返回迭代器
	Iterator* Find(const K& key) 
	{
		KeyofT kot;
		Node* cur = _root;
		while (cur) {
			if (kot(cur->_data) < key) {
				cur = cur->_right;
			}
			else if (kot(cur->_data) > key) {
				cur = cur->_left;
			}
			else {
				return Iterator(cur);  // 找到,返回节点指针
			}
		}
		return End();  // 未找到
	}

private:
	void RotateR(Node* parent)
	{
		Node* subL = parent->_left;
		Node* subLR = subL->_right;

		parent->_left = subLR;
		if (subLR)
			subLR->_parent = parent;

		Node* grandparent = parent->_parent;
		subL->_right = parent;
		parent->_parent = subL;

		if (parent == _root)
		{
			_root = subL;
			subL->_parent = nullptr;
		}
		else {
			if (grandparent->_left == parent)
				grandparent->_left = subL;
			else
				grandparent->_right = subL;

			subL->_parent = grandparent;
		}
	}

	void RotateL(Node* parent)
	{
		Node* subR = parent->_right;
		Node* subRL = subR->_left;

		parent->_right = subRL;
		if (subRL)
			subRL->_parent = parent;

		Node* grandparent = parent->_parent;
		subR->_left = parent;
		parent->_parent = subR;

		if (_root == parent)
		{
			_root = subR;
			subR->_parent = nullptr;
		}
		else
		{
			if (grandparent->_left == parent)
				grandparent->_left = subR;
			else
				grandparent->_right = subR;

			subR->_parent = grandparent;
		}
	}
private:
	Node* _root = nullptr;
};

3.4 iterator 实现思路分析:

  • iterator 实现的大框架跟list的iterator思路是一致的,用一个类型封装结点的指针,再通过重载运算符实现,迭代器像指针一样访问的行为。
  • 这里的难点是operator++和operator–的实现,之前使用部分,我们分析了,map和set的迭代器走的是中序遍历,左子树->根结点->右子树,那么begin()会返回中序第一个结点的iterator也就是10所在结点的迭代器。
  • 迭代器++的核心逻辑就是不看全局,只看局部,只考虑当前中序局部要访问的下一个结点。
  • 迭代器++时,如果it指向的结点的右子树不为空,代表当前结点已经访问完了,要访问下一个结点是右子树的中序第一个,一棵树中序第一个是最左结点,所以直接找右子树的最左结点即可。
  • 迭代器++时,如果it指向的结点的右子树空,代表当前结点已经访问完了且当前结点所在的子树也访问完了,要访问的下一个结点在当前结点的祖先里面,所以要沿着当前结点到根的祖先路径向上找。
  • 如果当前结点是父亲的左,根据中序左子树->根结点->右子树,那么下一个访问的结点就是当前结点的父亲;如下图:it指向25,25右为空,25是30的左,所以下一个访问的结点就是30.
  • 如果当前结点时父亲的右,根据中序左子树->根结点->右子树,当前结点所在的子树访问完了,当前结点所在父亲的子树也已经访问完了,那么下一个访问的需要继续往根的祖先中去找,直到找到孩子是父亲左的那个祖先就是中序要走的下一个结点。如下图:it指向15,15为空,15是10的右,15所在子树访问完了,10所在的子树也访问完了,继续往上找,10是18的左,那么下一个访问的结点就是18.
  • end()如何表示呢?如下图:当it指向50时,++it时,50是40的右,40是30的右,30是18的右,18到根没有父亲,没有找到孩子是父亲左的那个祖先,这时父亲为空了,那么我们就把it 中的结点指针置为nullptr,我们用去充当end。需要注意的是stl源空,红黑树增加了一个哨兵位头结点做为end(),这哨兵位头结点和根互为父亲,左指向最左结点,右指向最右结点。相比我们用nullptr作为end(),差别不大,他能实现的,我们也能实现。只是–end()判断到结点是空,特殊处理一下,让迭代器结点指向最右结点。具体参考迭代器一个个实现。
  • 迭代器–的实现跟++的思路完全类似,逻辑正好反过来即可,因为他访问顺序是右子树->根结点->左子树。但是需要一个_root
  • set的iterator也不支持修改,我们把set的第⼆个模板参数改成const K即可, RBTree<K,const K, SetKeyOfT> _t;
  • map的iterator不支持修改key但是可以修改value,我们把map的第二个模板参数pair的第⼀个参数改成const K即可, RBTree<K, pair<const K, V>, MapKeyOfT> _t;
  • 支持完整的迭代器还有很多细节需要修改,具体参考上面的代码。

别的实现方式:大家可以自己看看STL源码剖析。

四. mySet 与 myMap 完整实现

map支持[]主要修改insert返回值支持,修改RBTree中的insert返回值为 pair<Iterator,bool> Insert(const T& data)

4.1 mySet 实现

#pragma once
#include"RBTree.h"

namespace Scy
{
	template<class K>
	class set
	{
		// 仿函数:从T(const K)中提取key
		struct SetKeyofT
		{
			const K& operator() (const K& key)
			{
				return key;
			}
		};
	public:
		// typename 是为了防止这里没实例化报错
		typedef typename RBTree<K, const K, SetKeyofT>::Iterator iterator;
		typedef typename RBTree<K, const K, SetKeyofT>::ConstIterator const_iterator;

		iterator begin()
		{
			return _t.Begin();
		}
		
		iterator end()
		{
			return _t.End();
		}


		const_iterator begin() const
		{
			return _t.Begin();
		}

		const_iterator end() const
		{
			return _t.End();
		}
		
		pair<iterator, bool> insert(const K& key)
		{
			return _t.Insert(key);
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}
	private:
		// 红黑树:存储const K,禁止修改
		RBTree<K, const K, SetKeyofT> _t;
	};
}

4.1 myMap 实现

#pragma once
#include"RBTree.h"

namespace Scy
{
	template<class K, class V>
	class map
	{
		struct MapKeyofT
		{
			// 仿函数:从T(pair<const K, V>)中提取key
			const K& operator() (const pair<K, V>& kv)
			{
				return kv.first;
			}
		};
	public:
		typedef typename RBTree<K, pair<const K, V>, MapKeyofT>::Iterator iterator;
		typedef typename RBTree<K, pair<const K, V>, MapKeyofT>::ConstIterator const_iterator;

		iterator begin()
		{
			return _t.Begin();
		}

		iterator end()
		{
			return _t.End();
		}


		const_iterator begin() const
		{
			return _t.Begin();
		}

		const_iterator end() const
		{
			return _t.End();
		}

		pair<iterator, bool> insert(const pair<K, V>& kv)
		{
			return _t.Insert(kv);
		}

		iterator find(const pair<K, V>& kv)
		{
			return _t.Find(kv);
		}
 	  // []运算符:支持插入+访问/修改value
		V& operator[](const K& key)
		{
			//pair<iterator, bool> ret = _t.Insert({ key,V() });
			auto [it, flag] = _t.Insert({ key,V() });
			return it->second;
		}
	private:
		// 红黑树:存储pair<const K, V>,key不可修改
		RBTree<K, pair<const K, V>, MapKeyofT> _t;
	};
}

五. 测试代码:验证功能

#define _CRT_SECURE_NO_WARNINGS 1
#include"RBTree.h"

#include"Map.h"
#include"Set.h"

template<class T>
void func(const Lotso::set<T>& s)
{
	typename Scy::set<T>::const_iterator it = s.begin();
	while (it != s.end())
	{
		//*it = 1;
		cout << *it << " ";
		++it;
	}
	cout << endl;
}

void test_set()
{
	Scy::set<int> s;
	s.insert(1);
	s.insert(2);
	s.insert(1);
	s.insert(5);
	s.insert(0);
	s.insert(10);
	s.insert(8);

	Scy::set<int>::iterator it = s.begin();
	// *it += 10;
	while (it != s.end())
	{
		cout << *it << " ";
		++it;
	}
	cout << endl;

	func(s);
}

void test_map()
{
	Scy::map<string, string> dict;
	dict.insert({ "sort", "排序" });
	dict.insert({ "left", "左边" });
	dict.insert({ "right", "右边" });

	dict["string"] = "字符串"; // 插入+修改
	dict["left"] = "左边xxx";  // 修改

	auto it = dict.begin();
	while (it != dict.end())
	{
		// it->first += 'x'; // 不能修改
		it->second += 'x';

		cout << it->first << ":" << it->second << endl;
		++it;
	}
	cout << endl;

	for (auto& [k, v] : dict)
	{
		cout << k << ":" << v << endl;
	}
	cout << endl;

	string arr[] = { "苹果", "西瓜", "苹果", "西瓜", "苹果", "苹果", "西瓜", "苹果", "香蕉", "苹果", "香蕉" };
	Lotso::map<string, int> countMap;
	for (auto& e : arr)
	{
		/*auto it = countMap.find(e);
		if (it != countMap.end())
		{
			it->second++;
		}
		else
		{
			countMap.insert({ e, 1 });
		}*/
		countMap[e]++;
	}

	for (auto& [k, v] : countMap)
	{
		cout << k << ":" << v << endl;
	}
	cout << endl;
}

int main()
{
	cout << "测试set:" << endl;
	test_set();
	cout << "------------------" << endl;
	cout << "测试map:" << endl;
	test_map();

	return 0;
}
  • 测试没有问题,可以正常使用

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐