Last active
December 16, 2023 13:44
-
-
Save reu/8a703c0e57927050abcc3387b61c3c79 to your computer and use it in GitHub Desktop.
Simple TCP proxy using thread per connection and non blocking IO
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
io::{self, Read, Write}, | |
net::{TcpListener, TcpStream}, | |
os::fd::{AsRawFd, RawFd}, | |
thread, | |
}; | |
const READ: libc::c_short = libc::POLLRDNORM; | |
const WRITE: libc::c_short = libc::POLLWRNORM; | |
#[repr(C)] | |
#[derive(Debug)] | |
pub struct PollFd { | |
fd: RawFd, | |
events: libc::c_short, | |
revents: libc::c_short, | |
} | |
impl PollFd { | |
fn new(fd: RawFd) -> Self { | |
Self { | |
fd, | |
events: READ, | |
revents: 0, | |
} | |
} | |
pub fn is_readable(&self) -> bool { | |
self.revents & READ != 0 | |
} | |
pub fn is_writable(&self) -> bool { | |
self.revents & WRITE != 0 | |
} | |
pub fn set(&mut self, events: libc::c_short) { | |
self.events |= events; | |
} | |
pub fn unset(&mut self, events: libc::c_short) { | |
self.events &= !events; | |
} | |
} | |
fn main() -> io::Result<()> { | |
let listener = TcpListener::bind(("0.0.0.0", 10000))?; | |
while let Ok((downstream, _addr)) = listener.accept() { | |
thread::spawn(move || { | |
if let Ok(upstream) = TcpStream::connect(("0.0.0.0", 4444)) { | |
if let Err(err) = proxy(downstream, upstream) { | |
eprintln!("{err}"); | |
} | |
} | |
}); | |
} | |
Ok(()) | |
} | |
fn proxy(mut downstream: TcpStream, mut upstream: TcpStream) -> io::Result<()> { | |
downstream.set_nonblocking(true)?; | |
upstream.set_nonblocking(true)?; | |
let mut downstream_buf = vec![0; 1024]; | |
let mut upstream_buf = vec![0; 1024]; | |
let mut write_to_upstream: &[u8] = &[]; | |
let mut write_to_downstream: &[u8] = &[]; | |
let mut fds = [ | |
PollFd::new(downstream.as_raw_fd()), | |
PollFd::new(upstream.as_raw_fd()), | |
]; | |
loop { | |
unsafe { | |
libc::poll( | |
fds.as_mut_ptr() as *mut libc::pollfd, | |
2 as libc::nfds_t, | |
-1 as libc::c_int, | |
); | |
}; | |
let (fds1, fds2) = fds.split_at_mut(1); | |
let poll_downstream = &mut fds1[0]; | |
let poll_upstream = &mut fds2[0]; | |
if poll_downstream.is_readable() { | |
let start = write_to_upstream.len(); | |
let end = downstream_buf.len(); | |
if !(start..end).is_empty() { | |
write_to_upstream = match downstream.read(&mut downstream_buf[start..end])? { | |
0 => break, | |
bytes => &downstream_buf[0..start + bytes], | |
}; | |
poll_upstream.set(WRITE); | |
} | |
} | |
if poll_downstream.is_writable() && !write_to_downstream.is_empty() { | |
let written = downstream.write(write_to_downstream)?; | |
write_to_downstream = &write_to_downstream[written..write_to_downstream.len()]; | |
if write_to_downstream.is_empty() { | |
poll_downstream.unset(WRITE); | |
} | |
} | |
if poll_upstream.is_readable() { | |
let start = write_to_downstream.len(); | |
let end = upstream_buf.len(); | |
if !(start..end).is_empty() { | |
write_to_downstream = match upstream.read(&mut upstream_buf[start..end])? { | |
0 => break, | |
bytes => &upstream_buf[0..start + bytes], | |
}; | |
poll_downstream.set(WRITE); | |
} | |
} | |
if poll_upstream.is_writable() && !write_to_upstream.is_empty() { | |
let written = upstream.write(write_to_upstream)?; | |
write_to_upstream = &write_to_upstream[written..write_to_upstream.len()]; | |
if write_to_upstream.is_empty() { | |
poll_upstream.unset(WRITE); | |
} | |
} | |
} | |
io::Result::Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment