WIP: Initial commit

Oystein Kristoffer Tveit 2024-04-29 10:11:14 +02:00
commit 8216775885
4 changed files with 1371 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

1123
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

24
Cargo.toml Normal file
View File

@ -0,0 +1,24 @@
[package]
name = "woossh"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.82"
clap = { version = "4.5.4", features = ["derive"] }
env_logger = "0.11.3"
file-descriptors = { version = "0.9.1", optional = true }
futures-util = "0.3.30"
log = "0.4.21"
tokio = { version = "1.37.0", features = ["full"] }
tokio-tungstenite = "0.21.0"
tungstenite = "0.21.0"
[features]
socket-activation = ["dep:file-descriptors"]
[[bin]]
name = "wssh"
path = "src/main.rs"

223
src/main.rs Normal file
View File

@ -0,0 +1,223 @@
use anyhow::Context;
use futures_util::{SinkExt, StreamExt};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use clap::Parser;
use tungstenite::Message;
#[cfg(feature = "socket_activation")]
use std::os::fd::FromRawFd;
#[derive(Parser)]
struct Args {
#[command(subcommand)]
command: SubCommand,
}
#[derive(Parser)]
enum SubCommand {
#[command()]
Connect(ConnectArgs),
#[command()]
Server(ServerArgs),
#[cfg(feature = "socket_activation")]
#[command()]
SocketActivation,
}
#[derive(Parser)]
struct ConnectArgs {
#[arg()]
uri: String,
}
#[derive(Parser)]
struct ServerArgs {
#[arg(long, default_value = "0.0.0.0")]
host: String,
#[arg(short, long, default_value = "2222")]
port: u16,
#[arg(short, long, default_value = "localhost:22")]
ssh_socket: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();
let args = Args::parse();
match args.command {
SubCommand::Connect(args) => connect(args).await,
SubCommand::Server(args) => server(args).await,
#[cfg(feature = "socket_activation")]
SubCommand::SocketActivation => socket_activation().await,
}
}
async fn connect(args: ConnectArgs) -> anyhow::Result<()> {
log::info!("Connecting to {}", args.uri);
let (stream, _) = tokio_tungstenite::connect_async(args.uri)
.await
.context("Failed to connect to the server")?;
log::info!("Connected to the server");
let (mut tx, mut rx) = stream.split();
let stdin_forwarder = tokio::spawn(async move {
log::trace!("Starting stdin forwarder");
let mut buf_stdin = tokio::io::stdin();
loop {
let mut buffer = Vec::new();
match buf_stdin
.read_buf(&mut buffer)
.await
.context("Failed to read from stdin")
{
Err(e) => {
log::error!("Error reading from stdin: {}", e);
break;
}
Ok(0) => {
continue;
}
Ok(1..) => {
log::trace!("C->S {:?}", String::from_utf8_lossy(&buffer));
}
}
if let Err(e) = tx
.send(Message::Binary(buffer))
.await
.context("Failed to forward data from stdin")
{
log::error!("Error sending data to server: {}", e);
break;
}
}
});
let stdout_receiver = tokio::spawn(async move {
log::trace!("Starting stdout receiver");
let mut buf_stdout = tokio::io::stdout();
loop {
match rx.next().await.unwrap().unwrap() {
Message::Binary(data) => {
log::trace!("C<-S {:?}", String::from_utf8_lossy(&data));
buf_stdout.write_all(&data).await.unwrap();
buf_stdout.flush().await.unwrap();
}
_ => panic!("Unexpected message type"),
}
}
});
tokio::try_join!(stdin_forwarder, stdout_receiver)?;
Ok(())
}
// --------------------------------------------------------- //
async fn server(args: ServerArgs) -> anyhow::Result<()> {
let addr = format!("{}:{}", args.host, args.port);
let listener = tokio::net::TcpListener::bind(&addr)
.await
.context("Failed to bind to address")?;
log::info!("Listening on {}", &addr);
loop {
tokio::select! {
Ok((stream, _)) = listener.accept() => {
log::info!("Accepted connection from {}", stream.peer_addr()?);
tokio::spawn(handle_client(stream, args.ssh_socket.clone()));
}
_ = tokio::signal::ctrl_c() => break,
}
}
Ok(())
}
async fn handle_client(stream: TcpStream, ssh_socket: String) -> anyhow::Result<()> {
log::trace!("Accepting WebSocket connection");
let stream = tokio_tungstenite::accept_async(stream)
.await
.expect("Error during WebSocket handshake");
let (mut tx, mut rx) = stream.split();
log::trace!("Connecting to SSH server");
let ssh_stream = TcpStream::connect(ssh_socket)
.await
.context("Failed to connect to SSH server")?;
let (mut ssh_rx, mut ssh_tx) = tokio::io::split(ssh_stream);
let stdin_receiver = tokio::spawn(async move {
log::trace!("Starting stdin receiver");
loop {
match rx.next().await {
Some(Ok(Message::Binary(data))) => {
log::trace!("S<-C {:?}", String::from_utf8_lossy(&data));
if let Err(e) = ssh_tx.write_all(&data).await {
log::error!("Error sending data to SSH server: {}", e);
break;
}
ssh_tx.flush().await.unwrap();
}
None => {
eprintln!("Websocket connection closed");
break;
}
Some(Err(e)) => {
eprintln!("Websocket error: {}", e);
break;
}
Some(Ok(m)) => panic!("Unexpected message: {:?}", m),
}
}
});
let stdout_forwarder = tokio::spawn(async move {
log::trace!("Starting stdout forwarder");
loop {
let mut buffer = Vec::new();
match ssh_rx.read_buf(&mut buffer).await {
Err(e) => {
log::error!("Error reading from SSH server: {}", e);
break;
}
Ok(0) => {
continue;
}
Ok(1..) => {
log::info!("S->C {:?}", String::from_utf8_lossy(&buffer));
}
}
if let Err(e) = tx.send(Message::Binary(buffer)).await {
log::error!("Error sending data to client: {}", e);
break;
}
}
});
tokio::try_join!(stdin_receiver, stdout_forwarder)?;
Ok(())
}
// --------------------------------------------------------- //
#[cfg(feature = "socket_activation")]
async fn socket_activation() -> anyhow::Result<()> {
let raw_connection = unsafe { std::net::TcpStream::from_raw_fd(3) };
Ok(())
}