Skip to content

Commit

Permalink
feat: add unittest.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
vndee committed Jul 7, 2024
1 parent 84c3f43 commit 5ed870e
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 25 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: unittest

on: [pull_request, push]

jobs:
unittest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.11'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run unittests
run: python -m unittest discover -s tests -v
1 change: 1 addition & 0 deletions llm_sandbox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .session import SandboxSession # noqa: F401
101 changes: 78 additions & 23 deletions llm_sandbox/session.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import io
import os
import docker
import tarfile
from typing import List, Optional, Union

from docker.models.images import Image
from llm_sandbox.utils import image_exists
from llm_sandbox.utils import (
image_exists,
get_libraries_installation_command,
get_code_file_extension,
get_code_execution_command,
)
from llm_sandbox.const import SupportedLanguage, SupportedLanguageValues


Expand Down Expand Up @@ -47,14 +54,14 @@ def __init__(

def open(self):
warning_str = (
"Since the `keep_image` flag is set to True the image and container will not be removed after the session "
"ends and remains for future use."
"Since the `keep_template` flag is set to True the docker image will not be removed after the session ends "
"and remains for future use."
)
if self.dockerfile:
self.path = os.path.dirname(self.dockerfile)
if self.verbose:
f_str = f"Building docker image from {self.dockerfile}"
f_str = f"{f_str}. {warning_str}" if self.keep_template else f_str
f_str = f"{f_str}\n{warning_str}" if self.keep_template else f_str
print(f_str)

self.image, _ = self.client.images.build(
Expand All @@ -67,8 +74,8 @@ def open(self):
if isinstance(self.image, str):
if not image_exists(self.client, self.image):
if self.verbose:
f_str = f"Pulling image {self.image}"
f_str = f"{f_str}. {warning_str}" if self.keep_template else f_str
f_str = f"Pulling image {self.image}.."
f_str = f"{f_str}\n{warning_str}" if self.keep_template else f_str
print(f_str)

self.image = self.client.images.pull(self.image)
Expand All @@ -81,7 +88,7 @@ def open(self):
self.container = self.client.containers.run(self.image, detach=True, tty=True)

def close(self):
if self.container and not self.keep_template:
if self.container:
self.container.remove(force=True)
self.container = None

Expand All @@ -107,31 +114,79 @@ def close(self):
else:
if self.verbose:
print(
f"Image {self.image.tags[-1]} is in use by other containers. Skipping removal."
f"Image {self.image.tags[-1]} is in use by other containers. Skipping removal.."
)

def run(self, code: str, libraries: List = []):
raise NotImplementedError
def run(self, code: str, libraries: Optional[List] = None):
if not self.container:
raise RuntimeError(
"Session is not open. Please call open() method before running code."
)

if libraries:
command = get_libraries_installation_command(self.lang, libraries)
self.execute_command(command)

code_file = f"/tmp/code.{get_code_file_extension(self.lang)}"
with open(code_file, "w") as f:
f.write(code)

self.copy_to_runtime(code_file, code_file)
result = self.execute_command(get_code_execution_command(self.lang, code_file))
return result

def copy_from_runtime(self, src: str, dest: str):
if not self.container:
raise RuntimeError(
"Session is not open. Please call open() method before copying files."
)

if self.verbose:
print(f"Copying {self.container.short_id}:{src} to {dest}..")

bits, stat = self.container.get_archive(src)
if stat["size"] == 0:
raise FileNotFoundError(f"File {src} not found in the container")

tarstream = io.BytesIO(b"".join(bits))
with tarfile.open(fileobj=tarstream, mode="r") as tar:
tar.extractall(os.path.dirname(dest))

def copy_from_runtime(self, path: str):
raise NotImplementedError
def copy_to_runtime(self, src: str, dest: str):
if not self.container:
raise RuntimeError(
"Session is not open. Please call open() method before copying files."
)

if self.verbose:
print(f"Copying {src} to {self.container.short_id}:{dest}..")

tarstream = io.BytesIO()
with tarfile.open(fileobj=tarstream, mode="w") as tar:
tar.add(src, arcname=os.path.basename(src))

def copy_to_runtime(self, path: str):
raise NotImplementedError
tarstream.seek(0)
self.container.put_archive(os.path.dirname(dest), tarstream)

def execute_command(self, command: str):
if not self.container:
raise RuntimeError(
"Session is not open. Please call open() method before executing commands."
)

def execute_command(self, command: str, shell: str = "/bin/sh"):
raise NotImplementedError
if self.verbose:
print(f"Executing command: {command}")

exit_code, output = self.container.exec_run(command)
if self.verbose:
print(f"Output: {output.decode()}")
print(f"Exit code: {exit_code}")

return exit_code, output.decode()

def __enter__(self):
self.open()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()


if __name__ == "__main__":
with SandboxSession(
dockerfile="tests/busybox.Dockerfile", keep_template=False, lang="python"
) as session:
session.run("print('Hello, World!')")
72 changes: 72 additions & 0 deletions llm_sandbox/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import docker
import docker.errors
from typing import List, Optional

from docker import DockerClient
from llm_sandbox.const import SupportedLanguage


def image_exists(client: DockerClient, image: str) -> bool:
Expand All @@ -18,3 +20,73 @@ def image_exists(client: DockerClient, image: str) -> bool:
return False
except Exception as e:
raise e


def get_libraries_installation_command(
lang: str, libraries: List[str]
) -> Optional[str]:
"""
Get the command to install libraries for the given language
:param lang: Programming language
:param libraries: List of libraries
:return: Installation command
"""
if lang == SupportedLanguage.PYTHON:
return f"pip install {' '.join(libraries)}"
elif lang == SupportedLanguage.JAVA:
return f"mvn install:install-file -Dfile={' '.join(libraries)}"
elif lang == SupportedLanguage.JAVASCRIPT:
return f"npm install {' '.join(libraries)}"
elif lang == SupportedLanguage.CPP:
return f"apt-get install {' '.join(libraries)}"
elif lang == SupportedLanguage.GO:
return f"go get {' '.join(libraries)}"
elif lang == SupportedLanguage.RUBY:
return f"gem install {' '.join(libraries)}"
else:
raise ValueError(f"Language {lang} is not supported")


def get_code_file_extension(lang: str) -> str:
"""
Get the file extension for the given language
:param lang: Programming language
:return: File extension
"""
if lang == SupportedLanguage.PYTHON:
return "py"
elif lang == SupportedLanguage.JAVA:
return "java"
elif lang == SupportedLanguage.JAVASCRIPT:
return "js"
elif lang == SupportedLanguage.CPP:
return "cpp"
elif lang == SupportedLanguage.GO:
return "go"
elif lang == SupportedLanguage.RUBY:
return "rb"
else:
raise ValueError(f"Language {lang} is not supported")


def get_code_execution_command(lang: str, code_file: str) -> str:
"""
Get the command to execute the code
:param lang: Programming language
:param code_file: Path to the code file
:return: Execution command
"""
if lang == SupportedLanguage.PYTHON:
return f"python {code_file}"
elif lang == SupportedLanguage.JAVA:
return f"java {code_file}"
elif lang == SupportedLanguage.JAVASCRIPT:
return f"node {code_file}"
elif lang == SupportedLanguage.CPP:
return f"./{code_file}"
elif lang == SupportedLanguage.GO:
return f"go run {code_file}"
elif lang == SupportedLanguage.RUBY:
return f"ruby {code_file}"
else:
raise ValueError(f"Language {lang} is not supported")
1 change: 0 additions & 1 deletion tests/Dockerfile

This file was deleted.

1 change: 0 additions & 1 deletion tests/busybox.Dockerfile

This file was deleted.

125 changes: 125 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import tarfile
import unittest
from io import BytesIO
from unittest.mock import patch, MagicMock
from llm_sandbox import SandboxSession


class TestSandboxSession(unittest.TestCase):
@patch("docker.from_env")
def setUp(self, mock_docker_from_env):
self.mock_docker_client = MagicMock()
mock_docker_from_env.return_value = self.mock_docker_client

self.image = "python:3.9.19-bullseye"
self.dockerfile = None
self.lang = "python"
self.keep_template = False
self.verbose = False

self.session = SandboxSession(
image=self.image,
dockerfile=self.dockerfile,
lang=self.lang,
keep_template=self.keep_template,
verbose=self.verbose,
)

def test_init_with_invalid_lang(self):
with self.assertRaises(ValueError):
SandboxSession(lang="invalid_language")

def test_init_with_both_image_and_dockerfile(self):
with self.assertRaises(ValueError):
SandboxSession(image="some_image", dockerfile="some_dockerfile")

def test_open_with_image(self):
self.mock_docker_client.images.get.return_value = MagicMock(
tags=["python:3.9.19-bullseye"]
)
self.mock_docker_client.containers.run.return_value = MagicMock()

self.session.open()
self.mock_docker_client.containers.run.assert_called_once()
self.assertIsNotNone(self.session.container)

def test_close(self):
mock_container = MagicMock()
self.session.container = mock_container

self.session.close()
mock_container.remove.assert_called_once()
self.assertIsNone(self.session.container)

def test_run_without_open(self):
with self.assertRaises(RuntimeError):
self.session.run("print('Hello')")

def test_run_with_code(self):
self.session.container = MagicMock()
self.session.execute_command = MagicMock(return_value=(0, "Output"))

result = self.session.run("print('Hello')")
self.session.execute_command.assert_called()
self.assertEqual(result, (0, "Output"))

def test_copy_to_runtime(self):
self.session.container = MagicMock()
src = "test.txt"
dest = "/tmp/test.txt"
with open(src, "w") as f:
f.write("test content")

self.session.copy_to_runtime(src, dest)
self.session.container.put_archive.assert_called()

os.remove(src)

@patch("tarfile.open")
def test_copy_from_runtime(self, mock_tarfile_open):
self.session.container = MagicMock()
src = "/tmp/test.txt"
dest = "test.txt"

# Create a mock tarfile
tarstream = BytesIO()
with tarfile.open(fileobj=tarstream, mode="w") as tar:
tarinfo = tarfile.TarInfo(name=os.path.basename(dest))
tarinfo.size = len(b"test content")
tar.addfile(tarinfo, BytesIO(b"test content"))

tarstream.seek(0)
self.session.container.get_archive.return_value = (
[tarstream.read()],
{"size": tarstream.__sizeof__()},
)

def mock_extractall(path):
with open(dest, "wb") as f:
f.write(b"test content")

mock_tarfile = MagicMock()
mock_tarfile.extractall.side_effect = mock_extractall
mock_tarfile_open.return_value.__enter__.return_value = mock_tarfile

self.session.copy_from_runtime(src, dest)
self.assertTrue(os.path.exists(dest))

os.remove(dest)

def test_execute_command(self):
mock_container = MagicMock()
self.session.container = mock_container

command = "echo 'Hello'"
mock_container.exec_run.return_value = (0, b"Hello\n")

exit_code, output = self.session.execute_command(command)
mock_container.exec_run.assert_called_with(command)
self.assertEqual(exit_code, 0)
self.assertEqual(output, "Hello\n")


if __name__ == "__main__":
unittest.main()

0 comments on commit 5ed870e

Please sign in to comment.