Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save seriyps/2cd0a6e28ac47b9d00681c3e864c4ee6 to your computer and use it in GitHub Desktop.
Save seriyps/2cd0a6e28ac47b9d00681c3e864c4ee6 to your computer and use it in GitHub Desktop.
`epgsql` command with OOM protection. It discards the incoming results if the number of bytes received exceeds the threshold
-module(epgsql_cmd_prepared_query_oom_protection).
-behaviour(epgsql_command).
-export([init/1, execute/2, handle_message/4]).
-export_type([response/0]).
-type response() :: {ok, Count :: non_neg_integer(), Cols :: [epgsql:column()], Rows :: [tuple()]}
| {ok, Count :: non_neg_integer()}
| {ok, Cols :: [epgsql:column()], Rows :: [tuple()]}
| {error, epgsql:query_error() | {response_too_big, pos_integer()}}.
-include_lib("epgsql/include/epgsql.hrl").
-include_lib("epgsql/include/protocol.hrl").
-record(state,
{orig_state :: tuple(),
bytes_limit :: pos_integer() | unlimited,
bytes_received :: non_neg_integer()}).
init({Stmt, TypedParams, BytesLimit}) ->
OrigState = epgsql_cmd_prepared_query:init({Stmt, TypedParams}),
#state{orig_state = OrigState,
bytes_limit = BytesLimit,
bytes_received = 0}.
execute(Sock, #state{orig_state = OrigState} = State) ->
{send_multi, Commands, Sock, NewOrigState} = epgsql_cmd_prepared_query:execute(Sock, OrigState),
NewState = State#state{orig_state = NewOrigState},
{send_multi, Commands, Sock, NewState}.
handle_message(?DATA_ROW,
Data,
SockState,
#state{bytes_limit = BytesLimit,
bytes_received = BytesReceived} = State) ->
TotalBytes = BytesReceived + byte_size(Data),
if
BytesReceived > BytesLimit ->
%% Already above the limit: drop the row
{noaction, SockState, State#state{bytes_received = TotalBytes}};
TotalBytes > BytesLimit ->
%% Limit is reached right now: cancel the query and drop the row
NewSockState = cancel_request(SockState),
{noaction, NewSockState, State#state{bytes_received = TotalBytes}};
TotalBytes =< BytesLimit ->
%% Below the limit: process the row
NewState = State#state{bytes_received = TotalBytes},
call_original_handle_message(?DATA_ROW, Data, SockState, NewState)
end;
handle_message(?COMMAND_COMPLETE,
_Bin,
SockState,
#state{bytes_limit = BytesLimit,
bytes_received = BytesReceived} = State) when BytesReceived > BytesLimit ->
%% The command receives the complete command earlier
%% than Postgresql has reacted to the cancel command.
Error = {error, {response_too_big, BytesReceived}},
{add_result, Error, Error, SockState, State};
handle_message(?READY_FOR_QUERY,
_Status,
SockState,
#state{bytes_limit = BytesLimit,
bytes_received = BytesReceived}) when BytesReceived > BytesLimit ->
Error = {error, {response_too_big, BytesReceived}},
{finish, Error, done, SockState};
handle_message(?ERROR,
#error{codename = query_canceled},
SockState,
#state{bytes_limit = BytesLimit,
bytes_received = BytesReceived} = State) when BytesReceived > BytesLimit ->
%% Receives the cancellation that it has initiated.
{noaction, SockState, State};
handle_message(Type, Bin, SockState, State) ->
call_original_handle_message(Type, Bin, SockState, State).
%% Dialyzer treats SockState as any(). When it checks the call to epgsql_sock:handle_cast/2
%% it believes it will never can match on the last branch of the function. So, it prints very
%% misleading warning where the expected handle_cast's type that does not contain any types from
%% the third branch of the function. There is no easy way around it, unfortunately. So we just
%% isolate the case and ignore it.
-dialyzer({nowarn_function, cancel_request/1}).
cancel_request(SockState) ->
{noreply, NewSockState} = epgsql_sock:handle_cast(cancel, SockState),
NewSockState.
call_original_handle_message(Type, Bin, SockState, #state{orig_state = OrigState} = State) ->
Result = epgsql_cmd_prepared_query:handle_message(Type, Bin, SockState, OrigState),
case Result of
unknown ->
Result;
{finish, _, _, _} ->
Result;
{_, _, NewOrigState} ->
NewState = State#state{orig_state = NewOrigState},
setelement(3, Result, NewState);
{_, _, _, NewOrigState} ->
NewState = State#state{orig_state = NewOrigState},
setelement(4, Result, NewState);
{_, _, _, _, NewOrigState} ->
NewState = State#state{orig_state = NewOrigState},
setelement(5, Result, NewState)
end.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment