mirror of
https://github.com/yusing/godoxy.git
synced 2026-04-23 17:28:53 +02:00
added round_robin, least_conn and ip_hash load balance support, small refactoring
This commit is contained in:
33
internal/net/http/loadbalancer/ip_hash.go
Normal file
33
internal/net/http/loadbalancer/ip_hash.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ipHash struct{ *LoadBalancer }
|
||||
|
||||
func (lb *LoadBalancer) newIPHash() impl { return &ipHash{lb} }
|
||||
func (ipHash) OnAddServer(srv *Server) {}
|
||||
func (ipHash) OnRemoveServer(srv *Server) {}
|
||||
|
||||
func (impl ipHash) ServeHTTP(_ servers, rw http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
logger.Errorf("invalid remote address %s: %s", r.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
idx := hashIP(ip) % uint32(len(impl.pool))
|
||||
if !impl.pool[idx].available.Load() {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
}
|
||||
impl.pool[idx].handler.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
func hashIP(ip string) uint32 {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(ip))
|
||||
return h.Sum32()
|
||||
}
|
||||
53
internal/net/http/loadbalancer/least_conn.go
Normal file
53
internal/net/http/loadbalancer/least_conn.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
type leastConn struct {
|
||||
*LoadBalancer
|
||||
nConn F.Map[*Server, *atomic.Int64]
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) newLeastConn() impl {
|
||||
return &leastConn{
|
||||
LoadBalancer: lb,
|
||||
nConn: F.NewMapOf[*Server, *atomic.Int64](),
|
||||
}
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnAddServer(srv *Server) {
|
||||
impl.nConn.Store(srv, new(atomic.Int64))
|
||||
}
|
||||
|
||||
func (impl *leastConn) OnRemoveServer(srv *Server) {
|
||||
impl.nConn.Delete(srv)
|
||||
}
|
||||
|
||||
func (impl *leastConn) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
srv := srvs[0]
|
||||
minConn, ok := impl.nConn.Load(srv)
|
||||
if !ok {
|
||||
logger.Errorf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
for i := 1; i < len(srvs); i++ {
|
||||
nConn, ok := impl.nConn.Load(srvs[i])
|
||||
if !ok {
|
||||
logger.Errorf("[BUG] server %s not found", srv.Name)
|
||||
http.Error(rw, "Internal error", http.StatusInternalServerError)
|
||||
}
|
||||
if nConn.Load() < minConn.Load() {
|
||||
minConn = nConn
|
||||
srv = srvs[i]
|
||||
}
|
||||
}
|
||||
|
||||
minConn.Add(1)
|
||||
srv.handler.ServeHTTP(rw, r)
|
||||
minConn.Add(-1)
|
||||
}
|
||||
241
internal/net/http/loadbalancer/loadbalancer.go
Normal file
241
internal/net/http/loadbalancer/loadbalancer.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v4/log"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
// TODO: stats of each server
|
||||
// TODO: support weighted mode
|
||||
type (
|
||||
impl interface {
|
||||
ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request)
|
||||
OnAddServer(srv *Server)
|
||||
OnRemoveServer(srv *Server)
|
||||
}
|
||||
Config struct {
|
||||
Link string
|
||||
Mode Mode
|
||||
Weight weightType
|
||||
}
|
||||
LoadBalancer struct {
|
||||
impl
|
||||
Config
|
||||
|
||||
pool servers
|
||||
poolMu sync.RWMutex
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
|
||||
sumWeight weightType
|
||||
}
|
||||
|
||||
weightType uint16
|
||||
)
|
||||
|
||||
const maxWeight weightType = 100
|
||||
|
||||
func New(cfg Config) *LoadBalancer {
|
||||
lb := &LoadBalancer{Config: cfg, pool: servers{}}
|
||||
mode := cfg.Mode
|
||||
if !cfg.Mode.ValidateUpdate() {
|
||||
logger.Warnf("%s: invalid loadbalancer mode: %s, fallback to %s", cfg.Link, mode, cfg.Mode)
|
||||
}
|
||||
switch mode {
|
||||
case RoundRobin:
|
||||
lb.impl = lb.newRoundRobin()
|
||||
case LeastConn:
|
||||
lb.impl = lb.newLeastConn()
|
||||
case IPHash:
|
||||
lb.impl = lb.newIPHash()
|
||||
default: // should happen in test only
|
||||
lb.impl = lb.newRoundRobin()
|
||||
}
|
||||
return lb
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) AddServer(srv *Server) {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
lb.pool = append(lb.pool, srv)
|
||||
lb.sumWeight += srv.Weight
|
||||
|
||||
lb.impl.OnAddServer(srv)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) RemoveServer(srv *Server) {
|
||||
lb.poolMu.RLock()
|
||||
defer lb.poolMu.RUnlock()
|
||||
|
||||
lb.impl.OnRemoveServer(srv)
|
||||
|
||||
for i, s := range lb.pool {
|
||||
if s == srv {
|
||||
lb.pool = append(lb.pool[:i], lb.pool[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if lb.IsEmpty() {
|
||||
lb.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) IsEmpty() bool {
|
||||
return len(lb.pool) == 0
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Rebalance() {
|
||||
if lb.sumWeight == maxWeight {
|
||||
return
|
||||
}
|
||||
if lb.sumWeight == 0 { // distribute evenly
|
||||
weightEach := maxWeight / weightType(len(lb.pool))
|
||||
remainer := maxWeight % weightType(len(lb.pool))
|
||||
for _, s := range lb.pool {
|
||||
s.Weight = weightEach
|
||||
lb.sumWeight += weightEach
|
||||
if remainer > 0 {
|
||||
s.Weight++
|
||||
}
|
||||
remainer--
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// scale evenly
|
||||
scaleFactor := float64(maxWeight) / float64(lb.sumWeight)
|
||||
lb.sumWeight = 0
|
||||
|
||||
for _, s := range lb.pool {
|
||||
s.Weight = weightType(float64(s.Weight) * scaleFactor)
|
||||
lb.sumWeight += s.Weight
|
||||
}
|
||||
|
||||
delta := maxWeight - lb.sumWeight
|
||||
if delta == 0 {
|
||||
return
|
||||
}
|
||||
for _, s := range lb.pool {
|
||||
if delta == 0 {
|
||||
break
|
||||
}
|
||||
if delta > 0 {
|
||||
s.Weight++
|
||||
lb.sumWeight++
|
||||
delta--
|
||||
} else {
|
||||
s.Weight--
|
||||
lb.sumWeight--
|
||||
delta++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
srvs := lb.availServers()
|
||||
if len(srvs) == 0 {
|
||||
http.Error(rw, "Service unavailable", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
lb.impl.ServeHTTP(srvs, rw, r)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Start() {
|
||||
if lb.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
if lb.sumWeight != 0 && lb.sumWeight != maxWeight {
|
||||
msg := E.NewBuilder("loadbalancer %s total weight %d != %d", lb.Link, lb.sumWeight, maxWeight)
|
||||
for _, s := range lb.pool {
|
||||
msg.Addf("%s: %d", s.Name, s.Weight)
|
||||
}
|
||||
lb.Rebalance()
|
||||
inner := E.NewBuilder("After rebalancing")
|
||||
for _, s := range lb.pool {
|
||||
inner.Addf("%s: %d", s.Name, s.Weight)
|
||||
}
|
||||
msg.Addf("%s", inner)
|
||||
logger.Warn(msg)
|
||||
}
|
||||
|
||||
if lb.sumWeight != 0 {
|
||||
log.Warnf("Weighted mode not supported yet")
|
||||
}
|
||||
|
||||
switch lb.Mode {
|
||||
case RoundRobin:
|
||||
lb.impl = lb.newRoundRobin()
|
||||
case LeastConn:
|
||||
lb.impl = lb.newLeastConn()
|
||||
case IPHash:
|
||||
lb.impl = lb.newIPHash()
|
||||
}
|
||||
|
||||
lb.done = make(chan struct{}, 1)
|
||||
lb.ctx, lb.cancel = context.WithCancel(context.Background())
|
||||
|
||||
updateAll := func() {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(lb.pool))
|
||||
for _, s := range lb.pool {
|
||||
go func(s *Server) {
|
||||
defer wg.Done()
|
||||
s.checkUpdateAvail(lb.ctx)
|
||||
}(s)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer lb.cancel()
|
||||
defer close(lb.done)
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
updateAll()
|
||||
for {
|
||||
select {
|
||||
case <-lb.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
lb.poolMu.RLock()
|
||||
updateAll()
|
||||
lb.poolMu.RUnlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Stop() {
|
||||
if lb.impl == nil {
|
||||
return
|
||||
}
|
||||
|
||||
lb.cancel()
|
||||
|
||||
<-lb.done
|
||||
lb.pool = nil
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) availServers() servers {
|
||||
lb.poolMu.Lock()
|
||||
defer lb.poolMu.Unlock()
|
||||
|
||||
avail := servers{}
|
||||
for _, s := range lb.pool {
|
||||
if s.available.Load() {
|
||||
avail = append(avail, s)
|
||||
}
|
||||
}
|
||||
return avail
|
||||
}
|
||||
43
internal/net/http/loadbalancer/loadbalancer_test.go
Normal file
43
internal/net/http/loadbalancer/loadbalancer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
func TestRebalance(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("zero", func(t *testing.T) {
|
||||
lb := New(Config{})
|
||||
for range 10 {
|
||||
lb.AddServer(&Server{})
|
||||
}
|
||||
lb.Rebalance()
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("less", func(t *testing.T) {
|
||||
lb := New(Config{})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.Rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
t.Run("more", func(t *testing.T) {
|
||||
lb := New(Config{})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .4)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .3)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .2)})
|
||||
lb.AddServer(&Server{Weight: weightType(float64(maxWeight) * .1)})
|
||||
lb.Rebalance()
|
||||
// t.Logf("%s", U.Must(json.MarshalIndent(lb.pool, "", " ")))
|
||||
ExpectEqual(t, lb.sumWeight, maxWeight)
|
||||
})
|
||||
}
|
||||
5
internal/net/http/loadbalancer/logger.go
Normal file
5
internal/net/http/loadbalancer/logger.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package loadbalancer
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
|
||||
var logger = logrus.WithField("module", "load_balancer")
|
||||
29
internal/net/http/loadbalancer/mode.go
Normal file
29
internal/net/http/loadbalancer/mode.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
RoundRobin Mode = "roundrobin"
|
||||
LeastConn Mode = "leastconn"
|
||||
IPHash Mode = "iphash"
|
||||
)
|
||||
|
||||
func (mode *Mode) ValidateUpdate() bool {
|
||||
switch U.ToLowerNoSnake(string(*mode)) {
|
||||
case "", string(RoundRobin):
|
||||
*mode = RoundRobin
|
||||
return true
|
||||
case string(LeastConn):
|
||||
*mode = LeastConn
|
||||
return true
|
||||
case string(IPHash):
|
||||
*mode = IPHash
|
||||
return true
|
||||
}
|
||||
*mode = RoundRobin
|
||||
return false
|
||||
}
|
||||
22
internal/net/http/loadbalancer/round_robin.go
Normal file
22
internal/net/http/loadbalancer/round_robin.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type roundRobin struct {
|
||||
index atomic.Uint32
|
||||
}
|
||||
|
||||
func (*LoadBalancer) newRoundRobin() impl { return &roundRobin{} }
|
||||
func (lb *roundRobin) OnAddServer(srv *Server) {}
|
||||
func (lb *roundRobin) OnRemoveServer(srv *Server) {}
|
||||
|
||||
func (lb *roundRobin) ServeHTTP(srvs servers, rw http.ResponseWriter, r *http.Request) {
|
||||
index := lb.index.Add(1)
|
||||
srvs[index%uint32(len(srvs))].handler.ServeHTTP(rw, r)
|
||||
if lb.index.Load() >= 2*uint32(len(srvs)) {
|
||||
lb.index.Store(0)
|
||||
}
|
||||
}
|
||||
67
internal/net/http/loadbalancer/server.go
Normal file
67
internal/net/http/loadbalancer/server.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
type (
|
||||
Server struct {
|
||||
Name string
|
||||
URL types.URL
|
||||
Weight weightType
|
||||
handler http.Handler
|
||||
|
||||
pinger *http.Client
|
||||
available atomic.Bool
|
||||
}
|
||||
servers []*Server
|
||||
)
|
||||
|
||||
func NewServer(name string, url types.URL, weight weightType, handler http.Handler) *Server {
|
||||
srv := &Server{
|
||||
Name: name,
|
||||
URL: url,
|
||||
Weight: weightType(weight),
|
||||
handler: handler,
|
||||
pinger: &http.Client{Timeout: 3 * time.Second},
|
||||
}
|
||||
srv.available.Store(true)
|
||||
return srv
|
||||
}
|
||||
|
||||
func (srv *Server) checkUpdateAvail(ctx context.Context) {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodHead,
|
||||
srv.URL.String(),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error("failed to create request: ", err)
|
||||
srv.available.Store(false)
|
||||
}
|
||||
|
||||
resp, err := srv.pinger.Do(req)
|
||||
if err == nil && resp.StatusCode != http.StatusServiceUnavailable {
|
||||
if !srv.available.Swap(true) {
|
||||
logger.Infof("server %s is up", srv.Name)
|
||||
}
|
||||
} else if err != nil {
|
||||
if srv.available.Swap(false) {
|
||||
logger.Warnf("server %s is down: %s", srv.Name, err)
|
||||
}
|
||||
} else {
|
||||
if srv.available.Swap(false) {
|
||||
logger.Warnf("server %s is down: status %s", srv.Name, resp.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) String() string {
|
||||
return srv.Name
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
F "github.com/yusing/go-proxy/internal/utils/functional"
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -29,7 +29,6 @@ func init() {
|
||||
"setxforwarded": SetXForwarded,
|
||||
"hidexforwarded": HideXForwarded,
|
||||
"redirecthttp": RedirectHTTP,
|
||||
"forwardauth": ForwardAuth.m,
|
||||
"modifyresponse": ModifyResponse.m,
|
||||
"modifyrequest": ModifyRequest.m,
|
||||
"errorpage": CustomErrorPage,
|
||||
@@ -37,6 +36,10 @@ func init() {
|
||||
"realip": RealIP.m,
|
||||
"cloudflarerealip": CloudflareRealIP.m,
|
||||
"cidrwhitelist": CIDRWhiteList.m,
|
||||
|
||||
// !experimental
|
||||
"forwardauth": ForwardAuth.m,
|
||||
"oauth2": OAuth2.m,
|
||||
}
|
||||
names := make(map[*Middleware][]string)
|
||||
for name, m := range middlewares {
|
||||
|
||||
129
internal/net/http/middleware/oauth2.go
Normal file
129
internal/net/http/middleware/oauth2.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type oAuth2 struct {
|
||||
*oAuth2Opts
|
||||
m *Middleware
|
||||
}
|
||||
|
||||
type oAuth2Opts struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthURL string // Authorization Endpoint
|
||||
TokenURL string // Token Endpoint
|
||||
}
|
||||
|
||||
var OAuth2 = &oAuth2{
|
||||
m: &Middleware{withOptions: NewAuthentikOAuth2},
|
||||
}
|
||||
|
||||
func NewAuthentikOAuth2(opts OptionsRaw) (*Middleware, E.NestedError) {
|
||||
oauth := new(oAuth2)
|
||||
oauth.m = &Middleware{
|
||||
impl: oauth,
|
||||
before: oauth.handleOAuth2,
|
||||
}
|
||||
oauth.oAuth2Opts = &oAuth2Opts{}
|
||||
err := Deserialize(opts, oauth.oAuth2Opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := E.NewBuilder("missing required fields")
|
||||
optV := reflect.ValueOf(oauth.oAuth2Opts)
|
||||
for _, field := range reflect.VisibleFields(reflect.TypeFor[oAuth2Opts]()) {
|
||||
if optV.FieldByName(field.Name).Len() == 0 {
|
||||
b.Add(E.Missing(field.Name))
|
||||
}
|
||||
}
|
||||
if b.HasError() {
|
||||
return nil, b.Build().Subject("oAuth2")
|
||||
}
|
||||
return oauth.m, nil
|
||||
}
|
||||
|
||||
func (oauth *oAuth2) handleOAuth2(next http.HandlerFunc, rw ResponseWriter, r *Request) {
|
||||
// Check if the user is authenticated (you may use session, cookie, etc.)
|
||||
if !userIsAuthenticated(r) {
|
||||
// TODO: Redirect to OAuth2 auth URL
|
||||
http.Redirect(rw, r, fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code",
|
||||
oauth.oAuth2Opts.AuthURL, oauth.oAuth2Opts.ClientID, ""), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If you have a token in the query string, process it
|
||||
if code := r.URL.Query().Get("code"); code != "" {
|
||||
// Exchange the authorization code for a token here
|
||||
// Use the TokenURL and authenticate the user
|
||||
token, err := exchangeCodeForToken(code, oauth.oAuth2Opts, r.RequestURI)
|
||||
if err != nil {
|
||||
// handle error
|
||||
http.Error(rw, "failed to get token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Save token and user info based on your requirements
|
||||
saveToken(rw, token)
|
||||
|
||||
// Redirect to the originally requested URL
|
||||
http.Redirect(rw, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If user is authenticated, go to the next handler
|
||||
next(rw, r)
|
||||
}
|
||||
|
||||
func userIsAuthenticated(r *http.Request) bool {
|
||||
// Example: Check for a session or cookie
|
||||
session, err := r.Cookie("session_token")
|
||||
if err != nil || session.Value == "" {
|
||||
return false
|
||||
}
|
||||
// Validate the session_token if necessary
|
||||
return true
|
||||
}
|
||||
|
||||
func exchangeCodeForToken(code string, opts *oAuth2Opts, requestUri string) (string, error) {
|
||||
// Prepare the request body
|
||||
data := url.Values{
|
||||
"client_id": {opts.ClientID},
|
||||
"client_secret": {opts.ClientSecret},
|
||||
"code": {code},
|
||||
"grant_type": {"authorization_code"},
|
||||
"redirect_uri": {requestUri},
|
||||
}
|
||||
resp, err := http.PostForm(opts.TokenURL, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to request token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("received non-ok status from token endpoint: %s", resp.Status)
|
||||
}
|
||||
// Decode the response
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode token response: %v", err)
|
||||
}
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
func saveToken(rw ResponseWriter, token string) {
|
||||
// Example: Save token in cookie
|
||||
http.SetCookie(rw, &http.Cookie{
|
||||
Name: "auth_token",
|
||||
Value: token,
|
||||
// set other properties as necessary, such as Secure and HttpOnly
|
||||
})
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"net"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
// https://nginx.org/en/docs/http/ngx_http_realip_module.html
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/types"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
gphttp "github.com/yusing/go-proxy/internal/net/http"
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
)
|
||||
|
||||
//go:embed test_data/sample_headers.json
|
||||
@@ -110,7 +111,7 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.N
|
||||
} else {
|
||||
proxyURL, _ = url.Parse("https://" + testHost) // dummy url, no actual effect
|
||||
}
|
||||
rp := gphttp.NewReverseProxy(proxyURL, rr)
|
||||
rp := gphttp.NewReverseProxy(types.NewURL(proxyURL), rr)
|
||||
mid, setOptErr := middleware.WithOptionsClone(args.middlewareOpt)
|
||||
if setOptErr != nil {
|
||||
return nil, setOptErr
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
|
||||
"github.com/yusing/go-proxy/internal/net/types"
|
||||
U "github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
@@ -86,7 +87,7 @@ type ReverseProxy struct {
|
||||
|
||||
ServeHTTP http.HandlerFunc
|
||||
|
||||
TargetURL *url.URL
|
||||
TargetURL types.URL
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
@@ -144,7 +145,7 @@ func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||
// }
|
||||
//
|
||||
|
||||
func NewReverseProxy(target *url.URL, transport http.RoundTripper) *ReverseProxy {
|
||||
func NewReverseProxy(target types.URL, transport http.RoundTripper) *ReverseProxy {
|
||||
if transport == nil {
|
||||
panic("nil transport")
|
||||
}
|
||||
@@ -263,7 +264,7 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
|
||||
}
|
||||
|
||||
rewriteRequestURL(outreq, p.TargetURL)
|
||||
rewriteRequestURL(outreq, p.TargetURL.URL)
|
||||
outreq.Close = false
|
||||
|
||||
reqUpType := UpgradeType(outreq.Header)
|
||||
@@ -348,18 +349,16 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
roundTripMutex.Unlock()
|
||||
if err != nil {
|
||||
p.errorHandler(rw, outreq, err, false)
|
||||
errMsg := err.Error()
|
||||
res = &http.Response{
|
||||
Status: http.StatusText(http.StatusBadGateway),
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Proto: outreq.Proto,
|
||||
ProtoMajor: outreq.ProtoMajor,
|
||||
ProtoMinor: outreq.ProtoMinor,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
||||
Request: outreq,
|
||||
ContentLength: int64(len(errMsg)),
|
||||
TLS: outreq.TLS,
|
||||
Status: http.StatusText(http.StatusBadGateway),
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Proto: outreq.Proto,
|
||||
ProtoMajor: outreq.ProtoMajor,
|
||||
ProtoMinor: outreq.ProtoMinor,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(bytes.NewReader([]byte("Origin server is not reachable."))),
|
||||
Request: outreq,
|
||||
TLS: outreq.TLS,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
38
internal/net/types/cidr.go
Normal file
38
internal/net/types/cidr.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
)
|
||||
|
||||
type CIDR net.IPNet
|
||||
|
||||
func (*CIDR) ConvertFrom(val any) (any, E.NestedError) {
|
||||
cidr, ok := val.(string)
|
||||
if !ok {
|
||||
return nil, E.TypeMismatch[string](val)
|
||||
}
|
||||
|
||||
if !strings.Contains(cidr, "/") {
|
||||
cidr += "/32" // single IP
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, E.Invalid("CIDR", cidr)
|
||||
}
|
||||
return (*CIDR)(ipnet), nil
|
||||
}
|
||||
|
||||
func (cidr *CIDR) Contains(ip net.IP) bool {
|
||||
return (*net.IPNet)(cidr).Contains(ip)
|
||||
}
|
||||
|
||||
func (cidr *CIDR) String() string {
|
||||
return (*net.IPNet)(cidr).String()
|
||||
}
|
||||
|
||||
func (cidr *CIDR) Equals(other *CIDR) bool {
|
||||
return (*net.IPNet)(cidr).IP.Equal(other.IP) && cidr.Mask.String() == other.Mask.String()
|
||||
}
|
||||
24
internal/net/types/url.go
Normal file
24
internal/net/types/url.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package types
|
||||
|
||||
import "net/url"
|
||||
|
||||
type URL struct{ *url.URL }
|
||||
|
||||
func NewURL(url *url.URL) URL {
|
||||
return URL{url}
|
||||
}
|
||||
|
||||
func (u URL) String() string {
|
||||
if u.URL == nil {
|
||||
return "nil"
|
||||
}
|
||||
return u.URL.String()
|
||||
}
|
||||
|
||||
func (u URL) MarshalText() (text []byte, err error) {
|
||||
return []byte(u.String()), nil
|
||||
}
|
||||
|
||||
func (u URL) Equals(other URL) bool {
|
||||
return u.URL == other.URL || u.String() == other.String()
|
||||
}
|
||||
Reference in New Issue
Block a user