Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure EntraID (Token Based Authentication) integration tests #344

Merged
merged 12 commits into from
Nov 29, 2024
1 change: 1 addition & 0 deletions tests/NRedisStack.Tests/NRedisStack.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<PackageReference Include="xunit" Version="2.4.1" />
<PackageReference Include="xunit.assert" Version="2.4.1" />
<PackageReference Include="BouncyCastle.Cryptography" Version="2.2.0" />
<PackageReference Include="Microsoft.Azure.StackExchangeRedis" Version="3.1.0" />
</ItemGroup>

<ItemGroup>
Expand Down
16 changes: 12 additions & 4 deletions tests/NRedisStack.Tests/RedisFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ public class RedisFixture : IDisposable
public bool isEnterprise = Environment.GetEnvironmentVariable("IS_ENTERPRISE") == "true";
public bool isOSSCluster;

private ConnectionMultiplexer redis;
private ConfigurationOptions defaultConfig;

public RedisFixture()
{
ConfigurationOptions clusterConfig = new ConfigurationOptions
defaultConfig = new ConfigurationOptions
{
AsyncTimeout = 10000,
SyncTimeout = 10000
Expand Down Expand Up @@ -93,16 +96,21 @@ public RedisFixture()
isOSSCluster = true;
}
}

Redis = GetConnectionById(clusterConfig, defaultEndpointId);
}

public void Dispose()
{
Redis.Close();
}

public ConnectionMultiplexer Redis { get; }
public ConnectionMultiplexer Redis
{
get
{
redis = redis ?? GetConnectionById(defaultConfig, defaultEndpointId);
return redis;
}
}

public ConnectionMultiplexer GetConnectionById(ConfigurationOptions configurationOptions, string id)
{
Expand Down
4 changes: 3 additions & 1 deletion tests/NRedisStack.Tests/SkipIfRedisAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public class SkipIfRedisAttribute : FactAttribute
private readonly Comparison _comparison;
private readonly List<Is> _environments = new List<Is>();

private static Version serverVersion = null;

public SkipIfRedisAttribute(
Is environment,
Comparison comparison = Comparison.LessThan,
Expand Down Expand Up @@ -95,7 +97,7 @@ public override string? Skip
}
// Version check (if Is.Standalone/Is.OSSCluster is set then )

var serverVersion = redisFixture.Redis.GetServer(redisFixture.Redis.GetEndPoints()[0]).Version;
serverVersion = serverVersion ?? redisFixture.Redis.GetServer(redisFixture.Redis.GetEndPoints()[0]).Version;
var targetVersion = new Version(_targetVersion);
int comparisonResult = serverVersion.CompareTo(targetVersion);

Expand Down
36 changes: 36 additions & 0 deletions tests/NRedisStack.Tests/TargetEnvironmentAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using Xunit;

namespace NRedisStack.Tests;
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class TargetEnvironmentAttribute : SkipIfRedisAttribute
{
private string targetEnv;
public TargetEnvironmentAttribute(string targetEnv) : base(Comparison.LessThan, "0.0.0")
{
this.targetEnv = targetEnv;
}

public TargetEnvironmentAttribute(string targetEnv, Is environment, Comparison comparison = Comparison.LessThan,
string targetVersion = "0.0.0") : base(environment, comparison, targetVersion)
{
this.targetEnv = targetEnv;
}

public TargetEnvironmentAttribute(string targetEnv, Is environment1, Is environment2, Comparison comparison = Comparison.LessThan,
string targetVersion = "0.0.0") : base(environment1, environment2, comparison, targetVersion)
{
this.targetEnv = targetEnv;
}

public override string? Skip
{
get
{
if (!new RedisFixture().IsTargetConnectionExist(targetEnv))
{
return "Test skipped, because: target environment not found.";
}
return base.Skip;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using Xunit;
using StackExchange.Redis;
using Azure.Identity;
using NRedisStack.RedisStackCommands;
using NRedisStack.Search;

namespace NRedisStack.Tests.TokenBasedAuthentication
{
public class AuthenticationTests : AbstractNRedisStackTest
{
static readonly string key = "myKey";
static readonly string value = "myValue";
static readonly string index = "myIndex";
static readonly string field = "myField";
static readonly string alias = "myAlias";
public AuthenticationTests(RedisFixture redisFixture) : base(redisFixture) { }

[TargetEnvironment("standalone-entraid-acl")]
public void TestTokenBasedAuthentication()
{

var configurationOptions = new ConfigurationOptions().ConfigureForAzureWithTokenCredentialAsync(new DefaultAzureCredential()).Result!;
configurationOptions.Ssl = false;
configurationOptions.AbortOnConnectFail = true; // Fail fast for the purposes of this sample. In production code, this should remain false to retry connections on startup

ConnectionMultiplexer? connectionMultiplexer = redisFixture.GetConnectionById(configurationOptions, "standalone-entraid-acl");

IDatabase db = connectionMultiplexer.GetDatabase();

db.KeyDelete(key);
try
{
db.FT().DropIndex(index);
}
catch { }

db.StringSet(key, value);
string result = db.StringGet(key);
Assert.Equal(value, result);

var ft = db.FT();
Schema sc = new Schema().AddTextField(field);
Assert.True(ft.Create(index, FTCreateParams.CreateParams(), sc));

db.HashSet(index, new HashEntry[] { new HashEntry(field, value) });

Assert.True(ft.AliasAdd(alias, index));
SearchResult res1 = ft.Search(alias, new Query("*").ReturnFields(field));
Assert.Equal(1, res1.TotalResults);
Assert.Equal(value, res1.Documents[0][field]);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
using System.Text;
using System.Text.Json;
using System.Net.Http;

public class FaultInjectorClient
{
private static readonly string BASE_URL;

static FaultInjectorClient()
{
BASE_URL = Environment.GetEnvironmentVariable("FAULT_INJECTION_API_URL") ?? "http://127.0.0.1:20324";
}

public class TriggerActionResponse
{
public string ActionId { get; }
private DateTime? LastRequestTime { get; set; }
private DateTime? CompletedAt { get; set; }
private DateTime? FirstRequestAt { get; set; }

public TriggerActionResponse(string actionId)
{
ActionId = actionId;
}

public async Task<bool> IsCompletedAsync(TimeSpan checkInterval, TimeSpan delayAfter, TimeSpan timeout)
{
if (CompletedAt.HasValue)
{
return DateTime.UtcNow - CompletedAt.Value >= delayAfter;
}

if (FirstRequestAt.HasValue && DateTime.UtcNow - FirstRequestAt.Value >= timeout)
{
throw new TimeoutException("Timeout");
}

if (!LastRequestTime.HasValue || DateTime.UtcNow - LastRequestTime.Value >= checkInterval)
{
LastRequestTime = DateTime.UtcNow;

if (!FirstRequestAt.HasValue)
{
FirstRequestAt = LastRequestTime;
}

using var httpClient = GetHttpClient();
var request = new HttpRequestMessage(HttpMethod.Get, $"{BASE_URL}/action/{ActionId}");

try
{
var response = await httpClient.SendAsync(request);
var result = await response.Content.ReadAsStringAsync();


if (result.Contains("success"))
{
CompletedAt = DateTime.UtcNow;
return DateTime.UtcNow - CompletedAt.Value >= delayAfter;
}
}
catch (HttpRequestException e)
{
throw new Exception("Fault injection proxy error", e);
}
}
return false;
}
}

private static HttpClient GetHttpClient()
{
var httpClient = new HttpClient
{
Timeout = TimeSpan.FromMilliseconds(5000)
};
return httpClient;
}

public async Task<TriggerActionResponse> TriggerActionAsync(string actionType, Dictionary<string, object> parameters)
{
var payload = new Dictionary<string, object>
{
{ "type", actionType },
{ "parameters", parameters }
};

var jsonString = JsonSerializer.Serialize(payload, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
});

using var httpClient = GetHttpClient();
var request = new HttpRequestMessage(HttpMethod.Post, $"{BASE_URL}/action")
{
Content = new StringContent(jsonString, Encoding.UTF8, "application/json")
};

try
{
var response = await httpClient.SendAsync(request);
var result = await response.Content.ReadAsStringAsync();
return JsonSerializer.Deserialize<TriggerActionResponse>(result, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
});
}
catch (HttpRequestException e)
{
throw;
}
}
}
Loading