Skip to content

Commit

Permalink
changing instruction vector list API
Browse files Browse the repository at this point in the history
  • Loading branch information
nishant-sachdeva committed Dec 28, 2023
1 parent 0e840a7 commit e9ffdcb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
20 changes: 9 additions & 11 deletions Manylinux2014_Compliant_Source/pkg/ir2vec/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,25 +143,23 @@ class ir2vecHandler {
}

// Function to get Instruction Vector Dictionary
PyObject *createInstructionVectorDict(
PyObject *createInstructionVectorList(
llvm::SmallMapVector<const llvm::Instruction *, IR2Vec::Vector, 128>
llvmInstVecMap) {
PyObject *InstVecDict = PyDict_New();

PyObject *instructionVector = PyList_New(0);
for (auto &Inst_it : llvmInstVecMap) {
std::string demangledName = IR2Vec::getDemagledName(Inst_it.first);

PyObject *instructionVector = PyList_New(0);
PyObject *InstVec = PyList_New(0);
// copy this SmallVector into c++ Vector
for (auto &Vec_it : Inst_it.second) {
PyList_Append(instructionVector, PyFloat_FromDouble(Vec_it));
PyList_Append(InstVec, PyFloat_FromDouble(Vec_it));
}
PyDict_SetDefault(InstVecDict,
PyUnicode_FromString(demangledName.c_str()), Py_None);
PyDict_SetItemString(InstVecDict, demangledName.c_str(),
instructionVector);

// add InstVec to instructionVector
PyList_Append(instructionVector, InstVec);
}
return InstVecDict;
return instructionVector;
}

// generateEncodings
Expand Down Expand Up @@ -205,7 +203,7 @@ class ir2vecHandler {
} else if (type == OpType::Instruction) {
llvm::SmallMapVector<const llvm::Instruction *, IR2Vec::Vector, 128>
instVecMap = emb->getInstVecMap();
return this->createInstructionVectorDict(instVecMap);
return this->createInstructionVectorList(instVecMap);
} else {
PyErr_SetString(PyExc_TypeError, "Invalid OpType");
Py_RETURN_NONE;
Expand Down
19 changes: 9 additions & 10 deletions Manylinux2014_Compliant_Source/pkg/tests/test_ir2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,10 @@ def assert_valid_progVector(progVector):
return True


def assert_valid_insructionVectors(insVecMap):
assert insVecMap is not None
def assert_valid_instructionVectors(instVecList):
assert instVecList is not None

keys = list(insVecMap.keys())
assert len(keys) > 0

values = list(insVecMap.values())
assert len(values) > 0

for ins, vec in insVecMap.items():
assert ins is not None
for vec in instVecList:
assert vec is not None
assert isinstance(vec, list)
assert all(isinstance(x, float) for x in vec)
Expand Down Expand Up @@ -142,6 +135,12 @@ def test_sym_p():
progVector2 = initObj.getProgramVector()
assert_valid_progVector(progVector2)

instVecList = ir2vec.getInstructionVectors(initObj)
assert_valid_instructionVectors(instVecList)

instVecList2 = initObj.getInstructionVectors()
assert_valid_instructionVectors(instVecList2)

for idx, vec in enumerate(progVector1):
assert vec == pytest.approx(progVector2[idx], abs=ABS_ACCURACY)

Expand Down

0 comments on commit e9ffdcb

Please sign in to comment.