diff --git a/Sources/MetalLibraryArchive/Archive.swift b/Sources/MetalLibraryArchive/Archive.swift index 57b45e3..686a0bc 100644 --- a/Sources/MetalLibraryArchive/Archive.swift +++ b/Sources/MetalLibraryArchive/Archive.swift @@ -241,7 +241,7 @@ public struct Archive: Hashable { _ = try dataScanner.scan(UInt64.self) //48...55 // Private metadata offset - _ = try dataScanner.scan(UInt64.self) //56...63 + let privateMetadataOffset = try dataScanner.scan(UInt64.self) //56...63 // Private metadata size _ = try dataScanner.scan(UInt64.self) //64...71 @@ -297,12 +297,33 @@ public struct Archive: Hashable { let functions: [Function] = try { var entries: [Function] = [] for info in functionInfos { - try dataScanner.seek(to: bitcodeOffset + Int(info.bitcodeOffset)) + try dataScanner.seek(to: bitcodeOffset + Int(info.offsets.bitcode)) let data = try dataScanner.scanData(byteCount: Int(info.bitcodeSize)) guard SHA256.hash(data: data) == info.hash else { throw Error.invalidBitcodeHash } - entries.append(Function(name: info.name, type: info.type, languageVersion: info.languageVersion, tags: info.tags, bitcode: data)) + + try dataScanner.seek(to: publicMetadataOffset + Int(info.offsets.publicMetadata)) + let publicMetadataTagSize = try dataScanner.scan(UInt32.self) + guard publicMetadataTagSize > 0 else { + throw Error.invalidTagGroupSize + } + let publicMetadataTags = try dataScanner.scanTags(contentSizeType: UInt16.self) + + try dataScanner.seek(to: privateMetadataOffset + Int(info.offsets.privateMetadata)) + let privateMetadataTagsSize = try dataScanner.scan(UInt32.self) + guard privateMetadataTagsSize > 0 else { + throw Error.invalidTagGroupSize + } + let privateMetadataTags = try dataScanner.scanTags(contentSizeType: UInt16.self) + + entries.append(Function(name: info.name, + type: info.type, + languageVersion: info.languageVersion, + tags: info.tags, + publicMetadataTags: publicMetadataTags, + privateMetadataTags: privateMetadataTags, + bitcode: data)) } return entries }() @@ -348,7 +369,7 @@ extension Archive { private struct FunctionInfo { var name: String var bitcodeSize: UInt64 - var bitcodeOffset: UInt64 + var offsets: (publicMetadata: UInt64, privateMetadata: UInt64, bitcode: UInt64) var type: FunctionType? var languageVersion: LanguageVersion var hash: Data @@ -359,7 +380,7 @@ extension Archive { let tags: [Tag] = try scanner.scanTags(contentSizeType: UInt16.self) var name: String? var bitcodeSize: UInt64? - var bitcodeOffset: UInt64? + var offsets: (publicMetadata: UInt64, privateMetadata: UInt64, bitcode: UInt64)? var type: FunctionType? var hash: Data? var languageVersion: LanguageVersion? @@ -391,10 +412,9 @@ extension Archive { guard tag.content.count == MemoryLayout.size * 3 else { throw Error.unexpectedTagContentSize(tagName: tag.name) } - bitcodeOffset = tag.content.withUnsafeBytes({ pointer in - // 0: public metadata offset - // 1: private metadata offset - pointer.bindMemory(to: UInt64.self)[2] + offsets = tag.content.withUnsafeBytes({ pointer in + let offsetValues = pointer.bindMemory(to: UInt64.self) + return (publicMetadata: offsetValues[0], privateMetadata: offsetValues[1], bitcode: offsetValues[2]) }) case "VERS": guard tag.content.count == MemoryLayout.size * 4 else { @@ -410,9 +430,9 @@ extension Archive { break } } - guard let name = name, let bitcodeSize = bitcodeSize, let bitcodeOffset = bitcodeOffset, let hash = hash, let languageVersion = languageVersion else { + guard let name = name, let bitcodeSize = bitcodeSize, let offsets = offsets, let hash = hash, let languageVersion = languageVersion else { throw Error.incompleteFunctionInfo } - return FunctionInfo(name: name, bitcodeSize: bitcodeSize, bitcodeOffset: bitcodeOffset, type: type, languageVersion: languageVersion, hash: hash, tags: tags) + return FunctionInfo(name: name, bitcodeSize: bitcodeSize, offsets: offsets, type: type, languageVersion: languageVersion, hash: hash, tags: tags) } } diff --git a/Sources/MetalLibraryArchive/Function.swift b/Sources/MetalLibraryArchive/Function.swift index e26fcbf..df23eac 100644 --- a/Sources/MetalLibraryArchive/Function.swift +++ b/Sources/MetalLibraryArchive/Function.swift @@ -41,6 +41,8 @@ public struct Function: Hashable { public let type: FunctionType? public let languageVersion: LanguageVersion public let tags: [Tag] + public let publicMetadataTags: [Tag] + public let privateMetadataTags: [Tag] public let bitcode: Data } diff --git a/Tests/MetalLibraryArchiveTests/Tests.swift b/Tests/MetalLibraryArchiveTests/Tests.swift index a9c0dea..854868a 100644 --- a/Tests/MetalLibraryArchiveTests/Tests.swift +++ b/Tests/MetalLibraryArchiveTests/Tests.swift @@ -274,6 +274,8 @@ class MetalLibraryArchiveTests_macOSSDK: XCTestCase { XCTAssertEqual(function.name, "test") XCTAssertEqual(function.type, .extern) XCTAssert(function.bitcode.count > 0) + XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "RETR" })) + XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "ARGR" })) } func testLanguageVersion_2_0() throws { @@ -441,6 +443,124 @@ class MetalLibraryArchiveTests_macOSSDK: XCTestCase { } } + func testTessellationFuntion() throws { + let source = """ + #include + using namespace metal; + + struct ControlPoint { + float3 position [[attribute(0)]]; + float3 normal [[attribute(1)]]; + }; + + struct PatchIn { + patch_control_point controlPoints; + }; + + struct VertexOut { + float4 position [[position]]; + }; + + [[patch(quad, 4)]] + vertex VertexOut vertex_subdiv_quad(PatchIn patch [[stage_in]], + float2 positionInPatch [[position_in_patch]]) { + VertexOut out = {0}; + return out; + } + """ + let data = try self.makeLibrary(source: source) + let archive = try Archive(data: data) + XCTAssertEqual(archive.functions.count, 1) + XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform) + XCTAssertEqual(archive.libraryType, .executable) + let tessellationTag = try XCTUnwrap(archive.functions[0].tags.first(where: { $0.name == "TESS" })) + XCTAssertEqual(tessellationTag.content.withUnsafeBytes({ $0.bindMemory(to: UInt8.self)[0] }), 4 << 2 | 2) + XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "VATT" })) + XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "VATY" })) + } + + func testFuntionConstants() throws { + let source = """ + #include + constant int constantValueA [[function_constant(0)]]; + kernel void testKernel(device float *io) { + for (int i = 0; i < constantValueA; i += 1) { + io[i] = 0; + } + } + """ + let data = try self.makeLibrary(source: source) + let archive = try Archive(data: data) + XCTAssertEqual(archive.functions.count, 1) + XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform) + XCTAssertEqual(archive.libraryType, .executable) + XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "CNST" })) + } + + func testFuntionConstants_unused() throws { + let source = """ + #include + constant int constantValueA [[function_constant(0)]]; + kernel void testKernel(device float *io) { + for (int i = 0; i < constantValueA; i += 1) { + break; + } + } + """ + let data = try self.makeLibrary(source: source) + let archive = try Archive(data: data) + XCTAssertEqual(archive.functions.count, 1) + XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform) + XCTAssertEqual(archive.libraryType, .executable) + XCTAssertEqual(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "CNST" }), false) + } + + func testLayeredRendering_uint() throws { + let source = """ + typedef struct { + uint layer [[render_target_array_index]]; + float4 position [[position]]; + } ColorInOut; + + vertex ColorInOut vertexTransform(){ + ColorInOut out; + out.layer = 0; + out.position = float4(0); + return out; + } + """ + let data = try self.makeLibrary(source: source) + let archive = try Archive(data: data) + XCTAssertEqual(archive.functions.count, 1) + XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform) + XCTAssertEqual(archive.libraryType, .executable) + let layerTag = try XCTUnwrap(archive.functions[0].tags.first(where: { $0.name == "LAYR" })) + XCTAssertEqual(layerTag.content.withUnsafeBytes({ $0.bindMemory(to: UInt8.self)[0] }), 0x21) + } + + func testLayeredRendering_ushort() throws { + let source = """ + typedef struct { + ushort layer [[render_target_array_index]]; + float4 position [[position]]; + } ColorInOut; + + vertex ColorInOut vertexTransform(){ + ColorInOut out; + out.layer = 0; + out.position = float4(0); + return out; + } + """ + let data = try self.makeLibrary(source: source) + let archive = try Archive(data: data) + XCTAssertEqual(archive.functions.count, 1) + XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform) + XCTAssertEqual(archive.libraryType, .executable) + let layerTag = try XCTUnwrap(archive.functions[0].tags.first(where: { $0.name == "LAYR" })) + XCTAssertEqual(layerTag.content.withUnsafeBytes({ $0.bindMemory(to: UInt8.self)[0] }), 0x29) + } + func testSourceArchives_executable() throws { let source = """ #include @@ -452,6 +572,7 @@ class MetalLibraryArchiveTests_macOSSDK: XCTestCase { XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform) XCTAssertEqual(archive.libraryType, .executable) XCTAssert(archive.sourceArchives.count > 0) + XCTAssert(archive.functions[0].privateMetadataTags.contains(where: { $0.name == "DEPF" })) } func testSourceArchives_dynamic() throws {