diff --git a/src/main/java/me/zoarial/networkArbiter/ZoarialNetworkArbiter.kt b/src/main/java/me/zoarial/networkArbiter/ZoarialNetworkArbiter.kt index ce9ae59..a511459 100644 --- a/src/main/java/me/zoarial/networkArbiter/ZoarialNetworkArbiter.kt +++ b/src/main/java/me/zoarial/networkArbiter/ZoarialNetworkArbiter.kt @@ -23,6 +23,7 @@ object ZoarialNetworkArbiter { private val networkElementToByteMap = HashMap() private val byteToNetworkElementMap = HashMap() private val networkElementLengthMap = HashMap() + private val networkObjectCache = HashMap() private val inc = AutoIncrementInt() private const val NOT_ZNA_ERR_STR: String = "Not ZNA" @@ -41,13 +42,15 @@ object ZoarialNetworkArbiter { init { // TODO: Clean up and make these maps static at some point + classToNetworkElementMap[String::class.javaObjectType] = NetworkElementType.STRING + classToNetworkElementMap[UUID::class.javaObjectType] = NetworkElementType.UUID + // Add the Java boxed types classToNetworkElementMap[Byte::class.javaObjectType] = NetworkElementType.BYTE classToNetworkElementMap[Short::class.javaObjectType] = NetworkElementType.SHORT classToNetworkElementMap[Int::class.javaObjectType] = NetworkElementType.INT classToNetworkElementMap[Long::class.javaObjectType] = NetworkElementType.LONG classToNetworkElementMap[Boolean::class.javaObjectType] = NetworkElementType.BOOLEAN - classToNetworkElementMap[String::class.javaObjectType] = NetworkElementType.STRING - classToNetworkElementMap[UUID::class.javaObjectType] = NetworkElementType.UUID + // Add the Java primitive types classToNetworkElementMap[Byte::class.javaPrimitiveType!!] = classToNetworkElementMap[Byte::class.javaObjectType]!! classToNetworkElementMap[Short::class.javaPrimitiveType!!] = classToNetworkElementMap[Short::class.javaObjectType]!! classToNetworkElementMap[Int::class.javaPrimitiveType!!] = classToNetworkElementMap[Int::class.javaObjectType]!! @@ -86,7 +89,7 @@ object ZoarialNetworkArbiter { } println() println() - val objectStructure = getObjectStructure(obj::class.java, obj) + val objectStructure = getObjectStructure(obj) val basicElements = objectStructure.basicElements val advancedElements = objectStructure.advancedElements val totalLen = AtomicInteger(7) @@ -218,9 +221,12 @@ object ZoarialNetworkArbiter { println("Byte: " + buf[i]) } // TODO: Pass in the right object - val objectNetworkRepresentation = getObjectStructure(clazz, Object()) + val networkObjectStructureOpt = getObjectStructure(clazz) + if(networkObjectStructureOpt.isEmpty) { + throw NotANetworkObject("Object is not registered") + } val networkRepresentation = decodeNetworkObject(ByteArrayInputStream(buf)) - if (!objectNetworkRepresentation.equalsStructure(networkRepresentation)) { + if (!networkObjectStructureOpt.get().equalsStructure(networkRepresentation)) { throw MismatchedObject("Objects don't match") } else { println("Objects match signature") @@ -425,7 +431,7 @@ object ZoarialNetworkArbiter { throw RuntimeException(e) } ?: throw RuntimeException("Constructed a null object") - val objectStructure = getObjectStructure(obj::class.java, obj) + val objectStructure = getObjectStructure(obj) val basicElements = objectStructure.basicElements val advancedElements = objectStructure.advancedElements val inputDataStream = DataInputStream(inputStream) @@ -566,12 +572,12 @@ object ZoarialNetworkArbiter { } - /** - * - * @param clazz - * @return A sorted list of [NetworkElements][NetworkElement] - */ - private fun getObjectStructure(clazz: Class<*>, obj: Any): NetworkObject { + private fun getObjectStructure(clazz: Class<*>): Optional { + return Optional.ofNullable(networkObjectCache[clazz.canonicalName]) + } + + private fun registerNetworkObjectStructure(obj: Any) { + val clazz = obj.javaClass if (!clazz.isAnnotationPresent(ZoarialNetworkObject::class.java)) { throw NotANetworkObject("Object is not a ZoarialNetworkObject.") } @@ -580,37 +586,50 @@ object ZoarialNetworkArbiter { val advancedList = ArrayList() for (f in fields) { if (f.isAnnotationPresent(ZoarialObjectElement::class.java)) { + val objectElementAnnotation = f.getAnnotation(ZoarialObjectElement::class.java) + val optional: Boolean = objectElementAnnotation.optional val fieldClass = f.type val placement: Int = objectElementAnnotation.placement - val optional: Boolean = objectElementAnnotation.optional - val isArray: Boolean = fieldClass.isAssignableFrom(java.util.List::class.javaObjectType) - if (!classToNetworkElementMap.containsKey(fieldClass)) { - if (fieldClass == Optional::class.java) { - throw NotANetworkObject("The Optional class is not supported. Please use the `optional` attribute on @ZoarialObjectElement") - } - val str = "Object is not in map: $fieldClass" - println(str) - throw NotANetworkObject(str) - } else if (optional && fieldClass.isPrimitive) { - throw RuntimeException("A primitive type cannot be optional: $objectElementAnnotation") - } - val networkType = if(classToNetworkElementMap[fieldClass] != null) { - classToNetworkElementMap[fieldClass] - } else { - if(isArray) { + val isArray: Boolean = !fieldClass.isPrimitive && f[obj] is List<*> + //val isArray: Boolean = fieldClass.isInterface && f[obj].javaClass.isAssignableFrom(java.util.List::class.javaObjectType) + + // Error Checking + when { + isArray -> { + println("Object is an array: $fieldClass") val list = f[obj] as List<*> if(list.isEmpty()) { throw RuntimeException("List is empty") } - if(list.stream().filter { item -> item != null }.count() != 0L) { - throw RuntimeException("List contains null object") + if(list.stream().filter { item -> item == null }.count() != 0L) { + throw RuntimeException("List contains a null object") } - val objInList = list[0] ?: throw RuntimeException("The list has a null object") - classToNetworkElementMap[objInList.javaClass]!! + if (!classToNetworkElementMap.containsKey(list[0]!!.javaClass)) { + throw RuntimeException("Object type not supported") + } + } + !classToNetworkElementMap.containsKey(fieldClass) -> { + if (fieldClass == Optional::class.java) { + throw NotANetworkObject("The Optional class is not supported. Please use the `optional` attribute on @ZoarialObjectElement") + } + val str = "Object is not in map: $fieldClass" + throw NotANetworkObject(str) } - throw RuntimeException("Object type not supported") + optional && fieldClass.isPrimitive -> { + throw RuntimeException("A primitive type cannot be optional: $objectElementAnnotation") + } + } + + // Get Type + val networkType = if(isArray) { + val list = f[obj] as List<*> + classToNetworkElementMap[list[0]] + } else { + classToNetworkElementMap[fieldClass] } + + // Add element to the basic or advanced list val correctList: MutableList = if (isBasicElement(networkType)) basicList else advancedList if (correctList.stream().filter { e: NetworkElement -> e.index == placement }.findFirst().isEmpty) { correctList.add(NetworkElement(obj, placement, networkType, f, optional, isArray)) @@ -621,9 +640,30 @@ object ZoarialNetworkArbiter { } } - // Print list - //sortedElements.forEach(e -> System.out.println("Entry " + e.getIndex() + ": " + e.getType())); - return NetworkObject(basicList.stream().sorted(Comparator.comparingInt { obj: NetworkElement -> obj.index }).collect(Collectors.toList()), advancedList.stream().sorted(Comparator.comparingInt { obj: NetworkElement -> obj.index }).collect(Collectors.toList())) + // Sort the lists and create the new NetworkObject + val ret = NetworkObject( + basicList.stream().sorted(Comparator.comparingInt { obj: NetworkElement -> obj.index }).collect(Collectors.toList()), + advancedList.stream().sorted(Comparator.comparingInt { obj: NetworkElement -> obj.index }).collect(Collectors.toList()) + ) + networkObjectCache[clazz.canonicalName] = ret + } + + /** + * + * Return the object's network structure and cache the value for future use + * @param obj is an object which has the [ZoarialNetworkObject] annotation + * @return A sorted list of [NetworkElements][NetworkElement] + */ + private fun getObjectStructure(obj: Any): NetworkObject { + + val clazz = obj.javaClass + val objectIsCached = networkObjectCache.containsKey(clazz.canonicalName) + if(objectIsCached) { + return getObjectStructure(clazz).get() + } + + registerNetworkObjectStructure(obj) + return getObjectStructure(clazz).get() } private fun isBasicElement(type: NetworkElementType?): Boolean { diff --git a/src/test/java/Tests.java b/src/test/java/Tests.java index f6c1803..966c272 100644 --- a/src/test/java/Tests.java +++ b/src/test/java/Tests.java @@ -10,7 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -public class Tests { +class Tests { static final AtomicReference returnedObject = new AtomicReference<>(); static final WorkingObject sendingWorkingObject = new WorkingObject(); @@ -36,7 +36,7 @@ void WorkingObjectTest() { try { System.out.println("Waiting..."); synchronized (returnedObject) { - returnedObject.wait(10000); + returnedObject.wait(2000); } System.out.println("Done waiting."); } catch (InterruptedException e) {