diff --git a/.coverage b/.coverage new file mode 100644 index 0000000..3e63c09 Binary files /dev/null and b/.coverage differ diff --git a/main.py b/main.py index c9a7715..7f77b18 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,34 @@ -import argparse import os +import time +import argparse import uvicorn from fastapi import FastAPI, UploadFile from fastapi.encoders import jsonable_encoder +from fastapi.responses import PlainTextResponse +from starlette.middleware.base import BaseHTTPMiddleware +os.environ['PROMETHEUS_DISABLE_CREATED_SERIES'] = 'true' +# pylint: disable=wrong-import-position +from prometheus_client import generate_latest, Counter, Histogram, Summary, REGISTRY from config import TEMP_DIR os.makedirs(TEMP_DIR, exist_ok=True) + # Specify mode parser = argparse.ArgumentParser(description='Start service with different modes.') parser.add_argument('--langchain', action='store_true') parser.add_argument('--towhee', action='store_true') +parser.add_argument('--max_observation', default=1000) args = parser.parse_args() + USE_LANGCHAIN = args.langchain USE_TOWHEE = args.towhee +MAX_OBSERVATION = args.max_observation assert (USE_LANGCHAIN and not USE_TOWHEE ) or (USE_TOWHEE and not USE_LANGCHAIN), \ 'The service should start with either "--langchain" or "--towhee".' @@ -28,12 +38,90 @@ if USE_TOWHEE: from src_towhee.operations import chat, insert, drop # pylint: disable=C0413 + app = FastAPI() origins = ['*'] + +# Define metrcis +requests_total = Counter('requests_total', 'Cumulative requests') +requests_success_total = Counter('requests_success_total', 'Cumulative successful requests') +requests_failed_total = Counter('requests_failed_total', 'Cumulative failed requests') + +endpoint_requests_total = Counter('endpoint_requests_total', 'Cumulative requests of each endpoint', ['endpoint']) +endpoint_requests_success_total = Counter('endpoint_requests_success_total', 'Cumulative successful requests of each endpoint', ['endpoint']) +endpoint_requests_failed_total = Counter('endpoint_requests_failed_total', 'Cumulative failed requests of each endpoint', ['endpoint']) + +latency_seconds_histogram = Histogram('latency_seconds_histogram', 'Request process latency histogram') +endpoint_latency_seconds_histogram = Histogram( + 'endpoint_latency_seconds_histogram', 'Request process latency histogram of each endpoint', ['endpoint'] +) + +latency_seconds_summary = Summary('latency_seconds_summary', 'Request process latency summary') +endpoint_latency_seconds_summary = Summary('endpoint_latency_seconds_summary', 'Request process latency summary of each endpoint', ['endpoint']) + + +# Define middleware to collect metrics +class RequestMetricsMiddleware(BaseHTTPMiddleware): + """ + Middleware to process requests. + """ + async def dispatch(self, request, call_next): + path = request.scope.get('path') + is_req = path != '/metrics' + + if not is_req: + try: + response = await call_next(request) + return response + except Exception as e: + raise e + + begin = time.time() + requests_total.inc() + endpoint_requests_total.labels(path).inc() + try: + response = await call_next(request) + if response.status_code / 100 < 4: + requests_success_total.inc() + endpoint_requests_success_total.labels(path).inc() + end = time.time() + if path in latencies: + latencies[path].append(end - begin) + latencies[path] = latencies[path][-MAX_OBSERVATION:] + else: + latencies[path] = [end - begin] + latencies['all'].append(end - begin) + latencies['all'] = latencies['all'][-MAX_OBSERVATION:] + latency_seconds_histogram.observe(end - begin) + endpoint_latency_seconds_histogram.labels(path).observe(end - begin) + latency_seconds_summary.observe(end - begin) + endpoint_latency_seconds_summary.labels(path).observe(end - begin) + return response + else: + requests_failed_total.inc() + endpoint_requests_failed_total.labels(path).inc() + except Exception as e: + requests_failed_total.inc() + endpoint_requests_failed_total.labels(path).inc() + raise e + + +app.add_middleware(RequestMetricsMiddleware) + + @app.get('/') def check_api(): - return jsonable_encoder({'status': True, 'msg': 'ok'}), 200 + res = jsonable_encoder({'status': True, 'msg': 'ok'}), 200 + return res + + +@app.get('/metrics') +async def metrics(): + registry = REGISTRY + data = generate_latest(registry) + return PlainTextResponse(content=data, media_type='text/plain') + @app.get('/answer') def do_answer_api(session_id: str, project: str, question: str): diff --git a/requirements.txt b/requirements.txt index c1516ab..35c55cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ uvicorn towhee>=1.1.0 pymilvus elasticsearch>=8.0.0 +prometheus-client