メタプログラミングには以前から興味があったので、ぜひ式木について知りたい!というわけで。
こんなケースで考えてみました。
Equals と GetHashCode の実装
Equals と GetHashCode を実装するとき、大抵はプロパティや内部フィールドの比較を繋げるだけのことが多いので、その式を作ってくれる EqualityComparer があると嬉しいなあ。
※ まあ ReSharper があれば自動で実装してくれますが…
たぶんこんなクラスがあったら、
例
class Version
{
public Version(int major, int minor)
{
Major = major;
Minor = minor;
}
public int Major { get; set; }
public int Minor { get; set; }
}
こんな感じに実装すると思います。
class Version
{
public Version(int major, int minor)
{
Major = major;
Minor = minor;
}
public int Major { get; set; }
public int Minor { get; set; }
public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != this.GetType()) return false;
return Equals((Version) obj);
}
private bool Equals(Version other)
{
return Major == other.Major && Minor == other.Minor;
}
public override int GetHashCode()
{
return Major.GetHashCode() ^ Minor.GetHashCode();
}
}
これの実装を EqualityComparer に移したらこんな感じになるでしょうか。※ null判定は省略します
class VersionEqualityComparer : IEqualityComparer<Version>
{
public bool Equals(Version x, Version y)
{
return x.Major == y.Major && x.Minor == y.Minor;
}
public int GetHashCode(Version obj)
{
return obj.Major.GetHashCode() ^ obj.Minor.GetHashCode();
}
}
この EqualityComparer の Equals と GetHashCode の式を作ってくれる汎用EqualityComparer を作りたいと思います。
使い方をきめる
どんな風に使うかですが…
class Version
{
public Version(int major, int minor)
{
Major = major;
Minor = minor;
}
public int Major { get; set; }
public int Minor { get; set; }
private static readonly ComplexEqualityComparer<Version> Comparer = new ComplexEqualityComparer<Version>()
.AddMember(v => v.Major)
.AddMember(v => v.Minor)
.Compile();
public override bool Equals(object obj)
{
return Comparer.Equals(this, obj as Version);
}
public override int GetHashCode()
{
return Comparer.GetHashCode(this);
}
}
こんな感じで実装に利用するプロパティやフィールドをラムダ式で指定できるようにします。
Compile メソッドを呼ぶと指定されたプロパティやフィールドを用いた式がコンパイルされて、あとは使うだけになる、という感じです。
書きやすくするために安直な感じですが、thisを返してメソッドチェーンできるようにしておきます。
メンバの指定を実装する
Equals に利用する “x.PropertyName == y.PropertyName” という式と、GetHashCode に利用する “obj.GetHashCode()” という式を作ります。
それぞれあとで “&&” や “^” でつなぐので、リストに放り込んでいきます。
private readonly ParameterExpression _paramX = Expression.Parameter(typeof(T), "x"); // Equals(x, y) の引数 "x" を表す式
private readonly ParameterExpression _paramY = Expression.Parameter(typeof(T), "y"); // Equals(x, y) の引数 "y" を表す式
private readonly ParameterExpression _paramObj = Expression.Parameter(typeof(T), "obj"); // GetHashCode(obj) の引数 "obj" を表す式
private readonly IList<BinaryExpression> _equalsList = new List<BinaryExpression>();
private readonly IList<Expression> _getHashCodeList = new List<Expression>();
public ComplexEqualityComparer<T> AddMember<TMember>(Expression<Func<T, TMember>> member)
{
if (member.Body.NodeType != ExpressionType.MemberAccess)
throw new ArgumentException("メンバの指定はメンバアクセスの式でお願いします。");
var memberExpression = (MemberExpression) member.Body;
var memberInfo = memberExpression.Member;
var memberX = Expression.PropertyOrField(_paramX, memberInfo.Name); // x.PropertyName
var memberY = Expression.PropertyOrField(_paramY, memberInfo.Name); // y.PropertyName
var equals = Expression.Equal(memberX, memberY); // x.PropertyName == y.PropertyName
var paramMember = Expression.PropertyOrField(_paramObj, memberInfo.Name); // obj.PropertyName
var getHashCode = Expression.Call(paramMember, typeof(TMember).GetMethod("GetHashCode")); // obj.PropertyName.GetHashCode()
_equalsList.Add(equals);
_getHashCodeList.Add(getHashCode);
return this;
}
式を完成させる
Complile メソッドで式を完成させて、コンパイルをかけます。
メンバの指定の際にリストに放り込んでおいた式を “&&” や “^” で繋いで組み立てます。
最後に Expression.Lamda<>() でコンパイルして出来上がったデリゲートを保存します。
// コンパイルした式のデリゲート
private Func<T, T, bool> _compiledEquals;
private Func<T, int> _compiledGetHashCode;
public ComplexEqualityComparer<T> Compile()
{
// x.Property1 == y.Property1 && x.Property2 == y.Property2 && ...
BinaryExpression equalsExpression = null;
foreach (var equals in _equalsList)
{
equalsExpression = equalsExpression == null
? @equals
: Expression.AndAlso(equalsExpression, @equals);
}
if (equalsExpression == null)
throw new InvalidOperationException();
_compiledEquals = Expression.Lambda<Func<T, T, bool>>(equalsExpression, _paramX, _paramY).Compile();
// obj.Property1.GetHashCode() ^ obj.Property2.GetHashCode() ^ ...
Expression getHashCodeExpression = null;
foreach (var getHashCode in _getHashCodeList)
{
getHashCodeExpression = getHashCodeExpression == null
? getHashCode
: Expression.ExclusiveOr(getHashCodeExpression, getHashCode);
}
if (getHashCodeExpression == null)
throw new InvalidOperationException();
_compiledGetHashCode = Expression.Lambda<Func<T, int>>(getHashCodeExpression, _paramObj).Compile();
_compiled = true;
return this;
}
できあがり
だいたいこんな感じになりました。ほんとは null のチェックとかも必要ですが面倒なのでとりあえず。
public class ComplexEqualityComparer<T> : IEqualityComparer<T>
{
public ComplexEqualityComparer()
{
_compiled = false;
_paramX = Expression.Parameter(typeof(T), "x");
_paramY = Expression.Parameter(typeof(T), "y");
_paramObj = Expression.Parameter(typeof(T), "obj");
_equalsList = new List<BinaryExpression>();
_getHashCodeList = new List<Expression>();
}
private bool _compiled;
private readonly ParameterExpression _paramX; // Equals(x, y) の引数 "x" を表す式
private readonly ParameterExpression _paramY; // Equals(x, y) の引数 "y" を表す式
private readonly ParameterExpression _paramObj; // GetHashCode(obj) の引数 "obj" を表す式
private readonly IList<BinaryExpression> _equalsList;
private readonly IList<Expression> _getHashCodeList;
public ComplexEqualityComparer<T> AddMember<TMember>(Expression<Func<T, TMember>> member)
{
if (member.Body.NodeType != ExpressionType.MemberAccess)
throw new ArgumentException("メンバの指定はメンバアクセスの式でお願いします。");
var memberExpression = (MemberExpression) member.Body;
var memberInfo = memberExpression.Member;
var memberX = Expression.PropertyOrField(_paramX, memberInfo.Name); // x.PropertyName
var memberY = Expression.PropertyOrField(_paramY, memberInfo.Name); // y.PropertyName
var equals = Expression.Equal(memberX, memberY); // x.PropertyName == y.PropertyName
var paramMember = Expression.PropertyOrField(_paramObj, memberInfo.Name); // obj.PropertyName
var getHashCode = Expression.Call(paramMember, typeof(TMember).GetMethod("GetHashCode")); // obj.PropertyName.GetHashCode()
_equalsList.Add(equals);
_getHashCodeList.Add(getHashCode);
return this;
}
// コンパイルした式のデリゲート
private Func<T, T, bool> _compiledEquals;
private Func<T, int> _compiledGetHashCode;
public ComplexEqualityComparer<T> Compile()
{
// x.Property1 == y.Property1 && x.Property2 == y.Property2 && ...
BinaryExpression equalsExpression = null;
foreach (var equals in _equalsList)
{
equalsExpression = equalsExpression == null
? @equals
: Expression.AndAlso(equalsExpression, @equals);
}
if (equalsExpression == null)
throw new InvalidOperationException();
_compiledEquals = Expression.Lambda<Func<T, T, bool>>(equalsExpression, _paramX, _paramY).Compile();
// obj.Property1.GetHashCode() ^ obj.Property2.GetHashCode() ^ ...
Expression getHashCodeExpression = null;
foreach (var getHashCode in _getHashCodeList)
{
getHashCodeExpression = getHashCodeExpression == null
? getHashCode
: Expression.ExclusiveOr(getHashCodeExpression, getHashCode);
}
if (getHashCodeExpression == null)
throw new InvalidOperationException();
_compiledGetHashCode = Expression.Lambda<Func<T, int>>(getHashCodeExpression, _paramObj).Compile();
_compiled = true;
return this;
}
private void CompileIfNotCompiled()
{
if (_compiled) return;
Compile();
}
public bool Equals(T x, T y)
{
CompileIfNotCompiled();
return _compiledEquals(x, y);
}
public int GetHashCode(T obj)
{
CompileIfNotCompiled();
return _compiledGetHashCode(obj);
}
}
テスト
[Test]
public void UseCase()
{
var eqComparer = new ComplexEqualityComparer<Version>()
.AddMember(v => v.Major)
.AddMember(v => v.Minor)
.Compile();
var version_1_0 = new Version(1, 0);
var version_2_3 = new Version(2, 3);
Assert.IsTrue(eqComparer.Equals(version_1_0, version_1_0));
Assert.IsTrue(eqComparer.Equals(version_2_3, version_2_3));
Assert.IsTrue(eqComparer.Equals(version_2_3, new Version(2, 3)));
Assert.IsTrue(eqComparer.Equals(new Version(2, 3), version_2_3));
Assert.IsFalse(eqComparer.Equals(version_1_0, version_2_3));
Assert.IsFalse(eqComparer.Equals(version_2_3, version_1_0));
}
テスト結果
ちゃんと動くっぽいですね。
これで Equals と GetHashCode の実装は楽できそうです。
式木おもしろい!
式木とてもおもしろいですね。メタプログラミングって面白いなあと思います。
使い方を誤ると魔物が生まれそうですが、上手に使えばすごく役立ちそうです。(ぼくの頭ではなかなかアイデアが浮かびませんが…
参考にさせていただきました
大変参考にさせていただきました。