refactor: improve task management with xsync for concurrent access and enhance callback and subtasks handling as well as memory allocation

This commit is contained in:
yusing
2025-05-25 15:01:44 +08:00
parent ade93d49a3
commit c1221e61d4
16 changed files with 447 additions and 211 deletions

View File

@@ -1,70 +1,217 @@
package task
import (
"context"
"errors"
"fmt"
"sync"
"time"
_ "unsafe"
)
var (
taskPool = make(chan *Task, 100)
root = newRoot()
)
func testCleanup() {
root = newRoot()
}
func newRoot() *Task {
return newTask("root", nil, true)
}
//go:inline
func newTask(name string, parent *Task, needFinish bool) *Task {
var t *Task
select {
case t = <-taskPool:
default:
t = &Task{}
}
t.name = name
t.parent = parent
if needFinish {
t.canceled = make(chan struct{})
} else {
// it will not be nil, because root task always has a canceled channel
t.canceled = parent.canceled
}
return t
}
func putTask(t *Task) {
select {
case taskPool <- t:
default:
return
}
}
//go:inline
func (t *Task) setCause(cause error) {
if cause == nil {
t.cause = context.Canceled
} else {
t.cause = cause
}
}
//go:inline
func (t *Task) addCallback(about string, fn func(), waitSubTasks bool) {
t.mu.Lock()
if t.cause != nil {
t.mu.Unlock()
if waitSubTasks {
waitEmpty(t.children, taskTimeout)
}
fn()
return
}
defer t.mu.Unlock()
if t.callbacks == nil {
t.callbacks = make(map[*Callback]struct{})
t.callbacks = make(callbacksSet)
}
if t.callbacksDone == nil {
t.callbacksDone = make(chan struct{})
}
t.callbacks[&Callback{fn, about, waitSubTasks}] = struct{}{}
t.callbacks[&Callback{
fn: fn,
about: about,
waitChildren: waitSubTasks,
}] = struct{}{}
}
func (t *Task) addChildCount() {
//go:inline
func (t *Task) addChild(child *Task) {
t.mu.Lock()
defer t.mu.Unlock()
t.children++
if t.children == 1 {
t.childrenDone = make(chan struct{})
if t.cause != nil {
t.mu.Unlock()
child.Finish(t.FinishCause())
return
}
defer t.mu.Unlock()
if t.children == nil {
t.children = make(childrenSet)
}
t.children[child] = struct{}{}
}
func (t *Task) subChildCount() {
//go:inline
func (t *Task) removeChild(child *Task) {
t.mu.Lock()
defer t.mu.Unlock()
t.children--
switch t.children {
case 0:
close(t.childrenDone)
case ^uint32(0):
panic("negative child count")
delete(t.children, child)
}
func (t *Task) finishChildren() {
t.mu.Lock()
if len(t.children) == 0 {
t.mu.Unlock()
return
}
var wg sync.WaitGroup
for child := range t.children {
wg.Add(1)
go func() {
defer wg.Done()
child.Finish(t.cause)
}()
}
clear(t.children)
t.mu.Unlock()
wg.Wait()
}
func (t *Task) runCallbacks() {
t.mu.Lock()
if len(t.callbacks) == 0 {
t.mu.Unlock()
return
}
var wg sync.WaitGroup
var needWait bool
// runs callbacks that does not need wait first
for c := range t.callbacks {
if c.waitChildren {
waitWithTimeout(t.childrenDone)
if !c.waitChildren {
wg.Add(1)
go func() {
defer wg.Done()
invokeWithRecover(c)
}()
} else {
needWait = true
}
t.invokeWithRecover(c.fn, c.about)
delete(t.callbacks, c)
}
close(t.callbacksDone)
// runs callbacks that need to wait for children
if needWait {
waitEmpty(t.children, taskTimeout)
for c := range t.callbacks {
if c.waitChildren {
wg.Add(1)
go func() {
defer wg.Done()
invokeWithRecover(c)
}()
}
}
}
clear(t.callbacks)
t.mu.Unlock()
wg.Wait()
}
func waitWithTimeout(ch <-chan struct{}) bool {
if ch == nil {
func (t *Task) waitFinish(timeout time.Duration) bool {
// return directly if already finished
if t.isFinished() {
return true
}
select {
case <-ch:
if len(t.children) == 0 && len(t.callbacks) == 0 {
return true
case <-time.After(taskTimeout):
}
ok := waitEmpty(t.children, timeout) && waitEmpty(t.callbacks, timeout)
if !ok {
return false
}
t.finished.Store(true)
return true
}
//go:inline
func waitEmpty[T comparable](set map[T]struct{}, timeout time.Duration) bool {
if len(set) == 0 {
return true
}
var sema uint32
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
if len(set) == 0 {
return true
}
select {
case <-timer.C:
return false
default:
runtime_Semacquire(&sema)
}
}
}
//go:inline
func fmtCause(cause any) error {
switch cause := cause.(type) {
case nil:
@@ -77,3 +224,6 @@ func fmtCause(cause any) error {
return fmt.Errorf("%v", cause)
}
}
//go:linkname runtime_Semacquire sync.runtime_Semacquire
func runtime_Semacquire(s *uint32)