Skip to content

Commit

Permalink
fixing support for async neo4j driver
Browse files Browse the repository at this point in the history
  • Loading branch information
EvanDietzMorris committed Sep 30, 2024
1 parent 6367ff5 commit 2a775b9
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 11 deletions.
10 changes: 5 additions & 5 deletions reasoner_transpiler/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

from collections import defaultdict
from neo4j import AsyncResult, Result

from .attributes import transform_attributes, EDGE_SOURCE_PROPS
from .matching import match_query
Expand Down Expand Up @@ -112,10 +113,10 @@ def get_query(qgraph, **kwargs):
return " ".join(clauses)


def transform_result(cypher_result,
def transform_result(cypher_record,
qgraph: dict):

nodes, edges, paths = unpack_bolt_result(cypher_result)
nodes, edges, paths = unpack_bolt_record(cypher_record)

# Convert the list of unique result nodes from cypher results to dictionaries
# then convert them to TRAPI format, constructing the knowledge_graph["nodes"] section of the TRAPI response
Expand Down Expand Up @@ -449,9 +450,8 @@ def convert_jolt_edge_to_dict(jolt_edges, jolt_element_id_lookup):
return convert_edges


def unpack_bolt_result(bolt_response):
record = bolt_response.single()
return record['nodes'], record['edges'], record['paths']
def unpack_bolt_record(bolt_record):
return bolt_record['nodes'], bolt_record['edges'], bolt_record['paths']


def unpack_jolt_result(jolt_response):
Expand Down
62 changes: 57 additions & 5 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Initialize neo4j database helper function."""
import pytest
import requests
import base64
import neo4j
import asyncio

from neo4j import GraphDatabase
from reasoner_transpiler.cypher import transform_result


Expand All @@ -14,11 +13,17 @@ def fixture_neo4j_driver():
driver.close()


@pytest.fixture(name="async_neo4j_driver", scope="module")
def fixture_async_neo4j_driver():
driver = TranspilerAsyncNeo4jBoltDriver()
yield driver


class TranspilerNeo4jBoltDriver:
def __init__(self):
"""Pytest fixture for Neo4j database connection."""
url = "bolt://localhost:7687"
self.driver = GraphDatabase.driver(url, auth=("neo4j", "plater_testing_pw"))
self.driver = neo4j.GraphDatabase.driver(url, auth=("neo4j", "plater_testing_pw"))

@staticmethod
def _cypher_tx_function(tx,
Expand All @@ -31,7 +36,8 @@ def _cypher_tx_function(tx,

neo4j_result = tx.run(cypher, parameters=query_parameters)
if convert_to_trapi:
return transform_result(neo4j_result, qgraph)
neo4j_record = neo4j_result.single()
return transform_result(neo4j_record, qgraph)
return neo4j_result

def run(self,
Expand All @@ -53,3 +59,49 @@ def run(self,

def close(self):
self.driver.close()


class TranspilerAsyncNeo4jBoltDriver:

def __init__(self):
"""Pytest fixture for Neo4j database connection."""
self.driver = asyncio.run(self.get_async_driver())

async def get_async_driver(self):
url = "bolt://localhost:7687"
return neo4j.AsyncGraphDatabase.driver(url, auth=("neo4j", "plater_testing_pw"))

@staticmethod
async def _cypher_tx_function(tx,
cypher,
query_parameters=None,
convert_to_trapi=False,
qgraph=None):
if not query_parameters:
query_parameters = {}

neo4j_result: neo4j.AsyncResult = await tx.run(cypher, parameters=query_parameters)
if convert_to_trapi:
neo4j_record = await neo4j_result.single()
return transform_result(neo4j_record, qgraph)
return neo4j_result

async def run(self,
query,
query_parameters: dict = None,
convert_to_trapi=False,
qgraph=None):

if not query_parameters:
query_parameters = {}

async with self.driver.session(database="neo4j") as session:
result = await session.execute_read(self._cypher_tx_function,
cypher=query,
query_parameters=query_parameters,
convert_to_trapi=convert_to_trapi,
qgraph=qgraph)
return result

async def close(self):
await self.driver.close()
28 changes: 27 additions & 1 deletion tests/test_transform_results.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .fixtures import fixture_neo4j_driver
from .fixtures import fixture_neo4j_driver, fixture_async_neo4j_driver
from reasoner_transpiler.cypher import get_query
import asyncio


def test_bolt_driver_transform_results(neo4j_driver):
Expand All @@ -25,3 +26,28 @@ def test_bolt_driver_transform_results(neo4j_driver):
assert len(result["analyses"]) == 1
assert len(output['knowledge_graph']['nodes']) == 13
assert len(output['auxiliary_graphs']) == 14


def test_bolt_async_driver_transform_results(async_neo4j_driver):
qgraph = {
"nodes": {
"n0": {"ids": [
"MONDO:0000001",
"HP:0000118",
]},
"n1": {},
},
"edges": {
"e01": {
"subject": "n0",
"object": "n1",
},
},
}
output = asyncio.run(async_neo4j_driver.run(get_query(qgraph), convert_to_trapi=True, qgraph=qgraph))
assert len(output['results']) == 15
for result in output["results"]:
assert len(result["node_bindings"]) == 2
assert len(result["analyses"]) == 1
assert len(output['knowledge_graph']['nodes']) == 13
assert len(output['auxiliary_graphs']) == 14

0 comments on commit 2a775b9

Please sign in to comment.