fix(cli,tcp): replies are sent on the requesting channel

Replace the client socket with replies sent on the other side of the
querying stream, for both UDS and TCP clients. This has two results:
unix socket clients such as komorebic no longer race on the socket bind,
fixing the out of order bind error, and the response mixup conditions
that could occur. Queries over TCP now receive replies over TCP, rather
than replies being sent to a future or in-flight command line client.
This commit is contained in:
James Tucker
2024-02-11 01:57:43 -08:00
committed by جاد
parent afd93c34a2
commit c8f6502b02
2 changed files with 59 additions and 120 deletions

View File

@@ -4,7 +4,6 @@ use std::fs::OpenOptions;
use std::io::BufRead;
use std::io::BufReader;
use std::io::Read;
use std::io::Write;
use std::net::TcpListener;
use std::net::TcpStream;
use std::num::NonZeroUsize;
@@ -60,7 +59,6 @@ use crate::BORDER_OFFSET;
use crate::BORDER_OVERFLOW_IDENTIFIERS;
use crate::BORDER_WIDTH;
use crate::CUSTOM_FFM;
use crate::DATA_DIR;
use crate::DISPLAY_INDEX_PREFERENCES;
use crate::FLOAT_IDENTIFIERS;
use crate::HIDING_BEHAVIOUR;
@@ -144,8 +142,15 @@ pub fn listen_for_commands_tcp(wm: Arc<Mutex<WindowManager>>, port: usize) {
}
impl WindowManager {
#[tracing::instrument(skip(self))]
pub fn process_command(&mut self, message: SocketMessage) -> Result<()> {
// TODO(raggi): wrap reply in a newtype that can decorate a human friendly
// name for the peer, such as getting the pid of the komorebic process for
// the UDS or the IP:port for TCP.
#[tracing::instrument(skip(self, reply))]
pub fn process_command(
&mut self,
message: SocketMessage,
mut reply: impl std::io::Write,
) -> Result<()> {
if let Some(virtual_desktop_id) = &self.virtual_desktop_id {
if let Some(id) = current_virtual_desktop() {
if id != *virtual_desktop_id {
@@ -743,15 +748,11 @@ impl WindowManager {
Err(error) => error.to_string(),
};
let socket = DATA_DIR.join("komorebic.sock");
tracing::info!("replying to state");
let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(state.as_bytes())?;
}
}
reply.write_all(state.as_bytes())?;
tracing::info!("replying to state done");
}
SocketMessage::VisibleWindows => {
let mut monitor_visible_windows = HashMap::new();
@@ -774,15 +775,7 @@ impl WindowManager {
Err(error) => error.to_string(),
};
let socket = DATA_DIR.join("komorebic.sock");
let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(visible_windows_state.as_bytes())?;
}
}
reply.write_all(visible_windows_state.as_bytes())?;
}
SocketMessage::Query(query) => {
@@ -801,15 +794,7 @@ impl WindowManager {
}
.to_string();
let socket = DATA_DIR.join("komorebic.sock");
let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(response.as_bytes())?;
}
}
reply.write_all(response.as_bytes())?;
}
SocketMessage::ResizeWindowEdge(direction, sizing) => {
self.resize_window(direction, sizing, self.resize_delta, true)?;
@@ -1275,41 +1260,20 @@ impl WindowManager {
SocketMessage::ApplicationSpecificConfigurationSchema => {
let asc = schema_for!(Vec<ApplicationConfiguration>);
let schema = serde_json::to_string_pretty(&asc)?;
let socket = DATA_DIR.join("komorebic.sock");
let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(schema.as_bytes())?;
}
}
reply.write_all(schema.as_bytes())?;
}
SocketMessage::NotificationSchema => {
let notification = schema_for!(Notification);
let schema = serde_json::to_string_pretty(&notification)?;
let socket = DATA_DIR.join("komorebic.sock");
let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(schema.as_bytes())?;
}
}
reply.write_all(schema.as_bytes())?;
}
SocketMessage::SocketSchema => {
let socket_message = schema_for!(SocketMessage);
let schema = serde_json::to_string_pretty(&socket_message)?;
let socket = DATA_DIR.join("komorebic.sock");
let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(schema.as_bytes())?;
}
}
reply.write_all(schema.as_bytes())?;
}
SocketMessage::StaticConfigSchema => {
let settings = SchemaSettings::default().with(|s| {
@@ -1321,27 +1285,13 @@ impl WindowManager {
let gen = settings.into_generator();
let socket_message = gen.into_root_schema_for::<StaticConfig>();
let schema = serde_json::to_string_pretty(&socket_message)?;
let socket = DATA_DIR.join("komorebic.sock");
let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(schema.as_bytes())?;
}
}
reply.write_all(schema.as_bytes())?;
}
SocketMessage::GenerateStaticConfig => {
let config = serde_json::to_string_pretty(&StaticConfig::from(&*self))?;
let socket = DATA_DIR.join("komorebic.sock");
let mut connected = false;
while !connected {
if let Ok(mut stream) = UnixStream::connect(&socket) {
connected = true;
stream.write_all(config.as_bytes())?;
}
}
reply.write_all(config.as_bytes())?;
}
SocketMessage::RemoveTitleBar(_, ref id) => {
let mut identifiers = NO_TITLEBAR.lock();
@@ -1526,9 +1476,13 @@ impl WindowManager {
}
}
pub fn read_commands_uds(wm: &Arc<Mutex<WindowManager>>, stream: UnixStream) -> Result<()> {
let stream = BufReader::new(stream);
for line in stream.lines() {
pub fn read_commands_uds(wm: &Arc<Mutex<WindowManager>>, mut stream: UnixStream) -> Result<()> {
let reader = BufReader::new(stream.try_clone()?);
// TODO(raggi): while this processes more than one command, if there are
// replies there is no clearly defined protocol for framing yet - it's
// perhaps whole-json objects for now, but termination is signalled by
// socket shutdown.
for line in reader.lines() {
let message = SocketMessage::from_str(&line?)?;
let mut wm = wm.lock();
@@ -1536,7 +1490,7 @@ pub fn read_commands_uds(wm: &Arc<Mutex<WindowManager>>, stream: UnixStream) ->
if wm.is_paused {
return match message {
SocketMessage::TogglePause | SocketMessage::State | SocketMessage::Stop => {
Ok(wm.process_command(message)?)
Ok(wm.process_command(message, &mut stream)?)
}
_ => {
tracing::trace!("ignoring while paused");
@@ -1545,7 +1499,7 @@ pub fn read_commands_uds(wm: &Arc<Mutex<WindowManager>>, stream: UnixStream) ->
};
}
wm.process_command(message.clone())?;
wm.process_command(message.clone(), &mut stream)?;
notify_subscribers(&serde_json::to_string(&Notification {
event: NotificationEvent::Socket(message.clone()),
state: wm.as_ref().into(),
@@ -1560,11 +1514,11 @@ pub fn read_commands_tcp(
stream: &mut TcpStream,
addr: &str,
) -> Result<()> {
let mut stream = BufReader::new(stream);
let mut reader = BufReader::new(stream.try_clone()?);
loop {
let mut buf = vec![0; 1024];
match stream.read(&mut buf) {
match reader.read(&mut buf) {
Err(..) => {
tracing::warn!("removing disconnected tcp client: {addr}");
let mut connections = TCP_CONNECTIONS.lock();
@@ -1585,7 +1539,7 @@ pub fn read_commands_tcp(
if wm.is_paused {
return match message {
SocketMessage::TogglePause | SocketMessage::State | SocketMessage::Stop => {
Ok(wm.process_command(message)?)
Ok(wm.process_command(message, stream)?)
}
_ => {
tracing::trace!("ignoring while paused");
@@ -1594,7 +1548,7 @@ pub fn read_commands_tcp(
};
}
wm.process_command(message.clone())?;
wm.process_command(message.clone(), &mut *stream)?;
notify_subscribers(&serde_json::to_string(&Notification {
event: NotificationEvent::Socket(message.clone()),
state: wm.as_ref().into(),

View File

@@ -5,8 +5,9 @@ use std::fs::File;
use std::fs::OpenOptions;
use std::io::BufRead;
use std::io::BufReader;
use std::io::ErrorKind;
use std::io::Read;
use std::io::Write;
use std::net::Shutdown;
use std::path::Path;
use std::path::PathBuf;
use std::process::Command;
@@ -30,7 +31,6 @@ use miette::Report;
use miette::SourceOffset;
use miette::SourceSpan;
use paste::paste;
use uds_windows::UnixListener;
use uds_windows::UnixStream;
use which::which;
use windows::Win32::Foundation::HWND;
@@ -1172,35 +1172,26 @@ pub fn send_message(bytes: &[u8]) -> Result<()> {
Ok(())
}
fn with_komorebic_socket<F: Fn() -> Result<()>>(f: F) -> Result<()> {
let socket = DATA_DIR.join("komorebic.sock");
pub fn send_query(bytes: &[u8]) -> Result<String> {
let socket = DATA_DIR.join("komorebi.sock");
match std::fs::remove_file(&socket) {
Ok(()) => {}
Err(error) => match error.kind() {
// Doing this because ::exists() doesn't work reliably on Windows via IntelliJ
ErrorKind::NotFound => {}
_ => {
return Err(error.into());
}
},
};
let mut stream = UnixStream::connect(&socket)?;
stream.write_all(bytes)?;
stream.shutdown(Shutdown::Write)?;
f()?;
let mut reader = BufReader::new(stream);
let mut response = String::new();
reader.read_to_string(&mut response)?;
let listener = UnixListener::bind(socket)?;
match listener.accept() {
Ok(incoming) => {
let stream = BufReader::new(incoming.0);
for line in stream.lines() {
println!("{}", line?);
}
Ok(response)
}
Ok(())
}
Err(error) => {
panic!("{}", error);
}
// print_query is a helper that queries komorebi and prints the response.
// panics on error.
pub fn print_query(bytes: &[u8]) {
match send_query(bytes) {
Ok(response) => println!("{}", response),
Err(error) => panic!("{}", error),
}
}
@@ -2000,15 +1991,13 @@ Stop-Process -Name:whkd -ErrorAction SilentlyContinue
)?;
}
SubCommand::State => {
with_komorebic_socket(|| send_message(&SocketMessage::State.as_bytes()?))?;
print_query(&SocketMessage::State.as_bytes()?);
}
SubCommand::VisibleWindows => {
with_komorebic_socket(|| send_message(&SocketMessage::VisibleWindows.as_bytes()?))?;
print_query(&SocketMessage::VisibleWindows.as_bytes()?);
}
SubCommand::Query(arg) => {
with_komorebic_socket(|| {
send_message(&SocketMessage::Query(arg.state_query).as_bytes()?)
})?;
print_query(&SocketMessage::Query(arg.state_query).as_bytes()?);
}
SubCommand::RestoreWindows => {
let hwnd_json = DATA_DIR.join("komorebi.hwnd.json");
@@ -2239,23 +2228,19 @@ Stop-Process -Name:whkd -ErrorAction SilentlyContinue
);
}
SubCommand::ApplicationSpecificConfigurationSchema => {
with_komorebic_socket(|| {
send_message(&SocketMessage::ApplicationSpecificConfigurationSchema.as_bytes()?)
})?;
print_query(&SocketMessage::ApplicationSpecificConfigurationSchema.as_bytes()?);
}
SubCommand::NotificationSchema => {
with_komorebic_socket(|| send_message(&SocketMessage::NotificationSchema.as_bytes()?))?;
print_query(&SocketMessage::NotificationSchema.as_bytes()?);
}
SubCommand::SocketSchema => {
with_komorebic_socket(|| send_message(&SocketMessage::SocketSchema.as_bytes()?))?;
print_query(&SocketMessage::SocketSchema.as_bytes()?);
}
SubCommand::StaticConfigSchema => {
with_komorebic_socket(|| send_message(&SocketMessage::StaticConfigSchema.as_bytes()?))?;
print_query(&SocketMessage::StaticConfigSchema.as_bytes()?);
}
SubCommand::GenerateStaticConfig => {
with_komorebic_socket(|| {
send_message(&SocketMessage::GenerateStaticConfig.as_bytes()?)
})?;
print_query(&SocketMessage::GenerateStaticConfig.as_bytes()?);
}
}