Initial commit
This commit is contained in:
@@ -0,0 +1,574 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use bro::{
|
||||
common::{copy_exact, env_var_is_set, env_var_is_truthy, init_tracing, parse_csv_env},
|
||||
protocol::*,
|
||||
};
|
||||
use sendfd::SendWithFd;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
env,
|
||||
fs::File,
|
||||
io::{self, IsTerminal, Seek, Write},
|
||||
os::{
|
||||
fd::{AsRawFd, RawFd},
|
||||
unix::net::UnixStream,
|
||||
},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
fn main() {
|
||||
if let Err(error) = init_tracing("bro-client") {
|
||||
eprintln!("bro-client: failed to initialize tracing: {error:#}");
|
||||
std::process::exit(WRAPPER_ERROR_EXIT_CODE);
|
||||
}
|
||||
|
||||
let exit_code = match run() {
|
||||
Ok(code) => code,
|
||||
Err(error) => {
|
||||
log::error!("bro-client failed: {error:#}");
|
||||
WRAPPER_ERROR_EXIT_CODE
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("bro-client exiting with code {exit_code}");
|
||||
std::process::exit(exit_code);
|
||||
}
|
||||
|
||||
fn run() -> Result<i32> {
|
||||
let options = ClientOptions::from_env()?;
|
||||
log::info!(
|
||||
"starting bro-client request: socket={} args={} forward_env={} file_flags={} bro_file_args_enabled={}",
|
||||
options.socket_path.display(),
|
||||
options.target_args.len(),
|
||||
options.forward_env.len(),
|
||||
options.file_flags.len(),
|
||||
options.file_args_enabled
|
||||
);
|
||||
if options.file_args_enabled {
|
||||
log::info!(
|
||||
"BRO_FILE_ARGS is enabled; bro-client will probe non-flag arguments and auto-forward readable regular files"
|
||||
);
|
||||
} else {
|
||||
log::debug!(
|
||||
"BRO_FILE_ARGS is disabled; bro-client will only forward paths matched via BRO_FILE_FLAGS"
|
||||
);
|
||||
}
|
||||
if !options.file_flags.is_empty() {
|
||||
log::debug!("BRO_FILE_FLAGS configured as: {:?}", options.file_flags);
|
||||
}
|
||||
|
||||
let transports = choose_request_transports(&options.socket_path);
|
||||
log::info!(
|
||||
"selected request transports: uploads={:?} stdin={:?} responses={:?}",
|
||||
transports.upload_transport,
|
||||
transports.stdin_transport,
|
||||
transports.response_transport
|
||||
);
|
||||
|
||||
let mut stdin = prepare_stdin(transports.stdin_transport)?;
|
||||
let request = build_request_header(
|
||||
&options,
|
||||
stdin.stdin_size(),
|
||||
transports.upload_transport,
|
||||
transports.stdin_transport,
|
||||
transports.response_transport,
|
||||
)?;
|
||||
log::info!(
|
||||
"prepared request: uploads={} rewrites={} stdin_size={:?}",
|
||||
request.uploads.len(),
|
||||
request.header.rewrites.len(),
|
||||
request.header.stdin_size,
|
||||
);
|
||||
|
||||
log::info!(
|
||||
"connecting to bro-server socket at {}",
|
||||
options.socket_path.display()
|
||||
);
|
||||
let mut stream = UnixStream::connect(&options.socket_path).with_context(|| {
|
||||
format!(
|
||||
"failed to connect to broker socket `{}`",
|
||||
options.socket_path.display()
|
||||
)
|
||||
})?;
|
||||
log::info!(
|
||||
"connected to bro-server socket at {}",
|
||||
options.socket_path.display()
|
||||
);
|
||||
|
||||
write_execute_magic(&mut stream)?;
|
||||
write_execute_request_header(&mut stream, &request.header)?;
|
||||
|
||||
if request.header.upload_transport == UploadTransport::StreamedBytes {
|
||||
send_uploaded_file_bytes(&mut stream, &request.uploads)?;
|
||||
}
|
||||
|
||||
let fd_uploads: &[PreparedUpload] =
|
||||
if request.header.upload_transport == UploadTransport::PassedFileDescriptors {
|
||||
&request.uploads[..]
|
||||
} else {
|
||||
&request.uploads[..0]
|
||||
};
|
||||
send_passed_descriptors(&stream, fd_uploads, stdin.passed_fd())?;
|
||||
stdin.send_streamed_bytes_if_needed(&mut stream)?;
|
||||
stream
|
||||
.shutdown(std::net::Shutdown::Write)
|
||||
.context("failed to close the client write side")?;
|
||||
|
||||
log::info!("waiting for response frames from bro-server");
|
||||
receive_response(stream, request.header.response_transport)
|
||||
}
|
||||
|
||||
struct ClientOptions {
|
||||
socket_path: PathBuf,
|
||||
forward_env: Vec<String>,
|
||||
file_flags: Vec<String>,
|
||||
file_args_enabled: bool,
|
||||
target_args: Vec<String>,
|
||||
}
|
||||
|
||||
impl ClientOptions {
|
||||
fn from_env() -> Result<Self> {
|
||||
let socket_path = env::var_os("BRO_SOCKET_PATH")
|
||||
.map(PathBuf::from)
|
||||
.context("BRO_SOCKET_PATH must be set")?;
|
||||
|
||||
Ok(Self {
|
||||
socket_path,
|
||||
forward_env: parse_csv_env("BRO_FORWARD_ENV"),
|
||||
file_flags: parse_csv_env("BRO_FILE_FLAGS"),
|
||||
file_args_enabled: env_var_is_set("BRO_FILE_ARGS"),
|
||||
target_args: env::args().skip(1).collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct PreparedRequest {
|
||||
header: RequestHeader,
|
||||
uploads: Vec<PreparedUpload>,
|
||||
}
|
||||
|
||||
struct PreparedUpload {
|
||||
path: PathBuf,
|
||||
size: u64,
|
||||
file: File,
|
||||
}
|
||||
|
||||
struct RequestTransports {
|
||||
upload_transport: UploadTransport,
|
||||
stdin_transport: UploadTransport,
|
||||
response_transport: ResponseTransport,
|
||||
}
|
||||
|
||||
enum PreparedStdin {
|
||||
Streamed { spool: NamedTempFile, size: u64 },
|
||||
Passed { fd: RawFd, _owner: Option<File> },
|
||||
}
|
||||
|
||||
impl PreparedStdin {
|
||||
fn stdin_size(&self) -> Option<u64> {
|
||||
match self {
|
||||
Self::Streamed { size, .. } => Some(*size),
|
||||
Self::Passed { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn passed_fd(&self) -> Option<RawFd> {
|
||||
match self {
|
||||
Self::Streamed { .. } => None,
|
||||
Self::Passed { fd, .. } => Some(*fd),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_streamed_bytes_if_needed(&mut self, stream: &mut UnixStream) -> Result<()> {
|
||||
match self {
|
||||
Self::Streamed { spool, size } => {
|
||||
spool
|
||||
.rewind()
|
||||
.context("failed to rewind stdin spool before transmission")?;
|
||||
copy_exact(spool, stream, *size).context("failed to stream stdin to bro-server")
|
||||
}
|
||||
Self::Passed { .. } => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct UploadRegistry {
|
||||
uploads_by_path: HashMap<String, u32>,
|
||||
uploads: Vec<UploadSpec>,
|
||||
prepared_uploads: Vec<PreparedUpload>,
|
||||
next_upload_id: u32,
|
||||
}
|
||||
|
||||
impl UploadRegistry {
|
||||
fn register(&mut self, path_text: &str, file: File, size: u64) -> Result<u32> {
|
||||
let upload_id = self.next_upload_id;
|
||||
self.next_upload_id = self
|
||||
.next_upload_id
|
||||
.checked_add(1)
|
||||
.context("too many transport files were requested")?;
|
||||
|
||||
self.uploads_by_path.insert(path_text.to_owned(), upload_id);
|
||||
self.uploads.push(UploadSpec {
|
||||
id: upload_id,
|
||||
original_path: path_text.to_owned(),
|
||||
size,
|
||||
});
|
||||
self.prepared_uploads.push(PreparedUpload {
|
||||
path: PathBuf::from(path_text),
|
||||
size,
|
||||
file,
|
||||
});
|
||||
Ok(upload_id)
|
||||
}
|
||||
|
||||
fn into_parts(self) -> (Vec<UploadSpec>, Vec<PreparedUpload>) {
|
||||
(self.uploads, self.prepared_uploads)
|
||||
}
|
||||
}
|
||||
|
||||
impl UploadRegistrar for UploadRegistry {
|
||||
fn ensure_upload(&mut self, path_text: &str) -> Result<u32> {
|
||||
if let Some(&existing_id) = self.uploads_by_path.get(path_text) {
|
||||
return Ok(existing_id);
|
||||
}
|
||||
|
||||
let (file, size) = probe_regular_upload(path_text)?
|
||||
.with_context(|| format!("failed to open transport file `{path_text}`"))?;
|
||||
self.register(path_text, file, size)
|
||||
}
|
||||
|
||||
fn maybe_ensure_upload(&mut self, path_text: &str) -> Result<Option<u32>> {
|
||||
if let Some(&existing_id) = self.uploads_by_path.get(path_text) {
|
||||
return Ok(Some(existing_id));
|
||||
}
|
||||
|
||||
probe_regular_upload(path_text)?
|
||||
.map(|(file, size)| self.register(path_text, file, size))
|
||||
.transpose()
|
||||
}
|
||||
}
|
||||
|
||||
fn probe_regular_upload(path_text: &str) -> Result<Option<(File, u64)>> {
|
||||
let file = match File::open(path_text) {
|
||||
Ok(file) => file,
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
let metadata = match file.metadata() {
|
||||
Ok(metadata) => metadata,
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
|
||||
Ok(metadata.is_file().then_some((file, metadata.len())))
|
||||
}
|
||||
|
||||
fn choose_request_transports(socket_path: &Path) -> RequestTransports {
|
||||
match query_server_features(socket_path) {
|
||||
Ok(server_features) => RequestTransports {
|
||||
upload_transport: if server_features
|
||||
.supports_upload_transport(UploadTransport::PassedFileDescriptors)
|
||||
{
|
||||
UploadTransport::PassedFileDescriptors
|
||||
} else {
|
||||
UploadTransport::StreamedBytes
|
||||
},
|
||||
stdin_transport: if server_features
|
||||
.supports_stdin_transport(UploadTransport::PassedFileDescriptors)
|
||||
{
|
||||
UploadTransport::PassedFileDescriptors
|
||||
} else {
|
||||
UploadTransport::StreamedBytes
|
||||
},
|
||||
response_transport: if server_features
|
||||
.supports_response_transport(ResponseTransport::BinaryFrames)
|
||||
{
|
||||
ResponseTransport::BinaryFrames
|
||||
} else {
|
||||
ResponseTransport::JsonMessages
|
||||
},
|
||||
},
|
||||
Err(error) => {
|
||||
log::warn!(
|
||||
"failed to query server features, falling back to streamed byte uploads and stdin: {error:#}"
|
||||
);
|
||||
RequestTransports {
|
||||
upload_transport: UploadTransport::StreamedBytes,
|
||||
stdin_transport: UploadTransport::StreamedBytes,
|
||||
response_transport: ResponseTransport::JsonMessages,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn query_server_features(socket_path: &Path) -> Result<ServerFeatures> {
|
||||
let mut stream = UnixStream::connect(socket_path).with_context(|| {
|
||||
format!(
|
||||
"failed to connect to broker socket `{}` for server-features query",
|
||||
socket_path.display()
|
||||
)
|
||||
})?;
|
||||
write_server_features_magic(&mut stream)?;
|
||||
let socket_supports_fd_passing = probe_socket_for_fd_passing(&stream);
|
||||
let mut server_features = read_server_features_response(&mut stream)?;
|
||||
|
||||
if !socket_supports_fd_passing {
|
||||
let server_advertised_fd_passing = server_features
|
||||
.supports_upload_transport(UploadTransport::PassedFileDescriptors)
|
||||
|| server_features.supports_stdin_transport(UploadTransport::PassedFileDescriptors);
|
||||
if server_advertised_fd_passing {
|
||||
log::warn!(
|
||||
"bro-server advertised fd passing, but the socket rejected an SCM_RIGHTS probe; falling back to streamed uploads and stdin"
|
||||
);
|
||||
}
|
||||
server_features
|
||||
.upload_transports
|
||||
.retain(|transport| *transport != UploadTransport::PassedFileDescriptors);
|
||||
server_features
|
||||
.stdin_transports
|
||||
.retain(|transport| *transport != UploadTransport::PassedFileDescriptors);
|
||||
}
|
||||
|
||||
Ok(server_features)
|
||||
}
|
||||
|
||||
fn probe_socket_for_fd_passing(stream: &UnixStream) -> bool {
|
||||
let probe_file = match File::open("/dev/null") {
|
||||
Ok(file) => file,
|
||||
Err(error) => {
|
||||
log::warn!(
|
||||
"failed to open `/dev/null` for SCM_RIGHTS probe; assuming fd passing is unavailable: {error:#}"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
match stream.send_with_fd(&[0_u8], &[probe_file.as_raw_fd()]) {
|
||||
Ok(1) => {
|
||||
log::debug!("socket SCM_RIGHTS probe succeeded");
|
||||
true
|
||||
}
|
||||
Ok(sent) => {
|
||||
log::warn!(
|
||||
"socket SCM_RIGHTS probe wrote {sent} probe byte(s) instead of 1; assuming fd passing is unavailable"
|
||||
);
|
||||
false
|
||||
}
|
||||
Err(error) => {
|
||||
log::debug!("socket SCM_RIGHTS probe failed: {error:#}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_request_header(
|
||||
options: &ClientOptions,
|
||||
stdin_size: Option<u64>,
|
||||
upload_transport: UploadTransport,
|
||||
stdin_transport: UploadTransport,
|
||||
response_transport: ResponseTransport,
|
||||
) -> Result<PreparedRequest> {
|
||||
let mut upload_registry = UploadRegistry::default();
|
||||
let planned_rewrites = plan_request_rewrites(
|
||||
&options.target_args,
|
||||
&RequestPlanningOptions {
|
||||
file_flags: &options.file_flags,
|
||||
file_args_enabled: options.file_args_enabled,
|
||||
},
|
||||
&mut upload_registry,
|
||||
)?;
|
||||
|
||||
if !options.file_flags.is_empty() {
|
||||
log::info!(
|
||||
"BRO_FILE_FLAGS rewrote {} argument value(s) for transport",
|
||||
planned_rewrites.stats.file_flag_rewrite_count
|
||||
);
|
||||
}
|
||||
if options.file_args_enabled {
|
||||
log::info!(
|
||||
"BRO_FILE_ARGS examined {} non-flag argument(s) and auto-forwarded {} readable regular file(s)",
|
||||
planned_rewrites.stats.auto_file_arg_probe_count,
|
||||
planned_rewrites.stats.auto_file_arg_rewrite_count,
|
||||
);
|
||||
}
|
||||
|
||||
let (uploads, prepared_uploads) = upload_registry.into_parts();
|
||||
let forwarded_env = options
|
||||
.forward_env
|
||||
.iter()
|
||||
.filter_map(|key| env::var(key).ok().map(|value| (key.clone(), value)))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
|
||||
Ok(PreparedRequest {
|
||||
header: RequestHeader {
|
||||
args: options.target_args.clone(),
|
||||
env: forwarded_env,
|
||||
uploads,
|
||||
rewrites: planned_rewrites.rewrites,
|
||||
stdin_size,
|
||||
upload_transport,
|
||||
stdin_transport,
|
||||
response_transport,
|
||||
},
|
||||
uploads: prepared_uploads,
|
||||
})
|
||||
}
|
||||
|
||||
fn send_uploaded_file_bytes(stream: &mut UnixStream, uploads: &[PreparedUpload]) -> Result<()> {
|
||||
for upload in uploads {
|
||||
let mut source = &upload.file;
|
||||
copy_exact(&mut source, stream, upload.size).with_context(|| {
|
||||
format!(
|
||||
"failed to stream transport file `{}` to bro-server",
|
||||
upload.path.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_passed_descriptors(
|
||||
stream: &UnixStream,
|
||||
uploads: &[PreparedUpload],
|
||||
stdin_fd: Option<RawFd>,
|
||||
) -> Result<()> {
|
||||
if uploads.is_empty() && stdin_fd.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let marker_bytes = passed_fd_marker_bytes(uploads.len(), stdin_fd.is_some());
|
||||
|
||||
let mut fds = uploads
|
||||
.iter()
|
||||
.map(|upload| upload.file.as_raw_fd())
|
||||
.collect::<Vec<_>>();
|
||||
if let Some(stdin_fd) = stdin_fd {
|
||||
fds.push(stdin_fd);
|
||||
}
|
||||
|
||||
let sent = stream
|
||||
.send_with_fd(&marker_bytes, &fds)
|
||||
.context("failed to send passed file descriptors to bro-server")?;
|
||||
if sent != marker_bytes.len() {
|
||||
bail!(
|
||||
"sent {sent} fd marker bytes but expected to send {}",
|
||||
marker_bytes.len()
|
||||
)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare_stdin(stdin_transport: UploadTransport) -> Result<PreparedStdin> {
|
||||
match stdin_transport {
|
||||
UploadTransport::StreamedBytes => {
|
||||
let (spool, size) = spool_stdin()?;
|
||||
Ok(PreparedStdin::Streamed { spool, size })
|
||||
}
|
||||
UploadTransport::PassedFileDescriptors => prepare_passed_stdin(),
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_passed_stdin() -> Result<PreparedStdin> {
|
||||
let stdin = io::stdin();
|
||||
|
||||
if stdin.is_terminal() && !env_var_is_truthy("BRO_CAPTURE_TTY_STDIN") {
|
||||
log::info!(
|
||||
"stdin is a terminal; using /dev/null as fd-passed stdin to avoid blocking. Set BRO_CAPTURE_TTY_STDIN=1 to pass terminal stdin through to bro-server"
|
||||
);
|
||||
let devnull = File::open("/dev/null").context("failed to open `/dev/null` for stdin")?;
|
||||
return Ok(PreparedStdin::Passed {
|
||||
fd: devnull.as_raw_fd(),
|
||||
_owner: Some(devnull),
|
||||
});
|
||||
}
|
||||
|
||||
if stdin.is_terminal() {
|
||||
log::warn!(
|
||||
"stdin is a terminal and BRO_CAPTURE_TTY_STDIN is enabled; passing terminal stdin directly to bro-server"
|
||||
);
|
||||
} else {
|
||||
log::info!("using fd passing for stdin; not spooling local stdin before execution");
|
||||
}
|
||||
|
||||
Ok(PreparedStdin::Passed {
|
||||
fd: stdin.as_raw_fd(),
|
||||
_owner: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn spool_stdin() -> Result<(NamedTempFile, u64)> {
|
||||
let mut spool = NamedTempFile::new().context("failed to create stdin spool file")?;
|
||||
let stdin = io::stdin();
|
||||
|
||||
if stdin.is_terminal() && !env_var_is_truthy("BRO_CAPTURE_TTY_STDIN") {
|
||||
log::info!(
|
||||
"stdin is a terminal; using empty stdin to avoid blocking. Set BRO_CAPTURE_TTY_STDIN=1 to read terminal stdin until EOF"
|
||||
);
|
||||
spool
|
||||
.rewind()
|
||||
.context("failed to rewind empty stdin spool")?;
|
||||
return Ok((spool, 0));
|
||||
}
|
||||
|
||||
if stdin.is_terminal() {
|
||||
log::warn!(
|
||||
"stdin is a terminal and BRO_CAPTURE_TTY_STDIN is enabled; reading stdin until EOF"
|
||||
);
|
||||
}
|
||||
|
||||
let stdin_size = io::copy(&mut stdin.lock(), &mut spool)
|
||||
.context("failed to spool stdin before transmission")?;
|
||||
spool
|
||||
.rewind()
|
||||
.context("failed to rewind stdin spool after capture")?;
|
||||
Ok((spool, stdin_size))
|
||||
}
|
||||
|
||||
fn receive_response(mut stream: UnixStream, response_transport: ResponseTransport) -> Result<i32> {
|
||||
let stdout_handle = io::stdout();
|
||||
let stdout_is_terminal = stdout_handle.is_terminal();
|
||||
let mut stdout = stdout_handle.lock();
|
||||
let stderr_handle = io::stderr();
|
||||
let stderr_is_terminal = stderr_handle.is_terminal();
|
||||
let mut stderr = stderr_handle.lock();
|
||||
|
||||
loop {
|
||||
let frame = read_response_frame(&mut stream, response_transport)?;
|
||||
match frame {
|
||||
ResponseFrame::Stdout(bytes) => {
|
||||
stdout
|
||||
.write_all(&bytes)
|
||||
.context("failed to write remote stdout locally")?;
|
||||
if stdout_is_terminal {
|
||||
stdout.flush().context("failed to flush local stdout")?;
|
||||
}
|
||||
}
|
||||
ResponseFrame::Stderr(bytes) => {
|
||||
stderr
|
||||
.write_all(&bytes)
|
||||
.context("failed to write remote stderr locally")?;
|
||||
if stderr_is_terminal {
|
||||
stderr.flush().context("failed to flush local stderr")?;
|
||||
}
|
||||
}
|
||||
ResponseFrame::Exit(status) => {
|
||||
let exit_code = status.to_exit_code();
|
||||
log::info!(
|
||||
"received remote exit status {:?} -> code {}",
|
||||
status,
|
||||
exit_code
|
||||
);
|
||||
return Ok(exit_code);
|
||||
}
|
||||
ResponseFrame::Error(message) => {
|
||||
log::warn!("received error response from bro-server: {message}");
|
||||
writeln!(stderr, "bro-server: {message}")
|
||||
.context("failed to print bro-server error")?;
|
||||
stderr.flush().context("failed to flush local stderr")?;
|
||||
return Ok(WRAPPER_ERROR_EXIT_CODE);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,71 @@
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use std::{
|
||||
env,
|
||||
io::{self, Read, Write},
|
||||
};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub fn init_tracing(binary_name: &str) -> Result<()> {
|
||||
let binary_target = binary_name.replace('-', "_");
|
||||
let default_filter = format!("info,{binary_target}=debug");
|
||||
|
||||
let Ok(requested_filter) = env::var("RUST_LOG") else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let requested_filter = match requested_filter.trim() {
|
||||
"" | "1" | "true" | "yes" | "on" => default_filter,
|
||||
value => value.to_owned(),
|
||||
};
|
||||
|
||||
let env_filter = EnvFilter::try_new(&requested_filter)
|
||||
.map_err(|error| anyhow!("invalid log filter `{requested_filter}`: {error}"))?;
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(env_filter)
|
||||
.with_target(true)
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.compact()
|
||||
.try_init()
|
||||
.map_err(|error| anyhow!("failed to initialize tracing subscriber: {error}"))?;
|
||||
|
||||
log::debug!("initialized tracing subscriber for {binary_name}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn copy_exact<R: Read, W: Write>(reader: &mut R, writer: &mut W, size: u64) -> Result<()> {
|
||||
let mut limited_reader = reader.take(size);
|
||||
let copied = io::copy(&mut limited_reader, writer).context("failed to copy payload bytes")?;
|
||||
if copied != size {
|
||||
bail!("expected to copy {size} bytes, but copied {copied}")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn env_var_is_set(key: &str) -> bool {
|
||||
env::var_os(key).is_some_and(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
pub fn env_var_is_truthy(key: &str) -> bool {
|
||||
env::var_os(key).is_some_and(|value| {
|
||||
matches!(
|
||||
value.to_string_lossy().trim().to_ascii_lowercase().as_str(),
|
||||
"1" | "true" | "yes" | "on"
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn parse_csv_env(key: &str) -> Vec<String> {
|
||||
env::var(key).map_or_else(|_| Vec::new(), |value| parse_csv_list(&value))
|
||||
}
|
||||
|
||||
pub fn parse_csv_list(value: &str) -> Vec<String> {
|
||||
value
|
||||
.split(',')
|
||||
.map(str::trim)
|
||||
.filter(|entry| !entry.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.collect()
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
pub mod common;
|
||||
pub mod protocol;
|
||||
+828
@@ -0,0 +1,828 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use std::{
|
||||
collections::{BTreeMap, HashSet},
|
||||
io::{Read, Write},
|
||||
os::unix::process::ExitStatusExt,
|
||||
process::ExitStatus,
|
||||
};
|
||||
|
||||
pub const EXECUTE_MAGIC: [u8; 4] = *b"BRO1";
|
||||
pub const SERVER_FEATURES_MAGIC: [u8; 4] = *b"BROC";
|
||||
pub const FD_PASS_MARKER: u8 = b'F';
|
||||
pub const STDIN_FD_PASS_MARKER: u8 = b'S';
|
||||
pub const MAX_CONTROL_MESSAGE_BYTES: u64 = 16 * 1024 * 1024;
|
||||
pub const MAX_RESPONSE_FRAME_BYTES: u64 = 16 * 1024 * 1024;
|
||||
pub const WRAPPER_ERROR_EXIT_CODE: i32 = 125;
|
||||
|
||||
const RESPONSE_FRAME_STDOUT: u8 = b'O';
|
||||
const RESPONSE_FRAME_STDERR: u8 = b'E';
|
||||
const RESPONSE_FRAME_EXIT: u8 = b'X';
|
||||
const RESPONSE_FRAME_ERROR: u8 = b'!';
|
||||
const BINARY_EXIT_STATUS_PAYLOAD_LEN: usize = 10;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ConnectionKind {
|
||||
Execute,
|
||||
ServerFeaturesQuery,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum UploadTransport {
|
||||
#[default]
|
||||
StreamedBytes,
|
||||
PassedFileDescriptors,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseTransport {
|
||||
#[default]
|
||||
JsonMessages,
|
||||
BinaryFrames,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RequestHeader {
|
||||
pub args: Vec<String>,
|
||||
pub env: BTreeMap<String, String>,
|
||||
pub uploads: Vec<UploadSpec>,
|
||||
pub rewrites: Vec<ArgRewrite>,
|
||||
#[serde(default)]
|
||||
pub stdin_size: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub upload_transport: UploadTransport,
|
||||
#[serde(default)]
|
||||
pub stdin_transport: UploadTransport,
|
||||
#[serde(default)]
|
||||
pub response_transport: ResponseTransport,
|
||||
}
|
||||
|
||||
impl RequestHeader {
|
||||
pub fn validate_invariants(&self) -> Result<()> {
|
||||
if self.stdin_transport == UploadTransport::StreamedBytes && self.stdin_size.is_none() {
|
||||
bail!("request used streamed stdin but did not provide stdin_size")
|
||||
}
|
||||
|
||||
let mut upload_ids = HashSet::with_capacity(self.uploads.len());
|
||||
for upload in &self.uploads {
|
||||
if !upload_ids.insert(upload.id) {
|
||||
bail!("request referenced upload id {} more than once", upload.id)
|
||||
}
|
||||
}
|
||||
|
||||
for rewrite in &self.rewrites {
|
||||
match rewrite {
|
||||
ArgRewrite::Replace {
|
||||
arg_index,
|
||||
upload_id,
|
||||
}
|
||||
| ArgRewrite::Prefixed {
|
||||
arg_index,
|
||||
upload_id,
|
||||
..
|
||||
} => {
|
||||
if *arg_index >= self.args.len() {
|
||||
bail!(
|
||||
"request tried to rewrite argument index {} but only {} arguments were provided",
|
||||
arg_index,
|
||||
self.args.len()
|
||||
)
|
||||
}
|
||||
if !upload_ids.contains(upload_id) {
|
||||
bail!(
|
||||
"request tried to rewrite argument {} using unknown upload id {}",
|
||||
arg_index,
|
||||
upload_id
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn rewrite_args<F>(&self, mut resolve_upload: F) -> Result<Vec<String>>
|
||||
where
|
||||
F: FnMut(u32) -> Result<String>,
|
||||
{
|
||||
let mut rewritten = self.args.clone();
|
||||
|
||||
for rewrite in &self.rewrites {
|
||||
match rewrite {
|
||||
ArgRewrite::Replace {
|
||||
arg_index,
|
||||
upload_id,
|
||||
} => {
|
||||
rewritten[*arg_index] = resolve_upload(*upload_id)?;
|
||||
}
|
||||
ArgRewrite::Prefixed {
|
||||
arg_index,
|
||||
prefix,
|
||||
upload_id,
|
||||
} => {
|
||||
rewritten[*arg_index] = format!("{prefix}{}", resolve_upload(*upload_id)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(rewritten)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct UploadSpec {
|
||||
pub id: u32,
|
||||
pub original_path: String,
|
||||
pub size: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ArgRewrite {
|
||||
Replace {
|
||||
arg_index: usize,
|
||||
upload_id: u32,
|
||||
},
|
||||
Prefixed {
|
||||
arg_index: usize,
|
||||
prefix: String,
|
||||
upload_id: u32,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ResponseFrame {
|
||||
Stdout(Vec<u8>),
|
||||
Stderr(Vec<u8>),
|
||||
Exit(RemoteExitStatus),
|
||||
Error(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct RemoteExitStatus {
|
||||
pub code: Option<i32>,
|
||||
pub signal: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct RequestPlanningOptions<'a> {
|
||||
pub file_flags: &'a [String],
|
||||
pub file_args_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct RequestPlanningStats {
|
||||
pub file_flag_rewrite_count: usize,
|
||||
pub auto_file_arg_probe_count: usize,
|
||||
pub auto_file_arg_rewrite_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PlannedArgRewrites {
|
||||
pub rewrites: Vec<ArgRewrite>,
|
||||
pub stats: RequestPlanningStats,
|
||||
}
|
||||
|
||||
pub trait UploadRegistrar {
|
||||
fn ensure_upload(&mut self, path_text: &str) -> Result<u32>;
|
||||
fn maybe_ensure_upload(&mut self, path_text: &str) -> Result<Option<u32>>;
|
||||
}
|
||||
|
||||
pub fn plan_request_rewrites<U>(
|
||||
args: &[String],
|
||||
options: &RequestPlanningOptions<'_>,
|
||||
uploads: &mut U,
|
||||
) -> Result<PlannedArgRewrites>
|
||||
where
|
||||
U: UploadRegistrar,
|
||||
{
|
||||
let mut rewrites = Vec::new();
|
||||
let mut stats = RequestPlanningStats::default();
|
||||
let mut skip_index = None;
|
||||
|
||||
for (index, argument) in args.iter().enumerate() {
|
||||
if skip_index == Some(index) {
|
||||
skip_index = None;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(matched_flag) = match_file_flag(args, index, options.file_flags)? {
|
||||
match matched_flag {
|
||||
MatchedFileFlag::Separate {
|
||||
flag,
|
||||
value_index,
|
||||
value,
|
||||
} => {
|
||||
let upload_id = uploads.ensure_upload(value)?;
|
||||
rewrites.push(ArgRewrite::Replace {
|
||||
arg_index: value_index,
|
||||
upload_id,
|
||||
});
|
||||
log::debug!(
|
||||
"BRO_FILE_FLAGS matched separate flag `{flag}` at arg index {index}; forwarding argument index {value_index} -> `{value}`"
|
||||
);
|
||||
stats.file_flag_rewrite_count += 1;
|
||||
skip_index = Some(value_index);
|
||||
}
|
||||
MatchedFileFlag::Joined {
|
||||
flag,
|
||||
prefix,
|
||||
value,
|
||||
} => {
|
||||
let upload_id = uploads.ensure_upload(value)?;
|
||||
rewrites.push(ArgRewrite::Prefixed {
|
||||
arg_index: index,
|
||||
prefix,
|
||||
upload_id,
|
||||
});
|
||||
log::debug!(
|
||||
"BRO_FILE_FLAGS matched joined flag `{flag}` at arg index {index}; forwarding value `{value}`"
|
||||
);
|
||||
stats.file_flag_rewrite_count += 1;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if options.file_args_enabled && is_nonflag_argument(argument) {
|
||||
stats.auto_file_arg_probe_count += 1;
|
||||
if let Some(upload_id) = uploads.maybe_ensure_upload(argument)? {
|
||||
rewrites.push(ArgRewrite::Replace {
|
||||
arg_index: index,
|
||||
upload_id,
|
||||
});
|
||||
stats.auto_file_arg_rewrite_count += 1;
|
||||
log::debug!(
|
||||
"BRO_FILE_ARGS auto-forwarded non-flag argument index {index}: `{argument}`"
|
||||
);
|
||||
} else {
|
||||
log::debug!(
|
||||
"BRO_FILE_ARGS left non-flag argument index {index} unchanged because `{argument}` could not be opened as a readable regular file"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PlannedArgRewrites { rewrites, stats })
|
||||
}
|
||||
|
||||
enum MatchedFileFlag<'a> {
|
||||
Separate {
|
||||
flag: &'a str,
|
||||
value_index: usize,
|
||||
value: &'a str,
|
||||
},
|
||||
Joined {
|
||||
flag: &'a str,
|
||||
prefix: String,
|
||||
value: &'a str,
|
||||
},
|
||||
}
|
||||
|
||||
fn match_file_flag<'a>(
|
||||
args: &'a [String],
|
||||
index: usize,
|
||||
file_flags: &'a [String],
|
||||
) -> Result<Option<MatchedFileFlag<'a>>> {
|
||||
let argument = &args[index];
|
||||
|
||||
for flag in file_flags {
|
||||
if argument == flag {
|
||||
let value_index = index + 1;
|
||||
let value = args.get(value_index).with_context(|| {
|
||||
format!("flag `{flag}` is configured as a file flag but has no value")
|
||||
})?;
|
||||
return Ok(Some(MatchedFileFlag::Separate {
|
||||
flag: flag.as_str(),
|
||||
value_index,
|
||||
value,
|
||||
}));
|
||||
}
|
||||
|
||||
let prefix = format!("{flag}=");
|
||||
if let Some(value) = argument.strip_prefix(&prefix) {
|
||||
return Ok(Some(MatchedFileFlag::Joined {
|
||||
flag: flag.as_str(),
|
||||
prefix,
|
||||
value,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn is_nonflag_argument(argument: &str) -> bool {
|
||||
!argument.starts_with('-') || argument == "-"
|
||||
}
|
||||
|
||||
fn default_upload_transports() -> Vec<UploadTransport> {
|
||||
vec![UploadTransport::StreamedBytes]
|
||||
}
|
||||
|
||||
fn default_response_transports() -> Vec<ResponseTransport> {
|
||||
vec![ResponseTransport::JsonMessages]
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerFeatures {
|
||||
#[serde(default = "default_upload_transports")]
|
||||
pub upload_transports: Vec<UploadTransport>,
|
||||
#[serde(default = "default_upload_transports")]
|
||||
pub stdin_transports: Vec<UploadTransport>,
|
||||
#[serde(default = "default_response_transports")]
|
||||
pub response_transports: Vec<ResponseTransport>,
|
||||
}
|
||||
|
||||
impl ServerFeatures {
|
||||
pub fn supports_transport(&self, transport: UploadTransport) -> bool {
|
||||
self.supports_upload_transport(transport)
|
||||
}
|
||||
|
||||
pub fn supports_upload_transport(&self, transport: UploadTransport) -> bool {
|
||||
self.upload_transports.contains(&transport)
|
||||
}
|
||||
|
||||
pub fn supports_stdin_transport(&self, transport: UploadTransport) -> bool {
|
||||
self.stdin_transports.contains(&transport)
|
||||
}
|
||||
|
||||
pub fn supports_response_transport(&self, transport: ResponseTransport) -> bool {
|
||||
self.response_transports.contains(&transport)
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteExitStatus {
|
||||
pub fn to_exit_code(&self) -> i32 {
|
||||
if let Some(code) = self.code {
|
||||
code
|
||||
} else if let Some(signal) = self.signal {
|
||||
128 + signal
|
||||
} else {
|
||||
WRAPPER_ERROR_EXIT_CODE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ExitStatus> for RemoteExitStatus {
|
||||
fn from(status: ExitStatus) -> Self {
|
||||
Self {
|
||||
code: status.code(),
|
||||
signal: status.signal(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_execute_magic<W: Write>(writer: &mut W) -> Result<()> {
|
||||
write_magic(writer, &EXECUTE_MAGIC)
|
||||
}
|
||||
|
||||
pub fn write_server_features_magic<W: Write>(writer: &mut W) -> Result<()> {
|
||||
write_magic(writer, &SERVER_FEATURES_MAGIC)
|
||||
}
|
||||
|
||||
fn write_magic<W: Write>(writer: &mut W, magic: &[u8; 4]) -> Result<()> {
|
||||
writer
|
||||
.write_all(magic)
|
||||
.context("failed to write protocol magic")
|
||||
}
|
||||
|
||||
pub fn read_connection_kind<R: Read>(reader: &mut R) -> Result<ConnectionKind> {
|
||||
let mut magic = [0_u8; 4];
|
||||
reader
|
||||
.read_exact(&mut magic)
|
||||
.context("failed to read protocol magic")?;
|
||||
|
||||
match magic {
|
||||
EXECUTE_MAGIC => Ok(ConnectionKind::Execute),
|
||||
SERVER_FEATURES_MAGIC => Ok(ConnectionKind::ServerFeaturesQuery),
|
||||
_ => bail!(
|
||||
"received an unsupported protocol header: expected `{EXECUTE_MAGIC:?}` or `{SERVER_FEATURES_MAGIC:?}`, got `{magic:?}`"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_execute_request_header<W: Write>(
|
||||
writer: &mut W,
|
||||
header: &RequestHeader,
|
||||
) -> Result<()> {
|
||||
write_message(writer, header)
|
||||
}
|
||||
|
||||
pub fn read_execute_request_header<R: Read>(reader: &mut R) -> Result<RequestHeader> {
|
||||
read_message(reader)
|
||||
}
|
||||
|
||||
pub fn write_server_features_response<W: Write>(
|
||||
writer: &mut W,
|
||||
server_features: &ServerFeatures,
|
||||
) -> Result<()> {
|
||||
write_message(writer, server_features)
|
||||
}
|
||||
|
||||
pub fn read_server_features_response<R: Read>(reader: &mut R) -> Result<ServerFeatures> {
|
||||
read_message(reader)
|
||||
}
|
||||
|
||||
pub fn passed_fd_marker_bytes(upload_count: usize, include_stdin: bool) -> Vec<u8> {
|
||||
let mut marker_bytes = std::iter::repeat_n(FD_PASS_MARKER, upload_count).collect::<Vec<_>>();
|
||||
if include_stdin {
|
||||
marker_bytes.push(STDIN_FD_PASS_MARKER);
|
||||
}
|
||||
marker_bytes
|
||||
}
|
||||
|
||||
pub fn expected_passed_fd_count(
|
||||
upload_transport: UploadTransport,
|
||||
upload_count: usize,
|
||||
stdin_transport: UploadTransport,
|
||||
) -> usize {
|
||||
let upload_fd_count = match upload_transport {
|
||||
UploadTransport::StreamedBytes => 0,
|
||||
UploadTransport::PassedFileDescriptors => upload_count,
|
||||
};
|
||||
|
||||
upload_fd_count + usize::from(stdin_transport == UploadTransport::PassedFileDescriptors)
|
||||
}
|
||||
|
||||
pub fn validate_passed_fd_marker_bytes(
|
||||
marker_bytes: &[u8],
|
||||
upload_count: usize,
|
||||
expects_stdin: bool,
|
||||
) -> Result<()> {
|
||||
let expected_count = upload_count + usize::from(expects_stdin);
|
||||
if marker_bytes.len() != expected_count {
|
||||
bail!(
|
||||
"received {} fd marker bytes but expected {}",
|
||||
marker_bytes.len(),
|
||||
expected_count
|
||||
)
|
||||
}
|
||||
|
||||
if marker_bytes[..upload_count]
|
||||
.iter()
|
||||
.any(|byte| *byte != FD_PASS_MARKER)
|
||||
{
|
||||
bail!("received invalid upload file descriptor marker bytes from bro-client")
|
||||
}
|
||||
if expects_stdin && marker_bytes[upload_count] != STDIN_FD_PASS_MARKER {
|
||||
bail!("received invalid stdin file descriptor marker byte from bro-client")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_message<W: Write, T: Serialize>(writer: &mut W, value: &T) -> Result<()> {
|
||||
let payload =
|
||||
serde_json::to_vec(value).context("failed to serialize protocol message as JSON")?;
|
||||
let payload_len = u64::try_from(payload.len()).context("protocol message is too large")?;
|
||||
|
||||
writer
|
||||
.write_all(&payload_len.to_le_bytes())
|
||||
.context("failed to write protocol message length")?;
|
||||
writer
|
||||
.write_all(&payload)
|
||||
.context("failed to write protocol message payload")?;
|
||||
writer.flush().context("failed to flush protocol message")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn read_message<R: Read, T: DeserializeOwned>(reader: &mut R) -> Result<T> {
|
||||
let payload =
|
||||
read_length_prefixed_payload(reader, MAX_CONTROL_MESSAGE_BYTES, "protocol message")?;
|
||||
|
||||
serde_json::from_slice(&payload).context("failed to deserialize protocol message from JSON")
|
||||
}
|
||||
|
||||
pub fn write_response_frame<W: Write>(
|
||||
writer: &mut W,
|
||||
transport: ResponseTransport,
|
||||
frame: &ResponseFrame,
|
||||
) -> Result<()> {
|
||||
match transport {
|
||||
ResponseTransport::JsonMessages => write_message(writer, frame),
|
||||
ResponseTransport::BinaryFrames => write_binary_response_frame(writer, frame),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_response_frame<R: Read>(
|
||||
reader: &mut R,
|
||||
transport: ResponseTransport,
|
||||
) -> Result<ResponseFrame> {
|
||||
match transport {
|
||||
ResponseTransport::JsonMessages => read_message(reader),
|
||||
ResponseTransport::BinaryFrames => read_binary_response_frame(reader),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_binary_response_frame<W: Write>(writer: &mut W, frame: &ResponseFrame) -> Result<()> {
|
||||
match frame {
|
||||
ResponseFrame::Stdout(bytes) => {
|
||||
write_binary_response_payload(writer, RESPONSE_FRAME_STDOUT, bytes)
|
||||
}
|
||||
ResponseFrame::Stderr(bytes) => {
|
||||
write_binary_response_payload(writer, RESPONSE_FRAME_STDERR, bytes)
|
||||
}
|
||||
ResponseFrame::Exit(status) => {
|
||||
let mut payload = Vec::with_capacity(BINARY_EXIT_STATUS_PAYLOAD_LEN);
|
||||
payload.push(u8::from(status.code.is_some()));
|
||||
payload.extend_from_slice(&status.code.unwrap_or_default().to_le_bytes());
|
||||
payload.push(u8::from(status.signal.is_some()));
|
||||
payload.extend_from_slice(&status.signal.unwrap_or_default().to_le_bytes());
|
||||
write_binary_response_payload(writer, RESPONSE_FRAME_EXIT, &payload)
|
||||
}
|
||||
ResponseFrame::Error(message) => {
|
||||
write_binary_response_payload(writer, RESPONSE_FRAME_ERROR, message.as_bytes())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_binary_response_payload<W: Write>(writer: &mut W, tag: u8, payload: &[u8]) -> Result<()> {
|
||||
let payload_len =
|
||||
u64::try_from(payload.len()).context("response frame payload is too large")?;
|
||||
writer
|
||||
.write_all(&[tag])
|
||||
.context("failed to write response frame type")?;
|
||||
writer
|
||||
.write_all(&payload_len.to_le_bytes())
|
||||
.context("failed to write response frame length")?;
|
||||
writer
|
||||
.write_all(payload)
|
||||
.context("failed to write response frame payload")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_binary_response_frame<R: Read>(reader: &mut R) -> Result<ResponseFrame> {
|
||||
let mut tag = [0_u8; 1];
|
||||
reader
|
||||
.read_exact(&mut tag)
|
||||
.context("failed to read response frame type")?;
|
||||
let payload = read_length_prefixed_payload(reader, MAX_RESPONSE_FRAME_BYTES, "response frame")?;
|
||||
|
||||
match tag[0] {
|
||||
RESPONSE_FRAME_STDOUT => Ok(ResponseFrame::Stdout(payload)),
|
||||
RESPONSE_FRAME_STDERR => Ok(ResponseFrame::Stderr(payload)),
|
||||
RESPONSE_FRAME_ERROR => Ok(ResponseFrame::Error(
|
||||
String::from_utf8(payload).context("response error frame was not valid UTF-8")?,
|
||||
)),
|
||||
RESPONSE_FRAME_EXIT => read_binary_exit_status(&payload),
|
||||
_ => bail!("received unknown response frame type `{}`", tag[0]),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_binary_exit_status(payload: &[u8]) -> Result<ResponseFrame> {
|
||||
if payload.len() != BINARY_EXIT_STATUS_PAYLOAD_LEN {
|
||||
bail!(
|
||||
"binary exit status frame payload had length {} but expected {}",
|
||||
payload.len(),
|
||||
BINARY_EXIT_STATUS_PAYLOAD_LEN
|
||||
)
|
||||
}
|
||||
|
||||
let code = if payload[0] == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(i32::from_le_bytes([
|
||||
payload[1], payload[2], payload[3], payload[4],
|
||||
]))
|
||||
};
|
||||
let signal = if payload[5] == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(i32::from_le_bytes([
|
||||
payload[6], payload[7], payload[8], payload[9],
|
||||
]))
|
||||
};
|
||||
|
||||
Ok(ResponseFrame::Exit(RemoteExitStatus { code, signal }))
|
||||
}
|
||||
|
||||
fn read_length_prefixed_payload<R: Read>(
|
||||
reader: &mut R,
|
||||
max_payload_len: u64,
|
||||
payload_kind: &str,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut length_bytes = [0_u8; 8];
|
||||
reader
|
||||
.read_exact(&mut length_bytes)
|
||||
.with_context(|| format!("failed to read {payload_kind} length"))?;
|
||||
|
||||
let payload_len = u64::from_le_bytes(length_bytes);
|
||||
if payload_len > max_payload_len {
|
||||
bail!(
|
||||
"{payload_kind} length {payload_len} exceeds the maximum supported size of {max_payload_len} bytes"
|
||||
)
|
||||
}
|
||||
|
||||
let payload_len = usize::try_from(payload_len)
|
||||
.with_context(|| format!("{payload_kind} does not fit in memory"))?;
|
||||
let mut payload = vec![0_u8; payload_len];
|
||||
reader
|
||||
.read_exact(&mut payload)
|
||||
.with_context(|| format!("failed to read {payload_kind} payload"))?;
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use std::io::Cursor;
|
||||
|
||||
#[derive(Default)]
|
||||
struct MockUploadRegistrar {
|
||||
available_paths: HashSet<String>,
|
||||
upload_ids_by_path: HashMap<String, u32>,
|
||||
next_upload_id: u32,
|
||||
}
|
||||
|
||||
impl MockUploadRegistrar {
|
||||
fn with_available_paths(paths: impl IntoIterator<Item = &'static str>) -> Self {
|
||||
Self {
|
||||
available_paths: paths.into_iter().map(ToOwned::to_owned).collect(),
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn register(&mut self, path_text: &str) -> u32 {
|
||||
if let Some(&upload_id) = self.upload_ids_by_path.get(path_text) {
|
||||
return upload_id;
|
||||
}
|
||||
|
||||
let upload_id = self.next_upload_id;
|
||||
self.next_upload_id += 1;
|
||||
self.upload_ids_by_path
|
||||
.insert(path_text.to_owned(), upload_id);
|
||||
upload_id
|
||||
}
|
||||
}
|
||||
|
||||
impl UploadRegistrar for MockUploadRegistrar {
|
||||
fn ensure_upload(&mut self, path_text: &str) -> Result<u32> {
|
||||
if !self.available_paths.contains(path_text) {
|
||||
bail!("mock upload `{path_text}` is unavailable")
|
||||
}
|
||||
Ok(self.register(path_text))
|
||||
}
|
||||
|
||||
fn maybe_ensure_upload(&mut self, path_text: &str) -> Result<Option<u32>> {
|
||||
Ok(self
|
||||
.available_paths
|
||||
.contains(path_text)
|
||||
.then(|| self.register(path_text)))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_request_rewrites_handles_file_flags_and_file_args() -> Result<()> {
|
||||
let args = vec![
|
||||
"--config".to_owned(),
|
||||
"config.toml".to_owned(),
|
||||
"--input=data.txt".to_owned(),
|
||||
"note.txt".to_owned(),
|
||||
"-n".to_owned(),
|
||||
];
|
||||
let file_flags = vec!["--config".to_owned(), "--input".to_owned()];
|
||||
let options = RequestPlanningOptions {
|
||||
file_flags: &file_flags,
|
||||
file_args_enabled: true,
|
||||
};
|
||||
let mut uploads =
|
||||
MockUploadRegistrar::with_available_paths(["config.toml", "data.txt", "note.txt"]);
|
||||
|
||||
let planned = plan_request_rewrites(&args, &options, &mut uploads)?;
|
||||
|
||||
assert_eq!(
|
||||
planned.rewrites,
|
||||
vec![
|
||||
ArgRewrite::Replace {
|
||||
arg_index: 1,
|
||||
upload_id: 0,
|
||||
},
|
||||
ArgRewrite::Prefixed {
|
||||
arg_index: 2,
|
||||
prefix: "--input=".to_owned(),
|
||||
upload_id: 1,
|
||||
},
|
||||
ArgRewrite::Replace {
|
||||
arg_index: 3,
|
||||
upload_id: 2,
|
||||
},
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
planned.stats,
|
||||
RequestPlanningStats {
|
||||
file_flag_rewrite_count: 2,
|
||||
auto_file_arg_probe_count: 1,
|
||||
auto_file_arg_rewrite_count: 1,
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_request_rewrites_treats_double_dash_like_any_other_argument() -> Result<()> {
|
||||
let args = vec![
|
||||
"--".to_owned(),
|
||||
"--config".to_owned(),
|
||||
"plain.txt".to_owned(),
|
||||
];
|
||||
let file_flags = vec!["--config".to_owned()];
|
||||
let options = RequestPlanningOptions {
|
||||
file_flags: &file_flags,
|
||||
file_args_enabled: true,
|
||||
};
|
||||
let mut uploads = MockUploadRegistrar::with_available_paths(["--config", "plain.txt"]);
|
||||
|
||||
let planned = plan_request_rewrites(&args, &options, &mut uploads)?;
|
||||
|
||||
assert_eq!(
|
||||
planned.rewrites,
|
||||
vec![ArgRewrite::Replace {
|
||||
arg_index: 2,
|
||||
upload_id: 0,
|
||||
},]
|
||||
);
|
||||
assert_eq!(planned.stats.file_flag_rewrite_count, 1);
|
||||
assert_eq!(planned.stats.auto_file_arg_probe_count, 0);
|
||||
assert_eq!(planned.stats.auto_file_arg_rewrite_count, 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_header_validates_upload_ids_and_rewrite_references() {
|
||||
let header = RequestHeader {
|
||||
args: vec!["a".to_owned()],
|
||||
env: BTreeMap::new(),
|
||||
uploads: vec![
|
||||
UploadSpec {
|
||||
id: 0,
|
||||
original_path: "x".to_owned(),
|
||||
size: 1,
|
||||
},
|
||||
UploadSpec {
|
||||
id: 0,
|
||||
original_path: "y".to_owned(),
|
||||
size: 1,
|
||||
},
|
||||
],
|
||||
rewrites: vec![ArgRewrite::Replace {
|
||||
arg_index: 0,
|
||||
upload_id: 1,
|
||||
}],
|
||||
stdin_size: Some(0),
|
||||
upload_transport: UploadTransport::StreamedBytes,
|
||||
stdin_transport: UploadTransport::StreamedBytes,
|
||||
response_transport: ResponseTransport::JsonMessages,
|
||||
};
|
||||
|
||||
assert!(header.validate_invariants().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_header_requires_streamed_stdin_size() {
|
||||
let header = RequestHeader {
|
||||
args: Vec::new(),
|
||||
env: BTreeMap::new(),
|
||||
uploads: Vec::new(),
|
||||
rewrites: Vec::new(),
|
||||
stdin_size: None,
|
||||
upload_transport: UploadTransport::StreamedBytes,
|
||||
stdin_transport: UploadTransport::StreamedBytes,
|
||||
response_transport: ResponseTransport::JsonMessages,
|
||||
};
|
||||
|
||||
assert!(header.validate_invariants().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_response_frames_round_trip() -> Result<()> {
|
||||
let frames = [
|
||||
ResponseFrame::Stdout(vec![1, 2, 3]),
|
||||
ResponseFrame::Stderr(vec![4, 5]),
|
||||
ResponseFrame::Exit(RemoteExitStatus {
|
||||
code: Some(42),
|
||||
signal: None,
|
||||
}),
|
||||
ResponseFrame::Error("boom".to_owned()),
|
||||
];
|
||||
|
||||
let mut buffer = Cursor::new(Vec::new());
|
||||
for frame in &frames {
|
||||
write_response_frame(&mut buffer, ResponseTransport::BinaryFrames, frame)?;
|
||||
}
|
||||
|
||||
buffer.set_position(0);
|
||||
for expected in frames {
|
||||
let actual = read_response_frame(&mut buffer, ResponseTransport::BinaryFrames)?;
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user