引言

虽然 Serde 的派生宏能够自动生成序列化代码,但在许多实际场景中,我们需要对序列化过程进行精确控制。无论是为了兼容特定的数据格式、优化存储空间、保护敏感信息,还是实现复杂的业务逻辑,自定义序列化逻辑都是不可或缺的能力。Serde 通过灵活的 trait 系统和丰富的辅助方法,让开发者能够在保持类型安全的前提下,完全掌控数据的编码和解码过程。本文将深入探讨自定义序列化的各种技术,从字段级别的微调到完全自定义的实现,展示如何在 Rust 中优雅地处理复杂的序列化需求。

自定义序列化的三个层次

Serde 提供了三个层次的自定义能力,满足不同的需求场景。第一层是属性级别的定制,通过 #[serde(serialize_with = "...")] 等属性,针对特定字段使用自定义函数,这是最轻量级的方式,适合小范围的调整。第二层是 trait 方法的部分重写,保留派生宏生成的框架,但覆盖关键方法的实现,在保持大部分自动化的同时引入特殊处理。第三层是完全手动实现 SerializeDeserialize trait,拥有绝对的控制权,适合需要复杂逻辑或特殊优化的场景。

选择合适的层次是重要的权衡。属性级别的定制开发效率最高,维护成本最低,但灵活性有限;完全手动实现最灵活,但需要处理所有细节,容易出错。在实践中,应该遵循"最小定制"原则:优先使用派生宏和属性,只在必要时才手动实现。这种渐进式的方法让代码既保持了 Serde 的便利性,又能够应对复杂需求。

实践一:字段级别的序列化转换

最常见的自定义需求是转换字段的表示形式。例如,时间戳在内存中使用 SystemTime,但序列化为 Unix 时间戳整数;或者密码哈希在序列化时使用 Base64 编码。这些场景都可以通过 serialize_withdeserialize_with 属性优雅地实现:

use serde::{Serialize, Deserialize, Serializer, Deserializer};
use std::time::{SystemTime, UNIX_EPOCH, Duration};

#[derive(Serialize, Deserialize, Debug)]
struct Event {
    pub id: u64,
    
    #[serde(serialize_with = "serialize_timestamp")]
    #[serde(deserialize_with = "deserialize_timestamp")]
    pub created_at: SystemTime,
    
    #[serde(serialize_with = "serialize_base64")]
    #[serde(deserialize_with = "deserialize_base64")]
    pub signature: Vec<u8>,
    
    // 序列化时转换为小写
    #[serde(serialize_with = "serialize_lowercase")]
    pub event_type: String,
}

fn serialize_timestamp<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    let duration = time.duration_since(UNIX_EPOCH)
        .map_err(serde::ser::Error::custom)?;
    serializer.serialize_u64(duration.as_secs())
}

fn deserialize_timestamp<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
    D: Deserializer<'de>,
{
    let secs = u64::deserialize(deserializer)?;
    Ok(UNIX_EPOCH + Duration::from_secs(secs))
}

fn serialize_base64<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    use base64::{Engine as _, engine::general_purpose};
    let encoded = general_purpose::STANDARD.encode(bytes);
    serializer.serialize_str(&encoded)
}

fn deserialize_base64<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
    D: Deserializer<'de>,
{
    use base64::{Engine as _, engine::general_purpose};
    let s = String::deserialize(deserializer)?;
    general_purpose::STANDARD.decode(&s)
        .map_err(serde::de::Error::custom)
}

fn serialize_lowercase<S>(s: &str, serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    serializer.serialize_str(&s.to_lowercase())
}

// 使用示例
fn example_field_transformation() {
    let event = Event {
        id: 1,
        created_at: SystemTime::now(),
        signature: vec![0xDE, 0xAD, 0xBE, 0xEF],
        event_type: "USER_LOGIN".to_string(),
    };
    
    let json = serde_json::to_string_pretty(&event).unwrap();
    println!("{}", json);
    // 输出:
    // {
    //   "id": 1,
    //   "created_at": 1706140800,
    //   "signature": "3q2+7w==",
    //   "event_type": "user_login"
    // }
}

这个实现展示了几个重要的技术点:错误处理的转换 —— 使用 map_err(serde::ser::Error::custom) 将外部错误转换为 Serde 的错误类型;类型安全的保证 —— 自定义函数的签名确保了类型匹配,编译器会捕获任何不兼容的转换;单向转换的可能性 —— event_type 只有自定义序列化而没有自定义反序列化,反序列化时会直接读取字符串,这在某些场景下很有用。

实践二:条件序列化与动态字段

有时需要根据运行时条件决定是否序列化某些字段,或者动态地添加/移除字段。虽然 Serde 的 skip_serializing_if 提供了基本支持,但更复杂的逻辑需要手动实现:

use serde::{Serialize, Serializer};
use serde::ser::SerializeStruct;
use std::collections::HashMap;

#[derive(Debug)]
struct ApiResponse {
    pub status: String,
    pub data: Option<HashMap<String, String>>,
    pub error: Option<String>,
    pub metadata: HashMap<String, String>,
    pub include_debug: bool,
}

impl Serialize for ApiResponse {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        // 动态计算字段数量
        let mut field_count = 1; // status 总是存在
        if self.data.is_some() { field_count += 1; }
        if self.error.is_some() { field_count += 1; }
        if !self.metadata.is_empty() { field_count += 1; }
        if self.include_debug { field_count += 1; }
        
        let mut state = serializer.serialize_struct("ApiResponse", field_count)?;
        
        // 基础字段
        state.serialize_field("status", &self.status)?;
        
        // 条件序列化:只在有数据时包含
        if let Some(ref data) = self.data {
            state.serialize_field("data", data)?;
        }
        
        // 条件序列化:只在有错误时包含
        if let Some(ref error) = self.error {
            state.serialize_field("error", error)?;
        }
        
        // 条件序列化:只在非空时包含
        if !self.metadata.is_empty() {
            state.serialize_field("metadata", &self.metadata)?;
        }
        
        // 调试模式:添加额外信息
        if self.include_debug {
            let debug_info = format!(
                "data_present: {}, error_present: {}, metadata_count: {}",
                self.data.is_some(),
                self.error.is_some(),
                self.metadata.len()
            );
            state.serialize_field("debug_info", &debug_info)?;
        }
        
        state.end()
    }
}

// 使用示例
fn example_conditional_serialization() {
    let success_response = ApiResponse {
        status: "success".to_string(),
        data: Some(HashMap::from([
            ("user_id".to_string(), "123".to_string()),
            ("username".to_string(), "alice".to_string()),
        ])),
        error: None,
        metadata: HashMap::new(),
        include_debug: false,
    };
    
    let error_response = ApiResponse {
        status: "error".to_string(),
        data: None,
        error: Some("User not found".to_string()),
        metadata: HashMap::from([
            ("request_id".to_string(), "req-456".to_string()),
        ]),
        include_debug: true,
    };
    
    println!("Success: {}", serde_json::to_string_pretty(&success_response).unwrap());
    println!("Error: {}", serde_json::to_string_pretty(&error_response).unwrap());
}

这个实现的关键在于动态字段计数serialize_struct 的第二个参数指定了字段数量,必须与实际序列化的字段匹配。通过运行时计算,确保了结构的一致性。这种技术在构建 API 响应、处理可选特性或实现多态序列化时非常有用。

实践三:复杂嵌套结构的扁平化

有时需要将嵌套的结构体序列化为扁平的 JSON 对象,或者反过来将扁平的数据反序列化为嵌套结构。这在对接第三方 API 或优化数据库存储时很常见:

use serde::{Serialize, Deserialize, Serializer, Deserializer};
use serde::ser::SerializeMap;
use serde::de::{self, MapAccess, Visitor};
use std::fmt;

#[derive(Debug)]
struct UserProfile {
    pub user_id: u64,
    pub username: String,
    pub contact: ContactInfo,
    pub preferences: Preferences,
}

#[derive(Debug, Serialize, Deserialize)]
struct ContactInfo {
    pub email: String,
    pub phone: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
struct Preferences {
    pub theme: String,
    pub language: String,
    pub notifications_enabled: bool,
}

// 自定义序列化:将嵌套结构扁平化
impl Serialize for UserProfile {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let mut map = serializer.serialize_map(None)?;
        
        // 顶层字段
        map.serialize_entry("user_id", &self.user_id)?;
        map.serialize_entry("username", &self.username)?;
        
        // 扁平化 contact 字段
        map.serialize_entry("contact_email", &self.contact.email)?;
        if let Some(ref phone) = self.contact.phone {
            map.serialize_entry("contact_phone", phone)?;
        }
        
        // 扁平化 preferences 字段
        map.serialize_entry("pref_theme", &self.preferences.theme)?;
        map.serialize_entry("pref_language", &self.preferences.language)?;
        map.serialize_entry("pref_notifications", &self.preferences.notifications_enabled)?;
        
        map.end()
    }
}

// 自定义反序列化:从扁平结构重建嵌套
impl<'de> Deserialize<'de> for UserProfile {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        struct UserProfileVisitor;
        
        impl<'de> Visitor<'de> for UserProfileVisitor {
            type Value = UserProfile;
            
            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                formatter.write_str("a flattened UserProfile")
            }
            
            fn visit_map<V>(self, mut map: V) -> Result<UserProfile, V::Error>
            where
                V: MapAccess<'de>,
            {
                let mut user_id = None;
                let mut username = None;
                let mut contact_email = None;
                let mut contact_phone = None;
                let mut pref_theme = None;
                let mut pref_language = None;
                let mut pref_notifications = None;
                
                while let Some(key) = map.next_key::<String>()? {
                    match key.as_str() {
                        "user_id" => user_id = Some(map.next_value()?),
                        "username" => username = Some(map.next_value()?),
                        "contact_email" => contact_email = Some(map.next_value()?),
                        "contact_phone" => contact_phone = Some(map.next_value()?),
                        "pref_theme" => pref_theme = Some(map.next_value()?),
                        "pref_language" => pref_language = Some(map.next_value()?),
                        "pref_notifications" => pref_notifications = Some(map.next_value()?),
                        _ => {
                            // 忽略未知字段以实现向前兼容
                            let _ = map.next_value::<de::IgnoredAny>()?;
                        }
                    }
                }
                
                Ok(UserProfile {
                    user_id: user_id.ok_or_else(|| de::Error::missing_field("user_id"))?,
                    username: username.ok_or_else(|| de::Error::missing_field("username"))?,
                    contact: ContactInfo {
                        email: contact_email.ok_or_else(|| de::Error::missing_field("contact_email"))?,
                        phone: contact_phone,
                    },
                    preferences: Preferences {
                        theme: pref_theme.ok_or_else(|| de::Error::missing_field("pref_theme"))?,
                        language: pref_language.ok_or_else(|| de::Error::missing_field("pref_language"))?,
                        notifications_enabled: pref_notifications
                            .ok_or_else(|| de::Error::missing_field("pref_notifications"))?,
                    },
                })
            }
        }
        
        deserializer.deserialize_map(UserProfileVisitor)
    }
}

// 使用示例
fn example_flattening() {
    let profile = UserProfile {
        user_id: 42,
        username: "alice".to_string(),
        contact: ContactInfo {
            email: "alice@example.com".to_string(),
            phone: Some("+1234567890".to_string()),
        },
        preferences: Preferences {
            theme: "dark".to_string(),
            language: "en".to_string(),
            notifications_enabled: true,
        },
    };
    
    let json = serde_json::to_string_pretty(&profile).unwrap();
    println!("Serialized:\n{}", json);
    
    let deserialized: UserProfile = serde_json::from_str(&json).unwrap();
    println!("Deserialized: {:?}", deserialized);
}

这个实现展示了结构转换的完全控制:序列化时使用 serialize_map 动态添加键值对,反序列化时使用访问者模式逐个读取字段并重建嵌套结构。关键技术包括使用 de::IgnoredAny 忽略未知字段以实现向前兼容,以及通过 ok_or_else 提供清晰的错误消息。

实践四:多态序列化与类型标记

在处理继承或多态数据时,需要在序列化中包含类型信息以便反序列化时恢复正确的类型:

use serde::{Serialize, Deserialize, Serializer, Deserializer};
use serde::ser::SerializeStruct;
use serde::de::{self, MapAccess, Visitor};
use std::fmt;

#[derive(Debug)]
enum Shape {
    Circle { radius: f64 },
    Rectangle { width: f64, height: f64 },
    Triangle { base: f64, height: f64 },
}

impl Serialize for Shape {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        match self {
            Shape::Circle { radius } => {
                let mut state = serializer.serialize_struct("Shape", 2)?;
                state.serialize_field("type", "circle")?;
                state.serialize_field("radius", radius)?;
                state.end()
            }
            Shape::Rectangle { width, height } => {
                let mut state = serializer.serialize_struct("Shape", 3)?;
                state.serialize_field("type", "rectangle")?;
                state.serialize_field("width", width)?;
                state.serialize_field("height", height)?;
                state.end()
            }
            Shape::Triangle { base, height } => {
                let mut state = serializer.serialize_struct("Shape", 3)?;
                state.serialize_field("type", "triangle")?;
                state.serialize_field("base", base)?;
                state.serialize_field("height", height)?;
                state.end()
            }
        }
    }
}

impl<'de> Deserialize<'de> for Shape {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        #[derive(Deserialize)]
        #[serde(field_identifier, rename_all = "lowercase")]
        enum Field {
            Type,
            Radius,
            Width,
            Height,
            Base,
        }
        
        struct ShapeVisitor;
        
        impl<'de> Visitor<'de> for ShapeVisitor {
            type Value = Shape;
            
            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                formatter.write_str("a shape object with type field")
            }
            
            fn visit_map<V>(self, mut map: V) -> Result<Shape, V::Error>
            where
                V: MapAccess<'de>,
            {
                let mut shape_type: Option<String> = None;
                let mut radius = None;
                let mut width = None;
                let mut height = None;
                let mut base = None;
                
                while let Some(key) = map.next_key()? {
                    match key {
                        Field::Type => {
                            shape_type = Some(map.next_value()?);
                        }
                        Field::Radius => radius = Some(map.next_value()?),
                        Field::Width => width = Some(map.next_value()?),
                        Field::Height => height = Some(map.next_value()?),
                        Field::Base => base = Some(map.next_value()?),
                    }
                }
                
                let shape_type = shape_type.ok_or_else(|| de::Error::missing_field("type"))?;
                
                match shape_type.as_str() {
                    "circle" => {
                        let radius = radius.ok_or_else(|| de::Error::missing_field("radius"))?;
                        Ok(Shape::Circle { radius })
                    }
                    "rectangle" => {
                        let width = width.ok_or_else(|| de::Error::missing_field("width"))?;
                        let height = height.ok_or_else(|| de::Error::missing_field("height"))?;
                        Ok(Shape::Rectangle { width, height })
                    }
                    "triangle" => {
                        let base = base.ok_or_else(|| de::Error::missing_field("base"))?;
                        let height = height.ok_or_else(|| de::Error::missing_field("height"))?;
                        Ok(Shape::Triangle { base, height })
                    }
                    _ => Err(de::Error::unknown_variant(&shape_type, &["circle", "rectangle", "triangle"])),
                }
            }
        }
        
        deserializer.deserialize_map(ShapeVisitor)
    }
}

fn example_polymorphic() {
    let shapes = vec![
        Shape::Circle { radius: 5.0 },
        Shape::Rectangle { width: 10.0, height: 20.0 },
        Shape::Triangle { base: 8.0, height: 12.0 },
    ];
    
    for shape in &shapes {
        let json = serde_json::to_string(shape).unwrap();
        println!("Serialized: {}", json);
        
        let deserialized: Shape = serde_json::from_str(&json).unwrap();
        println!("Deserialized: {:?}\n", deserialized);
    }
}

这种标记联合类型(Tagged Union)的实现是处理多态数据的标准模式。通过显式的 type 字段标识变体,反序列化器能够正确恢复原始类型。这种方法在跨语言通信、版本演化和动态类型系统中特别有用。

深层思考:性能优化与零拷贝

自定义序列化不仅是为了功能性,也可以用于性能优化。一个关键技术是零拷贝反序列化,通过借用输入数据而非复制来减少内存分配。这需要正确使用生命周期参数 'de,并确保反序列化的类型能够借用数据。例如,使用 &'de str 而非 String,使用 &'de [u8] 而非 Vec<u8>

另一个优化是流式处理,对于大型数据集,避免一次性加载整个结构到内存。通过实现自定义的 Deserializer,可以边解析边处理数据,只在内存中保留当前处理的部分。这在处理日志文件、大型配置或网络流时能够显著降低内存使用。

总结

自定义序列化逻辑是 Serde 强大而灵活的核心能力。通过字段级别的转换函数、条件序列化、结构扁平化、多态处理等技术,你可以精确控制数据的编码和解码过程,满足各种复杂的业务需求。关键是理解 Serde 的 trait 架构,掌握访问者模式的应用,以及善用生命周期参数实现性能优化。在实践中,应该遵循渐进式定制的原则,在保持代码可维护性的同时获得必要的控制力。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐