readysite / pkg / assistant / stream.go
4.2 KB
stream.go
package assistant

import (
	"bufio"
	"io"
	"strings"
	"sync"
)

// EventType represents the type of streaming event.
type EventType int

const (
	EventContentDelta   EventType = iota // Text content chunk
	EventToolCallStart                   // Tool call started
	EventToolCallDelta                   // Tool call argument chunk
	EventDone                            // Stream complete
	EventError                           // Error occurred
)

// StreamEvent represents a single event in the stream.
type StreamEvent struct {
	Type       EventType
	Content    string    // For ContentDelta
	ToolCall   *ToolCall // For ToolCallStart/ToolCallDelta (partial)
	ToolIndex  int       // Index of tool call being updated
	Error      error     // For Error events
	Usage      *Usage    // For Done events (optional)
}

// StreamReader reads events from an SSE stream.
type StreamReader struct {
	mu        sync.Mutex
	reader    *bufio.Reader
	closer    io.Closer
	content   strings.Builder
	toolCalls []ToolCall
	done      bool
	err       error

	// parseEvent is set by the provider to parse raw SSE data
	parseEvent func(data string) (*StreamEvent, error)
}

// NewStreamReader creates a stream reader from an io.ReadCloser.
func NewStreamReader(r io.ReadCloser, parseEvent func(data string) (*StreamEvent, error)) *StreamReader {
	return &StreamReader{
		reader:     bufio.NewReader(r),
		closer:     r,
		parseEvent: parseEvent,
	}
}

// Next returns the next event in the stream.
// Returns nil when the stream is complete or on error.
func (s *StreamReader) Next() *StreamEvent {
	s.mu.Lock()
	defer s.mu.Unlock()

	if s.done || s.err != nil {
		return nil
	}

	// Read SSE event
	for {
		line, err := s.reader.ReadString('\n')
		if err != nil {
			if err == io.EOF {
				s.done = true
				return &StreamEvent{Type: EventDone}
			}
			s.err = err
			return &StreamEvent{Type: EventError, Error: err}
		}

		line = strings.TrimSpace(line)

		// Skip empty lines and comments
		if line == "" || strings.HasPrefix(line, ":") {
			continue
		}

		// Parse data lines
		if strings.HasPrefix(line, "data: ") {
			data := strings.TrimPrefix(line, "data: ")

			// Check for stream end
			if data == "[DONE]" {
				s.done = true
				return &StreamEvent{Type: EventDone}
			}

			// Parse the event
			event, err := s.parseEvent(data)
			if err != nil {
				// Skip unparseable events
				continue
			}

			// Accumulate content and tool calls
			switch event.Type {
			case EventContentDelta:
				s.content.WriteString(event.Content)
			case EventToolCallStart:
				s.toolCalls = append(s.toolCalls, *event.ToolCall)
			case EventToolCallDelta:
				if event.ToolIndex < len(s.toolCalls) {
					s.toolCalls[event.ToolIndex].Arguments += event.ToolCall.Arguments
				}
			case EventDone:
				s.done = true
			case EventError:
				s.err = event.Error
			}

			return event
		}
	}
}

// Close closes the underlying reader.
func (s *StreamReader) Close() error {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.done = true
	if s.closer != nil {
		return s.closer.Close()
	}
	return nil
}

// Content returns all accumulated content.
func (s *StreamReader) Content() string {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.content.String()
}

// ToolCalls returns all accumulated tool calls.
func (s *StreamReader) ToolCalls() []ToolCall {
	s.mu.Lock()
	defer s.mu.Unlock()
	// Return a copy to avoid race conditions
	result := make([]ToolCall, len(s.toolCalls))
	copy(result, s.toolCalls)
	return result
}

// Done returns true if the stream is complete.
func (s *StreamReader) Done() bool {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.done
}

// Err returns any error that occurred during streaming.
func (s *StreamReader) Err() error {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.err
}

// Collect reads all events and returns the final response.
func (s *StreamReader) Collect() (*ChatResponse, error) {
	for {
		event := s.Next()
		if event == nil {
			break
		}
		if event.Type == EventError {
			return nil, event.Error
		}
		if event.Type == EventDone {
			break
		}
	}

	resp := &ChatResponse{
		Content:   s.Content(),
		ToolCalls: s.ToolCalls(),
	}

	if len(resp.ToolCalls) > 0 {
		resp.FinishReason = "tool_calls"
	} else {
		resp.FinishReason = "stop"
	}

	return resp, s.Err()
}
← Back