Skip to content

Instantly share code, notes, and snippets.

@kevin-lee
Last active March 15, 2022 13:17
Show Gist options
  • Save kevin-lee/94df3994b22aa625eaf484b8d2df8552 to your computer and use it in GitHub Desktop.
Save kevin-lee/94df3994b22aa625eaf484b8d2df8552 to your computer and use it in GitHub Desktop.
Examples: Refinement types in Scala 3
object RefinementTypes {
type PositiveInt = PositiveInt.PositiveInt
object PositiveInt {
import compiletime.*
opaque type PositiveInt = Int
inline def apply(positiveInt: Int): PositiveInt =
inline if positiveInt > 0 then positiveInt
else error("PositiveInt must be > 0. value: " + codeOf(positiveInt))
def from(n: Int): Either[String, PositiveInt] =
if n > 0 then Right(n) else Left(s"PositiveInt must be > 0. value: $n")
def unsafeFrom(n: Int): PositiveInt = from(n).fold(sys.error, identity)
given positiveIntCanEqual: CanEqual[PositiveInt, PositiveInt] =
CanEqual.derived
extension (positiveInt: PositiveInt) {
def value: Int = positiveInt
}
}
type NonEmptyString = NonEmptyString.NonEmptyString
object NonEmptyString {
import compiletime.*
import scala.compiletime.ops.any.*
opaque type NonEmptyString = String
inline def apply(nonEmptyString: String): NonEmptyString =
inline if nonEmptyString != "" then nonEmptyString
else error("NonEmptyString must not be an empty String.")
def from(s: String): Either[String, NonEmptyString] =
if s.nonEmpty then Right(s)
else Left("NonEmptyString must not be an empty String.")
def unsafeFrom(s: String): NonEmptyString =
from(s).fold(sys.error, identity)
given nonEmptyStringCanEqual: CanEqual[NonEmptyString, NonEmptyString] =
CanEqual.derived
extension (nonEmptyString: NonEmptyString) {
def value: String = nonEmptyString
}
}
// A >= 0 && A <= 65353
type PortNumber = PortNumber.PortNumber
object PortNumber {
import compiletime.*
opaque type PortNumber = Int
inline def apply(portNumber: Int): PortNumber =
inline if portNumber >= 0 && portNumber <= 65353 then portNumber
else
error(
"PortNumber must be Int between 0 and 65353 but got " + codeOf(
portNumber
) + " instead."
)
def from(portNumber: Int): Either[String, PortNumber] =
if portNumber >= 0 && portNumber <= 65353 then Right(portNumber)
else
Left(
s"PortNumber must be Int between 0 and 65353 but got $portNumber instead"
)
def unsafeFrom(portNumber: Int): PortNumber =
from(portNumber).fold(sys.error, identity)
given portNumberCanEqual: CanEqual[PortNumber, PortNumber] =
CanEqual.derived
extension (portNumber: PortNumber) {
def value: Int = portNumber
}
}
}
@main def run(): Unit = {
import RefinementTypes.*
println(PositiveInt(1)) // 1
// println(PositiveInt(0)) // compile-time error: PositiveInt must be > 0. value: 0
println(PositiveInt.from(1)) // Right(1)
println(PositiveInt.from(0)) // Left(PositiveInt must be > 0. value: 0)
println(PositiveInt.unsafeFrom(1)) // 1
// println(PositiveInt.unsafeFrom(0)) // runtime error: RuntimeException: PositiveInt must be > 0. value: 0
println(NonEmptyString("aaa")) // aaa
// println(NonEmptyString("")) // compile-time error: NonEmptyString must not be an empty String.
println(NonEmptyString.from("aaa")) // Right(aaa)
println(NonEmptyString.from("")) // Left(NonEmptyString must not be an empty String.)
println(NonEmptyString.unsafeFrom("aaa")) // aaa
// println(NonEmptyString.unsafeFrom("")) // runtime error: RuntimeException: NonEmptyString must not be an empty String.
println(PortNumber(8080)) // 8080
// println(PortNumber(65354)) // compile-time error: PortNumber must be Int between 0 and 65353 but got 65354 instead.
println(PortNumber.from(8080)) // Right(8080)
println(PortNumber.from(65354)) // Left(PortNumber must be Int between 0 and 65353 but got 65354 instead)
println(PortNumber.unsafeFrom(8080)) // 8080
// println(PortNumber.unsafeFrom(65354)) // runtime error: RuntimeException: PortNumber must be Int between 0 and 65353 but got 65354 instead
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment