mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-18 22:49:52 +02:00
simplify task package implementation
This commit is contained in:
@@ -4,355 +4,194 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var globalTask = createGlobalTask()
|
||||
|
||||
func createGlobalTask() (t *Task) {
|
||||
t = new(Task)
|
||||
t.name = "root"
|
||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||
t.subtasks = F.NewSet[*Task]()
|
||||
return
|
||||
}
|
||||
|
||||
func testResetGlobalTask() {
|
||||
globalTask = createGlobalTask()
|
||||
}
|
||||
|
||||
type (
|
||||
TaskStarter interface {
|
||||
// Start starts the object that implements TaskStarter,
|
||||
// and returns an error if it fails to start.
|
||||
//
|
||||
// The task passed must be a subtask of the caller task.
|
||||
//
|
||||
// callerSubtask.Finish must be called when start fails or the object is finished.
|
||||
Start(callerSubtask *Task) E.Error
|
||||
Start(parent Parent) E.Error
|
||||
Task() *Task
|
||||
}
|
||||
TaskFinisher interface {
|
||||
// Finish marks the task as finished and cancel its context.
|
||||
//
|
||||
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
|
||||
// of the task to finish.
|
||||
//
|
||||
// Note that it will also cancel all subtasks.
|
||||
Finish(reason any)
|
||||
}
|
||||
// Task controls objects' lifetime.
|
||||
//
|
||||
// Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface.
|
||||
//
|
||||
// When passing a Task object to another function,
|
||||
// it must be a sub-Task of the current Task,
|
||||
// in name of "`currentTaskName`Subtask"
|
||||
//
|
||||
// Use Task.Finish to stop all subtasks of the Task.
|
||||
Task struct {
|
||||
name string
|
||||
|
||||
children sync.WaitGroup
|
||||
|
||||
onFinished sync.WaitGroup
|
||||
finished chan struct{}
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
|
||||
parent *Task
|
||||
subtasks F.Set[*Task]
|
||||
subTasksWg sync.WaitGroup
|
||||
|
||||
name string
|
||||
|
||||
OnFinishedFuncs []func()
|
||||
OnFinishedMu sync.Mutex
|
||||
onFinishedWg sync.WaitGroup
|
||||
|
||||
finishOnce sync.Once
|
||||
once sync.Once
|
||||
}
|
||||
Parent interface {
|
||||
Context() context.Context
|
||||
Subtask(name string, needFinish ...bool) *Task
|
||||
Name() string
|
||||
Finish(reason any)
|
||||
OnCancel(name string, f func())
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProgramExiting = errors.New("program exiting")
|
||||
ErrTaskCanceled = errors.New("task canceled")
|
||||
|
||||
logger = logging.With().Str("module", "task").Logger()
|
||||
)
|
||||
|
||||
// GlobalTask returns a new Task with the given name, derived from the global context.
|
||||
func GlobalTask(format string, args ...any) *Task {
|
||||
if len(args) > 0 {
|
||||
format = fmt.Sprintf(format, args...)
|
||||
}
|
||||
return globalTask.Subtask(format)
|
||||
}
|
||||
|
||||
// DebugTaskMap returns a map[string]any representation of the global task tree.
|
||||
//
|
||||
// The returned map is suitable for encoding to JSON, and can be used
|
||||
// to debug the task tree.
|
||||
//
|
||||
// The returned map is not guaranteed to be stable, and may change
|
||||
// between runs of the program. It is intended for debugging purposes
|
||||
// only.
|
||||
func DebugTaskMap() map[string]any {
|
||||
return globalTask.serialize()
|
||||
}
|
||||
|
||||
// CancelGlobalContext cancels the global task context, which will cause all tasks
|
||||
// created to be canceled. This should be called before exiting the program
|
||||
// to ensure that all tasks are properly cleaned up.
|
||||
func CancelGlobalContext() {
|
||||
globalTask.cancel(ErrProgramExiting)
|
||||
}
|
||||
|
||||
// GlobalContextWait waits for all tasks to finish, up to the given timeout.
|
||||
//
|
||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||
// still running when the timeout was reached, and their current tree
|
||||
// of subtasks.
|
||||
func GlobalContextWait(timeout time.Duration) (err error) {
|
||||
done := make(chan struct{})
|
||||
after := time.After(timeout)
|
||||
go func() {
|
||||
globalTask.Wait()
|
||||
close(done)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-after:
|
||||
logger.Warn().Msg("Timeout waiting for these tasks to finish:\n" + globalTask.tree())
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task) trace(msg string) {
|
||||
logger.Trace().Str("name", t.name).Msg(msg)
|
||||
}
|
||||
|
||||
// Name returns the name of the task.
|
||||
func (t *Task) Name() string {
|
||||
if !common.IsTrace {
|
||||
return t.name
|
||||
}
|
||||
parts := strings.Split(t.name, " > ")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// String returns the name of the task.
|
||||
func (t *Task) String() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Context returns the context associated with the task. This context is
|
||||
// canceled when Finish of the task is called, or parent task is canceled.
|
||||
func (t *Task) Context() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
func (t *Task) Finished() <-chan struct{} {
|
||||
return t.finished
|
||||
}
|
||||
|
||||
// FinishCause returns the reason / error that caused the task to be finished.
|
||||
func (t *Task) FinishCause() error {
|
||||
cause := context.Cause(t.ctx)
|
||||
if cause == nil {
|
||||
return t.ctx.Err()
|
||||
}
|
||||
return cause
|
||||
return context.Cause(t.ctx)
|
||||
}
|
||||
|
||||
// Parent returns the parent task of the current task.
|
||||
func (t *Task) Parent() *Task {
|
||||
return t.parent
|
||||
}
|
||||
|
||||
func (t *Task) runAllOnFinished(onCompTask *Task) {
|
||||
<-t.ctx.Done()
|
||||
t.WaitSubTasks()
|
||||
for _, OnFinishedFunc := range t.OnFinishedFuncs {
|
||||
OnFinishedFunc()
|
||||
t.onFinishedWg.Done()
|
||||
}
|
||||
onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done"))
|
||||
}
|
||||
|
||||
// OnFinished calls fn when all subtasks are finished.
|
||||
// OnFinished calls fn when the task is canceled and all subtasks are finished.
|
||||
//
|
||||
// It cannot be called after Finish or Wait is called.
|
||||
// It should not be called after Finish is called.
|
||||
func (t *Task) OnFinished(about string, fn func()) {
|
||||
if t.parent == globalTask {
|
||||
t.OnCancel(about, fn)
|
||||
return
|
||||
}
|
||||
t.onFinishedWg.Add(1)
|
||||
t.OnFinishedMu.Lock()
|
||||
defer t.OnFinishedMu.Unlock()
|
||||
|
||||
if t.OnFinishedFuncs == nil {
|
||||
onCompTask := GlobalTask(t.name + " > OnFinished > " + about)
|
||||
go t.runAllOnFinished(onCompTask)
|
||||
}
|
||||
idx := len(t.OnFinishedFuncs)
|
||||
wrapped := func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error().
|
||||
Str("name", t.name).
|
||||
Interface("err", err).
|
||||
Msg("panic in " + about)
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
logger.Trace().Str("name", t.name).Msgf("OnFinished[%d] done: %s", idx, about)
|
||||
}
|
||||
t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped)
|
||||
t.onCancel(about, fn, true)
|
||||
}
|
||||
|
||||
// OnCancel calls fn when the task is canceled.
|
||||
//
|
||||
// It cannot be called after Finish or Wait is called.
|
||||
// It should not be called after Finish is called.
|
||||
func (t *Task) OnCancel(about string, fn func()) {
|
||||
onCompTask := GlobalTask(t.name + " > OnFinished")
|
||||
t.onCancel(about, fn, false)
|
||||
}
|
||||
|
||||
func (t *Task) onCancel(about string, fn func(), waitSubTasks bool) {
|
||||
t.onFinished.Add(1)
|
||||
go func() {
|
||||
<-t.ctx.Done()
|
||||
fn()
|
||||
onCompTask.Finish("done")
|
||||
t.trace("onCancel done: " + about)
|
||||
if waitSubTasks {
|
||||
t.children.Wait()
|
||||
}
|
||||
t.invokeWithRecover(fn, about)
|
||||
t.onFinished.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
// Finish marks the task as finished and cancel its context.
|
||||
//
|
||||
// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished
|
||||
// of the task to finish.
|
||||
//
|
||||
// Note that it will also cancel all subtasks.
|
||||
// Finish cancel all subtasks and wait for them to finish,
|
||||
// then marks the task as finished, with the given reason (if any).
|
||||
func (t *Task) Finish(reason any) {
|
||||
var format string
|
||||
switch reason.(type) {
|
||||
case error:
|
||||
format = "%w"
|
||||
case string, fmt.Stringer:
|
||||
format = "%s"
|
||||
select {
|
||||
case <-t.finished:
|
||||
return
|
||||
default:
|
||||
format = "%v"
|
||||
t.once.Do(func() {
|
||||
t.finish(reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Task) finish(reason any) {
|
||||
t.cancel(fmtCause(reason))
|
||||
t.children.Wait()
|
||||
t.onFinished.Wait()
|
||||
if t.finished != nil {
|
||||
close(t.finished)
|
||||
}
|
||||
logger.Trace().Msg("task " + t.name + " finished")
|
||||
}
|
||||
|
||||
func fmtCause(cause any) error {
|
||||
switch cause := cause.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case error:
|
||||
return cause
|
||||
case string:
|
||||
return errors.New(cause)
|
||||
default:
|
||||
return fmt.Errorf("%v", cause)
|
||||
}
|
||||
t.finishOnce.Do(func() {
|
||||
t.cancel(fmt.Errorf("%w: %s, reason: "+format, ErrTaskCanceled, t.name, reason))
|
||||
})
|
||||
t.Wait()
|
||||
}
|
||||
|
||||
// Subtask returns a new subtask with the given name, derived from the parent's context.
|
||||
//
|
||||
// If the parent's context is already canceled, the returned subtask will be canceled immediately.
|
||||
//
|
||||
// This should not be called after Finish, Wait, or WaitSubTasks is called.
|
||||
func (t *Task) Subtask(name string) *Task {
|
||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||
return t.newSubTask(ctx, cancel, name)
|
||||
}
|
||||
// This should not be called after Finish is called.
|
||||
func (t *Task) Subtask(name string, needFinish ...bool) *Task {
|
||||
nf := len(needFinish) == 0 || needFinish[0]
|
||||
|
||||
func (t *Task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *Task {
|
||||
parent := t
|
||||
if common.IsTrace {
|
||||
name = parent.name + " > " + name
|
||||
}
|
||||
subtask := &Task{
|
||||
ctx, cancel := context.WithCancelCause(t.ctx)
|
||||
child := &Task{
|
||||
finished: make(chan struct{}, 1),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
name: name,
|
||||
parent: parent,
|
||||
subtasks: F.NewSet[*Task](),
|
||||
}
|
||||
parent.subTasksWg.Add(1)
|
||||
parent.subtasks.Add(subtask)
|
||||
if common.IsTrace {
|
||||
subtask.trace("started")
|
||||
if t != root {
|
||||
child.name = t.name + "." + name
|
||||
allTasks.Add(child)
|
||||
} else {
|
||||
child.name = name
|
||||
}
|
||||
|
||||
allTasksWg.Add(1)
|
||||
t.children.Add(1)
|
||||
|
||||
if !nf {
|
||||
go func() {
|
||||
subtask.Wait()
|
||||
subtask.trace("finished: " + subtask.FinishCause().Error())
|
||||
<-child.ctx.Done()
|
||||
child.Finish(nil)
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
subtask.Wait()
|
||||
parent.subtasks.Remove(subtask)
|
||||
parent.subTasksWg.Done()
|
||||
<-child.finished
|
||||
allTasksWg.Done()
|
||||
t.children.Done()
|
||||
allTasks.Remove(child)
|
||||
}()
|
||||
return subtask
|
||||
|
||||
logger.Trace().Msg("task " + child.name + " started")
|
||||
return child
|
||||
}
|
||||
|
||||
// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish.
|
||||
//
|
||||
// It must be called only after Finish is called.
|
||||
func (t *Task) Wait() {
|
||||
<-t.ctx.Done()
|
||||
t.WaitSubTasks()
|
||||
t.onFinishedWg.Wait()
|
||||
// Name returns the name of the task without parent names.
|
||||
func (t *Task) Name() string {
|
||||
parts := strutils.SplitRune(t.name, '.')
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// WaitSubTasks waits for all subtasks of the task to finish.
|
||||
//
|
||||
// No more subtasks can be added after this call.
|
||||
//
|
||||
// It can be called before Finish is called.
|
||||
func (t *Task) WaitSubTasks() {
|
||||
t.subTasksWg.Wait()
|
||||
// String returns the full name of the task.
|
||||
func (t *Task) String() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// tree returns a string representation of the task tree, with the given
|
||||
// prefix prepended to each line. The prefix is used to indent the tree,
|
||||
// and should be a string of spaces or a similar separator.
|
||||
//
|
||||
// The resulting string is suitable for printing to the console, and can be
|
||||
// used to debug the task tree.
|
||||
//
|
||||
// The tree is traversed in a depth-first manner, with each task's name and
|
||||
// line number (if available) printed on a separate line. The line number is
|
||||
// only printed if the task was created with a non-empty line argument.
|
||||
//
|
||||
// The returned string is not guaranteed to be stable, and may change between
|
||||
// runs of the program. It is intended for debugging purposes only.
|
||||
func (t *Task) tree(prefix ...string) string {
|
||||
var sb strings.Builder
|
||||
var pre string
|
||||
if len(prefix) > 0 {
|
||||
pre = prefix[0]
|
||||
sb.WriteString(pre + "- ")
|
||||
}
|
||||
sb.WriteString(t.Name() + "\n")
|
||||
t.subtasks.RangeAll(func(subtask *Task) {
|
||||
sb.WriteString(subtask.tree(pre + " "))
|
||||
})
|
||||
return sb.String()
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (t *Task) MarshalText() ([]byte, error) {
|
||||
return []byte(t.name), nil
|
||||
}
|
||||
|
||||
// serialize returns a map[string]any representation of the task tree.
|
||||
//
|
||||
// The map contains the following keys:
|
||||
// - name: the name of the task
|
||||
// - subtasks: a slice of maps, each representing a subtask
|
||||
//
|
||||
// The subtask maps contain the same keys, recursively.
|
||||
//
|
||||
// The returned map is suitable for encoding to JSON, and can be used
|
||||
// to debug the task tree.
|
||||
//
|
||||
// The returned map is not guaranteed to be stable, and may change
|
||||
// between runs of the program. It is intended for debugging purposes
|
||||
// only.
|
||||
func (t *Task) serialize() map[string]any {
|
||||
m := make(map[string]any)
|
||||
parts := strings.Split(t.name, " > ")
|
||||
m["name"] = parts[len(parts)-1]
|
||||
if t.subtasks.Size() > 0 {
|
||||
m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size())
|
||||
t.subtasks.RangeAll(func(subtask *Task) {
|
||||
m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize())
|
||||
})
|
||||
}
|
||||
return m
|
||||
func (t *Task) invokeWithRecover(fn func(), caller string) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error().
|
||||
Interface("err", err).
|
||||
Msg("panic in task " + t.name + "." + caller)
|
||||
if common.IsDebug {
|
||||
panic(string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
}
|
||||
|
||||
@@ -2,132 +2,112 @@ package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
const (
|
||||
rootTaskName = "root-task"
|
||||
subTaskName = "subtask"
|
||||
)
|
||||
|
||||
func TestTaskCreation(t *testing.T) {
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
subTask := rootTask.Subtask(subTaskName)
|
||||
|
||||
ExpectEqual(t, rootTaskName, rootTask.Name())
|
||||
ExpectEqual(t, subTaskName, subTask.Name())
|
||||
func testTask() *Task {
|
||||
return RootTask("test", false)
|
||||
}
|
||||
|
||||
func TestTaskCancellation(t *testing.T) {
|
||||
subTaskDone := make(chan struct{})
|
||||
func TestChildTaskCancellation(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
subTask := rootTask.Subtask(subTaskName)
|
||||
parent := testTask()
|
||||
child := parent.Subtask("")
|
||||
|
||||
go func() {
|
||||
subTask.Wait()
|
||||
close(subTaskDone)
|
||||
defer child.Finish(nil)
|
||||
for {
|
||||
select {
|
||||
case <-child.Context().Done():
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go rootTask.Finish(nil)
|
||||
parent.cancel(nil) // should also cancel child
|
||||
|
||||
select {
|
||||
case <-subTaskDone:
|
||||
err := subTask.Context().Err()
|
||||
ExpectError(t, context.Canceled, err)
|
||||
cause := context.Cause(subTask.Context())
|
||||
ExpectError(t, ErrTaskCanceled, cause)
|
||||
case <-time.After(1 * time.Second):
|
||||
case <-child.Finished():
|
||||
ExpectError(t, context.Canceled, child.Context().Err())
|
||||
default:
|
||||
t.Fatal("subTask context was not canceled as expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnComplete(t *testing.T) {
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
task := rootTask.Subtask(subTaskName)
|
||||
func TestTaskOnCancelOnFinished(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
task := testTask()
|
||||
|
||||
var value atomic.Int32
|
||||
task.OnFinished("set value", func() {
|
||||
value.Store(1234)
|
||||
var shouldTrueOnCancel bool
|
||||
var shouldTrueOnFinish bool
|
||||
|
||||
task.OnCancel("", func() {
|
||||
shouldTrueOnCancel = true
|
||||
})
|
||||
task.OnFinished("", func() {
|
||||
shouldTrueOnFinish = true
|
||||
})
|
||||
|
||||
ExpectFalse(t, shouldTrueOnFinish)
|
||||
task.Finish(nil)
|
||||
ExpectEqual(t, value.Load(), 1234)
|
||||
ExpectTrue(t, shouldTrueOnCancel)
|
||||
ExpectTrue(t, shouldTrueOnFinish)
|
||||
}
|
||||
|
||||
func TestGlobalContextWait(t *testing.T) {
|
||||
testResetGlobalTask()
|
||||
defer CancelGlobalContext()
|
||||
func TestCommonFlowWithGracefulShutdown(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
task := testTask()
|
||||
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
finished := false
|
||||
|
||||
finished1, finished2 := false, false
|
||||
|
||||
subTask1 := rootTask.Subtask(subTaskName)
|
||||
subTask2 := rootTask.Subtask(subTaskName)
|
||||
subTask1.OnFinished("", func() {
|
||||
finished1 = true
|
||||
})
|
||||
subTask2.OnFinished("", func() {
|
||||
finished2 = true
|
||||
task.OnFinished("", func() {
|
||||
finished = true
|
||||
})
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
subTask1.Finish(nil)
|
||||
defer task.Finish(nil)
|
||||
for {
|
||||
select {
|
||||
case <-task.Context().Done():
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
subTask2.Finish(nil)
|
||||
}()
|
||||
ExpectNoError(t, GracefulShutdown(1*time.Second))
|
||||
ExpectTrue(t, finished)
|
||||
|
||||
go func() {
|
||||
subTask1.Wait()
|
||||
subTask2.Wait()
|
||||
rootTask.Finish(nil)
|
||||
}()
|
||||
|
||||
_ = GlobalContextWait(1 * time.Second)
|
||||
ExpectTrue(t, finished1)
|
||||
ExpectTrue(t, finished2)
|
||||
ExpectError(t, context.Canceled, rootTask.Context().Err())
|
||||
ExpectError(t, ErrTaskCanceled, context.Cause(subTask1.Context()))
|
||||
ExpectError(t, ErrTaskCanceled, context.Cause(subTask2.Context()))
|
||||
<-root.finished
|
||||
ExpectError(t, context.Canceled, task.Context().Err())
|
||||
ExpectError(t, ErrProgramExiting, task.FinishCause())
|
||||
}
|
||||
|
||||
func TestTimeoutOnGlobalContextWait(t *testing.T) {
|
||||
testResetGlobalTask()
|
||||
func TestTimeoutOnGracefulShutdown(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
_ = testTask()
|
||||
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
rootTask.Subtask(subTaskName)
|
||||
|
||||
ExpectError(t, context.DeadlineExceeded, GlobalContextWait(200*time.Millisecond))
|
||||
ExpectError(t, context.DeadlineExceeded, GracefulShutdown(time.Millisecond))
|
||||
}
|
||||
|
||||
func TestGlobalContextCancellation(t *testing.T) {
|
||||
testResetGlobalTask()
|
||||
|
||||
taskDone := make(chan struct{})
|
||||
rootTask := GlobalTask(rootTaskName)
|
||||
|
||||
go func() {
|
||||
rootTask.Wait()
|
||||
close(taskDone)
|
||||
}()
|
||||
|
||||
CancelGlobalContext()
|
||||
|
||||
select {
|
||||
case <-taskDone:
|
||||
err := rootTask.Context().Err()
|
||||
ExpectError(t, context.Canceled, err)
|
||||
cause := context.Cause(rootTask.Context())
|
||||
ExpectError(t, ErrProgramExiting, cause)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("subTask context was not canceled as expected")
|
||||
func TestFinishMultipleCalls(t *testing.T) {
|
||||
t.Cleanup(testCleanup)
|
||||
task := testTask()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(5)
|
||||
for range 5 {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
task.Finish(nil)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
96
internal/task/utils.go
Normal file
96
internal/task/utils.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
var ErrProgramExiting = errors.New("program exiting")
|
||||
|
||||
var logger = logging.With().Str("module", "task").Logger()
|
||||
|
||||
var root = newRoot()
|
||||
var allTasks = F.NewSet[*Task]()
|
||||
var allTasksWg sync.WaitGroup
|
||||
|
||||
func testCleanup() {
|
||||
root = newRoot()
|
||||
allTasks.Clear()
|
||||
allTasksWg = sync.WaitGroup{}
|
||||
}
|
||||
|
||||
// RootTask returns a new Task with the given name, derived from the root context.
|
||||
func RootTask(name string, needFinish bool) *Task {
|
||||
return root.Subtask(name, needFinish)
|
||||
}
|
||||
|
||||
func newRoot() *Task {
|
||||
t := &Task{name: "root"}
|
||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||
return t
|
||||
}
|
||||
|
||||
func RootContext() context.Context {
|
||||
return root.ctx
|
||||
}
|
||||
|
||||
func RootContextCanceled() <-chan struct{} {
|
||||
return root.ctx.Done()
|
||||
}
|
||||
|
||||
func OnProgramExit(about string, fn func()) {
|
||||
root.OnFinished(about, fn)
|
||||
}
|
||||
|
||||
// GracefulShutdown waits for all tasks to finish, up to the given timeout.
|
||||
//
|
||||
// If the timeout is exceeded, it prints a list of all tasks that were
|
||||
// still running when the timeout was reached, and their current tree
|
||||
// of subtasks.
|
||||
func GracefulShutdown(timeout time.Duration) (err error) {
|
||||
root.cancel(ErrProgramExiting)
|
||||
|
||||
done := make(chan struct{})
|
||||
after := time.After(timeout)
|
||||
|
||||
go func() {
|
||||
allTasksWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-after:
|
||||
b, err := json.Marshal(DebugTaskList())
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to marshal tasks")
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
logger.Warn().RawJSON("tasks", b).Msgf("Timeout waiting for these %d tasks to finish", allTasks.Size())
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DebugTaskList returns list of all tasks.
|
||||
//
|
||||
// The returned string is suitable for printing to the console.
|
||||
func DebugTaskList() []string {
|
||||
l := make([]string, 0, allTasks.Size())
|
||||
|
||||
allTasks.RangeAll(func(t *Task) {
|
||||
l = append(l, t.name)
|
||||
})
|
||||
|
||||
slices.Sort(l)
|
||||
return l
|
||||
}
|
||||
Reference in New Issue
Block a user