Created
November 26, 2019 02:13
-
-
Save overing/1a5b035abe044456dea0736461c8da08 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Linq.Expressions; | |
using System.Text; | |
namespace Microsoft.EntityFrameworkCore | |
{ | |
delegate object[] KeySelector(object entry); | |
class EntryMeta | |
{ | |
public string TableName { get; private set; } | |
public string[] KeyColumnNames { get; private set; } | |
public KeySelector KeySelector { get; private set; } | |
EntryMeta(string tableName, string[] keyColumns, KeySelector keySelector) | |
{ | |
TableName = tableName; | |
KeyColumnNames = keyColumns; | |
KeySelector = keySelector; | |
} | |
public (string sql, object[] parameters, object[][] recordKeys) GetFromSqlParamaters<TEntity>(IList<TEntity> entities) where TEntity : class | |
{ | |
var count = entities.Count; | |
var recordKeys = new object[count][]; | |
var sqlBuilder = new StringBuilder("SELECT * FROM `").Append(TableName).Append('`'); | |
sqlBuilder.Append(" WHERE ("); | |
foreach (var column in KeyColumnNames) | |
sqlBuilder.Append('`').Append(column).Append("`,"); | |
sqlBuilder.Length -= 1; | |
sqlBuilder.Append(')'); | |
var parameterList = new object[count * KeyColumnNames.Length]; | |
sqlBuilder.Append(" IN ("); | |
for (var entryIndex = 0; entryIndex < count; entryIndex++) | |
{ | |
var entity = entities[entryIndex]; | |
var key = KeySelector(entity); | |
recordKeys[entryIndex] = key; | |
sqlBuilder.Append('('); | |
for (var propIndex = 0; propIndex < key.Length; propIndex++) | |
{ | |
var propValue = key[propIndex]; | |
var parameterIndex = entryIndex * key.Length + propIndex; | |
sqlBuilder.Append('{').Append(parameterIndex).Append("},"); | |
parameterList[parameterIndex] = propValue; | |
} | |
sqlBuilder.Length -= 1; | |
sqlBuilder.Append("),"); | |
} | |
sqlBuilder.Length -= 1; | |
sqlBuilder.Append(");"); | |
var sql = sqlBuilder.ToString(); | |
var parameters = parameterList.ToArray(); | |
return (sql, parameters, recordKeys); | |
} | |
public static EntryMeta Create(DbContext context, Type type) | |
{ | |
var entityType = context.Model.FindEntityType(type); | |
var primaryKey = entityType.FindPrimaryKey(); | |
var properties = primaryKey.Properties; | |
if (properties.Count == 0) | |
throw new Exception($"{type.FullName} does not have a primary key specified."); | |
var tableName = entityType.Relational().TableName; | |
var keyColumnNames = properties.Select(p => p.Relational().ColumnName).ToArray(); | |
var typeObject = typeof(object); | |
var exprParamE = Expression.Parameter(typeObject, "e"); | |
var exprVarT = Expression.Variable(type, "t"); | |
var exprAssign = Expression.Assign(exprVarT, Expression.Convert(exprParamE, type)); | |
var exprParameters = new[] { exprVarT }; | |
var exprPropReaders = properties.Select(p => Expression.TypeAs(Expression.Property(exprVarT, p.Name), typeObject)); | |
var exprPropsAsArray = Expression.NewArrayInit(typeObject, exprPropReaders); | |
var exprBlock = Expression.Block(typeof(object[]), exprParameters, exprVarT, exprAssign, exprPropsAsArray); | |
var exprLambda = Expression.Lambda(exprBlock, exprParamE); | |
var keySelector = new KeySelector((Func<object, object[]>)exprLambda.Compile()); | |
return new EntryMeta(tableName, keyColumnNames, keySelector); | |
} | |
} | |
public static class DbContextAddOrUpdateExtensions | |
{ | |
static readonly IDictionary<Type, EntryMeta> Cache = new Dictionary<Type, EntryMeta>(48); | |
static EntryMeta ResolveMeta<TEntity>(DbContext context) | |
{ | |
var cache = Cache; | |
var type = typeof(TEntity); | |
if (cache.TryGetValue(type, out var selector)) | |
return selector; | |
lock (cache) | |
{ | |
if (cache.TryGetValue(type, out selector)) | |
return selector; | |
selector = EntryMeta.Create(context, type); | |
cache[type] = selector; | |
} | |
return selector; | |
} | |
public static void AddOrUpdate<TEntity>(this DbContext context, params TEntity[] entities) where TEntity : class | |
{ | |
context = context ?? throw new ArgumentNullException(nameof(context)); | |
entities = entities ?? throw new ArgumentNullException(nameof(entities)); | |
if (entities.Length == 0) | |
throw new ArgumentException("Entities not any element.", nameof(entities)); | |
if (entities.Length == 1) | |
AddOrUpdateInternal(context, entities[0]); | |
else | |
AddOrUpdateInternal(context, entities as IList<TEntity>); | |
} | |
public static void AddOrUpdate<TEntity>(this DbContext context, List<TEntity> entities) where TEntity : class | |
{ | |
context = context ?? throw new ArgumentNullException(nameof(context)); | |
entities = entities ?? throw new ArgumentNullException(nameof(entities)); | |
if (entities.Count == 0) | |
throw new ArgumentException("Entities not any element.", nameof(entities)); | |
if (entities.Count == 1) | |
AddOrUpdateInternal(context, entities[0]); | |
else | |
AddOrUpdateInternal(context, entities as IList<TEntity>); | |
} | |
static void AddOrUpdateInternal<TEntity>(this DbContext context, TEntity entity) where TEntity : class | |
{ | |
var meta = ResolveMeta<TEntity>(context); | |
var key = meta.KeySelector(entity); | |
var dbEntry = context.Set<TEntity>().Find(key); | |
if (dbEntry == null) | |
context.Add(entity); | |
else | |
context.Entry(dbEntry).CurrentValues.SetValues(entity); | |
} | |
static void AddOrUpdateInternal<TEntity>(this DbContext context, IList<TEntity> entities) where TEntity : class | |
{ | |
var meta = ResolveMeta<TEntity>(context); | |
var x = meta.GetFromSqlParamaters(entities); | |
var dbEntities = context.Set<TEntity>().FromSql(x.sql, x.parameters) | |
.ToDictionary(e => meta.KeySelector(e)); // 不能改成 .ToDictionary(meta.KeySelector) 會導致 mapping 失敗 | |
var defaultPair = default(KeyValuePair<object[], TEntity>); | |
for (var i = 0; i < entities.Count; i++) | |
{ | |
var entity = entities[i]; | |
var exist = dbEntities.FirstOrDefault(p => Enumerable.SequenceEqual(p.Key, x.recordKeys[i])); | |
if (defaultPair.Equals(exist)) | |
context.Add(entity); | |
else | |
context.Entry(exist.Value).CurrentValues.SetValues(entity); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment