Skip to content

Commit 1330dc5

Browse files
committed
CSHARP-4457: LINQ3 should support filters with bool fields as well as properties.
1 parent 0aaebf6 commit 1330dc5

File tree

6 files changed

+229
-18
lines changed

6 files changed

+229
-18
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/DocumentSerializerHelper.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2020
{
2121
internal static class DocumentSerializerHelper
2222
{
23-
public static FieldInfo GetFieldInfo(IBsonSerializer serializer, string memberName)
23+
public static MemberSerializationInfo GetMemberSerializationInfo(IBsonSerializer serializer, string memberName)
2424
{
2525
if (!(serializer is IBsonDocumentSerializer documentSerializer))
2626
{
@@ -32,10 +32,10 @@ public static FieldInfo GetFieldInfo(IBsonSerializer serializer, string memberNa
3232
throw new InvalidOperationException($"Serializer for {serializer.ValueType} does not have a member named {memberName}.");
3333
}
3434

35-
return new FieldInfo(serializationInfo.ElementName, serializationInfo.Serializer);
35+
return new MemberSerializationInfo(serializationInfo.ElementName, serializationInfo.Serializer);
3636
}
3737

38-
public static bool HasFieldInfo(IBsonSerializer serializer, string memberName)
38+
public static bool HasMemberSerializationInfo(IBsonSerializer serializer, string memberName)
3939
{
4040
return
4141
serializer is IBsonDocumentSerializer documentSerializer &&

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/FieldInfo.cs src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MemberSerializationInfo.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919
namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
2020
{
21-
internal class FieldInfo
21+
internal class MemberSerializationInfo
2222
{
2323
// private fields
2424
private readonly string _elementName;
2525
private readonly IBsonSerializer _serializer;
2626

2727
// constructors
28-
public FieldInfo(string elementName, IBsonSerializer serializer)
28+
public MemberSerializationInfo(string elementName, IBsonSerializer serializer)
2929
{
3030
_elementName = Ensure.IsNotNullOrEmpty(elementName, nameof(elementName));
3131
_serializer = Ensure.IsNotNull(serializer, nameof(serializer));

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberExpressionToAggregationExpressionTranslator.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public static AggregationExpression Translate(TranslationContext context, Member
5252
containerTranslation = new AggregationExpression(containerExpression, unwrappedValueAst, wrappedValueSerializer.ValueSerializer);
5353
}
5454

55-
if (!DocumentSerializerHelper.HasFieldInfo(containerTranslation.Serializer, member.Name))
55+
if (!DocumentSerializerHelper.HasMemberSerializationInfo(containerTranslation.Serializer, member.Name))
5656
{
5757
if (member is PropertyInfo propertyInfo && propertyInfo.Name == "Length")
5858
{
@@ -70,9 +70,9 @@ public static AggregationExpression Translate(TranslationContext context, Member
7070
}
7171
}
7272

73-
var fieldInfo = DocumentSerializerHelper.GetFieldInfo(containerTranslation.Serializer, member.Name);
74-
var ast = AstExpression.GetField(containerTranslation.Ast, fieldInfo.ElementName);
75-
return new AggregationExpression(expression, ast, fieldInfo.Serializer);
73+
var serializationInfo = DocumentSerializerHelper.GetMemberSerializationInfo(containerTranslation.Serializer, member.Name);
74+
var ast = AstExpression.GetField(containerTranslation.Ast, serializationInfo.ElementName);
75+
return new AggregationExpression(expression, ast, serializationInfo.Serializer);
7676
}
7777

7878
private static bool TryTranslateCollectionCountProperty(MemberExpression expression, AggregationExpression container, MemberInfo memberInfo, out AggregationExpression result)

src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/ExpressionTranslators/MemberExpressionToFilterTranslator.cs

+9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ public static AstFilter Translate(TranslationContext context, MemberExpression e
2929
{
3030
var memberInfo = expression.Member;
3131

32+
if (memberInfo is FieldInfo fieldInfo)
33+
{
34+
if (fieldInfo.FieldType == typeof(bool))
35+
{
36+
var field = ExpressionToFilterFieldTranslator.Translate(context, expression);
37+
return AstFilter.Eq(field, true);
38+
}
39+
}
40+
3241
if (memberInfo is PropertyInfo propertyInfo)
3342
{
3443
if (propertyInfo.Is(NullableProperty.HasValue))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/* Copyright 2010-present MongoDB Inc.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*/
15+
16+
using System.Linq;
17+
using FluentAssertions;
18+
using MongoDB.Bson;
19+
using MongoDB.Bson.Serialization;
20+
using MongoDB.Driver.Linq;
21+
using MongoDB.TestHelpers.XunitExtensions;
22+
using Xunit;
23+
24+
namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Jira
25+
{
26+
public class CSharp4457Tests : Linq3IntegrationTest
27+
{
28+
[Theory]
29+
[ParameterAttributeData]
30+
public void Filter_with_bool_field_should_work([Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
31+
{
32+
var collection = CreateCollection(linqProvider);
33+
var builder = Builders<C>.Filter;
34+
var filter = builder.Where(x => x.BoolField);
35+
36+
var rendered = RenderFilter(filter, linqProvider);
37+
rendered.Should().Be("{ BoolField : true }");
38+
39+
var results = collection.FindSync(filter).ToList();
40+
results.Select(x => x.Id).Should().Equal(1);
41+
}
42+
43+
[Theory]
44+
[ParameterAttributeData]
45+
public void Filter_with_bool_property_should_work([Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
46+
{
47+
var collection = CreateCollection(linqProvider);
48+
var builder = Builders<C>.Filter;
49+
var filter = builder.Where(x => x.BoolProperty);
50+
51+
var rendered = RenderFilter(filter, linqProvider);
52+
rendered.Should().Be("{ BoolProperty : true }");
53+
54+
var results = collection.FindSync(filter).ToList();
55+
results.Select(x => x.Id).Should().Equal(1);
56+
}
57+
58+
[Theory]
59+
[ParameterAttributeData]
60+
public void Filter_with_not_bool_field_should_work([Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
61+
{
62+
var collection = CreateCollection(linqProvider);
63+
var builder = Builders<C>.Filter;
64+
var filter = builder.Where(x => !x.BoolField);
65+
66+
var rendered = RenderFilter(filter, linqProvider);
67+
rendered.Should().Be("{ BoolField : { $ne : true } }");
68+
69+
var results = collection.FindSync(filter).ToList();
70+
results.Select(x => x.Id).Should().Equal(2);
71+
}
72+
73+
[Theory]
74+
[ParameterAttributeData]
75+
public void Filter_with_not_bool_property_should_work([Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
76+
{
77+
var collection = CreateCollection(linqProvider);
78+
var builder = Builders<C>.Filter;
79+
var filter = builder.Where(x => !x.BoolProperty);
80+
81+
var rendered = RenderFilter(filter, linqProvider);
82+
rendered.Should().Be("{ BoolProperty : { $ne : true } }");
83+
84+
var results = collection.FindSync(filter).ToList();
85+
results.Select(x => x.Id).Should().Equal(2);
86+
}
87+
88+
[Theory]
89+
[ParameterAttributeData]
90+
public void Where_with_bool_field_should_work([Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
91+
{
92+
var collection = CreateCollection(linqProvider);
93+
94+
var queryable =
95+
collection.AsQueryable()
96+
.Where(x => x.BoolField);
97+
98+
var stages = Translate(collection, queryable);
99+
AssertStages(stages, "{ $match : { BoolField : true } }");
100+
101+
var results = queryable.ToList();
102+
results.Select(x => x.Id).Should().Equal(1);
103+
}
104+
105+
[Theory]
106+
[ParameterAttributeData]
107+
public void Where_with_bool_property_should_work([Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
108+
{
109+
var collection = CreateCollection(linqProvider);
110+
111+
var queryable =
112+
collection.AsQueryable()
113+
.Where(x => x.BoolProperty);
114+
115+
var stages = Translate(collection, queryable);
116+
AssertStages(stages, "{ $match : { BoolProperty : true } }");
117+
118+
var results = queryable.ToList();
119+
results.Select(x => x.Id).Should().Equal(1);
120+
}
121+
122+
[Theory]
123+
[ParameterAttributeData]
124+
public void Where_with_not_bool_field_should_work([Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
125+
{
126+
var collection = CreateCollection(linqProvider);
127+
128+
var queryable =
129+
collection.AsQueryable()
130+
.Where(x => !x.BoolField);
131+
132+
var stages = Translate(collection, queryable);
133+
AssertStages(stages, "{ $match : { BoolField : { $ne : true } } }");
134+
135+
var results = queryable.ToList();
136+
results.Select(x => x.Id).Should().Equal(2);
137+
}
138+
139+
[Theory]
140+
[ParameterAttributeData]
141+
public void Where_with_not_bool_property_should_work([Values(LinqProvider.V2, LinqProvider.V3)] LinqProvider linqProvider)
142+
{
143+
var collection = CreateCollection(linqProvider);
144+
145+
var queryable =
146+
collection.AsQueryable()
147+
.Where(x => !x.BoolProperty);
148+
149+
var stages = Translate(collection, queryable);
150+
AssertStages(stages, "{ $match : { BoolProperty : { $ne : true } } }");
151+
152+
var results = queryable.ToList();
153+
results.Select(x => x.Id).Should().Equal(2);
154+
}
155+
156+
private IMongoCollection<C> CreateCollection(LinqProvider linqProvider)
157+
{
158+
var collection = GetCollection<C>("C", linqProvider);
159+
160+
CreateCollection(
161+
collection,
162+
new C { Id = 1, BoolField = true, BoolProperty = true },
163+
new C { Id = 2, BoolField = false, BoolProperty = false });
164+
165+
return collection;
166+
}
167+
168+
private BsonDocument RenderFilter<TDocument>(FilterDefinition<TDocument> filter, LinqProvider linqProvider)
169+
{
170+
var serializerRegistry = BsonSerializer.SerializerRegistry;
171+
var documentSerializer = serializerRegistry.GetSerializer<TDocument>();
172+
return filter.Render(documentSerializer, serializerRegistry, linqProvider);
173+
}
174+
175+
private class C
176+
{
177+
public bool BoolField;
178+
179+
public int Id { get; set; }
180+
public bool BoolProperty { get; set; }
181+
}
182+
}
183+
}

tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationTests/Linq3IntegrationTest.cs

+28-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
* limitations under the License.
1414
*/
1515

16-
using System;
1716
using System.Collections.Generic;
1817
using System.Linq;
1918
using FluentAssertions;
@@ -86,7 +85,8 @@ protected static List<BsonDocument> Translate<TDocument, TResult>(IMongoCollecti
8685
// in this overload the collection argument is used only to infer the TDocument type
8786
protected List<BsonDocument> Translate<TDocument, TResult>(IMongoCollection<TDocument> collection, IQueryable<TResult> queryable)
8887
{
89-
return Translate<TDocument, TResult>(queryable);
88+
var linqProvider = collection.Database.Client.Settings.LinqProvider;
89+
return Translate<TDocument, TResult>(queryable, linqProvider);
9090
}
9191

9292
// in this overload the collection argument is used only to infer the TDocument type
@@ -108,18 +108,37 @@ protected List<BsonDocument> Translate<TResult>(IMongoDatabase database, IQuerya
108108
return Translate<NoPipelineInput, TResult>(queryable);
109109
}
110110

111-
protected List<BsonDocument> Translate<TDocument, TResult>(IQueryable<TResult> queryable)
111+
protected List<BsonDocument> Translate<TDocument, TResult>(IQueryable<TResult> queryable, LinqProvider linqProvider = LinqProvider.V3)
112112
{
113-
return Translate<TDocument, TResult>(queryable, out _);
113+
return Translate<TDocument, TResult>(queryable, linqProvider, out _);
114114
}
115115

116116
protected List<BsonDocument> Translate<TDocument, TResult>(IQueryable<TResult> queryable, out IBsonSerializer<TResult> outputSerializer)
117117
{
118-
var provider = (MongoQueryProvider<TDocument>)queryable.Provider;
119-
var executableQuery = ExpressionToExecutableQueryTranslator.Translate<TDocument, TResult>(provider, queryable.Expression);
120-
var stages = executableQuery.Pipeline.Stages;
121-
outputSerializer = (IBsonSerializer<TResult>)executableQuery.Pipeline.OutputSerializer;
122-
return stages.Select(s => s.Render().AsBsonDocument).ToList();
118+
return Translate<TDocument, TResult>(queryable, LinqProvider.V3, out outputSerializer);
119+
}
120+
121+
protected List<BsonDocument> Translate<TDocument, TResult>(IQueryable<TResult> queryable, LinqProvider linqProvider, out IBsonSerializer<TResult> outputSerializer)
122+
{
123+
if (linqProvider == LinqProvider.V2)
124+
{
125+
var linq2QueryProvider = (MongoDB.Driver.Linq.Linq2Implementation.MongoQueryProviderImpl<TDocument>)queryable.Provider;
126+
var executionModel = linq2QueryProvider.GetExecutionModel(queryable.Expression);
127+
var executionModelType = executionModel.GetType();
128+
var stagesPropertyInfo = executionModelType.GetProperty("Stages");
129+
var stages = (IEnumerable<BsonDocument>)stagesPropertyInfo.GetValue(executionModel);
130+
var outputSerializerPropertyInfo = executionModelType.GetProperty("OutputSerializer");
131+
outputSerializer = (IBsonSerializer<TResult>)outputSerializerPropertyInfo.GetValue(executionModel);
132+
return stages.ToList();
133+
}
134+
else
135+
{
136+
var linq3QueryProvider = (MongoQueryProvider<TDocument>)queryable.Provider;
137+
var executableQuery = ExpressionToExecutableQueryTranslator.Translate<TDocument, TResult>(linq3QueryProvider, queryable.Expression);
138+
var stages = executableQuery.Pipeline.Stages;
139+
outputSerializer = (IBsonSerializer<TResult>)executableQuery.Pipeline.OutputSerializer;
140+
return stages.Select(s => s.Render().AsBsonDocument).ToList();
141+
}
123142
}
124143

125144
protected static List<BsonDocument> Translate<TDocument, TResult>(

0 commit comments

Comments
 (0)