Skip to content

Instantly share code, notes, and snippets.

@overing
Created November 26, 2019 02:13
Show Gist options
  • Save overing/1a5b035abe044456dea0736461c8da08 to your computer and use it in GitHub Desktop.
Save overing/1a5b035abe044456dea0736461c8da08 to your computer and use it in GitHub Desktop.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
namespace Microsoft.EntityFrameworkCore
{
delegate object[] KeySelector(object entry);
class EntryMeta
{
public string TableName { get; private set; }
public string[] KeyColumnNames { get; private set; }
public KeySelector KeySelector { get; private set; }
EntryMeta(string tableName, string[] keyColumns, KeySelector keySelector)
{
TableName = tableName;
KeyColumnNames = keyColumns;
KeySelector = keySelector;
}
public (string sql, object[] parameters, object[][] recordKeys) GetFromSqlParamaters<TEntity>(IList<TEntity> entities) where TEntity : class
{
var count = entities.Count;
var recordKeys = new object[count][];
var sqlBuilder = new StringBuilder("SELECT * FROM `").Append(TableName).Append('`');
sqlBuilder.Append(" WHERE (");
foreach (var column in KeyColumnNames)
sqlBuilder.Append('`').Append(column).Append("`,");
sqlBuilder.Length -= 1;
sqlBuilder.Append(')');
var parameterList = new object[count * KeyColumnNames.Length];
sqlBuilder.Append(" IN (");
for (var entryIndex = 0; entryIndex < count; entryIndex++)
{
var entity = entities[entryIndex];
var key = KeySelector(entity);
recordKeys[entryIndex] = key;
sqlBuilder.Append('(');
for (var propIndex = 0; propIndex < key.Length; propIndex++)
{
var propValue = key[propIndex];
var parameterIndex = entryIndex * key.Length + propIndex;
sqlBuilder.Append('{').Append(parameterIndex).Append("},");
parameterList[parameterIndex] = propValue;
}
sqlBuilder.Length -= 1;
sqlBuilder.Append("),");
}
sqlBuilder.Length -= 1;
sqlBuilder.Append(");");
var sql = sqlBuilder.ToString();
var parameters = parameterList.ToArray();
return (sql, parameters, recordKeys);
}
public static EntryMeta Create(DbContext context, Type type)
{
var entityType = context.Model.FindEntityType(type);
var primaryKey = entityType.FindPrimaryKey();
var properties = primaryKey.Properties;
if (properties.Count == 0)
throw new Exception($"{type.FullName} does not have a primary key specified.");
var tableName = entityType.Relational().TableName;
var keyColumnNames = properties.Select(p => p.Relational().ColumnName).ToArray();
var typeObject = typeof(object);
var exprParamE = Expression.Parameter(typeObject, "e");
var exprVarT = Expression.Variable(type, "t");
var exprAssign = Expression.Assign(exprVarT, Expression.Convert(exprParamE, type));
var exprParameters = new[] { exprVarT };
var exprPropReaders = properties.Select(p => Expression.TypeAs(Expression.Property(exprVarT, p.Name), typeObject));
var exprPropsAsArray = Expression.NewArrayInit(typeObject, exprPropReaders);
var exprBlock = Expression.Block(typeof(object[]), exprParameters, exprVarT, exprAssign, exprPropsAsArray);
var exprLambda = Expression.Lambda(exprBlock, exprParamE);
var keySelector = new KeySelector((Func<object, object[]>)exprLambda.Compile());
return new EntryMeta(tableName, keyColumnNames, keySelector);
}
}
public static class DbContextAddOrUpdateExtensions
{
static readonly IDictionary<Type, EntryMeta> Cache = new Dictionary<Type, EntryMeta>(48);
static EntryMeta ResolveMeta<TEntity>(DbContext context)
{
var cache = Cache;
var type = typeof(TEntity);
if (cache.TryGetValue(type, out var selector))
return selector;
lock (cache)
{
if (cache.TryGetValue(type, out selector))
return selector;
selector = EntryMeta.Create(context, type);
cache[type] = selector;
}
return selector;
}
public static void AddOrUpdate<TEntity>(this DbContext context, params TEntity[] entities) where TEntity : class
{
context = context ?? throw new ArgumentNullException(nameof(context));
entities = entities ?? throw new ArgumentNullException(nameof(entities));
if (entities.Length == 0)
throw new ArgumentException("Entities not any element.", nameof(entities));
if (entities.Length == 1)
AddOrUpdateInternal(context, entities[0]);
else
AddOrUpdateInternal(context, entities as IList<TEntity>);
}
public static void AddOrUpdate<TEntity>(this DbContext context, List<TEntity> entities) where TEntity : class
{
context = context ?? throw new ArgumentNullException(nameof(context));
entities = entities ?? throw new ArgumentNullException(nameof(entities));
if (entities.Count == 0)
throw new ArgumentException("Entities not any element.", nameof(entities));
if (entities.Count == 1)
AddOrUpdateInternal(context, entities[0]);
else
AddOrUpdateInternal(context, entities as IList<TEntity>);
}
static void AddOrUpdateInternal<TEntity>(this DbContext context, TEntity entity) where TEntity : class
{
var meta = ResolveMeta<TEntity>(context);
var key = meta.KeySelector(entity);
var dbEntry = context.Set<TEntity>().Find(key);
if (dbEntry == null)
context.Add(entity);
else
context.Entry(dbEntry).CurrentValues.SetValues(entity);
}
static void AddOrUpdateInternal<TEntity>(this DbContext context, IList<TEntity> entities) where TEntity : class
{
var meta = ResolveMeta<TEntity>(context);
var x = meta.GetFromSqlParamaters(entities);
var dbEntities = context.Set<TEntity>().FromSql(x.sql, x.parameters)
.ToDictionary(e => meta.KeySelector(e)); // 不能改成 .ToDictionary(meta.KeySelector) 會導致 mapping 失敗
var defaultPair = default(KeyValuePair<object[], TEntity>);
for (var i = 0; i < entities.Count; i++)
{
var entity = entities[i];
var exist = dbEntities.FirstOrDefault(p => Enumerable.SequenceEqual(p.Key, x.recordKeys[i]));
if (defaultPair.Equals(exist))
context.Add(entity);
else
context.Entry(exist.Value).CurrentValues.SetValues(entity);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment