Created
June 9, 2020 09:21
-
-
Save phillip-haydon/621a4977af7bf2f135e91778262ccbee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public static class PostgresServiceCollectionExtensions | |
{ | |
public static WorkflowOptions UsePostgreSqlLocking(this WorkflowOptions options, string connectionString, string schemaName = "wfc") | |
{ | |
options.UseDistributedLockManager(sp => new PostgreSqlLockProvider( connectionString, schemaName, sp.GetService<ILoggerFactory>())); | |
return options; | |
} | |
} | |
public class PostgreSqlLockProvider : IDistributedLockProvider | |
{ | |
private readonly string _connectionString; | |
private readonly ILogger _logger; | |
private readonly Guid _nodeId; | |
private readonly long _ttl = 30000; | |
private readonly int _heartbeat = 10000; | |
private readonly List<string> _localLocks; | |
private Task _heartbeatTask; | |
private CancellationTokenSource _cancellationTokenSource; | |
private readonly AutoResetEvent _mutex = new AutoResetEvent(true); | |
public PostgreSqlLockProvider(string connectionString, string schemaName, ILoggerFactory logFactory) | |
{ | |
_connectionString = connectionString; | |
_logger = logFactory.CreateLogger<PostgreSqlLockProvider>(); | |
_localLocks = new List<string>(); | |
_nodeId = Guid.NewGuid(); | |
SetupCommands(schemaName); | |
} | |
private string _acquireLockCommand = ""; | |
private string _releaseLockCommand = ""; | |
private string _checkTableCommand = ""; | |
private string _createTableCommand = ""; | |
private string _heartbeatCommand = ""; | |
public async Task<bool> AcquireLock(string id, CancellationToken cancellationToken) | |
{ | |
try | |
{ | |
await using var conn = new NpgsqlConnection(_connectionString); | |
await conn.OpenAsync(cancellationToken); | |
await using var cmd = new NpgsqlCommand(_acquireLockCommand, conn); | |
cmd.Parameters.AddWithValue("id", NpgsqlDbType.Text, id); | |
cmd.Parameters.AddWithValue("lock_owner", NpgsqlDbType.Uuid, _nodeId); | |
cmd.Parameters.AddWithValue("expires", NpgsqlDbType.Timestamp, DateTime.UtcNow.AddMilliseconds(_ttl)); | |
await cmd.PrepareAsync(cancellationToken); | |
var result = await cmd.ExecuteNonQueryAsync(cancellationToken); | |
if (result == 1) | |
{ | |
_localLocks.Add(id); | |
return true; | |
} | |
} | |
catch (NpgsqlException exception) | |
{ | |
_logger.LogError(exception, "Could not acquire lock"); | |
} | |
return false; | |
} | |
public async Task ReleaseLock(string id) | |
{ | |
_mutex.WaitOne(); | |
try | |
{ | |
_localLocks.Remove(id); | |
} | |
finally | |
{ | |
_mutex.Set(); | |
} | |
try | |
{ | |
await using var conn = new NpgsqlConnection(_connectionString); | |
await using var cmd = new NpgsqlCommand(_releaseLockCommand, conn); | |
cmd.Parameters.AddWithValue("id", NpgsqlDbType.Text, id); | |
await conn.OpenAsync(default); | |
await cmd.PrepareAsync(default); | |
_ = await cmd.ExecuteNonQueryAsync(default); | |
} | |
catch (NpgsqlException exception) | |
{ | |
_logger.LogError(exception, "Could not acquire lock"); | |
} | |
} | |
public async Task Start() | |
{ | |
await EnsureTable(); | |
if (_heartbeatTask != null) | |
{ | |
throw new InvalidOperationException(); | |
} | |
_cancellationTokenSource = new CancellationTokenSource(); | |
_heartbeatTask = new Task(SendHeartbeat); | |
_heartbeatTask.Start(); | |
} | |
public Task Stop() | |
{ | |
_cancellationTokenSource.Cancel(); | |
_heartbeatTask.Wait(); | |
_heartbeatTask = null; | |
return Task.CompletedTask; | |
} | |
private async void SendHeartbeat() | |
{ | |
while (!_cancellationTokenSource.IsCancellationRequested) | |
{ | |
try | |
{ | |
await Task.Delay(_heartbeat, _cancellationTokenSource.Token); | |
if (_mutex.WaitOne()) | |
{ | |
try | |
{ | |
await using var conn = new NpgsqlConnection(_connectionString); | |
await using var cmd = new NpgsqlCommand(_heartbeatCommand, conn); | |
cmd.Parameters.AddWithValue("ids", NpgsqlDbType.Array | NpgsqlDbType.Text, _localLocks.ToArray()); | |
cmd.Parameters.AddWithValue("lock_owner", NpgsqlDbType.Uuid, _nodeId); | |
cmd.Parameters.AddWithValue("expires", NpgsqlDbType.Timestamp, DateTime.UtcNow.AddMilliseconds(_ttl)); | |
await conn.OpenAsync(default); | |
await cmd.PrepareAsync(default); | |
_ = await cmd.ExecuteNonQueryAsync(default); | |
} | |
catch (NpgsqlException exception) | |
{ | |
_logger.LogError(exception, "Exception occured when sending heartbeat."); | |
} | |
finally | |
{ | |
_mutex.Set(); | |
} | |
} | |
} | |
catch (Exception ex) | |
{ | |
_logger.LogError(ex, ex.Message); | |
} | |
} | |
} | |
private async Task EnsureTable() | |
{ | |
try | |
{ | |
await using var conn = new NpgsqlConnection(_connectionString); | |
await conn.OpenAsync(default); | |
await using var checkCommand = new NpgsqlCommand(_checkTableCommand, conn); | |
var exists = (bool)await checkCommand.ExecuteScalarAsync(default); | |
if (!exists) | |
{ | |
await using var createCommand = new NpgsqlCommand(_createTableCommand, conn); | |
_ = await createCommand.ExecuteNonQueryAsync(default); | |
} | |
await conn.CloseAsync(); | |
} | |
catch (NpgsqlException exception) | |
{ | |
_logger.LogError(exception, "Error occured when calling EnsureTable"); | |
} | |
} | |
private void SetupCommands(string schemaName) | |
{ | |
_acquireLockCommand = $@" | |
begin; | |
lock table {schemaName}.workflow_lock_provider in access exclusive mode; | |
select * from {schemaName}.workflow_lock_provider | |
where id = @id | |
and lock_owner = @lock_owner | |
for no key update; | |
insert into {schemaName}.workflow_lock_provider (id, lock_owner, expires) | |
select @id, @lock_owner, @expires | |
where not exists ( | |
select from {schemaName}.workflow_lock_provider | |
where id = @id | |
and lock_owner = @lock_owner | |
and expires > now() | |
); | |
commit; | |
"; | |
_releaseLockCommand = $@" | |
delete from {schemaName}.workflow_lock_provider where id = @id | |
"; | |
_checkTableCommand = $@" | |
select exists ( | |
select from pg_tables | |
where schemaname = '{schemaName}' | |
and tablename = 'workflow_lock_provider' | |
); | |
"; | |
_createTableCommand = $@" | |
create schema if not exists {schemaName}; | |
create table if not exists {schemaName}.workflow_lock_provider ( | |
id text not null, | |
lock_owner uuid not null, | |
expires timestamp not null | |
); | |
create index ix_workflow_lock_provider_id on {schemaName}.workflow_lock_provider using hash(id); | |
create index ix_workflow_lock_provider_lock_owner on {schemaName}.workflow_lock_provider using hash(lock_owner); | |
create index ix_workflow_lock_provider_expires on {schemaName}.workflow_lock_provider using btree(expires); | |
"; | |
_heartbeatCommand = $@" | |
-- Update any expiring locks | |
update {schemaName}.workflow_lock_provider | |
set expires = @expires | |
where lock_owner = @lock_owner | |
and id = any(@ids); | |
-- Delete any expired locks | |
delete from {schemaName}.workflow_lock_provider | |
where expires < now(); | |
"; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment