Skip to content

Commit

Permalink
feat: custom model 추가
Browse files Browse the repository at this point in the history
- 여러 카메라를 동시에 송출시 연결끊기는 이슈 수정
  • Loading branch information
sukkyun2 committed Aug 27, 2024
1 parent 7a1b4ef commit 5ca1396
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 4 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
HISTORY_API_HOST=http://3.34.196.131:8080
YOLO_WEIGHT_PATH=yolov5su.pt
CUSTOM_YOLO_WEIGHT_PATH=model/weight.pt
ALLOW_ORIGINS=http://localhost:5173;https://apap-ict.p-e.kr
1 change: 1 addition & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class Settings(BaseSettings):
history_api: str = Field(alias='HISTORY_API_HOST', default='http://localhost:8080')
yolo_weight_path: str = Field(alias='YOLO_WEIGHT_PATH', default='yolov5su.pt')
custom_yolo_weight_path: str = Field(alias='CUSTOM_YOLO_WEIGHT_PATH', default='yolov5su.pt')
allow_origins: str = Field(alias='ALLOW_ORIGINS', default='') # TODO Convert to List

class Config:
Expand Down
26 changes: 23 additions & 3 deletions model/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
from model.schema import Detection, TrackedObject, DetectionResult

model = YOLO(settings.yolo_weight_path)
model2 = YOLO(settings.yolo_weight_path)
custom_model = YOLO(settings.custom_yolo_weight_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model2.to(device)
custom_model.to(device)
print(f"Model run on the {device}")
print(f"General Model weight path : {settings.yolo_weight_path}")
print(f"Custom Model weight path : {settings.custom_yolo_weight_path}")

tracked_objects: Dict[int, TrackedObject] = {}


def track(image_np: ndarray) -> DetectionResult:
def track(image_np: ndarray, model: YOLO) -> DetectionResult:
result = model.track(image_np, persist=True)[0]
boxes = result.boxes

Expand Down Expand Up @@ -55,7 +61,7 @@ def check_intruded(bbox: list[float], zone: list[int]) -> bool:

def area_intrusion(image_np: ndarray) -> tuple[bool, DetectionResult]:
zone = define_zone(image_np)
result = track(image_np)
result = track(image_np, model)

image_with_zone = draw_zone(image_np, zone)
intrusion_detections = []
Expand All @@ -77,7 +83,7 @@ def area_intrusion(image_np: ndarray) -> tuple[bool, DetectionResult]:


def estimate_distance(image_np: ndarray) -> tuple[list[tuple[int, int, float]], DetectionResult]:
result = track(image_np)
result = track(image_np, model2)
non_person_detections = []
person_detections = []

Expand Down Expand Up @@ -131,6 +137,20 @@ def calculate_distance(p1, p2):
return distances


def detect_by_custom_model(image_np: ndarray) -> DetectionResult:
result = custom_model.predict(image_np)[0]
boxes = result.boxes

class_idxes = boxes.cls.int().cpu().tolist()
confidences = boxes.conf.int().cpu().tolist()
bboxes = boxes.xyxy.cpu().tolist()

detections = [Detection(model.names[ci], c, None, bbox) for ci, c, bbox in
zip(class_idxes, confidences, bboxes)]

return DetectionResult(result.plot(), detections)


def detect(image_np: ndarray) -> DetectionResult:
target_image = img.fromarray(image_np)
result = model.predict(target_image)[0]
Expand Down
10 changes: 9 additions & 1 deletion model/operations.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from enum import StrEnum

from model.detect import estimate_distance, area_intrusion
from model.detect import estimate_distance, area_intrusion, detect_by_custom_model
from model.schema import DetectionResult


class OperationType(StrEnum):
ESTIMATE_DISTANCE = "estimate_distance",
AREA_INTRUSION = "area_intrusion"
CUSTOM_MODEL = "custom_model"


def define_operation(op: OperationType):
if op == OperationType.ESTIMATE_DISTANCE:
return handle_estimate_distance
elif op == OperationType.AREA_INTRUSION:
return handle_area_intrusion
elif op == OperationType.CUSTOM_MODEL:
return handle_detect_by_custom_model


def handle_estimate_distance(img) -> tuple[bool, DetectionResult]:
Expand All @@ -26,3 +29,8 @@ def handle_estimate_distance(img) -> tuple[bool, DetectionResult]:
def handle_area_intrusion(img) -> tuple[bool, DetectionResult]:
intrusion, result = area_intrusion(img)
return intrusion, result


def handle_detect_by_custom_model(img) -> tuple[bool, DetectionResult]:
result = detect_by_custom_model(img)
return len(result.detections) > 0, result
Binary file added model/weight.pt
Binary file not shown.

0 comments on commit 5ca1396

Please sign in to comment.