Scala map function
Array.map
可以有效幫你去除重複程式碼,這篇文章以Spark ML Pipeline
為例子,先示範如何用map改寫重複程式,接著示範如何跟Pipeline
結合
原本
這邊有五個column要使用StringIndexer
val indexer1 = new StringIndexer().setInputCol("sex").setOutputCol("sexIndex")
val indexer2 = new StringIndexer().setInputCol("educ").setOutputCol("educIndex")
val indexer3 = new StringIndexer().setInputCol("factory").setOutputCol("factoryIndex")
val indexer4 = new StringIndexer().setInputCol("home").setOutputCol("homeIndex")
val indexer5 = new StringIndexer().setInputCol("course").setOutputCol("courseIndex")
////////////////////////////////////////////////////////////////////////////////////
// 其他程式碼
////////////////////////////////////////////////////////////////////////////////////
val kmeansPipeline = new Pipeline().setStages(Array(
indexer1, indexer2, indexer3, indexer4, indexer5,
vectorAssembler, featureIndexer, kmeans, labelConverter
))
map簡單範例
map
會用傳入的function將Array的每個小孩轉換成另外一個小孩,最後回傳一個新Array
val numbers = Array(1, 2, 3, 4, 5)
val squaredNumbers = numberArray.map(n => n * n)
//=> squaredNumbers: Array[Int] = Array(1, 4, 9, 16, 25)
用map改造程式
- 先把要index的欄位放到Array裡
- 寫map的轉換function
- 最後產生一個
Array[StringIndexer]
val columnsToBeIndexed = Array("sex", "educ", "factory", "home", "course")
// indexed columns's name will have "Index" appended
val stringIndexers = columnsToBeIndexed.map { columnName =>
new StringIndexer()
.setInputCol(columnName)
.setOutputCol(s"${columnName}Index")
}
//=> stringIndexers: Array[StringIndexer] = Array(strIdx_1, strInx_2, ...
把結果放到Pipeline
因為我們最後要把stringIndexers
這個Array裡面的每個小孩和其他Transformer
跟Estimator
一起放到pipeline裡,我們不能這樣直接放進去,因為stringIndexers
是Array,以下錯誤範例:
val kmeansPipeline = new Pipeline().setStages(Array(
stringIndexers, vectorAssembler, featureIndexer,
kmeans, labelConverter
))
/*
<console>:29: error: type mismatch;
found : Array[org.apache.spark.ml.feature.StringIndexer]
required: org.apache.spark.ml.PipelineStage
val kmeansPipeline = new Pipeline().setStages(Array(stringIndexers, vectorAssemb...
^
*/
我們只能把Transformer
或Estimator
放到pipeline的stage Array,以下是不會錯但是不好的範例:
val kmeansPipeline = new Pipeline().setStages(Array(
stringIndexers(0), stringIndexers(1),
stringIndexers(2), stringIndexers(3),
stringIndexers(4), vectorAssembler, featureIndexer,
kmeans, labelConverter
))
這邊有兩個選擇:
- 用pipeline
- 用比較好的方法把每個StringIndexer放到stage Array
用pipeline
我把我最喜歡的放在最上面
因為pipeline是Estimator
,把一個pipeline放到另一個pipeline是可以的,我們把stringIndexers
放到pipeline裡面:
val columnsToBeIndexed = Array("sex", "educ", "factory", "home", "course")
// indexed columns's name will have "Index" appended
val stringIndexerPipeline = {
val stringIndexers = columnsToBeIndexed.map { columnName =>
new StringIndexer()
.setInputCol(columnName)
.setOutputCol(s"${columnName}Index")
}
new Pipeline().setStages(stringIndexers)
}
val kmeansPipeline = new Pipeline().setStages(Array(
stringIndexerPipeline, vectorAssembler, featureIndexer,
kmeans, labelConverter
))
用flatMap
flatMap是map的兄弟,他們都會做map,不過flatMap會把結果flatten
val nestedArray = Array(stringIndexers, vectorAssembler, featureIndexer, kmeans, labelConverter)
val stageArray = nestedArray.flatMap(s => s)
用 ++ 把兩個Array合併
其實我上面只是想介紹一下flatMap,他有其他很酷炫的功能,但是在這邊有點大材小用。
val stageArray = stringIndexers ++ Array(vectorAssembler, featureIndexer, kmeans, labelConverter)