Scala map function

Array.map可以有效幫你去除重複程式碼,這篇文章以Spark ML Pipeline為例子,先示範如何用map改寫重複程式,接著示範如何跟Pipeline結合

原本

這邊有五個column要使用StringIndexer

val indexer1 = new StringIndexer().setInputCol("sex").setOutputCol("sexIndex")
val indexer2 = new StringIndexer().setInputCol("educ").setOutputCol("educIndex")
val indexer3 = new StringIndexer().setInputCol("factory").setOutputCol("factoryIndex")
val indexer4 = new StringIndexer().setInputCol("home").setOutputCol("homeIndex")
val indexer5 = new StringIndexer().setInputCol("course").setOutputCol("courseIndex")
////////////////////////////////////////////////////////////////////////////////////
// 其他程式碼
////////////////////////////////////////////////////////////////////////////////////
val kmeansPipeline = new Pipeline().setStages(Array(
  indexer1, indexer2, indexer3, indexer4, indexer5,
  vectorAssembler, featureIndexer, kmeans, labelConverter
))

map簡單範例

map會用傳入的function將Array的每個小孩轉換成另外一個小孩,最後回傳一個新Array

val numbers = Array(1, 2, 3, 4, 5)
val squaredNumbers = numberArray.map(n => n * n)
//=> squaredNumbers: Array[Int] = Array(1, 4, 9, 16, 25)

參考文章:文件教學

用map改造程式

  • 先把要index的欄位放到Array裡
  • 寫map的轉換function
  • 最後產生一個Array[StringIndexer]
val columnsToBeIndexed = Array("sex", "educ", "factory", "home", "course")
// indexed columns's name will have "Index" appended
val stringIndexers = columnsToBeIndexed.map { columnName =>
  new StringIndexer()
    .setInputCol(columnName)
    .setOutputCol(s"${columnName}Index")
}
//=> stringIndexers: Array[StringIndexer] = Array(strIdx_1, strInx_2, ...

把結果放到Pipeline

因為我們最後要把stringIndexers這個Array裡面的每個小孩和其他TransformerEstimator一起放到pipeline裡,我們不能這樣直接放進去,因為stringIndexers是Array,以下錯誤範例:

val kmeansPipeline = new Pipeline().setStages(Array(
  stringIndexers, vectorAssembler, featureIndexer,
  kmeans, labelConverter
))
/*
<console>:29: error: type mismatch;
 found   : Array[org.apache.spark.ml.feature.StringIndexer]
 required: org.apache.spark.ml.PipelineStage
       val kmeansPipeline = new Pipeline().setStages(Array(stringIndexers, vectorAssemb...
                                                     ^
*/

我們只能把TransformerEstimator放到pipeline的stage Array,以下是不會錯但是不好的範例:

val kmeansPipeline = new Pipeline().setStages(Array(
  stringIndexers(0), stringIndexers(1),
  stringIndexers(2), stringIndexers(3),
  stringIndexers(4), vectorAssembler, featureIndexer,
  kmeans, labelConverter
))

這邊有兩個選擇:

  • 用pipeline
  • 用比較好的方法把每個StringIndexer放到stage Array

用pipeline

我把我最喜歡的放在最上面

因為pipeline是Estimator,把一個pipeline放到另一個pipeline是可以的,我們把stringIndexers放到pipeline裡面:

val columnsToBeIndexed = Array("sex", "educ", "factory", "home", "course")
// indexed columns's name will have "Index" appended
val stringIndexerPipeline = {
  val stringIndexers = columnsToBeIndexed.map { columnName =>
    new StringIndexer()
      .setInputCol(columnName)
      .setOutputCol(s"${columnName}Index")
  }
  new Pipeline().setStages(stringIndexers)
}
val kmeansPipeline = new Pipeline().setStages(Array(
  stringIndexerPipeline, vectorAssembler, featureIndexer,
  kmeans, labelConverter
))

用flatMap

flatMap是map的兄弟,他們都會做map,不過flatMap會把結果flatten

val nestedArray = Array(stringIndexers, vectorAssembler, featureIndexer, kmeans, labelConverter)
val stageArray = nestedArray.flatMap(s => s)

用 ++ 把兩個Array合併

其實我上面只是想介紹一下flatMap,他有其他很酷炫的功能,但是在這邊有點大材小用。

val stageArray = stringIndexers ++ Array(vectorAssembler, featureIndexer, kmeans, labelConverter)
more!