Skip to content

Commit

Permalink
feat: support gzip & zstd compression
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy committed Nov 17, 2024
1 parent 03a5fdd commit e5a5022
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 5 deletions.
99 changes: 98 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mosec"
version = "0.8.9"
version = "0.9.0"
authors = ["Keming <[email protected]>", "Zichen <[email protected]>"]
edition = "2021"
license = "Apache-2.0"
Expand All @@ -25,3 +25,5 @@ serde = "1.0"
serde_json = "1.0"
utoipa = "5"
utoipa-swagger-ui = { version = "8", features = ["axum"] }
tower = "0.5.1"
tower-http = {version = "0.6.1", features = ["compression-zstd", "decompression-zstd", "compression-gzip", "decompression-gzip"]}
7 changes: 7 additions & 0 deletions mosec/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def build_arguments_parser() -> argparse.ArgumentParser:
"This will omit the worker number for each stage.",
action="store_true",
)

parser.add_argument(
"--compression",
help="Enable Zstd & Gzip compression for the request body",
action="store_true",
)

return parser


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Rust",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
Expand Down
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ ruff>=0.7
pre-commit>=2.15.0
httpx[http2]==0.27.2
httpx-sse==0.4.0
zstandard~=0.23
3 changes: 3 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub(crate) struct Config {
pub namespace: String,
// log level: (debug, info, warning, error)
pub log_level: String,
// Zstd & Gzip compression
pub compression: bool,
pub runtimes: Vec<Runtime>,
pub routes: Vec<Route>,
}
Expand All @@ -79,6 +81,7 @@ impl Default for Config {
port: 8000,
namespace: String::from("mosec_service"),
log_level: String::from("info"),
compression: false,
runtimes: vec![Runtime {
max_batch_size: 64,
max_wait_time: 3000,
Expand Down
15 changes: 14 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#![forbid(unsafe_code)]

mod apidoc;
mod config;
mod errors;
Expand All @@ -27,6 +29,9 @@ use std::net::SocketAddr;
use axum::routing::{get, post};
use axum::Router;
use tokio::signal::unix::{signal, SignalKind};
use tower::ServiceBuilder;
use tower_http::compression::CompressionLayer;
use tower_http::decompression::RequestDecompressionLayer;
use tracing::{debug, info};
use tracing_subscriber::fmt::time::UtcTime;
use tracing_subscriber::prelude::*;
Expand Down Expand Up @@ -90,12 +95,20 @@ async fn run(conf: &Config) {
}
}

if conf.compression {
router = router.layer(
ServiceBuilder::new()
.layer(RequestDecompressionLayer::new())
.layer(CompressionLayer::new()),
);
}

// wait until each stage has at least one worker alive
barrier.wait().await;
let addr: SocketAddr = format!("{}:{}", conf.address, conf.port).parse().unwrap();
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
info!(?addr, "http service is running");
axum::serve(listener, router.into_make_service())
axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn build_response(status: StatusCode, content: Bytes) -> Response<Body> {
),
),
)]
pub(crate) async fn index(_: Request<Body>) -> Response<Body> {
pub(crate) async fn index() -> Response<Body> {
let task_manager = TaskManager::global();
if task_manager.is_shutdown() {
build_response(
Expand All @@ -79,7 +79,7 @@ pub(crate) async fn index(_: Request<Body>) -> Response<Body> {
(status = StatusCode::OK, description = "Get metrics", body = String),
),
)]
pub(crate) async fn metrics(_: Request<Body>) -> Response<Body> {
pub(crate) async fn metrics() -> Response<Body> {
let mut encoded = String::new();
let registry = REGISTRY.get().unwrap();
encode(&mut encoded, registry).unwrap();
Expand Down
41 changes: 41 additions & 0 deletions tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""End-to-end service tests."""

import gzip
import json
import random
import re
import shlex
Expand All @@ -26,6 +28,7 @@
import msgpack # type: ignore
import pytest
from httpx_sse import connect_sse
from zstandard import ZstdCompressor

from mosec.server import GUARD_CHECK_INTERVAL
from tests.utils import wait_for_port_free, wait_for_port_open
Expand Down Expand Up @@ -349,3 +352,41 @@ def test_multi_route_service(mosec_service, http_client):
assert resp.status_code == HTTPStatus.OK, resp
assert resp.headers["content-type"] == "application/msgpack"
assert msgpack.unpackb(resp.content) == {"length": len(data)}


@pytest.mark.parametrize(
"mosec_service, http_client",
[
pytest.param("square_service --compression --debug", "", id="compression"),
],
indirect=["mosec_service", "http_client"],
)
def test_compression_service(mosec_service, http_client):
zstd_compressor = ZstdCompressor()
req = {"x": 2}
expect = {"x": 4}

# test without compression
resp = http_client.post("/inference", json=req)
assert resp.status_code == HTTPStatus.OK, resp
assert resp.json() == expect, resp.content

# test with gzip compression
binary = gzip.compress(json.dumps(req).encode())
resp = http_client.post(
"/inference",
content=binary,
headers={"Accept-Encoding": "gzip", "Content-Encoding": "gzip"},
)
assert resp.status_code == HTTPStatus.OK, resp
assert resp.json() == expect, resp.content

# test with zstd compression
binary = zstd_compressor.compress(json.dumps(req).encode())
resp = http_client.post(
"/inference",
content=binary,
headers={"Accept-Encoding": "zstd", "Content-Encoding": "zstd"},
)
assert resp.status_code == HTTPStatus.OK, resp
assert resp.json() == expect, resp.content

0 comments on commit e5a5022

Please sign in to comment.