Java AI 之 DJL 实战(第 3 篇):数据加载与预处理
在深度学习模型的训练过程中,数据加载和预处理是至关重要的步骤,直接影响模型的性能和训练效率。DJL 数据处理核心组件DJL 数据处理工作流程DJL 的数处理流程遵循 “数据源定义 → 数据集封装 → 预处理流水线 → 迭代器加载” 的核心逻辑,具体步骤如下:DJL 的是所有数据集的基类,定义了数据集的基本行为(如获取样本数、获取单个样本)。常用实现类:DJL 的核心数据结构,替代 DL4J 的,提
数据加载与预处理
在深度学习模型的训练过程中,数据加载和预处理是至关重要的步骤,直接影响模型的性能和训练效率。
DJL 数据处理核心组件
| DJL 组件 | 功能描述 | 对应 DL4J/DataVec 组件 |
|---|---|---|
Dataset接口 |
定义数据集的核心规范,所有自定义数据集需实现该接口 | 通用数据源定义 |
RandomAccessDataset |
支持随机访问的数据集(如 CSV、图像数据集) | RecordReader + FileSplit |
Transform接口 |
定义数据预处理的单步操作(如归一化、编码、增强) | TransformProcess |
Pipeline |
组合多个 Transform,实现预处理流程的链式执行 | TransformProcess |
NDArray |
核心数据结构,承载张量数据并提供标准化、归一化、One-Hot 编码等操作 | DataNormalization + 各类 Transform |
ImageTransform |
图像专用预处理工具(旋转、缩放、裁剪、数据增强) | DataAugmentation |
DJL 数据处理工作流程
DJL 的数处理流程遵循 “数据源定义 → 数据集封装 → 预处理流水线 → 迭代器加载” 的核心逻辑,具体步骤如下:
- 定义数据源:指定数据文件路径 / 存储位置、数据格式;
- 构建数据集:基于
RandomAccessDataset实现自定义数据集,读取原始数据; - 定义预处理流水线(Pipeline):组合多个
Transform,实现数据清洗、标准化、编码等操作; - 执行预处理:通过
Pipeline自动对数据进行批量预处理; - 加载数据到模型:通过
DatasetIterator将预处理后的数据批量输入模型训练 / 评估。
DJL 数据处理核心概念和架构
核心组件详解
Dataset 接口
DJL 的Dataset是所有数据集的基类,定义了数据集的基本行为(如获取样本数、获取单个样本)。常用实现类:
RandomAccessDataset:支持随机访问的数据集(推荐用于 CSV、图像等本地文件数据);ArrayDataset:基于内存数组的轻量级数据集(适合小批量数据);ImageFolder:图像分类专用数据集(自动按文件夹分类加载图像)。
Transform 与 Pipeline
Transform:单步预处理操作的接口,自定义预处理需实现transform(NDList)方法;Pipeline:将多个Transform按顺序组合,形成完整的预处理流水线,支持批量数据的自动处理。
NDArray
DJL 的核心数据结构,替代 DL4J 的DataSet,提供以下关键预处理能力:
- 数值标准化 / 归一化(
sub()/div()/mean()/std()); - 类别特征编码(
oneHot()); - 数据类型转换(
toType()); - 缺失值处理(
where()/fill())。
数据迭代器
DatasetIterator:将数据集转换为批量迭代器,支持按批次加载数据到模型;Batch:封装单批次的特征和标签数据,直接输入模型训练。
加载和预处理 CSV 数据
准备 CSV 样例数据
在resources目录下创建data.csv:
id,name,age,gender,income
1,Alice,34,F,50000
2,Bob,45,M,60000
3,Charlie,23,M,45000
4,Diana,56,F,70000
5,Eva,29,F,55000
6,Frank,38,M,65000
7,Grace,42,F,58000
8,Henry,51,M,72000
9,Ivy,26,F,48000
10,Jack,33,M,62000
数据字段说明:
- id: 唯一标识符(整型);
- name: 姓名(字符串);
- age: 年龄(整型);
- gender: 性别(F/M,类别型);
- income: 收入(浮点型)。
加载 CSV 数据
在pom.xml中添加 DJL 核心依赖和 CSV 解析依赖:
<!-- CSV解析依赖 -->
<dependency>
<groupId>com.opencsv</groupId>
<artifactId>opencsv</artifactId>
<version>5.6</version>
</dependency>
自定义 CSV 数据集加载数据
package com.woniuxy.base.load;
import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.dataset.Record;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import ai.djl.util.Progress;
import com.opencsv.CSVReader;
import com.opencsv.exceptions.CsvException;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
import java.util.List;
/**
* 自定义CSV数据集
*/
public class CSVDataset extends RandomAccessDataset {
// 存储CSV解析后的数据
private List<String[]> csvData;
// 数据集实际大小(对应availableSize())
private long dataSize;
/**
* 构建器
*/
public static class Builder extends BaseBuilder<Builder> {
private NDManager manager;
public Builder setManager(NDManager manager) {
this.manager = manager;
return self();
}
@Override
protected Builder self() {
return this;
}
public CSVDataset build() throws IOException, URISyntaxException, CsvException {
return new CSVDataset(this);
}
}
/**
* 私有化构造器,通过Builder初始化(
*/
private CSVDataset(Builder builder) throws IOException, URISyntaxException, CsvException {
super(builder); // 必须调用父类构造器初始化采样器、Pipeline等参数
// 1. 加载CSV文件
URL resourceUrl = getClass().getClassLoader().getResource("data.csv");
if (resourceUrl == null) {
throw new IOException("CSV文件 data.csv 未找到");
}
File csvFile = Paths.get(resourceUrl.toURI()).toFile();
// 2. 解析CSV
try (CSVReader reader = new CSVReader(new FileReader(csvFile))) {
List<String[]> allLines = reader.readAll();
// 跳过表头(第一行)
this.csvData = allLines.subList(1, allLines.size());
this.dataSize = csvData.size();
}
}
/**
* 核心抽象方法:获取数据集实际大小
*/
@Override
protected long availableSize() {
return dataSize;
}
/**
* 核心抽象方法:读取单个样本(必须声明throws IOException + 适配Record构造器)
*/
@Override
public Record get(NDManager manager, long index) throws IOException {
// 校验索引合法性
if (index < 0 || index >= dataSize) {
throw new IndexOutOfBoundsException("CSV索引超出范围: " + index + ", 数据集大小: " + dataSize);
}
String[] line = csvData.get((int) index);
// 解析字段:id(0), name(1), age(2), gender(3), income(4)
int id = Integer.parseInt(line[0]);
int age = Integer.parseInt(line[2]);
String gender = line[3];
double income = Double.parseDouble(line[4]);
// 将数据转换为NDArray(特征:id, age, genderCode, income)
// 手动编码性别:F→0,M→1
int genderCode = "F".equals(gender) ? 0 : 1;
NDArray featuresArray = manager.create(new float[]{id, age, genderCode, (float) income});
// 将单个NDArray封装为NDList
NDList features = new NDList(featuresArray);
// 标签:暂设为0(封装为NDList),可根据业务需求修改(如预测收入则设为income)
NDArray labelArray = manager.create(0);
NDList labels = new NDList(labelArray);
// 返回Record
return new Record(features, labels);
}
/**
* 可选:数据集预处理准备(如加载数据、初始化资源)
*/
@Override
public void prepare(Progress progress) {
// 此处可添加数据加载进度展示等逻辑
System.out.println("CSV数据集加载完成,共" + dataSize + "条样本");
}
/**
* 测试CSV加载
*/
public static void main(String[] args) throws Exception {
try (NDManager manager = NDManager.newBaseManager()) {
// 构建数据集(设置采样器:批量大小2,顺序读取)
CSVDataset dataset = new CSVDataset.Builder()
.setManager(manager)
.setSampling(2, false) // 批量大小2,非随机采样
.optDataBatchifier(Batchifier.STACK) // 数据批处理方式
.optLabelBatchifier(Batchifier.STACK) // 标签批处理方式
.optDevice(Device.cpu()) // 指定CPU设备
.build();
// 准备数据集
dataset.prepare(null);
// 遍历所有样本
for (long i = 0; i < dataset.size(); i++) {
Record record = dataset.get(manager, i);
// 读取特征(NDList的第一个元素)
NDArray feature = record.getData().get(0);
// 读取标签(NDList的第一个元素)
NDArray label = record.getLabels().get(0);
System.out.println("样本" + (i + 1) + " - 特征:" + feature + " | 标签:" + label);
}
}
}
}
运行结果
CSV数据集加载完成,共4条样本
样本1 - 特征:ND: (4) cpu() float32
[ 1.00000000e+00, 3.40000000e+01, 0.00000000e+00, 5.00000000e+04]
| 标签:ND: () cpu() int32
0
样本2 - 特征:ND: (4) cpu() float32
[ 2.00000000e+00, 4.50000000e+01, 1.00000000e+00, 6.00000000e+04]
| 标签:ND: () cpu() int32
0
样本3 - 特征:ND: (4) cpu() float32
[ 3.00000000e+00, 2.30000000e+01, 1.00000000e+00, 4.50000000e+04]
| 标签:ND: () cpu() int32
0
样本4 - 特征:ND: (4) cpu() float32
[ 4.00000000e+00, 5.60000000e+01, 0.00000000e+00, 7.00000000e+04]
| 标签:ND: () cpu() int32
0
数据预处理
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.IOException;
import java.io.Reader;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
public class SimpleCSVDemo {
public static void main(String[] args) throws URISyntaxException, IOException {
// 1. 模拟CSV数据(不用真的读文件)
List<Person> people = createSampleData();
for (Person p : people) {
System.out.println(p);
}
// 2. 创建NDManager(理解成"数据管理器")
try (NDManager manager = NDManager.newBaseManager()) {
// 3. 预处理:把文字变成数字
System.out.println("\n2. 预处理:文字变数字");
SimpleData data = simplePreprocess(manager, people);
// 4. 查看处理后的张量
System.out.println("\n3. 处理后的张量:");
System.out.println("特征(年龄,性别编码):\n" + data.features);
System.out.println("\n标签(收入):\n" + data.labels);
// 5. 简单计算演示
System.out.println("\n4. 简单统计:");
showSimpleStats(data);
} catch (Exception e) {
e.printStackTrace();
}
}
// 1. 数据结构(简单的Java类)
static class Person {
int id;
String name;
int age;
String gender; // "男" 或 "女"
double income;
Person(int id, String name, int age, String gender, double income) {
this.id = id;
this.name = name;
this.age = age;
this.gender = gender;
this.income = income;
}
@Override
public String toString() {
return id + "," + name + ",年龄:" + age + ",性别:" + gender + ",收入:" + income;
}
}
// 2. 简单的数据容器
static class SimpleData {
NDArray features; // 特征矩阵
NDArray labels; // 标签
SimpleData(NDArray features, NDArray labels) {
this.features = features;
this.labels = labels;
}
}
// 3. 创建模拟数据(代替读CSV文件)
private static List<Person> createSampleData() throws URISyntaxException, IOException {
List<Person> people = new ArrayList<>();
URL resourceUrl = CSVDataProcessingExample.class.getClassLoader().getResource("data.csv");
Path path = Paths.get(resourceUrl.toURI());
try (Reader reader = Files.newBufferedReader(path);
CSVParser csvParser = new CSVParser(reader,
CSVFormat.DEFAULT.withFirstRecordAsHeader())) {
for (CSVRecord record : csvParser) {
int id = Integer.parseInt(record.get("id"));
String name = record.get("name");
int age = Integer.parseInt(record.get("age"));
String gender = record.get("gender");
double income = Double.parseDouble(record.get("income"));
people.add(new Person(id, name, age, gender, income));
}
}
return people;
}
// 4. 简化版预处理
private static SimpleData simplePreprocess(NDManager manager, List<Person> people) {
int count = people.size(); // 4个人
// 4.1 准备数组(更容易理解)
double[][] features = new double[count][2]; // 4行×2列
double[] labels = new double[count]; // 4个收入
// 4.2 手动填充数据
for (int i = 0; i < count; i++) {
Person p = people.get(i);
// 特征1:年龄(原样保留)
features[i][0] = p.age; // 第一列:年龄
// 特征2:性别编码(男→0,女→1)
int genderCode = p.gender.equals("F") ? 0 : 1;
features[i][1] = genderCode; // 第二列:性别编码
// 标签:收入
labels[i] = p.income;
}
// 4.3 创建张量
NDArray featuresTensor = manager.create(features);
NDArray labelsTensor = manager.create(labels);
return new SimpleData(featuresTensor, labelsTensor);
}
// 5. 显示简单统计
private static void showSimpleStats(SimpleData data) {
NDArray features = data.features;
NDArray labels = data.labels;
// 5.1 计算总和
double totalIncome = labels.sum().getDouble();
System.out.println("总收入:" + totalIncome);
// 5.2 计算平均
double avgAge = features.get(":, 0").mean().getDouble(); // 第0列是年龄
double avgIncome = labels.mean().getDouble();
System.out.println("平均年龄:" + avgAge);
System.out.println("平均收入:" + avgIncome);
// 5.3 找出最大值
double maxIncome = labels.max().getDouble();
System.out.println("最高收入:" + maxIncome);
// 5.4 按性别统计
System.out.println("\n按性别统计:");
// 男性数据(性别编码=0)
NDArray maleMask = features.get(":, 1").eq(0); // 第1列是性别,等于0的是男性
NDArray maleIncomes = labels.get(maleMask);
// 女性数据(性别编码=1)
NDArray femaleMask = features.get(":, 1").eq(1);
NDArray femaleIncomes = labels.get(femaleMask);
System.out.println("男性平均收入:" + maleIncomes.mean().getDouble());
System.out.println("女性平均收入:" + femaleIncomes.mean().getDouble());
}
}
运行结果
1,Alice,年龄:34,性别:F,收入:50000.0
2,Bob,年龄:45,性别:M,收入:60000.0
3,Charlie,年龄:23,性别:M,收入:45000.0
4,Diana,年龄:56,性别:F,收入:70000.0
5,Eva,年龄:29,性别:F,收入:55000.0
6,Frank,年龄:38,性别:M,收入:65000.0
7,Grace,年龄:42,性别:F,收入:58000.0
8,Henry,年龄:51,性别:M,收入:72000.0
9,Ivy,年龄:26,性别:F,收入:48000.0
10,Jack,年龄:33,性别:M,收入:62000.0
2. 预处理:文字变数字
3. 处理后的张量:
特征(年龄,性别编码):
ND: (10, 2) cpu() float64
[[34., 0.],
[45., 1.],
[23., 1.],
[56., 0.],
[29., 0.],
[38., 1.],
[42., 0.],
[51., 1.],
[26., 0.],
[33., 1.],
]
标签(收入):
ND: (10) cpu() float64
[50000., 60000., 45000., 70000., 55000., 65000., 58000., 72000., 48000., 62000.]
4. 简单统计:
总收入:585000.0
平均年龄:37.7
平均收入:58500.0
最高收入:72000.0
按性别统计:
男性平均收入:56200.0
女性平均收入:60800.0
模块化数据预处理
上面我们写了一个完整的数据处理程序,把所有代码都放在一个main方法里。这样写容易理解,但不方便复用。
模块化的意思就是把代码按功能拆分成几个部分:
原来的:一个大文件做所有事
现在的:
├── DataLoader.java 只负责读CSV
├── DataPreprocessor.java 只负责处理数据
├── DataAnalyzer.java 只负责统计分析
└── Main.java 把上面几个组合起来
好处:
- 每部分代码更短,更容易看懂
- 可以在其他项目里直接使用这些模块
- 修改一个功能不影响其他部分
就像做菜:原来一个人负责所有步骤,现在分工明确:有人洗菜、有人切菜、有人炒菜。
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.IOException;
import java.io.Reader;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
public class ModularCSVDemo {
public static void main(String[] args) throws URISyntaxException, IOException {
System.out.println("=== 模块化CSV数据处理 ===\n");
try (NDManager manager = NDManager.newBaseManager()) {
// 1. 分模块处理数据
System.out.println("1. 加载数据...");
List<Person> people = DataLoader.loadCSV("data.csv");
System.out.println("\n2. 预处理数据...");
ProcessedData data = DataPreprocessor.process(manager, people);
System.out.println("\n3. 分析数据...");
DataAnalyzer.analyze(data);
System.out.println("\n4. 显示结果...");
displayResults(data, people);
}
}
// ==================== 1. 数据加载模块 ====================
static class DataLoader {
static List<Person> loadCSV(String fileName) throws URISyntaxException, IOException {
List<Person> people = new ArrayList<>();
URL resourceUrl = ModularCSVDemo.class.getClassLoader().getResource(fileName);
Path path = Paths.get(resourceUrl.toURI());
try (Reader reader = Files.newBufferedReader(path);
CSVParser csvParser = new CSVParser(reader,
CSVFormat.DEFAULT.withFirstRecordAsHeader())) {
for (CSVRecord record : csvParser) {
int id = Integer.parseInt(record.get("id"));
String name = record.get("name");
int age = Integer.parseInt(record.get("age"));
String gender = record.get("gender");
double income = Double.parseDouble(record.get("income"));
people.add(new Person(id, name, age, gender, income));
}
}
System.out.println(" 加载完成:" + people.size() + "条记录");
return people;
}
}
// ==================== 2. 数据预处理模块 ====================
static class DataPreprocessor {
static ProcessedData process(NDManager manager, List<Person> people) {
int count = people.size();
// 准备数组
double[][] features = new double[count][2];
double[] labels = new double[count];
// 填充数据
for (int i = 0; i < count; i++) {
Person p = people.get(i);
// 特征1:年龄
features[i][0] = p.age;
// 特征2:性别编码(M→0, F→1)
features[i][1] = p.gender.equals("M") ? 0 : 1;
// 标签:收入
labels[i] = p.income;
}
// 创建张量
NDArray featuresTensor = manager.create(features);
NDArray labelsTensor = manager.create(labels);
System.out.println(" 预处理完成:特征" + featuresTensor.getShape() +
",标签" + labelsTensor.getShape());
return new ProcessedData(featuresTensor, labelsTensor, people);
}
}
// ==================== 3. 数据分析模块 ====================
static class DataAnalyzer {
static void analyze(ProcessedData data) {
NDArray features = data.features;
NDArray labels = data.labels;
// 基本统计
double totalIncome = labels.sum().getDouble();
double avgAge = features.get(":, 0").mean().getDouble();
double avgIncome = labels.mean().getDouble();
double maxIncome = labels.max().getDouble();
// 按性别统计
double[] genders = features.get(":, 1").toDoubleArray();
double[] incomes = labels.toDoubleArray();
double maleSum = 0, femaleSum = 0;
int maleCount = 0, femaleCount = 0;
for (int i = 0; i < genders.length; i++) {
if (genders[i] == 0) {
maleSum += incomes[i];
maleCount++;
} else {
femaleSum += incomes[i];
femaleCount++;
}
}
System.out.println(" 分析完成:");
System.out.println(" - 平均年龄: " + String.format("%.1f", avgAge));
System.out.println(" - 平均收入: " + String.format("%,.0f", avgIncome));
System.out.println(" - 最高收入: " + String.format("%,.0f", maxIncome));
System.out.println(" - 男性平均收入: " +
String.format("%,.0f", maleCount > 0 ? maleSum/maleCount : 0));
System.out.println(" - 女性平均收入: " +
String.format("%,.0f", femaleCount > 0 ? femaleSum/femaleCount : 0));
}
}
// ==================== 4. 结果显示模块 ====================
static void displayResults(ProcessedData data, List<Person> people) {
System.out.println("\n=== 最终结果 ===");
System.out.println("\n原始数据(前5条):");
for (int i = 0; i < Math.min(5, people.size()); i++) {
System.out.println(" " + people.get(i));
}
System.out.println("\n特征张量(前5行):");
for (int i = 0; i < Math.min(5, data.features.getShape().get(0)); i++) {
System.out.printf(" 样本%d: %s%n", i+1,
Arrays.toString(data.features.get(i).toDoubleArray()));
}
System.out.println("\n标签张量(前5行):");
for (int i = 0; i < Math.min(5, data.labels.getShape().get(0)); i++) {
System.out.printf(" 样本%d: %.0f%n", i+1,
data.labels.getDouble(i));
}
}
// ==================== 数据结构 ====================
static class Person {
int id;
String name;
int age;
String gender;
double income;
Person(int id, String name, int age, String gender, double income) {
this.id = id;
this.name = name;
this.age = age;
this.gender = gender;
this.income = income;
}
@Override
public String toString() {
return String.format("ID:%d %-10s 年龄:%2d 性别:%s 收入:%,.0f",
id, name, age, gender, income);
}
}
static class ProcessedData {
NDArray features;
NDArray labels;
List<Person> rawData;
ProcessedData(NDArray features, NDArray labels, List<Person> rawData) {
this.features = features;
this.labels = labels;
this.rawData = rawData;
}
}
}
运行结果
=== 模块化CSV数据处理 ===
1. 加载数据...
加载完成:10条记录
2. 预处理数据...
预处理完成:特征(10, 2),标签(10)
3. 分析数据...
分析完成:
- 平均年龄: 37.7
- 平均收入: 58,500
- 最高收入: 72,000
- 男性平均收入: 60,800
- 女性平均收入: 56,200
4. 显示结果...
=== 最终结果 ===
原始数据(前5条):
ID:1 Alice 年龄:34 性别:F 收入:50,000
ID:2 Bob 年龄:45 性别:M 收入:60,000
ID:3 Charlie 年龄:23 性别:M 收入:45,000
ID:4 Diana 年龄:56 性别:F 收入:70,000
ID:5 Eva 年龄:29 性别:F 收入:55,000
特征张量(前5行):
样本1: [34.0, 1.0]
样本2: [45.0, 0.0]
样本3: [23.0, 0.0]
样本4: [56.0, 1.0]
样本5: [29.0, 1.0]
标签张量(前5行):
样本1: 50000
样本2: 60000
样本3: 45000
样本4: 70000
样本5: 55000
组合Pipeline预处理
上一个我们把代码分成了几个模块,结构清晰多了。但还有个问题:处理步骤是固定的。
现在我们要用Pipeline(流水线)模式,让数据处理流程可以灵活配置。
核心思想:
- 每个处理步骤都是一个"小插件"
- 可以随意组合这些插件
- 不需要改代码就能调整处理流程
就像乐高积木:每个步骤是一个积木块,你可以按需要拼出不同的形状。
Pipeline 设计的好处
-
模块化设计:每个步骤独立,易于理解和维护
-
灵活配置:可以轻松添加、删除或重新排序步骤
-
可重用性:步骤可以在不同项目间重用
示例代码
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVRecord;
import java.io.IOException;
import java.io.Reader;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.text.Normalizer;
import java.util.*;
public class SimpleCSVDemoWithPipeline {
public static void main(String[] args) throws Exception {
System.out.println("=== 带Pipeline的CSV数据处理 ===\n");
// 1. 创建数据处理Pipeline
DataPipeline pipeline = new DataPipeline();
// 2. 添加处理步骤
pipeline.addStep(new CSVLoader()); // 步骤1:加载CSV
pipeline.addStep(new GenderEncoder()); // 步骤2:性别编码
pipeline.addStep(new FeatureSelector()); // 步骤3:选择特征
pipeline.addStep(new Statistics()); // 步骤4:统计分析
// 3. 执行Pipeline
try (NDManager manager = NDManager.newBaseManager()) {
ProcessedResult result = pipeline.execute(manager, "data.csv");
// 4. 显示结果
result.showResults();
}
}
// ==================== 数据结构 ====================
static class Person {
int id;
String name;
int age;
String gender;
double income;
Person(int id, String name, int age, String gender, double income) {
this.id = id;
this.name = name;
this.age = age;
this.gender = gender;
this.income = income;
}
@Override
public String toString() {
return String.format("ID:%d %-10s 年龄:%2d 性别:%s 收入:%,.0f",
id, name, age, gender, income);
}
}
// ==================== Pipeline处理结果 ====================
static class ProcessedResult {
List<Person> rawData; // 原始数据
List<Person> encodedData; // 编码后的数据(内存中)
NDArray features; // 特征张量
NDArray labels; // 标签张量
Map<String, Object> stats; // 统计信息
void showResults() {
System.out.println("\n=== 数据处理结果 ===");
System.out.println("\n1. 原始数据(" + rawData.size() + "条):");
rawData.forEach(p -> System.out.println(" " + p));
System.out.println("\n2. 特征张量(" + features.getShape() + "):");
printArray(features, 5); // 只显示前5行
System.out.println("\n3. 标签张量(" + labels.getShape() + "):");
printArray(labels, 5);
System.out.println("\n4. 统计信息:");
stats.forEach((key, value) -> System.out.printf(" %-20s: %s%n", key, value));
}
private void printArray(NDArray arr, int maxRows) {
long rows = Math.min(arr.getShape().get(0), maxRows);
for (int i = 0; i < rows; i++) {
System.out.printf(" 样本%d: %s%n", i+1,
Arrays.toString(arr.get(i).toDoubleArray()));
}
if (arr.getShape().get(0) > maxRows) {
System.out.printf(" ... 还有%d行未显示%n", arr.getShape().get(0) - maxRows);
}
}
}
// ==================== Pipeline接口 ====================
interface PipelineStep {
void process(ProcessedResult result, NDManager manager) throws Exception;
String getStepName();
}
// ==================== 数据处理Pipeline ====================
static class DataPipeline {
private List<PipelineStep> steps = new ArrayList<>();
void addStep(PipelineStep step) {
steps.add(step);
System.out.println("添加步骤: " + step.getStepName());
}
ProcessedResult execute(NDManager manager, String csvPath) throws Exception {
ProcessedResult result = new ProcessedResult();
result.stats = new LinkedHashMap<>(); // 保持插入顺序
System.out.println("\n开始执行Pipeline:");
for (int i = 0; i < steps.size(); i++) {
PipelineStep step = steps.get(i);
System.out.printf("%d. %s...%n", i+1, step.getStepName());
long startTime = System.currentTimeMillis();
step.process(result, manager);
long time = System.currentTimeMillis() - startTime;
System.out.printf(" 完成 (耗时: %d ms)%n", time);
}
return result;
}
}
// ==================== 步骤1:CSV加载器 ====================
static class CSVLoader implements PipelineStep {
@Override
public void process(ProcessedResult result, NDManager manager) throws Exception {
URL resourceUrl = SimpleCSVDemoWithPipeline.class.getClassLoader()
.getResource("data.csv");
Path path = Paths.get(resourceUrl.toURI());
List<Person> people = new ArrayList<>();
try (Reader reader = Files.newBufferedReader(path);
CSVParser csvParser = new CSVParser(reader,
CSVFormat.DEFAULT.withFirstRecordAsHeader())) {
for (CSVRecord record : csvParser) {
int id = Integer.parseInt(record.get("id"));
String name = record.get("name");
int age = Integer.parseInt(record.get("age"));
String gender = record.get("gender");
double income = Double.parseDouble(record.get("income"));
people.add(new Person(id, name, age, gender, income));
}
}
result.rawData = people;
result.stats.put("加载完成, 数据条数", people.size());
}
@Override
public String getStepName() {
return "加载CSV文件";
}
}
// ==================== 步骤2:性别编码器 ====================
static class GenderEncoder implements PipelineStep {
@Override
public void process(ProcessedResult result, NDManager manager) {
List<Person> encodedPeople = new ArrayList<>();
int maleCount = 0, femaleCount = 0;
for (Person p : result.rawData) {
// 创建编码后的副本(在实际应用中可能会修改原对象)
Person encoded = new Person(
p.id, p.name, p.age,
p.gender.equals("M") ? "0" : "1", // 编码为字符串"0"/"1"
p.income
);
encodedPeople.add(encoded);
if (p.gender.equals("M")){
maleCount++;
}else {
femaleCount++;
}
}
result.encodedData = encodedPeople;
result.stats.put("男性人数", maleCount);
result.stats.put("女性人数", femaleCount);
}
@Override
public String getStepName() {
return "性别编码(M→0, F→1)";
}
}
// ==================== 步骤3:特征选择器 ====================
static class FeatureSelector implements PipelineStep {
@Override
public void process(ProcessedResult result, NDManager manager) {
int count = result.rawData.size();
// 创建特征矩阵:使用double保证类型一致
double[][] features = new double[count][2]; // [年龄, 性别编码]
double[] labels = new double[count]; // 收入
for (int i = 0; i < count; i++) {
Person p = result.rawData.get(i);
Person encoded = result.encodedData.get(i);
features[i][0] = p.age; // 年龄
features[i][1] = Double.parseDouble(encoded.gender); // 性别编码(转为double)
labels[i] = p.income; // 收入
}
// 创建张量
result.features = manager.create(features);
result.labels = manager.create(labels).reshape(-1, 1); // 转为列向量
result.stats.put("特征维度", features[0].length);
result.stats.put("样本数量", count);
}
@Override
public String getStepName() {
return "特征选择与张量创建";
}
}
// ==================== 步骤5:统计分析 ====================
static class Statistics implements PipelineStep {
@Override
public void process(ProcessedResult result, NDManager manager) {
NDArray features = result.features;
NDArray labels = result.labels;
// 基本统计
double totalIncome = labels.sum().getDouble();
double avgAge = features.get(":, 0").mean().getDouble();
double avgIncome = labels.mean().getDouble();
double maxIncome = labels.max().getDouble();
double minIncome = labels.min().getDouble();
// 按性别统计
NDArray genderCol = features.get(":, 1");
NDArray maleMask = genderCol.eq(0);
NDArray femaleMask = genderCol.eq(1);
double maleAvgIncome = 0, femaleAvgIncome = 0;
int maleCount = 0, femaleCount = 0;
// 使用更安全的方法统计
for (int i = 0; i < labels.getShape().get(0); i++) {
double gender = genderCol.getDouble(i);
double income = labels.getDouble(i);
if (gender == 0) {
maleAvgIncome += income;
maleCount++;
} else {
femaleAvgIncome += income;
femaleCount++;
}
}
if (maleCount > 0){
maleAvgIncome /= maleCount;
}
if (femaleCount > 0){
femaleAvgIncome /= femaleCount;
}
// 存储统计结果
result.stats.put("总收入", String.format("%,.0f", totalIncome));
result.stats.put("平均年龄", String.format("%.1f岁", avgAge));
result.stats.put("平均收入", String.format("%,.0f", avgIncome));
result.stats.put("最高收入", String.format("%,.0f", maxIncome));
result.stats.put("最低收入", String.format("%,.0f", minIncome));
result.stats.put("男性人数", maleCount);
result.stats.put("女性人数", femaleCount);
result.stats.put("男性平均收入", String.format("%,.0f", maleAvgIncome));
result.stats.put("女性平均收入", String.format("%,.0f", femaleAvgIncome));
// 计算年龄与收入的相关性(使用修正的方法)
NDArray ages = features.get(":, 0");
NDArray incomes = labels.flatten();
double corr = calculateCorrelation(ages, incomes);
result.stats.put("年龄-收入相关性", String.format("%.3f", corr));
// 添加更多统计信息
result.stats.put("年龄范围", String.format("%.0f-%.0f岁",
ages.min().getDouble(), ages.max().getDouble()));
result.stats.put("收入标准差", String.format("%,.0f",
calculateStd(incomes)));
result.stats.put("年龄标准差", String.format("%.1f岁",
calculateStd(ages)));
}
// 计算标准差的方法
private double calculateStd(NDArray arr) {
double mean = arr.mean().getDouble();
NDArray centered = arr.sub(mean);
double variance = centered.pow(2).mean().getDouble();
return Math.sqrt(variance);
}
// 计算相关系数的方法
private double calculateCorrelation(NDArray x, NDArray y) {
// 1. 计算均值
double xMean = x.mean().getDouble();
double yMean = y.mean().getDouble();
// 2. 中心化
NDArray xCentered = x.sub(xMean);
NDArray yCentered = y.sub(yMean);
// 3. 计算协方差
double cov = xCentered.mul(yCentered).mean().getDouble();
// 4. 计算标准差
double xStd = calculateStd(x);
double yStd = calculateStd(y);
// 5. 计算相关系数
if (xStd == 0 || yStd == 0) {
return 0; // 避免除零
}
return cov / (xStd * yStd);
}
@Override
public String getStepName() {
return "数据统计分析";
}
}
// ==================== 额外步骤示例:数据分割 ====================
static class DataSplitter implements PipelineStep {
private double trainRatio = 0.8;
public DataSplitter(double trainRatio) {
this.trainRatio = trainRatio;
}
@Override
public void process(ProcessedResult result, NDManager manager) {
int total = (int) result.features.getShape().get(0);
int trainSize = (int) (total * trainRatio);
int testSize = total - trainSize;
// 分割特征
NDArray trainFeatures = result.features.get("0:" + trainSize);
NDArray testFeatures = result.features.get(trainSize + ":" + total);
// 分割标签
NDArray trainLabels = result.labels.get("0:" + trainSize);
NDArray testLabels = result.labels.get(trainSize + ":" + total);
// 在实际应用中,这里会将数据存储到result中
result.stats.put("训练集大小", trainSize);
result.stats.put("测试集大小", testSize);
result.stats.put("分割比例", String.format("%.0f%%/%.0f%%",
trainRatio*100, (1-trainRatio)*100));
}
@Override
public String getStepName() {
return String.format("数据分割(%.0f%%训练)", trainRatio*100);
}
}
}
运行结果
== 带Pipeline的CSV数据处理 ===
添加步骤: 加载CSV文件
添加步骤: 性别编码(M→0, F→1)
添加步骤: 特征选择与张量创建
添加步骤: 数据统计分析
开始执行Pipeline:
1. 加载CSV文件...
完成 (耗时: 31 ms)
2. 性别编码(M→0, F→1)...
完成 (耗时: 0 ms)
3. 特征选择与张量创建...
完成 (耗时: 25 ms)
4. 数据统计分析...
完成 (耗时: 32 ms)
=== 数据处理结果 ===
1. 原始数据(10条):
ID:1 Alice 年龄:34 性别:F 收入:50,000
ID:2 Bob 年龄:45 性别:M 收入:60,000
ID:3 Charlie 年龄:23 性别:M 收入:45,000
ID:4 Diana 年龄:56 性别:F 收入:70,000
ID:5 Eva 年龄:29 性别:F 收入:55,000
ID:6 Frank 年龄:38 性别:M 收入:65,000
ID:7 Grace 年龄:42 性别:F 收入:58,000
ID:8 Henry 年龄:51 性别:M 收入:72,000
ID:9 Ivy 年龄:26 性别:F 收入:48,000
ID:10 Jack 年龄:33 性别:M 收入:62,000
2. 特征张量((10, 2)):
样本1: [34.0, 1.0]
样本2: [45.0, 0.0]
样本3: [23.0, 0.0]
样本4: [56.0, 1.0]
样本5: [29.0, 1.0]
... 还有5行未显示
3. 标签张量((10, 1)):
样本1: [50000.0]
样本2: [60000.0]
样本3: [45000.0]
样本4: [70000.0]
样本5: [55000.0]
... 还有5行未显示
4. 统计信息:
加载完成, 数据条数 : 10
男性人数 : 5
女性人数 : 5
特征维度 : 2
样本数量 : 10
总收入 : 585,000
平均年龄 : 37.7岁
平均收入 : 58,500
最高收入 : 72,000
最低收入 : 45,000
男性平均收入 : 60,800
女性平均收入 : 56,200
年龄-收入相关性 : 0.867
年龄范围 : 23-56岁
收入标准差 : 8,652
年龄标准差 : 10.2岁
更多推荐


所有评论(0)