Skip to content

Commit

Permalink
feat(Event streams): add support for cancellation and IAsyncEnumerator
Browse files Browse the repository at this point in the history
  • Loading branch information
Timothée Lecomte committed Nov 6, 2024
1 parent ad13fd4 commit 0887038
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*******************************************************************************
/*******************************************************************************
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use
* this file except in compliance with the License. A copy of the License is located at
Expand All @@ -25,6 +25,7 @@
using System.Diagnostics.CodeAnalysis;
using System.IO;
#if AWS_ASYNC_API
using System.Threading;
using System.Threading.Tasks;
#endif

Expand All @@ -36,7 +37,12 @@ namespace Amazon.Runtime.EventStreams.Internal
/// <typeparam name="T">An implementation of IEventStreamEvent (e.g. IS3Event).</typeparam>
/// <typeparam name="TE">An implementation of EventStreamException (e.g. S3EventStreamException).</typeparam>
[SuppressMessage("Microsoft.Naming", "CA1710", Justification = "IEventStreamCollection is not descriptive.")]
#if AWS_ASYNC_ENUMERABLES_API

public interface IEnumerableEventStream<T, TE> : IEventStream<T, TE>, IEnumerable<T>, IAsyncEnumerable<T> where T : IEventStreamEvent where TE : EventStreamException, new()
#else
public interface IEnumerableEventStream<T, TE> : IEventStream<T, TE>, IEnumerable<T> where T : IEventStreamEvent where TE : EventStreamException, new()
#endif
{
}

Expand Down Expand Up @@ -171,13 +177,72 @@ public override void StartProcessing()
///
/// The Task will be completed when all of the events from the stream have been processed.
/// </summary>
public override async Task StartProcessingAsync()
public override async Task StartProcessingAsync(CancellationToken cancellationToken = default)
{
// If they are/have enumerated, the event-driven mode should be disabled
if (IsEnumerated) throw new InvalidOperationException(MutuallyExclusiveExceptionMessage);

await base.StartProcessingAsync().ConfigureAwait(false);
await base.StartProcessingAsync(cancellationToken).ConfigureAwait(false);
}
#endif

#if AWS_ASYNC_ENUMERABLES_API
/// <summary>
/// Returns an async enumerator that iterates through the collection.
/// </summary>
/// <returns>An async enumerator that can be used to iterate through the collection.</returns>
public async IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken)
{
if (IsProcessing)
{
// If the queue has already begun processing, refuse to enumerate.
throw new InvalidOperationException(MutuallyExclusiveExceptionMessage);
}

// There could be more than 1 message created per decoder cycle.
var events = new Queue<T>();

// Opting out of events - letting the enumeration handle everything.
IsEnumerated = true;
IsProcessing = true;

// Enumeration is just magic over the event driven mechanism.
EventReceived += (sender, args) => events.Enqueue(args.EventStreamEvent);

var buffer = new byte[BufferSize];

while (IsProcessing)
{
// If there are already events ready to be served, do not ask for more.
if (events.Count > 0)
{
var ev = events.Dequeue();
// Enumeration handles terminal events on behalf of the user.
if (ev is IEventStreamTerminalEvent)
{
IsProcessing = false;
Dispose();
}

yield return ev;
}
else
{
try
{
await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
IsProcessing = false;
Dispose();

// Wrap exceptions as needed to match event-driven behavior.
throw WrapException(ex);
}
}
}
}
#endif
}
}
}
22 changes: 11 additions & 11 deletions sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
Expand All @@ -17,10 +17,9 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;
#if AWS_ASYNC_API
using System.Threading.Tasks;
#else
using System.Threading;
#endif

namespace Amazon.Runtime.EventStreams.Internal
Expand Down Expand Up @@ -55,7 +54,7 @@ namespace Amazon.Runtime.EventStreams.Internal
///
/// The Task will be completed when all of the events from the stream have been processed.
/// </summary>
Task StartProcessingAsync();
Task StartProcessingAsync(CancellationToken cancellationToken = default);
#endif
}

Expand Down Expand Up @@ -262,23 +261,23 @@ protected void Process()
{
#if AWS_ASYNC_API
// Task only exists in framework 4.5 and up, and Standard.
Task.Run(() => ProcessLoopAsync());
Task.Run(() => ProcessLoopAsync(CancellationToken.None));
#else
// ThreadPool only exists in 3.5 and below. These implementations do not have the Task library.
ThreadPool.QueueUserWorkItem(ProcessLoop);
#endif
}

#if AWS_ASYNC_API
private async Task ProcessLoopAsync()
private async Task ProcessLoopAsync(CancellationToken cancellationToken)
{
var buffer = new byte[BufferSize];

try
{
while (IsProcessing)
{
await ReadFromStreamAsync(buffer).ConfigureAwait(false);
await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false);
}
}
// These exceptions are raised on the background thread. They are fired as events for visibility.
Expand Down Expand Up @@ -351,9 +350,10 @@ protected void ReadFromStream(byte[] buffer)
/// each message it decodes.
/// </summary>
/// <param name="buffer">The buffer to store the read bytes from the stream.</param>
protected async Task ReadFromStreamAsync(byte[] buffer)
/// <param name="cancellationToken">A cancellation token.</param>
protected async Task ReadFromStreamAsync(byte[] buffer, CancellationToken cancellationToken)
{
var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
if (bytesRead > 0)
{
// Decoder raises MessageReceived for every message it encounters.
Expand Down Expand Up @@ -408,13 +408,13 @@ public virtual void StartProcessing()
///
/// The Task will be completed when all of the events from the stream have been processed.
/// </summary>
public virtual async Task StartProcessingAsync()
public virtual async Task StartProcessingAsync(CancellationToken cancellationToken = default)
{
if (IsProcessing)
return;

IsProcessing = true;
await ProcessLoopAsync().ConfigureAwait(false);
await ProcessLoopAsync(cancellationToken).ConfigureAwait(false);
}
#endif

Expand Down
9 changes: 8 additions & 1 deletion sdk/src/Services/BedrockRuntime/BedrockRuntime.sln
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.SecurityToken.Net45"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.SimpleNotificationService.Net45", "../SimpleNotificationService/AWSSDK.SimpleNotificationService.Net45.csproj", "{A657D500-DDA4-45FF-9459-8351CDD96B78}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.IntegrationTests.BedrockRuntime.NetStandard", "../../../test/Services/BedrockRuntime/IntegrationTests/AWSSDK.IntegrationTests.BedrockRuntime.NetStandard.csproj", "{9F726137-4C28-4FEA-9A6C-962DEA25951D}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -186,6 +188,10 @@ Global
{A657D500-DDA4-45FF-9459-8351CDD96B78}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A657D500-DDA4-45FF-9459-8351CDD96B78}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A657D500-DDA4-45FF-9459-8351CDD96B78}.Release|Any CPU.Build.0 = Release|Any CPU
{9F726137-4C28-4FEA-9A6C-962DEA25951D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{9F726137-4C28-4FEA-9A6C-962DEA25951D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9F726137-4C28-4FEA-9A6C-962DEA25951D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9F726137-4C28-4FEA-9A6C-962DEA25951D}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -220,8 +226,9 @@ Global
{7BD5B7F3-2ED9-4747-9FCE-8F5622BFCC36} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC}
{EE034587-0A31-4841-A4BB-055DB990990F} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC}
{A657D500-DDA4-45FF-9459-8351CDD96B78} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC}
{9F726137-4C28-4FEA-9A6C-962DEA25951D} = {12EC4E4B-7E2C-4B63-8EF9-7B959F82A89B}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {CE2F2305-8E72-44B6-9FAD-AA2E347C2B6A}
EndGlobalSection
EndGlobal
EndGlobal
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<RunAnalyzersDuringBuild Condition="'$(RunAnalyzersDuringBuild)'==''">true</RunAnalyzersDuringBuild>
<TargetFrameworks>netstandard2.0;netcoreapp3.1;net8.0</TargetFrameworks>
<DefineConstants>$(DefineConstants);NETSTANDARD;AWS_ASYNC_API</DefineConstants>
<DefineConstants Condition="'$(TargetFramework)' == 'netstandard2.0'">$(DefineConstants);NETSTANDARD20;AWS_ASYNC_ENUMERABLES_API</DefineConstants>
<DefineConstants Condition="'$(TargetFramework)' == 'netcoreapp3.1'">$(DefineConstants);AWS_ASYNC_ENUMERABLES_API</DefineConstants>
<DefineConstants Condition="'$(TargetFramework)' == 'net8.0'">$(DefineConstants);AWS_ASYNC_ENUMERABLES_API</DefineConstants>
<DebugType>portable</DebugType>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<AssemblyName>AWSSDK.IntegrationTests.BedrockRuntime.NetStandard</AssemblyName>
<PackageId>AWSSDK.IntegrationTests.BedrockRuntime.NetStandard</PackageId>

<GenerateAssemblyTitleAttribute>false</GenerateAssemblyTitleAttribute>
<GenerateAssemblyConfigurationAttribute>false</GenerateAssemblyConfigurationAttribute>
<GenerateAssemblyProductAttribute>false</GenerateAssemblyProductAttribute>
<GenerateAssemblyCompanyAttribute>false</GenerateAssemblyCompanyAttribute>
<GenerateAssemblyCopyrightAttribute>false</GenerateAssemblyCopyrightAttribute>
<GenerateAssemblyVersionAttribute>false</GenerateAssemblyVersionAttribute>
<GenerateAssemblyFileVersionAttribute>false</GenerateAssemblyFileVersionAttribute>
<GenerateAssemblyDescriptionAttribute>false</GenerateAssemblyDescriptionAttribute>
<SignAssembly>true</SignAssembly>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>

<NoWarn>CA1822</NoWarn>
</PropertyGroup>

<!-- Async Enumerable Compatibility -->
<PropertyGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<LangVersion>8.0</LangVersion>
</PropertyGroup>

<ItemGroup>
<Compile Remove="**/35/**" />
<None Remove="**/35/**" />
<Compile Remove="**/obj/**" />
<None Remove="**/obj/**" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.11.1" />
<PackageReference Include="MSTest.TestAdapter" Version="3.6.2" />
<PackageReference Include="MSTest.TestFramework" Version="3.6.2" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="../../../../src/Core/AWSSDK.Core.NetStandard.csproj" />
<ProjectReference Include="../../../../src/Services/BedrockRuntime/AWSSDK.BedrockRuntime.NetStandard.csproj" />
</ItemGroup>


</Project>
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Amazon.BedrockRuntime;
using Amazon.BedrockRuntime.Model;
using System.Threading.Tasks;
Expand All @@ -8,6 +8,9 @@
using System.Threading;
using System;
using System.Diagnostics.Contracts;
using System.Collections.Generic;
using Amazon;
using Amazon.Runtime;
namespace AWSSDK_DotNet.IntegrationTests.Tests
{
/// <summary>
Expand All @@ -22,7 +25,11 @@ namespace AWSSDK_DotNet.IntegrationTests.Tests
/// </summary>
[Ignore]
[TestClass]
#if NETSTANDARD
public class BedrockRuntimeEventStreamTests
#else
public class BedrockRuntimeEventStreamTests : TestBase<AmazonBedrockRuntimeClient>
#endif
{
#if BCL35
[TestMethod]
Expand Down Expand Up @@ -145,6 +152,48 @@ public async Task RequestWithInvalidBodyReturnsValidationException()
}

#endif

#if AWS_ASYNC_ENUMERABLES_API
[TestMethod]
public async Task ConverseStreamCanBeEnumeratedAsynchronously()
{
// configure with credentials and region
var client = new AmazonBedrockRuntimeClient();

var request = new ConverseStreamRequest
{
ModelId = "meta.llama3-1-8b-instruct-v1:0"
};

request.Messages.Add(new Message
{
Content = new List<ContentBlock> { new ContentBlock { Text = "Who was the first US president" } },
Role = ConversationRole.User
});

var response = await client.ConverseStreamAsync(request);

Assert.IsNotNull(response);
Assert.IsNotNull(response.Stream);

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20));

var contentStringBuilder = new StringBuilder();
await foreach (var item in response.Stream.WithCancellation(cts.Token))
{
if (item is ContentBlockDeltaEvent deltaEvent)
{
contentStringBuilder.Append(deltaEvent.Delta.Text);
}
}

var responseContent = contentStringBuilder.ToString();

// Since we don't know the contents of the response from Bedrock, we just assert that we received a response
Assert.IsTrue(responseContent.Length > 10);
}
#endif

static MemoryStream CreateStream(string query, bool createInvalidInput = false)
{
StringBuilder promptValueBuilder = new StringBuilder();
Expand Down

0 comments on commit 0887038

Please sign in to comment.