stream.go
package assistant
import (
"bufio"
"io"
"strings"
"sync"
)
// EventType represents the type of streaming event.
type EventType int
const (
EventContentDelta EventType = iota // Text content chunk
EventToolCallStart // Tool call started
EventToolCallDelta // Tool call argument chunk
EventDone // Stream complete
EventError // Error occurred
)
// StreamEvent represents a single event in the stream.
type StreamEvent struct {
Type EventType
Content string // For ContentDelta
ToolCall *ToolCall // For ToolCallStart/ToolCallDelta (partial)
ToolIndex int // Index of tool call being updated
Error error // For Error events
Usage *Usage // For Done events (optional)
}
// StreamReader reads events from an SSE stream.
type StreamReader struct {
mu sync.Mutex
reader *bufio.Reader
closer io.Closer
content strings.Builder
toolCalls []ToolCall
done bool
err error
// parseEvent is set by the provider to parse raw SSE data
parseEvent func(data string) (*StreamEvent, error)
}
// NewStreamReader creates a stream reader from an io.ReadCloser.
func NewStreamReader(r io.ReadCloser, parseEvent func(data string) (*StreamEvent, error)) *StreamReader {
return &StreamReader{
reader: bufio.NewReader(r),
closer: r,
parseEvent: parseEvent,
}
}
// Next returns the next event in the stream.
// Returns nil when the stream is complete or on error.
func (s *StreamReader) Next() *StreamEvent {
s.mu.Lock()
defer s.mu.Unlock()
if s.done || s.err != nil {
return nil
}
// Read SSE event
for {
line, err := s.reader.ReadString('\n')
if err != nil {
if err == io.EOF {
s.done = true
return &StreamEvent{Type: EventDone}
}
s.err = err
return &StreamEvent{Type: EventError, Error: err}
}
line = strings.TrimSpace(line)
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, ":") {
continue
}
// Parse data lines
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// Check for stream end
if data == "[DONE]" {
s.done = true
return &StreamEvent{Type: EventDone}
}
// Parse the event
event, err := s.parseEvent(data)
if err != nil {
// Skip unparseable events
continue
}
// Accumulate content and tool calls
switch event.Type {
case EventContentDelta:
s.content.WriteString(event.Content)
case EventToolCallStart:
s.toolCalls = append(s.toolCalls, *event.ToolCall)
case EventToolCallDelta:
if event.ToolIndex < len(s.toolCalls) {
s.toolCalls[event.ToolIndex].Arguments += event.ToolCall.Arguments
}
case EventDone:
s.done = true
case EventError:
s.err = event.Error
}
return event
}
}
}
// Close closes the underlying reader.
func (s *StreamReader) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
s.done = true
if s.closer != nil {
return s.closer.Close()
}
return nil
}
// Content returns all accumulated content.
func (s *StreamReader) Content() string {
s.mu.Lock()
defer s.mu.Unlock()
return s.content.String()
}
// ToolCalls returns all accumulated tool calls.
func (s *StreamReader) ToolCalls() []ToolCall {
s.mu.Lock()
defer s.mu.Unlock()
// Return a copy to avoid race conditions
result := make([]ToolCall, len(s.toolCalls))
copy(result, s.toolCalls)
return result
}
// Done returns true if the stream is complete.
func (s *StreamReader) Done() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.done
}
// Err returns any error that occurred during streaming.
func (s *StreamReader) Err() error {
s.mu.Lock()
defer s.mu.Unlock()
return s.err
}
// Collect reads all events and returns the final response.
func (s *StreamReader) Collect() (*ChatResponse, error) {
for {
event := s.Next()
if event == nil {
break
}
if event.Type == EventError {
return nil, event.Error
}
if event.Type == EventDone {
break
}
}
resp := &ChatResponse{
Content: s.Content(),
ToolCalls: s.ToolCalls(),
}
if len(resp.ToolCalls) > 0 {
resp.FinishReason = "tool_calls"
} else {
resp.FinishReason = "stop"
}
return resp, s.Err()
}