Skip to content

Commit

Permalink
Merge pull request #255 from ranking-agent/MCQ
Browse files Browse the repository at this point in the history
Mcq
  • Loading branch information
cbizon authored Jul 17, 2024
2 parents e1d768a + 10b082a commit 58e677c
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 27 deletions.
2 changes: 1 addition & 1 deletion openapi-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ servers:
# url: http://127.0.0.1:5000
termsOfService: http://robokop.renci.org:7055/tos?service_long=ARAGORN&provider_long=RENCI
title: ARAGORN
version: 2.7.5
version: 2.8.0
tags:
- name: translator
- name: ARA
Expand Down
15 changes: 10 additions & 5 deletions src/results_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@ def __init__(
password=redis_password,
)

def get_query_key(self, input_id, predicate, qualifiers, source_input, caller, workflow):
def get_query_key(self, input_id, predicate, qualifiers, source_input, caller, workflow, mcq, member_ids):
keydict = {'predicate': predicate, 'source_input': source_input, 'input_id': input_id, 'caller': caller, 'workflow': workflow}
keydict.update(qualifiers)
if mcq:
#because we already have a bunch of keys without mcq, we only want to add these if we are doing the new mcq.
member_ids.sort()
keydict['mcq'] = True
keydict['member_ids'] = member_ids
return json.dumps(keydict, sort_keys=True)

def get_result(self, input_id, predicate, qualifiers, source_input, caller, workflow):
key = self.get_query_key(input_id, predicate, qualifiers, source_input, caller, workflow)
def get_result(self, input_id, predicate, qualifiers, source_input, caller, workflow, mcq, member_ids):
key = self.get_query_key(input_id, predicate, qualifiers, source_input, caller, workflow, mcq, member_ids)
try:
result = self.creative_redis.get(key)
if result is not None:
Expand All @@ -51,8 +56,8 @@ def get_result(self, input_id, predicate, qualifiers, source_input, caller, work
return result


def set_result(self, input_id, predicate, qualifiers, source_input, caller, workflow, final_answer):
key = self.get_query_key(input_id, predicate, qualifiers, source_input, caller, workflow)
def set_result(self, input_id, predicate, qualifiers, source_input, caller, workflow, mcq, member_ids, final_answer):
key = self.get_query_key(input_id, predicate, qualifiers, source_input, caller, workflow, mcq, member_ids)

try:
self.creative_redis.set(key, gzip.compress(json.dumps(final_answer).encode()))
Expand Down
100 changes: 100 additions & 0 deletions src/rules/MCQ.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{
"{\"mcq\": true, \"predicate\": \"biolink:genetically_associated_with\"}": [
{
"Rule": "?a phenotype of ?b is genetically associated with ?c => ?a genetically associated with ?c",
"template": {
"query_graph": {
"nodes": {
"$source": {
"ids": [
"$source_id"
],
"categories": [
"biolink:PhenotypicFeature"
],
"set_interpretation": "MANY"
},
"$target": {
"ids": [
"$target_id"
],
"categories": [
"biolink:Gene"
]
},
"b": {
"categories": [
"biolink:Disease"
],
"set_interpretation": "MANY"
}
},
"edges": {
"edge_0": {
"subject": "g",
"object": "$source",
"predicates": [
"biolink:has_phenotype"
]
},
"edge_1": {
"subject": "g",
"object": "$target",
"predicates": [
"biolink:genetically_associated_with"
]
}
}
}
}
},
{
"Rule": "?a contributed to by ?b affects ?c => ?a genetically associated with ?c",
"template": {
"query_graph": {
"nodes": {
"$source": {
"ids": [
"$source_id"
],
"categories": [
"biolink:PhenotypicFeature"
],
"set_interpretation": "MANY"
},
"$target": {
"ids": [
"$target_id"
],
"categories": [
"biolink:Gene"
]
},
"b": {
"categories": [
"biolink:ChemicalEntity"
],
"set_interpretation": "MANY"
}
},
"edges": {
"edge_0": {
"subject": "g",
"object": "$source",
"predicates": [
"biolink:contributes_to"
]
},
"edge_1": {
"subject": "g",
"object": "$target",
"predicates": [
"biolink:affects"
]
}
}
}
}
}
]
}
55 changes: 34 additions & 21 deletions src/service_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,19 @@

DUMPTRUCK = False

#from src.rules.rules import rules as AMIE_EXPANSIONS

logger = logging.getLogger(__name__)

# declare the directory where the async data files will exist
queue_file_dir = "./queue-files"

#Load in the AMIE rules. I'm not sure how this works wrt startup and workers.
#Load in the AMIE rules.
thisdir = os.path.dirname(__file__)
#Temporarily point to a typed rules file. In the future, we will get types in the basic rules and use the config
# to generate "rules.json" in the "rules" directory.
#rulefile = os.path.join(thisdir,"rules","rules.json")
rulefile = os.path.join(thisdir,"rules","kara_typed_rules","rules_with_types_cleaned_finalized.json")
with open(rulefile,'r') as inf:
AMIE_EXPANSIONS = json.load(inf)
rulefiles = [os.path.join(thisdir,"rules","kara_typed_rules","rules_with_types_cleaned_finalized.json")]
rulefiles.append( os.path.join(thisdir, "rules", "MCQ.json"))
AMIE_EXPANSIONS = {}
for rulefile in rulefiles:
with open(rulefile,'r') as inf:
AMIE_EXPANSIONS.update(json.load(inf))

def examine_query(message):
"""Decides whether the input is an infer. Returns the grouping node"""
Expand Down Expand Up @@ -86,7 +84,7 @@ def match_results_to_query(results, query_message, query_source, query_target, q
# rewrite the results to match the query.

#First, get the source, target, and qedge id's from the results
_, _, _, results_source, _, results_target, results_qedge_id = get_infer_parameters(results)
_, _, _, results_source, _, results_target, results_qedge_id, _, _ = get_infer_parameters(results)

#Now replace the results query graph with the input query graph
results["message"]["query_graph"] = query_message["message"]["query_graph"]
Expand Down Expand Up @@ -184,9 +182,9 @@ async def entry(message, guid, coalesce_type, caller) -> (dict, int):
if infer:
# We're going to cache infer queries, and we need to do that even if we're overriding the cache
# because we need these values to post to the cache at the end.
input_id, predicate, qualifiers, source, source_input, target, qedge_id = get_infer_parameters(message)
input_id, predicate, qualifiers, source, source_input, target, qedge_id, mcq, member_ids = get_infer_parameters(message)
if read_from_cache:
results = results_cache.get_result(input_id, predicate, qualifiers, source_input, caller, workflow_def)
results = results_cache.get_result(input_id, predicate, qualifiers, source_input, caller, workflow_def, mcq, member_ids)
if results is not None:
logger.info(f"{guid}: Returning results cache lookup")
# The results can't go verbatim. While the essense of the query is the same as the cached result,
Expand All @@ -196,6 +194,8 @@ async def entry(message, guid, coalesce_type, caller) -> (dict, int):
else:
logger.info(f"{guid}: Results cache miss")
else:
mcq = False
member_ids = []
if read_from_cache:
results = results_cache.get_lookup_result(workflow_def, query_graph)
if results is not None:
Expand All @@ -220,7 +220,7 @@ async def entry(message, guid, coalesce_type, caller) -> (dict, int):
# so we want to write to the cache if bypass cache is false or overwrite_cache is true
if overwrite_cache or (not bypass_cache):
if infer:
results_cache.set_result(input_id, predicate, qualifiers, source_input, caller, workflow_def, final_answer)
results_cache.set_result(input_id, predicate, qualifiers, source_input, caller, workflow_def, mcq, member_ids, final_answer)
elif {"id": "lookup"} in workflow_def:
results_cache.set_lookup_result(workflow_def, query_graph, final_answer)

Expand Down Expand Up @@ -854,31 +854,44 @@ def get_infer_parameters(input_message):
qualifiers = {}
else:
qualifiers = {"qualifier_constraints": qc}
if ("ids" in input_message["message"]["query_graph"]["nodes"][source]) \
and (input_message["message"]["query_graph"]["nodes"][source]["ids"] is not None):
input_id = input_message["message"]["query_graph"]["nodes"][source]["ids"][0]
mcq = False
snode = input_message["message"]["query_graph"]["nodes"][source]
tnode = input_message["message"]["query_graph"]["nodes"][target]
if ("ids" in snode) and (snode["ids"] is not None):
input_id = snode["ids"][0]
member_ids = snode.get("member_ids",[])
if "set_interpretation" in snode and snode["set_interpretation"] == "MANY":
mcq = True
source_input = True
else:
input_id = input_message["message"]["query_graph"]["nodes"][target]["ids"][0]
input_id = tnode["ids"][0]
member_ids = tnode.get("member_ids",[])
if "set_interpretation" in tnode and tnode["set_interpretation"] == "MANY":
mcq = True
source_input = False
#key = get_key(predicate, qualifiers)
return input_id, predicate, qualifiers, source, source_input, target, query_edge
return input_id, predicate, qualifiers, source, source_input, target, query_edge, mcq, member_ids

def get_rule_key(predicate, qualifiers):
def get_rule_key(predicate, qualifiers, mcq):
keydict = {'predicate': predicate}
keydict.update(qualifiers)
if mcq:
keydict["mcq"] = True
return json.dumps(keydict,sort_keys=True)

def expand_query(input_message, params, guid):
#Contract: 1. there is a single edge in the query graph 2. The edge is marked inferred. 3. Either the source
# or the target has IDs, but not both. 4. The number of ids on the query node is 1.
input_id, predicate, qualifiers, source, source_input, target, qedge_id = get_infer_parameters(input_message)
key = get_rule_key(predicate, qualifiers)
input_id, predicate, qualifiers, source, source_input, target, qedge_id, mcq, member_ids = get_infer_parameters(input_message)
key = get_rule_key(predicate, qualifiers, mcq)
#We want to run the non-inferred version of the query as well
qg = deepcopy(input_message["message"]["query_graph"])
for eid,edge in qg["edges"].items():
del edge["knowledge_type"]
messages = [{"message": {"query_graph":qg}, "parameters": input_message.get("parameters") or {}}]
#If it's an MCQ, then we also copy the KG which has the member_of edges
if mcq:
messages[0]["message"]["knowledge_graph"] = deepcopy(input_message["message"]["knowledge_graph"])
#If we don't have any AMIE expansions, this will just generate the direct query
for rule_def in AMIE_EXPANSIONS.get(key,[]):
query_template = Template(json.dumps(rule_def["template"]))
Expand Down

0 comments on commit 58e677c

Please sign in to comment.