Skip to content

Commit

Permalink
WIP2
Browse files Browse the repository at this point in the history
  • Loading branch information
luke-li-2003 committed Dec 10, 2024
1 parent 4fc1fa8 commit c7311bc
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 19 deletions.
4 changes: 2 additions & 2 deletions runtime/compiler/control/JITClientCompilationThread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2981,15 +2981,15 @@ handleServerMessage(JITServer::ClientStream *client, TR_J9VM *fe, JITServer::Mes
TR::KnownObjectTable::Index baseObjectIndex = std::get<0>(recv);
intptr_t fieldOffset = std::get<1>(recv);

UDATA data = 0;
J9::TransformUtil::value data;

{
TR::VMAccessCriticalSection addFieldAddressFromBaseIndex(fe);
uintptr_t baseObjectAddress = knot->getPointer(baseObjectIndex);

uintptr_t fieldAddress = baseObjectAddress + fieldOffset;

data = *(UDATA *) fieldAddress;
data = *(J9::TransformUtil::value *) fieldAddress;
}

client->write(response, data);
Expand Down
167 changes: 151 additions & 16 deletions runtime/compiler/optimizer/J9TransformUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,107 @@ static void *dereferenceStructPointerChain(void *baseStruct, TR::Node *baseNode,
return NULL;
}

static void *dereferenceStructPointer(TR::KnownObjectTable::Index baseKnownObject,
TR::Node *node,
TR::Node *baseExpression,
bool isBaseStableArray,
TR::Compilation *comp,
J9::TransformUtil::value *valuePtr)
{
TR_J9VMBase *fej9 = comp->fej9();
TR::SymbolReference *symRef = node->getSymbolReference();
TR::Symbol *field = symRef->getSymbol();

TR::Node *addressChildNode = field->isArrayShadowSymbol() ?
node->getFirstChild()->getFirstChild() :
node->getFirstChild();
// Abort if the indirection is more than a single level.
if (!addressChildNode->getOpCode().hasSymbolReference()
|| addressChildNode != baseExpression)
return NULL;

// We only consider the case where isJavaField is true for verifyFieldAccess
if (isJavaField(symRef, comp))
{
TR_OpaqueClassBlock *fieldClass = NULL;

if (symRef->getCPIndex() < 0 &&
field->getRecognizedField() != TR::Symbol::UnknownField)
{
const char* className;
int32_t length;
className = field->owningClassNameCharsForRecognizedField(length);
fieldClass = fej9->getClassFromSignature(className, length, symRef->getOwningMethod(comp));
}
else
fieldClass = symRef->getOwningMethod(comp)->getDeclaringClassFromFieldOrStatic(comp,
symRef->getCPIndex());

TR_OpaqueClassBlock *objectClass =
fej9->getObjectClassFromKnownObjectIndex(comp, baseKnownObject);

// field access verified
if ((fieldClass != NULL) && (fej9->isInstanceOf(objectClass, fieldClass, true) == TR_yes))
{
// check the recognized fields case of avoidFoldingInstanceField
// the non-null checks are done when we get the actual values
if (field->getRecognizedField() == TR::Symbol::Java_lang_invoke_CallSite_target ||
field->getRecognizedField() == TR::Symbol::Java_lang_invoke_MethodHandle_form)
return NULL;

TR::DataType loadType = node->getDataType();

switch (loadType)
{
case TR::Int32:
case TR::Int64:
case TR::Float:
case TR::Double:
{
// not address
auto stream = comp->getStream();
stream->write(JITServer::MessageType::KnownObjectTable_getFieldAddressData,
baseKnownObject, symRef->getOffset());
J9::TransformUtil::value value = std::get<0>(stream->read<J9::TransformUtil::value>());
*valuePtr = value;

return valuePtr;
}
break;
case TR::Address:
{
if (isFinalFieldPointingAtRepresentableNativeStruct(symRef, comp) ||
isFinalFieldPointingAtNativeStruct(symRef, comp))
{
return NULL;
}
else if (field->isCollectedReference())
{
auto stream = comp->getStream();
stream->write(
JITServer::MessageType::KnownObjectTable_addFieldAddressFromBaseIndex,
baseKnownObject, symRef->getOffset());
auto recv = stream->read<TR::KnownObjectTable::Index, uintptr_t *>();
TR::KnownObjectTable::Index value = std::get<0>(recv);
uintptr_t *objectReferenceLocationClient = std::get<1>(recv);
comp->getKnownObjectTable()->updateKnownObjectTableAtServer(
value,
objectReferenceLocationClient
);
valuePtr->idx = value;
return valuePtr;
}
}
break;
default:
return NULL;
}
}
}
return NULL;
}


bool J9::TransformUtil::foldFinalFieldsIn(TR_OpaqueClassBlock *clazz, const char *className, int32_t classNameLength, bool isStatic, TR::Compilation *comp)
{
TR::SimpleRegex *classRegex = comp->getOptions()->getClassesWithFoldableFinalFields();
Expand Down Expand Up @@ -1691,6 +1792,8 @@ J9::TransformUtil::transformIndirectLoadChainAt(TR::Compilation *comp, TR::Node
}
#endif /* defined(J9VM_OPT_JITSERVER) */

return false;
/*
TR::VMAccessCriticalSection transformIndirectLoadChainAt(comp->fej9());
uintptr_t baseAddress;
if (baseExpression->getOpCode().hasSymbolReference() && baseExpression->getSymbol()->isStatic())
Expand All @@ -1703,6 +1806,7 @@ J9::TransformUtil::transformIndirectLoadChainAt(TR::Compilation *comp, TR::Node
}
bool result = TR::TransformUtil::transformIndirectLoadChainImpl(comp, node, baseExpression, (void*)baseAddress, 0, removedNode);
return result;
*/
}

/** Dereference node and fold it into a constant when possible.
Expand Down Expand Up @@ -2073,23 +2177,38 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
return false;
}

// in non-jitserver mode, we need to hold a mutex until the end of the function to ensure
// fieldAddress holds
TR::VMAccessCriticalSection transformIndirectLoadChain(comp,
TR::VMAccessCriticalSection::tryToAcquireVMAccess);
if (!transformIndirectLoadChain.hasVMAccess() && !isServer)
return false;

void *valuePtr;
J9::TransformUtil::value value;
if (isServer)
{
// Instead of the recursive dereferenceStructPointerChain, we only consider a single level
// of indirection
void *result = dereferenceStructPointer(baseKnownObject, node, baseExpression,
isBaseStableArray, comp, &value);
valuePtr = &value;
if (result != valuePtr)
return false;
}
else // not server
else // not jitserver
{
TR::VMAccessCriticalSection transformIndirectLoadChain(comp->fej9());
// Dereference the chain starting from baseAddress and get the field address
void *fieldAddress = dereferenceStructPointerChain(baseAddress, baseExpression,
isBaseStableArray, node, comp);
void *fieldAddress = dereferenceStructPointerChain(
(void *) comp->getKnownObjectTable()->getPointer(baseKnownObject),
baseExpression,
isBaseStableArray, node, comp
);
if (!fieldAddress)
{
if (comp->getOption(TR_TraceOptDetails))
{
traceMsg(comp, "Abort transformIndirectLoadChain - cannot verify/dereference field access to %s in %p!\n", symRef->getName(comp->getDebug()), baseAddress);
traceMsg(comp, "Abort transformIndirectLoadChain - cannot verify/dereference field access to %s in %d!\n", symRef->getName(comp->getDebug()), baseKnownObject);
}
return false;
}
Expand All @@ -2107,7 +2226,7 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
{
case TR::Int32:
{
int32_t value = *(int32_t*)fieldAddress;
int32_t value = *(int32_t*)valuePtr;
if (changeIndirectLoadIntoConst(node, TR::iconst, removedNode, comp))
node->setInt(value);
else
Expand All @@ -2116,7 +2235,7 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
break;
case TR::Int64:
{
int64_t value = *(int64_t*)fieldAddress;
int64_t value = *(int64_t*)valuePtr;
if (changeIndirectLoadIntoConst(node, TR::lconst, removedNode, comp))
node->setLongInt(value);
else
Expand All @@ -2125,7 +2244,7 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
break;
case TR::Float:
{
float value = *(float*)fieldAddress;
float value = *(float*)valuePtr;
if (changeIndirectLoadIntoConst(node, TR::fconst, removedNode, comp))
node->setFloat(value);
else
Expand All @@ -2134,7 +2253,7 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
break;
case TR::Double:
{
double value = *(double*)fieldAddress;
double value = *(double*)valuePtr;
if (changeIndirectLoadIntoConst(node, TR::dconst, removedNode, comp))
node->setDouble(value);
else
Expand All @@ -2145,11 +2264,13 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
{
if (isFinalFieldPointingAtRepresentableNativeStruct(symRef, comp))
{
if (isServer)
return false;
if (fej9->isFinalFieldPointingAtJ9Class(symRef, comp))
{
if (changeIndirectLoadIntoConst(node, TR::loadaddr, removedNode, comp))
{
TR_OpaqueClassBlock *value = *(TR_OpaqueClassBlock**)fieldAddress;
TR_OpaqueClassBlock *value = *(TR_OpaqueClassBlock**)valuePtr;
node->setSymbolReference(comp->getSymRefTab()->findOrCreateClassSymbol(comp->getMethodSymbol(), -1, value));
}
else
Expand All @@ -2166,9 +2287,11 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
}
else if (isFinalFieldPointingAtNativeStruct(symRef, comp))
{
if (isServer)
return false;
if (symRef->getReferenceNumber() - comp->getSymRefTab()->getNumHelperSymbols() == TR::SymbolReferenceTable::ramStaticsFromClassSymbol)
{
uintptr_t value = *(uintptr_t*)fieldAddress;
uintptr_t value = *(uintptr_t*)valuePtr;
if (changeIndirectLoadIntoConst(node, TR::aconst, removedNode, comp))
{
node->setAddress(value);
Expand All @@ -2183,10 +2306,22 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
}
else if (symRef->getSymbol()->isCollectedReference())
{
uintptr_t value = fej9->getReferenceFieldAtAddress((uintptr_t)fieldAddress);
if (value)
TR::KnownObjectTable::Index knotIndex = -1;
if (isServer)
{
knotIndex = *(TR::KnownObjectTable::Index *)valuePtr;
}
else
{
uintptr_t value = fej9->getReferenceFieldAtAddress((uintptr_t)valuePtr);
knotIndex = comp->getKnownObjectTable()->getOrCreateIndexAt(&value,
isArrayWithConstantElements(symRef, comp));
}

if (knotIndex != -1)
{
TR::SymbolReference *improvedSymRef = comp->getSymRefTab()->findOrCreateSymRefWithKnownObject(symRef, &value, isArrayWithConstantElements(symRef, comp));
TR::SymbolReference *improvedSymRef =
comp->getSymRefTab()->findOrCreateSymRefWithKnownObject(symRef, knotIndex);

if (improvedSymRef->hasKnownObjectIndex()
&& performTransformation(comp, "O^O transformIndirectLoadChain: %s [%p] with fieldOffset %d is obj%d referenceAddr is %p\n", node->getOpCode().getName(), node, improvedSymRef->getKnownObjectIndex(), symRef->getOffset(), value))
Expand All @@ -2196,8 +2331,8 @@ J9::TransformUtil::transformIndirectLoadChainImpl(TR::Compilation *comp, TR::Nod
node->setIsNonNull(true);

int32_t stableArrayRank = isArrayWithStableElements(symRef->getCPIndex(),
symRef->getOwningMethod(comp),
comp);
symRef->getOwningMethod(comp),
comp);
if (isBaseStableArray)
stableArrayRank = baseStableArrayRank - 1;

Expand Down
10 changes: 10 additions & 0 deletions runtime/compiler/optimizer/J9TransformUtil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,19 @@ class OMR_EXTENSIBLE TransformUtil : public OMR::TransformUtilConnector
TR_ResolvedMethod *owningMethod,
int32_t cpIndex);

union value{
int32_t i;
int64_t l;
float f;
double d;
void *p;
TR::KnownObjectTable::Index idx;
};

static bool transformIndirectLoadChain(TR::Compilation *, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, TR::Node **removedNode);
static bool transformIndirectLoadChainAt(TR::Compilation *, TR::Node *node, TR::Node *baseExpression, uintptr_t *baseReferenceLocation, TR::Node **removedNode);
static bool transformIndirectLoadChainImpl( TR::Compilation *, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, int32_t baseStableArrayRank, TR::Node **removedNode);
static bool transformIndirectLoadChainServerImpl( TR::Compilation *, TR::Node *node, TR::Node *baseExpression, TR::KnownObjectTable::Index baseKnownObject, int32_t baseStableArrayRank, TR::Node **removedNode);

static bool fieldShouldBeCompressed(TR::Node *node, TR::Compilation *comp);

Expand Down
2 changes: 1 addition & 1 deletion runtime/compiler/optimizer/VectorAPIExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ TR_VectorAPIExpansion::getOpaqueClassBlockFromClassNode(TR::Compilation *comp, T
{
auto stream = comp->getStream();
stream->write(JITServer::MessageType::KnownObjectTable_getOpaqueClass,
symRef->getKnownObjectIndex());
knownObjectIndex);

clazz = (TR_OpaqueClassBlock *)std::get<0>(stream->read<uintptr_t>());
}
Expand Down

0 comments on commit c7311bc

Please sign in to comment.