Skip to content

Commit

Permalink
feat: publisher에 operation type추가
Browse files Browse the repository at this point in the history
- 이상상황 영상 저장 삭제
- 위험지역 침입 operation 추가
  • Loading branch information
sukkyun2 committed Aug 21, 2024
1 parent f4ef4df commit 0818468
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 16 deletions.
24 changes: 9 additions & 15 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import asyncio
from io import BytesIO
from typing import Optional

import cv2
import numpy as np
from PIL import Image
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi import UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.params import Query

from app.api_response import ApiResponse, ApiListResponse
from app.config import settings
from app.connection_manager import ConnectionManager
from app.history import async_save_history
from model.detect import detect, estimate_distance, DetectionResult
from model.detect import detect, estimate_distance, DetectionResult, area_intrusion
from model.operations import OperationType, define_operation
from model.video_recorder import VideoRecorder

app = FastAPI()
Expand Down Expand Up @@ -47,37 +50,28 @@ def exists_publisher() -> ApiListResponse[str]:


@app.websocket("/ws/publishers/{location_name}")
async def websocket_publisher(websocket: WebSocket, location_name: str):
async def websocket_publisher(websocket: WebSocket,
location_name: str,
op: Optional[OperationType] = Query(OperationType.ESTIMATE_DISTANCE)):
await manager.connect(location_name, websocket)
video_recorder = VideoRecorder()

try:
while True:
# pre-processing
data = await websocket.receive_bytes()
img = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR) # byte to nparr

pattern_detected, result = handle_estimate_distance(img)
operation = define_operation(op)
pattern_detected, result = operation(img)
if pattern_detected:
print("Pattern Detected")
video_recorder.start_record_if_not()
await async_save_history(result, location_name)
if video_recorder.is_recording:
video_recorder.record_frame(result.plot_image)

await manager.broadcast(location_name, result.get_encoded_nparr().tobytes())
except WebSocketDisconnect:
manager.disconnect(location_name)
print("Publisher disconnected")


def handle_estimate_distance(img) -> tuple[bool, DetectionResult]:
distances, result = estimate_distance(img)
pattern_detected = any(distance <= 200 for _, _, distance in distances)

return pattern_detected, result


@app.websocket("/ws/subscribers/{location_name}")
async def websocket_subscriber(location_name: str, websocket: WebSocket):
await manager.subscribe(location_name, websocket)
Expand Down
43 changes: 42 additions & 1 deletion model/detect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from typing import List, Dict, Tuple

import cv2
Expand Down Expand Up @@ -35,6 +34,48 @@ def track(image_np: ndarray) -> DetectionResult:
return DetectionResult(result.plot(), detections)


def define_zone(image_np: ndarray) -> list[int]:
height, width = image_np.shape[:2]
return [width // 4, height // 4, width // 2, height // 2]


def draw_zone(image_np: ndarray, zone: list[int]) -> ndarray:
image_with_zone = cv2.rectangle(image_np.copy(), (zone[0], zone[1]), (zone[2], zone[3]), (0, 0, 255), 2)
label_text = "Danger Zone"
label_position = (zone[0] + 10, zone[1] - 10)
cv2.putText(image_with_zone, label_text, label_position, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

return image_with_zone


def check_intruded(bbox: list[float], zone: list[int]) -> bool:
x1, y1, x2, y2 = map(int, bbox)
return not (x2 < zone[0] or x1 > zone[2] or y2 < zone[1] or y1 > zone[3])


def area_intrusion(image_np: ndarray) -> tuple[bool, DetectionResult]:
zone = define_zone(image_np)
result = track(image_np)

image_with_zone = draw_zone(image_np, zone)
intrusion_detections = []
intrusion = False

for d in result.detections:
class_name, track_id, confidence, bbox = d.class_name, d.track_id, d.confidence, d.bbox
x1, y1, x2, y2 = map(int, bbox)

if class_name == 'person' and check_intruded(bbox, zone):
intrusion = True
intrusion_detections.append(d)

image_with_zone = cv2.rectangle(image_with_zone, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image_with_zone, f'{class_name}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 255, 0), 2)

return intrusion, DetectionResult(image_with_zone, intrusion_detections)


def estimate_distance(image_np: ndarray) -> tuple[list[tuple[int, int, float]], DetectionResult]:
result = track(image_np)
non_person_detections = []
Expand Down
28 changes: 28 additions & 0 deletions model/operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from enum import StrEnum

from model.detect import estimate_distance, area_intrusion
from model.schema import DetectionResult


class OperationType(StrEnum):
ESTIMATE_DISTANCE = "estimate_distance",
AREA_INTRUSION = "area_intrusion"


def define_operation(op: OperationType):
if op == OperationType.ESTIMATE_DISTANCE:
return handle_estimate_distance
elif op == OperationType.AREA_INTRUSION:
return handle_area_intrusion


def handle_estimate_distance(img) -> tuple[bool, DetectionResult]:
distances, result = estimate_distance(img)
pattern_detected = any(distance <= 200 for _, _, distance in distances)

return pattern_detected, result


def handle_area_intrusion(img) -> tuple[bool, DetectionResult]:
intrusion, result = area_intrusion(img)
return intrusion, result
23 changes: 23 additions & 0 deletions tests/test_area_intrusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import cv2

from model.operations import handle_area_intrusion


def test_area_intrusion():
cap = cv2.VideoCapture("resources/person.mp4")
assert cap.isOpened(), "Error reading video file"

while cap.isOpened():
success, im0 = cap.read()
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break

pattern_detected, result = handle_area_intrusion(im0)

cv2.imshow("Object Tracking", result.plot_image)
if cv2.waitKey(1) & 0xFF == ord('q'):
break

cap.release()
cv2.destroyAllWindows()

0 comments on commit 0818468

Please sign in to comment.