fix(task): revert to context based approach and fix tasks stuck, improve error handling

This commit is contained in:
yusing
2025-05-26 00:32:59 +08:00
parent 2e9f113224
commit 216c03c5ff
11 changed files with 200 additions and 183 deletions

View File

@@ -4,23 +4,38 @@ import (
"context"
"errors"
"fmt"
"sync"
"time"
_ "unsafe"
"github.com/rs/zerolog/log"
)
var (
taskPool = make(chan *Task, 100)
root = newRoot()
voidTask = &Task{ctx: context.Background()}
root = newRoot()
cancelCtx context.Context
)
func init() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
cancelCtx = ctx //nolint:fatcontext
voidTask.parent = root
}
func testCleanup() {
root = newRoot()
}
func newRoot() *Task {
return newTask("root", nil, true)
return newTask("root", voidTask, true)
}
func noCancel(error) {
// do nothing
}
//go:inline
@@ -28,20 +43,31 @@ func newTask(name string, parent *Task, needFinish bool) *Task {
var t *Task
select {
case t = <-taskPool:
t.finished.Store(false)
default:
t = &Task{}
}
t.name = name
t.parent = parent
if needFinish {
t.canceled = make(chan struct{})
t.ctx, t.cancel = context.WithCancelCause(parent.ctx)
} else {
// it will not be nil, because root task always has a canceled channel
t.canceled = parent.canceled
t.ctx, t.cancel = parent.ctx, noCancel
}
return t
}
//go:inline
func (t *Task) needFinish() bool {
return t.ctx != t.parent.ctx
}
//go:inline
func (t *Task) isCanceled() bool {
return t.cancel == nil
}
//go:inline
func putTask(t *Task) {
select {
case taskPool <- t:
@@ -50,49 +76,76 @@ func putTask(t *Task) {
}
}
//go:inline
func (t *Task) setCause(cause error) {
if cause == nil {
t.cause = context.Canceled
} else {
t.cause = cause
}
}
//go:inline
func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) {
t.mu.Lock()
if t.cause != nil {
t.mu.Unlock()
if !t.needFinish() {
if waitSubTasks {
waitEmpty(t.children, taskTimeout)
t.parent.addCallback(about, func() {
if !t.waitFinish(taskTimeout) {
t.reportStucked()
}
fn()
}, false)
} else {
t.parent.addCallback(about, fn, false)
}
fn()
return
}
defer t.mu.Unlock()
if t.callbacks == nil {
t.callbacks = make(callbacksSet)
if !waitSubTasks {
t.mu.Lock()
defer t.mu.Unlock()
if t.callbacksOnCancel == nil {
t.callbacksOnCancel = make(callbacksSet)
go func() {
<-t.ctx.Done()
for c := range t.callbacksOnCancel {
go func() {
invokeWithRecover(c)
t.mu.Lock()
delete(t.callbacksOnCancel, c)
t.mu.Unlock()
}()
}
}()
}
t.callbacksOnCancel[&Callback{fn: fn, about: about}] = struct{}{}
return
}
t.callbacks[&Callback{
fn: fn,
about: about,
waitChildren: waitSubTasks,
t.mu.Lock()
defer t.mu.Unlock()
if t.isCanceled() {
log.Panic().
Str("task", t.String()).
Str("callback", about).
Msg("callback added to canceled task")
return
}
if t.callbacksOnFinish == nil {
t.callbacksOnFinish = make(callbacksSet)
}
t.callbacksOnFinish[&Callback{
fn: fn,
about: about,
}] = struct{}{}
}
//go:inline
func (t *Task) addChild(child *Task) {
t.mu.Lock()
if t.cause != nil {
t.mu.Unlock()
child.Finish(t.FinishCause())
defer t.mu.Unlock()
if t.isCanceled() {
log.Panic().
Str("task", t.String()).
Str("child", child.Name()).
Msg("child added to canceled task")
return
}
defer t.mu.Unlock()
if t.children == nil {
t.children = make(childrenSet)
}
@@ -106,67 +159,19 @@ func (t *Task) removeChild(child *Task) {
delete(t.children, child)
}
func (t *Task) finishChildren() {
t.mu.Lock()
if len(t.children) == 0 {
t.mu.Unlock()
func (t *Task) runOnFinishCallbacks() {
if len(t.callbacksOnFinish) == 0 {
return
}
var wg sync.WaitGroup
for child := range t.children {
wg.Add(1)
for c := range t.callbacksOnFinish {
go func() {
defer wg.Done()
child.Finish(t.cause)
invokeWithRecover(c)
t.mu.Lock()
delete(t.callbacksOnFinish, c)
t.mu.Unlock()
}()
}
clear(t.children)
t.mu.Unlock()
wg.Wait()
}
func (t *Task) runCallbacks() {
t.mu.Lock()
if len(t.callbacks) == 0 {
t.mu.Unlock()
return
}
var wg sync.WaitGroup
var needWait bool
// runs callbacks that does not need wait first
for c := range t.callbacks {
if !c.waitChildren {
wg.Add(1)
go func() {
defer wg.Done()
invokeWithRecover(c)
}()
} else {
needWait = true
}
}
// runs callbacks that need to wait for children
if needWait {
waitEmpty(t.children, taskTimeout)
for c := range t.callbacks {
if c.waitChildren {
wg.Add(1)
go func() {
defer wg.Done()
invokeWithRecover(c)
}()
}
}
}
clear(t.callbacks)
t.mu.Unlock()
wg.Wait()
}
func (t *Task) waitFinish(timeout time.Duration) bool {
@@ -175,16 +180,24 @@ func (t *Task) waitFinish(timeout time.Duration) bool {
return true
}
if len(t.children) == 0 && len(t.callbacks) == 0 {
return true
t.mu.Lock()
children, callbacksOnCancel, callbacksOnFinish := t.children, t.callbacksOnCancel, t.callbacksOnFinish
t.mu.Unlock()
ok := true
if len(children) != 0 {
ok = waitEmpty(children, timeout)
}
ok := waitEmpty(t.children, timeout) && waitEmpty(t.callbacks, timeout)
if !ok {
return false
if len(callbacksOnCancel) != 0 {
ok = ok && waitEmpty(callbacksOnCancel, timeout)
}
t.finished.Store(true)
return true
if len(callbacksOnFinish) != 0 {
ok = ok && waitEmpty(callbacksOnFinish, timeout)
}
return ok
}
//go:inline
@@ -193,8 +206,6 @@ func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
return true
}
var sema uint32
timer := time.NewTimer(timeout)
defer timer.Stop()
@@ -206,7 +217,7 @@ func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
case <-timer.C:
return false
default:
runtime_Semacquire(&sema)
time.Sleep(100 * time.Millisecond)
}
}
}
@@ -224,6 +235,3 @@ func fmtCause(cause any) error {
return fmt.Errorf("%v", cause)
}
}
//go:linkname runtime_Semacquire sync.runtime_Semacquire
func runtime_Semacquire(s *uint32)