diff --git a/go/genkit/flow.go b/go/genkit/flow.go index 7f62ef024..b840999d0 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -695,48 +695,80 @@ func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(conte return finishedOpResponse(state.Operation) } -// StreamFlowValue is either a streamed value or a final output of a flow. -type StreamFlowValue[Out, Stream any] struct { - Done bool - Output Out // valid if Done is true - Stream Stream // valid if Done is false +// FlowIterator defines the interface for iterating over flow results. +type FlowIterator[Out, Stream any] interface { + // Next returns the next streamed value and a boolean indicating whether the flow has completed. + Next() (Stream, bool) + // FinalOutput returns the final output of the flow if it has completed. + FinalOutput() (Out, error) } -// Stream runs the flow on input and delivers both the streamed values and the final output. -// It returns a function whose argument function (the "yield function") will be repeatedly -// called with the results. -// -// If the yield function is passed a non-nil error, the flow has failed with that -// error; the yield function will not be called again. An error is also passed if -// the flow fails to complete (that is, it has an interrupt). -// Genkit Go does not yet support interrupts. -// -// If the yield function's [StreamFlowValue] argument has Done == true, the value's -// Output field contains the final output; the yield function will not be called -// again. -// -// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result. -func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) func(func(*StreamFlowValue[Out, Stream], error) bool) { - return func(yield func(*StreamFlowValue[Out, Stream], error) bool) { +// flowIterator implements the FlowIterator interface. +type flowIterator[Out, Stream any] struct { + done bool // true if the flow has completed + output Out // the final output of the flow + err error // the error that occurred, if any + streamCh chan Stream // channel to receive streamed values + doneCh chan struct{} // channel to indicate that the flow has completed +} + +// Next returns the next streamed value and a boolean indicating whether the flow has completed. +// If the flow has completed, the returned Stream pointer will be nil and the boolean will be true. +func (fi *flowIterator[Out, Stream]) Next() (*Stream, bool) { + select { + case stream := <-fi.streamCh: + return &stream, false + case <-fi.doneCh: + return nil, true + } +} + +// FinalOutput returns the final output of the flow if it has completed. +// If the flow has not completed, it returns an error. +// If the flow completed with an error, it returns that error. +func (fi *flowIterator[Out, Stream]) FinalOutput() (*Out, error) { + if !fi.done { + return nil, errors.New("flow has not completed") + } + if fi.err != nil { + return nil, fi.err + } + return &fi.output, nil +} + +// Stream returns a FlowIterator for the flow. +func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) FlowIterator[*Out, *Stream] { + fi := &flowIterator[Out, Stream]{ + done: false, + streamCh: make(chan Stream), + doneCh: make(chan struct{}), + } + + go func() { cb := func(ctx context.Context, s Stream) error { if ctx.Err() != nil { + fi.err = ctx.Err() return ctx.Err() } - if !yield(&StreamFlowValue[Out, Stream]{Stream: s}, nil) { - return errStop + select { + case fi.streamCh <- s: + return nil + case <-ctx.Done(): + return ctx.Err() } - return nil } output, err := f.run(ctx, input, cb, opts...) if err != nil { - yield(nil, err) + fi.err = err } else { - yield(&StreamFlowValue[Out, Stream]{Done: true, Output: output}, nil) + fi.output = output } - } -} + fi.done = true + close(fi.doneCh) + }() -var errStop = errors.New("stop") + return fi +} func finishedOpResponse[O any](op *operation[O]) (O, error) { if !op.Done { diff --git a/go/genkit/genkit_test.go b/go/genkit/genkit_test.go index adc818cdf..fad281eb8 100644 --- a/go/genkit/genkit_test.go +++ b/go/genkit/genkit_test.go @@ -23,22 +23,25 @@ func TestStreamFlow(t *testing.T) { f := DefineStreamingFlow("count", count) iter := f.Stream(context.Background(), 2) want := 0 - iter(func(val *StreamFlowValue[int, int], err error) bool { - if err != nil { - t.Fatal(err) - } - var got int - if val.Done { - got = val.Output - } else { - got = val.Stream + + for { + got, done := iter.Next() + if done { + break } - if got != want { + if *got != want { t.Errorf("got %d, want %d", got, want) } want++ - return true - }) + } + + finalOutput, err := iter.FinalOutput() + if err != nil { + t.Fatal(err) + } + if *finalOutput != want { + t.Errorf("got %d, want %d", finalOutput, want) + } } // count streams the numbers from 0 to n-1, then returns n. diff --git a/go/internal/doc-snippets/flows.go b/go/internal/doc-snippets/flows.go index b91f02e36..d0e2b7cec 100644 --- a/go/internal/doc-snippets/flows.go +++ b/go/internal/doc-snippets/flows.go @@ -103,22 +103,20 @@ func f3() { // [END streaming] // [START invoke-streaming] - menuSuggestionFlow.Stream( - context.Background(), - "French", - )(func(sfv *genkit.StreamFlowValue[OutputType, StreamType], err error) bool { - if err != nil { - // handle err - return false + iter := menuSuggestionFlow.Stream(context.Background(), InputType("French")) + for { + stream, done := iter.Next() + if done { + break } - if !sfv.Done { - fmt.Print(sfv.Stream) - return true - } else { - fmt.Print(sfv.Output) - return false - } - }) + fmt.Print(*stream) + } + finalOutput, err := iter.FinalOutput() + if err != nil { + // handle error + } else { + fmt.Print(*finalOutput) + } // [END invoke-streaming] } diff --git a/go/samples/coffee-shop/main.go b/go/samples/coffee-shop/main.go index 7d03a9f7d..066b2b4bf 100755 --- a/go/samples/coffee-shop/main.go +++ b/go/samples/coffee-shop/main.go @@ -170,7 +170,7 @@ func main() { log.Fatal(err) } - genkit.DefineStreamingFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) { + simpleStructuredGreetingFlow := genkit.DefineStreamingFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) { var callback func(context.Context, *ai.GenerateResponseChunk) error if cb != nil { callback = func(ctx context.Context, c *ai.GenerateResponseChunk) error { @@ -189,6 +189,24 @@ func main() { return resp.Text(), nil }) + genkit.DefineStreamingFlow("nestedStreaming", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) { + iter := simpleStructuredGreetingFlow.Stream(ctx, input) + for { + stream, done := iter.Next() + if done { + break + } + if cb != nil { + cb(ctx, *stream) + } + } + finalOutput, err := iter.FinalOutput() + if err != nil { + return "", err + } + return *finalOutput, nil + }) + genkit.DefineFlow("testAllCoffeeFlows", func(ctx context.Context, _ struct{}) (*testAllCoffeeFlowsOutput, error) { test1, err := simpleGreetingFlow.Run(ctx, &simpleGreetingInput{ CustomerName: "Sam",