Skip to content

Commit

Permalink
Create torch_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jeonghoonkang authored Oct 9, 2024
1 parent 0bec396 commit 1ff0d26
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions apps/deeplearning/PyTorch/torch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import time

# 테스트용 간단한 CNN 모델 정의
class SimpleCNN(torch.nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 3, stride=1, padding=1)
self.pool = torch.nn.MaxPool2d(2, 2)
self.fc1 = torch.nn.Linear(16 * 16 * 16, 10)

def forward(self, x):
x = self.pool(torch.nn.functional.relu(self.conv1(x)))
x = x.view(-1, 16 * 16 * 16)
x = self.fc1(x)
return x

# 모델 생성 및 임의 데이터 생성
model = SimpleCNN()
dummy_input = torch.randn(1, 3, 32, 32)

# 모델 성능 테스트
start_time = time.time()
with torch.no_grad():
for _ in range(100):
output = model(dummy_input)
end_time = time.time()

# 평균 처리 시간 계산
avg_time = (end_time - start_time) / 100
print(f"Average inference time per batch: {avg_time:.6f} seconds")

0 comments on commit 1ff0d26

Please sign in to comment.