Skip to content

Commit

Permalink
Add tests for function metadata tags
Browse files Browse the repository at this point in the history
  • Loading branch information
YuAo committed Apr 4, 2022
1 parent cb37596 commit aae3453
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 11 deletions.
42 changes: 31 additions & 11 deletions Sources/MetalLibraryArchive/Archive.swift
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ public struct Archive: Hashable {
_ = try dataScanner.scan(UInt64.self) //48...55

// Private metadata offset
_ = try dataScanner.scan(UInt64.self) //56...63
let privateMetadataOffset = try dataScanner.scan(UInt64.self) //56...63
// Private metadata size
_ = try dataScanner.scan(UInt64.self) //64...71

Expand Down Expand Up @@ -297,12 +297,33 @@ public struct Archive: Hashable {
let functions: [Function] = try {
var entries: [Function] = []
for info in functionInfos {
try dataScanner.seek(to: bitcodeOffset + Int(info.bitcodeOffset))
try dataScanner.seek(to: bitcodeOffset + Int(info.offsets.bitcode))
let data = try dataScanner.scanData(byteCount: Int(info.bitcodeSize))
guard SHA256.hash(data: data) == info.hash else {
throw Error.invalidBitcodeHash
}
entries.append(Function(name: info.name, type: info.type, languageVersion: info.languageVersion, tags: info.tags, bitcode: data))

try dataScanner.seek(to: publicMetadataOffset + Int(info.offsets.publicMetadata))
let publicMetadataTagSize = try dataScanner.scan(UInt32.self)
guard publicMetadataTagSize > 0 else {
throw Error.invalidTagGroupSize
}
let publicMetadataTags = try dataScanner.scanTags(contentSizeType: UInt16.self)

try dataScanner.seek(to: privateMetadataOffset + Int(info.offsets.privateMetadata))
let privateMetadataTagsSize = try dataScanner.scan(UInt32.self)
guard privateMetadataTagsSize > 0 else {
throw Error.invalidTagGroupSize
}
let privateMetadataTags = try dataScanner.scanTags(contentSizeType: UInt16.self)

entries.append(Function(name: info.name,
type: info.type,
languageVersion: info.languageVersion,
tags: info.tags,
publicMetadataTags: publicMetadataTags,
privateMetadataTags: privateMetadataTags,
bitcode: data))
}
return entries
}()
Expand Down Expand Up @@ -348,7 +369,7 @@ extension Archive {
private struct FunctionInfo {
var name: String
var bitcodeSize: UInt64
var bitcodeOffset: UInt64
var offsets: (publicMetadata: UInt64, privateMetadata: UInt64, bitcode: UInt64)
var type: FunctionType?
var languageVersion: LanguageVersion
var hash: Data
Expand All @@ -359,7 +380,7 @@ extension Archive {
let tags: [Tag] = try scanner.scanTags(contentSizeType: UInt16.self)
var name: String?
var bitcodeSize: UInt64?
var bitcodeOffset: UInt64?
var offsets: (publicMetadata: UInt64, privateMetadata: UInt64, bitcode: UInt64)?
var type: FunctionType?
var hash: Data?
var languageVersion: LanguageVersion?
Expand Down Expand Up @@ -391,10 +412,9 @@ extension Archive {
guard tag.content.count == MemoryLayout<UInt64>.size * 3 else {
throw Error.unexpectedTagContentSize(tagName: tag.name)
}
bitcodeOffset = tag.content.withUnsafeBytes({ pointer in
// 0: public metadata offset
// 1: private metadata offset
pointer.bindMemory(to: UInt64.self)[2]
offsets = tag.content.withUnsafeBytes({ pointer in
let offsetValues = pointer.bindMemory(to: UInt64.self)
return (publicMetadata: offsetValues[0], privateMetadata: offsetValues[1], bitcode: offsetValues[2])
})
case "VERS":
guard tag.content.count == MemoryLayout<UInt16>.size * 4 else {
Expand All @@ -410,9 +430,9 @@ extension Archive {
break
}
}
guard let name = name, let bitcodeSize = bitcodeSize, let bitcodeOffset = bitcodeOffset, let hash = hash, let languageVersion = languageVersion else {
guard let name = name, let bitcodeSize = bitcodeSize, let offsets = offsets, let hash = hash, let languageVersion = languageVersion else {
throw Error.incompleteFunctionInfo
}
return FunctionInfo(name: name, bitcodeSize: bitcodeSize, bitcodeOffset: bitcodeOffset, type: type, languageVersion: languageVersion, hash: hash, tags: tags)
return FunctionInfo(name: name, bitcodeSize: bitcodeSize, offsets: offsets, type: type, languageVersion: languageVersion, hash: hash, tags: tags)
}
}
2 changes: 2 additions & 0 deletions Sources/MetalLibraryArchive/Function.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public struct Function: Hashable {
public let type: FunctionType?
public let languageVersion: LanguageVersion
public let tags: [Tag]
public let publicMetadataTags: [Tag]
public let privateMetadataTags: [Tag]
public let bitcode: Data
}

Expand Down
121 changes: 121 additions & 0 deletions Tests/MetalLibraryArchiveTests/Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ class MetalLibraryArchiveTests_macOSSDK: XCTestCase {
XCTAssertEqual(function.name, "test")
XCTAssertEqual(function.type, .extern)
XCTAssert(function.bitcode.count > 0)
XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "RETR" }))
XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "ARGR" }))
}

func testLanguageVersion_2_0() throws {
Expand Down Expand Up @@ -441,6 +443,124 @@ class MetalLibraryArchiveTests_macOSSDK: XCTestCase {
}
}

func testTessellationFuntion() throws {
let source = """
#include <metal_stdlib>
using namespace metal;
struct ControlPoint {
float3 position [[attribute(0)]];
float3 normal [[attribute(1)]];
};
struct PatchIn {
patch_control_point<ControlPoint> controlPoints;
};
struct VertexOut {
float4 position [[position]];
};
[[patch(quad, 4)]]
vertex VertexOut vertex_subdiv_quad(PatchIn patch [[stage_in]],
float2 positionInPatch [[position_in_patch]]) {
VertexOut out = {0};
return out;
}
"""
let data = try self.makeLibrary(source: source)
let archive = try Archive(data: data)
XCTAssertEqual(archive.functions.count, 1)
XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform)
XCTAssertEqual(archive.libraryType, .executable)
let tessellationTag = try XCTUnwrap(archive.functions[0].tags.first(where: { $0.name == "TESS" }))
XCTAssertEqual(tessellationTag.content.withUnsafeBytes({ $0.bindMemory(to: UInt8.self)[0] }), 4 << 2 | 2)
XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "VATT" }))
XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "VATY" }))
}

func testFuntionConstants() throws {
let source = """
#include <metal_stdlib>
constant int constantValueA [[function_constant(0)]];
kernel void testKernel(device float *io) {
for (int i = 0; i < constantValueA; i += 1) {
io[i] = 0;
}
}
"""
let data = try self.makeLibrary(source: source)
let archive = try Archive(data: data)
XCTAssertEqual(archive.functions.count, 1)
XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform)
XCTAssertEqual(archive.libraryType, .executable)
XCTAssert(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "CNST" }))
}

func testFuntionConstants_unused() throws {
let source = """
#include <metal_stdlib>
constant int constantValueA [[function_constant(0)]];
kernel void testKernel(device float *io) {
for (int i = 0; i < constantValueA; i += 1) {
break;
}
}
"""
let data = try self.makeLibrary(source: source)
let archive = try Archive(data: data)
XCTAssertEqual(archive.functions.count, 1)
XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform)
XCTAssertEqual(archive.libraryType, .executable)
XCTAssertEqual(archive.functions[0].publicMetadataTags.contains(where: { $0.name == "CNST" }), false)
}

func testLayeredRendering_uint() throws {
let source = """
typedef struct {
uint layer [[render_target_array_index]];
float4 position [[position]];
} ColorInOut;
vertex ColorInOut vertexTransform(){
ColorInOut out;
out.layer = 0;
out.position = float4(0);
return out;
}
"""
let data = try self.makeLibrary(source: source)
let archive = try Archive(data: data)
XCTAssertEqual(archive.functions.count, 1)
XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform)
XCTAssertEqual(archive.libraryType, .executable)
let layerTag = try XCTUnwrap(archive.functions[0].tags.first(where: { $0.name == "LAYR" }))
XCTAssertEqual(layerTag.content.withUnsafeBytes({ $0.bindMemory(to: UInt8.self)[0] }), 0x21)
}

func testLayeredRendering_ushort() throws {
let source = """
typedef struct {
ushort layer [[render_target_array_index]];
float4 position [[position]];
} ColorInOut;
vertex ColorInOut vertexTransform(){
ColorInOut out;
out.layer = 0;
out.position = float4(0);
return out;
}
"""
let data = try self.makeLibrary(source: source)
let archive = try Archive(data: data)
XCTAssertEqual(archive.functions.count, 1)
XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform)
XCTAssertEqual(archive.libraryType, .executable)
let layerTag = try XCTUnwrap(archive.functions[0].tags.first(where: { $0.name == "LAYR" }))
XCTAssertEqual(layerTag.content.withUnsafeBytes({ $0.bindMemory(to: UInt8.self)[0] }), 0x29)
}

func testSourceArchives_executable() throws {
let source = """
#include <metal_stdlib>
Expand All @@ -452,6 +572,7 @@ class MetalLibraryArchiveTests_macOSSDK: XCTestCase {
XCTAssertEqual(archive.targetPlatform, sdk.targetPlatform)
XCTAssertEqual(archive.libraryType, .executable)
XCTAssert(archive.sourceArchives.count > 0)
XCTAssert(archive.functions[0].privateMetadataTags.contains(where: { $0.name == "DEPF" }))
}

func testSourceArchives_dynamic() throws {
Expand Down

0 comments on commit aae3453

Please sign in to comment.