Initial commit

This commit is contained in:
2026-05-23 03:12:03 +09:00
commit 5219faa20f
14 changed files with 3707 additions and 0 deletions
+574
View File
@@ -0,0 +1,574 @@
use anyhow::{Context, Result, bail};
use bro::{
common::{copy_exact, env_var_is_set, env_var_is_truthy, init_tracing, parse_csv_env},
protocol::*,
};
use sendfd::SendWithFd;
use std::{
collections::{BTreeMap, HashMap},
env,
fs::File,
io::{self, IsTerminal, Seek, Write},
os::{
fd::{AsRawFd, RawFd},
unix::net::UnixStream,
},
path::{Path, PathBuf},
};
use tempfile::NamedTempFile;
fn main() {
if let Err(error) = init_tracing("bro-client") {
eprintln!("bro-client: failed to initialize tracing: {error:#}");
std::process::exit(WRAPPER_ERROR_EXIT_CODE);
}
let exit_code = match run() {
Ok(code) => code,
Err(error) => {
log::error!("bro-client failed: {error:#}");
WRAPPER_ERROR_EXIT_CODE
}
};
log::info!("bro-client exiting with code {exit_code}");
std::process::exit(exit_code);
}
fn run() -> Result<i32> {
let options = ClientOptions::from_env()?;
log::info!(
"starting bro-client request: socket={} args={} forward_env={} file_flags={} bro_file_args_enabled={}",
options.socket_path.display(),
options.target_args.len(),
options.forward_env.len(),
options.file_flags.len(),
options.file_args_enabled
);
if options.file_args_enabled {
log::info!(
"BRO_FILE_ARGS is enabled; bro-client will probe non-flag arguments and auto-forward readable regular files"
);
} else {
log::debug!(
"BRO_FILE_ARGS is disabled; bro-client will only forward paths matched via BRO_FILE_FLAGS"
);
}
if !options.file_flags.is_empty() {
log::debug!("BRO_FILE_FLAGS configured as: {:?}", options.file_flags);
}
let transports = choose_request_transports(&options.socket_path);
log::info!(
"selected request transports: uploads={:?} stdin={:?} responses={:?}",
transports.upload_transport,
transports.stdin_transport,
transports.response_transport
);
let mut stdin = prepare_stdin(transports.stdin_transport)?;
let request = build_request_header(
&options,
stdin.stdin_size(),
transports.upload_transport,
transports.stdin_transport,
transports.response_transport,
)?;
log::info!(
"prepared request: uploads={} rewrites={} stdin_size={:?}",
request.uploads.len(),
request.header.rewrites.len(),
request.header.stdin_size,
);
log::info!(
"connecting to bro-server socket at {}",
options.socket_path.display()
);
let mut stream = UnixStream::connect(&options.socket_path).with_context(|| {
format!(
"failed to connect to broker socket `{}`",
options.socket_path.display()
)
})?;
log::info!(
"connected to bro-server socket at {}",
options.socket_path.display()
);
write_execute_magic(&mut stream)?;
write_execute_request_header(&mut stream, &request.header)?;
if request.header.upload_transport == UploadTransport::StreamedBytes {
send_uploaded_file_bytes(&mut stream, &request.uploads)?;
}
let fd_uploads: &[PreparedUpload] =
if request.header.upload_transport == UploadTransport::PassedFileDescriptors {
&request.uploads[..]
} else {
&request.uploads[..0]
};
send_passed_descriptors(&stream, fd_uploads, stdin.passed_fd())?;
stdin.send_streamed_bytes_if_needed(&mut stream)?;
stream
.shutdown(std::net::Shutdown::Write)
.context("failed to close the client write side")?;
log::info!("waiting for response frames from bro-server");
receive_response(stream, request.header.response_transport)
}
struct ClientOptions {
socket_path: PathBuf,
forward_env: Vec<String>,
file_flags: Vec<String>,
file_args_enabled: bool,
target_args: Vec<String>,
}
impl ClientOptions {
fn from_env() -> Result<Self> {
let socket_path = env::var_os("BRO_SOCKET_PATH")
.map(PathBuf::from)
.context("BRO_SOCKET_PATH must be set")?;
Ok(Self {
socket_path,
forward_env: parse_csv_env("BRO_FORWARD_ENV"),
file_flags: parse_csv_env("BRO_FILE_FLAGS"),
file_args_enabled: env_var_is_set("BRO_FILE_ARGS"),
target_args: env::args().skip(1).collect(),
})
}
}
struct PreparedRequest {
header: RequestHeader,
uploads: Vec<PreparedUpload>,
}
struct PreparedUpload {
path: PathBuf,
size: u64,
file: File,
}
struct RequestTransports {
upload_transport: UploadTransport,
stdin_transport: UploadTransport,
response_transport: ResponseTransport,
}
enum PreparedStdin {
Streamed { spool: NamedTempFile, size: u64 },
Passed { fd: RawFd, _owner: Option<File> },
}
impl PreparedStdin {
fn stdin_size(&self) -> Option<u64> {
match self {
Self::Streamed { size, .. } => Some(*size),
Self::Passed { .. } => None,
}
}
fn passed_fd(&self) -> Option<RawFd> {
match self {
Self::Streamed { .. } => None,
Self::Passed { fd, .. } => Some(*fd),
}
}
fn send_streamed_bytes_if_needed(&mut self, stream: &mut UnixStream) -> Result<()> {
match self {
Self::Streamed { spool, size } => {
spool
.rewind()
.context("failed to rewind stdin spool before transmission")?;
copy_exact(spool, stream, *size).context("failed to stream stdin to bro-server")
}
Self::Passed { .. } => Ok(()),
}
}
}
#[derive(Default)]
struct UploadRegistry {
uploads_by_path: HashMap<String, u32>,
uploads: Vec<UploadSpec>,
prepared_uploads: Vec<PreparedUpload>,
next_upload_id: u32,
}
impl UploadRegistry {
fn register(&mut self, path_text: &str, file: File, size: u64) -> Result<u32> {
let upload_id = self.next_upload_id;
self.next_upload_id = self
.next_upload_id
.checked_add(1)
.context("too many transport files were requested")?;
self.uploads_by_path.insert(path_text.to_owned(), upload_id);
self.uploads.push(UploadSpec {
id: upload_id,
original_path: path_text.to_owned(),
size,
});
self.prepared_uploads.push(PreparedUpload {
path: PathBuf::from(path_text),
size,
file,
});
Ok(upload_id)
}
fn into_parts(self) -> (Vec<UploadSpec>, Vec<PreparedUpload>) {
(self.uploads, self.prepared_uploads)
}
}
impl UploadRegistrar for UploadRegistry {
fn ensure_upload(&mut self, path_text: &str) -> Result<u32> {
if let Some(&existing_id) = self.uploads_by_path.get(path_text) {
return Ok(existing_id);
}
let (file, size) = probe_regular_upload(path_text)?
.with_context(|| format!("failed to open transport file `{path_text}`"))?;
self.register(path_text, file, size)
}
fn maybe_ensure_upload(&mut self, path_text: &str) -> Result<Option<u32>> {
if let Some(&existing_id) = self.uploads_by_path.get(path_text) {
return Ok(Some(existing_id));
}
probe_regular_upload(path_text)?
.map(|(file, size)| self.register(path_text, file, size))
.transpose()
}
}
fn probe_regular_upload(path_text: &str) -> Result<Option<(File, u64)>> {
let file = match File::open(path_text) {
Ok(file) => file,
Err(_) => return Ok(None),
};
let metadata = match file.metadata() {
Ok(metadata) => metadata,
Err(_) => return Ok(None),
};
Ok(metadata.is_file().then_some((file, metadata.len())))
}
fn choose_request_transports(socket_path: &Path) -> RequestTransports {
match query_server_features(socket_path) {
Ok(server_features) => RequestTransports {
upload_transport: if server_features
.supports_upload_transport(UploadTransport::PassedFileDescriptors)
{
UploadTransport::PassedFileDescriptors
} else {
UploadTransport::StreamedBytes
},
stdin_transport: if server_features
.supports_stdin_transport(UploadTransport::PassedFileDescriptors)
{
UploadTransport::PassedFileDescriptors
} else {
UploadTransport::StreamedBytes
},
response_transport: if server_features
.supports_response_transport(ResponseTransport::BinaryFrames)
{
ResponseTransport::BinaryFrames
} else {
ResponseTransport::JsonMessages
},
},
Err(error) => {
log::warn!(
"failed to query server features, falling back to streamed byte uploads and stdin: {error:#}"
);
RequestTransports {
upload_transport: UploadTransport::StreamedBytes,
stdin_transport: UploadTransport::StreamedBytes,
response_transport: ResponseTransport::JsonMessages,
}
}
}
}
fn query_server_features(socket_path: &Path) -> Result<ServerFeatures> {
let mut stream = UnixStream::connect(socket_path).with_context(|| {
format!(
"failed to connect to broker socket `{}` for server-features query",
socket_path.display()
)
})?;
write_server_features_magic(&mut stream)?;
let socket_supports_fd_passing = probe_socket_for_fd_passing(&stream);
let mut server_features = read_server_features_response(&mut stream)?;
if !socket_supports_fd_passing {
let server_advertised_fd_passing = server_features
.supports_upload_transport(UploadTransport::PassedFileDescriptors)
|| server_features.supports_stdin_transport(UploadTransport::PassedFileDescriptors);
if server_advertised_fd_passing {
log::warn!(
"bro-server advertised fd passing, but the socket rejected an SCM_RIGHTS probe; falling back to streamed uploads and stdin"
);
}
server_features
.upload_transports
.retain(|transport| *transport != UploadTransport::PassedFileDescriptors);
server_features
.stdin_transports
.retain(|transport| *transport != UploadTransport::PassedFileDescriptors);
}
Ok(server_features)
}
fn probe_socket_for_fd_passing(stream: &UnixStream) -> bool {
let probe_file = match File::open("/dev/null") {
Ok(file) => file,
Err(error) => {
log::warn!(
"failed to open `/dev/null` for SCM_RIGHTS probe; assuming fd passing is unavailable: {error:#}"
);
return false;
}
};
match stream.send_with_fd(&[0_u8], &[probe_file.as_raw_fd()]) {
Ok(1) => {
log::debug!("socket SCM_RIGHTS probe succeeded");
true
}
Ok(sent) => {
log::warn!(
"socket SCM_RIGHTS probe wrote {sent} probe byte(s) instead of 1; assuming fd passing is unavailable"
);
false
}
Err(error) => {
log::debug!("socket SCM_RIGHTS probe failed: {error:#}");
false
}
}
}
fn build_request_header(
options: &ClientOptions,
stdin_size: Option<u64>,
upload_transport: UploadTransport,
stdin_transport: UploadTransport,
response_transport: ResponseTransport,
) -> Result<PreparedRequest> {
let mut upload_registry = UploadRegistry::default();
let planned_rewrites = plan_request_rewrites(
&options.target_args,
&RequestPlanningOptions {
file_flags: &options.file_flags,
file_args_enabled: options.file_args_enabled,
},
&mut upload_registry,
)?;
if !options.file_flags.is_empty() {
log::info!(
"BRO_FILE_FLAGS rewrote {} argument value(s) for transport",
planned_rewrites.stats.file_flag_rewrite_count
);
}
if options.file_args_enabled {
log::info!(
"BRO_FILE_ARGS examined {} non-flag argument(s) and auto-forwarded {} readable regular file(s)",
planned_rewrites.stats.auto_file_arg_probe_count,
planned_rewrites.stats.auto_file_arg_rewrite_count,
);
}
let (uploads, prepared_uploads) = upload_registry.into_parts();
let forwarded_env = options
.forward_env
.iter()
.filter_map(|key| env::var(key).ok().map(|value| (key.clone(), value)))
.collect::<BTreeMap<_, _>>();
Ok(PreparedRequest {
header: RequestHeader {
args: options.target_args.clone(),
env: forwarded_env,
uploads,
rewrites: planned_rewrites.rewrites,
stdin_size,
upload_transport,
stdin_transport,
response_transport,
},
uploads: prepared_uploads,
})
}
fn send_uploaded_file_bytes(stream: &mut UnixStream, uploads: &[PreparedUpload]) -> Result<()> {
for upload in uploads {
let mut source = &upload.file;
copy_exact(&mut source, stream, upload.size).with_context(|| {
format!(
"failed to stream transport file `{}` to bro-server",
upload.path.display()
)
})?;
}
Ok(())
}
fn send_passed_descriptors(
stream: &UnixStream,
uploads: &[PreparedUpload],
stdin_fd: Option<RawFd>,
) -> Result<()> {
if uploads.is_empty() && stdin_fd.is_none() {
return Ok(());
}
let marker_bytes = passed_fd_marker_bytes(uploads.len(), stdin_fd.is_some());
let mut fds = uploads
.iter()
.map(|upload| upload.file.as_raw_fd())
.collect::<Vec<_>>();
if let Some(stdin_fd) = stdin_fd {
fds.push(stdin_fd);
}
let sent = stream
.send_with_fd(&marker_bytes, &fds)
.context("failed to send passed file descriptors to bro-server")?;
if sent != marker_bytes.len() {
bail!(
"sent {sent} fd marker bytes but expected to send {}",
marker_bytes.len()
)
}
Ok(())
}
fn prepare_stdin(stdin_transport: UploadTransport) -> Result<PreparedStdin> {
match stdin_transport {
UploadTransport::StreamedBytes => {
let (spool, size) = spool_stdin()?;
Ok(PreparedStdin::Streamed { spool, size })
}
UploadTransport::PassedFileDescriptors => prepare_passed_stdin(),
}
}
fn prepare_passed_stdin() -> Result<PreparedStdin> {
let stdin = io::stdin();
if stdin.is_terminal() && !env_var_is_truthy("BRO_CAPTURE_TTY_STDIN") {
log::info!(
"stdin is a terminal; using /dev/null as fd-passed stdin to avoid blocking. Set BRO_CAPTURE_TTY_STDIN=1 to pass terminal stdin through to bro-server"
);
let devnull = File::open("/dev/null").context("failed to open `/dev/null` for stdin")?;
return Ok(PreparedStdin::Passed {
fd: devnull.as_raw_fd(),
_owner: Some(devnull),
});
}
if stdin.is_terminal() {
log::warn!(
"stdin is a terminal and BRO_CAPTURE_TTY_STDIN is enabled; passing terminal stdin directly to bro-server"
);
} else {
log::info!("using fd passing for stdin; not spooling local stdin before execution");
}
Ok(PreparedStdin::Passed {
fd: stdin.as_raw_fd(),
_owner: None,
})
}
fn spool_stdin() -> Result<(NamedTempFile, u64)> {
let mut spool = NamedTempFile::new().context("failed to create stdin spool file")?;
let stdin = io::stdin();
if stdin.is_terminal() && !env_var_is_truthy("BRO_CAPTURE_TTY_STDIN") {
log::info!(
"stdin is a terminal; using empty stdin to avoid blocking. Set BRO_CAPTURE_TTY_STDIN=1 to read terminal stdin until EOF"
);
spool
.rewind()
.context("failed to rewind empty stdin spool")?;
return Ok((spool, 0));
}
if stdin.is_terminal() {
log::warn!(
"stdin is a terminal and BRO_CAPTURE_TTY_STDIN is enabled; reading stdin until EOF"
);
}
let stdin_size = io::copy(&mut stdin.lock(), &mut spool)
.context("failed to spool stdin before transmission")?;
spool
.rewind()
.context("failed to rewind stdin spool after capture")?;
Ok((spool, stdin_size))
}
fn receive_response(mut stream: UnixStream, response_transport: ResponseTransport) -> Result<i32> {
let stdout_handle = io::stdout();
let stdout_is_terminal = stdout_handle.is_terminal();
let mut stdout = stdout_handle.lock();
let stderr_handle = io::stderr();
let stderr_is_terminal = stderr_handle.is_terminal();
let mut stderr = stderr_handle.lock();
loop {
let frame = read_response_frame(&mut stream, response_transport)?;
match frame {
ResponseFrame::Stdout(bytes) => {
stdout
.write_all(&bytes)
.context("failed to write remote stdout locally")?;
if stdout_is_terminal {
stdout.flush().context("failed to flush local stdout")?;
}
}
ResponseFrame::Stderr(bytes) => {
stderr
.write_all(&bytes)
.context("failed to write remote stderr locally")?;
if stderr_is_terminal {
stderr.flush().context("failed to flush local stderr")?;
}
}
ResponseFrame::Exit(status) => {
let exit_code = status.to_exit_code();
log::info!(
"received remote exit status {:?} -> code {}",
status,
exit_code
);
return Ok(exit_code);
}
ResponseFrame::Error(message) => {
log::warn!("received error response from bro-server: {message}");
writeln!(stderr, "bro-server: {message}")
.context("failed to print bro-server error")?;
stderr.flush().context("failed to flush local stderr")?;
return Ok(WRAPPER_ERROR_EXIT_CODE);
}
}
}
}
File diff suppressed because it is too large Load Diff
+71
View File
@@ -0,0 +1,71 @@
use anyhow::{Context, Result, anyhow, bail};
use std::{
env,
io::{self, Read, Write},
};
use tracing_subscriber::EnvFilter;
pub fn init_tracing(binary_name: &str) -> Result<()> {
let binary_target = binary_name.replace('-', "_");
let default_filter = format!("info,{binary_target}=debug");
let Ok(requested_filter) = env::var("RUST_LOG") else {
return Ok(());
};
let requested_filter = match requested_filter.trim() {
"" | "1" | "true" | "yes" | "on" => default_filter,
value => value.to_owned(),
};
let env_filter = EnvFilter::try_new(&requested_filter)
.map_err(|error| anyhow!("invalid log filter `{requested_filter}`: {error}"))?;
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.with_target(true)
.with_thread_ids(true)
.with_thread_names(true)
.compact()
.try_init()
.map_err(|error| anyhow!("failed to initialize tracing subscriber: {error}"))?;
log::debug!("initialized tracing subscriber for {binary_name}");
Ok(())
}
pub fn copy_exact<R: Read, W: Write>(reader: &mut R, writer: &mut W, size: u64) -> Result<()> {
let mut limited_reader = reader.take(size);
let copied = io::copy(&mut limited_reader, writer).context("failed to copy payload bytes")?;
if copied != size {
bail!("expected to copy {size} bytes, but copied {copied}")
}
Ok(())
}
pub fn env_var_is_set(key: &str) -> bool {
env::var_os(key).is_some_and(|value| !value.is_empty())
}
pub fn env_var_is_truthy(key: &str) -> bool {
env::var_os(key).is_some_and(|value| {
matches!(
value.to_string_lossy().trim().to_ascii_lowercase().as_str(),
"1" | "true" | "yes" | "on"
)
})
}
pub fn parse_csv_env(key: &str) -> Vec<String> {
env::var(key).map_or_else(|_| Vec::new(), |value| parse_csv_list(&value))
}
pub fn parse_csv_list(value: &str) -> Vec<String> {
value
.split(',')
.map(str::trim)
.filter(|entry| !entry.is_empty())
.map(ToOwned::to_owned)
.collect()
}
+2
View File
@@ -0,0 +1,2 @@
pub mod common;
pub mod protocol;
+828
View File
@@ -0,0 +1,828 @@
use anyhow::{Context, Result, bail};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::{
collections::{BTreeMap, HashSet},
io::{Read, Write},
os::unix::process::ExitStatusExt,
process::ExitStatus,
};
pub const EXECUTE_MAGIC: [u8; 4] = *b"BRO1";
pub const SERVER_FEATURES_MAGIC: [u8; 4] = *b"BROC";
pub const FD_PASS_MARKER: u8 = b'F';
pub const STDIN_FD_PASS_MARKER: u8 = b'S';
pub const MAX_CONTROL_MESSAGE_BYTES: u64 = 16 * 1024 * 1024;
pub const MAX_RESPONSE_FRAME_BYTES: u64 = 16 * 1024 * 1024;
pub const WRAPPER_ERROR_EXIT_CODE: i32 = 125;
const RESPONSE_FRAME_STDOUT: u8 = b'O';
const RESPONSE_FRAME_STDERR: u8 = b'E';
const RESPONSE_FRAME_EXIT: u8 = b'X';
const RESPONSE_FRAME_ERROR: u8 = b'!';
const BINARY_EXIT_STATUS_PAYLOAD_LEN: usize = 10;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionKind {
Execute,
ServerFeaturesQuery,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum UploadTransport {
#[default]
StreamedBytes,
PassedFileDescriptors,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum ResponseTransport {
#[default]
JsonMessages,
BinaryFrames,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestHeader {
pub args: Vec<String>,
pub env: BTreeMap<String, String>,
pub uploads: Vec<UploadSpec>,
pub rewrites: Vec<ArgRewrite>,
#[serde(default)]
pub stdin_size: Option<u64>,
#[serde(default)]
pub upload_transport: UploadTransport,
#[serde(default)]
pub stdin_transport: UploadTransport,
#[serde(default)]
pub response_transport: ResponseTransport,
}
impl RequestHeader {
pub fn validate_invariants(&self) -> Result<()> {
if self.stdin_transport == UploadTransport::StreamedBytes && self.stdin_size.is_none() {
bail!("request used streamed stdin but did not provide stdin_size")
}
let mut upload_ids = HashSet::with_capacity(self.uploads.len());
for upload in &self.uploads {
if !upload_ids.insert(upload.id) {
bail!("request referenced upload id {} more than once", upload.id)
}
}
for rewrite in &self.rewrites {
match rewrite {
ArgRewrite::Replace {
arg_index,
upload_id,
}
| ArgRewrite::Prefixed {
arg_index,
upload_id,
..
} => {
if *arg_index >= self.args.len() {
bail!(
"request tried to rewrite argument index {} but only {} arguments were provided",
arg_index,
self.args.len()
)
}
if !upload_ids.contains(upload_id) {
bail!(
"request tried to rewrite argument {} using unknown upload id {}",
arg_index,
upload_id
)
}
}
}
}
Ok(())
}
pub fn rewrite_args<F>(&self, mut resolve_upload: F) -> Result<Vec<String>>
where
F: FnMut(u32) -> Result<String>,
{
let mut rewritten = self.args.clone();
for rewrite in &self.rewrites {
match rewrite {
ArgRewrite::Replace {
arg_index,
upload_id,
} => {
rewritten[*arg_index] = resolve_upload(*upload_id)?;
}
ArgRewrite::Prefixed {
arg_index,
prefix,
upload_id,
} => {
rewritten[*arg_index] = format!("{prefix}{}", resolve_upload(*upload_id)?);
}
}
}
Ok(rewritten)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct UploadSpec {
pub id: u32,
pub original_path: String,
pub size: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ArgRewrite {
Replace {
arg_index: usize,
upload_id: u32,
},
Prefixed {
arg_index: usize,
prefix: String,
upload_id: u32,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ResponseFrame {
Stdout(Vec<u8>),
Stderr(Vec<u8>),
Exit(RemoteExitStatus),
Error(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RemoteExitStatus {
pub code: Option<i32>,
pub signal: Option<i32>,
}
#[derive(Debug, Clone, Copy)]
pub struct RequestPlanningOptions<'a> {
pub file_flags: &'a [String],
pub file_args_enabled: bool,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RequestPlanningStats {
pub file_flag_rewrite_count: usize,
pub auto_file_arg_probe_count: usize,
pub auto_file_arg_rewrite_count: usize,
}
#[derive(Debug, Default)]
pub struct PlannedArgRewrites {
pub rewrites: Vec<ArgRewrite>,
pub stats: RequestPlanningStats,
}
pub trait UploadRegistrar {
fn ensure_upload(&mut self, path_text: &str) -> Result<u32>;
fn maybe_ensure_upload(&mut self, path_text: &str) -> Result<Option<u32>>;
}
pub fn plan_request_rewrites<U>(
args: &[String],
options: &RequestPlanningOptions<'_>,
uploads: &mut U,
) -> Result<PlannedArgRewrites>
where
U: UploadRegistrar,
{
let mut rewrites = Vec::new();
let mut stats = RequestPlanningStats::default();
let mut skip_index = None;
for (index, argument) in args.iter().enumerate() {
if skip_index == Some(index) {
skip_index = None;
continue;
}
if let Some(matched_flag) = match_file_flag(args, index, options.file_flags)? {
match matched_flag {
MatchedFileFlag::Separate {
flag,
value_index,
value,
} => {
let upload_id = uploads.ensure_upload(value)?;
rewrites.push(ArgRewrite::Replace {
arg_index: value_index,
upload_id,
});
log::debug!(
"BRO_FILE_FLAGS matched separate flag `{flag}` at arg index {index}; forwarding argument index {value_index} -> `{value}`"
);
stats.file_flag_rewrite_count += 1;
skip_index = Some(value_index);
}
MatchedFileFlag::Joined {
flag,
prefix,
value,
} => {
let upload_id = uploads.ensure_upload(value)?;
rewrites.push(ArgRewrite::Prefixed {
arg_index: index,
prefix,
upload_id,
});
log::debug!(
"BRO_FILE_FLAGS matched joined flag `{flag}` at arg index {index}; forwarding value `{value}`"
);
stats.file_flag_rewrite_count += 1;
}
}
continue;
}
if options.file_args_enabled && is_nonflag_argument(argument) {
stats.auto_file_arg_probe_count += 1;
if let Some(upload_id) = uploads.maybe_ensure_upload(argument)? {
rewrites.push(ArgRewrite::Replace {
arg_index: index,
upload_id,
});
stats.auto_file_arg_rewrite_count += 1;
log::debug!(
"BRO_FILE_ARGS auto-forwarded non-flag argument index {index}: `{argument}`"
);
} else {
log::debug!(
"BRO_FILE_ARGS left non-flag argument index {index} unchanged because `{argument}` could not be opened as a readable regular file"
);
}
}
}
Ok(PlannedArgRewrites { rewrites, stats })
}
enum MatchedFileFlag<'a> {
Separate {
flag: &'a str,
value_index: usize,
value: &'a str,
},
Joined {
flag: &'a str,
prefix: String,
value: &'a str,
},
}
fn match_file_flag<'a>(
args: &'a [String],
index: usize,
file_flags: &'a [String],
) -> Result<Option<MatchedFileFlag<'a>>> {
let argument = &args[index];
for flag in file_flags {
if argument == flag {
let value_index = index + 1;
let value = args.get(value_index).with_context(|| {
format!("flag `{flag}` is configured as a file flag but has no value")
})?;
return Ok(Some(MatchedFileFlag::Separate {
flag: flag.as_str(),
value_index,
value,
}));
}
let prefix = format!("{flag}=");
if let Some(value) = argument.strip_prefix(&prefix) {
return Ok(Some(MatchedFileFlag::Joined {
flag: flag.as_str(),
prefix,
value,
}));
}
}
Ok(None)
}
fn is_nonflag_argument(argument: &str) -> bool {
!argument.starts_with('-') || argument == "-"
}
fn default_upload_transports() -> Vec<UploadTransport> {
vec![UploadTransport::StreamedBytes]
}
fn default_response_transports() -> Vec<ResponseTransport> {
vec![ResponseTransport::JsonMessages]
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerFeatures {
#[serde(default = "default_upload_transports")]
pub upload_transports: Vec<UploadTransport>,
#[serde(default = "default_upload_transports")]
pub stdin_transports: Vec<UploadTransport>,
#[serde(default = "default_response_transports")]
pub response_transports: Vec<ResponseTransport>,
}
impl ServerFeatures {
pub fn supports_transport(&self, transport: UploadTransport) -> bool {
self.supports_upload_transport(transport)
}
pub fn supports_upload_transport(&self, transport: UploadTransport) -> bool {
self.upload_transports.contains(&transport)
}
pub fn supports_stdin_transport(&self, transport: UploadTransport) -> bool {
self.stdin_transports.contains(&transport)
}
pub fn supports_response_transport(&self, transport: ResponseTransport) -> bool {
self.response_transports.contains(&transport)
}
}
impl RemoteExitStatus {
pub fn to_exit_code(&self) -> i32 {
if let Some(code) = self.code {
code
} else if let Some(signal) = self.signal {
128 + signal
} else {
WRAPPER_ERROR_EXIT_CODE
}
}
}
impl From<ExitStatus> for RemoteExitStatus {
fn from(status: ExitStatus) -> Self {
Self {
code: status.code(),
signal: status.signal(),
}
}
}
pub fn write_execute_magic<W: Write>(writer: &mut W) -> Result<()> {
write_magic(writer, &EXECUTE_MAGIC)
}
pub fn write_server_features_magic<W: Write>(writer: &mut W) -> Result<()> {
write_magic(writer, &SERVER_FEATURES_MAGIC)
}
fn write_magic<W: Write>(writer: &mut W, magic: &[u8; 4]) -> Result<()> {
writer
.write_all(magic)
.context("failed to write protocol magic")
}
pub fn read_connection_kind<R: Read>(reader: &mut R) -> Result<ConnectionKind> {
let mut magic = [0_u8; 4];
reader
.read_exact(&mut magic)
.context("failed to read protocol magic")?;
match magic {
EXECUTE_MAGIC => Ok(ConnectionKind::Execute),
SERVER_FEATURES_MAGIC => Ok(ConnectionKind::ServerFeaturesQuery),
_ => bail!(
"received an unsupported protocol header: expected `{EXECUTE_MAGIC:?}` or `{SERVER_FEATURES_MAGIC:?}`, got `{magic:?}`"
),
}
}
pub fn write_execute_request_header<W: Write>(
writer: &mut W,
header: &RequestHeader,
) -> Result<()> {
write_message(writer, header)
}
pub fn read_execute_request_header<R: Read>(reader: &mut R) -> Result<RequestHeader> {
read_message(reader)
}
pub fn write_server_features_response<W: Write>(
writer: &mut W,
server_features: &ServerFeatures,
) -> Result<()> {
write_message(writer, server_features)
}
pub fn read_server_features_response<R: Read>(reader: &mut R) -> Result<ServerFeatures> {
read_message(reader)
}
pub fn passed_fd_marker_bytes(upload_count: usize, include_stdin: bool) -> Vec<u8> {
let mut marker_bytes = std::iter::repeat_n(FD_PASS_MARKER, upload_count).collect::<Vec<_>>();
if include_stdin {
marker_bytes.push(STDIN_FD_PASS_MARKER);
}
marker_bytes
}
pub fn expected_passed_fd_count(
upload_transport: UploadTransport,
upload_count: usize,
stdin_transport: UploadTransport,
) -> usize {
let upload_fd_count = match upload_transport {
UploadTransport::StreamedBytes => 0,
UploadTransport::PassedFileDescriptors => upload_count,
};
upload_fd_count + usize::from(stdin_transport == UploadTransport::PassedFileDescriptors)
}
pub fn validate_passed_fd_marker_bytes(
marker_bytes: &[u8],
upload_count: usize,
expects_stdin: bool,
) -> Result<()> {
let expected_count = upload_count + usize::from(expects_stdin);
if marker_bytes.len() != expected_count {
bail!(
"received {} fd marker bytes but expected {}",
marker_bytes.len(),
expected_count
)
}
if marker_bytes[..upload_count]
.iter()
.any(|byte| *byte != FD_PASS_MARKER)
{
bail!("received invalid upload file descriptor marker bytes from bro-client")
}
if expects_stdin && marker_bytes[upload_count] != STDIN_FD_PASS_MARKER {
bail!("received invalid stdin file descriptor marker byte from bro-client")
}
Ok(())
}
pub fn write_message<W: Write, T: Serialize>(writer: &mut W, value: &T) -> Result<()> {
let payload =
serde_json::to_vec(value).context("failed to serialize protocol message as JSON")?;
let payload_len = u64::try_from(payload.len()).context("protocol message is too large")?;
writer
.write_all(&payload_len.to_le_bytes())
.context("failed to write protocol message length")?;
writer
.write_all(&payload)
.context("failed to write protocol message payload")?;
writer.flush().context("failed to flush protocol message")?;
Ok(())
}
pub fn read_message<R: Read, T: DeserializeOwned>(reader: &mut R) -> Result<T> {
let payload =
read_length_prefixed_payload(reader, MAX_CONTROL_MESSAGE_BYTES, "protocol message")?;
serde_json::from_slice(&payload).context("failed to deserialize protocol message from JSON")
}
pub fn write_response_frame<W: Write>(
writer: &mut W,
transport: ResponseTransport,
frame: &ResponseFrame,
) -> Result<()> {
match transport {
ResponseTransport::JsonMessages => write_message(writer, frame),
ResponseTransport::BinaryFrames => write_binary_response_frame(writer, frame),
}
}
pub fn read_response_frame<R: Read>(
reader: &mut R,
transport: ResponseTransport,
) -> Result<ResponseFrame> {
match transport {
ResponseTransport::JsonMessages => read_message(reader),
ResponseTransport::BinaryFrames => read_binary_response_frame(reader),
}
}
fn write_binary_response_frame<W: Write>(writer: &mut W, frame: &ResponseFrame) -> Result<()> {
match frame {
ResponseFrame::Stdout(bytes) => {
write_binary_response_payload(writer, RESPONSE_FRAME_STDOUT, bytes)
}
ResponseFrame::Stderr(bytes) => {
write_binary_response_payload(writer, RESPONSE_FRAME_STDERR, bytes)
}
ResponseFrame::Exit(status) => {
let mut payload = Vec::with_capacity(BINARY_EXIT_STATUS_PAYLOAD_LEN);
payload.push(u8::from(status.code.is_some()));
payload.extend_from_slice(&status.code.unwrap_or_default().to_le_bytes());
payload.push(u8::from(status.signal.is_some()));
payload.extend_from_slice(&status.signal.unwrap_or_default().to_le_bytes());
write_binary_response_payload(writer, RESPONSE_FRAME_EXIT, &payload)
}
ResponseFrame::Error(message) => {
write_binary_response_payload(writer, RESPONSE_FRAME_ERROR, message.as_bytes())
}
}
}
fn write_binary_response_payload<W: Write>(writer: &mut W, tag: u8, payload: &[u8]) -> Result<()> {
let payload_len =
u64::try_from(payload.len()).context("response frame payload is too large")?;
writer
.write_all(&[tag])
.context("failed to write response frame type")?;
writer
.write_all(&payload_len.to_le_bytes())
.context("failed to write response frame length")?;
writer
.write_all(payload)
.context("failed to write response frame payload")?;
Ok(())
}
fn read_binary_response_frame<R: Read>(reader: &mut R) -> Result<ResponseFrame> {
let mut tag = [0_u8; 1];
reader
.read_exact(&mut tag)
.context("failed to read response frame type")?;
let payload = read_length_prefixed_payload(reader, MAX_RESPONSE_FRAME_BYTES, "response frame")?;
match tag[0] {
RESPONSE_FRAME_STDOUT => Ok(ResponseFrame::Stdout(payload)),
RESPONSE_FRAME_STDERR => Ok(ResponseFrame::Stderr(payload)),
RESPONSE_FRAME_ERROR => Ok(ResponseFrame::Error(
String::from_utf8(payload).context("response error frame was not valid UTF-8")?,
)),
RESPONSE_FRAME_EXIT => read_binary_exit_status(&payload),
_ => bail!("received unknown response frame type `{}`", tag[0]),
}
}
fn read_binary_exit_status(payload: &[u8]) -> Result<ResponseFrame> {
if payload.len() != BINARY_EXIT_STATUS_PAYLOAD_LEN {
bail!(
"binary exit status frame payload had length {} but expected {}",
payload.len(),
BINARY_EXIT_STATUS_PAYLOAD_LEN
)
}
let code = if payload[0] == 0 {
None
} else {
Some(i32::from_le_bytes([
payload[1], payload[2], payload[3], payload[4],
]))
};
let signal = if payload[5] == 0 {
None
} else {
Some(i32::from_le_bytes([
payload[6], payload[7], payload[8], payload[9],
]))
};
Ok(ResponseFrame::Exit(RemoteExitStatus { code, signal }))
}
fn read_length_prefixed_payload<R: Read>(
reader: &mut R,
max_payload_len: u64,
payload_kind: &str,
) -> Result<Vec<u8>> {
let mut length_bytes = [0_u8; 8];
reader
.read_exact(&mut length_bytes)
.with_context(|| format!("failed to read {payload_kind} length"))?;
let payload_len = u64::from_le_bytes(length_bytes);
if payload_len > max_payload_len {
bail!(
"{payload_kind} length {payload_len} exceeds the maximum supported size of {max_payload_len} bytes"
)
}
let payload_len = usize::try_from(payload_len)
.with_context(|| format!("{payload_kind} does not fit in memory"))?;
let mut payload = vec![0_u8; payload_len];
reader
.read_exact(&mut payload)
.with_context(|| format!("failed to read {payload_kind} payload"))?;
Ok(payload)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::io::Cursor;
#[derive(Default)]
struct MockUploadRegistrar {
available_paths: HashSet<String>,
upload_ids_by_path: HashMap<String, u32>,
next_upload_id: u32,
}
impl MockUploadRegistrar {
fn with_available_paths(paths: impl IntoIterator<Item = &'static str>) -> Self {
Self {
available_paths: paths.into_iter().map(ToOwned::to_owned).collect(),
..Self::default()
}
}
fn register(&mut self, path_text: &str) -> u32 {
if let Some(&upload_id) = self.upload_ids_by_path.get(path_text) {
return upload_id;
}
let upload_id = self.next_upload_id;
self.next_upload_id += 1;
self.upload_ids_by_path
.insert(path_text.to_owned(), upload_id);
upload_id
}
}
impl UploadRegistrar for MockUploadRegistrar {
fn ensure_upload(&mut self, path_text: &str) -> Result<u32> {
if !self.available_paths.contains(path_text) {
bail!("mock upload `{path_text}` is unavailable")
}
Ok(self.register(path_text))
}
fn maybe_ensure_upload(&mut self, path_text: &str) -> Result<Option<u32>> {
Ok(self
.available_paths
.contains(path_text)
.then(|| self.register(path_text)))
}
}
#[test]
fn plan_request_rewrites_handles_file_flags_and_file_args() -> Result<()> {
let args = vec![
"--config".to_owned(),
"config.toml".to_owned(),
"--input=data.txt".to_owned(),
"note.txt".to_owned(),
"-n".to_owned(),
];
let file_flags = vec!["--config".to_owned(), "--input".to_owned()];
let options = RequestPlanningOptions {
file_flags: &file_flags,
file_args_enabled: true,
};
let mut uploads =
MockUploadRegistrar::with_available_paths(["config.toml", "data.txt", "note.txt"]);
let planned = plan_request_rewrites(&args, &options, &mut uploads)?;
assert_eq!(
planned.rewrites,
vec![
ArgRewrite::Replace {
arg_index: 1,
upload_id: 0,
},
ArgRewrite::Prefixed {
arg_index: 2,
prefix: "--input=".to_owned(),
upload_id: 1,
},
ArgRewrite::Replace {
arg_index: 3,
upload_id: 2,
},
]
);
assert_eq!(
planned.stats,
RequestPlanningStats {
file_flag_rewrite_count: 2,
auto_file_arg_probe_count: 1,
auto_file_arg_rewrite_count: 1,
}
);
Ok(())
}
#[test]
fn plan_request_rewrites_treats_double_dash_like_any_other_argument() -> Result<()> {
let args = vec![
"--".to_owned(),
"--config".to_owned(),
"plain.txt".to_owned(),
];
let file_flags = vec!["--config".to_owned()];
let options = RequestPlanningOptions {
file_flags: &file_flags,
file_args_enabled: true,
};
let mut uploads = MockUploadRegistrar::with_available_paths(["--config", "plain.txt"]);
let planned = plan_request_rewrites(&args, &options, &mut uploads)?;
assert_eq!(
planned.rewrites,
vec![ArgRewrite::Replace {
arg_index: 2,
upload_id: 0,
},]
);
assert_eq!(planned.stats.file_flag_rewrite_count, 1);
assert_eq!(planned.stats.auto_file_arg_probe_count, 0);
assert_eq!(planned.stats.auto_file_arg_rewrite_count, 0);
Ok(())
}
#[test]
fn request_header_validates_upload_ids_and_rewrite_references() {
let header = RequestHeader {
args: vec!["a".to_owned()],
env: BTreeMap::new(),
uploads: vec![
UploadSpec {
id: 0,
original_path: "x".to_owned(),
size: 1,
},
UploadSpec {
id: 0,
original_path: "y".to_owned(),
size: 1,
},
],
rewrites: vec![ArgRewrite::Replace {
arg_index: 0,
upload_id: 1,
}],
stdin_size: Some(0),
upload_transport: UploadTransport::StreamedBytes,
stdin_transport: UploadTransport::StreamedBytes,
response_transport: ResponseTransport::JsonMessages,
};
assert!(header.validate_invariants().is_err());
}
#[test]
fn request_header_requires_streamed_stdin_size() {
let header = RequestHeader {
args: Vec::new(),
env: BTreeMap::new(),
uploads: Vec::new(),
rewrites: Vec::new(),
stdin_size: None,
upload_transport: UploadTransport::StreamedBytes,
stdin_transport: UploadTransport::StreamedBytes,
response_transport: ResponseTransport::JsonMessages,
};
assert!(header.validate_invariants().is_err());
}
#[test]
fn binary_response_frames_round_trip() -> Result<()> {
let frames = [
ResponseFrame::Stdout(vec![1, 2, 3]),
ResponseFrame::Stderr(vec![4, 5]),
ResponseFrame::Exit(RemoteExitStatus {
code: Some(42),
signal: None,
}),
ResponseFrame::Error("boom".to_owned()),
];
let mut buffer = Cursor::new(Vec::new());
for frame in &frames {
write_response_frame(&mut buffer, ResponseTransport::BinaryFrames, frame)?;
}
buffer.set_position(0);
for expected in frames {
let actual = read_response_frame(&mut buffer, ResponseTransport::BinaryFrames)?;
assert_eq!(actual, expected);
}
Ok(())
}
}