diff --git a/internal/config/config.go b/internal/config/config.go index 235fd8e3..6771fc67 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -27,7 +27,7 @@ type Config struct { value *types.Config providers F.Map[string, *proxy.Provider] autocertProvider *autocert.Provider - task task.Task + task *task.Task } var ( @@ -88,7 +88,7 @@ func WatchChanges() { eventQueue.Start(cfgWatcher.Events(task.Context())) } -func OnConfigChange(flushTask task.Task, ev []events.Event) { +func OnConfigChange(flushTask *task.Task, ev []events.Event) { defer flushTask.Finish("config reload complete") // no matter how many events during the interval @@ -136,7 +136,7 @@ func GetAutoCertProvider() *autocert.Provider { return instance.autocertProvider } -func (cfg *Config) Task() task.Task { +func (cfg *Config) Task() *task.Task { return cfg.task } diff --git a/internal/docker/idlewatcher/waker.go b/internal/docker/idlewatcher/waker.go index 2aa8480e..93c5091e 100644 --- a/internal/docker/idlewatcher/waker.go +++ b/internal/docker/idlewatcher/waker.go @@ -39,7 +39,7 @@ const ( // TODO: support stream -func newWaker(providerSubTask task.Task, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { +func newWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy, stream net.Stream) (Waker, E.Error) { hcCfg := entry.HealthCheckConfig() hcCfg.Timeout = idleWakerCheckTimeout @@ -72,16 +72,16 @@ func newWaker(providerSubTask task.Task, entry route.Entry, rp *gphttp.ReversePr } // lifetime should follow route provider. -func NewHTTPWaker(providerSubTask task.Task, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { +func NewHTTPWaker(providerSubTask *task.Task, entry route.Entry, rp *gphttp.ReverseProxy) (Waker, E.Error) { return newWaker(providerSubTask, entry, rp, nil) } -func NewStreamWaker(providerSubTask task.Task, entry route.Entry, stream net.Stream) (Waker, E.Error) { +func NewStreamWaker(providerSubTask *task.Task, entry route.Entry, stream net.Stream) (Waker, E.Error) { return newWaker(providerSubTask, entry, nil, stream) } // Start implements health.HealthMonitor. -func (w *Watcher) Start(routeSubTask task.Task) E.Error { +func (w *Watcher) Start(routeSubTask *task.Task) E.Error { routeSubTask.Finish("ignored") w.task.OnCancel("stop route and cleanup", func() { routeSubTask.Parent().Finish(w.task.FinishCause()) diff --git a/internal/docker/idlewatcher/watcher.go b/internal/docker/idlewatcher/watcher.go index f8e9f2ae..17afcb80 100644 --- a/internal/docker/idlewatcher/watcher.go +++ b/internal/docker/idlewatcher/watcher.go @@ -32,7 +32,7 @@ type ( client D.Client stopByMethod StopCallback // send a docker command w.r.t. `stop_method` ticker *time.Ticker - task task.Task + task *task.Task } WakeDone <-chan error @@ -51,7 +51,7 @@ var ( const dockerReqTimeout = 3 * time.Second -func registerWatcher(providerSubtask task.Task, entry route.Entry, waker *waker) (*Watcher, error) { +func registerWatcher(providerSubtask *task.Task, entry route.Entry, waker *waker) (*Watcher, error) { cfg := entry.IdlewatcherConfig() if cfg.IdleTimeout == 0 { @@ -209,7 +209,7 @@ func (w *Watcher) resetIdleTimer() { w.ticker.Reset(w.IdleTimeout) } -func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) { +func (w *Watcher) getEventCh(dockerWatcher watcher.DockerWatcher) (eventTask *task.Task, eventCh <-chan events.Event, errCh <-chan E.Error) { eventTask = w.task.Subtask("docker event watcher") eventCh, errCh = dockerWatcher.EventsWithOptions(eventTask.Context(), watcher.DockerListOptions{ Filters: watcher.NewDockerFilter( diff --git a/internal/net/http/loadbalancer/loadbalancer.go b/internal/net/http/loadbalancer/loadbalancer.go index 89b06f6a..26192b77 100644 --- a/internal/net/http/loadbalancer/loadbalancer.go +++ b/internal/net/http/loadbalancer/loadbalancer.go @@ -27,7 +27,7 @@ type ( impl *Config - task task.Task + task *task.Task pool Pool poolMu sync.Mutex @@ -52,7 +52,7 @@ func New(cfg *Config) *LoadBalancer { } // Start implements task.TaskStarter. -func (lb *LoadBalancer) Start(routeSubtask task.Task) E.Error { +func (lb *LoadBalancer) Start(routeSubtask *task.Task) E.Error { lb.startTime = time.Now() lb.task = routeSubtask lb.task.OnFinished("loadbalancer cleanup", func() { diff --git a/internal/net/http/middleware/errorpage/error_page.go b/internal/net/http/middleware/errorpage/error_page.go index 7226f676..113a467d 100644 --- a/internal/net/http/middleware/errorpage/error_page.go +++ b/internal/net/http/middleware/errorpage/error_page.go @@ -73,7 +73,7 @@ func loadContent() { } } -func watchDir(task task.Task) { +func watchDir(task *task.Task) { eventCh, errCh := dirWatcher.Events(task.Context()) for { select { diff --git a/internal/net/http/server/server.go b/internal/net/http/server/server.go index 2597c46e..b38faf48 100644 --- a/internal/net/http/server/server.go +++ b/internal/net/http/server/server.go @@ -24,7 +24,7 @@ type Server struct { httpsStarted bool startTime time.Time - task task.Task + task *task.Task l zerolog.Logger } diff --git a/internal/notif/dispatcher.go b/internal/notif/dispatcher.go index 8bc4438c..42bd27be 100644 --- a/internal/notif/dispatcher.go +++ b/internal/notif/dispatcher.go @@ -13,7 +13,7 @@ import ( type ( Dispatcher struct { - task task.Task + task *task.Task logCh chan *LogMessage providers F.Set[Provider] } @@ -35,7 +35,7 @@ var ( const dispatchErr = "notification dispatch error" -func StartNotifDispatcher(parent task.Task) *Dispatcher { +func StartNotifDispatcher(parent *task.Task) *Dispatcher { dispatcher = &Dispatcher{ task: parent.Subtask("notification dispatcher"), logCh: make(chan *LogMessage), diff --git a/internal/route/http.go b/internal/route/http.go index ac903d13..a4fb33b6 100755 --- a/internal/route/http.go +++ b/internal/route/http.go @@ -31,7 +31,7 @@ type ( handler http.Handler rp *gphttp.ReverseProxy - task task.Task + task *task.Task l zerolog.Logger } @@ -74,8 +74,8 @@ func (r *HTTPRoute) String() string { return string(r.Alias) } -// Start implements task.TaskStarter. -func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error { +// Start implements*task.TaskStarter. +func (r *HTTPRoute) Start(providerSubtask *task.Task) E.Error { if entry.ShouldNotServe(r) { providerSubtask.Finish("should not serve") return nil @@ -148,7 +148,7 @@ func (r *HTTPRoute) Start(providerSubtask task.Task) E.Error { return nil } -// Finish implements task.TaskFinisher. +// Finish implements*task.TaskFinisher. func (r *HTTPRoute) Finish(reason any) { r.task.Finish(reason) } diff --git a/internal/route/provider/event_handler.go b/internal/route/provider/event_handler.go index 8af8b61e..653db118 100644 --- a/internal/route/provider/event_handler.go +++ b/internal/route/provider/event_handler.go @@ -28,7 +28,7 @@ func (p *Provider) newEventHandler() *EventHandler { } } -func (handler *EventHandler) Handle(parent task.Task, events []watcher.Event) { +func (handler *EventHandler) Handle(parent *task.Task, events []watcher.Event) { oldRoutes := handler.provider.routes newRoutes, err := handler.provider.loadRoutesImpl() if err != nil { @@ -97,7 +97,7 @@ func (handler *EventHandler) match(event watcher.Event, route *route.Route) bool return false } -func (handler *EventHandler) Add(parent task.Task, route *route.Route) { +func (handler *EventHandler) Add(parent *task.Task, route *route.Route) { err := handler.provider.startRoute(parent, route) if err != nil { handler.errs.Add(err.Subject("add")) @@ -112,7 +112,7 @@ func (handler *EventHandler) Remove(route *route.Route) { handler.removed.Adds(route.Entry.Alias) } -func (handler *EventHandler) Update(parent task.Task, oldRoute *route.Route, newRoute *route.Route) { +func (handler *EventHandler) Update(parent *task.Task, oldRoute *route.Route, newRoute *route.Route) { oldRoute.Finish("route update") err := handler.provider.startRoute(parent, newRoute) if err != nil { diff --git a/internal/route/provider/provider.go b/internal/route/provider/provider.go index 0c87caf1..cb09fb2c 100644 --- a/internal/route/provider/provider.go +++ b/internal/route/provider/provider.go @@ -100,7 +100,7 @@ func (p *Provider) MarshalText() ([]byte, error) { return []byte(p.String()), nil } -func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error { +func (p *Provider) startRoute(parent *task.Task, r *R.Route) E.Error { subtask := parent.Subtask(p.String() + "/" + r.Entry.Alias) err := r.Start(subtask) if err != nil { @@ -115,8 +115,8 @@ func (p *Provider) startRoute(parent task.Task, r *R.Route) E.Error { return nil } -// Start implements task.TaskStarter. -func (p *Provider) Start(configSubtask task.Task) E.Error { +// Start implements*task.TaskStarter. +func (p *Provider) Start(configSubtask *task.Task) E.Error { // routes and event queue will stop on parent cancel providerTask := configSubtask @@ -128,7 +128,7 @@ func (p *Provider) Start(configSubtask task.Task) E.Error { eventQueue := events.NewEventQueue( providerTask, providerEventFlushInterval, - func(flushTask task.Task, events []events.Event) { + func(flushTask *task.Task, events []events.Event) { handler := p.newEventHandler() // routes' lifetime should follow the provider's lifetime handler.Handle(providerTask, events) diff --git a/internal/route/stream.go b/internal/route/stream.go index 9d94cad2..6b7b88fe 100755 --- a/internal/route/stream.go +++ b/internal/route/stream.go @@ -24,7 +24,7 @@ type StreamRoute struct { HealthMon health.HealthMonitor `json:"health"` - task task.Task + task *task.Task l zerolog.Logger } @@ -47,8 +47,8 @@ func (r *StreamRoute) String() string { return fmt.Sprintf("stream %s", r.Alias) } -// Start implements task.TaskStarter. -func (r *StreamRoute) Start(providerSubtask task.Task) E.Error { +// Start implements*task.TaskStarter. +func (r *StreamRoute) Start(providerSubtask *task.Task) E.Error { if entry.ShouldNotServe(r) { providerSubtask.Finish("should not serve") return nil diff --git a/internal/task/task.go b/internal/task/task.go index 6584a2f5..0d0cedc7 100644 --- a/internal/task/task.go +++ b/internal/task/task.go @@ -16,11 +16,11 @@ import ( var globalTask = createGlobalTask() -func createGlobalTask() (t *task) { - t = new(task) +func createGlobalTask() (t *Task) { + t = new(Task) t.name = "root" t.ctx, t.cancel = context.WithCancelCause(context.Background()) - t.subtasks = F.NewSet[*task]() + t.subtasks = F.NewSet[*Task]() return } @@ -29,52 +29,6 @@ func testResetGlobalTask() { } type ( - // Task controls objects' lifetime. - // - // Objects that uses a task should implement the TaskStarter and the TaskFinisher interface. - // - // When passing a Task object to another function, - // it must be a sub-task of the current task, - // in name of "`currentTaskName`Subtask" - // - // Use Task.Finish to stop all subtasks of the task. - Task interface { - TaskFinisher - fmt.Stringer - // Name returns the name of the task. - Name() string - // Context returns the context associated with the task. This context is - // canceled when Finish of the task is called, or parent task is canceled. - Context() context.Context - // FinishCause returns the reason / error that caused the task to be finished. - FinishCause() error - // Parent returns the parent task of the current task. - Parent() Task - // Subtask returns a new subtask with the given name, derived from the parent's context. - // - // If the parent's context is already canceled, the returned subtask will be canceled immediately. - // - // This should not be called after Finish, Wait, or WaitSubTasks is called. - Subtask(name string) Task - // OnFinished calls fn when all subtasks are finished. - // - // It cannot be called after Finish or Wait is called. - OnFinished(about string, fn func()) - // OnCancel calls fn when the task is canceled. - // - // It cannot be called after Finish or Wait is called. - OnCancel(about string, fn func()) - // Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish. - // - // It must be called only after Finish is called. - Wait() - // WaitSubTasks waits for all subtasks of the task to finish. - // - // No more subtasks can be added after this call. - // - // It can be called before Finish is called. - WaitSubTasks() - } TaskStarter interface { // Start starts the object that implements TaskStarter, // and returns an error if it fails to start. @@ -82,7 +36,7 @@ type ( // The task passed must be a subtask of the caller task. // // callerSubtask.Finish must be called when start fails or the object is finished. - Start(callerSubtask Task) E.Error + Start(callerSubtask *Task) E.Error } TaskFinisher interface { // Finish marks the task as finished and cancel its context. @@ -93,12 +47,21 @@ type ( // Note that it will also cancel all subtasks. Finish(reason any) } - task struct { + // Task controls objects' lifetime. + // + // Objects that uses a Task should implement the TaskStarter and the TaskFinisher interface. + // + // When passing a Task object to another function, + // it must be a sub-Task of the current Task, + // in name of "`currentTaskName`Subtask" + // + // Use Task.Finish to stop all subtasks of the Task. + Task struct { ctx context.Context cancel context.CancelCauseFunc - parent *task - subtasks F.Set[*task] + parent *Task + subtasks F.Set[*Task] subTasksWg sync.WaitGroup name string @@ -119,7 +82,7 @@ var ( ) // GlobalTask returns a new Task with the given name, derived from the global context. -func GlobalTask(format string, args ...any) Task { +func GlobalTask(format string, args ...any) *Task { if len(args) > 0 { format = fmt.Sprintf(format, args...) } @@ -168,11 +131,12 @@ func GlobalContextWait(timeout time.Duration) (err error) { } } -func (t *task) trace(msg string) { +func (t *Task) trace(msg string) { logger.Trace().Str("name", t.name).Msg(msg) } -func (t *task) Name() string { +// Name returns the name of the task. +func (t *Task) Name() string { if !common.IsTrace { return t.name } @@ -180,15 +144,19 @@ func (t *task) Name() string { return parts[len(parts)-1] } -func (t *task) String() string { +// String returns the name of the task. +func (t *Task) String() string { return t.name } -func (t *task) Context() context.Context { +// Context returns the context associated with the task. This context is +// canceled when Finish of the task is called, or parent task is canceled. +func (t *Task) Context() context.Context { return t.ctx } -func (t *task) FinishCause() error { +// FinishCause returns the reason / error that caused the task to be finished. +func (t *Task) FinishCause() error { cause := context.Cause(t.ctx) if cause == nil { return t.ctx.Err() @@ -196,11 +164,12 @@ func (t *task) FinishCause() error { return cause } -func (t *task) Parent() Task { +// Parent returns the parent task of the current task. +func (t *Task) Parent() *Task { return t.parent } -func (t *task) runAllOnFinished(onCompTask Task) { +func (t *Task) runAllOnFinished(onCompTask *Task) { <-t.ctx.Done() t.WaitSubTasks() for _, OnFinishedFunc := range t.OnFinishedFuncs { @@ -210,7 +179,10 @@ func (t *task) runAllOnFinished(onCompTask Task) { onCompTask.Finish(fmt.Errorf("%w: %s, reason: %s", ErrTaskCanceled, t.name, "done")) } -func (t *task) OnFinished(about string, fn func()) { +// OnFinished calls fn when all subtasks are finished. +// +// It cannot be called after Finish or Wait is called. +func (t *Task) OnFinished(about string, fn func()) { if t.parent == globalTask { t.OnCancel(about, fn) return @@ -239,7 +211,10 @@ func (t *task) OnFinished(about string, fn func()) { t.OnFinishedFuncs = append(t.OnFinishedFuncs, wrapped) } -func (t *task) OnCancel(about string, fn func()) { +// OnCancel calls fn when the task is canceled. +// +// It cannot be called after Finish or Wait is called. +func (t *Task) OnCancel(about string, fn func()) { onCompTask := GlobalTask(t.name + " > OnFinished") go func() { <-t.ctx.Done() @@ -249,7 +224,13 @@ func (t *task) OnCancel(about string, fn func()) { }() } -func (t *task) Finish(reason any) { +// Finish marks the task as finished and cancel its context. +// +// Then call Wait to wait for all subtasks, OnFinished and OnSubtasksFinished +// of the task to finish. +// +// Note that it will also cancel all subtasks. +func (t *Task) Finish(reason any) { var format string switch reason.(type) { case error: @@ -265,22 +246,27 @@ func (t *task) Finish(reason any) { t.Wait() } -func (t *task) Subtask(name string) Task { +// Subtask returns a new subtask with the given name, derived from the parent's context. +// +// If the parent's context is already canceled, the returned subtask will be canceled immediately. +// +// This should not be called after Finish, Wait, or WaitSubTasks is called. +func (t *Task) Subtask(name string) *Task { ctx, cancel := context.WithCancelCause(t.ctx) return t.newSubTask(ctx, cancel, name) } -func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *task { +func (t *Task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, name string) *Task { parent := t if common.IsTrace { name = parent.name + " > " + name } - subtask := &task{ + subtask := &Task{ ctx: ctx, cancel: cancel, name: name, parent: parent, - subtasks: F.NewSet[*task](), + subtasks: F.NewSet[*Task](), } parent.subTasksWg.Add(1) parent.subtasks.Add(subtask) @@ -299,13 +285,21 @@ func (t *task) newSubTask(ctx context.Context, cancel context.CancelCauseFunc, n return subtask } -func (t *task) Wait() { +// Wait waits for all subtasks, itself, OnFinished and OnSubtasksFinished to finish. +// +// It must be called only after Finish is called. +func (t *Task) Wait() { <-t.ctx.Done() t.WaitSubTasks() t.onFinishedWg.Wait() } -func (t *task) WaitSubTasks() { +// WaitSubTasks waits for all subtasks of the task to finish. +// +// No more subtasks can be added after this call. +// +// It can be called before Finish is called. +func (t *Task) WaitSubTasks() { t.subTasksWg.Wait() } @@ -322,7 +316,7 @@ func (t *task) WaitSubTasks() { // // The returned string is not guaranteed to be stable, and may change between // runs of the program. It is intended for debugging purposes only. -func (t *task) tree(prefix ...string) string { +func (t *Task) tree(prefix ...string) string { var sb strings.Builder var pre string if len(prefix) > 0 { @@ -330,7 +324,7 @@ func (t *task) tree(prefix ...string) string { sb.WriteString(pre + "- ") } sb.WriteString(t.Name() + "\n") - t.subtasks.RangeAll(func(subtask *task) { + t.subtasks.RangeAll(func(subtask *Task) { sb.WriteString(subtask.tree(pre + " ")) }) return sb.String() @@ -350,13 +344,13 @@ func (t *task) tree(prefix ...string) string { // The returned map is not guaranteed to be stable, and may change // between runs of the program. It is intended for debugging purposes // only. -func (t *task) serialize() map[string]any { +func (t *Task) serialize() map[string]any { m := make(map[string]any) parts := strings.Split(t.name, " > ") m["name"] = parts[len(parts)-1] if t.subtasks.Size() > 0 { m["subtasks"] = make([]map[string]any, 0, t.subtasks.Size()) - t.subtasks.RangeAll(func(subtask *task) { + t.subtasks.RangeAll(func(subtask *Task) { m["subtasks"] = append(m["subtasks"].([]map[string]any), subtask.serialize()) }) } diff --git a/internal/watcher/directory_watcher.go b/internal/watcher/directory_watcher.go index d910650b..4fd06484 100644 --- a/internal/watcher/directory_watcher.go +++ b/internal/watcher/directory_watcher.go @@ -26,7 +26,7 @@ type DirWatcher struct { eventCh chan Event errCh chan E.Error - task task.Task + task *task.Task } // NewDirectoryWatcher returns a DirWatcher instance. @@ -37,7 +37,7 @@ type DirWatcher struct { // // Note that the returned DirWatcher is not ready to use until the goroutine // started by NewDirectoryWatcher has finished. -func NewDirectoryWatcher(callerSubtask task.Task, dirPath string) *DirWatcher { +func NewDirectoryWatcher(callerSubtask *task.Task, dirPath string) *DirWatcher { //! subdirectories are not watched w, err := fsnotify.NewWatcher() if err != nil { diff --git a/internal/watcher/events/event_queue.go b/internal/watcher/events/event_queue.go index f899d12b..8ed7be3f 100644 --- a/internal/watcher/events/event_queue.go +++ b/internal/watcher/events/event_queue.go @@ -10,14 +10,14 @@ import ( type ( EventQueue struct { - task task.Task + task *task.Task queue []Event ticker *time.Ticker flushInterval time.Duration onFlush OnFlushFunc onError OnErrorFunc } - OnFlushFunc = func(flushTask task.Task, events []Event) + OnFlushFunc = func(flushTask *task.Task, events []Event) OnErrorFunc = func(err E.Error) ) @@ -38,7 +38,7 @@ const eventQueueCapacity = 10 // but the onFlush function can return earlier (e.g. run in another goroutine). // // If task is canceled before the flushInterval is reached, the events in queue will be discarded. -func NewEventQueue(parent task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue { +func NewEventQueue(parent *task.Task, flushInterval time.Duration, onFlush OnFlushFunc, onError OnErrorFunc) *EventQueue { return &EventQueue{ task: parent.Subtask("event queue"), queue: make([]Event, 0, eventQueueCapacity), @@ -53,7 +53,7 @@ func (e *EventQueue) Start(eventCh <-chan Event, errCh <-chan E.Error) { if common.IsProduction { origOnFlush := e.onFlush // recover panic in onFlush when in production mode - e.onFlush = func(flushTask task.Task, events []Event) { + e.onFlush = func(flushTask *task.Task, events []Event) { defer func() { if err := recover(); err != nil { e.onError(E.New("recovered panic in onFlush"). diff --git a/internal/watcher/health/monitor/monitor.go b/internal/watcher/health/monitor/monitor.go index ef4250f3..b2d4d1fa 100644 --- a/internal/watcher/health/monitor/monitor.go +++ b/internal/watcher/health/monitor/monitor.go @@ -36,7 +36,7 @@ type ( metric *metrics.Gauge - task task.Task + task *task.Task } ) @@ -61,7 +61,7 @@ func (mon *monitor) ContextWithTimeout(cause string) (ctx context.Context, cance } // Start implements task.TaskStarter. -func (mon *monitor) Start(routeSubtask task.Task) E.Error { +func (mon *monitor) Start(routeSubtask *task.Task) E.Error { mon.service = routeSubtask.Parent().Name() mon.task = routeSubtask