diff --git a/src/Adapter/MSTest.TestAdapter/Discovery/TypeEnumerator.cs b/src/Adapter/MSTest.TestAdapter/Discovery/TypeEnumerator.cs index f26905d26d..76a52c1253 100644 --- a/src/Adapter/MSTest.TestAdapter/Discovery/TypeEnumerator.cs +++ b/src/Adapter/MSTest.TestAdapter/Discovery/TypeEnumerator.cs @@ -195,7 +195,7 @@ internal UnitTestElement GetTestFromMethod(MethodInfo method, bool isDeclaredInT testElement.WorkItemIds = workItemAttributes.Select(x => x.Id.ToString(CultureInfo.InvariantCulture)).ToArray(); } - // get DisplayName from TestMethodAttribute + // get DisplayName from TestMethodAttribute (or any inherited attribute) var testMethodAttribute = _reflectHelper.GetCustomAttribute(method); testElement.DisplayName = testMethodAttribute?.DisplayName ?? method.Name; diff --git a/src/TestFramework/TestFramework/Attributes/DataSource/DataTestMethodAttribute.cs b/src/TestFramework/TestFramework/Attributes/DataSource/DataTestMethodAttribute.cs index 6d7f40c213..a1603092e0 100644 --- a/src/TestFramework/TestFramework/Attributes/DataSource/DataTestMethodAttribute.cs +++ b/src/TestFramework/TestFramework/Attributes/DataSource/DataTestMethodAttribute.cs @@ -9,4 +9,21 @@ namespace Microsoft.VisualStudio.TestTools.UnitTesting; [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] public class DataTestMethodAttribute : TestMethodAttribute { + /// + /// Initializes a new instance of the class. + /// + public DataTestMethodAttribute() + { + } + + /// + /// Initializes a new instance of the class. + /// + /// + /// Display name for the test. + /// + public DataTestMethodAttribute(string? displayName) + : base(displayName) + { + } } diff --git a/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt b/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt index 2f65313e56..b6be4a4c1b 100644 --- a/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt +++ b/src/TestFramework/TestFramework/PublicAPI/PublicAPI.Unshipped.txt @@ -1,4 +1,5 @@ #nullable enable Microsoft.VisualStudio.TestTools.UnitTesting.DataRowAttribute.DataRowAttribute(object? data) -> void +Microsoft.VisualStudio.TestTools.UnitTesting.DataTestMethodAttribute.DataTestMethodAttribute(string? displayName) -> void Microsoft.VisualStudio.TestTools.UnitTesting.UnitTestOutcome.NotFound = 9 -> Microsoft.VisualStudio.TestTools.UnitTesting.UnitTestOutcome virtual Microsoft.VisualStudio.TestTools.UnitTesting.Logging.Logger.LogMessageHandler.Invoke(string! message) -> void diff --git a/test/UnitTests/MSTestAdapter.UnitTests/Discovery/TypeEnumeratorTests.cs b/test/UnitTests/MSTestAdapter.UnitTests/Discovery/TypeEnumeratorTests.cs index 4b6b75f4f0..2244898ae7 100644 --- a/test/UnitTests/MSTestAdapter.UnitTests/Discovery/TypeEnumeratorTests.cs +++ b/test/UnitTests/MSTestAdapter.UnitTests/Discovery/TypeEnumeratorTests.cs @@ -517,7 +517,7 @@ public void GetTestFromMethodShouldSetDisplayNameToTestMethodNameIfDisplayNameIs Verify(testElement.DisplayName == "MethodWithVoidReturnType"); } - public void GetTestFromMethodShouldSetDisplayNameFromAttribute() + public void GetTestFromMethodShouldSetDisplayNameFromTestMethodAttribute() { SetupTestClassAndTestMethods(isValidTestClass: true, isValidTestMethod: true, isMethodFromSameAssembly: true); TypeEnumerator typeEnumerator = GetTypeEnumeratorInstance(typeof(DummyTestClass), "DummyAssemblyName"); @@ -533,6 +533,22 @@ public void GetTestFromMethodShouldSetDisplayNameFromAttribute() Verify(testElement.DisplayName == "Test method display name."); } + public void GetTestFromMethodShouldSetDisplayNameFromDataTestMethodAttribute() + { + SetupTestClassAndTestMethods(isValidTestClass: true, isValidTestMethod: true, isMethodFromSameAssembly: true); + TypeEnumerator typeEnumerator = GetTypeEnumeratorInstance(typeof(DummyTestClass), "DummyAssemblyName"); + var methodInfo = typeof(DummyTestClass).GetMethod(nameof(DummyTestClass.MethodWithVoidReturnType)); + + // Setup mocks to behave like we have [DataTestMethod("Test method display name.")] attribute on the method + _mockReflectHelper.Setup( + rh => rh.GetCustomAttribute(methodInfo)).Returns(new DataTestMethodAttribute("Test method display name.")); + + var testElement = typeEnumerator.GetTestFromMethod(methodInfo, true, _warnings); + + Verify(testElement is not null); + Verify(testElement.DisplayName == "Test method display name."); + } + #endregion #region private methods