Created
April 3, 2024 08:35
-
-
Save knight42/bb74da69d110c2dce1d7747e21fc919c to your computer and use it in GitHub Desktop.
tcp forwarder that will retry upon connection refused error
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"errors" | |
"flag" | |
"fmt" | |
"io" | |
"log/slog" | |
"net" | |
"os" | |
"syscall" | |
"time" | |
) | |
func handleConn(downConn net.Conn, remoteAddr string, maxRetries int) { | |
defer downConn.Close() | |
slog.Info("Handle connection") | |
var ( | |
upConn net.Conn | |
err error | |
) | |
for i := range maxRetries { | |
upConn, err = net.Dial("tcp", remoteAddr) | |
if err == nil { | |
break | |
} | |
if !errors.Is(err, syscall.ECONNREFUSED) { | |
slog.Error("Fail to dial", "error", err) | |
return | |
} | |
if i == maxRetries-1 { | |
slog.Error("Fail to dial after retries", "error", err) | |
return | |
} | |
slog.Info("Retry dial after 3 secs", "error", err) | |
time.Sleep(3 * time.Second) | |
} | |
slog.Info("Connected to upstream") | |
defer upConn.Close() | |
done := make(chan struct{}) | |
go func() { | |
_, _ = io.Copy(upConn, downConn) | |
close(done) | |
}() | |
_, _ = io.Copy(downConn, upConn) | |
<-done | |
} | |
func run(localAddr, remoteAddr string, maxRetries int) error { | |
slog.Info("Waiting for connections", "localAddr", localAddr, "remoteAddr", remoteAddr) | |
lis, err := net.Listen("tcp", localAddr) | |
if err != nil { | |
return err | |
} | |
for { | |
downConn, err := lis.Accept() | |
if err != nil { | |
var tempErr interface{ Temporary() bool } | |
if errors.As(err, &tempErr) && tempErr.Temporary() { | |
continue | |
} | |
return fmt.Errorf("accept: %w", err) | |
} | |
go handleConn(downConn, remoteAddr, maxRetries) | |
} | |
} | |
func main() { | |
var ( | |
localAddr string | |
remoteAddr string | |
retries int | |
) | |
flag.StringVar(&localAddr, "l", "127.0.0.1:8090", "local listen address") | |
flag.StringVar(&remoteAddr, "u", "127.0.0.1:8091", "upstream address") | |
flag.IntVar(&retries, "r", 10, "max retries") | |
flag.Parse() | |
err := run(localAddr, remoteAddr, retries) | |
if err != nil { | |
slog.Error("Fail to run", "error", err) | |
os.Exit(1) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment