Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option for removing the object types when writing to file #144

Open
wants to merge 5 commits into
base: dev-0.9.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions docs/notebooks/mod-and-write-pddl-to-file.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Modify and write PDDL to file"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook shows how to read a PDDL domain and problem, extend the problem file and then write to a new file. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import PDDL domain and problem "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tarski.io import PDDLReader, FstripsWriter\n",
"\n",
"# Assuming that you have these files on disk\n",
"domain_file = \"domain.pddl\"\n",
"problem_file = \"problem.pddl\"\n",
"out_domain_file = \"domain_out.pddl\"\n",
"out_problem_file = \"prob_out.pddl\"\n",
"\n",
"reader = PDDLReader(raise_on_error=True)\n",
"\n",
"reader.parse_domain(domain_file)\n",
"problem = reader.parse_instance(problem_file)\n",
"\n",
"writer = FstripsWriter(problem)\n",
"\n",
"writer.write_instance(out_problem_file)\n",
"print(\"Wrote problem file to\", out_problem_file)\n",
"\n",
"writer.write_domain(out_domain_file)\n",
"print(\"Wrote domain file to\", out_domain_file)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add predicates to your domain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we will add a block to the Blocksworld instance "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lang = problem.language\n",
"predicates = lang.predicates\n",
"sorts = lang.sorts\n",
"\n",
"clear_pred = [pred for pred in predicates if pred.name == \"clear\"][0]\n",
"block_name = \"b1\"\n",
"lang.constant(block_name, sorts[0]) # Assuming there is only the \"Object\" type\n",
"new_block = lang.get_constant(block_name)\n",
"# Make the new block clear\n",
"problem.init.add(clear_pred, new_block)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generate PDDL without types\n",
"\n",
"By default Tarski will add types to the PDDL files, even if only the default `object` type is present. This can be removed by setting the parameter `no_types=True` in the Writer class"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"writer = FstripsWriter(problem, no_types=True)\n",
"\n",
"writer.write_instance(out_problem_file)\n",
"print(\"Wrote problem file to\", out_problem_file)\n",
"\n",
"writer.write_domain(out_domain_file)\n",
"print(\"Wrote domain file to\", out_domain_file)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
47 changes: 32 additions & 15 deletions src/tarski/io/fstrips.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def parse_string(self, string, start_rule):
"""


def print_objects(constants):
def print_objects(constants, no_types=False):
""" Print a PDDL object declaration with the given objects.
Objects are sorted by name and grouped by type, and types sorted by name as well """
constants_by_sort = defaultdict(list)
Expand All @@ -106,6 +106,9 @@ def print_objects(constants):
sobjects = " ".join(sorted(constants_by_sort[sort]))
elements.append("{} - {}".format(sobjects, sort))

if no_types:
return sobjects

return linebreaks(elements, indentation=2, indent_first=False)


Expand Down Expand Up @@ -175,9 +178,13 @@ def print_problem_metric(problem):

class FstripsWriter:

def __init__(self, problem):
def __init__(self, problem, no_types = False):
self.problem = problem
self.lang = problem.language
self.no_types = no_types

if self.no_types and self.lang.sorts:
raise RuntimeError("The sort information will be lost if no_types is set to True")

def write(self, domain_filename, instance_filename, domain_constants: Optional[List[Constant]] = None):
domain_constants = domain_constants or []
Expand All @@ -191,7 +198,10 @@ def print_domain(self, constant_objects: Optional[List[Constant]] = None):
constants", and which as "PDDL instance objects", which is something that cannot be determined from the problem
information alone. If `constant_objects` is None, all objects are considered instance objects.
"""
tpl = load_tpl("fstrips_domain.tpl")
if self.no_types:
tpl = load_tpl("fstrips_domain_no_types.tpl")
else:
tpl = load_tpl("fstrips_domain.tpl")
content = tpl.format(
header_info="",
domain_name=self.problem.domain_name,
Expand All @@ -205,7 +215,7 @@ def print_domain(self, constant_objects: Optional[List[Constant]] = None):
)
return content

def write_domain(self, filename, constant_objects):
def write_domain(self, filename, constant_objects: Optional[List[Constant]] = None):
with open(filename, 'w', encoding='utf8') as file:
file.write(self.print_domain(constant_objects))

Expand All @@ -227,7 +237,7 @@ def print_instance(self, constant_objects: Optional[List[Constant]] = None):
domain_name=self.problem.domain_name,
problem_name=self.problem.name,

objects=print_objects(instance_objects),
objects=print_objects(instance_objects, no_types=self.no_types),
init=print_init(self.problem),
goal=print_goal(self.problem),
constraints=print_problem_constraints(self.problem),
Expand All @@ -236,12 +246,14 @@ def print_instance(self, constant_objects: Optional[List[Constant]] = None):
)
return content

def write_instance(self, filename, constant_objects):
def write_instance(self, filename, constant_objects: Optional[List[Constant]] = None):
with open(filename, 'w', encoding='utf8') as file:
file.write(self.print_instance(constant_objects))

def get_types(self):
res = []
if self.no_types:
return ("\n" + _TAB * 2).join(res)
for t in self.lang.sorts:
if t.builtin or t == self.lang.Object:
continue # Don't declare builtin elements
Expand Down Expand Up @@ -269,46 +281,51 @@ def get_predicates(self):
for fun in self.lang.predicates:
if fun.builtin:
continue # Don't declare builtin elements
domain_str = build_signature_string(fun.sort)
domain_str = build_signature_string(fun.sort, no_types=self.no_types)
res.append("({} {})".format(fun.symbol, domain_str))
return ("\n" + _TAB * 2).join(res)

def get_actions(self):
return "\n".join(self.get_action(a) for a in self.problem.actions.values())
return "\n".join(self.get_action(a, no_types=self.no_types) for a in self.problem.actions.values())

@staticmethod
def get_action(a):
def get_action(a, no_types=False):
base_indentation = 1
return action_tpl.format(
name=a.name,
parameters=print_variable_list(a.parameters),
parameters=print_variable_list(a.parameters, no_types=no_types),
precondition=print_formula(a.precondition, base_indentation),
effect=print_effects(a.effects, a.cost, base_indentation)
)

def get_derived_predicates(self):
return "\n".join(self.get_derived(d) for d in self.problem.derived_predicates.values())
return "\n".join(self.get_derived(d, no_types=self.no_types) for d in self.problem.derived_predicates.values())

@staticmethod
def get_derived(d):
def get_derived(d, no_types=False):
return derived_tpl.format(
name=d.predicate.symbol,
parameters=print_variable_list(d.parameters),
parameters=print_variable_list(d.parameters, no_types=no_types),
formula=print_formula(d.formula))


def build_signature_string(domain):
def build_signature_string(domain, no_types=False):
if not domain:
return ""

if no_types:
return " ".join(f"{print_variable_name(f'x{i}')}" for i, t in enumerate(domain, 1))

return " ".join(f"{print_variable_name(f'x{i}')} - {tarski_to_pddl_type(t)}" for i, t in enumerate(domain, 1))


def print_variable_name(name: str):
return name if name.startswith("?") else f'?{name}'


def print_variable_list(parameters):
def print_variable_list(parameters, no_types=False):
if no_types:
return " ".join(f"{print_variable_name(p.symbol)}" for p in parameters)
return " ".join(f"{print_variable_name(p.symbol)} - {p.sort.name}" for p in parameters)


Expand Down
24 changes: 24 additions & 0 deletions src/tarski/io/templates/fstrips_domain_no_types.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;; Domain file automatically generated by the Tarski FSTRIPS writer
;;; {header_info}
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (domain {domain_name})
(:requirements {requirements})
(:constants
{constants}
)

(:predicates
{predicates}
)

(:functions
{functions}
)

{derived}

{actions}
)