Skip to content

Commit

Permalink
Introduce file util
Browse files Browse the repository at this point in the history
  • Loading branch information
goFrendiAsgard committed Dec 15, 2024
1 parent 540e43b commit 4b5d224
Show file tree
Hide file tree
Showing 19 changed files with 139 additions and 174 deletions.
4 changes: 2 additions & 2 deletions src/zrb/builtin/llm/tool/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RAG_OVERLAP,
)
from zrb.util.cli.style import stylize_error, stylize_faint
from zrb.util.file import read_file
from zrb.util.run import run_async

Document = str | Callable[[], str]
Expand Down Expand Up @@ -137,8 +138,7 @@ def get_documents() -> list[Callable[[], str]]:
def _get_text_reader(file_path: str):
def read():
print(stylize_faint(f"Start reading {file_path}"), file=sys.stderr)
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
content = read_file(file_path)
print(stylize_faint(f"Complete reading {file_path}"), file=sys.stderr)
return content

Expand Down
36 changes: 18 additions & 18 deletions src/zrb/builtin/project/add/fastapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from zrb.task.make_task import make_task
from zrb.task.scaffolder import Scaffolder
from zrb.task.task import Task
from zrb.util.file import read_file, write_file
from zrb.util.string.conversion import double_quote, to_snake_case
from zrb.util.string.name import get_random_name

Expand Down Expand Up @@ -50,28 +51,27 @@ def register_fastapp_automation(ctx: AnyContext):
project_dir_path = ctx.input["project-dir"]
zrb_init_path = os.path.join(project_dir_path, "zrb_init.py")
app_dir_path = ctx.input.app
app_name = to_snake_case(ctx.input.app)
with open(zrb_init_path, "r") as f:
file_content = f.read().strip()
# Assemble new content
new_content_list = [file_content]
# Check if import load_file is exists, if not exists, add
snake_app_name = to_snake_case(ctx.input.app)
old_code = read_file(zrb_init_path).strip()
# Assemble new content components
import_load_file_script = "from zrb import load_file"
if import_load_file_script not in file_content:
new_content_list = [import_load_file_script] + new_content_list
# Add fastapp-automation script
automation_file_part = ", ".join(
[double_quote(part) for part in [app_dir_path, "_zrb", "main.py"]]
)
new_content_list = new_content_list + [
f"{app_name} = load_file(os.path.join(_DIR, {automation_file_part}))",
f"assert {app_name}",
"",
]
new_content = "\n".join(new_content_list)
# Write new content
with open(zrb_init_path, "w") as f:
f.write(new_content)
write_file(
zrb_init_path,
[
(
import_load_file_script
if import_load_file_script not in old_code
else None
),
old_code,
f"{snake_app_name} = load_file(os.path.join(_DIR, {automation_file_part}))",
f"assert {snake_app_name}",
"",
],
)


scaffold_fastapp >> register_fastapp_automation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from zrb.util.codemod.add_code_to_function import add_code_to_function
from zrb.util.codemod.add_code_to_module import add_code_to_module
from zrb.util.codemod.add_parent_to_class import add_parent_to_class
from zrb.util.file import read_file, write_file
from zrb.util.string.conversion import to_pascal_case, to_snake_case


Expand Down Expand Up @@ -102,8 +103,7 @@ async def register_my_app_name_migration(ctx: AnyContext):
APP_DIR, "module", to_snake_case(ctx.input.module), "migration_metadata.py"
)
app_name = os.path.basename(APP_DIR)
with open(migration_metadata_file_path, "r") as f:
file_content = f.read()
file_content = read_file(migration_metadata_file_path)
entity_name = to_snake_case(ctx.input.entity)
entity_class = to_pascal_case(ctx.input.entity)
new_file_content_list = (
Expand All @@ -115,8 +115,7 @@ async def register_my_app_name_migration(ctx: AnyContext):
"",
]
)
with open(migration_metadata_file_path, "w") as f:
f.write("\n".join(new_file_content_list))
write_file(migration_metadata_file_path, "\n".join(new_file_content_list))


@make_task(
Expand All @@ -129,8 +128,7 @@ async def register_my_app_name_api_client(ctx: AnyContext):
api_client_file_path = os.path.join(
APP_DIR, "module", to_snake_case(ctx.input.module), "client", "api_client.py"
)
with open(api_client_file_path, "r") as f:
file_content = f.read()
file_content = read_file(api_client_file_path)
module_config_name = to_snake_case(ctx.input.module).upper()
new_code = add_code_to_module(
file_content,
Expand All @@ -149,8 +147,7 @@ async def register_my_app_name_api_client(ctx: AnyContext):
new_code.strip(),
"",
]
with open(api_client_file_path, "w") as f:
f.write("\n".join(new_file_content_list))
write_file(api_client_file_path, "\n".join(new_file_content_list))


@make_task(
Expand All @@ -163,8 +160,7 @@ async def register_my_app_name_direct_client(ctx: AnyContext):
direct_client_file_path = os.path.join(
APP_DIR, "module", to_snake_case(ctx.input.module), "client", "direct_client.py"
)
with open(direct_client_file_path, "r") as f:
file_content = f.read()
file_content = read_file(direct_client_file_path)
new_code = add_code_to_module(
file_content, "user_direct_client = user_usecase.as_direct_client()"
)
Expand All @@ -181,8 +177,7 @@ async def register_my_app_name_direct_client(ctx: AnyContext):
new_code.strip(),
"",
]
with open(direct_client_file_path, "w") as f:
f.write("\n".join(new_file_content_list))
write_file(direct_client_file_path, "\n".join(new_file_content_list))


@make_task(
Expand All @@ -195,8 +190,7 @@ async def register_my_app_name_route(ctx: AnyContext):
direct_client_file_path = os.path.join(
APP_DIR, "module", to_snake_case(ctx.input.module), "route.py"
)
with open(direct_client_file_path, "r") as f:
file_content = f.read()
file_content = read_file(direct_client_file_path)
entity_name = to_snake_case(ctx.input.entity)
new_code = add_code_to_function(
file_content, "serve_route", f"{entity_name}_usecase.serve_route(app)"
Expand All @@ -208,8 +202,7 @@ async def register_my_app_name_route(ctx: AnyContext):
new_code.strip(),
"",
]
with open(direct_client_file_path, "w") as f:
f.write("\n".join(new_file_content_list))
write_file(direct_client_file_path, "\n".join(new_file_content_list))


@make_task(
Expand All @@ -222,8 +215,7 @@ async def register_my_app_name_client_method(ctx: AnyContext):
any_client_file_path = os.path.join(
APP_DIR, "module", to_snake_case(ctx.input.module), "route.py"
)
with open(any_client_file_path, "r") as f:
file_content = f.read()
file_content = read_file(any_client_file_path)
app_name = os.path.basename(APP_DIR)
snake_entity_name = to_snake_case(ctx.input.entity)
pascal_entity_name = to_pascal_case(ctx.input.entity)
Expand All @@ -232,8 +224,7 @@ async def register_my_app_name_client_method(ctx: AnyContext):
any_client_method_template_path = (
os.path.join(os.path.dirname(__file__), "any_client_method.template.py"),
)
with open(any_client_method_template_path, "r") as f:
any_client_method_template = f.read()
any_client_method_template = read_file(any_client_method_template_path)
any_client_method = any_client_method_template.replace(
"my_entity", snake_entity_name
).replace("MyEntity", pascal_entity_name)
Expand All @@ -245,8 +236,7 @@ async def register_my_app_name_client_method(ctx: AnyContext):
new_code.strip(),
"",
]
with open(any_client_file_path, "w") as f:
f.write("\n".join(new_file_content_list))
write_file(any_client_file_path, "\n".join(new_file_content_list))


# TODO: Register gateway route
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapp_template._zrb.input import new_module_input

from zrb import AnyContext, Scaffolder, Task, make_task
from zrb.util.file import read_file, write_file
from zrb.util.string.conversion import to_kebab_case, to_pascal_case, to_snake_case


Expand Down Expand Up @@ -61,8 +62,13 @@ async def register_my_app_name_module_config(ctx: AnyContext):
config_name = f"APP_{upper_module_name}_BASE_URL"
env_name = f"MY_APP_NAME_{upper_module_name}_BASE_URL"
# TODO: check before write
with open(config_file_name, "a") as f:
f.write(f'{config_name} = os.getenv("{env_name}", "{module_base_url}")\n')
write_file(
config_file_name,
[
read_file(config_file_name),
f'{config_name} = os.getenv("{env_name}", "{module_base_url}")\n',
],
)


@make_task(
Expand All @@ -77,12 +83,10 @@ async def register_my_app_name_module(ctx: AnyContext):
module_name = to_snake_case(ctx.input.module)
import_code = f"from fastapp_template.module.{module_name} import route as {module_name}_route" # noqa
assert_code = f"assert {module_name}_route"
with open(app_main_file_name, "r") as f:
code = f.read()
code = read_file(app_main_file_name)
new_code = "\n".join([import_code, code.strip(), assert_code, ""])
# TODO: check before write
with open(app_main_file_name, "w") as f:
f.write(new_code)
write_file(app_main_file_name, new_code)


# TODO: Register config
Expand All @@ -108,20 +112,19 @@ async def register_my_app_name_module_runner(ctx: AnyContext):
module_snake_name = to_snake_case(ctx.input.module)
module_kebab_name = to_kebab_case(ctx.input.module)
module_pascal_name = to_pascal_case(ctx.input.module)
with open(os.path.join(os.path.dirname(__file__), "run_module.template.py")) as f:
module_runner_code = (
f.read()
.replace("my_module", module_snake_name)
.replace("my-module", module_kebab_name)
.replace("My Module", module_pascal_name)
.replace("3000", f"{module_port}")
)
with open(task_main_file_name, "r") as f:
code = f.read()
module_runner_code = read_file(
os.path.join(os.path.dirname(__file__), "run_module.template.py"),
{
"my_module": module_snake_name,
"my-module": module_kebab_name,
"My Module": module_pascal_name,
"3000": f"{module_port}",
},
)
code = read_file(task_main_file_name)
new_code = "\n".join([code.strip(), "", module_runner_code, ""])
# TODO: check before write
with open(task_main_file_name, "w") as f:
f.write(new_code)
write_file(task_main_file_name, new_code)


create_my_app_name_module = app_create_group.add_task(
Expand Down
12 changes: 4 additions & 8 deletions src/zrb/builtin/setup/asdf/asdf_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from zrb.context.any_context import AnyContext
from zrb.util.file import read_file, write_file


def get_install_prerequisites_cmd(ctx: AnyContext) -> str:
Expand Down Expand Up @@ -31,14 +32,9 @@ def setup_asdf_ps_config(file_path: str):


def _setup_asdf_config(file_path: str, asdf_config: str):
dir_path = os.path.dirname(file_path)
os.makedirs(dir_path, exist_ok=True)
if not os.path.isfile(file_path):
with open(file_path, "w") as f:
f.write("")
with open(file_path, "r") as f:
content = f.read()
write_file(file_path, "")
content = read_file(file_path)
if asdf_config in content:
return
with open(file_path, "a") as f:
f.write(f"\n{asdf_config}\n")
write_file(file_path, [content, asdf_config, ""])
19 changes: 7 additions & 12 deletions src/zrb/builtin/setup/tmux/tmux.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from zrb.input.str_input import StrInput
from zrb.task.cmd_task import CmdTask
from zrb.task.make_task import make_task
from zrb.util.file import read_file, write_file

install_tmux = CmdTask(
name="install-tmux",
Expand All @@ -28,22 +29,16 @@
alias="tmux",
)
def setup_tmux(ctx: AnyContext):
with open(os.path.join(os.path.dirname(__file__), "tmux_config.sh"), "r") as f:
tmux_config_template = f.read()
tmux_config = read_file(os.path.join(os.path.dirname(__file__), "tmux_config.sh"))
tmux_config_file = os.path.expanduser(ctx.input["tmux-config"])
tmux_config_dir = os.path.dirname(tmux_config_file)
# Make sure config file exists
os.makedirs(tmux_config_dir, exist_ok=True)
if not os.path.isfile(tmux_config_file):
with open(tmux_config_file, "w") as f:
f.write("")
with open(tmux_config_file, "r") as f:
# config file already contain the config
if tmux_config_template in f.read():
return
write_file(tmux_config_file, "")
content = read_file(tmux_config_file)
if tmux_config in content:
return
# Write config
with open(tmux_config_file, "a") as f:
f.write(f"\n{tmux_config_template}\n")
write_file(tmux_config_file, [content, tmux_config, ""])
ctx.print("Setup complete, restart your terminal to continue")


Expand Down
19 changes: 7 additions & 12 deletions src/zrb/builtin/todo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from zrb.input.str_input import StrInput
from zrb.input.text_input import TextInput
from zrb.task.make_task import make_task
from zrb.util.file import read_file, write_file
from zrb.util.todo import (
TodoTaskModel,
add_durations,
Expand Down Expand Up @@ -119,8 +120,7 @@ def show_todo(ctx: AnyContext):
log_work_path = os.path.join(TODO_DIR, "log-work", f"{task_id}.json")
log_work_list = []
if os.path.isfile(log_work_path):
with open(log_work_path, "r") as f:
log_work_list = json.loads(f.read())
log_work_list = json.loads(read_file(log_work_path))
return get_visual_todo_card(todo_task, log_work_list)


Expand Down Expand Up @@ -237,24 +237,21 @@ def log_todo(ctx: AnyContext):
log_work_dir, f"{todo_task.keyval.get('id')}.json"
)
if os.path.isfile(log_work_file_path):
with open(log_work_file_path, "r") as f:
log_work_json = f.read()
log_work_json = read_file(log_work_file_path)
else:
log_work_json = "[]"
log_work: list[dict[str, Any]] = json.loads(log_work_json)
log_work.append(
{"log": ctx.input.log, "duration": ctx.input.duration, "start": ctx.input.start}
)
# save todo with log work
with open(log_work_file_path, "w") as f:
f.write(json.dumps(log_work, indent=2))
write_file(log_work_file_path, json.dumps(log_work, indent=2))
# get log work list
task_id = todo_task.keyval.get("id", "")
log_work_path = os.path.join(TODO_DIR, "log-work", f"{task_id}.json")
log_work_list = []
if os.path.isfile(log_work_path):
with open(log_work_path, "r") as f:
log_work_list = json.loads(f.read())
log_work_list = json.loads(read_file(log_work_path))
return "\n".join(
[
get_visual_todo_list(todo_list, TODO_VISUAL_FILTER),
Expand Down Expand Up @@ -290,8 +287,7 @@ def edit_todo(ctx: AnyContext):
]
new_content = "\n".join(todo_task_to_line(todo_task) for todo_task in todo_list)
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
with open(todo_file_path, "w") as f:
f.write(new_content)
write_file(todo_file_path, new_content)
todo_list = load_todo_list(todo_file_path)
return get_visual_todo_list(todo_list, TODO_VISUAL_FILTER)

Expand All @@ -300,5 +296,4 @@ def _get_todo_txt_content() -> str:
todo_file_path = os.path.join(TODO_DIR, "todo.txt")
if not os.path.isfile(todo_file_path):
return ""
with open(todo_file_path, "r") as f:
return f.read()
return read_file(todo_file_path)
Loading

0 comments on commit 4b5d224

Please sign in to comment.