Skip to content

Instantly share code, notes, and snippets.

@izeigerman
Created January 30, 2021 22:19
Show Gist options
  • Save izeigerman/d1fe83519767de6514cfd02384075457 to your computer and use it in GitHub Desktop.
Save izeigerman/d1fe83519767de6514cfd02384075457 to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.functions.broadcast
import shapeless.ops.hlist.Prepend
import shapeless.{::, HList, HNil}
object flow {
type JoinList = HList
case class AnnotatedDataFrame[D, J <: JoinList](toDF: DataFrame) extends Serializable
/** DataSource type class */
trait DataSource[D, P] {
def read(parameters: P)(implicit spark: SparkSession): AnnotatedDataFrame[D, HNil]
}
object DataSource {
def apply[D]: Helper[D] = new Helper[D]
// Helper used to improve the type inference and make the reading API cleaner.
final class Helper[D] {
def read[P](parameters: P)(implicit S: DataSource[D, P], spark: SparkSession): AnnotatedDataFrame[D, HNil] =
S.read(parameters)
def read(implicit S: DataSource[D, Unit], spark: SparkSession): AnnotatedDataFrame[D, HNil] =
S.read(())
}
}
/** Join type class */
sealed trait Join[L, LJ <: JoinList, R, RJ <: JoinList] {
def join(
left: AnnotatedDataFrame[L, LJ],
right: AnnotatedDataFrame[R, RJ]
)(implicit P: Prepend[LJ, RJ]): AnnotatedDataFrame[L, R :: P.Out]
}
object Join {
def apply[L, LJ <: JoinList, R, RJ <: JoinList](
joinExprs: Column,
joinType: JoinType = Inner,
isBroadcast: Boolean = false
): Join[L, LJ, R, RJ] = new Join[L, LJ, R, RJ] {
override def join(
left: AnnotatedDataFrame[L, LJ],
right: AnnotatedDataFrame[R, RJ]
)(implicit P: Prepend[LJ, RJ]): AnnotatedDataFrame[L, R :: P.Out] =
AnnotatedDataFrame(
left.toDF.join(
if (isBroadcast) broadcast(right.toDF) else right.toDF,
joinExprs,
joinType.sparkName
)
)
}
def usingColumns[L, LJ <: JoinList, R, RJ <: JoinList](
joinKeys: Seq[String],
joinType: JoinType = Inner,
isBroadcast: Boolean = false
): Join[L, LJ, R, RJ] = new Join[L, LJ, R, RJ] {
override def join(
left: AnnotatedDataFrame[L, LJ],
right: AnnotatedDataFrame[R, RJ]
)(implicit P: Prepend[LJ, RJ]): AnnotatedDataFrame[L, R :: P.Out] =
AnnotatedDataFrame(
left.toDF.join(
if (isBroadcast) broadcast(right.toDF) else right.toDF,
joinKeys,
joinType.sparkName
)
)
}
sealed abstract class JoinType(val sparkName: String)
case object Inner extends JoinType("inner")
case object LeftOuter extends JoinType("left_outer")
case object FullOuter extends JoinType("full_outer")
}
/** Transform type class */
trait Transform[I, IJ <: JoinList, O, P] {
def transform(input: AnnotatedDataFrame[I, IJ], parameters: P)(implicit
spark: SparkSession
): AnnotatedDataFrame[O, HNil]
}
/** Syntaxt for AnnotatedDataFrame */
object implicits {
implicit class AnnotatedDataFrameSyntax[L, LJ <: JoinList](left: AnnotatedDataFrame[L, LJ]) {
def join[R, RJ <: JoinList](
right: AnnotatedDataFrame[R, RJ]
)(implicit J: Join[L, LJ, R, RJ], P: Prepend[LJ, RJ]): AnnotatedDataFrame[L, R :: P.Out] =
J.join(left, right)
def transform[R, P](
parameters: P
)(implicit T: Transform[L, LJ, R, P], spark: SparkSession): AnnotatedDataFrame[R, HNil] =
T.transform(left, parameters)
def transform[R](implicit T: Transform[L, LJ, R, Unit], spark: SparkSession): AnnotatedDataFrame[R, HNil] =
T.transform(left, ())
}
}
}
object Example {
import flow._
import flow.implicits._
implicit val spark: SparkSession = ???
/** DeviceModel dataset definition */
sealed trait DeviceModel
object DeviceModel {
implicit val deviceModelDataSource = new DataSource[DeviceModel, Unit] {
override def read(parameters: Unit)(implicit spark: SparkSession): AnnotatedDataFrame[DeviceModel, HNil] =
AnnotatedDataFrame(
spark
.createDataFrame(Seq(
(0, "model_0"),
(1, "model_1")
))
.toDF("device_model_id", "model_name")
)
}
implicit def deviceModelToMeasurementJoin[LJ <: JoinList, RJ <: JoinList] =
Join.usingColumns[DeviceMeasurement, LJ, DeviceModel, RJ](Seq("device_model_id"))
}
/** DeviceMeasurement dataset definition */
sealed trait DeviceMeasurement
object DeviceMeasurement {
implicit val deviceMeasurementDataSource = new DataSource[DeviceMeasurement, Unit] {
override def read(parameters: Unit)(implicit spark: SparkSession): AnnotatedDataFrame[DeviceMeasurement, HNil] =
AnnotatedDataFrame(
spark.createDataFrame(Seq(
(0, 1.0),
(0, 2.0),
(1, 3.0)
))
.toDF("device_model_id", "measurement_value")
)
}
}
val deviceModel: AnnotatedDataFrame[DeviceModel, HNil] =
DataSource[DeviceModel].read
val deviceMeasurement: AnnotatedDataFrame[DeviceMeasurement, HNil] =
DataSource[DeviceMeasurement].read
val joined: AnnotatedDataFrame[DeviceMeasurement, DeviceModel :: HNil] =
deviceMeasurement.join(deviceModel)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment