Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Go] Refactored Stream() API to be more ergonomic. #766

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 61 additions & 29 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,48 +695,80 @@ func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(conte
return finishedOpResponse(state.Operation)
}

// StreamFlowValue is either a streamed value or a final output of a flow.
type StreamFlowValue[Out, Stream any] struct {
Done bool
Output Out // valid if Done is true
Stream Stream // valid if Done is false
// FlowIterator defines the interface for iterating over flow results.
type FlowIterator[Out, Stream any] interface {
// Next returns the next streamed value and a boolean indicating whether the flow has completed.
Next() (Stream, bool)
// FinalOutput returns the final output of the flow if it has completed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it do if the flow hasn't completed?

FinalOutput() (Out, error)
}

// Stream runs the flow on input and delivers both the streamed values and the final output.
// It returns a function whose argument function (the "yield function") will be repeatedly
// called with the results.
//
// If the yield function is passed a non-nil error, the flow has failed with that
// error; the yield function will not be called again. An error is also passed if
// the flow fails to complete (that is, it has an interrupt).
// Genkit Go does not yet support interrupts.
//
// If the yield function's [StreamFlowValue] argument has Done == true, the value's
// Output field contains the final output; the yield function will not be called
// again.
//
// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result.
func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) func(func(*StreamFlowValue[Out, Stream], error) bool) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I actually wrote this with range-over-func in mind. Its return type is exactly an iter.Seq2 (see https://pkg.go.dev/iter#Seq2). So think about keeping it. It seems weird now but won't in a few months.

return func(yield func(*StreamFlowValue[Out, Stream], error) bool) {
// flowIterator implements the FlowIterator interface.
type flowIterator[Out, Stream any] struct {
done bool // true if the flow has completed
output Out // the final output of the flow
err error // the error that occurred, if any
streamCh chan Stream // channel to receive streamed values
doneCh chan struct{} // channel to indicate that the flow has completed
}

// Next returns the next streamed value and a boolean indicating whether the flow has completed.
// If the flow has completed, the returned Stream pointer will be nil and the boolean will be true.
func (fi *flowIterator[Out, Stream]) Next() (*Stream, bool) {
select {
case stream := <-fi.streamCh:
return &stream, false
case <-fi.doneCh:
return nil, true
}
}

// FinalOutput returns the final output of the flow if it has completed.
// If the flow has not completed, it returns an error.
// If the flow completed with an error, it returns that error.
func (fi *flowIterator[Out, Stream]) FinalOutput() (*Out, error) {
if !fi.done {
return nil, errors.New("flow has not completed")
}
if fi.err != nil {
return nil, fi.err
}
return &fi.output, nil
}

// Stream returns a FlowIterator for the flow.
func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) FlowIterator[*Out, *Stream] {
fi := &flowIterator[Out, Stream]{
done: false,
streamCh: make(chan Stream),
doneCh: make(chan struct{}),
}

go func() {
cb := func(ctx context.Context, s Stream) error {
if ctx.Err() != nil {
fi.err = ctx.Err()
return ctx.Err()
}
if !yield(&StreamFlowValue[Out, Stream]{Stream: s}, nil) {
return errStop
select {
case fi.streamCh <- s:
return nil
case <-ctx.Done():
return ctx.Err()
}
return nil
}
output, err := f.run(ctx, input, cb, opts...)
if err != nil {
yield(nil, err)
fi.err = err
} else {
yield(&StreamFlowValue[Out, Stream]{Done: true, Output: output}, nil)
fi.output = output
}
}
}
fi.done = true
close(fi.doneCh)
}()

var errStop = errors.New("stop")
return fi
}

func finishedOpResponse[O any](op *operation[O]) (O, error) {
if !op.Done {
Expand Down
27 changes: 15 additions & 12 deletions go/genkit/genkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,25 @@ func TestStreamFlow(t *testing.T) {
f := DefineStreamingFlow("count", count)
iter := f.Stream(context.Background(), 2)
want := 0
iter(func(val *StreamFlowValue[int, int], err error) bool {
if err != nil {
t.Fatal(err)
}
var got int
if val.Done {
got = val.Output
} else {
got = val.Stream

for {
got, done := iter.Next()
if done {
break
}
if got != want {
if *got != want {
t.Errorf("got %d, want %d", got, want)
}
want++
return true
})
}

finalOutput, err := iter.FinalOutput()
if err != nil {
t.Fatal(err)
}
if *finalOutput != want {
t.Errorf("got %d, want %d", finalOutput, want)
}
}

// count streams the numbers from 0 to n-1, then returns n.
Expand Down
28 changes: 13 additions & 15 deletions go/internal/doc-snippets/flows.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,20 @@ func f3() {
// [END streaming]

// [START invoke-streaming]
menuSuggestionFlow.Stream(
context.Background(),
"French",
)(func(sfv *genkit.StreamFlowValue[OutputType, StreamType], err error) bool {
if err != nil {
// handle err
return false
iter := menuSuggestionFlow.Stream(context.Background(), InputType("French"))
for {
stream, done := iter.Next()
if done {
break
}
if !sfv.Done {
fmt.Print(sfv.Stream)
return true
} else {
fmt.Print(sfv.Output)
return false
}
})
fmt.Print(*stream)
}
finalOutput, err := iter.FinalOutput()
if err != nil {
// handle error
} else {
fmt.Print(*finalOutput)
}
// [END invoke-streaming]
}

Expand Down
20 changes: 19 additions & 1 deletion go/samples/coffee-shop/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func main() {
log.Fatal(err)
}

genkit.DefineStreamingFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) {
simpleStructuredGreetingFlow := genkit.DefineStreamingFlow("simpleStructuredGreeting", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) {
var callback func(context.Context, *ai.GenerateResponseChunk) error
if cb != nil {
callback = func(ctx context.Context, c *ai.GenerateResponseChunk) error {
Expand All @@ -189,6 +189,24 @@ func main() {
return resp.Text(), nil
})

genkit.DefineStreamingFlow("nestedStreaming", func(ctx context.Context, input *simpleGreetingInput, cb func(context.Context, string) error) (string, error) {
iter := simpleStructuredGreetingFlow.Stream(ctx, input)
for {
stream, done := iter.Next()
if done {
break
}
if cb != nil {
cb(ctx, *stream)
}
}
finalOutput, err := iter.FinalOutput()
if err != nil {
return "", err
}
return *finalOutput, nil
})

genkit.DefineFlow("testAllCoffeeFlows", func(ctx context.Context, _ struct{}) (*testAllCoffeeFlowsOutput, error) {
test1, err := simpleGreetingFlow.Run(ctx, &simpleGreetingInput{
CustomerName: "Sam",
Expand Down
Loading