Created
January 23, 2014 16:03
-
-
Save whysoserious/8581195 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import net.liftweb.common._ | |
import net.liftweb.http.LiftRules._ | |
import net.liftweb.http.rest.RestHelper | |
import net.liftweb.http.{S, JsonResponse, LiftResponse, Req} | |
import net.liftweb.json.JsonDSL._ | |
import scala.concurrent.duration.Duration | |
object RateLimit { | |
sealed trait Config | |
case class Enabled(maxRequests: Int, duration: Duration) extends Config | |
case object Disabled extends Config | |
class RateLimitingDispatchPF(config: Enabled, handler: DispatchPF) | |
(errorHandler: (Req, Enabled, Long) => () => Box[LiftResponse]) extends DispatchPF { | |
@volatile var remainingRequests: Int = config.maxRequests | |
@volatile var nextReset: Long = System.currentTimeMillis | |
def isDefinedAt(req: Req): Boolean = handler.isDefinedAt(req) | |
def apply(req: Req): () => Box[LiftResponse] = withCheckLimit(req) | |
def withCheckLimit(req: Req, now: Long = System.currentTimeMillis): () => Box[LiftResponse] = { | |
config match { | |
case Enabled(maxRequests, duration) if nextReset < now => | |
resetState(now) | |
addHeaders | |
handler(req) | |
case _: Enabled if remainingRequests > 0 => | |
remainingRequests = scala.math.max(remainingRequests - 1, 0) | |
addHeaders | |
handler(req) | |
case _ => | |
addHeaders | |
errorHandler(req, config, nextReset) | |
} | |
} | |
def addHeaders: Unit = { | |
S.setResponseHeader("X-RateLimit-Limit", config.maxRequests.toString) | |
S.setResponseHeader("X-RateLimit-Remaining", remainingRequests.toString) | |
S.setResponseHeader("X-RateLimit-Reset", nextReset.toString) | |
} | |
private def resetState(now: Long): Unit = { | |
nextReset = now + config.duration.toMillis | |
remainingRequests = config.maxRequests - 1 | |
} | |
} | |
} | |
trait RateLimit extends Loggable { | |
self: RestHelper => | |
import RateLimit._ | |
def rateLimitConfig: RateLimit.Config | |
def errorHandler(req: Req, config: RateLimit.Enabled, nextReset: Long): () => Box[LiftResponse] = { | |
logger.debug(s"API rate limit exceeded for request path ${req.path.wholePath.mkString("/")}") | |
JsonResponse( | |
("message" -> "API rate limit exceeded."), | |
403 | |
) | |
} | |
def withRateLimit(handler: DispatchPF): DispatchPF = rateLimitConfig match { | |
case config: Enabled => new RateLimitingDispatchPF(config, handler)(errorHandler) | |
case Disabled => handler | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import net.liftweb.http.rest.RestHelper | |
import net.liftweb.http.OkResponse | |
import scala.concurrent.duration._ | |
class RestHelperWithRateLimit extends RestHelper with RateLimit { | |
val rateLimitConfig: RateLimit.Config = RateLimit.Enabled(100, 1 minute) | |
serve { | |
withRateLimit { | |
case "hello" :: Nil => OkResponse | |
} | |
} | |
serve { | |
case "no_limits" :: Nil => OkResponse | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
hey! Is there a license for these files?