摘要:在GEE中如何利用随机森林特征重要性和袋外误差进行特征优选。

随机森林中有两个重要概念:特征重要性(Feature Importance)、袋外误差(Out-of-Bag Error,简称OOB Error)。

什么是特征重要性:

特征重要性是指每个特征的重要性评分,用于评估各个输入特征对模型预测能力的贡献。这对于特征选择和降维非常有帮助,有助于去除冗余或不相关的特征,从而提升模型的性能。

计算重要性评分主要涉及以下两种方法:

1.基于Gini指数的重要性(Gini Importance):

对于每棵树,随机森林会根据每个特征的使用情况计算出Gini指数的下降量。特征在每次分裂中带来的Gini指数的减少越大,说明该特征对模型越重要。

2.基于精度的重要性(Mean Decrease Accuracy):

在进行特征重要性计算时,可以通过打乱某一特征的值来评估对模型性能的影响。具体来说,通过随机打乱某个特征的值并计算模型的预测精度变化,帮助确定该特征的重要性。如果打乱该特征值后模型的精度显著下降,说明该特征对模型预测至关重要。

什么是袋外误差:

在随机森林中,构建每棵决策树时,通常会使用一种叫做bootstrap的方法:从训练数据集中有放回地随机抽取样本。也就是说,每棵树在训练时只会使用训练集的一个子集,部分样本不会被选入该树的训练数据中,这些未被选中的样本称为“袋外样本”(Out-of-Bag Samples)。由于每棵树的训练数据都是随机抽样的,因此每棵树会有大约三分之一的样本作为它的袋外样本。袋外误差是随机森林中评估模型性能的一种有效方法,它利用袋外样本进行验证,能够提供对模型泛化能力的无偏估计。

如何利用随机森林重要性排序和袋外误差进行特征优选:

1.将特征重要性进行排序。

var importance = ee.Dictionary(classifier.explain().get("importance"));
var keys = importance.keys();
var values = importance.values();
var sortedValues = values.sort().reverse();
var sortedKeysList = sortedValues
  .map(function (value) {
    return keys.get(values.indexOf(value));
  })
  .flatten();
var importance_sort = sortedKeysList.zip(sortedValues);
print("importance", importance_sort);

2.根据特征重要性排序,依次选取前n(1,2,3,...N; N为特征总数)个重要特征,从而构造N个不同的特征子集,并分别训练随机森林模型,然后获取每个模型的OOB误差。

var bandNum = ee.List.sequence(1, SELECT_BANDS.size());
var oob_sco = bandNum.map(function (i) {
  var band_slice = SELECT_BANDS.slice(0, i);
  var classifier = ee.Classifier.smileRandomForest(100).train({
    features: training,
    classProperty: "landcover",
    inputProperties: band_slice
  });
  var outOfBagError = classifier.explain().get("outOfBagErrorEstimate");
  return ee.Number(outOfBagError);
});
print("袋外误差与光谱特征数量之间的关系", oob_sco);
var chart = ui.Chart.array.values({
                  array:oob_sco, 
                  axis:0, 
                  xLabels: bandNum
                })
                .setSeriesNames(["袋外误差"])
                .setOptions({
                  hAxis: {title: "光谱特征数量" },
                  vAxis: {title: "袋外误差" },
                  pointSize: 1,
                  legend: 'none',
                  series: {
        0: { color: "blue", lineWidth: 1, lineDashStyle: [1, 1], pointSize: 1 }
      }
});
print(chart);

3.比较所有模型的OOB误差,OOB误差最小的模型所使用的特征子集即为最优特征。

var oob_min = oob_sco.reduce(ee.Reducer.min());
var oob_min_index = oob_sco.indexOf(oob_min);
SELECT_BANDS = sortedKeysList.slice(0, oob_min_index.add(1));
print("最优特征", SELECT_BANDS);

完整代码如下:

var roi = ee.Geometry.Polygon(
        [[[98.76874122458284, 27.476046782884104],
          [98.76874122458284, 26.36424630156598],
          [100.16400489645784, 26.36424630156598],
          [100.16400489645784, 27.476046782884104]]], null, false);
Map.centerObject(roi, 8);                 
var maskclouds = function (img) {
    var cloudShadowBitMask = 1 << 4;
    var cloudsBitMask = 1 << 3;
    var qa = img.select("QA_PIXEL");
    var mask = qa
      .bitwiseAnd(cloudShadowBitMask)
      .eq(0)
      .and(qa.bitwiseAnd(cloudsBitMask).eq(0));
    return img.addBands(img.updateMask(mask), null, true);
};

var applyScaleFactors = function (image) {
    var opticalBands = image
      .select(["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"])
      .multiply(0.0000275)
      .add(-0.2).float();
    return image.addBands(opticalBands, null, true);
};

var imageCol = ee.ImageCollection("LANDSAT/LC08/C02/T1_L2")
                  .filterBounds(roi)
                  .filterDate('2021-01-01', '2021-12-31')
                  .map(maskclouds)
                  .select(
                    ["SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"],
                    ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"])
                  .map(applyScaleFactors);

var index_dict = {
  'EVI': '2.5 * ((NIR - Red) / (NIR + 6 * Red - 7.5 * Blue + 1))',
  'NDVI': '(NIR - Red) / (NIR + Red)',
  'DVI': 'NIR - Red',
  'RVI': 'NIR / Red',
  'NDWI': '(Green - NIR) / (Green + NIR)',
  'MNDWI': '(Green - SWIR1) / (Green + SWIR1)',
  'NDBI': '(SWIR1 - NIR) / (SWIR1 + NIR)',
  'NDSI': '(Green - SWIR1) / (Green + SWIR1)',
  'LSWI': '(NIR-SWIR1)/(NIR+SWIR1)',
  'AWEI_nsh': '4*(Green-SWIR1)-(0.25*NIR+2.75*SWIR2)',
  'AWEI_sh': 'Blue+2.5*Green-1.5*(NIR+SWIR1)-0.25*SWIR2',
  'NWI': '(Blue-(NIR+SWIR1+SWIR2)) / (Blue+(NIR+SWIR1+SWIR2))',
  'RRI': '(NIR - Red) / (NIR + Red)',
  'BSI': '(Red - Blue) / (Red + Blue)',
  'HUE': '(Green - NIR) / (Green + NIR)'
};                  
var culIndexs = function (imageCol, index_dict) {
    for (var key in index_dict) {
      var map_function = function (image) {
        return image.addBands(
          image.expression(index_dict[key], {
              Blue: image.select("Blue"),
              Green: image.select("Green"),
              Red: image.select("Red"),
              NIR: image.select("NIR"),
              SWIR1: image.select("SWIR1"),
              SWIR2: image.select("SWIR2"),
            }).float().rename(key)
        );
      };
      imageCol = imageCol.map(map_function);
    }
    return imageCol;
};
imageCol = culIndexs(imageCol, index_dict);
var image_median = imageCol.median().clip(roi);
var SELECT_BANDS = image_median.bandNames();
var sampleData = ee.FeatureCollection("users/lijian960708/examples/sampleData");
sampleData = sampleData.randomColumn("random");
var sample_training = sampleData.filter(ee.Filter.lte("random", 0.7));
var sample_validate = sampleData.filter(ee.Filter.gt("random", 0.7));
var training = image_median.sampleRegions({
    collection: sample_training,
    properties: ["landcover"],
    scale: 30,
    tileScale: 16
});
var validation = image_median.sampleRegions({
    collection: sample_validate,
    properties: ["landcover"],
    scale: 30,
    tileScale: 16
});
/**
* 训练
**/
var classifier = ee.Classifier.smileRandomForest(100).train({
    features: training,
    classProperty: "landcover",
    inputProperties: SELECT_BANDS
  });
print(classifier.explain());
/**
* 重要性排序
**/
var importance = ee.Dictionary(classifier.explain().get("importance"));
var keys = importance.keys();
var values = importance.values();
var sortedValues = values.sort().reverse();
var sortedKeysList = sortedValues
  .map(function (value) {
    return keys.get(values.indexOf(value));
  })
  .flatten();
var importance_sort = sortedKeysList.zip(sortedValues);
print("importance", importance_sort);
/**
* 袋外误差与光谱特征数量之间的关系
**/
var bandNum = ee.List.sequence(1, SELECT_BANDS.size());
var oob_sco = bandNum.map(function (i) {
  var band_slice = SELECT_BANDS.slice(0, i);
  var classifier = ee.Classifier.smileRandomForest(100).train({
    features: training,
    classProperty: "landcover",
    inputProperties: band_slice
  });
  var outOfBagError = classifier.explain().get("outOfBagErrorEstimate");
  return ee.Number(outOfBagError);
});
print("袋外误差与光谱特征数量之间的关系", oob_sco);
var chart = ui.Chart.array.values({
                  array:oob_sco, 
                  axis:0, 
                  xLabels: bandNum
                })
                .setSeriesNames(["袋外误差"])
                .setOptions({
                  hAxis: {title: "光谱特征数量" },
                  vAxis: {title: "袋外误差" },
                  pointSize: 1,
                  legend: 'none',
                  series: {
        0: { color: "blue", lineWidth: 1, lineDashStyle: [1, 1], pointSize: 1 }
      }
});
print(chart);
/**
* 最优特征
**/
var oob_min = oob_sco.reduce(ee.Reducer.min());
var oob_min_index = oob_sco.indexOf(oob_min);
SELECT_BANDS = sortedKeysList.slice(0, oob_min_index.add(1));
print("最优特征", SELECT_BANDS);

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐