diff --git a/c++/main.cu b/c++/main.cu index 6038119..9224ebe 100644 --- a/c++/main.cu +++ b/c++/main.cu @@ -31,17 +31,18 @@ struct BufferData { u32 depth; }; -struct BufferDataVec { +class BufferDataVec { + __device__ BufferDataVec(Ray* ray, float4* attenuation_and_pixel_index, u32* depth) : ray(ray), attenuation_and_pixel_index(attenuation_and_pixel_index), depth(depth) {} +public: Ray* ray; float4* attenuation_and_pixel_index; u32* depth; - BufferDataVec(size_t state_size) { + BufferDataVec(i32 state_size) { checkCudaErrors(cudaMalloc(&ray, state_size * sizeof(Ray))); checkCudaErrors(cudaMalloc(&attenuation_and_pixel_index, state_size * sizeof(float4))); checkCudaErrors(cudaMalloc(&depth, state_size * sizeof(u32))); } - void free () { // clean up checkCudaErrors(cudaGetLastError()); @@ -49,6 +50,25 @@ struct BufferDataVec { checkCudaErrors(cudaFree(attenuation_and_pixel_index)); checkCudaErrors(cudaFree(depth)); } + + __device__ BufferData operator[](i32 i) const { + float4 a_p = attenuation_and_pixel_index[i]; + colour attenuation(a_p.x, a_p.y, a_p.z); + u32 pixel_index = __float_as_uint(a_p.w); + + return BufferData(ray[i], attenuation, pixel_index, depth[i]); + } + + __device__ BufferDataVec operator[](i32 i) { + return BufferDataVec(&ray[i], &attenuation_and_pixel_index[i], &depth[i]); + } + + __device__ BufferDataVec& operator=(BufferData x) { + *ray = x.ray; + *attenuation_and_pixel_index = make_float4(x.attenuation, __uint_as_float(x.pixel_index)); + *depth = x.depth; + return *this; + } }; __device__ colour world_colour(Ray ray) { @@ -152,29 +172,20 @@ __device__ void scatter(colour* img, BufferDataVec next_state, BufferData curren } else { i32 old_index = atomicAdd(next_state_index, 1); - next_state.ray[old_index] = Ray(position, direction); - next_state.attenuation_and_pixel_index[old_index] = make_float4(new_attenuation, __uint_as_float(pixel_index)); - next_state.depth[old_index] = current_state.depth + 1u; + next_state[old_index] = BufferData(Ray(position, direction), new_attenuation, pixel_index, current_state.depth + 1u); } } } } -__global__ void intersect_and_scatter(colour* img, BufferDataVec next_state, BufferDataVec current_state, u32 max_depth, i32* next_state_index, i32 current_state_size, f32 tmin, f32 tmax, u32 number_of_rays_generated) { +__global__ void intersect_and_scatter(colour* img, BufferDataVec next_state, const BufferDataVec current_state, u32 max_depth, i32* next_state_index, i32 current_state_size, f32 tmin, f32 tmax, u32 number_of_rays_generated) { i32 index = blockIdx.x * blockDim.x + threadIdx.x; for (i32 i = index; i < current_state_size; i += gridDim.x * blockDim.x) { - Ray ray = current_state.ray[i]; - - HitRecord hit_record = hit(ray, tmin, tmax); - - float4 a_p = current_state.attenuation_and_pixel_index[i]; - colour attenuation(a_p.x, a_p.y, a_p.z); - u32 pixel_index = __float_as_uint(a_p.w); + BufferData state = current_state[i]; - u32 depth = current_state.depth[i]; - RNG rng((1u + pixel_index) * ((1u + i) + number_of_rays_generated) + depth); - BufferData state(ray, attenuation, pixel_index, depth); + HitRecord hit_record = hit(state.ray, tmin, tmax); + RNG rng((1u + state.pixel_index) * ((1u + i) + number_of_rays_generated) + state.depth); scatter(img, next_state, state, next_state_index, rng, hit_record, max_depth); } }