From f62be7a31a40711f50cc289f256531b3d711533a Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 8 Aug 2024 17:28:31 -0700 Subject: [PATCH 1/3] Prototype. --- go/genkit/flow.go | 91 ++++++++++++++++++++++++++++------------ go/genkit/genkit_test.go | 25 ++++++----- 2 files changed, 78 insertions(+), 38 deletions(-) diff --git a/go/genkit/flow.go b/go/genkit/flow.go index 7f62ef024..45f99eb41 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -695,45 +695,82 @@ func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(conte return finishedOpResponse(state.Operation) } -// StreamFlowValue is either a streamed value or a final output of a flow. -type StreamFlowValue[Out, Stream any] struct { - Done bool - Output Out // valid if Done is true - Stream Stream // valid if Done is false +// FlowIterator defines the interface for iterating over flow results. +type FlowIterator[Out, Stream any] interface { + IsDone() bool + Next() (Stream, error) + FinalOutput() (Out, error) +} + +// flowIterator implements the FlowIterator interface. +type flowIterator[Out, Stream any] struct { + ctx context.Context + done bool + output Out + streamCh chan Stream + doneCh chan struct{} + errCh chan error +} + +// IsDone returns true if the flow has completed. +func (fi *flowIterator[Out, Stream]) IsDone() bool { + return fi.done +} + +// Next returns the next streamed value or an error if the flow has completed or failed. +func (fi *flowIterator[Out, Stream]) Next() (*Stream, error) { + select { + case stream := <-fi.streamCh: + return &stream, nil + case err := <-fi.errCh: + return nil, err + case <-fi.doneCh: + return nil, errors.New("flow has completed") + } } -// Stream runs the flow on input and delivers both the streamed values and the final output. -// It returns a function whose argument function (the "yield function") will be repeatedly -// called with the results. -// -// If the yield function is passed a non-nil error, the flow has failed with that -// error; the yield function will not be called again. An error is also passed if -// the flow fails to complete (that is, it has an interrupt). -// Genkit Go does not yet support interrupts. -// -// If the yield function's [StreamFlowValue] argument has Done == true, the value's -// Output field contains the final output; the yield function will not be called -// again. -// -// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result. -func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) func(func(*StreamFlowValue[Out, Stream], error) bool) { - return func(yield func(*StreamFlowValue[Out, Stream], error) bool) { +// FinalOutput returns the final output of the flow if it has completed. +func (fi *flowIterator[Out, Stream]) FinalOutput() (*Out, error) { + if !fi.done { + return nil, errors.New("flow has not completed") + } + return &fi.output, nil +} + +// Stream returns a FlowIterator for the flow. +func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) FlowIterator[*Out, *Stream] { + fi := &flowIterator[Out, Stream]{ + ctx: ctx, + done: false, + streamCh: make(chan Stream), + doneCh: make(chan struct{}), + errCh: make(chan error), + } + + go func() { cb := func(ctx context.Context, s Stream) error { if ctx.Err() != nil { + fi.errCh <- ctx.Err() return ctx.Err() } - if !yield(&StreamFlowValue[Out, Stream]{Stream: s}, nil) { - return errStop + select { + case fi.streamCh <- s: + return nil + case <-ctx.Done(): + return ctx.Err() } - return nil } output, err := f.run(ctx, input, cb, opts...) if err != nil { - yield(nil, err) + fi.errCh <- err } else { - yield(&StreamFlowValue[Out, Stream]{Done: true, Output: output}, nil) + fi.output = output + fi.done = true + close(fi.doneCh) } - } + }() + + return fi } var errStop = errors.New("stop") diff --git a/go/genkit/genkit_test.go b/go/genkit/genkit_test.go index adc818cdf..0c629fb97 100644 --- a/go/genkit/genkit_test.go +++ b/go/genkit/genkit_test.go @@ -23,22 +23,25 @@ func TestStreamFlow(t *testing.T) { f := DefineStreamingFlow("count", count) iter := f.Stream(context.Background(), 2) want := 0 - iter(func(val *StreamFlowValue[int, int], err error) bool { + + for !iter.IsDone() { + got, err := iter.Next() if err != nil { - t.Fatal(err) - } - var got int - if val.Done { - got = val.Output - } else { - got = val.Stream + break } - if got != want { + if *got != want { t.Errorf("got %d, want %d", got, want) } want++ - return true - }) + } + + finalOutput, err := iter.FinalOutput() + if err != nil { + t.Fatal(err) + } + if *finalOutput != want { + t.Errorf("got %d, want %d", finalOutput, want) + } } // count streams the numbers from 0 to n-1, then returns n. From c71692a9347693496a275fdd692e765c7967099d Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 8 Aug 2024 18:23:07 -0700 Subject: [PATCH 2/3] Cleaned up API. --- go/genkit/flow.go | 32 ++++++++++++-------------------- go/genkit/genkit_test.go | 6 +++--- go/samples/coffee-shop/main.go | 20 +++++++++++++++++++- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/go/genkit/flow.go b/go/genkit/flow.go index 45f99eb41..b4f59fbac 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -697,35 +697,26 @@ func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(conte // FlowIterator defines the interface for iterating over flow results. type FlowIterator[Out, Stream any] interface { - IsDone() bool - Next() (Stream, error) + Next() (Stream, bool) FinalOutput() (Out, error) } // flowIterator implements the FlowIterator interface. type flowIterator[Out, Stream any] struct { - ctx context.Context done bool output Out + err error streamCh chan Stream doneCh chan struct{} - errCh chan error -} - -// IsDone returns true if the flow has completed. -func (fi *flowIterator[Out, Stream]) IsDone() bool { - return fi.done } // Next returns the next streamed value or an error if the flow has completed or failed. -func (fi *flowIterator[Out, Stream]) Next() (*Stream, error) { +func (fi *flowIterator[Out, Stream]) Next() (*Stream, bool) { select { case stream := <-fi.streamCh: - return &stream, nil - case err := <-fi.errCh: - return nil, err + return &stream, false case <-fi.doneCh: - return nil, errors.New("flow has completed") + return nil, true } } @@ -734,23 +725,24 @@ func (fi *flowIterator[Out, Stream]) FinalOutput() (*Out, error) { if !fi.done { return nil, errors.New("flow has not completed") } + if fi.err != nil { + return nil, fi.err + } return &fi.output, nil } // Stream returns a FlowIterator for the flow. func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) FlowIterator[*Out, *Stream] { fi := &flowIterator[Out, Stream]{ - ctx: ctx, done: false, streamCh: make(chan Stream), doneCh: make(chan struct{}), - errCh: make(chan error), } go func() { cb := func(ctx context.Context, s Stream) error { if ctx.Err() != nil { - fi.errCh <- ctx.Err() + fi.err = ctx.Err() return ctx.Err() } select { @@ -762,12 +754,12 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...Fl } output, err := f.run(ctx, input, cb, opts...) if err != nil { - fi.errCh <- err + fi.err = err } else { fi.output = output - fi.done = true - close(fi.doneCh) } + fi.done = true + close(fi.doneCh) }() return fi diff --git a/go/genkit/genkit_test.go b/go/genkit/genkit_test.go index 0c629fb97..fad281eb8 100644 --- a/go/genkit/genkit_test.go +++ b/go/genkit/genkit_test.go @@ -24,9 +24,9 @@ func TestStreamFlow(t *testing.T) { iter := f.Stream(context.Background(), 2) want := 0 - for !iter.IsDone() { - got, err := iter.Next() - if err != nil { + for { + got, done := iter.Next() + if done { break } if *got != want { diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 7d03a9f7d..066b2b4bf 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -170,7 +170,7 @@ func main() { log.Fatal(err) } - genkit.DefineStreamingFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) { + simpleStructuredGreetingFlow := genkit.DefineStreamingFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) { var callback func(context.Context, *ai.GenerateResponseChunk) error if cb != nil { callback = func(ctx context.Context, c *ai.GenerateResponseChunk) error { @@ -189,6 +189,24 @@ func main() { return resp.Text(), nil }) + genkit.DefineStreamingFlow("nestedStreaming", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) { + iter := simpleStructuredGreetingFlow.Stream(ctx, input) + for { + stream, done := iter.Next() + if done { + break + } + if cb != nil { + cb(ctx, *stream) + } + } + finalOutput, err := iter.FinalOutput() + if err != nil { + return "", err + } + return *finalOutput, nil + }) + genkit.DefineFlow("testAllCoffeeFlows", func(ctx context.Context, _ struct{}) (*testAllCoffeeFlowsOutput, error) { test1, err := simpleGreetingFlow.Run(ctx, &simpleGreetingInput{ CustomerName: "Sam", From 70e2cae36b26fa593ccd8f3be7b40af3a04cf8ee Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Thu, 8 Aug 2024 20:00:37 -0700 Subject: [PATCH 3/3] Updated docs. --- go/genkit/flow.go | 19 +++++++++++-------- go/internal/doc-snippets/flows.go | 28 +++++++++++++--------------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/go/genkit/flow.go b/go/genkit/flow.go index b4f59fbac..b840999d0 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -697,20 +697,23 @@ func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(conte // FlowIterator defines the interface for iterating over flow results. type FlowIterator[Out, Stream any] interface { + // Next returns the next streamed value and a boolean indicating whether the flow has completed. Next() (Stream, bool) + // FinalOutput returns the final output of the flow if it has completed. FinalOutput() (Out, error) } // flowIterator implements the FlowIterator interface. type flowIterator[Out, Stream any] struct { - done bool - output Out - err error - streamCh chan Stream - doneCh chan struct{} + done bool // true if the flow has completed + output Out // the final output of the flow + err error // the error that occurred, if any + streamCh chan Stream // channel to receive streamed values + doneCh chan struct{} // channel to indicate that the flow has completed } -// Next returns the next streamed value or an error if the flow has completed or failed. +// Next returns the next streamed value and a boolean indicating whether the flow has completed. +// If the flow has completed, the returned Stream pointer will be nil and the boolean will be true. func (fi *flowIterator[Out, Stream]) Next() (*Stream, bool) { select { case stream := <-fi.streamCh: @@ -721,6 +724,8 @@ func (fi *flowIterator[Out, Stream]) Next() (*Stream, bool) { } // FinalOutput returns the final output of the flow if it has completed. +// If the flow has not completed, it returns an error. +// If the flow completed with an error, it returns that error. func (fi *flowIterator[Out, Stream]) FinalOutput() (*Out, error) { if !fi.done { return nil, errors.New("flow has not completed") @@ -765,8 +770,6 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...Fl return fi } -var errStop = errors.New("stop") - func finishedOpResponse[O any](op *operation[O]) (O, error) { if !op.Done { return base.Zero[O](), fmt.Errorf("flow %s did not finish execution", op.FlowID) diff --git a/go/internal/doc-snippets/flows.go b/go/internal/doc-snippets/flows.go index b91f02e36..d0e2b7cec 100644 --- a/go/internal/doc-snippets/flows.go +++ b/go/internal/doc-snippets/flows.go @@ -103,22 +103,20 @@ func f3() { // [END streaming] // [START invoke-streaming] - menuSuggestionFlow.Stream( - context.Background(), - "French", - )(func(sfv *genkit.StreamFlowValue[OutputType, StreamType], err error) bool { - if err != nil { - // handle err - return false + iter := menuSuggestionFlow.Stream(context.Background(), InputType("French")) + for { + stream, done := iter.Next() + if done { + break } - if !sfv.Done { - fmt.Print(sfv.Stream) - return true - } else { - fmt.Print(sfv.Output) - return false - } - }) + fmt.Print(*stream) + } + finalOutput, err := iter.FinalOutput() + if err != nil { + // handle error + } else { + fmt.Print(*finalOutput) + } // [END invoke-streaming] }