Created
August 5, 2016 12:51
-
-
Save madjar/4c97e972694009bd6f3bbb5e33c73c13 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml.PipelineStage | |
import org.apache.spark.ml.feature.{OneHotEncoder, VectorAssembler, StringIndexer} | |
import org.apache.spark.mllib.linalg.VectorUDT | |
import scala.Function.const | |
import scala.language.{implicitConversions, reflectiveCalls} | |
/* Spark Pipeline API is kind of sad to use, so let's make a nicer, more compositional API! | |
*/ | |
/** | |
* A Col is a just a DataFrame column waiting to be build. | |
* | |
* Pass the `build` function the name of the output column to get the pipeline stages to build that column. | |
* `suggestedName` is the name that should be used for this column if it ends up being an intermediate column. | |
* | |
* Notice the type parameter `T`, it allow us to have a stringly typed API on top of pipelines. | |
* | |
* def example() = { | |
* col[String]("cyber_channel") |> stringIndexer |> oneHotEncoder build "output" | |
* | |
* val col1 = col[String]("someCol") | |
* val col2 = stringIndexer(col1) | |
* val pipeline: Array[PipelineStage] = col2.build("output2") | |
* | |
* vectorAssembler(Array(col[String]("someCol"), col[String]("anotherCol"))).build("vector") | |
* } | |
*/ | |
case class Col[T](suggestedName: String, build: String => Array[PipelineStage]) | |
object PipelineBuilder { | |
/** | |
* An existing column, that is already built (so we return an empty array when asked for the pipeline stages). | |
*/ | |
def col[T](name: String) = Col[T](name, const(Array())) | |
/** | |
* With that definition of Col, a transformation (or any pipeline stage, for that matter) is just a function | |
* that takes a column and returns another. | |
*/ | |
type Transfo[S, T] = Col[S] => Col[T] | |
/** | |
* This helper function creates a `Transfo` from a PipelineStage constructor (that has setInputCol and setOutputCol). | |
*/ | |
def mkColTransform[A <: PipelineStage {def setInputCol(c: String): A; def setOutputCol(c: String): A}, S, T] | |
(a: () => A)(col: Col[S]): Col[T] = { | |
Col(col.suggestedName + "_" + a().getClass.getSimpleName, | |
(nextCol: String) => col.build(col.suggestedName) :+ a().setInputCol(col.suggestedName).setOutputCol(nextCol)) | |
} | |
/** | |
* This other helper creates a function that returns a `Col` from an array of input col (for `VectorAssembler`, for example) | |
*/ | |
def mkColsTransform[A <: PipelineStage {def setInputCols(c: Array[String]): A; def setOutputCol(c: String): A}, S, T] | |
(a: () => A)(cols: Array[Col[S]]): Col[T] = { | |
val newName = cols.map(_.suggestedName).mkString("_") + "_" + a().getClass.getSimpleName | |
Col(newName, (nextCol: String) => cols.flatMap(c => c.build(c.suggestedName)) | |
:+ a().setInputCols(cols.map(_.suggestedName)).setOutputCol(nextCol)) | |
} | |
// Now, we can define some `Transfos` | |
val stringIndexer: Transfo[String, Double] = mkColTransform(() => new StringIndexer()) | |
val oneHotEncoder: Transfo[Double, VectorUDT] = mkColTransform(() => new OneHotEncoder()) | |
val vectorAssembler: Array[Col[String]] => Col[VectorUDT] = mkColsTransform(() => new VectorAssembler()) | |
// Finally, the `|>` operator copypasted from scalaz, to avoid a scalaz dependency | |
final class IdOps[A](val self: A) extends AnyVal { | |
/** Applies `self` to the provided function. The Thrush combinator. */ | |
def |>[B](f: A => B): B = | |
f(self) | |
} | |
implicit def ToIdOps[A](a: A): IdOps[A] = new IdOps(a) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment