Created
January 30, 2021 22:19
-
-
Save izeigerman/d1fe83519767de6514cfd02384075457 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.sql.{Column, DataFrame, SparkSession} | |
import org.apache.spark.sql.functions.broadcast | |
import shapeless.ops.hlist.Prepend | |
import shapeless.{::, HList, HNil} | |
object flow { | |
type JoinList = HList | |
case class AnnotatedDataFrame[D, J <: JoinList](toDF: DataFrame) extends Serializable | |
/** DataSource type class */ | |
trait DataSource[D, P] { | |
def read(parameters: P)(implicit spark: SparkSession): AnnotatedDataFrame[D, HNil] | |
} | |
object DataSource { | |
def apply[D]: Helper[D] = new Helper[D] | |
// Helper used to improve the type inference and make the reading API cleaner. | |
final class Helper[D] { | |
def read[P](parameters: P)(implicit S: DataSource[D, P], spark: SparkSession): AnnotatedDataFrame[D, HNil] = | |
S.read(parameters) | |
def read(implicit S: DataSource[D, Unit], spark: SparkSession): AnnotatedDataFrame[D, HNil] = | |
S.read(()) | |
} | |
} | |
/** Join type class */ | |
sealed trait Join[L, LJ <: JoinList, R, RJ <: JoinList] { | |
def join( | |
left: AnnotatedDataFrame[L, LJ], | |
right: AnnotatedDataFrame[R, RJ] | |
)(implicit P: Prepend[LJ, RJ]): AnnotatedDataFrame[L, R :: P.Out] | |
} | |
object Join { | |
def apply[L, LJ <: JoinList, R, RJ <: JoinList]( | |
joinExprs: Column, | |
joinType: JoinType = Inner, | |
isBroadcast: Boolean = false | |
): Join[L, LJ, R, RJ] = new Join[L, LJ, R, RJ] { | |
override def join( | |
left: AnnotatedDataFrame[L, LJ], | |
right: AnnotatedDataFrame[R, RJ] | |
)(implicit P: Prepend[LJ, RJ]): AnnotatedDataFrame[L, R :: P.Out] = | |
AnnotatedDataFrame( | |
left.toDF.join( | |
if (isBroadcast) broadcast(right.toDF) else right.toDF, | |
joinExprs, | |
joinType.sparkName | |
) | |
) | |
} | |
def usingColumns[L, LJ <: JoinList, R, RJ <: JoinList]( | |
joinKeys: Seq[String], | |
joinType: JoinType = Inner, | |
isBroadcast: Boolean = false | |
): Join[L, LJ, R, RJ] = new Join[L, LJ, R, RJ] { | |
override def join( | |
left: AnnotatedDataFrame[L, LJ], | |
right: AnnotatedDataFrame[R, RJ] | |
)(implicit P: Prepend[LJ, RJ]): AnnotatedDataFrame[L, R :: P.Out] = | |
AnnotatedDataFrame( | |
left.toDF.join( | |
if (isBroadcast) broadcast(right.toDF) else right.toDF, | |
joinKeys, | |
joinType.sparkName | |
) | |
) | |
} | |
sealed abstract class JoinType(val sparkName: String) | |
case object Inner extends JoinType("inner") | |
case object LeftOuter extends JoinType("left_outer") | |
case object FullOuter extends JoinType("full_outer") | |
} | |
/** Transform type class */ | |
trait Transform[I, IJ <: JoinList, O, P] { | |
def transform(input: AnnotatedDataFrame[I, IJ], parameters: P)(implicit | |
spark: SparkSession | |
): AnnotatedDataFrame[O, HNil] | |
} | |
/** Syntaxt for AnnotatedDataFrame */ | |
object implicits { | |
implicit class AnnotatedDataFrameSyntax[L, LJ <: JoinList](left: AnnotatedDataFrame[L, LJ]) { | |
def join[R, RJ <: JoinList]( | |
right: AnnotatedDataFrame[R, RJ] | |
)(implicit J: Join[L, LJ, R, RJ], P: Prepend[LJ, RJ]): AnnotatedDataFrame[L, R :: P.Out] = | |
J.join(left, right) | |
def transform[R, P]( | |
parameters: P | |
)(implicit T: Transform[L, LJ, R, P], spark: SparkSession): AnnotatedDataFrame[R, HNil] = | |
T.transform(left, parameters) | |
def transform[R](implicit T: Transform[L, LJ, R, Unit], spark: SparkSession): AnnotatedDataFrame[R, HNil] = | |
T.transform(left, ()) | |
} | |
} | |
} | |
object Example { | |
import flow._ | |
import flow.implicits._ | |
implicit val spark: SparkSession = ??? | |
/** DeviceModel dataset definition */ | |
sealed trait DeviceModel | |
object DeviceModel { | |
implicit val deviceModelDataSource = new DataSource[DeviceModel, Unit] { | |
override def read(parameters: Unit)(implicit spark: SparkSession): AnnotatedDataFrame[DeviceModel, HNil] = | |
AnnotatedDataFrame( | |
spark | |
.createDataFrame(Seq( | |
(0, "model_0"), | |
(1, "model_1") | |
)) | |
.toDF("device_model_id", "model_name") | |
) | |
} | |
implicit def deviceModelToMeasurementJoin[LJ <: JoinList, RJ <: JoinList] = | |
Join.usingColumns[DeviceMeasurement, LJ, DeviceModel, RJ](Seq("device_model_id")) | |
} | |
/** DeviceMeasurement dataset definition */ | |
sealed trait DeviceMeasurement | |
object DeviceMeasurement { | |
implicit val deviceMeasurementDataSource = new DataSource[DeviceMeasurement, Unit] { | |
override def read(parameters: Unit)(implicit spark: SparkSession): AnnotatedDataFrame[DeviceMeasurement, HNil] = | |
AnnotatedDataFrame( | |
spark.createDataFrame(Seq( | |
(0, 1.0), | |
(0, 2.0), | |
(1, 3.0) | |
)) | |
.toDF("device_model_id", "measurement_value") | |
) | |
} | |
} | |
val deviceModel: AnnotatedDataFrame[DeviceModel, HNil] = | |
DataSource[DeviceModel].read | |
val deviceMeasurement: AnnotatedDataFrame[DeviceMeasurement, HNil] = | |
DataSource[DeviceMeasurement].read | |
val joined: AnnotatedDataFrame[DeviceMeasurement, DeviceModel :: HNil] = | |
deviceMeasurement.join(deviceModel) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment