-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
303 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#include "AddDescriptor.hpp" | ||
#include "AddKernel.hpp" | ||
#include "../ccv_nnc_mfa_hash.hpp" | ||
#include "../ccv_nnc_mfa_error.hpp" | ||
|
||
bool AddDescriptor::operator==(const AddDescriptor& rhs) const { | ||
return | ||
memoryPrecision == rhs.memoryPrecision && | ||
value == rhs.value && | ||
length == rhs.length; | ||
} | ||
|
||
std::size_t std::hash<AddDescriptor>::operator()(const AddDescriptor& hash) const noexcept { | ||
using namespace ccv::nnc::mfa::hash; | ||
std::size_t seed = 0; | ||
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.memoryPrecision.value, (unsigned int)hash.value })); | ||
combine_64(seed, pack_64(simd::uint2 { (unsigned int)hash.length, 0 })); | ||
return seed; | ||
} | ||
|
||
std::pair<AddKernelDescriptor, PipelineValue<AddKernel> *> AddDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map<AddKernelDescriptor, std::unique_ptr<AddKernel>> *const libraryCache) const noexcept { | ||
// The caller is not responsible for calling 'delete' on this pointer. The | ||
// reference is saved in the 'libraryCache'. It will be deallocated whenever | ||
// the shader cache itself is cleaned up. | ||
auto createKernel = | ||
[=](AddKernelDescriptor descriptor) -> AddKernel* { | ||
auto iterator = libraryCache->find(descriptor); | ||
if (iterator != libraryCache->end()) { | ||
return iterator->second.get(); | ||
} else { | ||
AddKernel* kernel = new AddKernel(descriptor, device); | ||
(*libraryCache)[descriptor] = std::unique_ptr<AddKernel>(kernel); | ||
return kernel; | ||
} | ||
}; | ||
|
||
AddKernelDescriptor kernelDesc; | ||
kernelDesc.value = value; | ||
kernelDesc.memoryPrecision = memoryPrecision; | ||
|
||
// WARNING: The owner must explicitly retain the compute pipeline. | ||
auto createPipeline = | ||
[=](MTL::Library* library) -> MTL::ComputePipelineState* { | ||
// Set the function constants. | ||
auto constants = NS::TransferPtr | ||
(MTL::FunctionConstantValues::alloc()->init()); | ||
uint32_t count; | ||
if (value == 0) { | ||
} else if (value == 1) { | ||
count = length / 4; | ||
constants->setConstantValue(&count, MTL::DataTypeUInt, NS::UInteger(0)); | ||
} else { | ||
count = length; | ||
constants->setConstantValue(&count, MTL::DataTypeUInt, NS::UInteger(0)); | ||
} | ||
|
||
NS::String* swiftName = NS::String::string("add", NS::UTF8StringEncoding); | ||
NS::Error* error = nil; | ||
|
||
auto function = NS::TransferPtr | ||
(library->newFunction(swiftName, constants.get(), &error)); | ||
CCV_NNC_MFA_CHECK_ERROR(error); | ||
|
||
auto pipeline = device->newComputePipelineState(function.get(), &error); | ||
CCV_NNC_MFA_CHECK_ERROR(error); | ||
return pipeline; | ||
}; | ||
AddKernel* kernel = createKernel(kernelDesc); | ||
auto pipeline = NS::TransferPtr(createPipeline(kernel->library.get())); | ||
|
||
// Force the user to retrieve the return value from the cache. We ensure | ||
// the cache takes ownership, and the pointer doesn't become a zombie | ||
// object. | ||
PipelineValue<AddKernel>* output = new PipelineValue<AddKernel> { kernel, pipeline }; | ||
return std::make_pair(kernelDesc, output); | ||
} |
Oops, something went wrong.