Files
godoxy-yusing/src/go-proxy/udp_route.go

196 lines
3.9 KiB
Go
Executable File

package main
import (
"fmt"
"io"
"net"
"sync"
"github.com/sirupsen/logrus"
)
type UDPRoute struct {
*StreamRouteBase
connMap map[net.Addr]net.Conn
connMapMutex sync.Mutex
listeningConn *net.UDPConn
targetConn *net.UDPConn
}
type UDPConn struct {
remoteAddr net.Addr
buffer []byte
bytesReceived []byte
nReceived int
}
func NewUDPRoute(config *ProxyConfig) (StreamRoute, error) {
base, err := newStreamRouteBase(config)
if err != nil {
return nil, err
}
if base.TargetScheme != StreamType_UDP {
return nil, NewNestedError("unsupported").Subjectf("udp->%s", base.TargetScheme)
}
base.StreamImpl = &UDPRoute{
StreamRouteBase: base,
connMap: make(map[net.Addr]net.Conn),
}
return base, nil
}
func (route *UDPRoute) Setup() error {
source, err := net.ListenPacket(route.ListeningScheme, fmt.Sprintf(":%v", route.ListeningPort))
if err != nil {
return err
}
target, err := net.Dial(route.TargetScheme, fmt.Sprintf("%s:%v", route.TargetHost, route.TargetPort))
if err != nil {
source.Close()
return err
}
route.listeningConn = source.(*net.UDPConn)
route.targetConn = target.(*net.UDPConn)
return nil
}
func (route *UDPRoute) Accept() (interface{}, error) {
in := route.listeningConn
buffer := make([]byte, udpBufferSize)
nRead, srcAddr, err := in.ReadFromUDP(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
conn := &UDPConn{
remoteAddr: srcAddr,
buffer: buffer,
bytesReceived: buffer[:nRead],
nReceived: nRead,
}
return conn, nil
}
func (route *UDPRoute) HandleConnection(c interface{}) error {
var err error
conn := c.(*UDPConn)
srcConn, ok := route.connMap[conn.remoteAddr]
if !ok {
route.connMapMutex.Lock()
srcConn, err = net.DialUDP("udp", nil, conn.remoteAddr.(*net.UDPAddr))
if err != nil {
return err
}
route.connMap[conn.remoteAddr] = srcConn
route.connMapMutex.Unlock()
}
var forwarder func(*UDPConn, net.Conn) error
if logLevel == logrus.DebugLevel {
forwarder = route.forwardReceivedDebug
} else {
forwarder = route.forwardReceivedReal
}
// initiate connection to target
err = forwarder(conn, route.targetConn)
if err != nil {
return err
}
for {
select {
case <-route.stopCh:
return nil
default:
// receive from target
conn, err = route.readFrom(route.targetConn, conn.buffer)
if err != nil {
return err
}
// forward to source
err = forwarder(conn, srcConn)
if err != nil {
return err
}
// read from source
conn, err = route.readFrom(srcConn, conn.buffer)
if err != nil {
continue
}
// forward to target
err = forwarder(conn, route.targetConn)
if err != nil {
return err
}
}
}
}
func (route *UDPRoute) CloseListeners() {
if route.listeningConn != nil {
route.listeningConn.Close()
route.listeningConn = nil
}
if route.targetConn != nil {
route.targetConn.Close()
route.targetConn = nil
}
for _, conn := range route.connMap {
conn.(*net.UDPConn).Close() // TODO: change on non udp target
}
route.connMap = make(map[net.Addr]net.Conn)
}
func (route *UDPRoute) readFrom(src net.Conn, buffer []byte) (*UDPConn, error) {
nRead, err := src.Read(buffer)
if err != nil {
return nil, err
}
if nRead == 0 {
return nil, io.ErrShortBuffer
}
return &UDPConn{
remoteAddr: src.RemoteAddr(),
buffer: buffer,
bytesReceived: buffer[:nRead],
nReceived: nRead,
}, nil
}
func (route *UDPRoute) forwardReceivedReal(receivedConn *UDPConn, dest net.Conn) error {
nWritten, err := dest.Write(receivedConn.bytesReceived)
if nWritten != receivedConn.nReceived {
err = io.ErrShortWrite
}
return err
}
func (route *UDPRoute) forwardReceivedDebug(receivedConn *UDPConn, dest net.Conn) error {
route.l.WithField("size", receivedConn.nReceived).Debugf(
"forwarding from %s to %s",
receivedConn.remoteAddr.String(),
dest.RemoteAddr().String(),
)
return route.forwardReceivedReal(receivedConn, dest)
}