mirror of
https://github.com/mountain-loop/yaak.git
synced 2026-06-30 18:11:39 +02:00
Add request message size setting
This commit is contained in:
@@ -20,6 +20,7 @@ pub async fn ws_connect(
|
||||
headers: HeaderMap<HeaderValue>,
|
||||
validate_certificates: bool,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
request_message_size: i32,
|
||||
) -> Result<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response)> {
|
||||
info!("Connecting to WS {url}");
|
||||
let tls_config = get_tls_config(validate_certificates, WITH_ALPN, client_cert.clone())?;
|
||||
@@ -34,7 +35,7 @@ pub async fn ws_connect(
|
||||
|
||||
let (stream, response) = connect_async_tls_with_config(
|
||||
req,
|
||||
Some(WebSocketConfig::default()),
|
||||
Some(websocket_config(request_message_size)),
|
||||
false,
|
||||
Some(Connector::Rustls(Arc::new(tls_config))),
|
||||
)
|
||||
@@ -48,3 +49,12 @@ pub async fn ws_connect(
|
||||
|
||||
Ok((stream, response))
|
||||
}
|
||||
|
||||
fn websocket_config(request_message_size: i32) -> WebSocketConfig {
|
||||
let max_message_size = message_size_limit(request_message_size);
|
||||
WebSocketConfig::default().max_message_size(max_message_size).max_frame_size(max_message_size)
|
||||
}
|
||||
|
||||
pub(crate) fn message_size_limit(setting: i32) -> Option<usize> {
|
||||
setting.try_into().ok().filter(|limit| *limit > 0)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::connect::ws_connect;
|
||||
use crate::connect::{message_size_limit, ws_connect};
|
||||
use crate::error::Error::GenericError;
|
||||
use crate::error::Result;
|
||||
use futures_util::stream::SplitSink;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
@@ -15,10 +16,16 @@ use tokio_tungstenite::tungstenite::http::HeaderValue;
|
||||
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
|
||||
use yaak_tls::ClientCertificateConfig;
|
||||
|
||||
type WebsocketSink = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
|
||||
|
||||
struct WebsocketConnection {
|
||||
max_message_size: Option<usize>,
|
||||
sink: WebsocketSink,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WebsocketManager {
|
||||
connections:
|
||||
Arc<Mutex<HashMap<String, SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>>,
|
||||
connections: Arc<Mutex<HashMap<String, WebsocketConnection>>>,
|
||||
read_tasks: Arc<Mutex<HashMap<String, tokio::task::JoinHandle<()>>>>,
|
||||
}
|
||||
|
||||
@@ -35,14 +42,20 @@ impl WebsocketManager {
|
||||
receive_tx: mpsc::Sender<Message>,
|
||||
validate_certificates: bool,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
request_message_size: i32,
|
||||
) -> Result<Response> {
|
||||
let tx = receive_tx.clone();
|
||||
let max_message_size = message_size_limit(request_message_size);
|
||||
|
||||
let (stream, response) =
|
||||
ws_connect(url, headers, validate_certificates, client_cert).await?;
|
||||
ws_connect(url, headers, validate_certificates, client_cert, request_message_size)
|
||||
.await?;
|
||||
let (write, mut read) = stream.split();
|
||||
|
||||
self.connections.lock().await.insert(id.to_string(), write);
|
||||
self.connections
|
||||
.lock()
|
||||
.await
|
||||
.insert(id.to_string(), WebsocketConnection { max_message_size, sink: write });
|
||||
|
||||
let handle = {
|
||||
let connection_id = id.to_string();
|
||||
@@ -76,7 +89,15 @@ impl WebsocketManager {
|
||||
None => return Ok(()),
|
||||
Some(c) => c,
|
||||
};
|
||||
connection.send(msg).await?;
|
||||
if let Some(limit) = connection.max_message_size {
|
||||
let message_size = msg.len();
|
||||
if message_size > limit {
|
||||
return Err(GenericError(format!(
|
||||
"WebSocket message too large: found {message_size} bytes, the limit is {limit} bytes"
|
||||
)));
|
||||
}
|
||||
}
|
||||
connection.sink.send(msg).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -84,7 +105,7 @@ impl WebsocketManager {
|
||||
info!("Closing websocket");
|
||||
if let Some(mut connection) = self.connections.lock().await.remove(id) {
|
||||
// Wait a maximum of 1 second for the connection to close
|
||||
if let Err(e) = connection.close().await {
|
||||
if let Err(e) = connection.sink.close().await {
|
||||
warn!("Failed to close websocket connection {e:?}");
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user